• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/strings/str_util.h"
28 #include "tensorflow/core/lib/strings/strcat.h"
29 
30 #if GOOGLE_CUDA && GOOGLE_TENSORRT
31 #include "third_party/tensorrt/NvInfer.h"
32 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
33 
34 namespace tensorflow {
35 namespace tensorrt {
36 
37 static constexpr char kCastOutputTypeAttrName[] = "DstT";
38 
39 class IONamePrefixes {
40  public:
41   static constexpr const char* const kInputPHName = "TensorRTInputPH_";
42   static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
43 };
44 
45 template <typename T>
46 struct TrtDestroyer {
operatorTrtDestroyer47   void operator()(T* t) {
48     if (t) t->destroy();
49   }
50 };
51 
52 template <typename T>
53 using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
54 
55 enum class TrtPrecisionMode { FP32, FP16, INT8 };
56 
57 Status TrtPrecisionModeToName(const TrtPrecisionMode mode, string* name);
58 
59 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
60 
61 // Define a hash function for vector<TensorShape> because it is used as the key
62 // for the engine cache.
63 struct VectorTensorShapeHasher {
operatorVectorTensorShapeHasher64   std::size_t operator()(const std::vector<TensorShape>& key) const {
65     return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
66   }
67 };
68 
69 #if GOOGLE_CUDA && GOOGLE_TENSORRT
70 
71 using absl::StrAppend;
72 using absl::StrCat;
73 
74 #define IS_TRT_VERSION_GE(major, minor, patch, build)           \
75   ((NV_TENSORRT_MAJOR > major) ||                               \
76    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
77    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
78     NV_TENSORRT_PATCH > patch) ||                               \
79    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
80     NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
81 
82 // This utility template converts an arithmetic type to a string. This function
83 // is necessary to allow the following function to behave recursively:
84 // `string DebugString(const std::vector<CType>&)`.
85 template <typename CType, typename = typename std::enable_if<
86                               std::is_arithmetic<CType>::value, CType>::type>
DebugString(const CType & el)87 string DebugString(const CType& el) {
88   string el_str = std::to_string(el);
89   // Prettify std::to_string which can sometimes returns 1.50000 instead of 1.5.
90   // In short it removes trailing 0s in a string-formatted number.
91   el_str.erase(el_str.find_last_not_of('0') + 1, std::string::npos);
92   return el_str;
93 }
94 // This utility template converts nested vectors to a string for debug purposes.
95 template <typename CType>
DebugString(const std::vector<CType> & vector)96 string DebugString(const std::vector<CType>& vector) {
97   string tmp_s = "";
98   for (const auto el : vector) {
99     StrAppend(&tmp_s, StrCat(DebugString(el), ", "));
100   }
101   return StrCat("{", tmp_s.substr(0, tmp_s.length() - 2), "}");
102 }
103 string DebugString(const nvinfer1::DimensionType type);
104 string DebugString(const nvinfer1::Dims& dims);
105 string DebugString(const nvinfer1::DataType trt_dtype);
106 string DebugString(const TrtPrecisionMode mode);
107 string DebugString(const DataType tf_type);
108 string DebugString(const nvinfer1::Permutation& permutation, int len);
109 string DebugString(const nvinfer1::ITensor& tensor);
110 string DebugString(const std::vector<nvinfer1::Dims>& dimvec);
111 string DebugString(const std::vector<TensorShape>& shapes);
112 string DebugString(const std::vector<PartialTensorShape>& shapes);
113 
HasStaticShape(const nvinfer1::Dims & dims)114 inline bool HasStaticShape(const nvinfer1::Dims& dims) {
115   if (dims.nbDims < 0) return false;
116   for (int d = 0; d < dims.nbDims; ++d) {
117     if (dims.d[d] < 0) return false;
118   }
119   return true;
120 }
121 
HasStaticShape(std::vector<int> dims)122 inline bool HasStaticShape(std::vector<int> dims) {
123   return !absl::c_any_of(dims, [](int i) { return i < 0; });
124 }
125 
126 // Returns whether a shape is compatible with a TRT shape tensor.
127 template <typename TensorShapeType>
IsTrtShapeTensorCompatible(const TensorShapeType & shape)128 inline bool IsTrtShapeTensorCompatible(const TensorShapeType& shape) {
129   return (
130       shape.dims() == 0 ||
131       (shape.dims() == 1 && shape.num_elements() <= nvinfer1::Dims::MAX_DIMS));
132 }
133 
134 // Returns whether a TF tensor could be interpreted as a TRT shape tensor.
IsTrtShapeTensorCompatible(const Tensor & tensor)135 inline bool IsTrtShapeTensorCompatible(const Tensor& tensor) {
136   return tensor.dtype() == DT_INT32 &&
137          IsTrtShapeTensorCompatible(tensor.shape());
138 }
139 
140 template <typename TensorShapeType>
TensorShapeToTrtDims(const TensorShapeType & shape,bool ignore_first_dim)141 inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
142                                            bool ignore_first_dim) {
143   nvinfer1::Dims trt_dims;
144   const int offset = (ignore_first_dim ? 1 : 0);
145   for (int i = offset; i < shape.dims(); i++) {
146     trt_dims.d[i - offset] = shape.dim_size(i);
147   }
148   trt_dims.nbDims = shape.dims() - offset;
149   return trt_dims;
150 }
151 
152 Status GetNetworkInputShapes(const nvinfer1::INetworkDefinition* network,
153                              std::vector<PartialTensorShape>* input_shapes);
154 
155 Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
156                             TensorShape* shape,
157                             absl::optional<int> batch_size = absl::nullopt);
158 
159 template <typename TensorShapeType>
160 Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
161                             TensorShapeType* shape,
162                             absl::optional<int> batch_size = absl::nullopt) {
163   TF_RETURN_IF_ERROR(
164       TensorShapeUtils::MakeShape(trt_dims.d, trt_dims.nbDims, shape));
165   if (batch_size) {
166     shape->InsertDim(0, batch_size.value());
167   }
168   return Status::OK();
169 }
170 
171 Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type);
172 Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type);
173 
174 // Returns true if an engine built for cached_shapes can also run actual_shapes.
175 bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
176                          const std::vector<TensorShape>& cached_shapes);
177 
178 // Returns the number of inputs for the engine, which also correspends to the
179 // number of input tensors for the network. This can differ from the number of
180 // input bindings, because the number of total input bindings equals the number
181 // of profiles times the number of engine inputs.
182 int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine);
183 
184 // Returns the string representation for the assigned device or the requested
185 // device of the given node.
186 absl::string_view GetDeviceName(const Node* node);
187 
188 // Returns the ParsedName representation for the assigned device or the
189 // requested device string of the given node. If the device string is invalid,
190 // returns absl::nullopt.
191 absl::optional<DeviceNameUtils::ParsedName> GetDeviceParsedName(
192     const Node* node);
193 
194 // If the given two device assignments as compatible, returns the merge of the
195 // two assignments. Otherwise, returns absl::nullopt.
196 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
197     const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b);
198 // Similar to the above, except that the second device assignment is represented
199 // by a string_view.
200 absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
201     const DeviceNameUtils::ParsedName& a, absl::string_view b);
202 
203 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
204 
205 }  // namespace tensorrt
206 }  // namespace tensorflow
207 
208 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
209