1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
28
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33
GenerateMaxPoolingCode(const Pooling2DAttributes & attr,const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)34 absl::Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr,
35 const NodeShader::GenerationContext& ctx,
36 GeneratedCode* generated_code) {
37 if (attr.padding.prepended.h > attr.kernel.h ||
38 attr.padding.prepended.w > attr.kernel.w) {
39 return absl::InvalidArgumentError("Padding is bigger than kernel.");
40 }
41
42 std::vector<Variable> parameters = {
43 {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
44 {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
45 {"stride", int2(attr.strides.w, attr.strides.h)},
46 {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
47 {"window_h", attr.kernel.h},
48 {"window_w", attr.kernel.w},
49 };
50
51 // Per GLSL_ES 3.1 spec in Issue 13.4
52 // "Floating Point Representation and Functionality" highp floats are
53 // expected to behave as defined in IEEE 754. In particular, signed
54 // infinities are mandated and defined as a number divided by 0.
55 std::string source = R"(
56 const highp float inf = -(1.0f / 0.0f);
57 value_0 = vec4(inf);)";
58 if (attr.output_indices) {
59 source += R"(
60 ivec4 value_1;
61 )";
62 }
63 source += R"(
64 ivec2 base_coord = gid.xy * $stride$ - $offset$;
65 for (int a = 0; a < $window_h$; ++a) {
66 for (int b = 0; b < $window_w$; ++b) {
67 ivec2 coord = base_coord + ivec2(b, a);
68 if (coord.x < 0 || coord.y < 0 || coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) {
69 continue;
70 }
71 vec4 input_ = $input_data_0[coord.x, coord.y, gid.z]$;)";
72 if (attr.output_indices) {
73 source += R"(
74 int window_index = a * $window_w$ + b;
75 if (input_.x > value_0.x) value_1.x = window_index;
76 if (input_.y > value_0.y) value_1.y = window_index;
77 if (input_.z > value_0.z) value_1.z = window_index;
78 if (input_.w > value_0.w) value_1.w = window_index;)";
79 }
80 source += R"(
81 value_0 = max(value_0, input_);
82 }
83 }
84 )";
85 *generated_code = {
86 /*parameters=*/std::move(parameters),
87 /*objects=*/{},
88 /*shared_variables=*/{},
89 /*workload=*/uint3(),
90 /*workgroup=*/uint3(),
91 /*source_code=*/std::move(source),
92 /*input=*/IOStructure::ONLY_DEFINITIONS,
93 /*output=*/IOStructure::AUTO,
94 };
95 return absl::OkStatus();
96 }
97
GenerateAveragePoolingCode(const Pooling2DAttributes & attr,const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)98 absl::Status GenerateAveragePoolingCode(
99 const Pooling2DAttributes& attr, const NodeShader::GenerationContext& ctx,
100 GeneratedCode* generated_code) {
101 std::vector<Variable> parameters = {
102 {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
103 {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
104 {"stride", int2(attr.strides.w, attr.strides.h)},
105 {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
106 {"window_h", attr.kernel.h},
107 {"window_w", attr.kernel.w},
108 };
109
110 // Bounds checking helper functions.
111 auto x_in_bounds = [input_width = ctx.input_shapes[0][2],
112 kernel_width = attr.kernel.w](int64_t x) -> bool {
113 return 0 <= x && x + kernel_width <= input_width;
114 };
115 auto y_in_bounds = [input_height = ctx.input_shapes[0][1],
116 kernel_height = attr.kernel.h](int64_t y) -> bool {
117 return 0 <= y && y + kernel_height <= input_height;
118 };
119
120 // Only include a bounds check in the shader if it will actually be necessary
121 // at run time.
122 const int64_t output_shape_max_y = ctx.output_shapes[0][1] - 1;
123 const int64_t output_shape_max_x = ctx.output_shapes[0][2] - 1;
124 const int64_t base_x = -attr.padding.prepended.w;
125 const int64_t base_y = -attr.padding.prepended.h;
126 const bool bounds_check_necessary =
127 !(x_in_bounds(base_x) &&
128 x_in_bounds(base_x + output_shape_max_x * attr.strides.w) &&
129 y_in_bounds(base_y) &&
130 y_in_bounds(base_y + output_shape_max_y * attr.strides.h));
131
132 std::string source = bounds_check_necessary ?
133 R"(
134 int window_size = 0;
135 for (int a = 0; a < $window_h$; ++a) {
136 for (int b = 0; b < $window_w$; ++b) {
137 ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
138 if (coord.x >= 0 && coord.y >= 0 && coord.x < $input_data_0_w$ && coord.y < $input_data_0_h$) {
139 value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
140 window_size++;
141 }
142 }
143 }
144 // If window_size==0, window covered nothing. This situation is a sign of
145 // incorrectly constructed operation. NaNs are expected as output.
146 value_0 /= float(window_size);
147 )"
148 :
149 R"(
150 for (int a = 0; a < $window_h$; ++a) {
151 for (int b = 0; b < $window_w$; ++b) {
152 ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
153 value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
154 }
155 }
156 // If the denominator is 0, that is a sign of an incorrectly constructed
157 // operation. NaNs are expected as output.
158 value_0 /= float($window_h$ * $window_w$);
159 )";
160
161 *generated_code = {
162 /*parameters=*/std::move(parameters),
163 /*objects=*/{},
164 /*shared_variables=*/{},
165 /*workload=*/uint3(),
166 /*workgroup=*/uint3(),
167 /*source_code=*/std::move(source),
168 /*input=*/IOStructure::ONLY_DEFINITIONS,
169 /*output=*/IOStructure::AUTO,
170 };
171 return absl::OkStatus();
172 }
173
174 class Pooling : public NodeShader {
175 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const176 absl::Status GenerateCode(const GenerationContext& ctx,
177 GeneratedCode* generated_code) const final {
178 const auto& attr = absl::any_cast<const Pooling2DAttributes&>(ctx.op_attr);
179 switch (attr.type) {
180 case PoolingType::AVERAGE:
181 return GenerateAveragePoolingCode(attr, ctx, generated_code);
182 case PoolingType::MAX:
183 return GenerateMaxPoolingCode(attr, ctx, generated_code);
184 default:
185 return absl::InvalidArgumentError("Incorrect attributes' type.");
186 }
187 }
188 };
189
190 } // namespace
191
NewPoolingNodeShader()192 std::unique_ptr<NodeShader> NewPoolingNodeShader() {
193 return absl::make_unique<Pooling>();
194 }
195
196 } // namespace gl
197 } // namespace gpu
198 } // namespace tflite
199