• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_SUB_GTAPH_
17 #define MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_SUB_GTAPH_
18 #include <utility>
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include "include/api/kernel.h"
24 #include "src/delegate/tensorrt/tensorrt_runtime.h"
25 #include "src/delegate/tensorrt/tensorrt_utils.h"
26 #include "include/api/context.h"
27 
28 namespace mindspore::lite {
29 using mindspore::lite::RET_ERROR;
30 using mindspore::lite::RET_OK;
31 class TensorRTSubGraph : public kernel::Kernel {
32  public:
TensorRTSubGraph(std::vector<TensorRTOp * > ops,const std::vector<mindspore::MSTensor> & inputs,const std::vector<mindspore::MSTensor> & outputs,const mindspore::Context * ctx,std::shared_ptr<GPUDeviceInfo> device_info,TensorRTRuntime * runtime,bool support_hw_resize)33   TensorRTSubGraph(std::vector<TensorRTOp *> ops, const std::vector<mindspore::MSTensor> &inputs,
34                    const std::vector<mindspore::MSTensor> &outputs, const mindspore::Context *ctx,
35                    std::shared_ptr<GPUDeviceInfo> device_info, TensorRTRuntime *runtime, bool support_hw_resize)
36       : kernel::Kernel(inputs, outputs, nullptr, ctx),
37         all_ops_(std::move(ops)),
38         device_info_(device_info),
39         runtime_(runtime) {
40     trt_specific_weight_nodes_ = {
41       schema::PrimitiveType_Conv2DFusion, schema::PrimitiveType_ReduceFusion,  schema::PrimitiveType_Transpose,
42       schema::PrimitiveType_Gather,       schema::PrimitiveType_Reshape,       schema::PrimitiveType_PowFusion,
43       schema::PrimitiveType_AddFusion,    schema::PrimitiveType_DivFusion,     schema::PrimitiveType_SubFusion,
44       schema::PrimitiveType_MatMulFusion, schema::PrimitiveType_PowFusion,     schema::PrimitiveType_Eltwise,
45       schema::PrimitiveType_ScaleFusion,  schema::PrimitiveType_MulFusion,     schema::PrimitiveType_StridedSlice,
46       schema::PrimitiveType_PadFusion,    schema::PrimitiveType_FullConnection};
47     if (!support_hw_resize) {
48       input_hw_index_ = -1;
49     }
50   }
51 
52   ~TensorRTSubGraph() override;
53 
54   int Prepare() override;
55 
56   int Execute() override;
57 
58   int ReSize();
59 
60   int BuildTensorRTGraph();
61 
62   int Init();
63 
64  private:
65   int BuildEngine();
66 
67   int SetDeviceConfig();
68 
69   bool SupportFP16();
70 
71   nvinfer1::ITensor *SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor);
72 
73   ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor);
74 
75   int MarkOutputs();
76 
77   std::vector<TensorRTOp *> all_ops_{};
78   // subgraph input nodes.
79   std::vector<TensorRTOp *> in_ops_{};
80   // subgraph output nodes.
81   std::vector<TensorRTOp *> out_ops_{};
82 
83   void **tensor_bindings_{nullptr};
84 
85   std::shared_ptr<GPUDeviceInfo> device_info_{nullptr};
86 
87   TensorRTRuntime *runtime_{nullptr};  // all subgraph in one delegate share a runtime_
88 
89   std::set<mindspore::schema::PrimitiveType> trt_specific_weight_nodes_;
90 
91   // save in/out tensor name for subgraph isolate.
92   std::vector<std::string> trt_in_tensor_name_;
93   std::vector<std::string> trt_out_tensor_name_;
94 
95   nvinfer1::INetworkDefinition *network_{nullptr};
96   nvinfer1::IBuilderConfig *config_{nullptr};
97   nvinfer1::ICudaEngine *engine_{nullptr};
98   nvinfer1::IExecutionContext *trt_context_{nullptr};
99   nvinfer1::IOptimizationProfile *profile_{nullptr};
100 
101   int input_batchsize_index_{0};
102   int output_batchsize_index_{0};
103 
104   // -1 means don't support hw resize
105   int input_hw_index_{0};
106 };
107 }  // namespace mindspore::lite
108 #endif  // MINDSPORE_LITE_SRC_RUNTIME_DELEGATE_TENSORRT_SUB_GTAPH_
109