1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
17
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/types.h"
24
25 namespace tflite {
26 namespace gpu {
27 namespace gl {
28 namespace {
29
30 class ElementwiseOneArgument : public NodeShader {
31 public:
ElementwiseOneArgument(OperationType operation_type)32 explicit ElementwiseOneArgument(OperationType operation_type)
33 : operation_type_(operation_type) {}
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const34 absl::Status GenerateCode(const GenerationContext& ctx,
35 GeneratedCode* generated_code) const final {
36 std::string source;
37 switch (operation_type_) {
38 case OperationType::ABS:
39 source = "value_0 = abs(value_0);";
40 break;
41 case OperationType::COS:
42 source = "value_0 = cos(value_0);";
43 break;
44 case OperationType::COPY:
45 source = "value_0 = value_0;";
46 break;
47 case OperationType::ELU:
48 source = R"(
49 value_0.x = value_0.x < 0.0 ? exp(value_0.x) - 1.0 : value_0.x;
50 value_0.y = value_0.y < 0.0 ? exp(value_0.y) - 1.0 : value_0.y;
51 value_0.z = value_0.z < 0.0 ? exp(value_0.z) - 1.0 : value_0.z;
52 value_0.w = value_0.w < 0.0 ? exp(value_0.w) - 1.0 : value_0.w;
53 )";
54 break;
55 case OperationType::EXP:
56 source = "value_0 = exp(value_0);";
57 break;
58 case OperationType::HARD_SWISH:
59 source =
60 "value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
61 "vec4(1.0));";
62 break;
63 case OperationType::LOG:
64 source = R"(
65 const float nan = normalize(vec4(0, 0, 0, 0)).x;
66 value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
67 value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan;
68 value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan;
69 value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
70 )";
71 break;
72 case OperationType::NEG:
73 source = "value_0 = -(value_0);";
74 break;
75 case OperationType::RSQRT:
76 source = R"(
77 const float nan = normalize(vec4(0, 0, 0, 0)).x;
78 value_0.x = value_0.x > 0.0 ? 1.0 / sqrt(value_0.x) : nan;
79 value_0.y = value_0.y > 0.0 ? 1.0 / sqrt(value_0.y) : nan;
80 value_0.z = value_0.z > 0.0 ? 1.0 / sqrt(value_0.z) : nan;
81 value_0.w = value_0.w > 0.0 ? 1.0 / sqrt(value_0.w) : nan;
82 )";
83 break;
84 case OperationType::SIGMOID:
85 source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
86 break;
87 case OperationType::SIN:
88 source = "value_0 = sin(value_0);";
89 break;
90 case OperationType::SQRT:
91 source = R"(
92 const float nan = normalize(vec4(0, 0, 0, 0)).x;
93 value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
94 value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
95 value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
96 value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
97 )";
98 break;
99 case OperationType::SQUARE:
100 source = "value_0 = value_0 * value_0;";
101 break;
102 case OperationType::TANH:
103 source = "value_0 = tanh(value_0);";
104 break;
105 default:
106 return absl::InvalidArgumentError(
107 "Incorrect elementwise operation type.");
108 }
109 *generated_code = {
110 /*parameters=*/{},
111 /*objects=*/{},
112 /*shared_variables=*/{},
113 /*workload=*/uint3(),
114 /*workgroup=*/uint3(),
115 source,
116 /*input=*/IOStructure::AUTO,
117 /*output=*/IOStructure::AUTO,
118 };
119 return absl::OkStatus();
120 }
121
122 private:
123 OperationType operation_type_;
124 };
125
126 class ElementwiseTwoArguments : public NodeShader {
127 public:
ElementwiseTwoArguments(OperationType operation_type)128 explicit ElementwiseTwoArguments(OperationType operation_type)
129 : operation_type_(operation_type) {}
130
IsElementwiseSupported(const GenerationContext & ctx) const131 inline bool IsElementwiseSupported(const GenerationContext& ctx) const {
132 return ctx.input_shapes.size() == 2 &&
133 ctx.input_shapes[0] == ctx.input_shapes[1];
134 }
135
IsBroadcastSupported(const GenerationContext & ctx) const136 inline bool IsBroadcastSupported(const GenerationContext& ctx) const {
137 return ctx.input_shapes.size() == 2 && ctx.input_shapes[1][1] == 1 &&
138 ctx.input_shapes[1][2] == 1 &&
139 ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
140 }
141
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const142 absl::Status GenerateCode(const GenerationContext& ctx,
143 GeneratedCode* generated_code) const final {
144 std::vector<Variable> parameters;
145 std::vector<std::pair<std::string, Object>> objects;
146 std::string argument0, argument1;
147 if (IsElementwiseSupported(ctx)) {
148 argument0 = "value_0";
149 argument1 = "value_1";
150 } else if (IsBroadcastSupported(ctx)) {
151 argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
152 argument1 = "$input_data_1[0, 0, gid.z]$";
153 } else { // Scalar of const vector case
154 const auto& attr =
155 absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
156 const auto* tensor =
157 absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
158 const auto* scalar = absl::get_if<float>(&attr.param);
159 if (!tensor && !scalar) {
160 return absl::InvalidArgumentError(
161 "Couldn't read scalar of const vector data from the attributes.");
162 }
163
164 argument0 = "value_0";
165 if (tensor) {
166 argument1 = "$const_data[gid.z]$";
167 objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
168 } else {
169 argument1 = "vec4($const_data$)";
170 parameters.push_back({"const_data", *scalar});
171 }
172 }
173
174 std::string source;
175 switch (operation_type_) {
176 case OperationType::DIV: {
177 source = "value_0 = $0/$1;";
178 break;
179 }
180 case OperationType::MAXIMUM: {
181 source = "value_0 = max($0, $1);";
182 break;
183 }
184 case OperationType::MINIMUM: {
185 source = "value_0 = min($0, $1);";
186 break;
187 }
188 case OperationType::SQUARED_DIFF: {
189 source = "value_0 = ($0 - $1) * ($0 - $1);";
190 break;
191 }
192 case OperationType::SUB: {
193 source = "value_0 = $0 - $1;";
194 break;
195 }
196 case OperationType::POW: {
197 source = "value_0 = pow($0, $1);";
198 break;
199 }
200 default:
201 return absl::InvalidArgumentError(
202 "Incorrect elementwise with scalar operation type.");
203 }
204 source = absl::Substitute(source, argument0, argument1);
205 *generated_code = {
206 /*parameters=*/std::move(parameters),
207 /*objects=*/std::move(objects),
208 /*shared_variables=*/{},
209 /*workload=*/uint3(),
210 /*workgroup=*/uint3(),
211 /*source_code=*/source,
212 /*input=*/IOStructure::AUTO,
213 /*output=*/IOStructure::AUTO,
214 };
215 return absl::OkStatus();
216 }
217
218 private:
219 OperationType operation_type_;
220 };
221
222 } // namespace
223
NewElementwiseNodeShader(OperationType operation_type)224 std::unique_ptr<NodeShader> NewElementwiseNodeShader(
225 OperationType operation_type) {
226 switch (operation_type) {
227 case OperationType::ABS:
228 case OperationType::COS:
229 case OperationType::COPY:
230 case OperationType::ELU:
231 case OperationType::EXP:
232 case OperationType::HARD_SWISH:
233 case OperationType::LOG:
234 case OperationType::NEG:
235 case OperationType::RSQRT:
236 case OperationType::SIGMOID:
237 case OperationType::SIN:
238 case OperationType::SQRT:
239 case OperationType::SQUARE:
240 case OperationType::TANH:
241 return absl::make_unique<ElementwiseOneArgument>(operation_type);
242 case OperationType::DIV:
243 case OperationType::MAXIMUM:
244 case OperationType::MINIMUM:
245 case OperationType::POW:
246 case OperationType::SQUARED_DIFF:
247 case OperationType::SUB:
248 return absl::make_unique<ElementwiseTwoArguments>(operation_type);
249 default:
250 return nullptr;
251 }
252 }
253
254 } // namespace gl
255 } // namespace gpu
256 } // namespace tflite
257