1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/core/status.h"
20 #include "tensorflow/core/platform/errors.h"
21
22 namespace tensorflow {
23 namespace tensorrt {
24
TrtPrecisionModeToName(const TrtPrecisionMode mode,string * name)25 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name) {
26 switch (mode) {
27 case TrtPrecisionMode::FP32:
28 *name = "FP32";
29 break;
30 case TrtPrecisionMode::FP16:
31 *name = "FP16";
32 break;
33 case TrtPrecisionMode::INT8:
34 *name = "INT8";
35 break;
36 default:
37 *name = "UNKNOWN";
38 return errors::OutOfRange("Unknown precision mode");
39 }
40 return Status::OK();
41 }
42
TrtPrecisionModeFromName(const string & name,TrtPrecisionMode * mode)43 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
44 if (name == "FP32") {
45 *mode = TrtPrecisionMode::FP32;
46 } else if (name == "FP16") {
47 *mode = TrtPrecisionMode::FP16;
48 } else if (name == "INT8") {
49 *mode = TrtPrecisionMode::INT8;
50 } else {
51 return errors::InvalidArgument("Invalid precision mode name: ", name);
52 }
53 return Status::OK();
54 }
55
56 #if GOOGLE_CUDA && GOOGLE_TENSORRT
57
DebugString(const nvinfer1::DimensionType type)58 string DebugString(const nvinfer1::DimensionType type) {
59 switch (type) {
60 case nvinfer1::DimensionType::kSPATIAL:
61 return "kSPATIAL";
62 case nvinfer1::DimensionType::kCHANNEL:
63 return "kCHANNEL";
64 case nvinfer1::DimensionType::kINDEX:
65 return "kINDEX";
66 case nvinfer1::DimensionType::kSEQUENCE:
67 return "kSEQUENCE";
68 default:
69 return StrCat(static_cast<int>(type), "=unknown");
70 }
71 }
72
DebugString(const nvinfer1::Dims & dims)73 string DebugString(const nvinfer1::Dims& dims) {
74 string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
75 for (int i = 0; i < dims.nbDims; ++i) {
76 StrAppend(&out, dims.d[i]);
77 if (VLOG_IS_ON(2)) {
78 StrAppend(&out, "[", DebugString(dims.type[i]), "],");
79 } else {
80 StrAppend(&out, ",");
81 }
82 }
83 StrAppend(&out, ")");
84 return out;
85 }
86
DebugString(const DataType tf_type)87 string DebugString(const DataType tf_type) {
88 switch (tf_type) {
89 case DT_FLOAT:
90 return "DT_FLOAT";
91 case DT_HALF:
92 return "DT_HALF";
93 case DT_INT32:
94 return "DT_INT32";
95 case DT_INT8:
96 return "DT_INT8";
97 default:
98 return "Unknow TF DataType";
99 }
100 }
101
DebugString(const nvinfer1::DataType trt_dtype)102 string DebugString(const nvinfer1::DataType trt_dtype) {
103 switch (trt_dtype) {
104 case nvinfer1::DataType::kFLOAT:
105 return "kFLOAT";
106 case nvinfer1::DataType::kHALF:
107 return "kHALF";
108 case nvinfer1::DataType::kINT8:
109 return "kINT8";
110 case nvinfer1::DataType::kINT32:
111 return "kINT32";
112 default:
113 return "Invalid TRT data type";
114 }
115 }
116
DebugString(const TrtPrecisionMode mode)117 string DebugString(const TrtPrecisionMode mode) {
118 string mode_str;
119 TF_CHECK_OK(TrtPrecisionModeToName(mode, &mode_str));
120 return StrCat("TrtPrecisionMode::", mode_str);
121 }
122
DebugString(const nvinfer1::Permutation & permutation,int len)123 string DebugString(const nvinfer1::Permutation& permutation, int len) {
124 string out = "nvinfer1::Permutation(";
125 for (int i = 0; i < len; ++i) {
126 StrAppend(&out, permutation.order[i], ",");
127 }
128 StrAppend(&out, ")");
129 return out;
130 }
131
DebugString(const nvinfer1::ITensor & tensor)132 string DebugString(const nvinfer1::ITensor& tensor) {
133 return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
134 ", name=", tensor.getName(),
135 ", dtype=", DebugString(tensor.getType()),
136 ", dims=", DebugString(tensor.getDimensions()), ")");
137 }
138
DebugString(const std::vector<nvinfer1::Dims> & dimvec)139 string DebugString(const std::vector<nvinfer1::Dims>& dimvec) {
140 return absl::StrCat("[",
141 absl::StrJoin(dimvec, ",",
142 [](std::string* out, nvinfer1::Dims in) {
143 out->append(DebugString(in));
144 }),
145 "]");
146 }
147
DebugString(const std::vector<TensorShape> & shapes)148 string DebugString(const std::vector<TensorShape>& shapes) {
149 return TensorShapeUtils::ShapeListString(shapes);
150 }
151
DebugString(const std::vector<PartialTensorShape> & shapes)152 string DebugString(const std::vector<PartialTensorShape>& shapes) {
153 return PartialTensorShapeUtils::PartialShapeListString(shapes);
154 }
155
156 // Checks whether actual_shapes are compatible with cached_shapes. This should
157 // only be used in implicit batch mode (in explicit batch mode one needs to
158 // check the profile ranges). Therefore implicit batch mode is assumed.
159 // It is also assumed that both actual_shapes and cached_shapes have been
160 // verified by TRTEngineOp::VerifyInputShapes, which ensures that the batch size
161 // for all tensors are the same.
AreShapesCompatible(const std::vector<TensorShape> & actual_shapes,const std::vector<TensorShape> & cached_shapes)162 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
163 const std::vector<TensorShape>& cached_shapes) {
164 auto match_shape = [](const TensorShape& actual_shape,
165 const TensorShape& cached_shape) {
166 // Match the rank.
167 if (actual_shape.dims() != cached_shape.dims()) return false;
168 // Match the batch size. In implicit batch mode cached_shape.dim_size(0) is
169 // the max batch size, which can be larger than the actual batch size.
170 if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
171 // Match remaining dimensions.
172 for (int i = 1; i < actual_shape.dims(); ++i) {
173 if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
174 }
175 return true;
176 };
177 for (int i = 0; i < actual_shapes.size(); ++i) {
178 if (!match_shape(actual_shapes[i], cached_shapes[i])) {
179 return false;
180 }
181 }
182 return true;
183 }
GetNetworkInputShapes(const nvinfer1::INetworkDefinition * network,std::vector<PartialTensorShape> * input_shapes)184 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
185 std::vector<PartialTensorShape>* input_shapes) {
186 const int n_inputs = network->getNbInputs();
187 input_shapes->resize(n_inputs);
188 for (int i = 0; i < n_inputs; i++) {
189 const nvinfer1::ITensor* input = network->getInput(i);
190 const nvinfer1::Dims input_dim = input->getDimensions();
191 TF_RETURN_IF_ERROR(TrtDimsToTensorShape(input_dim, &input_shapes->at(i)));
192 }
193 return Status::OK();
194 }
TrtDimsToTensorShape(const std::vector<int> & trt_dims,TensorShape * shape,absl::optional<int> batch_size)195 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
196 TensorShape* shape,
197 absl::optional<int> batch_size) {
198 TF_RETURN_IF_ERROR(
199 TensorShapeUtils::MakeShape(trt_dims.data(), trt_dims.size(), shape));
200 if (batch_size) {
201 shape->InsertDim(0, batch_size.value());
202 }
203 return Status::OK();
204 }
205
TfTypeToTrtType(DataType tf_type,nvinfer1::DataType * trt_type)206 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
207 switch (tf_type) {
208 case DT_FLOAT:
209 *trt_type = nvinfer1::DataType::kFLOAT;
210 break;
211 case DT_HALF:
212 *trt_type = nvinfer1::DataType::kHALF;
213 break;
214 case DT_INT32:
215 *trt_type = nvinfer1::DataType::kINT32;
216 break;
217 default:
218 return errors::InvalidArgument("Unsupported tensorflow data type ",
219 DataTypeString(tf_type));
220 }
221 return Status::OK();
222 }
223
TrtTypeToTfType(nvinfer1::DataType trt_type,DataType * tf_type)224 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
225 switch (trt_type) {
226 case nvinfer1::DataType::kFLOAT:
227 *tf_type = DT_FLOAT;
228 break;
229 case nvinfer1::DataType::kHALF:
230 *tf_type = DT_HALF;
231 break;
232 case nvinfer1::DataType::kINT32:
233 *tf_type = DT_INT32;
234 break;
235 default:
236 return errors::InvalidArgument("Invalid TRT data type");
237 }
238 return Status::OK();
239 }
240
GetNumberOfEngineInputs(const nvinfer1::ICudaEngine * engine)241 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
242 int n_bindings = engine->getNbBindings();
243 int n_input = 0;
244 for (int i = 0; i < n_bindings; i++) {
245 if (engine->bindingIsInput(i)) n_input++;
246 }
247 // According to TensorRT 7 doc: "If the engine has been built for K profiles,
248 // the first getNbBindings() / K bindings are used by profile number 0, the
249 // following getNbBindings() / K bindings are used by profile number 1 etc."
250 // Therefore, to get the number of input tensors, we need to divide by the
251 // the number of profiles.
252 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
253 int n_profiles = engine->getNbOptimizationProfiles();
254 #else
255 int n_profiles = 1;
256 #endif
257 return n_input / n_profiles;
258 }
259
260 #endif
261
GetDeviceName(const Node * node)262 absl::string_view GetDeviceName(const Node* node) {
263 if (node->has_assigned_device_name()) {
264 return node->assigned_device_name();
265 }
266 return node->requested_device();
267 }
268
GetDeviceParsedName(const Node * node)269 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
270 const Node* node) {
271 absl::string_view device_name = GetDeviceName(node);
272 DeviceNameUtils::ParsedName parsed_name;
273 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
274 return absl::nullopt;
275 }
276 return parsed_name;
277 }
278
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,const DeviceNameUtils::ParsedName & b)279 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
280 const DeviceNameUtils::ParsedName& a,
281 const DeviceNameUtils::ParsedName& b) {
282 DeviceNameUtils::ParsedName merged_name = a;
283 if (!DeviceNameUtils::MergeDevNames(&merged_name, b,
284 /*allow_soft_placement=*/false)
285 .ok()) {
286 return absl::nullopt;
287 }
288 return merged_name;
289 }
290
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,absl::string_view b)291 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
292 const DeviceNameUtils::ParsedName& a, absl::string_view b) {
293 DeviceNameUtils::ParsedName b_parsed_name;
294 if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) {
295 return absl::nullopt;
296 }
297
298 return MergeIfCompatible(a, b_parsed_name);
299 }
300
301 } // namespace tensorrt
302 } // namespace tensorflow
303