• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
17 
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/types.h"
24 
25 namespace tflite {
26 namespace gpu {
27 namespace gl {
28 namespace {
29 
30 class ElementwiseOneArgument : public NodeShader {
31  public:
ElementwiseOneArgument(OperationType operation_type)32   explicit ElementwiseOneArgument(OperationType operation_type)
33       : operation_type_(operation_type) {}
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const34   absl::Status GenerateCode(const GenerationContext& ctx,
35                             GeneratedCode* generated_code) const final {
36     std::string source;
37     switch (operation_type_) {
38       case OperationType::ABS:
39         source = "value_0 = abs(value_0);";
40         break;
41       case OperationType::COS:
42         source = "value_0 = cos(value_0);";
43         break;
44       case OperationType::COPY:
45         source = "value_0 = value_0;";
46         break;
47       case OperationType::ELU:
48         source = R"(
49             value_0.x = value_0.x < 0.0 ? exp(value_0.x) - 1.0 : value_0.x;
50             value_0.y = value_0.y < 0.0 ? exp(value_0.y) - 1.0 : value_0.y;
51             value_0.z = value_0.z < 0.0 ? exp(value_0.z) - 1.0 : value_0.z;
52             value_0.w = value_0.w < 0.0 ? exp(value_0.w) - 1.0 : value_0.w;
53         )";
54         break;
55       case OperationType::EXP:
56         source = "value_0 = exp(value_0);";
57         break;
58       case OperationType::HARD_SWISH:
59         source =
60             "value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
61             "vec4(1.0));";
62         break;
63       case OperationType::LOG:
64         source = R"(
65             const float nan = normalize(vec4(0, 0, 0, 0)).x;
66             value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
67             value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan;
68             value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan;
69             value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
70         )";
71         break;
72       case OperationType::NEG:
73         source = "value_0 = -(value_0);";
74         break;
75       case OperationType::RSQRT:
76         source = R"(
77             const float nan = normalize(vec4(0, 0, 0, 0)).x;
78             value_0.x = value_0.x > 0.0 ? 1.0 / sqrt(value_0.x) : nan;
79             value_0.y = value_0.y > 0.0 ? 1.0 / sqrt(value_0.y) : nan;
80             value_0.z = value_0.z > 0.0 ? 1.0 / sqrt(value_0.z) : nan;
81             value_0.w = value_0.w > 0.0 ? 1.0 / sqrt(value_0.w) : nan;
82         )";
83         break;
84       case OperationType::SIGMOID:
85         source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
86         break;
87       case OperationType::SIN:
88         source = "value_0 = sin(value_0);";
89         break;
90       case OperationType::SQRT:
91         source = R"(
92             const float nan = normalize(vec4(0, 0, 0, 0)).x;
93             value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
94             value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
95             value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
96             value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
97         )";
98         break;
99       case OperationType::SQUARE:
100         source = "value_0 = value_0 * value_0;";
101         break;
102       case OperationType::TANH:
103         source = "value_0 = tanh(value_0);";
104         break;
105       default:
106         return absl::InvalidArgumentError(
107             "Incorrect elementwise operation type.");
108     }
109     *generated_code = {
110         /*parameters=*/{},
111         /*objects=*/{},
112         /*shared_variables=*/{},
113         /*workload=*/uint3(),
114         /*workgroup=*/uint3(),
115         source,
116         /*input=*/IOStructure::AUTO,
117         /*output=*/IOStructure::AUTO,
118     };
119     return absl::OkStatus();
120   }
121 
122  private:
123   OperationType operation_type_;
124 };
125 
126 class ElementwiseTwoArguments : public NodeShader {
127  public:
ElementwiseTwoArguments(OperationType operation_type)128   explicit ElementwiseTwoArguments(OperationType operation_type)
129       : operation_type_(operation_type) {}
130 
IsElementwiseSupported(const GenerationContext & ctx) const131   inline bool IsElementwiseSupported(const GenerationContext& ctx) const {
132     return ctx.input_shapes.size() == 2 &&
133            ctx.input_shapes[0] == ctx.input_shapes[1];
134   }
135 
IsBroadcastSupported(const GenerationContext & ctx) const136   inline bool IsBroadcastSupported(const GenerationContext& ctx) const {
137     return ctx.input_shapes.size() == 2 && ctx.input_shapes[1][1] == 1 &&
138            ctx.input_shapes[1][2] == 1 &&
139            ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
140   }
141 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const142   absl::Status GenerateCode(const GenerationContext& ctx,
143                             GeneratedCode* generated_code) const final {
144     std::vector<Variable> parameters;
145     std::vector<std::pair<std::string, Object>> objects;
146     std::string argument0, argument1;
147     if (IsElementwiseSupported(ctx)) {
148       argument0 = "value_0";
149       argument1 = "value_1";
150     } else if (IsBroadcastSupported(ctx)) {
151       argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
152       argument1 = "$input_data_1[0, 0, gid.z]$";
153     } else {  // Scalar of const vector case
154       const auto& attr =
155           absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
156       const auto* tensor =
157           absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
158       const auto* scalar = absl::get_if<float>(&attr.param);
159       if (!tensor && !scalar) {
160         return absl::InvalidArgumentError(
161             "Couldn't read scalar of const vector data from the attributes.");
162       }
163 
164       argument0 = "value_0";
165       if (tensor) {
166         argument1 = "$const_data[gid.z]$";
167         objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
168       } else {
169         argument1 = "vec4($const_data$)";
170         parameters.push_back({"const_data", *scalar});
171       }
172     }
173 
174     std::string source;
175     switch (operation_type_) {
176       case OperationType::DIV: {
177         source = "value_0 = $0/$1;";
178         break;
179       }
180       case OperationType::MAXIMUM: {
181         source = "value_0 = max($0, $1);";
182         break;
183       }
184       case OperationType::MINIMUM: {
185         source = "value_0 = min($0, $1);";
186         break;
187       }
188       case OperationType::SQUARED_DIFF: {
189         source = "value_0 = ($0 - $1) * ($0 - $1);";
190         break;
191       }
192       case OperationType::SUB: {
193         source = "value_0 = $0 - $1;";
194         break;
195       }
196       case OperationType::POW: {
197         source = "value_0 = pow($0, $1);";
198         break;
199       }
200       default:
201         return absl::InvalidArgumentError(
202             "Incorrect elementwise with scalar operation type.");
203     }
204     source = absl::Substitute(source, argument0, argument1);
205     *generated_code = {
206         /*parameters=*/std::move(parameters),
207         /*objects=*/std::move(objects),
208         /*shared_variables=*/{},
209         /*workload=*/uint3(),
210         /*workgroup=*/uint3(),
211         /*source_code=*/source,
212         /*input=*/IOStructure::AUTO,
213         /*output=*/IOStructure::AUTO,
214     };
215     return absl::OkStatus();
216   }
217 
218  private:
219   OperationType operation_type_;
220 };
221 
222 }  // namespace
223 
NewElementwiseNodeShader(OperationType operation_type)224 std::unique_ptr<NodeShader> NewElementwiseNodeShader(
225     OperationType operation_type) {
226   switch (operation_type) {
227     case OperationType::ABS:
228     case OperationType::COS:
229     case OperationType::COPY:
230     case OperationType::ELU:
231     case OperationType::EXP:
232     case OperationType::HARD_SWISH:
233     case OperationType::LOG:
234     case OperationType::NEG:
235     case OperationType::RSQRT:
236     case OperationType::SIGMOID:
237     case OperationType::SIN:
238     case OperationType::SQRT:
239     case OperationType::SQUARE:
240     case OperationType::TANH:
241       return absl::make_unique<ElementwiseOneArgument>(operation_type);
242     case OperationType::DIV:
243     case OperationType::MAXIMUM:
244     case OperationType::MINIMUM:
245     case OperationType::POW:
246     case OperationType::SQUARED_DIFF:
247     case OperationType::SUB:
248       return absl::make_unique<ElementwiseTwoArguments>(operation_type);
249     default:
250       return nullptr;
251   }
252 }
253 
254 }  // namespace gl
255 }  // namespace gpu
256 }  // namespace tflite
257