• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_OP_H_
17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_OP_H_
18 
19 #include <utility>
20 #include <NvInfer.h>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include "include/api/kernel.h"
25 #include "src/common/log_adapter.h"
26 #include "include/errorcode.h"
27 #include "src/extendrt/delegate/tensorrt/tensorrt_context.h"
28 #include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
29 #include "src/extendrt/delegate/tensorrt/op_registration_factory.h"
30 #include "src/extendrt/delegate/tensorrt/tensor_info.h"
31 #include "src/common/log_util.h"
32 #include "ops/base_operator.h"
33 #include "ops/op_name.h"
34 #include "kernel/kernel.h"
35 #include "include/api/types.h"
36 #include "mindapi/base/types.h"
37 
38 namespace mindspore::lite {
39 constexpr int INPUT_SIZE2 = 2;
40 constexpr int INPUT_SIZE3 = 3;
41 constexpr int INPUT_SIZE4 = 4;
42 constexpr int INPUT_SIZE5 = 5;
43 constexpr int INPUT_SIZE6 = 6;
44 constexpr int INPUT_SIZE7 = 7;
45 constexpr int INPUT_SIZE8 = 8;
46 constexpr int INPUT_SIZE9 = 9;
47 constexpr int INPUT_SIZE10 = 10;
48 
49 struct BindingHelper {
50   std::string name_;
51   const void *data_{nullptr};
52   nvinfer1::DataType data_type_;
53   size_t size_;
54   bool is_input_binding_{false};
55 };
56 
57 struct DynamicShapeParams {
58   bool support_dynamic_{true};
59   bool support_hw_dynamic_{true};
60 };
61 
62 class TensorRTRuntime;
63 
64 using BaseOperatorPtr = std::shared_ptr<ops::BaseOperator>;
65 
66 class TensorRTOp {
67  public:
68   TensorRTOp(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
69              const std::vector<TensorInfo> &out_tensors, std::string name);
70 
71   virtual ~TensorRTOp() = default;
72 
73   virtual int IsSupport(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &in_tensors,
74                         const std::vector<TensorInfo> &out_tensors) = 0;
75 
76   // The weight input has been processed internally by the operator. The framework does not
77   // need to process the weight input.
IsWeightInputHanledInner()78   virtual bool IsWeightInputHanledInner() const { return false; }
79 
80   virtual int AddInnerOp(TensorRTContext *ctx) = 0;
81 
82   virtual int SetInt8DynamicRange(TensorRTContext *ctx);
83 
84   virtual int Prepare(void **network_tensor_bindings, nvinfer1::ICudaEngine *engine);
85 
86   const BaseOperatorPtr &GetBaseOperator();
87 
88   virtual bool HasConst() const;
89 
90   int ReadyInputsNumber(TensorRTContext *ctx) const;
91 
92   std::string GetOpName();
93 
94   std::vector<TensorInfo> &inputs();
95 
96   ITensorHelper input(TensorRTContext *ctx, size_t i);
97 
98   ITensorHelper output(TensorRTContext *ctx, size_t i);
99 
100   std::vector<TensorInfo> &outputs();
101 
102   const std::string &type() const;
103 
104   schema::QuantType GetQuantType() const;
105 
106   void set_in_ops(const std::vector<TensorRTOp *> &in_ops);
107 
108   void set_out_ops(const std::vector<TensorRTOp *> &out_ops);
109 
110   const std::vector<TensorRTOp *> &in_ops() const;
111 
112   const std::vector<TensorRTOp *> &out_ops() const;
113 
114   void SetRuntime(TensorRTRuntime *runtime);
GetCublasHandle()115   cublasHandle_t GetCublasHandle() { return runtime_ ? runtime_->GetCublasHandle() : nullptr; }
GetCublasLtHandle()116   cublasLtHandle_t GetCublasLtHandle() { return runtime_ ? runtime_->GetCublasLtHandle() : nullptr; }
117 
118   DynamicShapeParams GetDynamicShapeParams() const;
119 
layer()120   nvinfer1::ILayer *layer() { return layer_; }
121 
122   bool GetSupportInputBool();
123   bool IsDynamicInput(TensorRTContext *ctx, size_t k);
124 
125   void SetSupportInputBool(bool support_input_bool);
126   template <class OpsT>
AsOps()127   std::shared_ptr<OpsT> AsOps() {
128     return std::make_shared<OpsT>(base_operator_->GetPrim());
129   }
130 
131   template <class OpsT>
AsOps(const BaseOperatorPtr & base_operator)132   static std::shared_ptr<OpsT> AsOps(const BaseOperatorPtr &base_operator) {
133     return std::make_shared<OpsT>(base_operator->GetPrim());
134   }
135   void PrintTrtInputs(TensorRTContext *ctx);
136   void PrintTrtOutputs(TensorRTContext *ctx);
137 
138  private:
139   int SetTransposeDynamicRange();
140 
141  protected:
142   bool IsShapeKnown();
143 
144   nvinfer1::ILayer *layer_ = nullptr;
145 
146   nvinfer1::IShuffleLayer *transpose_layer_ = nullptr;
147 
148   BaseOperatorPtr base_operator_ = nullptr;
149   std::vector<TensorInfo> in_tensors_;
150   std::vector<TensorInfo> out_tensors_;
151 
152   std::vector<TensorRTOp *> in_ops_;
153 
154   std::vector<TensorRTOp *> out_ops_;
155 
156   std::string op_name_;
157 
158   std::string type_;
159 
160   schema::QuantType quant_type_ = schema::QuantType_QUANT_NONE;
161 
162   std::vector<BindingHelper> op_binding_tensor_;
163 
164   TensorRTRuntime *runtime_{nullptr};
165 
166   DynamicShapeParams dynamic_shape_params_;
167 
168   uint32_t device_id_{0};
169 
170   bool support_input_bool_{true};
171 };
172 
173 template <class T>
GetTensorRTOp(const BaseOperatorPtr & base_operator,const std::vector<TensorInfo> & inputs,const std::vector<TensorInfo> & outputs,const std::string & name)174 TensorRTOp *GetTensorRTOp(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &inputs,
175                           const std::vector<TensorInfo> &outputs, const std::string &name) {
176   auto *op = new (std::nothrow) T(base_operator, inputs, outputs, name);
177   if (op == nullptr) {
178     MS_LOG(WARNING) << "TensorRT is nullptr.";
179     return nullptr;
180   }
181 
182   auto ret = op->IsSupport(base_operator, inputs, outputs);
183   if (ret != RET_OK) {
184     MS_LOG(WARNING) << "TensorRT op is not supported: " << name;
185     delete op;
186     return nullptr;
187   }
188   return op;
189 }
190 typedef TensorRTOp *(*TensorRTGetOp)(const BaseOperatorPtr &base_operator, const std::vector<TensorInfo> &inputs,
191                                      const std::vector<TensorInfo> &outputs, const std::string &name);
192 
193 #define REGISTER_TENSORRT_CREATOR(KEY, TENSORRT_OP) \
194   REGISTER_CLASS_CREATOR(std::string, KEY, TensorRTGetOp, GetTensorRTOp<TENSORRT_OP>);
195 
196 using TensorRTRegistrationFactory = AutoRegistrationFactory<std::string, TensorRTGetOp>;
197 }  // namespace mindspore::lite
198 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_OP_TENSORRT_OP_H_
199