• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_
17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_
18 #include <utility>
19 #include <set>
20 #include <map>
21 #include <string>
22 #include <vector>
23 #include <memory>
24 #include "include/api/kernel.h"
25 #include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h"
26 #include "src/extendrt/delegate/tensorrt/tensorrt_utils.h"
27 #include "src/extendrt/delegate/tensorrt/tensorrt_serializer.h"
28 #include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h"
29 #include "src/extendrt/delegate/parameter_cache/embedding_cache_manager.h"
30 #include "include/api/context.h"
31 #include "common/config_infos.h"
32 
33 namespace mindspore::lite {
34 using mindspore::lite::RET_ERROR;
35 using mindspore::lite::RET_OK;
36 struct CacheTensorInfo {
37   std::vector<mindspore::MSTensor> network_input_tensor_;
38   bool front_op_can_cache_;
39 };
40 
41 class TensorRTSubGraph {
42  public:
43   TensorRTSubGraph(std::vector<TensorRTOp *> ops, const std::vector<TensorInfo> &inputs,
44                    const std::vector<TensorInfo> &outputs, const mindspore::Context *ctx,
45                    std::shared_ptr<GPUDeviceInfo> device_info, TensorRTRuntime *runtime, bool support_resize,
46                    bool support_hw_resize, const ProfileConfigs &trt_profile_config);
47   ~TensorRTSubGraph();
48 
49   int Prepare();
50 
51   int Execute(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs);
52 
53   int Resize(const std::vector<tensor::Tensor> &inputs, const std::vector<ShapeVector> &new_shapes);
54 
55   int BuildTensorRTGraph();
56 
57   int Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle);
58 
SetSerializePath(const std::string & path)59   void SetSerializePath(const std::string &path) { serialize_file_path_ = std::move(path); }
60 
61   int VSLPreExectute(const std::vector<tensor::Tensor> &inputs, int i, bool sync, const std::string &tensor_name);
62 
inputs()63   std::vector<TensorInfo> &inputs() { return inputs_; }
64 
outputs()65   std::vector<TensorInfo> &outputs() { return outputs_; }
66 
67  private:
68   int GetInputIndexByName(const std::string &name);
69   int BuildEngine();
70 
71   int SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle);
72 
73   bool IsInt8Mode();
74 
75   bool SupportFP16();
76 
77   nvinfer1::ITensor *SetTensorRTNetworkInput(const TensorInfo &in_tensor, int index);
78 
79   ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const TensorInfo &in_tensor);
80 
81   int MarkOutputs();
82 
83   bool IsCached(TensorRTOp *cur_op, const TensorInfo &in_tensor);
84 
85   void FindCacheTensorInfo(TensorRTOp *cur_op, TensorInfo device_cache_tensor);
86 
87   bool CanOpCache(TensorRTOp *cur_op);
88 
89   int HandleCacheTensor(TensorRTOp *cur_op, const TensorInfo &in_tensor);
90 
91   nvinfer1::Dims ParseInputDimsProfile(const TensorInfo &in_tensor, int index);
92   nvinfer1::Dims SetInputDimsProfile(const TensorInfo &in_tensor, int index);
93   int ParseInputsProfile();
94 
95   int PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs,
96                  bool sync = true);
97   int PostExecute(std::vector<tensor::Tensor> *outputs, bool sync = true);
98 
99   int OnNewInputShapes(const std::vector<ShapeVector> &inputs);
100 
101   size_t MaxVolumnProfileIndex() const;
102   int SelectProfile(const std::vector<ShapeVector> &new_shapes) const;
103   int GetProfileBindingIndex(const std::string &name, size_t profile_index);
104   bool ValidInputResizeDims(const nvinfer1::Dims &construct_dims, const std::vector<int64_t> &resize_input_shape);
105   bool IsValidProfileDims() const;
106 
107   std::string name_;
108   std::vector<TensorInfo> inputs_;
109   std::vector<TensorInfo> outputs_;
110 
111   std::vector<TensorRTOp *> all_ops_{};
112   // subgraph input nodes.
113   std::vector<TensorRTOp *> in_ops_{};
114   // subgraph output nodes.
115   std::vector<TensorRTOp *> out_ops_{};
116 
117   void **tensor_bindings_{nullptr};
118 
119   std::shared_ptr<GPUDeviceInfo> device_info_{nullptr};
120 
121   TensorRTRuntime *runtime_{nullptr};  // all subgraph in one delegate share a runtime_
122 
123   std::set<std::string> trt_specific_weight_handled_inner_;
124 
125   // save in/out tensor name for subgraph isolate.
126   std::vector<std::string> trt_in_tensor_name_;
127   std::vector<std::string> trt_out_tensor_name_;
128 
129   nvinfer1::INetworkDefinition *network_{nullptr};
130   nvinfer1::IBuilderConfig *config_{nullptr};
131   nvinfer1::ICudaEngine *engine_{nullptr};
132   nvinfer1::IExecutionContext *trt_context_{nullptr};
133 
134   TensorRTContext *ctx_;
135 
136   // -1 means don't support resize
137   int input_batchsize_index_{0};
138   int output_batchsize_index_{0};
139   int input_hw_index_{0};
140 
141   std::map<std::string, std::vector<mindspore::MSTensor>> model_input_to_cache_tensors_;
142 
143   std::shared_ptr<TensorRTSerializer> serializer_{nullptr};
144 
145   std::string serialize_file_path_;
146   cudaStream_t stream_{nullptr};
147 
148   std::vector<nvinfer1::IOptimizationProfile *> profiles_{};
149   bool using_input_ranges_{false};
150   ProfileConfigs trt_profile_config_;
151   size_t profile_index_{0};
152 };
153 }  // namespace mindspore::lite
154 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_
155