1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_ 17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_ 18 #include <utility> 19 #include <set> 20 #include <map> 21 #include <string> 22 #include <vector> 23 #include <memory> 24 #include "include/api/kernel.h" 25 #include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h" 26 #include "src/extendrt/delegate/tensorrt/tensorrt_utils.h" 27 #include "src/extendrt/delegate/tensorrt/tensorrt_serializer.h" 28 #include "src/extendrt/delegate/tensorrt/op/tensorrt_op.h" 29 #include "src/extendrt/delegate/parameter_cache/embedding_cache_manager.h" 30 #include "include/api/context.h" 31 #include "common/config_infos.h" 32 33 namespace mindspore::lite { 34 using mindspore::lite::RET_ERROR; 35 using mindspore::lite::RET_OK; 36 struct CacheTensorInfo { 37 std::vector<mindspore::MSTensor> network_input_tensor_; 38 bool front_op_can_cache_; 39 }; 40 41 class TensorRTSubGraph { 42 public: 43 TensorRTSubGraph(std::vector<TensorRTOp *> ops, const std::vector<TensorInfo> &inputs, 44 const std::vector<TensorInfo> &outputs, const mindspore::Context *ctx, 45 std::shared_ptr<GPUDeviceInfo> device_info, TensorRTRuntime *runtime, bool support_resize, 46 bool support_hw_resize, const ProfileConfigs &trt_profile_config); 47 ~TensorRTSubGraph(); 48 49 int Prepare(); 50 51 int Execute(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs); 52 53 int Resize(const std::vector<tensor::Tensor> &inputs, const std::vector<ShapeVector> &new_shapes); 54 55 int BuildTensorRTGraph(); 56 57 int Init(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle); 58 SetSerializePath(const std::string & path)59 void SetSerializePath(const std::string &path) { serialize_file_path_ = std::move(path); } 60 61 int VSLPreExectute(const std::vector<tensor::Tensor> &inputs, int i, bool sync, const std::string &tensor_name); 62 inputs()63 std::vector<TensorInfo> &inputs() { return inputs_; } 64 outputs()65 std::vector<TensorInfo> &outputs() { return outputs_; } 66 67 private: 68 int GetInputIndexByName(const std::string &name); 69 int BuildEngine(); 70 71 int SetDeviceConfig(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle); 72 73 bool IsInt8Mode(); 74 75 bool SupportFP16(); 76 77 nvinfer1::ITensor *SetTensorRTNetworkInput(const TensorInfo &in_tensor, int index); 78 79 ITensorHelper FindTensorRTInputs(TensorRTOp *cur_op, const TensorInfo &in_tensor); 80 81 int MarkOutputs(); 82 83 bool IsCached(TensorRTOp *cur_op, const TensorInfo &in_tensor); 84 85 void FindCacheTensorInfo(TensorRTOp *cur_op, TensorInfo device_cache_tensor); 86 87 bool CanOpCache(TensorRTOp *cur_op); 88 89 int HandleCacheTensor(TensorRTOp *cur_op, const TensorInfo &in_tensor); 90 91 nvinfer1::Dims ParseInputDimsProfile(const TensorInfo &in_tensor, int index); 92 nvinfer1::Dims SetInputDimsProfile(const TensorInfo &in_tensor, int index); 93 int ParseInputsProfile(); 94 95 int PreExecute(const std::vector<tensor::Tensor> &inputs, const std::vector<tensor::Tensor> &outputs, 96 bool sync = true); 97 int PostExecute(std::vector<tensor::Tensor> *outputs, bool sync = true); 98 99 int OnNewInputShapes(const std::vector<ShapeVector> &inputs); 100 101 size_t MaxVolumnProfileIndex() const; 102 int SelectProfile(const std::vector<ShapeVector> &new_shapes) const; 103 int GetProfileBindingIndex(const std::string &name, size_t profile_index); 104 bool ValidInputResizeDims(const nvinfer1::Dims &construct_dims, const std::vector<int64_t> &resize_input_shape); 105 bool IsValidProfileDims() const; 106 107 std::string name_; 108 std::vector<TensorInfo> inputs_; 109 std::vector<TensorInfo> outputs_; 110 111 std::vector<TensorRTOp *> all_ops_{}; 112 // subgraph input nodes. 113 std::vector<TensorRTOp *> in_ops_{}; 114 // subgraph output nodes. 115 std::vector<TensorRTOp *> out_ops_{}; 116 117 void **tensor_bindings_{nullptr}; 118 119 std::shared_ptr<GPUDeviceInfo> device_info_{nullptr}; 120 121 TensorRTRuntime *runtime_{nullptr}; // all subgraph in one delegate share a runtime_ 122 123 std::set<std::string> trt_specific_weight_handled_inner_; 124 125 // save in/out tensor name for subgraph isolate. 126 std::vector<std::string> trt_in_tensor_name_; 127 std::vector<std::string> trt_out_tensor_name_; 128 129 nvinfer1::INetworkDefinition *network_{nullptr}; 130 nvinfer1::IBuilderConfig *config_{nullptr}; 131 nvinfer1::ICudaEngine *engine_{nullptr}; 132 nvinfer1::IExecutionContext *trt_context_{nullptr}; 133 134 TensorRTContext *ctx_; 135 136 // -1 means don't support resize 137 int input_batchsize_index_{0}; 138 int output_batchsize_index_{0}; 139 int input_hw_index_{0}; 140 141 std::map<std::string, std::vector<mindspore::MSTensor>> model_input_to_cache_tensors_; 142 143 std::shared_ptr<TensorRTSerializer> serializer_{nullptr}; 144 145 std::string serialize_file_path_; 146 cudaStream_t stream_{nullptr}; 147 148 std::vector<nvinfer1::IOptimizationProfile *> profiles_{}; 149 bool using_input_ranges_{false}; 150 ProfileConfigs trt_profile_config_; 151 size_t profile_index_{0}; 152 }; 153 } // namespace mindspore::lite 154 #endif // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_SUBGRAPH_H_ 155