1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
19
20 #include <utility>
21 #include <map>
22 #include <vector>
23 #include <tuple>
24 #include <algorithm>
25 #include <memory>
26 #include <string>
27 #include <variant>
28 #include <NvInfer.h>
29 #include "utils/log_adapter.h"
30 #include "utils/singleton.h"
31 #include "utils/convert_utils_base.h"
32 #include "utils/shape_utils.h"
33 #include "ir/dtype/type.h"
34
35 namespace mindspore {
36 class TrtUtils {
37 public:
TrtDtypeToMsDtype(const nvinfer1::DataType & trt_dtype)38 static TypeId TrtDtypeToMsDtype(const nvinfer1::DataType &trt_dtype) {
39 static std::map<nvinfer1::DataType, TypeId> type_list = {{nvinfer1::DataType::kFLOAT, TypeId::kNumberTypeFloat32},
40 {nvinfer1::DataType::kHALF, TypeId::kNumberTypeFloat16},
41 {nvinfer1::DataType::kINT8, TypeId::kNumberTypeInt8},
42 {nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt32}};
43
44 auto iter = type_list.find(trt_dtype);
45 if (iter == type_list.end()) {
46 MS_LOG(EXCEPTION) << "Invalid Tensor-RT dtype: " << trt_dtype;
47 }
48 return iter->second;
49 }
50
MsDtypeToTrtDtype(const TypeId & ms_dtype)51 static std::variant<bool, nvinfer1::DataType> MsDtypeToTrtDtype(const TypeId &ms_dtype) {
52 static std::map<TypeId, nvinfer1::DataType> type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
53 {TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
54 {TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8},
55 {TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32},
56 {TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}};
57 auto iter = type_list.find(ms_dtype);
58 if (iter == type_list.end()) {
59 MS_LOG(WARNING) << "data type not support: " << ms_dtype;
60 return false;
61 }
62 return iter->second;
63 }
64
65 static nvinfer1::Dims MsDimsToTrtDims(const std::vector<size_t> &ms_shape, bool ignore_batch_dim = false) {
66 nvinfer1::Dims trt_dims;
67 size_t offset = ignore_batch_dim ? 1 : 0;
68 for (size_t i = offset; i < ms_shape.size(); ++i) {
69 trt_dims.d[i - offset] = SizeToInt(ms_shape[i]);
70 }
71 trt_dims.nbDims = ms_shape.size() - offset;
72 return trt_dims;
73 }
74
75 static nvinfer1::Dims MsDimsToTrtDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) {
76 nvinfer1::Dims trt_dims;
77 size_t offset = ignore_batch_dim ? 1 : 0;
78 for (size_t i = offset; i < ms_shape.size(); ++i) {
79 trt_dims.d[i - offset] = LongToInt(ms_shape[i]);
80 }
81 trt_dims.nbDims = ms_shape.size() - offset;
82 return trt_dims;
83 }
84
TrtDimsToMsDims(const nvinfer1::Dims & trt_dims)85 static ShapeVector TrtDimsToMsDims(const nvinfer1::Dims &trt_dims) {
86 ShapeVector shape;
87 std::transform(trt_dims.d, trt_dims.d + trt_dims.nbDims, std::back_inserter(shape),
88 [](const uint32_t &value) { return static_cast<int64_t>(value); });
89 return shape;
90 }
91
IsSameShape(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs)92 static bool IsSameShape(const nvinfer1::Dims &lhs, const nvinfer1::Dims &rhs) {
93 if (lhs.nbDims != rhs.nbDims) {
94 return false;
95 }
96
97 for (int32_t i = 0; i < lhs.nbDims; i++) {
98 if (lhs.d[i] != rhs.d[i]) {
99 return false;
100 }
101 }
102
103 return true;
104 }
105 };
106
107 class TrtLogger : public nvinfer1::ILogger {
108 public:
TrtLogger()109 TrtLogger() {
110 log_level_ = MsLogLevel::WARNING; // set default log level to WARNING
111 const char *glog_config = std::getenv("GLOG_v");
112 if (glog_config == nullptr) {
113 return;
114 }
115
116 std::string str_level{glog_config};
117 if (str_level.size() == 1) {
118 int ch = str_level.c_str()[0];
119 ch = ch - '0'; // subtract ASCII code of '0', which is 48
120 if (ch >= mindspore::DEBUG && ch <= mindspore::ERROR) {
121 log_level_ = static_cast<MsLogLevel>(ch);
122 }
123 }
124 }
125 // Redirect Tensor-RT inner log to GLOG
log(Severity severity,const char * msg)126 void log(Severity severity, const char *msg) override {
127 #ifdef USE_GLOG
128 #define google mindspore_private
129 static std::map<Severity, std::tuple<MsLogLevel, int, std::string>> logger_map = {
130 {Severity::kVERBOSE, {MsLogLevel::DEBUG, google::GLOG_INFO, "VERBOSE"}},
131 {Severity::kINFO, {MsLogLevel::INFO, google::GLOG_INFO, "INFO"}},
132 {Severity::kWARNING, {MsLogLevel::WARNING, google::GLOG_WARNING, "WARNING"}},
133 {Severity::kERROR, {MsLogLevel::ERROR, google::GLOG_ERROR, "ERROR"}},
134 {Severity::kINTERNAL_ERROR, {MsLogLevel::ERROR, google::GLOG_ERROR, "INTERNAL ERROR"}}};
135
136 auto iter = logger_map.find(severity);
137 if (iter == logger_map.end()) {
138 google::LogMessage("", 0, google::GLOG_WARNING).stream() << "Unrecognized severity type: " << msg << std::endl;
139 return;
140 }
141
142 auto level = iter->second;
143 // discard log
144 if (std::get<0>(level) < log_level_) {
145 return;
146 }
147
148 google::LogMessage("", 0, std::get<1>(level)).stream()
149 << "[TensorRT " << std::get<2>(level) << "] " << msg << std::endl;
150 #undef google
151 #endif // USE_GLOG
152 }
153
154 private:
155 MsLogLevel log_level_;
156 };
157
158 // Using RAII to avoid tensor-rt object leakage
159 template <typename T>
TrtPtr(T * obj)160 inline std::shared_ptr<T> TrtPtr(T *obj) {
161 return std::shared_ptr<T>(obj, [](T *obj) {
162 if (obj) obj->destroy();
163 });
164 }
165
166 #define TRT_VARIANT_CHECK(input, expect, ret) \
167 do { \
168 if ((input.index()) != (expect)) { \
169 return ret; \
170 } \
171 } while (0)
172 } // namespace mindspore
173 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
174