• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resize.h"
17 
18 #include "tensorflow/lite/delegates/gpu/common/operations.h"
19 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
20 
21 namespace tflite {
22 namespace gpu {
23 
Resize(const OperationDef & definition,const Resize2DAttributes & attr)24 Resize::Resize(const OperationDef& definition, const Resize2DAttributes& attr)
25     : GPUOperation(definition), attr_(attr) {
26   code_ = GetResizeCode(definition_, attr_);
27 }
28 
Resize(Resize && operation)29 Resize::Resize(Resize&& operation)
30     : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
31 
operator =(Resize && operation)32 Resize& Resize::operator=(Resize&& operation) {
33   if (this != &operation) {
34     attr_ = operation.attr_;
35     GPUOperation::operator=(std::move(operation));
36   }
37   return *this;
38 }
39 
GetResizeCode(const OperationDef & op_def,const Resize2DAttributes & attr)40 std::string Resize::GetResizeCode(const OperationDef& op_def,
41                                   const Resize2DAttributes& attr) {
42   auto src_desc = op_def.src_tensors[0];
43   if (op_def.IsBatchSupported()) {
44     src_desc.SetStateVar("BatchedWidth", "true");
45   }
46   AddSrcTensor("src_tensor", src_desc);
47   auto dst_desc = op_def.dst_tensors[0];
48   if (op_def.IsBatchSupported()) {
49     dst_desc.SetStateVar("BatchedWidth", "true");
50   }
51   AddDstTensor("dst_tensor", dst_desc);
52   args_.AddInt("border_x");
53   args_.AddInt("border_y");
54   args_.AddFloat("scale_factor_x");
55   args_.AddFloat("scale_factor_y");
56 
57   std::string c;
58   c += "MAIN_FUNCTION($0) {\n";
59   c += "  int Y = GLOBAL_ID_1;\n";
60   c += "  int Z = GLOBAL_ID_2;\n";
61   if (op_def.IsBatchSupported()) {
62     c += "  int linear_id = GLOBAL_ID_0;\n";
63     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
64     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
65     c += "  if (linear_id >= args.dst_tensor.Width() || Y >= "
66          "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
67   } else {
68     c += "  int X = GLOBAL_ID_0;\n";
69     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
70          "|| Z >= args.dst_tensor.Slices()) return;\n";
71   }
72   if (attr.type == SamplingType::NEAREST) {
73     std::string fxc;
74     std::string fyc;
75     if (attr.half_pixel_centers) {
76       fxc = "(X + 0.5f) * args.scale_factor_x";
77       fyc = "(Y + 0.5f) * args.scale_factor_y";
78     } else {
79       fxc = "X * args.scale_factor_x";
80       fyc = "Y * args.scale_factor_y";
81     }
82     if (attr.align_corners) {
83       fxc += " + 0.5f";
84       fyc += " + 0.5f";
85     }
86     c += "  int2 coord;\n";
87     c += "  coord.x = INIT_INT(" + fxc + ");\n";
88     c += "  coord.y = INIT_INT(" + fyc + ");\n";
89     c += "  coord.x = max(0, coord.x);\n";
90     c += "  coord.y = max(0, coord.y);\n";
91     c += "  coord.x = min(coord.x, args.border_x);\n";
92     c += "  coord.y = min(coord.y, args.border_y);\n";
93     if (op_def.IsBatchSupported()) {
94       c += "  coord.x = coord.x * args.src_tensor.Batch() + B;\n";
95       c += "  X = X * args.src_tensor.Batch() + B;\n";
96     }
97     c += "  FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
98   } else {
99     if (attr.half_pixel_centers) {
100       c += "  float2 f_coords = (INIT_FLOAT2v2(X, Y) + 0.5f) * "
101            "INIT_FLOAT2v2(args.scale_factor_x, args.scale_factor_y) - "
102            "0.5f;\n";
103     } else {
104       c += "  float2 f_coords = INIT_FLOAT2v2(X, Y) * "
105            "INIT_FLOAT2v2(args.scale_factor_x, "
106            "args.scale_factor_y);\n";
107     }
108     c += "  float2 f_coords_floor = floor(f_coords);\n";
109     c += "  int2 coords_floor = INIT_INT2v2(f_coords_floor.x, "
110          "f_coords_floor.y);\n";
111     c += "  int4 st;\n";
112     c += "  st.xy = max(coords_floor, INIT_INT2v2(0, 0));\n";
113     c += "  st.zw = min(coords_floor + INIT_INT2v2(1, 1), "
114          "INIT_INT2v2(args.border_x, "
115          "args.border_y));\n";
116     c += "  float2 t = f_coords - f_coords_floor;\n";
117     if (op_def.IsBatchSupported()) {
118       c += "  st.x = st.x * args.src_tensor.Batch() + B;\n";
119       c += "  st.z = st.z * args.src_tensor.Batch() + B;\n";
120       c += "  X = X * args.src_tensor.Batch() + B;\n";
121     }
122     c += "  float4 src0 = args.src_tensor.Read<float>(st.x, st.y, Z);\n";
123     c += "  float4 src1 = args.src_tensor.Read<float>(st.z, st.y, Z);\n";
124     c += "  float4 src2 = args.src_tensor.Read<float>(st.x, st.w, Z);\n";
125     c += "  float4 src3 = args.src_tensor.Read<float>(st.z, st.w, Z);\n";
126     c += "  FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
127          "t.y));\n";
128   }
129   c += "  args.dst_tensor.Write(r0, X, Y, Z);\n";
130   c += "}\n";
131   return c;
132 }
133 
BindArguments(ArgumentsBinder * args)134 absl::Status Resize::BindArguments(ArgumentsBinder* args) {
135   RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
136   RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
137   RETURN_IF_ERROR(args->SetFloat(
138       "scale_factor_x",
139       CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
140   RETURN_IF_ERROR(args->SetFloat(
141       "scale_factor_y",
142       CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
143   return absl::OkStatus();
144 }
145 
GetGridSize() const146 int3 Resize::GetGridSize() const {
147   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
148   const int grid_y = dst_[0]->Height();
149   const int grid_z = dst_[0]->Slices();
150   return int3(grid_x, grid_y, grid_z);
151 }
152 
CreateResize(const OperationDef & definition,const Resize2DAttributes & attr)153 Resize CreateResize(const OperationDef& definition,
154                     const Resize2DAttributes& attr) {
155   return Resize(definition, attr);
156 }
157 
Resize3D(const OperationDef & definition,const Resize3DAttributes & attr)158 Resize3D::Resize3D(const OperationDef& definition,
159                    const Resize3DAttributes& attr)
160     : GPUOperation(definition), attr_(attr) {
161   code_ = GetResize3DCode(definition_, attr_);
162 }
163 
Resize3D(Resize3D && operation)164 Resize3D::Resize3D(Resize3D&& operation)
165     : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
166 
operator =(Resize3D && operation)167 Resize3D& Resize3D::operator=(Resize3D&& operation) {
168   if (this != &operation) {
169     attr_ = operation.attr_;
170     GPUOperation::operator=(std::move(operation));
171   }
172   return *this;
173 }
174 
GetResize3DCode(const OperationDef & op_def,const Resize3DAttributes & attr)175 std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
176                                       const Resize3DAttributes& attr) {
177   auto src_desc = op_def.src_tensors[0];
178   if (op_def.IsBatchSupported()) {
179     src_desc.SetStateVar("BatchedWidth", "true");
180   }
181   AddSrcTensor("src_tensor", src_desc);
182   auto dst_desc = op_def.dst_tensors[0];
183   if (op_def.IsBatchSupported()) {
184     dst_desc.SetStateVar("BatchedWidth", "true");
185   }
186   AddDstTensor("dst_tensor", dst_desc);
187   args_.AddInt("border_x");
188   args_.AddInt("border_y");
189   args_.AddInt("border_z");
190   args_.AddFloat("scale_factor_x");
191   args_.AddFloat("scale_factor_y");
192   args_.AddFloat("scale_factor_z");
193 
194   std::string c;
195   c += "MAIN_FUNCTION($0) {\n";
196   c += "  int Y = GLOBAL_ID_1;\n";
197   c += "  int linear_id_z = GLOBAL_ID_2;\n";
198   c += "  int S = linear_id_z % args.dst_tensor.Slices();\n";
199   c += "  int Z = linear_id_z / args.dst_tensor.Slices();\n";
200   if (op_def.IsBatchSupported()) {
201     c += "  int linear_id = GLOBAL_ID_0;\n";
202     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
203     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
204     c += "  if (linear_id >= args.dst_tensor.Width() || Y >= "
205          "args.dst_tensor.Height() || Z >= args.dst_tensor.Depth()) return;\n";
206   } else {
207     c += "  int X = GLOBAL_ID_0;\n";
208     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
209          "|| Z >= args.dst_tensor.Depth()) return;\n";
210   }
211   if (attr.type == SamplingType::NEAREST) {
212     std::string fxc;
213     std::string fyc;
214     std::string fzc;
215     if (attr.half_pixel_centers) {
216       fxc = "(X + 0.5f) * args.scale_factor_x";
217       fyc = "(Y + 0.5f) * args.scale_factor_y";
218       fzc = "(Z + 0.5f) * args.scale_factor_z";
219     } else {
220       fxc = "X * args.scale_factor_x";
221       fyc = "Y * args.scale_factor_y";
222       fzc = "Z * args.scale_factor_z";
223     }
224     if (attr.align_corners) {
225       fxc += " + 0.5f";
226       fyc += " + 0.5f";
227       fzc += " + 0.5f";
228     }
229     c += "  int4 coord;\n";
230     c += "  coord.x = INIT_INT(" + fxc + ");\n";
231     c += "  coord.y = INIT_INT(" + fyc + ");\n";
232     c += "  coord.z = INIT_INT(" + fzc + ");\n";
233     c += "  coord.x = max(0, coord.x);\n";
234     c += "  coord.y = max(0, coord.y);\n";
235     c += "  coord.z = max(0, coord.z);\n";
236     c += "  coord.x = min(coord.x, args.border_x);\n";
237     c += "  coord.y = min(coord.y, args.border_y);\n";
238     c += "  coord.z = min(coord.z, args.border_z);\n";
239     if (op_def.IsBatchSupported()) {
240       c += "  coord.x = coord.x * args.src_tensor.Batch() + B;\n";
241       c += "  X = X * args.src_tensor.Batch() + B;\n";
242     }
243     c += "  FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, coord.z, S);\n";
244   } else {
245     c += "  float4 f_coords;\n";
246     c += "  f_coords.x = INIT_FLOAT(X) * args.scale_factor_x;\n";
247     c += "  f_coords.y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
248     c += "  f_coords.z = INIT_FLOAT(Z) * args.scale_factor_z;\n";
249     c += "  int4 start = INIT_INT4v4(f_coords.x, f_coords.y, f_coords.z, 0);\n";
250     c += "  int4 end;\n";
251     c += "  end.x = min(start.x + 1, args.border_x);\n";
252     c += "  end.y = min(start.y + 1, args.border_y);\n";
253     c += "  end.z = min(start.z + 1, args.border_z);\n";
254     c += "  float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
255     if (op_def.IsBatchSupported()) {
256       c += "  start.x = start.x * args.src_tensor.Batch() + B;\n";
257       c += "  end.x = end.x * args.src_tensor.Batch() + B;\n";
258       c += "  X = X * args.src_tensor.Batch() + B;\n";
259     }
260     c += "  float4 src0 = args.src_tensor.Read<float>(start.x, start.y, "
261          "start.z, S);\n";
262     c += "  float4 src1 = args.src_tensor.Read<float>(end.x, start.y, start.z, "
263          "S);\n";
264     c += "  float4 src2 = args.src_tensor.Read<float>(start.x, end.y, start.z, "
265          "S);\n";
266     c += "  float4 src3 = args.src_tensor.Read<float>(end.x, end.y, start.z, "
267          "S);\n";
268     c += "  float4 src4 = args.src_tensor.Read<float>(start.x, start.y, end.z, "
269          "S);\n";
270     c += "  float4 src5 = args.src_tensor.Read<float>(end.x, start.y, end.z, "
271          "S);\n";
272     c += "  float4 src6 = args.src_tensor.Read<float>(start.x, end.y, end.z, "
273          "S);\n";
274     c += "  float4 src7 = args.src_tensor.Read<float>(end.x, end.y, end.z, "
275          "S);\n";
276     c +=
277         "  float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
278     c +=
279         "  float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
280     c += "  FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
281   }
282   c += "  args.dst_tensor.Write(r0, X, Y, Z, S);\n";
283   c += "}\n";
284   return c;
285 }
286 
BindArguments(ArgumentsBinder * args)287 absl::Status Resize3D::BindArguments(ArgumentsBinder* args) {
288   RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
289   RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
290   RETURN_IF_ERROR(args->SetInt("border_z", src_[0]->Depth() - 1));
291   RETURN_IF_ERROR(args->SetFloat(
292       "scale_factor_x",
293       CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
294   RETURN_IF_ERROR(args->SetFloat(
295       "scale_factor_y",
296       CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
297   RETURN_IF_ERROR(args->SetFloat(
298       "scale_factor_z",
299       CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
300   return absl::OkStatus();
301 }
302 
GetGridSize() const303 int3 Resize3D::GetGridSize() const {
304   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
305   const int grid_y = dst_[0]->Height();
306   const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
307   return int3(grid_x, grid_y, grid_z);
308 }
309 
CreateResize3D(const OperationDef & definition,const Resize3DAttributes & attr)310 Resize3D CreateResize3D(const OperationDef& definition,
311                         const Resize3DAttributes& attr) {
312   return Resize3D(definition, attr);
313 }
314 
315 }  // namespace gpu
316 }  // namespace tflite
317