1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
17
18 #include <string>
19
20 #include "absl/memory/memory.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/types.h"
24
25 namespace tflite {
26 namespace gpu {
27 namespace gl {
28 namespace {
29
30 class ElementwiseOneArgument : public NodeShader {
31 public:
ElementwiseOneArgument(OperationType operation_type)32 explicit ElementwiseOneArgument(OperationType operation_type)
33 : operation_type_(operation_type) {}
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const34 absl::Status GenerateCode(const GenerationContext& ctx,
35 GeneratedCode* generated_code) const final {
36 std::string source;
37 switch (operation_type_) {
38 case OperationType::ABS:
39 source = "value_0 = abs(value_0);";
40 break;
41 case OperationType::COS:
42 source = "value_0 = cos(value_0);";
43 break;
44 case OperationType::COPY:
45 source = "value_0 = value_0;";
46 break;
47 case OperationType::ELU:
48 source = R"(
49 value_0.x = value_0.x < 0.0 ? exp(value_0.x) - 1.0 : value_0.x;
50 value_0.y = value_0.y < 0.0 ? exp(value_0.y) - 1.0 : value_0.y;
51 value_0.z = value_0.z < 0.0 ? exp(value_0.z) - 1.0 : value_0.z;
52 value_0.w = value_0.w < 0.0 ? exp(value_0.w) - 1.0 : value_0.w;
53 )";
54 break;
55 case OperationType::EXP:
56 source = "value_0 = exp(value_0);";
57 break;
58 case tflite::gpu::OperationType::FLOOR:
59 source = "value_0 = floor(value_0);";
60 break;
61 case OperationType::HARD_SWISH:
62 source =
63 "value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
64 "vec4(1.0));";
65 break;
66 case OperationType::LOG:
67 source = R"(
68 const float nan = normalize(vec4(0, 0, 0, 0)).x;
69 value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
70 value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan;
71 value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan;
72 value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
73 )";
74 break;
75 case OperationType::NEG:
76 source = "value_0 = -(value_0);";
77 break;
78 case OperationType::RSQRT:
79 source = R"(
80 const float nan = normalize(vec4(0, 0, 0, 0)).x;
81 value_0.x = value_0.x > 0.0 ? 1.0 / sqrt(value_0.x) : nan;
82 value_0.y = value_0.y > 0.0 ? 1.0 / sqrt(value_0.y) : nan;
83 value_0.z = value_0.z > 0.0 ? 1.0 / sqrt(value_0.z) : nan;
84 value_0.w = value_0.w > 0.0 ? 1.0 / sqrt(value_0.w) : nan;
85 )";
86 break;
87 case OperationType::SIGMOID:
88 source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
89 break;
90 case OperationType::SIN:
91 source = "value_0 = sin(value_0);";
92 break;
93 case OperationType::SQRT:
94 source = R"(
95 const float nan = normalize(vec4(0, 0, 0, 0)).x;
96 value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
97 value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
98 value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
99 value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
100 )";
101 break;
102 case OperationType::SQUARE:
103 source = "value_0 = value_0 * value_0;";
104 break;
105 case OperationType::TANH:
106 source = "value_0 = tanh(value_0);";
107 break;
108 default:
109 return absl::InvalidArgumentError(
110 "Incorrect elementwise operation type.");
111 }
112 *generated_code = {
113 /*parameters=*/{},
114 /*objects=*/{},
115 /*shared_variables=*/{},
116 /*workload=*/uint3(),
117 /*workgroup=*/uint3(),
118 source,
119 /*input=*/IOStructure::AUTO,
120 /*output=*/IOStructure::AUTO,
121 };
122 return absl::OkStatus();
123 }
124
125 private:
126 OperationType operation_type_;
127 };
128
129 class ElementwiseTwoArguments : public NodeShader {
130 public:
ElementwiseTwoArguments(OperationType operation_type)131 explicit ElementwiseTwoArguments(OperationType operation_type)
132 : operation_type_(operation_type) {}
133
IsElementwiseSupported(const GenerationContext & ctx) const134 inline bool IsElementwiseSupported(const GenerationContext& ctx) const {
135 return ctx.input_shapes.size() == 2 &&
136 ctx.input_shapes[0] == ctx.input_shapes[1];
137 }
138
IsBroadcastSupported(const GenerationContext & ctx) const139 inline bool IsBroadcastSupported(const GenerationContext& ctx) const {
140 return ctx.input_shapes.size() == 2 && ctx.input_shapes[1][1] == 1 &&
141 ctx.input_shapes[1][2] == 1 &&
142 ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
143 }
144
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const145 absl::Status GenerateCode(const GenerationContext& ctx,
146 GeneratedCode* generated_code) const final {
147 std::vector<Variable> parameters;
148 std::vector<std::pair<std::string, Object>> objects;
149 std::string argument0, argument1;
150 if (IsElementwiseSupported(ctx)) {
151 argument0 = "value_0";
152 argument1 = "value_1";
153 } else if (IsBroadcastSupported(ctx)) {
154 argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
155 argument1 = "$input_data_1[0, 0, gid.z]$";
156 } else { // Scalar of const vector case
157 const auto& attr =
158 absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
159 const auto* tensor =
160 absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
161 const auto* scalar = absl::get_if<float>(&attr.param);
162 if (!tensor && !scalar) {
163 return absl::InvalidArgumentError(
164 "Couldn't read scalar of const vector data from the attributes.");
165 }
166
167 argument0 = "value_0";
168 if (tensor) {
169 argument1 = "$const_data[gid.z]$";
170 objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
171 } else {
172 argument1 = "vec4($const_data$)";
173 parameters.push_back({"const_data", *scalar});
174 }
175 }
176
177 std::string source;
178 switch (operation_type_) {
179 case OperationType::DIV: {
180 source = "value_0 = $0/$1;";
181 break;
182 }
183 case tflite::gpu::OperationType::FLOOR_DIV:
184 source = "value_0 = floor($0 / $1);";
185 break;
186 case tflite::gpu::OperationType::FLOOR_MOD:
187 source = "value_0 = $0 - floor($0 / $1) * $1;";
188 break;
189 case OperationType::MAXIMUM: {
190 source = "value_0 = max($0, $1);";
191 break;
192 }
193 case OperationType::MINIMUM: {
194 source = "value_0 = min($0, $1);";
195 break;
196 }
197 case OperationType::SQUARED_DIFF: {
198 source = "value_0 = ($0 - $1) * ($0 - $1);";
199 break;
200 }
201 case OperationType::SUB: {
202 source = "value_0 = $0 - $1;";
203 break;
204 }
205 case OperationType::POW: {
206 source = "value_0 = pow($0, $1);";
207 break;
208 }
209 default:
210 return absl::InvalidArgumentError(
211 "Incorrect elementwise with scalar operation type.");
212 }
213 source = absl::Substitute(source, argument0, argument1);
214 *generated_code = {
215 /*parameters=*/std::move(parameters),
216 /*objects=*/std::move(objects),
217 /*shared_variables=*/{},
218 /*workload=*/uint3(),
219 /*workgroup=*/uint3(),
220 /*source_code=*/source,
221 /*input=*/IOStructure::AUTO,
222 /*output=*/IOStructure::AUTO,
223 };
224 return absl::OkStatus();
225 }
226
227 private:
228 OperationType operation_type_;
229 };
230
231 } // namespace
232
NewElementwiseNodeShader(OperationType operation_type)233 std::unique_ptr<NodeShader> NewElementwiseNodeShader(
234 OperationType operation_type) {
235 switch (operation_type) {
236 case OperationType::ABS:
237 case OperationType::COS:
238 case OperationType::COPY:
239 case OperationType::ELU:
240 case OperationType::EXP:
241 case OperationType::FLOOR:
242 case OperationType::HARD_SWISH:
243 case OperationType::LOG:
244 case OperationType::NEG:
245 case OperationType::RSQRT:
246 case OperationType::SIGMOID:
247 case OperationType::SIN:
248 case OperationType::SQRT:
249 case OperationType::SQUARE:
250 case OperationType::TANH:
251 return absl::make_unique<ElementwiseOneArgument>(operation_type);
252 case OperationType::DIV:
253 case OperationType::FLOOR_DIV:
254 case OperationType::FLOOR_MOD:
255 case OperationType::MAXIMUM:
256 case OperationType::MINIMUM:
257 case OperationType::POW:
258 case OperationType::SQUARED_DIFF:
259 case OperationType::SUB:
260 return absl::make_unique<ElementwiseTwoArguments>(operation_type);
261 default:
262 return nullptr;
263 }
264 }
265
266 } // namespace gl
267 } // namespace gpu
268 } // namespace tflite
269