• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17 
18 #include "absl/strings/ascii.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/core/status.h"
21 #include "tensorflow/core/platform/errors.h"
22 
23 namespace tensorflow {
24 namespace tensorrt {
25 
TrtPrecisionModeToName(const TrtPrecisionMode mode,string * name)26 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name) {
27   switch (mode) {
28     case TrtPrecisionMode::FP32:
29       *name = "FP32";
30       break;
31     case TrtPrecisionMode::FP16:
32       *name = "FP16";
33       break;
34     case TrtPrecisionMode::INT8:
35       *name = "INT8";
36       break;
37     default:
38       *name = "UNKNOWN";
39       return errors::OutOfRange("Unknown precision mode");
40   }
41   return Status::OK();
42 }
43 
TrtPrecisionModeFromName(const string & name,TrtPrecisionMode * mode)44 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
45   if (name == "FP32") {
46     *mode = TrtPrecisionMode::FP32;
47   } else if (name == "FP16") {
48     *mode = TrtPrecisionMode::FP16;
49   } else if (name == "INT8") {
50     *mode = TrtPrecisionMode::INT8;
51   } else {
52     return errors::InvalidArgument("Invalid precision mode name: ", name);
53   }
54   return Status::OK();
55 }
56 
57 #if GOOGLE_CUDA && GOOGLE_TENSORRT
58 
DebugString(const nvinfer1::Dims & dims)59 string DebugString(const nvinfer1::Dims& dims) {
60   string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
61   for (int i = 0; i < dims.nbDims; ++i) {
62     StrAppend(&out, dims.d[i]);
63     StrAppend(&out, ",");
64   }
65   StrAppend(&out, ")");
66   return out;
67 }
68 
DebugString(const DataType tf_type)69 string DebugString(const DataType tf_type) {
70   switch (tf_type) {
71     case DT_FLOAT:
72       return "DT_FLOAT";
73     case DT_HALF:
74       return "DT_HALF";
75     case DT_INT32:
76       return "DT_INT32";
77     case DT_INT8:
78       return "DT_INT8";
79     default:
80       return "Unknow TF DataType";
81   }
82 }
83 
DebugString(const nvinfer1::DataType trt_dtype)84 string DebugString(const nvinfer1::DataType trt_dtype) {
85   switch (trt_dtype) {
86     case nvinfer1::DataType::kFLOAT:
87       return "kFLOAT";
88     case nvinfer1::DataType::kHALF:
89       return "kHALF";
90     case nvinfer1::DataType::kINT8:
91       return "kINT8";
92     case nvinfer1::DataType::kINT32:
93       return "kINT32";
94     default:
95       return "Invalid TRT data type";
96   }
97 }
98 
DebugString(const TrtPrecisionMode mode)99 string DebugString(const TrtPrecisionMode mode) {
100   string mode_str;
101   TF_CHECK_OK(TrtPrecisionModeToName(mode, &mode_str));
102   return StrCat("TrtPrecisionMode::", mode_str);
103 }
104 
DebugString(const nvinfer1::Permutation & permutation,int len)105 string DebugString(const nvinfer1::Permutation& permutation, int len) {
106   string out = "nvinfer1::Permutation(";
107   for (int i = 0; i < len; ++i) {
108     StrAppend(&out, permutation.order[i], ",");
109   }
110   StrAppend(&out, ")");
111   return out;
112 }
113 
DebugString(const ITensorProxyPtr & tensor)114 string DebugString(const ITensorProxyPtr& tensor) {
115   return DebugString(*tensor->trt_tensor());
116 }
117 
DebugString(const nvinfer1::ITensor & tensor)118 string DebugString(const nvinfer1::ITensor& tensor) {
119   return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
120                 ", name=", tensor.getName(),
121                 ", dtype=", DebugString(tensor.getType()),
122                 ", dims=", DebugString(tensor.getDimensions()), ")");
123 }
124 
DebugString(const std::vector<nvinfer1::Dims> & dimvec)125 string DebugString(const std::vector<nvinfer1::Dims>& dimvec) {
126   return absl::StrCat("[",
127                       absl::StrJoin(dimvec, ",",
128                                     [](std::string* out, nvinfer1::Dims in) {
129                                       out->append(DebugString(in));
130                                     }),
131                       "]");
132 }
133 
DebugString(const std::vector<TensorShape> & shapes)134 string DebugString(const std::vector<TensorShape>& shapes) {
135   return TensorShapeUtils::ShapeListString(shapes);
136 }
137 
DebugString(const std::vector<PartialTensorShape> & shapes)138 string DebugString(const std::vector<PartialTensorShape>& shapes) {
139   return PartialTensorShapeUtils::PartialShapeListString(shapes);
140 }
141 
142 // Checks whether actual_shapes are compatible with cached_shapes. This should
143 // only be used in implicit batch mode (in explicit batch mode one needs to
144 // check the profile ranges). Therefore implicit batch mode is assumed.
145 // It is also assumed that both actual_shapes and cached_shapes have been
146 // verified by TRTEngineOp::VerifyInputShapes, which ensures that the batch size
147 // for all tensors are the same.
AreShapesCompatible(const std::vector<TensorShape> & actual_shapes,const std::vector<TensorShape> & cached_shapes)148 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
149                          const std::vector<TensorShape>& cached_shapes) {
150   auto match_shape = [](const TensorShape& actual_shape,
151                         const TensorShape& cached_shape) {
152     // Match the rank.
153     if (actual_shape.dims() != cached_shape.dims()) return false;
154     // Match the batch size. In implicit batch mode cached_shape.dim_size(0) is
155     // the max batch size, which can be larger than the actual batch size.
156     if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
157     // Match remaining dimensions.
158     for (int i = 1; i < actual_shape.dims(); ++i) {
159       if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
160     }
161     return true;
162   };
163   for (int i = 0; i < actual_shapes.size(); ++i) {
164     if (!match_shape(actual_shapes[i], cached_shapes[i])) {
165       return false;
166     }
167   }
168   return true;
169 }
GetNetworkInputShapes(const nvinfer1::INetworkDefinition * network,std::vector<PartialTensorShape> * input_shapes)170 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
171                              std::vector<PartialTensorShape>* input_shapes) {
172   const int n_inputs = network->getNbInputs();
173   input_shapes->resize(n_inputs);
174   for (int i = 0; i < n_inputs; i++) {
175     const ITensorProxyPtr input = network->getInput(i);
176     const nvinfer1::Dims input_dim = input->getDimensions();
177     TF_RETURN_IF_ERROR(TrtDimsToTensorShape(input_dim, &input_shapes->at(i)));
178   }
179   return Status::OK();
180 }
TrtDimsToTensorShape(const std::vector<int> & trt_dims,TensorShape * shape,absl::optional<int> batch_size)181 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
182                             TensorShape* shape,
183                             absl::optional<int> batch_size) {
184   TF_RETURN_IF_ERROR(
185       TensorShapeUtils::MakeShape(trt_dims.data(), trt_dims.size(), shape));
186   if (batch_size) {
187     shape->InsertDim(0, batch_size.value());
188   }
189   return Status::OK();
190 }
191 
TfTypeToTrtType(DataType tf_type,nvinfer1::DataType * trt_type)192 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
193   switch (tf_type) {
194     case DT_FLOAT:
195       *trt_type = nvinfer1::DataType::kFLOAT;
196       break;
197     case DT_HALF:
198       *trt_type = nvinfer1::DataType::kHALF;
199       break;
200     case DT_INT32:
201       *trt_type = nvinfer1::DataType::kINT32;
202       break;
203     default:
204       return errors::InvalidArgument("Unsupported tensorflow data type ",
205                                      DataTypeString(tf_type));
206   }
207   return Status::OK();
208 }
209 
TrtTypeToTfType(nvinfer1::DataType trt_type,DataType * tf_type)210 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
211   switch (trt_type) {
212     case nvinfer1::DataType::kFLOAT:
213       *tf_type = DT_FLOAT;
214       break;
215     case nvinfer1::DataType::kHALF:
216       *tf_type = DT_HALF;
217       break;
218     case nvinfer1::DataType::kINT32:
219       *tf_type = DT_INT32;
220       break;
221     default:
222       return errors::InvalidArgument("Invalid TRT data type");
223   }
224   return Status::OK();
225 }
226 
GetNumberOfEngineInputs(const nvinfer1::ICudaEngine * engine)227 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
228   int n_bindings = engine->getNbBindings();
229   int n_input = 0;
230   for (int i = 0; i < n_bindings; i++) {
231     if (engine->bindingIsInput(i)) n_input++;
232   }
233   // According to TensorRT 7 doc: "If the engine has been built for K profiles,
234   // the first getNbBindings() / K bindings are used by profile number 0, the
235   // following getNbBindings() / K bindings are used by profile number 1 etc."
236   // Therefore, to get the number of input tensors, we need to divide by the
237   // the number of profiles.
238   int n_profiles = engine->getNbOptimizationProfiles();
239   return n_input / n_profiles;
240 }
241 
ProfileStrategyToName(const ProfileStrategy strategy)242 string ProfileStrategyToName(const ProfileStrategy strategy) {
243   switch (strategy) {
244     case ProfileStrategy::kRange:
245       return "Range";
246     case ProfileStrategy::kOptimal:
247       return "Optimal";
248     case ProfileStrategy::kRangeOptimal:
249       return "Range+Optimal";
250     case ProfileStrategy::kImplicitBatchModeCompatible:
251       return "ImplicitBatchModeCompatible";
252   }
253   return "Unknown";
254 }
255 
ProfileStrategyFromName(const string & name,ProfileStrategy * strategy)256 Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy) {
257   string name_lowercase(name);
258   std::transform(name.begin(), name.end(), name_lowercase.begin(),
259                  [](unsigned char c) { return std::tolower(c); });
260   if (name_lowercase == "range") {
261     *strategy = ProfileStrategy::kRange;
262   } else if (name_lowercase == "optimal") {
263     *strategy = ProfileStrategy::kOptimal;
264   } else if (name_lowercase == "range+optimal") {
265     *strategy = ProfileStrategy::kRangeOptimal;
266   } else if (name_lowercase == "implicitbatchmodecompatible") {
267     *strategy = ProfileStrategy::kImplicitBatchModeCompatible;
268   } else {
269     return errors::InvalidArgument("Invalid profile strategy: ", name);
270   }
271   return Status::OK();
272 }
273 
274 #endif
275 
GetDeviceName(const Node * node)276 absl::string_view GetDeviceName(const Node* node) {
277   if (node->has_assigned_device_name()) {
278     return node->assigned_device_name();
279   }
280   return node->requested_device();
281 }
282 
GetDeviceParsedName(const Node * node)283 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
284     const Node* node) {
285   absl::string_view device_name = GetDeviceName(node);
286   DeviceNameUtils::ParsedName parsed_name;
287   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
288     return absl::nullopt;
289   }
290   return parsed_name;
291 }
292 
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,const DeviceNameUtils::ParsedName & b)293 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
294     const DeviceNameUtils::ParsedName& a,
295     const DeviceNameUtils::ParsedName& b) {
296   DeviceNameUtils::ParsedName merged_name = a;
297   if (!DeviceNameUtils::MergeDevNames(&merged_name, b,
298                                       /*allow_soft_placement=*/false)
299            .ok()) {
300     return absl::nullopt;
301   }
302   return merged_name;
303 }
304 
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,absl::string_view b)305 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
306     const DeviceNameUtils::ParsedName& a, absl::string_view b) {
307   DeviceNameUtils::ParsedName b_parsed_name;
308   if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) {
309     return absl::nullopt;
310   }
311 
312   return MergeIfCompatible(a, b_parsed_name);
313 }
314 
315 }  // namespace tensorrt
316 }  // namespace tensorflow
317