1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
17
18 #include <string>
19
20 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
21
22 namespace tflite {
23 namespace gpu {
24
Split(const OperationDef & definition,const SplitAttributes & attr)25 Split::Split(const OperationDef& definition, const SplitAttributes& attr)
26 : GPUOperation(definition), attr_(attr) {
27 work_group_size_ = int3(8, 4, 1);
28 code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
29 }
30
GetSplitCode()31 std::string Split::GetSplitCode() {
32 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
33 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
34 AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
35 }
36 const std::string task_width =
37 attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
38 const std::string task_height =
39 attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
40 const std::string task_depth =
41 attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
42 const std::string task_batch =
43 attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
44 const std::string task_slices =
45 attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
46
47 std::string c;
48 c += "MAIN_FUNCTION($0) {\n";
49 c += " int task_width = "
50 ";\n";
51 if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
52 c += " int linear_id = GLOBAL_ID_0;\n";
53 c += " int X = linear_id / " + task_batch + ";\n";
54 c += " int B = linear_id % " + task_batch + ";\n";
55 } else {
56 c += " int X = GLOBAL_ID_0;\n";
57 }
58 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
59 c += " int linear_id = GLOBAL_ID_1;\n";
60 c += " int Y = linear_id % " + task_height + ";\n";
61 c += " int B = linear_id / " + task_height + ";\n";
62 } else {
63 c += " int Y = GLOBAL_ID_1;\n";
64 }
65 c += " int S = GLOBAL_ID_2;\n";
66 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() || "
67 "S >= args.dst_tensor.Slices()) { \n";
68 c += " return; \n";
69 c += " } \n";
70 c += " int src_counter = 0;\n";
71 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
72 const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
73 c += " for (int i = 0; i < " + dst_name +
74 ".Slices(); ++i, src_counter++) {\n";
75 c += " FLT4 result = args.src_tensor.Read(s_x, s_y, src_counter);\n";
76 c += " " + dst_name + ".Write(result, X, Y, i);\n";
77 c += " }\n";
78 }
79 c += "}\n";
80 return c;
81 }
82
GetSplitChannelsCode()83 std::string Split::GetSplitChannelsCode() {
84 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
85 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
86 AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
87 }
88
89 const std::string batch_coord =
90 definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
91 std::string coords = "X, Y";
92 std::string c;
93 c += "MAIN_FUNCTION($0) {\n";
94 if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
95 c += " int linear_id = GLOBAL_ID_0;\n";
96 c += " int X = linear_id / args.src_tensor.Batch();\n";
97 c += " int B = linear_id % args.src_tensor.Batch();\n";
98 c += " if (X >= args.src_tensor.Width()) return;\n";
99 } else {
100 c += " int X = GLOBAL_ID_0;\n";
101 c += " if (X >= args.src_tensor.Width()) return;\n";
102 }
103 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
104 c += " int linear_id = GLOBAL_ID_1;\n";
105 c += " int Y = linear_id % args.src_tensor.Height();\n";
106 c += " int Z = linear_id / args.src_tensor.Height();\n";
107 c += " if (Z >= args.src_tensor.Depth()) return;\n";
108 coords += ", Z";
109 } else {
110 c += " int Y = GLOBAL_ID_1;\n";
111 c += " if (Y >= args.src_tensor.Height()) return;\n";
112 }
113 c += " int src_channel = 0;\n";
114 const std::string postfixes[] = {"x", "y", "z", "w"};
115 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
116 const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
117 c += " for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
118 c += " FLT4 result = INIT_FLT4(0.0f);\n";
119 for (int j = 0; j < 4; ++j) {
120 c += " if (i * 4 + " + std::to_string(j) + " < " + dst_name +
121 ".Channels()) {\n";
122 c += " int src_slice = src_channel >> 2;\n";
123 c += " int src_sub_ch = src_channel & 3;\n";
124 c += " FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
125 batch_coord + ");\n";
126 c += " FLT t_ar[4] = {t.x, t.y, t.z, t.w};\n";
127 c += " result." + postfixes[j] + " = t_ar[src_sub_ch];\n";
128 c += " src_channel++;\n";
129 c += " }\n";
130 }
131 c += " " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
132 ");\n";
133 c += " }\n";
134 }
135 c += "}\n";
136 return c;
137 }
138
GetGridSize() const139 int3 Split::GetGridSize() const {
140 const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
141 const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
142 const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
143 const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
144 const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
145 const int grid_x = width * batch;
146 const int grid_y = height * depth;
147 const int grid_z = slices;
148 return int3(grid_x, grid_y, grid_z);
149 }
150
CreateSplit(const OperationDef & definition,const SplitAttributes & attr)151 Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
152 return Split(definition, attr);
153 }
154
155 } // namespace gpu
156 } // namespace tflite
157