1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/split.h"
17
18 #include <string>
19
20 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
21
22 namespace tflite {
23 namespace gpu {
24
Split(const OperationDef & definition,const SplitAttributes & attr)25 Split::Split(const OperationDef& definition, const SplitAttributes& attr)
26 : GPUOperation(definition), attr_(attr) {
27 work_group_size_ = int3(8, 4, 1);
28 code_ = attr.axis == Axis::CHANNELS ? GetSplitChannelsCode() : GetSplitCode();
29 }
30
GetSplitCode()31 std::string Split::GetSplitCode() {
32 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
33 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
34 AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
35 }
36 const std::string task_width =
37 attr_.axis == Axis::WIDTH ? "1" : "args.src_tensor.Width()";
38 const std::string task_height =
39 attr_.axis == Axis::HEIGHT ? "1" : "args.src_tensor.Height()";
40 const std::string task_depth =
41 attr_.axis == Axis::DEPTH ? "1" : "args.src_tensor.Depth()";
42 const std::string task_batch =
43 attr_.axis == Axis::BATCH ? "1" : "args.src_tensor.Batch()";
44 const std::string task_slices =
45 attr_.axis == Axis::CHANNELS ? "1" : "args.src_tensor.Slices()";
46
47 std::map<Axis, std::string> axis_to_selector = {
48 {Axis::WIDTH, "Width"}, {Axis::HEIGHT, "Height"},
49 {Axis::DEPTH, "Depth"}, {Axis::CHANNELS, "Slices"},
50 {Axis::BATCH, "Batch"},
51 };
52 std::map<Axis, std::string> axis_to_coord = {
53 {Axis::WIDTH, "X"}, {Axis::HEIGHT, "Y"}, {Axis::DEPTH, "D"},
54 {Axis::CHANNELS, "S"}, {Axis::BATCH, "B"},
55 };
56
57 std::string c;
58 c += "MAIN_FUNCTION($0) {\n";
59 if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
60 c += " int linear_id = GLOBAL_ID_0;\n";
61 c += " int X = linear_id / " + task_batch + ";\n";
62 c += " int B = linear_id % " + task_batch + ";\n";
63 c += " if (X >= " + task_width + ") return;\n";
64 } else {
65 c += " int X = GLOBAL_ID_0;\n";
66 c += " if (X >= " + task_width + ") return;\n";
67 }
68 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
69 c += " int linear_id = GLOBAL_ID_1;\n";
70 c += " int Y = linear_id % " + task_height + ";\n";
71 c += " int D = linear_id / " + task_height + ";\n";
72 c += " if (D >= " + task_depth + ") return;\n";
73 } else {
74 c += " int Y = GLOBAL_ID_1;\n";
75 c += " if (Y >= " + task_height + ") return;\n";
76 }
77 c += " int S = GLOBAL_ID_2;\n";
78 c += " if (S >= " + task_slices + ") return;\n";
79 c += " int src_counter = 0;\n";
80 std::vector<std::string> src_coords;
81 for (auto axis :
82 {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS, Axis::BATCH}) {
83 if (definition_.src_tensors[0].HasAxis(axis)) {
84 const std::string coord_name =
85 attr_.axis == axis ? "src_counter" : axis_to_coord[axis];
86 src_coords.push_back(coord_name);
87 }
88 }
89 std::string src_coords_str = src_coords[0];
90 for (int i = 1; i < src_coords.size(); ++i) {
91 src_coords_str += ", " + src_coords[i];
92 }
93 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
94 std::vector<std::string> dst_coords;
95 for (auto axis : {Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH, Axis::CHANNELS,
96 Axis::BATCH}) {
97 if (definition_.dst_tensors[i].HasAxis(axis)) {
98 const std::string coord_name =
99 attr_.axis == axis ? "i" : axis_to_coord[axis];
100 dst_coords.push_back(coord_name);
101 }
102 }
103 std::string dst_coords_str = dst_coords[0];
104 for (int j = 1; j < dst_coords.size(); ++j) {
105 dst_coords_str += ", " + dst_coords[j];
106 }
107 const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
108 c += " for (int i = 0; i < " + dst_name + "." +
109 axis_to_selector[attr_.axis] + "(); ++i, src_counter++) {\n";
110 c += " FLT4 result = args.src_tensor.Read(" + src_coords_str + ");\n";
111 c += " " + dst_name + ".Write(result, " + dst_coords_str + ");\n";
112 c += " }\n";
113 }
114 c += "}\n";
115 return c;
116 }
117
GetSplitChannelsCode()118 std::string Split::GetSplitChannelsCode() {
119 AddSrcTensor("src_tensor", definition_.src_tensors[0]);
120 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
121 AddDstTensor("dst_tensor_" + std::to_string(i), definition_.dst_tensors[i]);
122 }
123
124 const std::string batch_coord =
125 definition_.src_tensors[0].HasAxis(Axis::BATCH) ? ", B" : "";
126 std::string coords = "X, Y";
127 std::string c;
128 c += "MAIN_FUNCTION($0) {\n";
129 if (definition_.src_tensors[0].HasAxis(Axis::BATCH)) {
130 c += " int linear_id = GLOBAL_ID_0;\n";
131 c += " int X = linear_id / args.src_tensor.Batch();\n";
132 c += " int B = linear_id % args.src_tensor.Batch();\n";
133 c += " if (X >= args.src_tensor.Width()) return;\n";
134 } else {
135 c += " int X = GLOBAL_ID_0;\n";
136 c += " if (X >= args.src_tensor.Width()) return;\n";
137 }
138 if (definition_.src_tensors[0].HasAxis(Axis::DEPTH)) {
139 c += " int linear_id = GLOBAL_ID_1;\n";
140 c += " int Y = linear_id % args.src_tensor.Height();\n";
141 c += " int Z = linear_id / args.src_tensor.Height();\n";
142 c += " if (Z >= args.src_tensor.Depth()) return;\n";
143 coords += ", Z";
144 } else {
145 c += " int Y = GLOBAL_ID_1;\n";
146 c += " if (Y >= args.src_tensor.Height()) return;\n";
147 }
148 c += " int src_channel = 0;\n";
149 const std::string postfixes[] = {"x", "y", "z", "w"};
150 for (int i = 0; i < definition_.dst_tensors.size(); ++i) {
151 const std::string dst_name = "args.dst_tensor_" + std::to_string(i);
152 c += " for (int i = 0; i < " + dst_name + ".Slices(); ++i) {\n";
153 c += " FLT4 result = INIT_FLT4(0.0f);\n";
154 for (int j = 0; j < 4; ++j) {
155 c += " if (i * 4 + " + std::to_string(j) + " < " + dst_name +
156 ".Channels()) {\n";
157 c += " int src_slice = src_channel >> 2;\n";
158 c += " int src_sub_ch = src_channel & 3;\n";
159 c += " FLT4 t = args.src_tensor.Read(" + coords + ", src_slice" +
160 batch_coord + ");\n";
161 c += " result." + postfixes[j] +
162 " = SELECT_BY_INDEX_FROM_FLT4(t, src_sub_ch);\n";
163 c += " src_channel++;\n";
164 c += " }\n";
165 }
166 c += " " + dst_name + ".Write(result, " + coords + ", i" + batch_coord +
167 ");\n";
168 c += " }\n";
169 }
170 c += "}\n";
171 return c;
172 }
173
GetGridSize() const174 int3 Split::GetGridSize() const {
175 const int width = attr_.axis == Axis::WIDTH ? 1 : src_[0]->Width();
176 const int height = attr_.axis == Axis::HEIGHT ? 1 : src_[0]->Height();
177 const int depth = attr_.axis == Axis::DEPTH ? 1 : src_[0]->Depth();
178 const int batch = attr_.axis == Axis::BATCH ? 1 : src_[0]->Batch();
179 const int slices = attr_.axis == Axis::CHANNELS ? 1 : src_[0]->Slices();
180 const int grid_x = width * batch;
181 const int grid_y = height * depth;
182 const int grid_z = slices;
183 return int3(grid_x, grid_y, grid_z);
184 }
185
CreateSplit(const OperationDef & definition,const SplitAttributes & attr)186 Split CreateSplit(const OperationDef& definition, const SplitAttributes& attr) {
187 return Split(definition, attr);
188 }
189
190 } // namespace gpu
191 } // namespace tflite
192