• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/lite/delegates/gpu/common/operations.h"
29 #include "tensorflow/lite/delegates/gpu/common/status.h"
30 #include "tensorflow/lite/delegates/gpu/gl/kernels/add.h"
31 #include "tensorflow/lite/delegates/gpu/gl/kernels/concat.h"
32 #include "tensorflow/lite/delegates/gpu/gl/kernels/conv.h"
33 #include "tensorflow/lite/delegates/gpu/gl/kernels/custom_registry.h"
34 #include "tensorflow/lite/delegates/gpu/gl/kernels/depthwise_conv.h"
35 #include "tensorflow/lite/delegates/gpu/gl/kernels/elementwise.h"
36 #include "tensorflow/lite/delegates/gpu/gl/kernels/fully_connected.h"
37 #include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h"
38 #include "tensorflow/lite/delegates/gpu/gl/kernels/mean.h"
39 #include "tensorflow/lite/delegates/gpu/gl/kernels/mul.h"
40 #include "tensorflow/lite/delegates/gpu/gl/kernels/pad.h"
41 #include "tensorflow/lite/delegates/gpu/gl/kernels/pooling.h"
42 #include "tensorflow/lite/delegates/gpu/gl/kernels/prelu.h"
43 #include "tensorflow/lite/delegates/gpu/gl/kernels/quantize_and_dequantize.h"
44 #include "tensorflow/lite/delegates/gpu/gl/kernels/relu.h"
45 #include "tensorflow/lite/delegates/gpu/gl/kernels/reshape.h"
46 #include "tensorflow/lite/delegates/gpu/gl/kernels/resize.h"
47 #include "tensorflow/lite/delegates/gpu/gl/kernels/slice.h"
48 #include "tensorflow/lite/delegates/gpu/gl/kernels/softmax.h"
49 #include "tensorflow/lite/delegates/gpu/gl/kernels/transpose_conv.h"
50 
51 #ifndef TFLITE_GPU_BINARY_RELEASE
52 #include "tensorflow/lite/delegates/gpu/gl/kernels/max_unpooling.h"
53 #endif  // TFLITE_GPU_BINARY_RELEASE
54 
55 namespace tflite {
56 namespace gpu {
57 namespace gl {
58 namespace {
59 
60 class Registry : public NodeShader {
61  public:
Registry()62   Registry() {
63     using Type = OperationType;
64     using NewShaderFunc = std::function<std::unique_ptr<NodeShader>()>;
65 
66     const auto insert_op = [&](Type type, NewShaderFunc func) {
67       shaders_[ToString(type)].push_back(func());
68     };
69     const auto insert_elementwise_op = [&](Type operation_type) {
70       shaders_[ToString(operation_type)].push_back(
71           NewElementwiseNodeShader(operation_type));
72     };
73 
74     insert_op(Type::ADD, NewAddNodeShader);
75     insert_op(Type::CONCAT, NewAlignedConcatNodeShader);
76     insert_op(Type::CONCAT, NewFlatConcatNodeShader);
77     insert_op(Type::CONCAT, NewConcatNodeShader);
78     insert_op(Type::CONVOLUTION_2D, NewConvolution1x1NodeShader);
79     insert_op(Type::CONVOLUTION_2D, NewConvolutionNodeShader);
80     insert_op(Type::CONVOLUTION_TRANSPOSED, NewConvolutionTransposedNodeShader);
81     insert_op(Type::DEPTHWISE_CONVOLUTION, NewDepthwiseConvolutionNodeShader);
82     insert_op(Type::FULLY_CONNECTED, NewFullyConnectedNodeShader);
83     insert_op(Type::LSTM, NewLstmNodeShader);
84     insert_op(Type::MEAN, NewMeanNodeShader);
85     // TODO(b/162763635): implement MeanStddevNormalization for OpenGL.
86     insert_op(Type::MUL, NewMultiplyNodeShader);
87     insert_op(Type::PAD, NewPadNodeShader);
88     insert_op(Type::POOLING_2D, NewPoolingNodeShader);
89     insert_op(Type::PRELU, NewPReLUNodeShader);
90     insert_op(Type::QUANTIZE_AND_DEQUANTIZE,
91               NewQuantizeAndDequantizeNodeShader);
92     insert_op(Type::RELU, NewReLUNodeShader);
93     insert_op(Type::RESIZE, NewResizeNodeShader);
94     insert_op(Type::RESHAPE, NewReshapeNodeShader);
95     insert_op(Type::SLICE, NewSliceNodeShader);
96     insert_op(Type::SOFTMAX, NewSoftmaxNodeShader);
97 
98     insert_elementwise_op(Type::ABS);
99     insert_elementwise_op(Type::COPY);
100     insert_elementwise_op(Type::COS);
101     insert_elementwise_op(Type::DIV);
102     insert_elementwise_op(Type::ELU);
103     insert_elementwise_op(Type::EXP);
104     insert_elementwise_op(Type::HARD_SWISH);
105     insert_elementwise_op(Type::LOG);
106     insert_elementwise_op(Type::NEG);
107     insert_elementwise_op(Type::MAXIMUM);
108     insert_elementwise_op(Type::MINIMUM);
109     insert_elementwise_op(Type::POW);
110     insert_elementwise_op(Type::RSQRT);
111     insert_elementwise_op(Type::SIGMOID);
112     insert_elementwise_op(Type::SIN);
113     insert_elementwise_op(Type::SQRT);
114     insert_elementwise_op(Type::SQUARE);
115     insert_elementwise_op(Type::SQUARED_DIFF);
116     insert_elementwise_op(Type::SUB);
117     insert_elementwise_op(Type::TANH);
118 
119 #ifndef TFLITE_GPU_BINARY_RELEASE
120     insert_op(Type::MAX_UNPOOLING_2D, NewMaxUnpoolingNodeShader);
121     RegisterCustomOps(&shaders_);
122 #endif  // TFLITE_GPU_BINARY_RELEASE
123   }
124 
125   ~Registry() final = default;
126 
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const127   absl::Status GenerateCode(const GenerationContext& ctx,
128                             GeneratedCode* generated_code) const final {
129     auto it = shaders_.find(ctx.op_type);
130     if (it == shaders_.end()) {
131       return absl::NotFoundError(
132           absl::StrCat("No shader implementation for ", ctx.op_type));
133     }
134     std::vector<std::string> errors;
135     for (const auto& shader : it->second) {
136       const auto status = shader->GenerateCode(ctx, generated_code);
137       // Return the first suitable shader.
138       if (status.ok()) return absl::OkStatus();
139       errors.push_back(std::string(status.message()));
140     }
141     return errors.empty() ? absl::OkStatus()
142                           : absl::UnknownError(absl::StrJoin(errors, ", "));
143   }
144 
145  private:
146   absl::flat_hash_map<std::string, std::vector<std::unique_ptr<NodeShader>>>
147       shaders_;
148 };
149 
150 }  // namespace
151 
NewNodeShaderRegistry()152 std::unique_ptr<NodeShader> NewNodeShaderRegistry() {
153   return absl::make_unique<Registry>();
154 }
155 
156 }  // namespace gl
157 }  // namespace gpu
158 }  // namespace tflite
159