1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resize.h"
17
18 #include <string>
19 #include <utility>
20
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23
24 namespace tflite {
25 namespace gpu {
26
Resize(const OperationDef & definition,const Resize2DAttributes & attr)27 Resize::Resize(const OperationDef& definition, const Resize2DAttributes& attr)
28 : GPUOperation(definition), attr_(attr) {
29 code_ = GetResizeCode(definition_, attr_);
30 }
31
Resize(Resize && operation)32 Resize::Resize(Resize&& operation)
33 : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
34
operator =(Resize && operation)35 Resize& Resize::operator=(Resize&& operation) {
36 if (this != &operation) {
37 attr_ = operation.attr_;
38 GPUOperation::operator=(std::move(operation));
39 }
40 return *this;
41 }
42
GetResizeCode(const OperationDef & op_def,const Resize2DAttributes & attr)43 std::string Resize::GetResizeCode(const OperationDef& op_def,
44 const Resize2DAttributes& attr) {
45 auto src_desc = op_def.src_tensors[0];
46 if (op_def.IsBatchSupported()) {
47 src_desc.SetStateVar("BatchedWidth", "true");
48 }
49 AddSrcTensor("src_tensor", src_desc);
50 auto dst_desc = op_def.dst_tensors[0];
51 if (op_def.IsBatchSupported()) {
52 dst_desc.SetStateVar("BatchedWidth", "true");
53 }
54 AddDstTensor("dst_tensor", dst_desc);
55 args_.AddInt("border_x");
56 args_.AddInt("border_y");
57 args_.AddFloat("scale_factor_x");
58 args_.AddFloat("scale_factor_y");
59
60 std::string c;
61 c += "MAIN_FUNCTION($0) {\n";
62 c += " int Y = GLOBAL_ID_1;\n";
63 c += " int Z = GLOBAL_ID_2;\n";
64 if (op_def.IsBatchSupported()) {
65 c += " int linear_id = GLOBAL_ID_0;\n";
66 c += " int X = linear_id / args.dst_tensor.Batch();\n";
67 c += " int B = linear_id % args.dst_tensor.Batch();\n";
68 c += " if (linear_id >= args.dst_tensor.Width() || Y >= "
69 "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
70 } else {
71 c += " int X = GLOBAL_ID_0;\n";
72 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
73 "|| Z >= args.dst_tensor.Slices()) return;\n";
74 }
75 if (attr.type == SamplingType::NEAREST) {
76 std::string fxc;
77 std::string fyc;
78 if (attr.half_pixel_centers) {
79 fxc = "(INIT_FLOAT(X) + 0.5f) * args.scale_factor_x";
80 fyc = "(INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y";
81 } else {
82 fxc = "INIT_FLOAT(X) * args.scale_factor_x";
83 fyc = "INIT_FLOAT(Y) * args.scale_factor_y";
84 }
85 if (attr.align_corners) {
86 fxc += " + 0.5f";
87 fyc += " + 0.5f";
88 }
89 c += " int2 coord;\n";
90 c += " coord.x = INIT_INT(" + fxc + ");\n";
91 c += " coord.y = INIT_INT(" + fyc + ");\n";
92 c += " coord.x = max(0, coord.x);\n";
93 c += " coord.y = max(0, coord.y);\n";
94 c += " coord.x = min(coord.x, args.border_x);\n";
95 c += " coord.y = min(coord.y, args.border_y);\n";
96 if (op_def.IsBatchSupported()) {
97 c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
98 c += " X = X * args.src_tensor.Batch() + B;\n";
99 }
100 c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
101 } else {
102 if (attr.half_pixel_centers) {
103 c += " float2 f_coords = (INIT_FLOAT2v2(X, Y) + 0.5f) * "
104 "INIT_FLOAT2v2(args.scale_factor_x, args.scale_factor_y) - "
105 "0.5f;\n";
106 } else {
107 c += " float2 f_coords = INIT_FLOAT2v2(X, Y) * "
108 "INIT_FLOAT2v2(args.scale_factor_x, "
109 "args.scale_factor_y);\n";
110 }
111 c += " float2 f_coords_floor = floor(f_coords);\n";
112 c += " int2 coords_floor = INIT_INT2v2(f_coords_floor.x, "
113 "f_coords_floor.y);\n";
114 c += " int4 st;\n";
115 c += " st.xy = max(coords_floor, INIT_INT2v2(0, 0));\n";
116 c += " st.zw = min(coords_floor + INIT_INT2v2(1, 1), "
117 "INIT_INT2v2(args.border_x, "
118 "args.border_y));\n";
119 c += " float2 t = f_coords - f_coords_floor;\n";
120 if (op_def.IsBatchSupported()) {
121 c += " st.x = st.x * args.src_tensor.Batch() + B;\n";
122 c += " st.z = st.z * args.src_tensor.Batch() + B;\n";
123 c += " X = X * args.src_tensor.Batch() + B;\n";
124 }
125 c += " float4 src0 = args.src_tensor.Read<float>(st.x, st.y, Z);\n";
126 c += " float4 src1 = args.src_tensor.Read<float>(st.z, st.y, Z);\n";
127 c += " float4 src2 = args.src_tensor.Read<float>(st.x, st.w, Z);\n";
128 c += " float4 src3 = args.src_tensor.Read<float>(st.z, st.w, Z);\n";
129 c += " FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
130 "t.y));\n";
131 }
132 c += " args.dst_tensor.Write(r0, X, Y, Z);\n";
133 c += "}\n";
134 return c;
135 }
136
BindArguments(ArgumentsBinder * args)137 absl::Status Resize::BindArguments(ArgumentsBinder* args) {
138 RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
139 RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
140 RETURN_IF_ERROR(args->SetFloat(
141 "scale_factor_x",
142 CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
143 RETURN_IF_ERROR(args->SetFloat(
144 "scale_factor_y",
145 CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
146 return absl::OkStatus();
147 }
148
GetGridSize() const149 int3 Resize::GetGridSize() const {
150 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
151 const int grid_y = dst_[0]->Height();
152 const int grid_z = dst_[0]->Slices();
153 return int3(grid_x, grid_y, grid_z);
154 }
155
CreateResize(const OperationDef & definition,const Resize2DAttributes & attr)156 Resize CreateResize(const OperationDef& definition,
157 const Resize2DAttributes& attr) {
158 return Resize(definition, attr);
159 }
160
Resize3D(const OperationDef & definition,const Resize3DAttributes & attr)161 Resize3D::Resize3D(const OperationDef& definition,
162 const Resize3DAttributes& attr)
163 : GPUOperation(definition), attr_(attr) {
164 code_ = GetResize3DCode(definition_, attr_);
165 }
166
Resize3D(Resize3D && operation)167 Resize3D::Resize3D(Resize3D&& operation)
168 : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
169
operator =(Resize3D && operation)170 Resize3D& Resize3D::operator=(Resize3D&& operation) {
171 if (this != &operation) {
172 attr_ = operation.attr_;
173 GPUOperation::operator=(std::move(operation));
174 }
175 return *this;
176 }
177
GetResize3DCode(const OperationDef & op_def,const Resize3DAttributes & attr)178 std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
179 const Resize3DAttributes& attr) {
180 auto src_desc = op_def.src_tensors[0];
181 if (op_def.IsBatchSupported()) {
182 src_desc.SetStateVar("BatchedWidth", "true");
183 }
184 AddSrcTensor("src_tensor", src_desc);
185 auto dst_desc = op_def.dst_tensors[0];
186 if (op_def.IsBatchSupported()) {
187 dst_desc.SetStateVar("BatchedWidth", "true");
188 }
189 AddDstTensor("dst_tensor", dst_desc);
190 args_.AddInt("border_x");
191 args_.AddInt("border_y");
192 args_.AddInt("border_z");
193 args_.AddFloat("scale_factor_x");
194 args_.AddFloat("scale_factor_y");
195 args_.AddFloat("scale_factor_z");
196
197 std::string c;
198 c += "MAIN_FUNCTION($0) {\n";
199 c += " int Y = GLOBAL_ID_1;\n";
200 c += " int linear_id_z = GLOBAL_ID_2;\n";
201 c += " int S = linear_id_z % args.dst_tensor.Slices();\n";
202 c += " int Z = linear_id_z / args.dst_tensor.Slices();\n";
203 if (op_def.IsBatchSupported()) {
204 c += " int linear_id = GLOBAL_ID_0;\n";
205 c += " int X = linear_id / args.dst_tensor.Batch();\n";
206 c += " int B = linear_id % args.dst_tensor.Batch();\n";
207 c += " if (linear_id >= args.dst_tensor.Width() || Y >= "
208 "args.dst_tensor.Height() || Z >= args.dst_tensor.Depth()) return;\n";
209 } else {
210 c += " int X = GLOBAL_ID_0;\n";
211 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
212 "|| Z >= args.dst_tensor.Depth()) return;\n";
213 }
214 if (attr.type == SamplingType::NEAREST) {
215 std::string fxc;
216 std::string fyc;
217 std::string fzc;
218 if (attr.half_pixel_centers) {
219 fxc = "(INIT_FLOAT(X) + 0.5f) * args.scale_factor_x";
220 fyc = "(INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y";
221 fzc = "(INIT_FLOAT(Z) + 0.5f) * args.scale_factor_z";
222 } else {
223 fxc = "INIT_FLOAT(X) * args.scale_factor_x";
224 fyc = "INIT_FLOAT(Y) * args.scale_factor_y";
225 fzc = "INIT_FLOAT(Z) * args.scale_factor_z";
226 }
227 if (attr.align_corners) {
228 fxc += " + 0.5f";
229 fyc += " + 0.5f";
230 fzc += " + 0.5f";
231 }
232 c += " int4 coord;\n";
233 c += " coord.x = INIT_INT(" + fxc + ");\n";
234 c += " coord.y = INIT_INT(" + fyc + ");\n";
235 c += " coord.z = INIT_INT(" + fzc + ");\n";
236 c += " coord.x = max(0, coord.x);\n";
237 c += " coord.y = max(0, coord.y);\n";
238 c += " coord.z = max(0, coord.z);\n";
239 c += " coord.x = min(coord.x, args.border_x);\n";
240 c += " coord.y = min(coord.y, args.border_y);\n";
241 c += " coord.z = min(coord.z, args.border_z);\n";
242 if (op_def.IsBatchSupported()) {
243 c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
244 c += " X = X * args.src_tensor.Batch() + B;\n";
245 }
246 c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, coord.z, S);\n";
247 } else {
248 c += " float4 f_coords;\n";
249 c += " f_coords.x = INIT_FLOAT(X) * args.scale_factor_x;\n";
250 c += " f_coords.y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
251 c += " f_coords.z = INIT_FLOAT(Z) * args.scale_factor_z;\n";
252 c += " int4 start = INIT_INT4v4(f_coords.x, f_coords.y, f_coords.z, 0);\n";
253 c += " int4 end;\n";
254 c += " end.x = min(start.x + 1, args.border_x);\n";
255 c += " end.y = min(start.y + 1, args.border_y);\n";
256 c += " end.z = min(start.z + 1, args.border_z);\n";
257 c += " float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
258 if (op_def.IsBatchSupported()) {
259 c += " start.x = start.x * args.src_tensor.Batch() + B;\n";
260 c += " end.x = end.x * args.src_tensor.Batch() + B;\n";
261 c += " X = X * args.src_tensor.Batch() + B;\n";
262 }
263 c += " float4 src0 = args.src_tensor.Read<float>(start.x, start.y, "
264 "start.z, S);\n";
265 c += " float4 src1 = args.src_tensor.Read<float>(end.x, start.y, start.z, "
266 "S);\n";
267 c += " float4 src2 = args.src_tensor.Read<float>(start.x, end.y, start.z, "
268 "S);\n";
269 c += " float4 src3 = args.src_tensor.Read<float>(end.x, end.y, start.z, "
270 "S);\n";
271 c += " float4 src4 = args.src_tensor.Read<float>(start.x, start.y, end.z, "
272 "S);\n";
273 c += " float4 src5 = args.src_tensor.Read<float>(end.x, start.y, end.z, "
274 "S);\n";
275 c += " float4 src6 = args.src_tensor.Read<float>(start.x, end.y, end.z, "
276 "S);\n";
277 c += " float4 src7 = args.src_tensor.Read<float>(end.x, end.y, end.z, "
278 "S);\n";
279 c +=
280 " float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
281 c +=
282 " float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
283 c += " FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
284 }
285 c += " args.dst_tensor.Write(r0, X, Y, Z, S);\n";
286 c += "}\n";
287 return c;
288 }
289
BindArguments(ArgumentsBinder * args)290 absl::Status Resize3D::BindArguments(ArgumentsBinder* args) {
291 RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
292 RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
293 RETURN_IF_ERROR(args->SetInt("border_z", src_[0]->Depth() - 1));
294 RETURN_IF_ERROR(args->SetFloat(
295 "scale_factor_x",
296 CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
297 RETURN_IF_ERROR(args->SetFloat(
298 "scale_factor_y",
299 CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
300 RETURN_IF_ERROR(args->SetFloat(
301 "scale_factor_z",
302 CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
303 return absl::OkStatus();
304 }
305
GetGridSize() const306 int3 Resize3D::GetGridSize() const {
307 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
308 const int grid_y = dst_[0]->Height();
309 const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
310 return int3(grid_x, grid_y, grid_z);
311 }
312
CreateResize3D(const OperationDef & definition,const Resize3DAttributes & attr)313 Resize3D CreateResize3D(const OperationDef& definition,
314 const Resize3DAttributes& attr) {
315 return Resize3D(definition, attr);
316 }
317
318 } // namespace gpu
319 } // namespace tflite
320