1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/gl/kernels/add.h"
31 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
32 #include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h"
33 #include "tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h"
34 #include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h"
35 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
36 #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
37 #include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h"
38 #include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h"
39 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
40 #include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h"
41 #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
42 #include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
43 #include "tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.h"
44 #include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h"
45 #include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h"
46 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
47 #include "tensorflow/lite/delegates/gpu/gl/kernels/slice.h"
48 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
49 #include "tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h"
50
51 #ifndef TFLITE_GPU_BINARY_RELEASE
52 #include "tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h"
53 #endif // TFLITE_GPU_BINARY_RELEASE
54
55 namespace tflite {
56 namespace gpu {
57 namespace gl {
58 namespace {
59
60 class Registry : public NodeShader {
61 public:
Registry()62 Registry() {
63 using Type = OperationType;
64 using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>;
65
66 const auto insert_op = [&](Type type, NewShaderFunc func) {
67 shaders_[ToString(type)].push_back(func());
68 };
69 const auto insert_elementwise_op = [&](Type operation_type) {
70 shaders_[ToString(operation_type)].push_back(
71 NewElementwiseNodeShader(operation_type));
72 };
73
74 insert_op(Type::ADD, NewAddNodeShader);
75 insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
76 insert_op(Type::CONCAT, NewFlatConcatNodeShader);
77 insert_op(Type::CONCAT, NewConcatNodeShader);
78 insert_op(Type::CONVOLUTION_2D, NewConvolution1x1NodeShader);
79 insert_op(Type::CONVOLUTION_2D, NewConvolutionNodeShader);
80 insert_op(Type::CONVOLUTION_TRANSPOSED, NewConvolutionTransposedNodeShader);
81 insert_op(Type::DEPTHWISE_CONVOLUTION, NewDepthwiseConvolutionNodeShader);
82 insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
83 insert_op(Type::LSTM, NewLstmNodeShader);
84 insert_op(Type::MEAN, NewMeanNodeShader);
85 // TODO(b/162763635): implement MeanStddevNormalization for OpenGL.
86 insert_op(Type::MUL, NewMultiplyNodeShader);
87 insert_op(Type::PAD, NewPadNodeShader);
88 insert_op(Type::POOLING_2D, NewPoolingNodeShader);
89 insert_op(Type::PRELU, NewPReLUNodeShader);
90 insert_op(Type::QUANTIZE_AND_DEQUANTIZE,
91 NewQuantizeAndDequantizeNodeShader);
92 insert_op(Type::RELU, NewReLUNodeShader);
93 insert_op(Type::RESIZE, NewResizeNodeShader);
94 insert_op(Type::RESHAPE, NewReshapeNodeShader);
95 insert_op(Type::SLICE, NewSliceNodeShader);
96 insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
97
98 insert_elementwise_op(Type::ABS);
99 insert_elementwise_op(Type::COPY);
100 insert_elementwise_op(Type::COS);
101 insert_elementwise_op(Type::DIV);
102 insert_elementwise_op(Type::ELU);
103 insert_elementwise_op(Type::EXP);
104 insert_elementwise_op(Type::HARD_SWISH);
105 insert_elementwise_op(Type::LOG);
106 insert_elementwise_op(Type::NEG);
107 insert_elementwise_op(Type::MAXIMUM);
108 insert_elementwise_op(Type::MINIMUM);
109 insert_elementwise_op(Type::POW);
110 insert_elementwise_op(Type::RSQRT);
111 insert_elementwise_op(Type::SIGMOID);
112 insert_elementwise_op(Type::SIN);
113 insert_elementwise_op(Type::SQRT);
114 insert_elementwise_op(Type::SQUARE);
115 insert_elementwise_op(Type::SQUARED_DIFF);
116 insert_elementwise_op(Type::SUB);
117 insert_elementwise_op(Type::TANH);
118
119 #ifndef TFLITE_GPU_BINARY_RELEASE
120 insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);
121 RegisterCustomOps(&shaders_);
122 #endif // TFLITE_GPU_BINARY_RELEASE
123 }
124
125 ~Registry() final = default;
126
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const127 absl::Status GenerateCode(const GenerationContext& ctx,
128 GeneratedCode* generated_code) const final {
129 auto it = shaders_.find(ctx.op_type);
130 if (it == shaders_.end()) {
131 return absl::NotFoundError(
132 absl::StrCat("No shader implementation for ", ctx.op_type));
133 }
134 std::vector<std::string> errors;
135 for (const auto& shader : it->second) {
136 const auto status = shader->GenerateCode(ctx, generated_code);
137 // Return the first suitable shader.
138 if (status.ok()) return absl::OkStatus();
139 errors.push_back(std::string(status.message()));
140 }
141 return errors.empty() ? absl::OkStatus()
142 : absl::UnknownError(absl::StrJoin(errors, ", "));
143 }
144
145 private:
146 absl::flat_hash_map<std::string, std::vector<std::unique_ptr<NodeShader>>>
147 shaders_;
148 };
149
150 } // namespace
151
NewNodeShaderRegistry()152 std::unique_ptr<NodeShader> NewNodeShaderRegistry() {
153 return absl::make_unique<Registry>();
154 }
155
156 } // namespace gl
157 } // namespace gpu
158 } // namespace tflite
159