1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resize.h"
17
18 #include "tensorflow/lite/delegates/gpu/common/operations.h"
19 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
20
21 namespace tflite {
22 namespace gpu {
23
Resize(const OperationDef & definition,const Resize2DAttributes & attr)24 Resize::Resize(const OperationDef& definition, const Resize2DAttributes& attr)
25 : GPUOperation(definition), attr_(attr) {
26 code_ = GetResizeCode(definition_, attr_);
27 }
28
Resize(Resize && operation)29 Resize::Resize(Resize&& operation)
30 : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
31
operator =(Resize && operation)32 Resize& Resize::operator=(Resize&& operation) {
33 if (this != &operation) {
34 attr_ = operation.attr_;
35 GPUOperation::operator=(std::move(operation));
36 }
37 return *this;
38 }
39
GetResizeCode(const OperationDef & op_def,const Resize2DAttributes & attr)40 std::string Resize::GetResizeCode(const OperationDef& op_def,
41 const Resize2DAttributes& attr) {
42 auto src_desc = op_def.src_tensors[0];
43 if (op_def.IsBatchSupported()) {
44 src_desc.SetStateVar("BatchedWidth", "true");
45 }
46 AddSrcTensor("src_tensor", src_desc);
47 auto dst_desc = op_def.dst_tensors[0];
48 if (op_def.IsBatchSupported()) {
49 dst_desc.SetStateVar("BatchedWidth", "true");
50 }
51 AddDstTensor("dst_tensor", dst_desc);
52 args_.AddInt("border_x");
53 args_.AddInt("border_y");
54 args_.AddFloat("scale_factor_x");
55 args_.AddFloat("scale_factor_y");
56
57 std::string c;
58 c += "MAIN_FUNCTION($0) {\n";
59 c += " int Y = GLOBAL_ID_1;\n";
60 c += " int Z = GLOBAL_ID_2;\n";
61 if (op_def.IsBatchSupported()) {
62 c += " int linear_id = GLOBAL_ID_0;\n";
63 c += " int X = linear_id / args.dst_tensor.Batch();\n";
64 c += " int B = linear_id % args.dst_tensor.Batch();\n";
65 c += " if (linear_id >= args.dst_tensor.Width() || Y >= "
66 "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
67 } else {
68 c += " int X = GLOBAL_ID_0;\n";
69 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
70 "|| Z >= args.dst_tensor.Slices()) return;\n";
71 }
72 if (attr.type == SamplingType::NEAREST) {
73 std::string fxc;
74 std::string fyc;
75 if (attr.half_pixel_centers) {
76 fxc = "(X + 0.5f) * args.scale_factor_x";
77 fyc = "(Y + 0.5f) * args.scale_factor_y";
78 } else {
79 fxc = "X * args.scale_factor_x";
80 fyc = "Y * args.scale_factor_y";
81 }
82 if (attr.align_corners) {
83 fxc += " + 0.5f";
84 fyc += " + 0.5f";
85 }
86 c += " int2 coord;\n";
87 c += " coord.x = INIT_INT(" + fxc + ");\n";
88 c += " coord.y = INIT_INT(" + fyc + ");\n";
89 c += " coord.x = max(0, coord.x);\n";
90 c += " coord.y = max(0, coord.y);\n";
91 c += " coord.x = min(coord.x, args.border_x);\n";
92 c += " coord.y = min(coord.y, args.border_y);\n";
93 if (op_def.IsBatchSupported()) {
94 c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
95 c += " X = X * args.src_tensor.Batch() + B;\n";
96 }
97 c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
98 } else {
99 if (attr.half_pixel_centers) {
100 c += " float2 f_coords = (INIT_FLOAT2v2(X, Y) + 0.5f) * "
101 "INIT_FLOAT2v2(args.scale_factor_x, args.scale_factor_y) - "
102 "0.5f;\n";
103 } else {
104 c += " float2 f_coords = INIT_FLOAT2v2(X, Y) * "
105 "INIT_FLOAT2v2(args.scale_factor_x, "
106 "args.scale_factor_y);\n";
107 }
108 c += " float2 f_coords_floor = floor(f_coords);\n";
109 c += " int2 coords_floor = INIT_INT2v2(f_coords_floor.x, "
110 "f_coords_floor.y);\n";
111 c += " int4 st;\n";
112 c += " st.xy = max(coords_floor, INIT_INT2v2(0, 0));\n";
113 c += " st.zw = min(coords_floor + INIT_INT2v2(1, 1), "
114 "INIT_INT2v2(args.border_x, "
115 "args.border_y));\n";
116 c += " float2 t = f_coords - f_coords_floor;\n";
117 if (op_def.IsBatchSupported()) {
118 c += " st.x = st.x * args.src_tensor.Batch() + B;\n";
119 c += " st.z = st.z * args.src_tensor.Batch() + B;\n";
120 c += " X = X * args.src_tensor.Batch() + B;\n";
121 }
122 c += " float4 src0 = args.src_tensor.Read<float>(st.x, st.y, Z);\n";
123 c += " float4 src1 = args.src_tensor.Read<float>(st.z, st.y, Z);\n";
124 c += " float4 src2 = args.src_tensor.Read<float>(st.x, st.w, Z);\n";
125 c += " float4 src3 = args.src_tensor.Read<float>(st.z, st.w, Z);\n";
126 c += " FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
127 "t.y));\n";
128 }
129 c += " args.dst_tensor.Write(r0, X, Y, Z);\n";
130 c += "}\n";
131 return c;
132 }
133
BindArguments(ArgumentsBinder * args)134 absl::Status Resize::BindArguments(ArgumentsBinder* args) {
135 RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
136 RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
137 RETURN_IF_ERROR(args->SetFloat(
138 "scale_factor_x",
139 CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
140 RETURN_IF_ERROR(args->SetFloat(
141 "scale_factor_y",
142 CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
143 return absl::OkStatus();
144 }
145
GetGridSize() const146 int3 Resize::GetGridSize() const {
147 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
148 const int grid_y = dst_[0]->Height();
149 const int grid_z = dst_[0]->Slices();
150 return int3(grid_x, grid_y, grid_z);
151 }
152
CreateResize(const OperationDef & definition,const Resize2DAttributes & attr)153 Resize CreateResize(const OperationDef& definition,
154 const Resize2DAttributes& attr) {
155 return Resize(definition, attr);
156 }
157
Resize3D(const OperationDef & definition,const Resize3DAttributes & attr)158 Resize3D::Resize3D(const OperationDef& definition,
159 const Resize3DAttributes& attr)
160 : GPUOperation(definition), attr_(attr) {
161 code_ = GetResize3DCode(definition_, attr_);
162 }
163
Resize3D(Resize3D && operation)164 Resize3D::Resize3D(Resize3D&& operation)
165 : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
166
operator =(Resize3D && operation)167 Resize3D& Resize3D::operator=(Resize3D&& operation) {
168 if (this != &operation) {
169 attr_ = operation.attr_;
170 GPUOperation::operator=(std::move(operation));
171 }
172 return *this;
173 }
174
GetResize3DCode(const OperationDef & op_def,const Resize3DAttributes & attr)175 std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
176 const Resize3DAttributes& attr) {
177 auto src_desc = op_def.src_tensors[0];
178 if (op_def.IsBatchSupported()) {
179 src_desc.SetStateVar("BatchedWidth", "true");
180 }
181 AddSrcTensor("src_tensor", src_desc);
182 auto dst_desc = op_def.dst_tensors[0];
183 if (op_def.IsBatchSupported()) {
184 dst_desc.SetStateVar("BatchedWidth", "true");
185 }
186 AddDstTensor("dst_tensor", dst_desc);
187 args_.AddInt("border_x");
188 args_.AddInt("border_y");
189 args_.AddInt("border_z");
190 args_.AddFloat("scale_factor_x");
191 args_.AddFloat("scale_factor_y");
192 args_.AddFloat("scale_factor_z");
193
194 std::string c;
195 c += "MAIN_FUNCTION($0) {\n";
196 c += " int Y = GLOBAL_ID_1;\n";
197 c += " int linear_id_z = GLOBAL_ID_2;\n";
198 c += " int S = linear_id_z % args.dst_tensor.Slices();\n";
199 c += " int Z = linear_id_z / args.dst_tensor.Slices();\n";
200 if (op_def.IsBatchSupported()) {
201 c += " int linear_id = GLOBAL_ID_0;\n";
202 c += " int X = linear_id / args.dst_tensor.Batch();\n";
203 c += " int B = linear_id % args.dst_tensor.Batch();\n";
204 c += " if (linear_id >= args.dst_tensor.Width() || Y >= "
205 "args.dst_tensor.Height() || Z >= args.dst_tensor.Depth()) return;\n";
206 } else {
207 c += " int X = GLOBAL_ID_0;\n";
208 c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
209 "|| Z >= args.dst_tensor.Depth()) return;\n";
210 }
211 if (attr.type == SamplingType::NEAREST) {
212 std::string fxc;
213 std::string fyc;
214 std::string fzc;
215 if (attr.half_pixel_centers) {
216 fxc = "(X + 0.5f) * args.scale_factor_x";
217 fyc = "(Y + 0.5f) * args.scale_factor_y";
218 fzc = "(Z + 0.5f) * args.scale_factor_z";
219 } else {
220 fxc = "X * args.scale_factor_x";
221 fyc = "Y * args.scale_factor_y";
222 fzc = "Z * args.scale_factor_z";
223 }
224 if (attr.align_corners) {
225 fxc += " + 0.5f";
226 fyc += " + 0.5f";
227 fzc += " + 0.5f";
228 }
229 c += " int4 coord;\n";
230 c += " coord.x = INIT_INT(" + fxc + ");\n";
231 c += " coord.y = INIT_INT(" + fyc + ");\n";
232 c += " coord.z = INIT_INT(" + fzc + ");\n";
233 c += " coord.x = max(0, coord.x);\n";
234 c += " coord.y = max(0, coord.y);\n";
235 c += " coord.z = max(0, coord.z);\n";
236 c += " coord.x = min(coord.x, args.border_x);\n";
237 c += " coord.y = min(coord.y, args.border_y);\n";
238 c += " coord.z = min(coord.z, args.border_z);\n";
239 if (op_def.IsBatchSupported()) {
240 c += " coord.x = coord.x * args.src_tensor.Batch() + B;\n";
241 c += " X = X * args.src_tensor.Batch() + B;\n";
242 }
243 c += " FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, coord.z, S);\n";
244 } else {
245 c += " float4 f_coords;\n";
246 c += " f_coords.x = INIT_FLOAT(X) * args.scale_factor_x;\n";
247 c += " f_coords.y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
248 c += " f_coords.z = INIT_FLOAT(Z) * args.scale_factor_z;\n";
249 c += " int4 start = INIT_INT4v4(f_coords.x, f_coords.y, f_coords.z, 0);\n";
250 c += " int4 end;\n";
251 c += " end.x = min(start.x + 1, args.border_x);\n";
252 c += " end.y = min(start.y + 1, args.border_y);\n";
253 c += " end.z = min(start.z + 1, args.border_z);\n";
254 c += " float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
255 if (op_def.IsBatchSupported()) {
256 c += " start.x = start.x * args.src_tensor.Batch() + B;\n";
257 c += " end.x = end.x * args.src_tensor.Batch() + B;\n";
258 c += " X = X * args.src_tensor.Batch() + B;\n";
259 }
260 c += " float4 src0 = args.src_tensor.Read<float>(start.x, start.y, "
261 "start.z, S);\n";
262 c += " float4 src1 = args.src_tensor.Read<float>(end.x, start.y, start.z, "
263 "S);\n";
264 c += " float4 src2 = args.src_tensor.Read<float>(start.x, end.y, start.z, "
265 "S);\n";
266 c += " float4 src3 = args.src_tensor.Read<float>(end.x, end.y, start.z, "
267 "S);\n";
268 c += " float4 src4 = args.src_tensor.Read<float>(start.x, start.y, end.z, "
269 "S);\n";
270 c += " float4 src5 = args.src_tensor.Read<float>(end.x, start.y, end.z, "
271 "S);\n";
272 c += " float4 src6 = args.src_tensor.Read<float>(start.x, end.y, end.z, "
273 "S);\n";
274 c += " float4 src7 = args.src_tensor.Read<float>(end.x, end.y, end.z, "
275 "S);\n";
276 c +=
277 " float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
278 c +=
279 " float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
280 c += " FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
281 }
282 c += " args.dst_tensor.Write(r0, X, Y, Z, S);\n";
283 c += "}\n";
284 return c;
285 }
286
BindArguments(ArgumentsBinder * args)287 absl::Status Resize3D::BindArguments(ArgumentsBinder* args) {
288 RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
289 RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
290 RETURN_IF_ERROR(args->SetInt("border_z", src_[0]->Depth() - 1));
291 RETURN_IF_ERROR(args->SetFloat(
292 "scale_factor_x",
293 CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
294 RETURN_IF_ERROR(args->SetFloat(
295 "scale_factor_y",
296 CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
297 RETURN_IF_ERROR(args->SetFloat(
298 "scale_factor_z",
299 CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
300 return absl::OkStatus();
301 }
302
GetGridSize() const303 int3 Resize3D::GetGridSize() const {
304 const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
305 const int grid_y = dst_[0]->Height();
306 const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
307 return int3(grid_x, grid_y, grid_z);
308 }
309
CreateResize3D(const OperationDef & definition,const Resize3DAttributes & attr)310 Resize3D CreateResize3D(const OperationDef& definition,
311 const Resize3DAttributes& attr) {
312 return Resize3D(definition, attr);
313 }
314
315 } // namespace gpu
316 } // namespace tflite
317