• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_
17 #define TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/stream_executor/lib/statusor.h"
27 
28 #if GOOGLE_CUDA && GOOGLE_TENSORRT
29 #include "third_party/tensorrt/NvInfer.h"
30 
31 namespace tensorflow {
32 namespace tensorrt {
33 using ::stream_executor::port::StatusOr;
34 
35 // Input/output data format for OpConverterTest::BuildAndRun().
36 struct InputOutputData {
BufferInputOutputData37   void* Buffer() const {
38     return const_cast<char*>(tensor.tensor_data().data());
39   }
40 
TotalBytesInputOutputData41   size_t TotalBytes() const { return tensor.TotalBytes(); }
42 
43   string name;
44   Tensor tensor;
45 };
46 
47 class TRTBaseAllocator;
48 
49 // Keeps track of the TensorRT execution context and the device memory owned by
50 // the context, if any. An execution context owns the device memory that TF-TRT
51 // allocates for the context. In this case, the allocator is not null and is
52 // used to free the device memory. An execution context doesn't own a device
53 // memory (1) if the device memory is allocated through TensorRT, or (2) the
54 // device memory is allocated by TF-TRT for another execution context but
55 // shared with this context. If this case, the device memory is null.
56 //
57 // Currently, the main reason we want to allocate the device memory for an
58 // execution context in TF-TRT is because the TensorRT API to create an
59 // execution context with device memory doesn't handle out of memory properly.
60 //
61 // To support dynamic shapes, we create multiple execution contexts for an
62 // engine and may want to support multiple execution contexts sharing the same
63 // device memory.
64 class ExecutionContext {
65  private:
66   // Records the TensorRT execution context `context`, the device memory
67   // `device_memory` TF-TRT allocates for the context and the device memory
68   // allocator `allocator` used to allocate the memory. If TF-TRT doesn't
69   // allocate any device memory for the context, then `device_memory` is null.
70   // otherwise, allocator should not be null.
ExecutionContext(TRTBaseAllocator * allocator,void * device_memory,nvinfer1::IExecutionContext * context)71   ExecutionContext(TRTBaseAllocator* allocator, void* device_memory,
72                    nvinfer1::IExecutionContext* context)
73       : memory_allocator_(allocator),
74         device_memory_(device_memory),
75         execution_context_(context) {}
76 
77  public:
78   // Disables copy constructors as the object owns the device memory and the
79   // execution context.
80   ExecutionContext(const ExecutionContext&) = delete;
81   ExecutionContext& operator=(const ExecutionContext&) = delete;
82 
ExecutionContext(ExecutionContext && other)83   ExecutionContext(ExecutionContext&& other)
84       : memory_allocator_(other.memory_allocator_),
85         device_memory_(other.device_memory_),
86         execution_context_(other.execution_context_) {
87     other.memory_allocator_ = nullptr;
88     other.device_memory_ = nullptr;
89     other.execution_context_ = nullptr;
90   }
91 
92   ~ExecutionContext();
93 
94   operator nvinfer1::IExecutionContext*() const { return execution_context_; }
GetIExecutionContext()95   nvinfer1::IExecutionContext* GetIExecutionContext() const {
96     return execution_context_;
97   }
98 
99   static StatusOr<ExecutionContext> Create(nvinfer1::ICudaEngine* cuda_engine,
100                                            TRTBaseAllocator* allocator);
101 
102  private:
103   // The allocator used to allocate and free the device memory owned by the
104   // execution context.
105   TRTBaseAllocator* memory_allocator_;
106   // The device memory owned by the execution context.
107   void* device_memory_;
108   // The TensorRT execution context.
109   nvinfer1::IExecutionContext* execution_context_;
110 };
111 
112 // Creates a TensorRT execution context. If an allocator is not given, then the
113 // execution context is created with device memory allocated by TensorRT.
114 // Otherwise, uses the allocator to allocate the needed device memory for the
115 // execution context.
116 //
117 // Returns an ExecutionContext object that wraps the above results. If out of
118 // device memory happens, returns an error status instead.
119 StatusOr<ExecutionContext> CreateExecutionContext(
120     nvinfer1::ICudaEngine* cuda_engine, TRTBaseAllocator* allocator);
121 
122 using DataVec = std::vector<InputOutputData>;
123 
124 // Gets the binding index of a tensor in an engine.
125 //
126 // The binding index is looked up using the tensor's name and the profile index.
127 // Profile index should be set to zero, if we do not have optimization profiles.
128 Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
129                           const nvinfer1::ICudaEngine* cuda_engine,
130                           int* binding_index);
131 
132 // Sets input buffers for TRT from a list of input tensors. The input tensors
133 // are either defined by ctx or by input_vec.
134 Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine,
135                           nvinfer1::IExecutionContext* execution_context,
136                           const int trt_profile_idx,
137                           std::vector<void*>& buffers, bool use_implicit_batch,
138                           int num_batch, OpKernelContext* ctx = nullptr,
139                           const DataVec* input_vec = nullptr);
140 
141 // Returns the shape of a binding from TensorRT.
142 //
143 // The binding is identified by its binding_index. The batch_size argument is
144 // ignored if use_implicit_batch==false. The shape is returned in the last
145 // argument.
146 Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
147                           const nvinfer1::IExecutionContext* execution_context,
148                           int binding_index, bool use_implicit_batch,
149                           int batch_size, TensorShape& shape);
150 
151 // Defines output buffers for TRT. The buffers are allocated by ctx, if ctx is
152 // not null. Otherwise it is expected that the outputs DataVec is not null, and
153 // the Tensors in outputs are already allocated.
154 Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine,
155                            nvinfer1::IExecutionContext* execution_context,
156                            int trt_profile_idx, std::vector<void*>& buffers,
157                            bool use_implicit_batch, int batch_size = 0,
158                            OpKernelContext* ctx = nullptr,
159                            DataVec* outputs = nullptr);
160 
161 // Enqueues TensorRT inference job. The batch_size argument is only relevant in
162 // implicit batch mode.
163 Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context,
164                   std::vector<void*>& buffers, cudaStream_t stream,
165                   bool use_implicit_batch, int batch_size = 1);
166 
167 }  // namespace tensorrt
168 }  // namespace tensorflow
169 
170 #endif  // GOOGLE_CUDA && GOOGLE_TENSORRT
171 
172 #endif  // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_
173