1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "tensorflow/lite/delegates/gpu/common/shape.h"
25 #include "tensorflow/lite/delegates/gpu/common/status.h"
26 #include "tensorflow/lite/delegates/gpu/common/types.h"
27 #include "tensorflow/lite/delegates/gpu/common/util.h"
28 #include "tensorflow/lite/delegates/gpu/gl/variable.h"
29
30 namespace tflite {
31 namespace gpu {
32 namespace gl {
33 namespace {
34
GetMask(int num_channels)35 float4 GetMask(int num_channels) {
36 float4 mask(0.0f);
37 const int remainder = num_channels % 4 == 0 ? 4 : num_channels % 4;
38 for (int i = 0; i < remainder; ++i) mask[i] = 1.0f;
39 return mask;
40 }
41
42 class Softmax : public NodeShader {
43 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const44 absl::Status GenerateCode(const GenerationContext& ctx,
45 GeneratedCode* generated_code) const final {
46 const auto& attr = absl::any_cast<const SoftmaxAttributes&>(ctx.op_attr);
47 if (ctx.input_shapes[0] != ctx.output_shapes[0]) {
48 return absl::InvalidArgumentError(
49 "Input and output shapes do not match.");
50 }
51 if (attr.axis != Axis::CHANNELS) {
52 return absl::UnimplementedError(
53 "Softmax is only supported for channels axis.");
54 }
55 return ctx.input_shapes[0][1] == 1 && ctx.input_shapes[0][2] == 1
56 ? GenerateCodeFor1x1(ctx, generated_code)
57 : GenerateCodeGeneral(ctx, generated_code);
58 }
59
60 private:
GenerateCodeFor1x1(const GenerationContext & ctx,GeneratedCode * generated_code) const61 absl::Status GenerateCodeFor1x1(const GenerationContext& ctx,
62 GeneratedCode* generated_code) const {
63 const int depth = DivideRoundUp(ctx.output_shapes[0][3], 4);
64 std::vector<Variable> shared_variables = {
65 {"partial_sum", std::vector<float4>(8)},
66 };
67 std::vector<Variable> uniform_parameters = {
68 {"depth", depth},
69 {"mask", GetMask(ctx.output_shapes[0][3])},
70 };
71 std::string source_code = R"(
72 highp vec4 kOnes = vec4(1.0);
73 int tid = int(gl_LocalInvocationID.x);
74 highp vec4 maxx4 = $input_data_0[0, 0, 0]$;
75 maxx4.y = maxx4.x;
76 maxx4.z = maxx4.x;
77 maxx4.w = maxx4.x;
78 for (int s = tid; s < $depth$; s += 32) {
79 highp vec4 mask_a = s == $depth$ - 1 ? $mask$ : kOnes;
80 highp vec4 mask_b = kOnes - mask_a;
81 highp vec4 src = $input_data_0[0, 0, s]$;
82 src = src * mask_a + mask_b * src.x;
83 maxx4 = max(maxx4, src);
84 }
85 highp float maximum = max(maxx4.x, maxx4.y);
86 maximum = max(maximum, maxx4.z);
87 maximum = max(maximum, maxx4.w);
88 partial_sum[tid / 4][tid % 4] = maximum;
89
90 memoryBarrierShared();
91 barrier();
92
93 if (tid == 0) {
94 maxx4 = max(partial_sum[0], partial_sum[1]);
95 maxx4 = max(maxx4, partial_sum[2]);
96 maxx4 = max(maxx4, partial_sum[3]);
97 maxx4 = max(maxx4, partial_sum[4]);
98 maxx4 = max(maxx4, partial_sum[5]);
99 maxx4 = max(maxx4, partial_sum[6]);
100 maxx4 = max(maxx4, partial_sum[7]);
101 maximum = max(maxx4.x, maxx4.y);
102 maximum = max(maximum, maxx4.z);
103 maximum = max(maximum, maxx4.w);
104 partial_sum[0][0] = maximum;
105 }
106
107 memoryBarrierShared();
108 barrier();
109
110 maximum = partial_sum[0][0];
111
112 highp float sum = 0.0;
113 for (int s = tid; s < $depth$; s += 32) {
114 highp vec4 mask_temp = s == $depth$ - 1 ? $mask$ : kOnes;
115 highp vec4 src = $input_data_0[0, 0, s]$ - vec4(maximum);
116 sum += dot(mask_temp, exp(src));
117 }
118
119 memoryBarrierShared();
120 barrier();
121
122 partial_sum[tid / 4][tid % 4] = sum;
123
124 memoryBarrierShared();
125 barrier();
126
127 if (tid == 0) {
128 sum = dot(kOnes, partial_sum[0]);
129 sum += dot(kOnes, partial_sum[1]);
130 sum += dot(kOnes, partial_sum[2]);
131 sum += dot(kOnes, partial_sum[3]);
132 sum += dot(kOnes, partial_sum[4]);
133 sum += dot(kOnes, partial_sum[5]);
134 sum += dot(kOnes, partial_sum[6]);
135 sum += dot(kOnes, partial_sum[7]);
136 partial_sum[0][0] = 1.0 / sum;
137 }
138
139 memoryBarrierShared();
140 barrier();
141
142 sum = partial_sum[0][0];
143
144 int dst_s = int(gl_GlobalInvocationID.x);
145 if (dst_s < $depth$) {
146 highp vec4 src = $input_data_0[0, 0, dst_s]$ - vec4(maximum);
147 highp vec4 temp = exp(src) * sum;
148 $output_data_0[0, 0, dst_s] = temp$;
149 }
150 )";
151
152 *generated_code = {
153 /*parameters=*/std::move(uniform_parameters),
154 /*objects=*/{},
155 /*shared_variables=*/std::move(shared_variables),
156 /*workload=*/uint3(depth, 1, 1),
157 /*workgroup=*/uint3(32, 1, 1),
158 /*source_code=*/std::move(source_code),
159 /*input=*/IOStructure::ONLY_DEFINITIONS,
160 /*output=*/IOStructure::ONLY_DEFINITIONS,
161 };
162 return absl::OkStatus();
163 }
164
GenerateCodeGeneral(const GenerationContext & ctx,GeneratedCode * generated_code) const165 absl::Status GenerateCodeGeneral(const GenerationContext& ctx,
166 GeneratedCode* generated_code) const {
167 std::vector<Variable> parameters = {
168 {"src_depth",
169 DivideRoundUp(static_cast<int>(ctx.output_shapes[0][3]), 4)},
170 {"mask", GetMask(ctx.output_shapes[0][3])},
171 };
172
173 std::string source_code = R"(
174 highp vec4 kOnes = vec4(1.0);
175 highp float sum = 0.0;
176 highp float maximum = $input_data_0[gid.x, gid.y, 0]$.x;
177 for (int d = 0; d < $src_depth$; ++d) {
178 highp vec4 mask_a = d == $src_depth$ - 1 ? $mask$ : kOnes;
179 highp vec4 mask_b = kOnes - mask_a;
180 highp vec4 src = $input_data_0[gid.x, gid.y, d]$;
181 src = src * mask_a + mask_b * src.x;
182 maximum = max(maximum, src.x);
183 maximum = max(maximum, src.y);
184 maximum = max(maximum, src.z);
185 maximum = max(maximum, src.w);
186 }
187 for (int d = 0; d < $src_depth$; ++d) {
188 highp vec4 mask_temp = d == $src_depth$ - 1 ? $mask$ : kOnes;
189 highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
190 sum += dot(mask_temp, exp(src));
191 }
192 for (int d = 0; d < $src_depth$; ++d) {
193 highp vec4 src = $input_data_0[gid.x, gid.y, d]$ - vec4(maximum);
194 highp vec4 temp_sum = exp(src) / sum;
195 $output_data_0[gid.x, gid.y, d] = temp_sum$;
196 }
197 )";
198 *generated_code = {
199 /*parameters=*/std::move(parameters),
200 /*objects=*/{},
201 /*shared_variables=*/{},
202 /*workload=*/
203 uint3(static_cast<int>(ctx.output_shapes[0][2]),
204 static_cast<int>(ctx.output_shapes[0][1]), 1),
205 /*workgroup=*/uint3(),
206 /*source_code=*/std::move(source_code),
207 /*input=*/IOStructure::ONLY_DEFINITIONS,
208 /*output=*/IOStructure::ONLY_DEFINITIONS,
209 };
210 return absl::OkStatus();
211 }
212 };
213
214 } // namespace
215
NewSoftmaxNodeShader()216 std::unique_ptr<NodeShader> NewSoftmaxNodeShader() {
217 return absl::make_unique<Softmax>();
218 }
219
220 } // namespace gl
221 } // namespace gpu
222 } // namespace tflite
223