• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 #if GOOGLE_CUDA && GOOGLE_TENSORRT
26 #include "third_party/tensorrt/NvInfer.h"
27 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
28 
29 namespace tensorflow {
30 namespace tensorrt {
31 
32 class IONamePrefixes {
33  public:
34   static constexpr const char* const kInputPHName = "TensorRTInputPH_";
35   static constexpr const char* const kOutputPHName = "TensorRTOutputPH_";
36 };
37 
38 template <typename T>
39 struct TrtDestroyer {
operatorTrtDestroyer40   void operator()(T* t) {
41     if (t) t->destroy();
42   }
43 };
44 
45 template <typename T>
46 using TrtUniquePtrType = std::unique_ptr<T, TrtDestroyer<T>>;
47 
48 enum class TrtPrecisionMode { FP32, FP16, INT8 };
49 
50 Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name);
51 
52 Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
53 
54 // Define a hash function for vector<TensorShape> because it is used as the key
55 // for the engine cache.
56 struct VectorTensorShapeHasher {
operatorVectorTensorShapeHasher57   std::size_t operator()(const std::vector<TensorShape>& key) const {
58     return std::hash<std::string>()(TensorShapeUtils::ShapeListString(key));
59   }
60 };
61 
62 #if GOOGLE_CUDA && GOOGLE_TENSORRT
63 
64 #define IS_TRT_VERSION_GE(major, minor, patch, build)           \
65   ((NV_TENSORRT_MAJOR > major) ||                               \
66    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR > minor) || \
67    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
68     NV_TENSORRT_PATCH > patch) ||                               \
69    (NV_TENSORRT_MAJOR == major && NV_TENSORRT_MINOR == minor && \
70     NV_TENSORRT_PATCH == patch && NV_TENSORRT_BUILD >= build))
71 
72 string DebugString(const nvinfer1::DimensionType type);
73 string DebugString(const nvinfer1::Dims& dims);
74 string DebugString(const nvinfer1::DataType trt_dtype);
75 string DebugString(const nvinfer1::Permutation& permutation, int len);
76 string DebugString(const nvinfer1::ITensor& tensor);
77 
HasStaticShape(const nvinfer1::Dims & dims)78 inline bool HasStaticShape(const nvinfer1::Dims& dims) {
79   if (dims.nbDims < 0) return false;
80   for (int d = 0; d < dims.nbDims; ++d) {
81     if (dims.d[d] < 0) return false;
82   }
83   return true;
84 }
85 
86 template <typename TensorShapeType>
TensorShapeToTrtDims(const TensorShapeType & shape,bool ignore_first_dim)87 inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
88                                            bool ignore_first_dim) {
89   nvinfer1::Dims trt_dims;
90   const int offset = (ignore_first_dim ? 1 : 0);
91   for (int i = offset; i < shape.dims(); i++) {
92     trt_dims.d[i - offset] = shape.dim_size(i);
93   }
94   trt_dims.nbDims = shape.dims() - offset;
95   return trt_dims;
96 }
97 
98 // Return a string that includes compile time
99 // TensorRT library version information {Maj, Min, Patch}.
100 string GetLinkedTensorRTVersion();
101 
102 // Return a string that includes runtime time
103 // TensorRT library version information {Maj, Min, Patch}.
104 string GetLoadedTensorRTVersion();
105 
106 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
107 
108 }  // namespace tensorrt
109 }  // namespace tensorflow
110 
111 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_UTILS_H_
112