1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/core/status.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22
23 namespace tensorflow {
24 namespace tensorrt {
25
TrtPrecisionModeToName(TrtPrecisionMode mode,string * name)26 Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) {
27 switch (mode) {
28 case TrtPrecisionMode::FP32:
29 *name = "FP32";
30 break;
31 case TrtPrecisionMode::FP16:
32 *name = "FP16";
33 break;
34 case TrtPrecisionMode::INT8:
35 *name = "INT8";
36 break;
37 default:
38 return errors::OutOfRange("Unknown precision mode");
39 }
40 return Status::OK();
41 }
42
TrtPrecisionModeFromName(const string & name,TrtPrecisionMode * mode)43 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
44 if (name == "FP32") {
45 *mode = TrtPrecisionMode::FP32;
46 } else if (name == "FP16") {
47 *mode = TrtPrecisionMode::FP16;
48 } else if (name == "INT8") {
49 *mode = TrtPrecisionMode::INT8;
50 } else {
51 return errors::InvalidArgument("Invalid precision mode name: ", name);
52 }
53 return Status::OK();
54 }
55
56 #if GOOGLE_CUDA && GOOGLE_TENSORRT
57 using absl::StrAppend;
58 using absl::StrCat;
59
DebugString(const nvinfer1::DimensionType type)60 string DebugString(const nvinfer1::DimensionType type) {
61 switch (type) {
62 case nvinfer1::DimensionType::kSPATIAL:
63 return "kSPATIAL";
64 case nvinfer1::DimensionType::kCHANNEL:
65 return "kCHANNEL";
66 case nvinfer1::DimensionType::kINDEX:
67 return "kINDEX";
68 case nvinfer1::DimensionType::kSEQUENCE:
69 return "kSEQUENCE";
70 default:
71 return StrCat(static_cast<int>(type), "=unknown");
72 }
73 }
74
DebugString(const nvinfer1::Dims & dims)75 string DebugString(const nvinfer1::Dims& dims) {
76 string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
77 for (int i = 0; i < dims.nbDims; ++i) {
78 StrAppend(&out, dims.d[i]);
79 if (VLOG_IS_ON(2)) {
80 StrAppend(&out, "[", DebugString(dims.type[i]), "],");
81 } else {
82 StrAppend(&out, ",");
83 }
84 }
85 StrAppend(&out, ")");
86 return out;
87 }
88
DebugString(const nvinfer1::DataType trt_dtype)89 string DebugString(const nvinfer1::DataType trt_dtype) {
90 switch (trt_dtype) {
91 case nvinfer1::DataType::kFLOAT:
92 return "kFLOAT";
93 case nvinfer1::DataType::kHALF:
94 return "kHALF";
95 case nvinfer1::DataType::kINT8:
96 return "kINT8";
97 case nvinfer1::DataType::kINT32:
98 return "kINT32";
99 default:
100 return "Invalid TRT data type";
101 }
102 }
103
DebugString(const nvinfer1::Permutation & permutation,int len)104 string DebugString(const nvinfer1::Permutation& permutation, int len) {
105 string out = "nvinfer1::Permutation(";
106 for (int i = 0; i < len; ++i) {
107 StrAppend(&out, permutation.order[i], ",");
108 }
109 StrAppend(&out, ")");
110 return out;
111 }
112
DebugString(const nvinfer1::ITensor & tensor)113 string DebugString(const nvinfer1::ITensor& tensor) {
114 return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
115 ", name=", tensor.getName(),
116 ", dtype=", DebugString(tensor.getType()),
117 ", dims=", DebugString(tensor.getDimensions()), ")");
118 }
119
120 #endif
121
GetLinkedTensorRTVersion()122 string GetLinkedTensorRTVersion() {
123 int major, minor, patch;
124 #if GOOGLE_CUDA && GOOGLE_TENSORRT
125 major = NV_TENSORRT_MAJOR;
126 minor = NV_TENSORRT_MINOR;
127 patch = NV_TENSORRT_PATCH;
128 #else
129 major = 0;
130 minor = 0;
131 patch = 0;
132 #endif
133 return absl::StrCat(major, ".", minor, ".", patch);
134 }
135
GetLoadedTensorRTVersion()136 string GetLoadedTensorRTVersion() {
137 int major, minor, patch;
138 #if GOOGLE_CUDA && GOOGLE_TENSORRT
139 int ver = getInferLibVersion();
140 major = ver / 1000;
141 ver = ver - major * 1000;
142 minor = ver / 100;
143 patch = ver - minor * 100;
144 #else
145 major = 0;
146 minor = 0;
147 patch = 0;
148 #endif
149 return absl::StrCat(major, ".", minor, ".", patch);
150 }
151
152 } // namespace tensorrt
153 } // namespace tensorflow
154