• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <cstring>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
28 
29 namespace tflite {
30 namespace gpu {
31 namespace gl {
32 namespace {
33 
GenerateMaxPoolingCode(const Pooling2DAttributes & attr,const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)34 absl::Status GenerateMaxPoolingCode(const Pooling2DAttributes& attr,
35                                     const NodeShader::GenerationContext& ctx,
36                                     GeneratedCode* generated_code) {
37   if (attr.padding.prepended.h > attr.kernel.h ||
38       attr.padding.prepended.w > attr.kernel.w) {
39     return absl::InvalidArgumentError("Padding is bigger than kernel.");
40   }
41 
42   std::vector<Variable> parameters = {
43       {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
44       {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
45       {"stride", int2(attr.strides.w, attr.strides.h)},
46       {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
47       {"window_h", attr.kernel.h},
48       {"window_w", attr.kernel.w},
49   };
50 
51   // Per GLSL_ES 3.1 spec in Issue 13.4
52   // "Floating Point Representation and Functionality" highp floats are
53   // expected to behave as defined in IEEE 754. In particular, signed
54   // infinities are mandated and defined as a number divided by 0.
55   std::string source = R"(
56   const highp float inf = -(1.0f / 0.0f);
57   value_0 = vec4(inf);)";
58   if (attr.output_indices) {
59     source += R"(
60   ivec4 value_1;
61 )";
62   }
63   source += R"(
64   ivec2 base_coord = gid.xy * $stride$ - $offset$;
65   for (int a = 0; a < $window_h$; ++a) {
66     for (int b = 0; b < $window_w$; ++b) {
67       ivec2 coord = base_coord + ivec2(b, a);
68       if (coord.x < 0 || coord.y < 0 || coord.x >= $input_data_0_w$ || coord.y >= $input_data_0_h$) {
69         continue;
70       }
71       vec4 input_ = $input_data_0[coord.x, coord.y, gid.z]$;)";
72   if (attr.output_indices) {
73     source += R"(
74       int window_index = a * $window_w$ + b;
75       if (input_.x > value_0.x) value_1.x = window_index;
76       if (input_.y > value_0.y) value_1.y = window_index;
77       if (input_.z > value_0.z) value_1.z = window_index;
78       if (input_.w > value_0.w) value_1.w = window_index;)";
79   }
80   source += R"(
81       value_0 = max(value_0, input_);
82     }
83   }
84 )";
85   *generated_code = {
86       /*parameters=*/std::move(parameters),
87       /*objects=*/{},
88       /*shared_variables=*/{},
89       /*workload=*/uint3(),
90       /*workgroup=*/uint3(),
91       /*source_code=*/std::move(source),
92       /*input=*/IOStructure::ONLY_DEFINITIONS,
93       /*output=*/IOStructure::AUTO,
94   };
95   return absl::OkStatus();
96 }
97 
GenerateAveragePoolingCode(const Pooling2DAttributes & attr,const NodeShader::GenerationContext & ctx,GeneratedCode * generated_code)98 absl::Status GenerateAveragePoolingCode(
99     const Pooling2DAttributes& attr, const NodeShader::GenerationContext& ctx,
100     GeneratedCode* generated_code) {
101   std::vector<Variable> parameters = {
102       {"input_data_0_h", static_cast<int>(ctx.input_shapes[0][1])},
103       {"input_data_0_w", static_cast<int>(ctx.input_shapes[0][2])},
104       {"stride", int2(attr.strides.w, attr.strides.h)},
105       {"offset", int2(attr.padding.prepended.w, attr.padding.prepended.h)},
106       {"window_h", attr.kernel.h},
107       {"window_w", attr.kernel.w},
108   };
109 
110   // Bounds checking helper functions.
111   auto x_in_bounds = [input_width = ctx.input_shapes[0][2],
112                       kernel_width = attr.kernel.w](int64_t x) -> bool {
113     return 0 <= x && x + kernel_width <= input_width;
114   };
115   auto y_in_bounds = [input_height = ctx.input_shapes[0][1],
116                       kernel_height = attr.kernel.h](int64_t y) -> bool {
117     return 0 <= y && y + kernel_height <= input_height;
118   };
119 
120   // Only include a bounds check in the shader if it will actually be necessary
121   // at run time.
122   const int64_t output_shape_max_y = ctx.output_shapes[0][1] - 1;
123   const int64_t output_shape_max_x = ctx.output_shapes[0][2] - 1;
124   const int64_t base_x = -attr.padding.prepended.w;
125   const int64_t base_y = -attr.padding.prepended.h;
126   const bool bounds_check_necessary =
127       !(x_in_bounds(base_x) &&
128         x_in_bounds(base_x + output_shape_max_x * attr.strides.w) &&
129         y_in_bounds(base_y) &&
130         y_in_bounds(base_y + output_shape_max_y * attr.strides.h));
131 
132   std::string source = bounds_check_necessary ?
133                                               R"(
134   int window_size = 0;
135   for (int a = 0; a < $window_h$; ++a) {
136     for (int b = 0; b < $window_w$; ++b) {
137       ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
138       if (coord.x >= 0 && coord.y >= 0 && coord.x < $input_data_0_w$ && coord.y < $input_data_0_h$) {
139         value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
140         window_size++;
141       }
142     }
143   }
144   // If window_size==0, window covered nothing. This situation is a sign of
145   // incorrectly constructed operation. NaNs are expected as output.
146   value_0 /= float(window_size);
147 )"
148                                               :
149                                               R"(
150   for (int a = 0; a < $window_h$; ++a) {
151     for (int b = 0; b < $window_w$; ++b) {
152       ivec2 coord = gid.xy * $stride$ - $offset$ + ivec2(b, a);
153       value_0 += $input_data_0[coord.x, coord.y, gid.z]$;
154     }
155   }
156   // If the denominator is 0, that is a sign of an incorrectly constructed
157   // operation. NaNs are expected as output.
158   value_0 /= float($window_h$ * $window_w$);
159 )";
160 
161   *generated_code = {
162       /*parameters=*/std::move(parameters),
163       /*objects=*/{},
164       /*shared_variables=*/{},
165       /*workload=*/uint3(),
166       /*workgroup=*/uint3(),
167       /*source_code=*/std::move(source),
168       /*input=*/IOStructure::ONLY_DEFINITIONS,
169       /*output=*/IOStructure::AUTO,
170   };
171   return absl::OkStatus();
172 }
173 
174 class Pooling : public NodeShader {
175  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const176   absl::Status GenerateCode(const GenerationContext& ctx,
177                             GeneratedCode* generated_code) const final {
178     const auto& attr = absl::any_cast<const Pooling2DAttributes&>(ctx.op_attr);
179     switch (attr.type) {
180       case PoolingType::AVERAGE:
181         return GenerateAveragePoolingCode(attr, ctx, generated_code);
182       case PoolingType::MAX:
183         return GenerateMaxPoolingCode(attr, ctx, generated_code);
184       default:
185         return absl::InvalidArgumentError("Incorrect attributes' type.");
186     }
187   }
188 };
189 
190 }  // namespace
191 
NewPoolingNodeShader()192 std::unique_ptr<NodeShader> NewPoolingNodeShader() {
193   return absl::make_unique<Pooling>();
194 }
195 
196 }  // namespace gl
197 }  // namespace gpu
198 }  // namespace tflite
199