1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/operations.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/delegates/gpu/common/types.h"
28
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33
34 class Resize : public NodeShader {
35 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const36 absl::Status GenerateCode(const GenerationContext& ctx,
37 GeneratedCode* generated_code) const final {
38 const auto& attr = absl::any_cast<const Resize2DAttributes&>(ctx.op_attr);
39
40 if (ctx.input_shapes[0][2] > ctx.output_shapes[0][2] ||
41 ctx.input_shapes[0][1] > ctx.output_shapes[0][1]) {
42 return absl::InvalidArgumentError("Output size is less than input size.");
43 }
44 if (ctx.output_shapes[0][2] != attr.new_shape.w ||
45 ctx.output_shapes[0][1] != attr.new_shape.h) {
46 return absl::InvalidArgumentError(
47 "Output size does not match new_size in attributes.");
48 }
49 if (ctx.input_shapes[0][3] != ctx.output_shapes[0][3]) {
50 return absl::InvalidArgumentError("Input/output channels mismatch.");
51 }
52 if (ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1) {
53 // Copy a single element from input.
54 *generated_code = {
55 /*parameters=*/{},
56 /*objects=*/{},
57 /*shared_variables=*/{},
58 /*workload=*/uint3(),
59 /*workgroup=*/uint3(),
60 /*source_code=*/"value_0 = $input_data_0[0, 0, gid.z]$;",
61 /*input=*/IOStructure::ONLY_DEFINITIONS,
62 /*output=*/IOStructure::AUTO,
63 };
64 return absl::OkStatus();
65 }
66 std::vector<Variable> parameters = {
67 {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
68 {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
69 {"scale_factor",
70 float2(CalculateResizeScale(ctx.input_shapes[0][2],
71 ctx.output_shapes[0][2], attr),
72 CalculateResizeScale(ctx.input_shapes[0][1],
73 ctx.output_shapes[0][1], attr))},
74 };
75
76 std::string source;
77 if (attr.type == SamplingType::BILINEAR) {
78 if (attr.half_pixel_centers) {
79 source = "vec2 coord = (vec2(gid.xy) + 0.5) * $scale_factor$ - 0.5;";
80 } else {
81 source = "vec2 coord = vec2(gid.xy) * $scale_factor$;";
82 }
83 source += R"(
84 vec2 coord_floor = floor(coord);
85 ivec2 icoord_floor = ivec2(coord_floor);
86 ivec2 borders = ivec2($input_data_0_w$, $input_data_0_h$) - ivec2(1, 1);
87 ivec4 st;
88 st.xy = max(icoord_floor, ivec2(0, 0));
89 st.zw = min(icoord_floor + ivec2(1, 1), borders);
90
91 vec2 t = coord - coord_floor; // interpolating factors
92
93 vec4 tex11 = $input_data_0[st.x, st.y, gid.z]$;
94 vec4 tex21 = $input_data_0[st.z, st.y, gid.z]$;
95 vec4 tex12 = $input_data_0[st.x, st.w, gid.z]$;
96 vec4 tex22 = $input_data_0[st.z, st.w, gid.z]$;
97
98 value_0 = mix(mix(tex11, tex21, t.x), mix(tex12, tex22, t.x), t.y);)";
99 } else if (attr.type == SamplingType::NEAREST) {
100 std::string fxc;
101 std::string fyc;
102 if (attr.half_pixel_centers) {
103 fxc = "(float(gid.x) + 0.5) * $scale_factor.x$";
104 fyc = "(float(gid.y) + 0.5) * $scale_factor.y$";
105 } else {
106 fxc = "float(gid.x) * $scale_factor.x$";
107 fyc = "float(gid.y) * $scale_factor.y$";
108 }
109 if (attr.align_corners) {
110 fxc += " + 0.5";
111 fyc += " + 0.5";
112 }
113 source += " ivec2 coord;\n";
114 source += " coord.x = int(" + fxc + ");\n";
115 source += " coord.y = int(" + fyc + ");\n";
116 source += " coord.x = max(0, coord.x);\n";
117 source += " coord.y = max(0, coord.y);\n";
118 source += " coord.x = min(coord.x, $input_data_0_w$ - 1);\n";
119 source += " coord.y = min(coord.y, $input_data_0_h$ - 1);\n";
120 source += R"(
121 value_0 = $input_data_0[coord.x, coord.y, gid.z]$;
122 )";
123 } else {
124 return absl::InvalidArgumentError("Unknown sampling type");
125 }
126 *generated_code = {
127 /*parameters=*/std::move(parameters),
128 /*objects=*/{},
129 /*shared_variables=*/{},
130 /*workload=*/uint3(),
131 /*workgroup=*/uint3(),
132 /*source_code=*/std::move(source),
133 /*input=*/IOStructure::ONLY_DEFINITIONS,
134 /*output=*/IOStructure::AUTO,
135 };
136 return absl::OkStatus();
137 }
138 };
139
140 } // namespace
141
NewResizeNodeShader()142 std::unique_ptr<NodeShader> NewResizeNodeShader() {
143 return absl::make_unique<Resize>();
144 }
145
146 } // namespace gl
147 } // namespace gpu
148 } // namespace tflite
149