• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
17 
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/core/status.h"
20 #include "tensorflow/core/platform/errors.h"
21 
22 namespace tensorflow {
23 namespace tensorrt {
24 
TrtPrecisionModeToName(const TrtPrecisionMode mode,string * name)25 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name) {
26   switch (mode) {
27     case TrtPrecisionMode::FP32:
28       *name = "FP32";
29       break;
30     case TrtPrecisionMode::FP16:
31       *name = "FP16";
32       break;
33     case TrtPrecisionMode::INT8:
34       *name = "INT8";
35       break;
36     default:
37       *name = "UNKNOWN";
38       return errors::OutOfRange("Unknown precision mode");
39   }
40   return Status::OK();
41 }
42 
TrtPrecisionModeFromName(const string & name,TrtPrecisionMode * mode)43 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
44   if (name == "FP32") {
45     *mode = TrtPrecisionMode::FP32;
46   } else if (name == "FP16") {
47     *mode = TrtPrecisionMode::FP16;
48   } else if (name == "INT8") {
49     *mode = TrtPrecisionMode::INT8;
50   } else {
51     return errors::InvalidArgument("Invalid precision mode name: ", name);
52   }
53   return Status::OK();
54 }
55 
56 #if GOOGLE_CUDA && GOOGLE_TENSORRT
57 
DebugString(const nvinfer1::DimensionType type)58 string DebugString(const nvinfer1::DimensionType type) {
59   switch (type) {
60     case nvinfer1::DimensionType::kSPATIAL:
61       return "kSPATIAL";
62     case nvinfer1::DimensionType::kCHANNEL:
63       return "kCHANNEL";
64     case nvinfer1::DimensionType::kINDEX:
65       return "kINDEX";
66     case nvinfer1::DimensionType::kSEQUENCE:
67       return "kSEQUENCE";
68     default:
69       return StrCat(static_cast<int>(type), "=unknown");
70   }
71 }
72 
DebugString(const nvinfer1::Dims & dims)73 string DebugString(const nvinfer1::Dims& dims) {
74   string out = StrCat("nvinfer1::Dims(nbDims=", dims.nbDims, ", d=");
75   for (int i = 0; i < dims.nbDims; ++i) {
76     StrAppend(&out, dims.d[i]);
77     if (VLOG_IS_ON(2)) {
78       StrAppend(&out, "[", DebugString(dims.type[i]), "],");
79     } else {
80       StrAppend(&out, ",");
81     }
82   }
83   StrAppend(&out, ")");
84   return out;
85 }
86 
DebugString(const DataType tf_type)87 string DebugString(const DataType tf_type) {
88   switch (tf_type) {
89     case DT_FLOAT:
90       return "DT_FLOAT";
91     case DT_HALF:
92       return "DT_HALF";
93     case DT_INT32:
94       return "DT_INT32";
95     case DT_INT8:
96       return "DT_INT8";
97     default:
98       return "Unknow TF DataType";
99   }
100 }
101 
DebugString(const nvinfer1::DataType trt_dtype)102 string DebugString(const nvinfer1::DataType trt_dtype) {
103   switch (trt_dtype) {
104     case nvinfer1::DataType::kFLOAT:
105       return "kFLOAT";
106     case nvinfer1::DataType::kHALF:
107       return "kHALF";
108     case nvinfer1::DataType::kINT8:
109       return "kINT8";
110     case nvinfer1::DataType::kINT32:
111       return "kINT32";
112     default:
113       return "Invalid TRT data type";
114   }
115 }
116 
DebugString(const TrtPrecisionMode mode)117 string DebugString(const TrtPrecisionMode mode) {
118   string mode_str;
119   TF_CHECK_OK(TrtPrecisionModeToName(mode, &mode_str));
120   return StrCat("TrtPrecisionMode::", mode_str);
121 }
122 
DebugString(const nvinfer1::Permutation & permutation,int len)123 string DebugString(const nvinfer1::Permutation& permutation, int len) {
124   string out = "nvinfer1::Permutation(";
125   for (int i = 0; i < len; ++i) {
126     StrAppend(&out, permutation.order[i], ",");
127   }
128   StrAppend(&out, ")");
129   return out;
130 }
131 
DebugString(const nvinfer1::ITensor & tensor)132 string DebugString(const nvinfer1::ITensor& tensor) {
133   return StrCat("nvinfer1::ITensor(@", reinterpret_cast<uintptr_t>(&tensor),
134                 ", name=", tensor.getName(),
135                 ", dtype=", DebugString(tensor.getType()),
136                 ", dims=", DebugString(tensor.getDimensions()), ")");
137 }
138 
DebugString(const std::vector<nvinfer1::Dims> & dimvec)139 string DebugString(const std::vector<nvinfer1::Dims>& dimvec) {
140   return absl::StrCat("[",
141                       absl::StrJoin(dimvec, ",",
142                                     [](std::string* out, nvinfer1::Dims in) {
143                                       out->append(DebugString(in));
144                                     }),
145                       "]");
146 }
147 
DebugString(const std::vector<TensorShape> & shapes)148 string DebugString(const std::vector<TensorShape>& shapes) {
149   return TensorShapeUtils::ShapeListString(shapes);
150 }
151 
DebugString(const std::vector<PartialTensorShape> & shapes)152 string DebugString(const std::vector<PartialTensorShape>& shapes) {
153   return PartialTensorShapeUtils::PartialShapeListString(shapes);
154 }
155 
156 // Checks whether actual_shapes are compatible with cached_shapes. This should
157 // only be used in implicit batch mode (in explicit batch mode one needs to
158 // check the profile ranges). Therefore implicit batch mode is assumed.
159 // It is also assumed that both actual_shapes and cached_shapes have been
160 // verified by TRTEngineOp::VerifyInputShapes, which ensures that the batch size
161 // for all tensors are the same.
AreShapesCompatible(const std::vector<TensorShape> & actual_shapes,const std::vector<TensorShape> & cached_shapes)162 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
163                          const std::vector<TensorShape>& cached_shapes) {
164   auto match_shape = [](const TensorShape& actual_shape,
165                         const TensorShape& cached_shape) {
166     // Match the rank.
167     if (actual_shape.dims() != cached_shape.dims()) return false;
168     // Match the batch size. In implicit batch mode cached_shape.dim_size(0) is
169     // the max batch size, which can be larger than the actual batch size.
170     if (actual_shape.dim_size(0) > cached_shape.dim_size(0)) return false;
171     // Match remaining dimensions.
172     for (int i = 1; i < actual_shape.dims(); ++i) {
173       if (actual_shape.dim_size(i) != cached_shape.dim_size(i)) return false;
174     }
175     return true;
176   };
177   for (int i = 0; i < actual_shapes.size(); ++i) {
178     if (!match_shape(actual_shapes[i], cached_shapes[i])) {
179       return false;
180     }
181   }
182   return true;
183 }
GetNetworkInputShapes(const nvinfer1::INetworkDefinition * network,std::vector<PartialTensorShape> * input_shapes)184 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
185                              std::vector<PartialTensorShape>* input_shapes) {
186   const int n_inputs = network->getNbInputs();
187   input_shapes->resize(n_inputs);
188   for (int i = 0; i < n_inputs; i++) {
189     const nvinfer1::ITensor* input = network->getInput(i);
190     const nvinfer1::Dims input_dim = input->getDimensions();
191     TF_RETURN_IF_ERROR(TrtDimsToTensorShape(input_dim, &input_shapes->at(i)));
192   }
193   return Status::OK();
194 }
TrtDimsToTensorShape(const std::vector<int> & trt_dims,TensorShape * shape,absl::optional<int> batch_size)195 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
196                             TensorShape* shape,
197                             absl::optional<int> batch_size) {
198   TF_RETURN_IF_ERROR(
199       TensorShapeUtils::MakeShape(trt_dims.data(), trt_dims.size(), shape));
200   if (batch_size) {
201     shape->InsertDim(0, batch_size.value());
202   }
203   return Status::OK();
204 }
205 
TfTypeToTrtType(DataType tf_type,nvinfer1::DataType * trt_type)206 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type) {
207   switch (tf_type) {
208     case DT_FLOAT:
209       *trt_type = nvinfer1::DataType::kFLOAT;
210       break;
211     case DT_HALF:
212       *trt_type = nvinfer1::DataType::kHALF;
213       break;
214     case DT_INT32:
215       *trt_type = nvinfer1::DataType::kINT32;
216       break;
217     default:
218       return errors::InvalidArgument("Unsupported tensorflow data type ",
219                                      DataTypeString(tf_type));
220   }
221   return Status::OK();
222 }
223 
TrtTypeToTfType(nvinfer1::DataType trt_type,DataType * tf_type)224 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type) {
225   switch (trt_type) {
226     case nvinfer1::DataType::kFLOAT:
227       *tf_type = DT_FLOAT;
228       break;
229     case nvinfer1::DataType::kHALF:
230       *tf_type = DT_HALF;
231       break;
232     case nvinfer1::DataType::kINT32:
233       *tf_type = DT_INT32;
234       break;
235     default:
236       return errors::InvalidArgument("Invalid TRT data type");
237   }
238   return Status::OK();
239 }
240 
GetNumberOfEngineInputs(const nvinfer1::ICudaEngine * engine)241 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
242   int n_bindings = engine->getNbBindings();
243   int n_input = 0;
244   for (int i = 0; i < n_bindings; i++) {
245     if (engine->bindingIsInput(i)) n_input++;
246   }
247   // According to TensorRT 7 doc: "If the engine has been built for K profiles,
248   // the first getNbBindings() / K bindings are used by profile number 0, the
249   // following getNbBindings() / K bindings are used by profile number 1 etc."
250   // Therefore, to get the number of input tensors, we need to divide by the
251   // the number of profiles.
252 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
253   int n_profiles = engine->getNbOptimizationProfiles();
254 #else
255   int n_profiles = 1;
256 #endif
257   return n_input / n_profiles;
258 }
259 
260 #endif
261 
GetDeviceName(const Node * node)262 absl::string_view GetDeviceName(const Node* node) {
263   if (node->has_assigned_device_name()) {
264     return node->assigned_device_name();
265   }
266   return node->requested_device();
267 }
268 
GetDeviceParsedName(const Node * node)269 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
270     const Node* node) {
271   absl::string_view device_name = GetDeviceName(node);
272   DeviceNameUtils::ParsedName parsed_name;
273   if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) {
274     return absl::nullopt;
275   }
276   return parsed_name;
277 }
278 
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,const DeviceNameUtils::ParsedName & b)279 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
280     const DeviceNameUtils::ParsedName& a,
281     const DeviceNameUtils::ParsedName& b) {
282   DeviceNameUtils::ParsedName merged_name = a;
283   if (!DeviceNameUtils::MergeDevNames(&merged_name, b,
284                                       /*allow_soft_placement=*/false)
285            .ok()) {
286     return absl::nullopt;
287   }
288   return merged_name;
289 }
290 
MergeIfCompatible(const DeviceNameUtils::ParsedName & a,absl::string_view b)291 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
292     const DeviceNameUtils::ParsedName& a, absl::string_view b) {
293   DeviceNameUtils::ParsedName b_parsed_name;
294   if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) {
295     return absl::nullopt;
296   }
297 
298   return MergeIfCompatible(a, b_parsed_name);
299 }
300 
301 }  // namespace tensorrt
302 }  // namespace tensorflow
303