• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
17 
18 #include <string>
19 
20 #include "absl/memory/memory.h"
21 #include "absl/strings/substitute.h"
22 #include "tensorflow/lite/delegates/gpu/common/status.h"
23 #include "tensorflow/lite/delegates/gpu/common/types.h"
24 
25 namespace tflite {
26 namespace gpu {
27 namespace gl {
28 namespace {
29 
30 class ElementwiseOneArgument : public NodeShader {
31  public:
ElementwiseOneArgument(OperationType operation_type)32   explicit ElementwiseOneArgument(OperationType operation_type)
33       : operation_type_(operation_type) {}
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const34   absl::Status GenerateCode(const GenerationContext& ctx,
35                             GeneratedCode* generated_code) const final {
36     std::string source;
37     switch (operation_type_) {
38       case OperationType::ABS:
39         source = "value_0 = abs(value_0);";
40         break;
41       case OperationType::COS:
42         source = "value_0 = cos(value_0);";
43         break;
44       case OperationType::COPY:
45         source = "value_0 = value_0;";
46         break;
47       case OperationType::ELU:
48         source = R"(
49             value_0.x = value_0.x < 0.0 ? exp(value_0.x) - 1.0 : value_0.x;
50             value_0.y = value_0.y < 0.0 ? exp(value_0.y) - 1.0 : value_0.y;
51             value_0.z = value_0.z < 0.0 ? exp(value_0.z) - 1.0 : value_0.z;
52             value_0.w = value_0.w < 0.0 ? exp(value_0.w) - 1.0 : value_0.w;
53         )";
54         break;
55       case OperationType::EXP:
56         source = "value_0 = exp(value_0);";
57         break;
58       case tflite::gpu::OperationType::FLOOR:
59         source = "value_0 = floor(value_0);";
60         break;
61       case OperationType::HARD_SWISH:
62         source =
63             "value_0 *= clamp(value_0 / 6.0 + vec4(0.5), vec4(0.0), "
64             "vec4(1.0));";
65         break;
66       case OperationType::LOG:
67         source = R"(
68             const float nan = normalize(vec4(0, 0, 0, 0)).x;
69             value_0.x = value_0.x > 0.0 ? log(value_0.x) : nan;
70             value_0.y = value_0.y > 0.0 ? log(value_0.y) : nan;
71             value_0.z = value_0.z > 0.0 ? log(value_0.z) : nan;
72             value_0.w = value_0.w > 0.0 ? log(value_0.w) : nan;
73         )";
74         break;
75       case OperationType::NEG:
76         source = "value_0 = -(value_0);";
77         break;
78       case OperationType::RSQRT:
79         source = R"(
80             const float nan = normalize(vec4(0, 0, 0, 0)).x;
81             value_0.x = value_0.x > 0.0 ? 1.0 / sqrt(value_0.x) : nan;
82             value_0.y = value_0.y > 0.0 ? 1.0 / sqrt(value_0.y) : nan;
83             value_0.z = value_0.z > 0.0 ? 1.0 / sqrt(value_0.z) : nan;
84             value_0.w = value_0.w > 0.0 ? 1.0 / sqrt(value_0.w) : nan;
85         )";
86         break;
87       case OperationType::SIGMOID:
88         source = "value_0 = 1.0 / (1.0 + exp(-1.0 * value_0));";
89         break;
90       case OperationType::SIN:
91         source = "value_0 = sin(value_0);";
92         break;
93       case OperationType::SQRT:
94         source = R"(
95             const float nan = normalize(vec4(0, 0, 0, 0)).x;
96             value_0.x = value_0.x >= 0.0 ? sqrt(value_0.x) : nan;
97             value_0.y = value_0.y >= 0.0 ? sqrt(value_0.y) : nan;
98             value_0.z = value_0.z >= 0.0 ? sqrt(value_0.z) : nan;
99             value_0.w = value_0.w >= 0.0 ? sqrt(value_0.w) : nan;
100         )";
101         break;
102       case OperationType::SQUARE:
103         source = "value_0 = value_0 * value_0;";
104         break;
105       case OperationType::TANH:
106         source = "value_0 = tanh(value_0);";
107         break;
108       default:
109         return absl::InvalidArgumentError(
110             "Incorrect elementwise operation type.");
111     }
112     *generated_code = {
113         /*parameters=*/{},
114         /*objects=*/{},
115         /*shared_variables=*/{},
116         /*workload=*/uint3(),
117         /*workgroup=*/uint3(),
118         source,
119         /*input=*/IOStructure::AUTO,
120         /*output=*/IOStructure::AUTO,
121     };
122     return absl::OkStatus();
123   }
124 
125  private:
126   OperationType operation_type_;
127 };
128 
129 class ElementwiseTwoArguments : public NodeShader {
130  public:
ElementwiseTwoArguments(OperationType operation_type)131   explicit ElementwiseTwoArguments(OperationType operation_type)
132       : operation_type_(operation_type) {}
133 
IsElementwiseSupported(const GenerationContext & ctx) const134   inline bool IsElementwiseSupported(const GenerationContext& ctx) const {
135     return ctx.input_shapes.size() == 2 &&
136            ctx.input_shapes[0] == ctx.input_shapes[1];
137   }
138 
IsBroadcastSupported(const GenerationContext & ctx) const139   inline bool IsBroadcastSupported(const GenerationContext& ctx) const {
140     return ctx.input_shapes.size() == 2 && ctx.input_shapes[1][1] == 1 &&
141            ctx.input_shapes[1][2] == 1 &&
142            ctx.input_shapes[0][3] == ctx.input_shapes[1][3];
143   }
144 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const145   absl::Status GenerateCode(const GenerationContext& ctx,
146                             GeneratedCode* generated_code) const final {
147     std::vector<Variable> parameters;
148     std::vector<std::pair<std::string, Object>> objects;
149     std::string argument0, argument1;
150     if (IsElementwiseSupported(ctx)) {
151       argument0 = "value_0";
152       argument1 = "value_1";
153     } else if (IsBroadcastSupported(ctx)) {
154       argument0 = "$input_data_0[gid.x, gid.y, gid.z]$";
155       argument1 = "$input_data_1[0, 0, gid.z]$";
156     } else {  // Scalar of const vector case
157       const auto& attr =
158           absl::any_cast<const ElementwiseAttributes&>(ctx.op_attr);
159       const auto* tensor =
160           absl::get_if<Tensor<Linear, DataType::FLOAT32>>(&attr.param);
161       const auto* scalar = absl::get_if<float>(&attr.param);
162       if (!tensor && !scalar) {
163         return absl::InvalidArgumentError(
164             "Couldn't read scalar of const vector data from the attributes.");
165       }
166 
167       argument0 = "value_0";
168       if (tensor) {
169         argument1 = "$const_data[gid.z]$";
170         objects.push_back({"const_data", MakeReadonlyObject(tensor->data)});
171       } else {
172         argument1 = "vec4($const_data$)";
173         parameters.push_back({"const_data", *scalar});
174       }
175     }
176 
177     std::string source;
178     switch (operation_type_) {
179       case OperationType::DIV: {
180         source = "value_0 = $0/$1;";
181         break;
182       }
183       case tflite::gpu::OperationType::FLOOR_DIV:
184         source = "value_0 = floor($0 / $1);";
185         break;
186       case tflite::gpu::OperationType::FLOOR_MOD:
187         source = "value_0 = $0 - floor($0 / $1) * $1;";
188         break;
189       case OperationType::MAXIMUM: {
190         source = "value_0 = max($0, $1);";
191         break;
192       }
193       case OperationType::MINIMUM: {
194         source = "value_0 = min($0, $1);";
195         break;
196       }
197       case OperationType::SQUARED_DIFF: {
198         source = "value_0 = ($0 - $1) * ($0 - $1);";
199         break;
200       }
201       case OperationType::SUB: {
202         source = "value_0 = $0 - $1;";
203         break;
204       }
205       case OperationType::POW: {
206         source = "value_0 = pow($0, $1);";
207         break;
208       }
209       default:
210         return absl::InvalidArgumentError(
211             "Incorrect elementwise with scalar operation type.");
212     }
213     source = absl::Substitute(source, argument0, argument1);
214     *generated_code = {
215         /*parameters=*/std::move(parameters),
216         /*objects=*/std::move(objects),
217         /*shared_variables=*/{},
218         /*workload=*/uint3(),
219         /*workgroup=*/uint3(),
220         /*source_code=*/source,
221         /*input=*/IOStructure::AUTO,
222         /*output=*/IOStructure::AUTO,
223     };
224     return absl::OkStatus();
225   }
226 
227  private:
228   OperationType operation_type_;
229 };
230 
231 }  // namespace
232 
NewElementwiseNodeShader(OperationType operation_type)233 std::unique_ptr<NodeShader> NewElementwiseNodeShader(
234     OperationType operation_type) {
235   switch (operation_type) {
236     case OperationType::ABS:
237     case OperationType::COS:
238     case OperationType::COPY:
239     case OperationType::ELU:
240     case OperationType::EXP:
241     case OperationType::FLOOR:
242     case OperationType::HARD_SWISH:
243     case OperationType::LOG:
244     case OperationType::NEG:
245     case OperationType::RSQRT:
246     case OperationType::SIGMOID:
247     case OperationType::SIN:
248     case OperationType::SQRT:
249     case OperationType::SQUARE:
250     case OperationType::TANH:
251       return absl::make_unique<ElementwiseOneArgument>(operation_type);
252     case OperationType::DIV:
253     case OperationType::FLOOR_DIV:
254     case OperationType::FLOOR_MOD:
255     case OperationType::MAXIMUM:
256     case OperationType::MINIMUM:
257     case OperationType::POW:
258     case OperationType::SQUARED_DIFF:
259     case OperationType::SUB:
260       return absl::make_unique<ElementwiseTwoArguments>(operation_type);
261     default:
262       return nullptr;
263   }
264 }
265 
266 }  // namespace gl
267 }  // namespace gpu
268 }  // namespace tflite
269