• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_
17 #define MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_
18 #include <NvInfer.h>
19 #include <string>
20 #include "include/errorcode.h"
21 #include "src/extendrt/delegate/tensorrt/tensorrt_allocator.h"
22 #include "src/extendrt/delegate/tensorrt/cuda_impl/cublas_utils.h"
23 #include "src/common/log_adapter.h"
24 #define MAX_BATCH_SIZE 64
25 
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_OK;
28 
29 namespace mindspore::lite {
30 class TensorRTLogger : public nvinfer1::ILogger {
log(Severity severity,const char * msg)31   void log(Severity severity, const char *msg) noexcept override {
32     if (severity == Severity::kINTERNAL_ERROR || severity == Severity::kERROR) {
33       MS_LOG(WARNING) << msg;
34     } else if (severity == Severity::kWARNING) {
35       MS_LOG(WARNING) << msg;
36     } else if (severity == Severity::kINFO) {
37       MS_LOG(INFO) << msg;
38     } else {
39       MS_LOG(DEBUG) << msg;
40     }
41   }
42 };
43 
44 enum RuntimePrecisionMode : int { RuntimePrecisionMode_FP32, RuntimePrecisionMode_FP16 };
45 
46 class TensorRTRuntime {
47  public:
48   TensorRTRuntime() = default;
49 
50   ~TensorRTRuntime();
51 
52   int Init();
53 
GetBuilder()54   nvinfer1::IBuilder *GetBuilder() { return this->builder_; }
55 
GetBatchSize()56   int GetBatchSize() { return batch_size_; }
57 
SetBatchSize(int batch_size)58   void SetBatchSize(int batch_size) { batch_size_ = batch_size; }
59 
SetCudaStream(cudaStream_t stream,cublasHandle_t cublas_handle,cublasLtHandle_t cublaslt_handle)60   void SetCudaStream(cudaStream_t stream, cublasHandle_t cublas_handle, cublasLtHandle_t cublaslt_handle) {
61     allocator_->SetCudaStream(stream);
62     cublas_handle_ = cublas_handle;
63     cublaslt_handle_ = cublaslt_handle;
64   }
65 
GetRuntimePrecisionMode()66   RuntimePrecisionMode GetRuntimePrecisionMode() { return runtime_percision_mode_; }
67 
GetTransformerEncoderInputIdx()68   int GetTransformerEncoderInputIdx() { return transformer_encoder_input_idx_; }
69 
GetTransformerDecoderInputIdx()70   int GetTransformerDecoderInputIdx() { return transformer_decoder_input_idx_; }
71 
GetTransformerFfnFp16()72   bool GetTransformerFfnFp16() { return transformer_ffn_fp16_; }
73 
GetTransformerOptimize()74   std::string GetTransformerOptimize() const { return optimize_transformer_; }
75 
GetVslEncoderPluginId()76   int GetVslEncoderPluginId() { return vsl_encoder_plugin_id_; }
77 
GetVslDecoderPluginId()78   int GetVslDecoderPluginId() { return vsl_decoder_plugin_id_; }
79 
SetRuntimePrecisionMode(RuntimePrecisionMode runtime_percision_mode)80   void SetRuntimePrecisionMode(RuntimePrecisionMode runtime_percision_mode) {
81     runtime_percision_mode_ = runtime_percision_mode;
82   }
83 
SetTransformerEncoderInputIdx(int transformer_encoder_input_idx)84   void SetTransformerEncoderInputIdx(int transformer_encoder_input_idx) {
85     transformer_encoder_input_idx_ = transformer_encoder_input_idx;
86   }
87 
SetTransformerDecoderInputIdx(int transformer_decoder_input_idx)88   void SetTransformerDecoderInputIdx(int transformer_decoder_input_idx) {
89     transformer_decoder_input_idx_ = transformer_decoder_input_idx;
90   }
SetTransformerFfnFp16(bool is_ffn_fp16)91   void SetTransformerFfnFp16(bool is_ffn_fp16) { transformer_ffn_fp16_ = is_ffn_fp16; }
SetTransformerOptimize(const std::string & optimize_transformer)92   void SetTransformerOptimize(const std::string &optimize_transformer) { optimize_transformer_ = optimize_transformer; }
93 
IsTransformerOptimizeSigma()94   bool IsTransformerOptimizeSigma() {
95     std::string pangu_sigma("pangu_sigma");
96     return (optimize_transformer_ == pangu_sigma) ? true : false;
97   }
SetVslEncoderPluginId(int plugin_id)98   void SetVslEncoderPluginId(int plugin_id) { vsl_encoder_plugin_id_ = plugin_id; }
99 
SetVslDecoderPluginId(int plugin_id)100   void SetVslDecoderPluginId(int plugin_id) { vsl_decoder_plugin_id_ = plugin_id; }
101 
GetAllocator()102   TensorRTAllocator *GetAllocator() { return this->allocator_; }
103 
SetDeviceID(uint32_t device_id)104   void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
105 
GetDeviceID()106   uint32_t GetDeviceID() { return device_id_; }
GetCublasHandle()107   cublasHandle_t GetCublasHandle() { return cublas_handle_; }
GetCublasLtHandle()108   cublasLtHandle_t GetCublasLtHandle() { return cublaslt_handle_; }
109 
110  private:
111   bool is_init_{false};
112   nvinfer1::IBuilder *builder_{nullptr};
113   TensorRTLogger logger_;
114   TensorRTAllocator *allocator_{nullptr};
115   int batch_size_{0};
116   uint32_t device_id_{0};
117   RuntimePrecisionMode runtime_percision_mode_{RuntimePrecisionMode::RuntimePrecisionMode_FP32};
118   int transformer_encoder_input_idx_{-1};
119   int transformer_decoder_input_idx_{-1};
120   bool transformer_ffn_fp16_{true};
121   std::string optimize_transformer_{""};
122   int vsl_encoder_plugin_id_{-1};
123   int vsl_decoder_plugin_id_{-1};
124   cublasHandle_t cublas_handle_{nullptr};
125   cublasLtHandle_t cublaslt_handle_{nullptr};
126 };
127 }  // namespace mindspore::lite
128 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_DELEGATE_TENSORRT_TENSORRT_RUNTIME_H_
129