1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17
18 #include "absl/strings/ascii.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/core/status.h"
21 #include "tensorflow/core/platform/errors.h"
22
23 namespace tensorflow {
24 namespace tensorrt {
25
TrtPrecisionModeToName(const TrtPrecisionMode mode,string * name)26 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name) {
27 switch (mode) {
28 case TrtPrecisionMode::FP32:
29 *name = "FP32";
30 break;
31 case TrtPrecisionMode::FP16:
32 *name = "FP16";
33 break;
34 case TrtPrecisionMode::INT8:
35 *name = "INT8";
36 break;
37 default:
38 *name = "UNKNOWN";
39 return errors::OutOfRange("Unknown precision mode");
40 }
41 return Status::OK();
42 }
43
TrtPrecisionModeFromName(const string & name,TrtPrecisionMode * mode)44 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
45 if (name == "FP32") {
46 *mode = TrtPrecisionMode::FP32;
47 } else if (name == "FP16") {
48 *mode = TrtPrecisionMode::FP16;
49 } else if (name == "INT8") {
50 *mode = TrtPrecisionMode::INT8;
51 } else {
52 return errors::InvalidArgument("Invalid precision mode name: ", name);
53 }
54 return Status::OK();
55 }
56
57 #if GOOGLE_CUDA && GOOGLE_TENSORRT
58
DebugString(const nvinfer1::Dims & dims)59 string DebugString(const nvinfer1::Dims& dims) {
60 string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
61 for (int i = 0; i < dims.nbDims; ++i) {
62 StrAppend(&out, dims.d[i]);
63 StrAppend(&out, ",");
64 }
65 StrAppend(&out, ")");
66 return out;
67 }
68
DebugString(const DataType tf_type)69 string DebugString(const DataType tf_type) {
70 switch (tf_type) {
71 case DT_FLOAT:
72 return "DT_FLOAT";
73 case DT_HALF:
74 return "DT_HALF";
75 case DT_INT32:
76 return "DT_INT32";
77 case DT_INT8:
78 return "DT_INT8";
79 default:
80 return "Unknow TF DataType";
81 }
82 }
83
DebugString(const nvinfer1::DataType trt_dtype)84 string DebugString(const nvinfer1::DataType trt_dtype) {
85 switch (trt_dtype) {
86 case nvinfer1::DataType::kFLOAT:
87 return "kFLOAT";
88 case nvinfer1::DataType::kHALF:
89 return "kHALF";
90 case nvinfer1::DataType::kINT8:
91 return "kINT8";
92 case nvinfer1::DataType::kINT32:
93 return "kINT32";
94 default:
95 return "Invalid TRT data type";
96 }
97 }
98
DebugString(const TrtPrecisionMode mode)99 string DebugString(const TrtPrecisionMode mode) {
100 string mode_str;
101 TF_CHECK_OK(TrtPrecisionModeToName(mode, &mode_str));
102 return StrCat("TrtPrecisionMode::", mode_str);
103 }
104
DebugString(const nvinfer1::Permutation & permutation,int len)105 string DebugString(const nvinfer1::Permutation& permutation, int len) {
106 string out = "nvinfer1::Permutation(";
107 for (int i = 0; i < len; ++i) {
108 StrAppend(&out, permutation.order[i], ",");
109 }
110 StrAppend(&out, ")");
111 return out;
112 }
113
DebugString(const ITensorProxyPtr & tensor)114 string DebugString(const ITensorProxyPtr& tensor) {
115 return DebugString(*tensor->trt_tensor());
116 }
117
DebugString(const nvinfer1::ITensor & tensor)118 string DebugString(const nvinfer1::ITensor& tensor) {
119 return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
120 ", name=", tensor.getName(),
121 ", dtype=", DebugString(tensor.getType()),
122 ", dims=", DebugString(tensor.getDimensions()), ")");
123 }
124
DebugString(const std::vector<nvinfer1::Dims> & dimvec)125 string DebugString(const std::vector<nvinfer1::Dims>& dimvec) {
126 return absl::StrCat("[",
127 absl::StrJoin(dimvec, ",",
128 [](std::string* out, nvinfer1::Dims in) {
129 out->append(DebugString(in));
130 }),
131 "]");
132 }
133
DebugString(const std::vector<TensorShape> & shapes)134 string DebugString(const std::vector<TensorShape>& shapes) {
135 return TensorShapeUtils::ShapeListString(shapes);
136 }
137
DebugString(const std::vector<PartialTensorShape> & shapes)138 string DebugString(const std::vector<PartialTensorShape>& shapes) {
139 return PartialTensorShapeUtils::PartialShapeListString(shapes);
140 }
141
142 // Checks whether actual_shapes are compatible with cached_shapes. This should
143 // only be used in implicit batch mode (in explicit batch mode one needs to
144 // check the profile ranges). Therefore implicit batch mode is assumed.
145 // It is also assumed that both actual_shapes and cached_shapes have been
146 // verified by TRTEngineOp::VerifyInputShapes, which ensures that the batch size
147 // for all tensors are the same.
AreShapesCompatible(const std::vector<TensorShape> & actual_shapes,const std::vector<TensorShape> & cached_shapes)148 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
149 const std::vector<TensorShape>& cached_shapes) {
150 auto match_shape = [](const TensorShape& actual_shape,
151 const TensorShape& cached_shape) {
152 // Match the rank.
153 if (actual_shape.dims() != cached_shape.dims()) return false;
154 // Match the batch size. In implicit batch mode cached_shape.dim_size(0) is
155 // the max batch size, which can be larger than the actual batch size.
156 if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
157 // Match remaining dimensions.
158 for (int i = 1; i < actual_shape.dims(); ++i) {
159 if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
160 }
161 return true;
162 };
163 for (int i = 0; i < actual_shapes.size(); ++i) {
164 if (!match_shape(actual_shapes[i], cached_shapes[i])) {
165 return false;
166 }
167 }
168 return true;
169 }
GetNetworkInputShapes(const nvinfer1::INetworkDefinition * network,std::vector<PartialTensorShape> * input_shapes)170 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
171 std::vector<PartialTensorShape>* input_shapes) {
172 const int n_inputs = network->getNbInputs();
173 input_shapes->resize(n_inputs);
174 for (int i = 0; i < n_inputs; i++) {
175 const ITensorProxyPtr input = network->getInput(i);
176 const nvinfer1::Dims input_dim = input->getDimensions();
177 TF_RETURN_IF_ERROR(TrtDimsToTensorShape(input_dim, &input_shapes->at(i)));
178 }
179 return Status::OK();
180 }
TrtDimsToTensorShape(const std::vector<int> & trt_dims,TensorShape * shape,absl::optional<int> batch_size)181 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
182 TensorShape* shape,
183 absl::optional<int> batch_size) {
184 TF_RETURN_IF_ERROR(
185 TensorShapeUtils::MakeShape(trt_dims.data(), trt_dims.size(), shape));
186 if (batch_size) {
187 shape->InsertDim(0, batch_size.value());
188 }
189 return Status::OK();
190 }
191
TfTypeToTrtType(DataType tf_type,nvinfer1::DataType * trt_type)192 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
193 switch (tf_type) {
194 case DT_FLOAT:
195 *trt_type = nvinfer1::DataType::kFLOAT;
196 break;
197 case DT_HALF:
198 *trt_type = nvinfer1::DataType::kHALF;
199 break;
200 case DT_INT32:
201 *trt_type = nvinfer1::DataType::kINT32;
202 break;
203 default:
204 return errors::InvalidArgument("Unsupported tensorflow data type ",
205 DataTypeString(tf_type));
206 }
207 return Status::OK();
208 }
209
TrtTypeToTfType(nvinfer1::DataType trt_type,DataType * tf_type)210 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
211 switch (trt_type) {
212 case nvinfer1::DataType::kFLOAT:
213 *tf_type = DT_FLOAT;
214 break;
215 case nvinfer1::DataType::kHALF:
216 *tf_type = DT_HALF;
217 break;
218 case nvinfer1::DataType::kINT32:
219 *tf_type = DT_INT32;
220 break;
221 default:
222 return errors::InvalidArgument("Invalid TRT data type");
223 }
224 return Status::OK();
225 }
226
GetNumberOfEngineInputs(const nvinfer1::ICudaEngine * engine)227 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
228 int n_bindings = engine->getNbBindings();
229 int n_input = 0;
230 for (int i = 0; i < n_bindings; i++) {
231 if (engine->bindingIsInput(i)) n_input++;
232 }
233 // According to TensorRT 7 doc: "If the engine has been built for K profiles,
234 // the first getNbBindings() / K bindings are used by profile number 0, the
235 // following getNbBindings() / K bindings are used by profile number 1 etc."
236 // Therefore, to get the number of input tensors, we need to divide by the
237 // the number of profiles.
238 int n_profiles = engine->getNbOptimizationProfiles();
239 return n_input / n_profiles;
240 }
241
ProfileStrategyToName(const ProfileStrategy strategy)242 string ProfileStrategyToName(const ProfileStrategy strategy) {
243 switch (strategy) {
244 case ProfileStrategy::kRange:
245 return "Range";
246 case ProfileStrategy::kOptimal:
247 return "Optimal";
248 case ProfileStrategy::kRangeOptimal:
249 return "Range+Optimal";
250 case ProfileStrategy::kImplicitBatchModeCompatible:
251 return "ImplicitBatchModeCompatible";
252 }
253 return "Unknown";
254 }
255
ProfileStrategyFromName(const string & name,ProfileStrategy * strategy)256 Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) {
257 string name_lowercase(name);
258 std::transform(name.begin(), name.end(), name_lowercase.begin(),
259 [](unsigned char c) { return std::tolower(c); });
260 if (name_lowercase == "range") {
261 *strategy = ProfileStrategy::kRange;
262 } else if (name_lowercase == "optimal") {
263 *strategy = ProfileStrategy::kOptimal;
264 } else if (name_lowercase == "range+optimal") {
265 *strategy = ProfileStrategy::kRangeOptimal;
266 } else if (name_lowercase == "implicitbatchmodecompatible") {
267 *strategy = ProfileStrategy::kImplicitBatchModeCompatible;
268 } else {
269 return errors::InvalidArgument("Invalid profile strategy: ", name);
270 }
271 return Status::OK();
272 }
273
274 #endif
275
GetDeviceName(const Node * node)276 absl::string_view GetDeviceName(const Node* node) {
277 if (node->has_assigned_device_name()) {
278 return node->assigned_device_name();
279 }
280 return node->requested_device();
281 }
282
GetDeviceParsedName(const Node * node)283 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
284 const Node* node) {
285 absl::string_view device_name = GetDeviceName(node);
286 DeviceNameUtils::ParsedName parsed_name;
287 if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
288 return absl::nullopt;
289 }
290 return parsed_name;
291 }
292
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,const DeviceNameUtils::ParsedName & b)293 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
294 const DeviceNameUtils::ParsedName& a,
295 const DeviceNameUtils::ParsedName& b) {
296 DeviceNameUtils::ParsedName merged_name = a;
297 if (!DeviceNameUtils::MergeDevNames(&merged_name, b,
298 /*allow_soft_placement=*/false)
299 .ok()) {
300 return absl::nullopt;
301 }
302 return merged_name;
303 }
304
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,absl::string_view b)305 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
306 const DeviceNameUtils::ParsedName& a, absl::string_view b) {
307 DeviceNameUtils::ParsedName b_parsed_name;
308 if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) {
309 return absl::nullopt;
310 }
311
312 return MergeIfCompatible(a, b_parsed_name);
313 }
314
315 } // namespace tensorrt
316 } // namespace tensorflow
317