• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "tensorflow/lite/delegates/gpu/common/shape.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/common/util.h"
28 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
29 
30 namespace tflite {
31 namespace gpu {
32 namespace gl {
33 namespace {
34 
GetMask(int num_channels)35 float4 GetMask(int num_channels) {
36   float4 mask(0.0f);
37   const int remainder = num_channels % 4 == 0 ? 4 : num_channels % 4;
38   for (int i = 0; i < remainder; ++i) mask[i] = 1.0f;
39   return mask;
40 }
41 
42 class Softmax : public NodeShader {
43  public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const44   absl::Status GenerateCode(const GenerationContext& ctx,
45                             GeneratedCode* generated_code) const final {
46     const auto& attr = absl::any_cast<const SoftmaxAttributes&>(ctx.op_attr);
47     if (ctx.input_shapes[0] != ctx.output_shapes[0]) {
48       return absl::InvalidArgumentError(
49           "Input and output shapes do not match.");
50     }
51     if (attr.axis != Axis::CHANNELS) {
52       return absl::UnimplementedError(
53           "Softmax is only supported for channels axis.");
54     }
55     return ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1
56                ? GenerateCodeFor1x1(ctx, generated_code)
57                : GenerateCodeGeneral(ctx, generated_code);
58   }
59 
60  private:
GenerateCodeFor1x1(const GenerationContext & ctx,GeneratedCode * generated_code) const61   absl::Status GenerateCodeFor1x1(const GenerationContext& ctx,
62                                   GeneratedCode* generated_code) const {
63     const int depth = DivideRoundUp(ctx.output_shapes[0][3], 4);
64     std::vector<Variable> shared_variables = {
65         {"partial_sum", std::vector<float4>(8)},
66     };
67     std::vector<Variable> uniform_parameters = {
68         {"depth", depth},
69         {"mask", GetMask(ctx.output_shapes[0][3])},
70     };
71     std::string source_code = R"(
72   highp vec4 kOnes = vec4(1.0);
73   int tid = int(gl_LocalInvocationID.x);
74   highp vec4 maxx4 = $input_data_0[0, 0, 0]$;
75   maxx4.y = maxx4.x;
76   maxx4.z = maxx4.x;
77   maxx4.w = maxx4.x;
78   for (int s = tid; s < $depth$; s += 32) {
79     highp vec4 mask_a = s == $depth$ - 1 ? $mask$ : kOnes;
80     highp vec4 mask_b = kOnes - mask_a;
81     highp vec4 src = $input_data_0[0, 0, s]$;
82     src = src * mask_a + mask_b * src.x;
83     maxx4 = max(maxx4, src);
84   }
85   highp float maximum = max(maxx4.x, maxx4.y);
86   maximum = max(maximum, maxx4.z);
87   maximum = max(maximum, maxx4.w);
88   partial_sum[tid / 4][tid % 4] = maximum;
89 
90   memoryBarrierShared();
91   barrier();
92 
93   if (tid == 0) {
94     maxx4 = max(partial_sum[0], partial_sum[1]);
95     maxx4 = max(maxx4, partial_sum[2]);
96     maxx4 = max(maxx4, partial_sum[3]);
97     maxx4 = max(maxx4, partial_sum[4]);
98     maxx4 = max(maxx4, partial_sum[5]);
99     maxx4 = max(maxx4, partial_sum[6]);
100     maxx4 = max(maxx4, partial_sum[7]);
101     maximum = max(maxx4.x, maxx4.y);
102     maximum = max(maximum, maxx4.z);
103     maximum = max(maximum, maxx4.w);
104     partial_sum[0][0] = maximum;
105   }
106 
107   memoryBarrierShared();
108   barrier();
109 
110   maximum = partial_sum[0][0];
111 
112   highp float sum = 0.0;
113   for (int s = tid; s < $depth$; s += 32) {
114     highp vec4 mask_temp = s == $depth$ - 1 ? $mask$ : kOnes;
115     highp vec4 src = $input_data_0[0, 0, s]$ - vec4(maximum);
116     sum += dot(mask_temp, exp(src));
117   }
118 
119   memoryBarrierShared();
120   barrier();
121 
122   partial_sum[tid / 4][tid % 4] = sum;
123 
124   memoryBarrierShared();
125   barrier();
126 
127   if (tid == 0) {
128     sum = dot(kOnes, partial_sum[0]);
129     sum += dot(kOnes, partial_sum[1]);
130     sum += dot(kOnes, partial_sum[2]);
131     sum += dot(kOnes, partial_sum[3]);
132     sum += dot(kOnes, partial_sum[4]);
133     sum += dot(kOnes, partial_sum[5]);
134     sum += dot(kOnes, partial_sum[6]);
135     sum += dot(kOnes, partial_sum[7]);
136     partial_sum[0][0] = 1.0 / sum;
137   }
138 
139   memoryBarrierShared();
140   barrier();
141 
142   sum = partial_sum[0][0];
143 
144   int dst_s = int(gl_GlobalInvocationID.x);
145   if (dst_s < $depth$) {
146     highp vec4 src = $input_data_0[0, 0, dst_s]$ - vec4(maximum);
147     highp vec4 temp = exp(src) * sum;
148     $output_data_0[0, 0, dst_s] = temp$;
149   }
150 )";
151 
152     *generated_code = {
153         /*parameters=*/std::move(uniform_parameters),
154         /*objects=*/{},
155         /*shared_variables=*/std::move(shared_variables),
156         /*workload=*/uint3(depth, 1, 1),
157         /*workgroup=*/uint3(32, 1, 1),
158         /*source_code=*/std::move(source_code),
159         /*input=*/IOStructure::ONLY_DEFINITIONS,
160         /*output=*/IOStructure::ONLY_DEFINITIONS,
161     };
162     return absl::OkStatus();
163   }
164 
GenerateCodeGeneral(const GenerationContext & ctx,GeneratedCode * generated_code) const165   absl::Status GenerateCodeGeneral(const GenerationContext& ctx,
166                                    GeneratedCode* generated_code) const {
167     std::vector<Variable> parameters = {
168         {"src_depth",
169          DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4)},
170         {"mask", GetMask(ctx.output_shapes[0][3])},
171     };
172 
173     std::string source_code = R"(
174   highp vec4 kOnes = vec4(1.0);
175   highp float sum = 0.0;
176   highp float maximum = $input_data_0[gid.x, gid.y, 0]$.x;
177   for (int d = 0; d < $src_depth$; ++d) {
178     highp vec4 mask_a = d == $src_depth$ - 1 ? $mask$ : kOnes;
179     highp vec4 mask_b = kOnes - mask_a;
180     highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
181     src = src * mask_a + mask_b * src.x;
182     maximum = max(maximum, src.x);
183     maximum = max(maximum, src.y);
184     maximum = max(maximum, src.z);
185     maximum = max(maximum, src.w);
186   }
187   for (int d = 0; d < $src_depth$; ++d) {
188     highp vec4 mask_temp = d == $src_depth$ - 1 ? $mask$ : kOnes;
189     highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
190     sum += dot(mask_temp, exp(src));
191   }
192   for (int d = 0; d < $src_depth$; ++d) {
193     highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
194     highp vec4 temp_sum = exp(src) / sum;
195     $output_data_0[gid.x, gid.y, d] = temp_sum$;
196   }
197 )";
198     *generated_code = {
199         /*parameters=*/std::move(parameters),
200         /*objects=*/{},
201         /*shared_variables=*/{},
202         /*workload=*/
203         uint3(static_cast<int>(ctx.output_shapes[0][2]),
204               static_cast<int>(ctx.output_shapes[0][1]), 1),
205         /*workgroup=*/uint3(),
206         /*source_code=*/std::move(source_code),
207         /*input=*/IOStructure::ONLY_DEFINITIONS,
208         /*output=*/IOStructure::ONLY_DEFINITIONS,
209     };
210     return absl::OkStatus();
211   }
212 };
213 
214 }  // namespace
215 
NewSoftmaxNodeShader()216 std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
217   return absl::make_unique<Softmax>();
218 }
219 
220 }  // namespace gl
221 }  // namespace gpu
222 }  // namespace tflite
223