• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
19 
20 #include <utility>
21 #include <map>
22 #include <vector>
23 #include <tuple>
24 #include <algorithm>
25 #include <memory>
26 #include <string>
27 #include <variant>
28 #include <NvInfer.h>
29 #include "utils/log_adapter.h"
30 #include "utils/singleton.h"
31 #include "utils/convert_utils_base.h"
32 #include "utils/shape_utils.h"
33 #include "ir/dtype/type.h"
34 
35 namespace mindspore {
36 class TrtUtils {
37  public:
TrtDtypeToMsDtype(const nvinfer1::DataType & trt_dtype)38   static TypeId TrtDtypeToMsDtype(const nvinfer1::DataType &trt_dtype) {
39     static std::map<nvinfer1::DataType, TypeId> type_list = {{nvinfer1::DataType::kFLOAT, TypeId::kNumberTypeFloat32},
40                                                              {nvinfer1::DataType::kHALF, TypeId::kNumberTypeFloat16},
41                                                              {nvinfer1::DataType::kINT8, TypeId::kNumberTypeInt8},
42                                                              {nvinfer1::DataType::kINT32, TypeId::kNumberTypeInt32}};
43 
44     auto iter = type_list.find(trt_dtype);
45     if (iter == type_list.end()) {
46       MS_LOG(EXCEPTION) << "Invalid Tensor-RT dtype: " << trt_dtype;
47     }
48     return iter->second;
49   }
50 
MsDtypeToTrtDtype(const TypeId & ms_dtype)51   static std::variant<bool, nvinfer1::DataType> MsDtypeToTrtDtype(const TypeId &ms_dtype) {
52     static std::map<TypeId, nvinfer1::DataType> type_list = {{TypeId::kNumberTypeFloat32, nvinfer1::DataType::kFLOAT},
53                                                              {TypeId::kNumberTypeFloat16, nvinfer1::DataType::kHALF},
54                                                              {TypeId::kNumberTypeInt8, nvinfer1::DataType::kINT8},
55                                                              {TypeId::kNumberTypeInt, nvinfer1::DataType::kINT32},
56                                                              {TypeId::kNumberTypeInt32, nvinfer1::DataType::kINT32}};
57     auto iter = type_list.find(ms_dtype);
58     if (iter == type_list.end()) {
59       MS_LOG(WARNING) << "data type not support: " << ms_dtype;
60       return false;
61     }
62     return iter->second;
63   }
64 
65   static nvinfer1::Dims MsDimsToTrtDims(const std::vector<size_t> &ms_shape, bool ignore_batch_dim = false) {
66     nvinfer1::Dims trt_dims;
67     size_t offset = ignore_batch_dim ? 1 : 0;
68     for (size_t i = offset; i < ms_shape.size(); ++i) {
69       trt_dims.d[i - offset] = SizeToInt(ms_shape[i]);
70     }
71     trt_dims.nbDims = ms_shape.size() - offset;
72     return trt_dims;
73   }
74 
75   static nvinfer1::Dims MsDimsToTrtDims(const ShapeVector &ms_shape, bool ignore_batch_dim = false) {
76     nvinfer1::Dims trt_dims;
77     size_t offset = ignore_batch_dim ? 1 : 0;
78     for (size_t i = offset; i < ms_shape.size(); ++i) {
79       trt_dims.d[i - offset] = LongToInt(ms_shape[i]);
80     }
81     trt_dims.nbDims = ms_shape.size() - offset;
82     return trt_dims;
83   }
84 
TrtDimsToMsDims(const nvinfer1::Dims & trt_dims)85   static ShapeVector TrtDimsToMsDims(const nvinfer1::Dims &trt_dims) {
86     ShapeVector shape;
87     std::transform(trt_dims.d, trt_dims.d + trt_dims.nbDims, std::back_inserter(shape),
88                    [](const uint32_t &value) { return static_cast<int64_t>(value); });
89     return shape;
90   }
91 
IsSameShape(const nvinfer1::Dims & lhs,const nvinfer1::Dims & rhs)92   static bool IsSameShape(const nvinfer1::Dims &lhs, const nvinfer1::Dims &rhs) {
93     if (lhs.nbDims != rhs.nbDims) {
94       return false;
95     }
96 
97     for (int32_t i = 0; i < lhs.nbDims; i++) {
98       if (lhs.d[i] != rhs.d[i]) {
99         return false;
100       }
101     }
102 
103     return true;
104   }
105 };
106 
107 class TrtLogger : public nvinfer1::ILogger {
108  public:
TrtLogger()109   TrtLogger() {
110     log_level_ = MsLogLevel::kWarning;  // set default log level to WARNING
111     const char *glog_config = std::getenv("GLOG_v");
112     if (glog_config == nullptr) {
113       return;
114     }
115 
116     std::string str_level{glog_config};
117     if (str_level.size() == 1) {
118       int ch = str_level.c_str()[0];
119       ch = ch - '0';  // subtract ASCII code of '0', which is 48
120       if (ch >= MsLogLevel::kDebug && ch <= MsLogLevel::kException) {
121         log_level_ = static_cast<MsLogLevel>(ch);
122       }
123     }
124   }
125   // Redirect Tensor-RT inner log to GLOG
log(Severity severity,const char * msg)126   void log(Severity severity, const char *msg) noexcept override {
127 #ifdef USE_GLOG
128 #define google mindspore_private
129     static std::map<Severity, std::tuple<MsLogLevel, int, std::string>> logger_map = {
130       {Severity::kVERBOSE, {MsLogLevel::kDebug, google::GLOG_INFO, "VERBOSE"}},
131       {Severity::kINFO, {MsLogLevel::kInfo, google::GLOG_INFO, "INFO"}},
132       {Severity::kWARNING, {MsLogLevel::kWarning, google::GLOG_WARNING, "WARNING"}},
133       {Severity::kERROR, {MsLogLevel::kError, google::GLOG_ERROR, "ERROR"}},
134       {Severity::kINTERNAL_ERROR, {MsLogLevel::kError, google::GLOG_ERROR, "INTERNAL ERROR"}}};
135 
136     static const size_t kMsLogLevelIndex = 0;
137     static const size_t kGoogleLogLevelIndex = 1;
138     static const size_t kLogLevelDescriptionIndex = 2;
139 
140     auto iter = logger_map.find(severity);
141     if (iter == logger_map.end()) {
142       google::LogMessage("", 0, google::GLOG_WARNING).stream() << "Unrecognized severity type: " << msg << std::endl;
143       return;
144     }
145 
146     auto level = iter->second;
147     // discard log
148     if (std::get<kMsLogLevelIndex>(level) < log_level_) {
149       return;
150     }
151 
152     google::LogMessage("", 0, std::get<kGoogleLogLevelIndex>(level)).stream()
153       << "[TensorRT " << std::get<kLogLevelDescriptionIndex>(level) << "] " << msg << std::endl;
154 #undef google
155 #endif  // USE_GLOG
156   }
157 
158  private:
159   MsLogLevel log_level_;
160 };
161 
162 // Using RAII to avoid tensor-rt object leakage
163 template <typename T>
TrtPtr(T * obj)164 inline std::shared_ptr<T> TrtPtr(T *obj) {
165   return std::shared_ptr<T>(obj, [](T *obj) {
166     if (obj) obj->destroy();
167   });
168 }
169 
170 #define TRT_VARIANT_CHECK(input, expect, ret) \
171   do {                                        \
172     if ((input.index()) != (expect)) {        \
173       return ret;                             \
174     }                                         \
175   } while (0)
176 }  // namespace mindspore
177 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_TRT_UTILS_H_
178