1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/gl/kernels/add.h"
31 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
32 #include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h"
33 #include "tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h"
34 #include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h"
35 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
36 #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
37 #include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h"
38 #include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h"
39 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
40 #include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h"
41 #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
42 #include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
43 #include "tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.h"
44 #include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h"
45 #include "tensorflow/lite/delegates/gpu/gl/kernels/resampler.h"
46 #include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h"
47 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
48 #include "tensorflow/lite/delegates/gpu/gl/kernels/slice.h"
49 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
50 #include "tensorflow/lite/delegates/gpu/gl/kernels/space_to_depth.h"
51 #include "tensorflow/lite/delegates/gpu/gl/kernels/tile.h"
52 #include "tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h"
53
54 #ifndef TFLITE_GPU_BINARY_RELEASE
55 #include "tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h"
56 #endif // TFLITE_GPU_BINARY_RELEASE
57
58 namespace tflite {
59 namespace gpu {
60 namespace gl {
61 namespace {
62
63 class Registry : public NodeShader {
64 public:
Registry()65 Registry() {
66 using Type = OperationType;
67 using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>;
68
69 const auto insert_op = [&](Type type, NewShaderFunc func) {
70 shaders_[ToString(type)].push_back(func());
71 };
72 const auto insert_elementwise_op = [&](Type operation_type) {
73 shaders_[ToString(operation_type)].push_back(
74 NewElementwiseNodeShader(operation_type));
75 };
76
77 insert_op(Type::ADD, NewAddNodeShader);
78 insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
79 insert_op(Type::CONCAT, NewFlatConcatNodeShader);
80 insert_op(Type::CONCAT, NewConcatNodeShader);
81 insert_op(Type::CONVOLUTION_2D, NewConvolution1x1NodeShader);
82 insert_op(Type::CONVOLUTION_2D, NewConvolutionNodeShader);
83 insert_op(Type::CONVOLUTION_TRANSPOSED, NewConvolutionTransposedNodeShader);
84 insert_op(Type::DEPTHWISE_CONVOLUTION, NewDepthwiseConvolutionNodeShader);
85 insert_op(Type::DEPTH_TO_SPACE, NewDepthToSpaceNodeShader);
86 insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
87 insert_op(Type::LSTM, NewLstmNodeShader);
88 insert_op(Type::MEAN, NewMeanNodeShader);
89 // TODO(b/162763635): implement MeanStddevNormalization for OpenGL.
90 insert_op(Type::MUL, NewMultiplyNodeShader);
91 insert_op(Type::PAD, NewPadNodeShader);
92 insert_op(Type::POOLING_2D, NewPoolingNodeShader);
93 insert_op(Type::PRELU, NewPReLUNodeShader);
94 insert_op(Type::QUANTIZE_AND_DEQUANTIZE,
95 NewQuantizeAndDequantizeNodeShader);
96 insert_op(Type::RELU, NewReLUNodeShader);
97 insert_op(Type::RESAMPLER, NewResamplerNodeShader);
98 insert_op(Type::RESIZE, NewResizeNodeShader);
99 insert_op(Type::RESHAPE, NewReshapeNodeShader);
100 insert_op(Type::SLICE, NewSliceNodeShader);
101 insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
102 insert_op(Type::SPACE_TO_DEPTH, NewSpaceToDepthNodeShader);
103 insert_op(Type::TILE, NewTileNodeShader);
104
105 insert_elementwise_op(Type::ABS);
106 insert_elementwise_op(Type::COPY);
107 insert_elementwise_op(Type::COS);
108 insert_elementwise_op(Type::DIV);
109 insert_elementwise_op(Type::ELU);
110 insert_elementwise_op(Type::EXP);
111 insert_elementwise_op(Type::FLOOR);
112 insert_elementwise_op(Type::FLOOR_DIV);
113 insert_elementwise_op(Type::FLOOR_MOD);
114 insert_elementwise_op(Type::HARD_SWISH);
115 insert_elementwise_op(Type::LOG);
116 insert_elementwise_op(Type::NEG);
117 insert_elementwise_op(Type::MAXIMUM);
118 insert_elementwise_op(Type::MINIMUM);
119 insert_elementwise_op(Type::POW);
120 insert_elementwise_op(Type::RSQRT);
121 insert_elementwise_op(Type::SIGMOID);
122 insert_elementwise_op(Type::SIN);
123 insert_elementwise_op(Type::SQRT);
124 insert_elementwise_op(Type::SQUARE);
125 insert_elementwise_op(Type::SQUARED_DIFF);
126 insert_elementwise_op(Type::SUB);
127 insert_elementwise_op(Type::TANH);
128
129 #ifndef TFLITE_GPU_BINARY_RELEASE
130 insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);
131 RegisterCustomOps(&shaders_);
132 #endif // TFLITE_GPU_BINARY_RELEASE
133 }
134
135 ~Registry() final = default;
136
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const137 absl::Status GenerateCode(const GenerationContext& ctx,
138 GeneratedCode* generated_code) const final {
139 auto it = shaders_.find(ctx.op_type);
140 if (it == shaders_.end()) {
141 return absl::NotFoundError(
142 absl::StrCat("No shader implementation for ", ctx.op_type));
143 }
144 std::vector<std::string> errors;
145 for (const auto& shader : it->second) {
146 const auto status = shader->GenerateCode(ctx, generated_code);
147 // Return the first suitable shader.
148 if (status.ok()) return absl::OkStatus();
149 errors.push_back(std::string(status.message()));
150 }
151 return errors.empty() ? absl::OkStatus()
152 : absl::UnknownError(absl::StrJoin(errors, ", "));
153 }
154
155 private:
156 absl::flat_hash_map<std::string, std::vector<std::unique_ptr<NodeShader>>>
157 shaders_;
158 };
159
160 } // namespace
161
NewNodeShaderRegistry()162 std::unique_ptr<NodeShader> NewNodeShaderRegistry() {
163 return std::make_unique<Registry>();
164 }
165
166 } // namespace gl
167 } // namespace gpu
168 } // namespace tflite
169