1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29
30 #if GOOGLE_CUDA && GOOGLE_TENSORRT
31 #include "third_party/tensorrt/NvInfer.h"
32 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
33
34 namespace tensorflow {
35 namespace tensorrt {
36
37 static constexpr char kCastOutputTypeAttrName[] = "DstT";
38
39 class IONamePrefixes {
40 public:
41 static constexpr const char* const kInputPHName = "TensorRTInputPH_";
42 static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
43 };
44
45 template <typename T>
46 struct TrtDestroyer {
operatorTrtDestroyer47 void operator()(T* t) {
48 if (t) t->destroy();
49 }
50 };
51
52 template <typename T>
53 using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
54
55 enum class TrtPrecisionMode { FP32, FP16, INT8 };
56
57 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name);
58
59 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
60
61 // Define a hash function for vector<TensorShape> because it is used as the key
62 // for the engine cache.
63 struct VectorTensorShapeHasher {
operatorVectorTensorShapeHasher64 std::size_t operator()(const std::vector<TensorShape>& key) const {
65 return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
66 }
67 };
68
69 #if GOOGLE_CUDA && GOOGLE_TENSORRT
70
71 using absl::StrAppend;
72 using absl::StrCat;
73
74 #define IS_TRT_VERSION_GE(major, minor, patch, build) \
75 ((NV_TENSORRT_MAJOR > major) || \
76 (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
77 (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
78 NV_TENSORRT_PATCH > patch) || \
79 (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
80 NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
81
82 // This utility template converts an arithmetic type to a string. This function
83 // is necessary to allow the following function to behave recursively:
84 // `string DebugString(const std::vector<CType>&)`.
85 template <typename CType, typename = typename std::enable_if<
86 std::is_arithmetic<CType>::value, CType>::type>
DebugString(const CType & el)87 string DebugString(const CType& el) {
88 string el_str = std::to_string(el);
89 // Prettify std::to_string which can sometimes returns 1.50000 instead of 1.5.
90 // In short it removes trailing 0s in a string-formatted number.
91 el_str.erase(el_str.find_last_not_of('0') + 1, std::string::npos);
92 return el_str;
93 }
94 // This utility template converts nested vectors to a string for debug purposes.
95 template <typename CType>
DebugString(const std::vector<CType> & vector)96 string DebugString(const std::vector<CType>& vector) {
97 string tmp_s = "";
98 for (const auto el : vector) {
99 StrAppend(&tmp_s, StrCat(DebugString(el), ", "));
100 }
101 return StrCat("{", tmp_s.substr(0, tmp_s.length() - 2), "}");
102 }
103 string DebugString(const nvinfer1::DimensionType type);
104 string DebugString(const nvinfer1::Dims& dims);
105 string DebugString(const nvinfer1::DataType trt_dtype);
106 string DebugString(const TrtPrecisionMode mode);
107 string DebugString(const DataType tf_type);
108 string DebugString(const nvinfer1::Permutation& permutation, int len);
109 string DebugString(const nvinfer1::ITensor& tensor);
110 string DebugString(const std::vector<nvinfer1::Dims>& dimvec);
111 string DebugString(const std::vector<TensorShape>& shapes);
112 string DebugString(const std::vector<PartialTensorShape>& shapes);
113
HasStaticShape(const nvinfer1::Dims & dims)114 inline bool HasStaticShape(const nvinfer1::Dims& dims) {
115 if (dims.nbDims < 0) return false;
116 for (int d = 0; d < dims.nbDims; ++d) {
117 if (dims.d[d] < 0) return false;
118 }
119 return true;
120 }
121
HasStaticShape(std::vector<int> dims)122 inline bool HasStaticShape(std::vector<int> dims) {
123 return !absl::c_any_of(dims, [](int i) { return i < 0; });
124 }
125
126 // Returns whether a shape is compatible with a TRT shape tensor.
127 template <typename TensorShapeType>
IsTrtShapeTensorCompatible(const TensorShapeType & shape)128 inline bool IsTrtShapeTensorCompatible(const TensorShapeType& shape) {
129 return (
130 shape.dims() == 0 ||
131 (shape.dims() == 1 && shape.num_elements() <= nvinfer1::Dims::MAX_DIMS));
132 }
133
134 // Returns whether a TF tensor could be interpreted as a TRT shape tensor.
IsTrtShapeTensorCompatible(const Tensor & tensor)135 inline bool IsTrtShapeTensorCompatible(const Tensor& tensor) {
136 return tensor.dtype() == DT_INT32 &&
137 IsTrtShapeTensorCompatible(tensor.shape());
138 }
139
140 template <typename TensorShapeType>
TensorShapeToTrtDims(const TensorShapeType & shape,bool ignore_first_dim)141 inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
142 bool ignore_first_dim) {
143 nvinfer1::Dims trt_dims;
144 const int offset = (ignore_first_dim ? 1 : 0);
145 for (int i = offset; i < shape.dims(); i++) {
146 trt_dims.d[i - offset] = shape.dim_size(i);
147 }
148 trt_dims.nbDims = shape.dims() - offset;
149 return trt_dims;
150 }
151
152 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
153 std::vector<PartialTensorShape>* input_shapes);
154
155 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
156 TensorShape* shape,
157 absl::optional<int> batch_size = absl::nullopt);
158
159 template <typename TensorShapeType>
160 Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
161 TensorShapeType* shape,
162 absl::optional<int> batch_size = absl::nullopt) {
163 TF_RETURN_IF_ERROR(
164 TensorShapeUtils::MakeShape(trt_dims.d, trt_dims.nbDims, shape));
165 if (batch_size) {
166 shape->InsertDim(0, batch_size.value());
167 }
168 return Status::OK();
169 }
170
171 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
172 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
173
174 // Returns true if an engine built for cached_shapes can also run actual_shapes.
175 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
176 const std::vector<TensorShape>& cached_shapes);
177
178 // Returns the number of inputs for the engine, which also correspends to the
179 // number of input tensors for the network. This can differ from the number of
180 // input bindings, because the number of total input bindings equals the number
181 // of profiles times the number of engine inputs.
182 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine);
183
184 // Returns the string representation for the assigned device or the requested
185 // device of the given node.
186 absl::string_view GetDeviceName(const Node* node);
187
188 // Returns the ParsedName representation for the assigned device or the
189 // requested device string of the given node. If the device string is invalid,
190 // returns absl::nullopt.
191 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
192 const Node* node);
193
194 // If the given two device assignments as compatible, returns the merge of the
195 // two assignments. Otherwise, returns absl::nullopt.
196 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
197 const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b);
198 // Similar to the above, except that the second device assignment is represented
199 // by a string_view.
200 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
201 const DeviceNameUtils::ParsedName& a, absl::string_view b);
202
203 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
204
205 } // namespace tensorrt
206 } // namespace tensorflow
207
208 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
209