• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17 
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/core/status.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 
23 namespace tensorflow {
24 namespace tensorrt {
25 
TrtPrecisionModeToName(TrtPrecisionMode mode,string * name)26 Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) {
27   switch (mode) {
28     case TrtPrecisionMode::FP32:
29       *name = "FP32";
30       break;
31     case TrtPrecisionMode::FP16:
32       *name = "FP16";
33       break;
34     case TrtPrecisionMode::INT8:
35       *name = "INT8";
36       break;
37     default:
38       return errors::OutOfRange("Unknown precision mode");
39   }
40   return Status::OK();
41 }
42 
TrtPrecisionModeFromName(const string & name,TrtPrecisionMode * mode)43 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
44   if (name == "FP32") {
45     *mode = TrtPrecisionMode::FP32;
46   } else if (name == "FP16") {
47     *mode = TrtPrecisionMode::FP16;
48   } else if (name == "INT8") {
49     *mode = TrtPrecisionMode::INT8;
50   } else {
51     return errors::InvalidArgument("Invalid precision mode name: ", name);
52   }
53   return Status::OK();
54 }
55 
56 #if GOOGLE_CUDA && GOOGLE_TENSORRT
57 using absl::StrAppend;
58 using absl::StrCat;
59 
DebugString(const nvinfer1::DimensionType type)60 string DebugString(const nvinfer1::DimensionType type) {
61   switch (type) {
62     case nvinfer1::DimensionType::kSPATIAL:
63       return "kSPATIAL";
64     case nvinfer1::DimensionType::kCHANNEL:
65       return "kCHANNEL";
66     case nvinfer1::DimensionType::kINDEX:
67       return "kINDEX";
68     case nvinfer1::DimensionType::kSEQUENCE:
69       return "kSEQUENCE";
70     default:
71       return StrCat(static_cast<int>(type), "=unknown");
72   }
73 }
74 
DebugString(const nvinfer1::Dims & dims)75 string DebugString(const nvinfer1::Dims& dims) {
76   string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
77   for (int i = 0; i < dims.nbDims; ++i) {
78     StrAppend(&out, dims.d[i]);
79     if (VLOG_IS_ON(2)) {
80       StrAppend(&out, "[", DebugString(dims.type[i]), "],");
81     } else {
82       StrAppend(&out, ",");
83     }
84   }
85   StrAppend(&out, ")");
86   return out;
87 }
88 
DebugString(const nvinfer1::DataType trt_dtype)89 string DebugString(const nvinfer1::DataType trt_dtype) {
90   switch (trt_dtype) {
91     case nvinfer1::DataType::kFLOAT:
92       return "kFLOAT";
93     case nvinfer1::DataType::kHALF:
94       return "kHALF";
95     case nvinfer1::DataType::kINT8:
96       return "kINT8";
97     case nvinfer1::DataType::kINT32:
98       return "kINT32";
99     default:
100       return "Invalid TRT data type";
101   }
102 }
103 
DebugString(const nvinfer1::Permutation & permutation,int len)104 string DebugString(const nvinfer1::Permutation& permutation, int len) {
105   string out = "nvinfer1::Permutation(";
106   for (int i = 0; i < len; ++i) {
107     StrAppend(&out, permutation.order[i], ",");
108   }
109   StrAppend(&out, ")");
110   return out;
111 }
112 
DebugString(const nvinfer1::ITensor & tensor)113 string DebugString(const nvinfer1::ITensor& tensor) {
114   return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
115                 ", name=", tensor.getName(),
116                 ", dtype=", DebugString(tensor.getType()),
117                 ", dims=", DebugString(tensor.getDimensions()), ")");
118 }
119 
120 #endif
121 
GetLinkedTensorRTVersion()122 string GetLinkedTensorRTVersion() {
123   int major, minor, patch;
124 #if GOOGLE_CUDA && GOOGLE_TENSORRT
125   major = NV_TENSORRT_MAJOR;
126   minor = NV_TENSORRT_MINOR;
127   patch = NV_TENSORRT_PATCH;
128 #else
129   major = 0;
130   minor = 0;
131   patch = 0;
132 #endif
133   return absl::StrCat(major, ".", minor, ".", patch);
134 }
135 
GetLoadedTensorRTVersion()136 string GetLoadedTensorRTVersion() {
137   int major, minor, patch;
138 #if GOOGLE_CUDA && GOOGLE_TENSORRT
139   int ver = getInferLibVersion();
140   major = ver / 1000;
141   ver = ver - major * 1000;
142   minor = ver / 100;
143   patch = ver - minor * 100;
144 #else
145   major = 0;
146   minor = 0;
147   patch = 0;
148 #endif
149   return absl::StrCat(major, ".", minor, ".", patch);
150 }
151 
152 }  // namespace tensorrt
153 }  // namespace tensorflow
154