• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/operations.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28 
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33 
34 class Resize : public NodeShader {
35  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const36   absl::Status GenerateCode(const GenerationContext& ctx,
37                             GeneratedCode* generated_code) const final {
38     const auto& attr = absl::any_cast<const Resize2DAttributes&>(ctx.op_attr);
39 
40     if (ctx.input_shapes[0][2] > ctx.output_shapes[0][2] ||
41         ctx.input_shapes[0][1] > ctx.output_shapes[0][1]) {
42       return absl::InvalidArgumentError("Output size is less than input size.");
43     }
44     if (ctx.output_shapes[0][2] != attr.new_shape.w ||
45         ctx.output_shapes[0][1] != attr.new_shape.h) {
46       return absl::InvalidArgumentError(
47           "Output size does not match new_size in attributes.");
48     }
49     if (ctx.input_shapes[0][3] != ctx.output_shapes[0][3]) {
50       return absl::InvalidArgumentError("Input/output channels mismatch.");
51     }
52     if (ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1) {
53       // Copy a single element from input.
54       *generated_code = {
55           /*parameters=*/{},
56           /*objects=*/{},
57           /*shared_variables=*/{},
58           /*workload=*/uint3(),
59           /*workgroup=*/uint3(),
60           /*source_code=*/"value_0 = $input_data_0[0, 0, gid.z]$;",
61           /*input=*/IOStructure::ONLY_DEFINITIONS,
62           /*output=*/IOStructure::AUTO,
63       };
64       return absl::OkStatus();
65     }
66     std::vector<Variable> parameters = {
67         {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
68         {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
69         {"scale_factor",
70          float2(CalculateResizeScale(ctx.input_shapes[0][2],
71                                      ctx.output_shapes[0][2], attr),
72                 CalculateResizeScale(ctx.input_shapes[0][1],
73                                      ctx.output_shapes[0][1], attr))},
74     };
75 
76     std::string source;
77     if (attr.type == SamplingType::BILINEAR) {
78       if (attr.half_pixel_centers) {
79         source = "vec2 coord = (vec2(gid.xy) + 0.5) * $scale_factor$ - 0.5;";
80       } else {
81         source = "vec2 coord = vec2(gid.xy) * $scale_factor$;";
82       }
83       source += R"(
84       vec2 coord_floor = floor(coord);
85       ivec2 icoord_floor = ivec2(coord_floor);
86       ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
87       ivec4 st;
88       st.xy = max(icoord_floor, ivec2(0, 0));
89       st.zw = min(icoord_floor + ivec2(1, 1), borders);
90 
91       vec2 t = coord - coord_floor; // interpolating factors
92 
93       vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
94       vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
95       vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$;
96       vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$;
97 
98       value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
99     } else if (attr.type == SamplingType::NEAREST) {
100       std::string fxc;
101       std::string fyc;
102       if (attr.half_pixel_centers) {
103         fxc = "(float(gid.x) + 0.5) * $scale_factor.x$";
104         fyc = "(float(gid.y) + 0.5) * $scale_factor.y$";
105       } else {
106         fxc = "float(gid.x) * $scale_factor.x$";
107         fyc = "float(gid.y) * $scale_factor.y$";
108       }
109       if (attr.align_corners) {
110         fxc += " + 0.5";
111         fyc += " + 0.5";
112       }
113       source += "  ivec2 coord;\n";
114       source += "  coord.x = int(" + fxc + ");\n";
115       source += "  coord.y = int(" + fyc + ");\n";
116       source += "  coord.x = max(0, coord.x);\n";
117       source += "  coord.y = max(0, coord.y);\n";
118       source += "  coord.x = min(coord.x, $input_data_0_w$ - 1);\n";
119       source += "  coord.y = min(coord.y, $input_data_0_h$ - 1);\n";
120       source += R"(
121       value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
122       )";
123     } else {
124       return absl::InvalidArgumentError("Unknown sampling type");
125     }
126     *generated_code = {
127         /*parameters=*/std::move(parameters),
128         /*objects=*/{},
129         /*shared_variables=*/{},
130         /*workload=*/uint3(),
131         /*workgroup=*/uint3(),
132         /*source_code=*/std::move(source),
133         /*input=*/IOStructure::ONLY_DEFINITIONS,
134         /*output=*/IOStructure::AUTO,
135     };
136     return absl::OkStatus();
137   }
138 };
139 
140 }  // namespace
141 
NewResizeNodeShader()142 std::unique_ptr<NodeShader> NewResizeNodeShader() {
143   return absl::make_unique<Resize>();
144 }
145 
146 }  // namespace gl
147 }  // namespace gpu
148 }  // namespace tflite
149