• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
28 
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33 
34 class AlignedConcatByChannels : public NodeShader {
35  public:
IsSupported(const GenerationContext & ctx)36   static bool IsSupported(const GenerationContext& ctx) {
37     const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
38 
39     // Implementation supports concatenation by channels only.
40     if (attr.axis != Axis::CHANNELS) return false;
41 
42     // Implementation supports concatenation of 2 tensors only.
43     if (ctx.input_shapes.size() != 2) return false;
44 
45     // H and W must be the same for every concatenated tensor.
46     for (int i = 1; i < ctx.input_shapes.size(); i++) {
47       if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
48           ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
49         return false;
50       }
51     }
52 
53     // Channels must be aligned by 4 for every concatenated tensor.
54     for (const auto& shape : ctx.input_shapes) {
55       if (shape[3] % 4 != 0) return false;
56     }
57 
58     return true;
59   }
60 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const61   absl::Status GenerateCode(const GenerationContext& ctx,
62                             GeneratedCode* generated_code) const final {
63     if (!IsSupported(ctx)) {
64       return absl::InvalidArgumentError(
65           "This case is not supported by aligned concat");
66     }
67 
68     // Shader below concatenates 2 tensors which channels are aligned by 4
69     std::string source = R"(
70       if (gid.z < $border$) {
71         value_0 = $input_data_0[gid.x, gid.y, gid.z]$;
72       } else {
73         int z = gid.z - $border$;
74         value_0 = $input_data_1[gid.x, gid.y, z]$;
75       }
76 )";
77     *generated_code = {
78         /*parameters=*/{
79             {"border", static_cast<int>(ctx.input_shapes[0][3]) / 4}},
80         /*objects=*/{},
81         /*shared_variables=*/{},
82         /*workload=*/uint3(),
83         /*workgroup=*/uint3(),
84         /*source_code=*/std::move(source),
85         /*input=*/IOStructure::ONLY_DEFINITIONS,
86         /*output=*/IOStructure::AUTO,
87     };
88     return absl::OkStatus();
89   }
90 };
91 
92 class ConcatByAnyChannel : public NodeShader {
93  public:
IsSupported(const GenerationContext & ctx)94   static bool IsSupported(const GenerationContext& ctx) {
95     const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
96 
97     // Implementation supports concatenation by channels only.
98     if (attr.axis != Axis::CHANNELS) return false;
99 
100     // Implementation supports concatenation of more that 1 tensors only.
101     if (ctx.input_shapes.size() <= 1) return false;
102 
103     // H and W must be the same for every concatenated tensor.
104     for (int i = 1; i < ctx.input_shapes.size(); i++) {
105       if (ctx.input_shapes[0][1] != ctx.input_shapes[i][1] ||
106           ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
107         return false;
108       }
109     }
110 
111     return true;
112   }
113 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const114   absl::Status GenerateCode(const GenerationContext& ctx,
115                             GeneratedCode* generated_code) const final {
116     if (!IsSupported(ctx)) {
117       return absl::UnimplementedError("This case is not supported by concat");
118     }
119 
120     std::string code = DeclareVariables();
121 
122     // "already_written" is used to keep the amount of already joined channels
123     int already_written = 0;
124     // "t" is an id of the next temp* variable.
125     // Generally, temp* variables are used in macros
126     // READ_BUFFER_VEC4(buff, addr, var).
127     // This macros instantiate the variable "var" and
128     // reads the value from buffer "buff" by address "addr"
129     int t = 0;
130     for (int current_input_id = 0; current_input_id < ctx.input_shapes.size();
131          current_input_id++) {
132       // Start joining next inout tensor
133 
134       // Grab channels amount
135       int in_ch = ctx.input_shapes[current_input_id][3];
136       code += PrintStartMessage(current_input_id, in_ch, already_written);
137 
138       // Construct the buffer name associated with this tensor
139       std::string input = "input_data_" + std::to_string(current_input_id);
140 
141       // "reminder" shows us how many cells in 4-element vector are left after
142       // the last write. As example, if we join two tensors both with
143       // 3 channels, after joining the first one we come to this line again
144       // and, when joining the second tensor, the reminder value
145       // will be equal to 1
146       int reminder = already_written % 4;
147 
148       if (reminder == 0) {
149         code += AlignedCase(in_ch, input);
150       } else {
151         code += UnalignedCase(reminder, in_ch, input, &t);
152       }
153       already_written += in_ch;
154     }
155 
156     *generated_code = {
157         /*parameters=*/{},
158         /*objects=*/{},
159         /*shared_variables=*/{},
160         /*workload=*/
161         uint3(static_cast<int>(ctx.output_shapes[0][2]),
162               static_cast<int>(ctx.output_shapes[0][1]), 1),
163         /*workgroup=*/uint3(),
164         /*source_code=*/std::move(code),
165         /*input=*/IOStructure::ONLY_DEFINITIONS,
166         /*output=*/IOStructure::ONLY_DEFINITIONS,
167     };
168     return absl::OkStatus();
169   }
170 
171  private:
172   // Utility function
temp(int t) const173   std::string temp(int t) const { return "temp" + std::to_string(t); }
174 
DeclareVariables() const175   std::string DeclareVariables() const {
176     // "val" is used to collect useful information before the next
177     // upcoming write.
178     return R"(
179 int z = gid.z;
180 vec4 val = vec4(0.0f);
181 
182 )";
183   }
184 
PrintStartMessage(int current_input_id,int in_ch,int already_written) const185   std::string PrintStartMessage(int current_input_id, int in_ch,
186                                 int already_written) const {
187     return "//              Joining " + std::to_string(current_input_id) +
188            " tensor with " + std::to_string(in_ch) +
189            " channels\n//  * * * *\\n// Already wrote " +
190            std::to_string(already_written) + " elements\n\n";
191   }
192 
AlignedCase(int in_ch,const std::string & input) const193   std::string AlignedCase(int in_ch, const std::string& input) const {
194     std::string code;
195     // This branch is for aligned reading and writing, when we can copy
196     // all 4 components at once. Address of the first element to write
197     // should be aligned.
198     // Visual examples:
199     // 1) when copy input_data_0
200     //
201     //       | * * * * | * * * @ | @ @ . . .
202     //         ^
203     // 2) when in the middle of joining process:
204     //
205     //       | X X X X | * * * @ | @ @ . . .
206     //                   ^
207     // Note that amount of * equals to the in_ch
208     //
209     // X - cells were written before
210     // * - you are going to write into these cells
211     // @ - you will fill these cells next cycles
212     // ^ - first elem you start writing from
213     int blocks_amount = DivideRoundUp<int>(in_ch, 4);
214     code += "// Aligned case\n";
215     code += "// I'm going to make " + std::to_string(blocks_amount) +
216             " write(s)\n\n";
217     for (int block = 0; block < blocks_amount; block++) {
218       // Copy full 4-element vector
219       code += "val = $" + input + "[gid.x, gid.y, " + std::to_string(block) +
220               "]$;\n" +
221               "$output_data_0[gid.x, gid.y, z] = val$;\n"
222               // calculate next address to write
223               + "z++; \n\n";
224     }
225     return code;
226   }
227 
UnalignedCase(int reminder,int in_ch,const std::string & input,int * t) const228   std::string UnalignedCase(int reminder, int in_ch, const std::string& input,
229                             int* t) const {
230     // This branch is for copying cell-by-cell. It will never start from the
231     // first tensor input_data_0. This function is splitting in two stages:
232     // 1) Copy the "leftovers" for the previous cells
233     // 2) Copy all other
234     // Visual examples:
235     //
236     //        Stage 1       Stage 2
237     //        -----------   -------------------------
238     // . . X | X  X  X *1 | *2 *2 *2  @ | @  @  . . .
239     //               ^
240     // . . X | X  X *1 *1 | *2 *2 *2 *2 | *2 *2 . . .
241     //             ^
242     // . . X | X *1 *1 *1 | *2  @  @  @ | @  @  . . .
243     //           ^
244     // Note that amount of * equals to the in_ch
245     //
246     // X - cells were written before
247     // *1 - write there at the Stage 1
248     // *2 - write there at the Stage 2
249     // @ - you will fill these cells next cycles
250     // ^ - first elem you start writing from
251 
252     std::string code = "// Unaligned case\n";
253 
254     // Variable "shift" showes how many "empty" cells are left after previous
255     // write. Remember, that this case should is unaligned.
256     // shift now can only be 1, 2 or 3
257     int shift = 4 - reminder;
258     if (shift > in_ch) {
259       shift = in_ch;
260     }
261     code += "\n// Stage 1\n";
262     code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, 0]$;\n";
263     for (int i = 0; i < shift; i++) {
264       // Note that reminder + i has implicitly added 1, cause
265       // reminder by it's nature is an amount, not an index
266       code += "val[" + std::to_string(reminder + i) + "] = " + temp(*t) + "[" +
267               std::to_string(i) + "];\n";
268     }
269     // Rewrite previous value with updated last cells
270     code += "$output_data_0[gid.x, gid.y, z - 1] = val$;\n";
271     (*t)++;
272 
273     // "left_blocks" is equal to an amount of WRITE_BUFFER_VEC4 calls
274     // which will are left for this input to be finally copied
275     int left_blocks = (in_ch - shift) / 4;
276     if ((in_ch - shift) % 4 != 0) {
277       left_blocks++;
278     }
279     if (left_blocks) {
280       code += "\n// Stage 2\n";
281       for (int block = 0; block < left_blocks; block++) {
282         for (int elem = 0; elem < 4; elem++) {
283           if (shift % 4 == 0) {
284             code += "vec4 " + temp(*t) + " = $" + input + "[gid.x, gid.y, " +
285                     std::to_string(block + 1) + "]$;\n";
286             (*t)++;
287           }
288           code += "val[" + std::to_string(elem) + "] = " + temp(*t - 1) + "[" +
289                   std::to_string(shift % 4) + "];\n";
290           if (shift == in_ch) {
291             break;
292           }
293           shift++;
294         }
295         code += "$output_data_0[gid.x, gid.y, z] = val$;\n";
296         code += "z++;\n";
297       }
298     } else {
299       code += "// No Stage 2\n";
300     }
301     return code;
302   }
303 };
304 
305 class FlatConcatByHeight : public NodeShader {
306  public:
IsSupported(const GenerationContext & ctx)307   static bool IsSupported(const GenerationContext& ctx) {
308     const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
309 
310     // Implementation supports concatenation by height only.
311     if (attr.axis != Axis::HEIGHT) return false;
312 
313     // Implementation supports concatenation of more that 1 tensors only.
314     if (ctx.input_shapes.size() <= 1) return false;
315 
316     // C and W must be the same for every concatenated tensor.
317     for (int i = 1; i < ctx.input_shapes.size(); i++) {
318       if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
319           ctx.input_shapes[0][2] != ctx.input_shapes[i][2]) {
320         return false;
321       }
322     }
323 
324     return true;
325   }
326 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const327   absl::Status GenerateCode(const GenerationContext& ctx,
328                             GeneratedCode* generated_code) const final {
329     std::string code;
330     std::vector<Variable> params;
331     for (int i = 0, shift = 0; i < ctx.input_shapes.size();
332          shift += ctx.input_shapes[i][1], i++) {
333       code += "if (";
334       if (i != 0) {
335         code += "$input_data_" + std::to_string(i - 1) + "_h$ <= gid.y && ";
336       }
337       code +=
338           "gid.y < " + std::to_string(shift + ctx.input_shapes[i][1]) + ") {\n";
339       code += "if (gid.y - " + std::to_string(shift) + " >= $input_data_" +
340               std::to_string(i) + "_h$) return;\n";
341       code += "value_0 = $input_data_" + std::to_string(i) +
342               "[gid.x, gid.y - " + std::to_string(shift) + ", gid.z]$;\n}\n";
343       if (i != ctx.input_shapes.size() - 1) {
344         code += " else ";
345       }
346       params.push_back({"input_data_" + std::to_string(i) + "_h",
347                         static_cast<int>(ctx.input_shapes[i][1])});
348     }
349 
350     *generated_code = {
351         /*parameters=*/std::move(params),
352         /*objects=*/{},
353         /*shared_variables=*/{},
354         /*workload=*/uint3(),
355         /*workgroup=*/uint3(),
356         /*source_code=*/std::move(code),
357         /*input=*/IOStructure::ONLY_DEFINITIONS,
358         /*output=*/IOStructure::AUTO,
359     };
360     return absl::OkStatus();
361   }
362 };
363 
364 class FlatConcatByWidth : public NodeShader {
365  public:
IsSupported(const GenerationContext & ctx)366   static bool IsSupported(const GenerationContext& ctx) {
367     const auto& attr = absl::any_cast<const ConcatAttributes&>(ctx.op_attr);
368 
369     // Implementation supports concatenation by width only.
370     if (attr.axis != Axis::WIDTH) return false;
371 
372     // Implementation supports concatenation of more that 1 tensors only.
373     if (ctx.input_shapes.size() <= 1) return false;
374 
375     // C and H must be the same for every concatenated tensor.
376     for (int i = 1; i < ctx.input_shapes.size(); i++) {
377       if (ctx.input_shapes[0][3] != ctx.input_shapes[i][3] ||
378           ctx.input_shapes[0][1] != ctx.input_shapes[i][1]) {
379         return false;
380       }
381     }
382 
383     return true;
384   }
385 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const386   absl::Status GenerateCode(const GenerationContext& ctx,
387                             GeneratedCode* generated_code) const final {
388     std::string code;
389     std::vector<Variable> params;
390     for (int i = 0, shift = 0; i < ctx.input_shapes.size();
391          shift += ctx.input_shapes[i][2], i++) {
392       code += "if (";
393       if (i != 0) {
394         code += "$input_data_" + std::to_string(i - 1) + "_w$ <= gid.x && ";
395       }
396       code +=
397           "gid.x < " + std::to_string(shift + ctx.input_shapes[i][2]) + ") {\n";
398       code += "if (gid.x - " + std::to_string(shift) + " >= $input_data_" +
399               std::to_string(i) + "_w$) return;\n";
400       code += "value_0 = $input_data_" + std::to_string(i) + "[gid.x - " +
401               std::to_string(shift) + ", gid.y, gid.z]$;\n}\n";
402       if (i != ctx.input_shapes.size() - 1) {
403         code += " else ";
404       }
405       params.push_back({"input_data_" + std::to_string(i) + "_w",
406                         static_cast<int>(ctx.input_shapes[i][2])});
407     }
408 
409     *generated_code = {
410         /*parameters=*/std::move(params),
411         /*objects=*/{},
412         /*shared_variables=*/{},
413         /*workload=*/uint3(),
414         /*workgroup=*/uint3(),
415         /*source_code=*/std::move(code),
416         /*input=*/IOStructure::ONLY_DEFINITIONS,
417         /*output=*/IOStructure::AUTO,
418     };
419     return absl::OkStatus();
420   }
421 };
422 
423 class FlatConcat : public NodeShader {
424  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const425   absl::Status GenerateCode(const GenerationContext& ctx,
426                             GeneratedCode* generated_code) const final {
427     if (FlatConcatByHeight::IsSupported(ctx)) {
428       return flat_concat_by_height_.GenerateCode(ctx, generated_code);
429     }
430     if (FlatConcatByWidth::IsSupported(ctx)) {
431       return flat_concat_by_width_.GenerateCode(ctx, generated_code);
432     }
433     return absl::InvalidArgumentError(
434         "This case is not supported by flat concat");
435   }
436 
437  private:
438   FlatConcatByHeight flat_concat_by_height_;
439   FlatConcatByWidth flat_concat_by_width_;
440 };
441 
442 }  // namespace
443 
NewAlignedConcatNodeShader()444 std::unique_ptr<NodeShader> NewAlignedConcatNodeShader() {
445   return absl::make_unique<AlignedConcatByChannels>();
446 }
447 
NewConcatNodeShader()448 std::unique_ptr<NodeShader> NewConcatNodeShader() {
449   return absl::make_unique<ConcatByAnyChannel>();
450 }
451 
NewFlatConcatNodeShader()452 std::unique_ptr<NodeShader> NewFlatConcatNodeShader() {
453   return absl::make_unique<FlatConcat>();
454 }
455 
456 }  // namespace gl
457 }  // namespace gpu
458 }  // namespace tflite
459