• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/tasks/resize.h"
17 
18 #include <string>
19 #include <utility>
20 
21 #include "tensorflow/lite/delegates/gpu/common/operations.h"
22 #include "tensorflow/lite/delegates/gpu/common/task/work_group_picking.h"
23 
24 namespace tflite {
25 namespace gpu {
26 
Resize(const OperationDef & definition,const Resize2DAttributes & attr)27 Resize::Resize(const OperationDef& definition, const Resize2DAttributes& attr)
28     : GPUOperation(definition), attr_(attr) {
29   code_ = GetResizeCode(definition_, attr_);
30 }
31 
Resize(Resize && operation)32 Resize::Resize(Resize&& operation)
33     : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
34 
operator =(Resize && operation)35 Resize& Resize::operator=(Resize&& operation) {
36   if (this != &operation) {
37     attr_ = operation.attr_;
38     GPUOperation::operator=(std::move(operation));
39   }
40   return *this;
41 }
42 
GetResizeCode(const OperationDef & op_def,const Resize2DAttributes & attr)43 std::string Resize::GetResizeCode(const OperationDef& op_def,
44                                   const Resize2DAttributes& attr) {
45   auto src_desc = op_def.src_tensors[0];
46   if (op_def.IsBatchSupported()) {
47     src_desc.SetStateVar("BatchedWidth", "true");
48   }
49   AddSrcTensor("src_tensor", src_desc);
50   auto dst_desc = op_def.dst_tensors[0];
51   if (op_def.IsBatchSupported()) {
52     dst_desc.SetStateVar("BatchedWidth", "true");
53   }
54   AddDstTensor("dst_tensor", dst_desc);
55   args_.AddInt("border_x");
56   args_.AddInt("border_y");
57   args_.AddFloat("scale_factor_x");
58   args_.AddFloat("scale_factor_y");
59 
60   std::string c;
61   c += "MAIN_FUNCTION($0) {\n";
62   c += "  int Y = GLOBAL_ID_1;\n";
63   c += "  int Z = GLOBAL_ID_2;\n";
64   if (op_def.IsBatchSupported()) {
65     c += "  int linear_id = GLOBAL_ID_0;\n";
66     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
67     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
68     c += "  if (linear_id >= args.dst_tensor.Width() || Y >= "
69          "args.dst_tensor.Height() || Z >= args.dst_tensor.Slices()) return;\n";
70   } else {
71     c += "  int X = GLOBAL_ID_0;\n";
72     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
73          "|| Z >= args.dst_tensor.Slices()) return;\n";
74   }
75   if (attr.type == SamplingType::NEAREST) {
76     std::string fxc;
77     std::string fyc;
78     if (attr.half_pixel_centers) {
79       fxc = "(INIT_FLOAT(X) + 0.5f) * args.scale_factor_x";
80       fyc = "(INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y";
81     } else {
82       fxc = "INIT_FLOAT(X) * args.scale_factor_x";
83       fyc = "INIT_FLOAT(Y) * args.scale_factor_y";
84     }
85     if (attr.align_corners) {
86       fxc += " + 0.5f";
87       fyc += " + 0.5f";
88     }
89     c += "  int2 coord;\n";
90     c += "  coord.x = INIT_INT(" + fxc + ");\n";
91     c += "  coord.y = INIT_INT(" + fyc + ");\n";
92     c += "  coord.x = max(0, coord.x);\n";
93     c += "  coord.y = max(0, coord.y);\n";
94     c += "  coord.x = min(coord.x, args.border_x);\n";
95     c += "  coord.y = min(coord.y, args.border_y);\n";
96     if (op_def.IsBatchSupported()) {
97       c += "  coord.x = coord.x * args.src_tensor.Batch() + B;\n";
98       c += "  X = X * args.src_tensor.Batch() + B;\n";
99     }
100     c += "  FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, Z);\n";
101   } else {
102     if (attr.half_pixel_centers) {
103       c += "  float2 f_coords = (INIT_FLOAT2v2(X, Y) + 0.5f) * "
104            "INIT_FLOAT2v2(args.scale_factor_x, args.scale_factor_y) - "
105            "0.5f;\n";
106     } else {
107       c += "  float2 f_coords = INIT_FLOAT2v2(X, Y) * "
108            "INIT_FLOAT2v2(args.scale_factor_x, "
109            "args.scale_factor_y);\n";
110     }
111     c += "  float2 f_coords_floor = floor(f_coords);\n";
112     c += "  int2 coords_floor = INIT_INT2v2(f_coords_floor.x, "
113          "f_coords_floor.y);\n";
114     c += "  int4 st;\n";
115     c += "  st.xy = max(coords_floor, INIT_INT2v2(0, 0));\n";
116     c += "  st.zw = min(coords_floor + INIT_INT2v2(1, 1), "
117          "INIT_INT2v2(args.border_x, "
118          "args.border_y));\n";
119     c += "  float2 t = f_coords - f_coords_floor;\n";
120     if (op_def.IsBatchSupported()) {
121       c += "  st.x = st.x * args.src_tensor.Batch() + B;\n";
122       c += "  st.z = st.z * args.src_tensor.Batch() + B;\n";
123       c += "  X = X * args.src_tensor.Batch() + B;\n";
124     }
125     c += "  float4 src0 = args.src_tensor.Read<float>(st.x, st.y, Z);\n";
126     c += "  float4 src1 = args.src_tensor.Read<float>(st.z, st.y, Z);\n";
127     c += "  float4 src2 = args.src_tensor.Read<float>(st.x, st.w, Z);\n";
128     c += "  float4 src3 = args.src_tensor.Read<float>(st.z, st.w, Z);\n";
129     c += "  FLT4 r0 = TO_FLT4(mix(mix(src0, src1, t.x), mix(src2, src3, t.x), "
130          "t.y));\n";
131   }
132   c += "  args.dst_tensor.Write(r0, X, Y, Z);\n";
133   c += "}\n";
134   return c;
135 }
136 
BindArguments(ArgumentsBinder * args)137 absl::Status Resize::BindArguments(ArgumentsBinder* args) {
138   RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
139   RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
140   RETURN_IF_ERROR(args->SetFloat(
141       "scale_factor_x",
142       CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
143   RETURN_IF_ERROR(args->SetFloat(
144       "scale_factor_y",
145       CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
146   return absl::OkStatus();
147 }
148 
GetGridSize() const149 int3 Resize::GetGridSize() const {
150   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
151   const int grid_y = dst_[0]->Height();
152   const int grid_z = dst_[0]->Slices();
153   return int3(grid_x, grid_y, grid_z);
154 }
155 
CreateResize(const OperationDef & definition,const Resize2DAttributes & attr)156 Resize CreateResize(const OperationDef& definition,
157                     const Resize2DAttributes& attr) {
158   return Resize(definition, attr);
159 }
160 
Resize3D(const OperationDef & definition,const Resize3DAttributes & attr)161 Resize3D::Resize3D(const OperationDef& definition,
162                    const Resize3DAttributes& attr)
163     : GPUOperation(definition), attr_(attr) {
164   code_ = GetResize3DCode(definition_, attr_);
165 }
166 
Resize3D(Resize3D && operation)167 Resize3D::Resize3D(Resize3D&& operation)
168     : GPUOperation(std::move(operation)), attr_(operation.attr_) {}
169 
operator =(Resize3D && operation)170 Resize3D& Resize3D::operator=(Resize3D&& operation) {
171   if (this != &operation) {
172     attr_ = operation.attr_;
173     GPUOperation::operator=(std::move(operation));
174   }
175   return *this;
176 }
177 
GetResize3DCode(const OperationDef & op_def,const Resize3DAttributes & attr)178 std::string Resize3D::GetResize3DCode(const OperationDef& op_def,
179                                       const Resize3DAttributes& attr) {
180   auto src_desc = op_def.src_tensors[0];
181   if (op_def.IsBatchSupported()) {
182     src_desc.SetStateVar("BatchedWidth", "true");
183   }
184   AddSrcTensor("src_tensor", src_desc);
185   auto dst_desc = op_def.dst_tensors[0];
186   if (op_def.IsBatchSupported()) {
187     dst_desc.SetStateVar("BatchedWidth", "true");
188   }
189   AddDstTensor("dst_tensor", dst_desc);
190   args_.AddInt("border_x");
191   args_.AddInt("border_y");
192   args_.AddInt("border_z");
193   args_.AddFloat("scale_factor_x");
194   args_.AddFloat("scale_factor_y");
195   args_.AddFloat("scale_factor_z");
196 
197   std::string c;
198   c += "MAIN_FUNCTION($0) {\n";
199   c += "  int Y = GLOBAL_ID_1;\n";
200   c += "  int linear_id_z = GLOBAL_ID_2;\n";
201   c += "  int S = linear_id_z % args.dst_tensor.Slices();\n";
202   c += "  int Z = linear_id_z / args.dst_tensor.Slices();\n";
203   if (op_def.IsBatchSupported()) {
204     c += "  int linear_id = GLOBAL_ID_0;\n";
205     c += "  int X = linear_id / args.dst_tensor.Batch();\n";
206     c += "  int B = linear_id % args.dst_tensor.Batch();\n";
207     c += "  if (linear_id >= args.dst_tensor.Width() || Y >= "
208          "args.dst_tensor.Height() || Z >= args.dst_tensor.Depth()) return;\n";
209   } else {
210     c += "  int X = GLOBAL_ID_0;\n";
211     c += "  if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
212          "|| Z >= args.dst_tensor.Depth()) return;\n";
213   }
214   if (attr.type == SamplingType::NEAREST) {
215     std::string fxc;
216     std::string fyc;
217     std::string fzc;
218     if (attr.half_pixel_centers) {
219       fxc = "(INIT_FLOAT(X) + 0.5f) * args.scale_factor_x";
220       fyc = "(INIT_FLOAT(Y) + 0.5f) * args.scale_factor_y";
221       fzc = "(INIT_FLOAT(Z) + 0.5f) * args.scale_factor_z";
222     } else {
223       fxc = "INIT_FLOAT(X) * args.scale_factor_x";
224       fyc = "INIT_FLOAT(Y) * args.scale_factor_y";
225       fzc = "INIT_FLOAT(Z) * args.scale_factor_z";
226     }
227     if (attr.align_corners) {
228       fxc += " + 0.5f";
229       fyc += " + 0.5f";
230       fzc += " + 0.5f";
231     }
232     c += "  int4 coord;\n";
233     c += "  coord.x = INIT_INT(" + fxc + ");\n";
234     c += "  coord.y = INIT_INT(" + fyc + ");\n";
235     c += "  coord.z = INIT_INT(" + fzc + ");\n";
236     c += "  coord.x = max(0, coord.x);\n";
237     c += "  coord.y = max(0, coord.y);\n";
238     c += "  coord.z = max(0, coord.z);\n";
239     c += "  coord.x = min(coord.x, args.border_x);\n";
240     c += "  coord.y = min(coord.y, args.border_y);\n";
241     c += "  coord.z = min(coord.z, args.border_z);\n";
242     if (op_def.IsBatchSupported()) {
243       c += "  coord.x = coord.x * args.src_tensor.Batch() + B;\n";
244       c += "  X = X * args.src_tensor.Batch() + B;\n";
245     }
246     c += "  FLT4 r0 = args.src_tensor.Read(coord.x, coord.y, coord.z, S);\n";
247   } else {
248     c += "  float4 f_coords;\n";
249     c += "  f_coords.x = INIT_FLOAT(X) * args.scale_factor_x;\n";
250     c += "  f_coords.y = INIT_FLOAT(Y) * args.scale_factor_y;\n";
251     c += "  f_coords.z = INIT_FLOAT(Z) * args.scale_factor_z;\n";
252     c += "  int4 start = INIT_INT4v4(f_coords.x, f_coords.y, f_coords.z, 0);\n";
253     c += "  int4 end;\n";
254     c += "  end.x = min(start.x + 1, args.border_x);\n";
255     c += "  end.y = min(start.y + 1, args.border_y);\n";
256     c += "  end.z = min(start.z + 1, args.border_z);\n";
257     c += "  float4 t = f_coords - (float4)(start.x, start.y, start.z, 0.0f);\n";
258     if (op_def.IsBatchSupported()) {
259       c += "  start.x = start.x * args.src_tensor.Batch() + B;\n";
260       c += "  end.x = end.x * args.src_tensor.Batch() + B;\n";
261       c += "  X = X * args.src_tensor.Batch() + B;\n";
262     }
263     c += "  float4 src0 = args.src_tensor.Read<float>(start.x, start.y, "
264          "start.z, S);\n";
265     c += "  float4 src1 = args.src_tensor.Read<float>(end.x, start.y, start.z, "
266          "S);\n";
267     c += "  float4 src2 = args.src_tensor.Read<float>(start.x, end.y, start.z, "
268          "S);\n";
269     c += "  float4 src3 = args.src_tensor.Read<float>(end.x, end.y, start.z, "
270          "S);\n";
271     c += "  float4 src4 = args.src_tensor.Read<float>(start.x, start.y, end.z, "
272          "S);\n";
273     c += "  float4 src5 = args.src_tensor.Read<float>(end.x, start.y, end.z, "
274          "S);\n";
275     c += "  float4 src6 = args.src_tensor.Read<float>(start.x, end.y, end.z, "
276          "S);\n";
277     c += "  float4 src7 = args.src_tensor.Read<float>(end.x, end.y, end.z, "
278          "S);\n";
279     c +=
280         "  float4 t0 = mix(mix(src0, src1, t.x), mix(src2, src3, t.x), t.y);\n";
281     c +=
282         "  float4 t1 = mix(mix(src4, src5, t.x), mix(src6, src7, t.x), t.y);\n";
283     c += "  FLT4 r0 = TO_FLT4(mix(t0, t1, t.z));\n";
284   }
285   c += "  args.dst_tensor.Write(r0, X, Y, Z, S);\n";
286   c += "}\n";
287   return c;
288 }
289 
BindArguments(ArgumentsBinder * args)290 absl::Status Resize3D::BindArguments(ArgumentsBinder* args) {
291   RETURN_IF_ERROR(args->SetInt("border_x", src_[0]->Width() - 1));
292   RETURN_IF_ERROR(args->SetInt("border_y", src_[0]->Height() - 1));
293   RETURN_IF_ERROR(args->SetInt("border_z", src_[0]->Depth() - 1));
294   RETURN_IF_ERROR(args->SetFloat(
295       "scale_factor_x",
296       CalculateResizeScale(src_[0]->Width(), dst_[0]->Width(), attr_)));
297   RETURN_IF_ERROR(args->SetFloat(
298       "scale_factor_y",
299       CalculateResizeScale(src_[0]->Height(), dst_[0]->Height(), attr_)));
300   RETURN_IF_ERROR(args->SetFloat(
301       "scale_factor_z",
302       CalculateResizeScale(src_[0]->Depth(), dst_[0]->Depth(), attr_)));
303   return absl::OkStatus();
304 }
305 
GetGridSize() const306 int3 Resize3D::GetGridSize() const {
307   const int grid_x = dst_[0]->Width() * dst_[0]->Batch();
308   const int grid_y = dst_[0]->Height();
309   const int grid_z = dst_[0]->Slices() * dst_[0]->Depth();
310   return int3(grid_x, grid_y, grid_z);
311 }
312 
CreateResize3D(const OperationDef & definition,const Resize3DAttributes & attr)313 Resize3D CreateResize3D(const OperationDef& definition,
314                         const Resize3DAttributes& attr) {
315   return Resize3D(definition, attr);
316 }
317 
318 }  // namespace gpu
319 }  // namespace tflite
320