1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
18
19 #include <memory>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_tensor_proxy.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/lib/strings/str_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31
32 #if GOOGLE_CUDA && GOOGLE_TENSORRT
33 #include "third_party/tensorrt/NvInfer.h"
34 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
35
36 namespace tensorflow {
37 namespace tensorrt {
38
39 static constexpr char kCastOutputTypeAttrName[] = "DstT";
40
41 class IONamePrefixes {
42 public:
43 static constexpr const char* const kInputPHName = "TensorRTInputPH_";
44 static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
45 };
46
47 template <typename T>
48 struct TrtDestroyer {
operatorTrtDestroyer49 void operator()(T* t) {
50 if (t) t->destroy();
51 }
52 };
53
54 template <typename T>
55 using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
56
57 enum class TrtPrecisionMode { FP32, FP16, INT8 };
58
59 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name);
60
61 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
62
63 // Define a hash function for vector<TensorShape> because it is used as the key
64 // for the engine cache.
65 struct VectorTensorShapeHasher {
operatorVectorTensorShapeHasher66 std::size_t operator()(const std::vector<TensorShape>& key) const {
67 return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
68 }
69 };
70
71 #if GOOGLE_CUDA && GOOGLE_TENSORRT
72
73 using absl::StrAppend;
74 using absl::StrCat;
75
76 // This utility template converts an arithmetic type to a string. This function
77 // is necessary to allow the following function to behave recursively:
78 // `string DebugString(const std::vector<CType>&)`.
79 template <typename CType, typename = typename std::enable_if<
80 std::is_arithmetic<CType>::value, CType>::type>
DebugString(const CType & el)81 string DebugString(const CType& el) {
82 string el_str = std::to_string(el);
83 // Prettify std::to_string which can sometimes returns 1.50000 instead of 1.5.
84 // In short it removes trailing 0s in a string-formatted number.
85 el_str.erase(el_str.find_last_not_of('0') + 1, std::string::npos);
86 return el_str;
87 }
88 // This utility template converts nested vectors to a string for debug purposes.
89 template <typename CType>
DebugString(const std::vector<CType> & vector)90 string DebugString(const std::vector<CType>& vector) {
91 string tmp_s = "";
92 for (const auto el : vector) {
93 StrAppend(&tmp_s, StrCat(DebugString(el), ", "));
94 }
95 return StrCat("{", tmp_s.substr(0, tmp_s.length() - 2), "}");
96 }
97 string DebugString(const nvinfer1::Dims& dims);
98 string DebugString(const nvinfer1::DataType trt_dtype);
99 string DebugString(const TrtPrecisionMode mode);
100 string DebugString(const DataType tf_type);
101 string DebugString(const nvinfer1::Permutation& permutation, int len);
102 string DebugString(const ITensorProxyPtr& tensor);
103 string DebugString(const nvinfer1::ITensor& tensor);
104 string DebugString(const std::vector<nvinfer1::Dims>& dimvec);
105 string DebugString(const std::vector<TensorShape>& shapes);
106 string DebugString(const std::vector<PartialTensorShape>& shapes);
107
HasStaticShape(const nvinfer1::Dims & dims)108 inline bool HasStaticShape(const nvinfer1::Dims& dims) {
109 if (dims.nbDims < 0) return false;
110 for (int d = 0; d < dims.nbDims; ++d) {
111 if (dims.d[d] < 0) return false;
112 }
113 return true;
114 }
115
116 template <typename T>
HasStaticShape(const T & dims)117 bool HasStaticShape(const T& dims) {
118 return !absl::c_any_of(dims, [](int i) { return i < 0; });
119 }
120
121 // Returns whether a shape is compatible with a TRT shape tensor.
122 template <typename TensorShapeType>
IsTrtShapeTensorCompatible(const TensorShapeType & shape)123 inline bool IsTrtShapeTensorCompatible(const TensorShapeType& shape) {
124 return (
125 shape.dims() == 0 ||
126 (shape.dims() == 1 && shape.num_elements() <= nvinfer1::Dims::MAX_DIMS));
127 }
128
129 // Returns whether a TF tensor could be interpreted as a TRT shape tensor.
IsTrtShapeTensorCompatible(const Tensor & tensor)130 inline bool IsTrtShapeTensorCompatible(const Tensor& tensor) {
131 return tensor.dtype() == DT_INT32 &&
132 IsTrtShapeTensorCompatible(tensor.shape());
133 }
134
135 template <typename Container>
136 Status ContainerToTrtDims(const Container& shape, nvinfer1::Dims* trt_dims,
137 bool ignore_first_dim = false) {
138 if (shape.size() == 0) {
139 // scalar
140 if (ignore_first_dim) {
141 return errors::Internal(
142 "Scalars cannot be represented in implicit batch mode");
143 }
144 *trt_dims = {0, {1}};
145 } else {
146 const int offset = (ignore_first_dim ? 1 : 0);
147 for (int i = offset; i < shape.size(); i++) {
148 trt_dims->d[i - offset] = shape.at(i);
149 }
150 trt_dims->nbDims = shape.size() - offset;
151 }
152 return Status::OK();
153 }
154
155 template <typename TensorShapeType>
TensorShapeToTrtDims(const TensorShapeType & shape,bool ignore_first_dim,nvinfer1::Dims * trt_dims)156 Status TensorShapeToTrtDims(const TensorShapeType& shape, bool ignore_first_dim,
157 nvinfer1::Dims* trt_dims) {
158 if (shape.dims() == -1) {
159 trt_dims->nbDims = -1;
160 return Status::OK();
161 }
162 return ContainerToTrtDims(shape.dim_sizes(), trt_dims, ignore_first_dim);
163 }
164
165 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
166 std::vector<PartialTensorShape>* input_shapes);
167
168 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
169 TensorShape* shape,
170 absl::optional<int> batch_size = absl::nullopt);
171
172 template <typename TensorShapeType>
173 Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
174 TensorShapeType* shape,
175 absl::optional<int> batch_size = absl::nullopt) {
176 TF_RETURN_IF_ERROR(
177 TensorShapeUtils::MakeShape(trt_dims.d, trt_dims.nbDims, shape));
178 if (batch_size) {
179 shape->InsertDim(0, batch_size.value());
180 }
181 return Status::OK();
182 }
183
184 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
185 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
186
187 // Returns true if an engine built for cached_shapes can also run actual_shapes.
188 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
189 const std::vector<TensorShape>& cached_shapes);
190
191 // Returns the number of inputs for the engine, which also correspends to the
192 // number of input tensors for the network. This can differ from the number of
193 // input bindings, because the number of total input bindings equals the number
194 // of profiles times the number of engine inputs.
195 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine);
196
197 // Returns the string representation for the assigned device or the requested
198 // device of the given node.
199 absl::string_view GetDeviceName(const Node* node);
200
201 // Returns the ParsedName representation for the assigned device or the
202 // requested device string of the given node. If the device string is invalid,
203 // returns absl::nullopt.
204 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
205 const Node* node);
206
207 // If the given two device assignments as compatible, returns the merge of the
208 // two assignments. Otherwise, returns absl::nullopt.
209 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
210 const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b);
211 // Similar to the above, except that the second device assignment is represented
212 // by a string_view.
213 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
214 const DeviceNameUtils::ParsedName& a, absl::string_view b);
215
216 // Optimization profile generation strategies.
217 enum class ProfileStrategy {
218 kRange,
219 kOptimal,
220 kRangeOptimal,
221 kImplicitBatchModeCompatible,
222 };
223
224 string ProfileStrategyToName(const ProfileStrategy strategy);
225 Status ProfileStrategyFromName(const string& name, ProfileStrategy* strategy);
226
227 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
228
229 } // namespace tensorrt
230 } // namespace tensorflow
231
232 #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
233