1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/platform/errors.h"
29
30 #if GOOGLE_CUDA && GOOGLE_TENSORRT
31 #include "third_party/tensorrt/NvInfer.h"
32
33 namespace tensorflow {
34 namespace tensorrt {
35
36 using absl::StrCat;
37
~ExecutionContext()38 ExecutionContext::~ExecutionContext() {
39 if (device_memory_) {
40 DCHECK(memory_allocator_) << "Internal error: Device memory with address "
41 << (char*)device_memory_ << "is not freed";
42 memory_allocator_->free(device_memory_);
43 }
44 if (execution_context_) {
45 execution_context_->destroy();
46 }
47 }
48
Create(nvinfer1::ICudaEngine * cuda_engine,TRTBaseAllocator * allocator)49 StatusOr<ExecutionContext> ExecutionContext::Create(
50 nvinfer1::ICudaEngine* cuda_engine, TRTBaseAllocator* allocator) {
51 void* device_memory = nullptr;
52 nvinfer1::IExecutionContext* execution_context;
53 if (allocator == nullptr) {
54 execution_context = cuda_engine->createExecutionContext();
55 } else {
56 execution_context =
57 cuda_engine->createExecutionContextWithoutDeviceMemory();
58 size_t device_memory_size = cuda_engine->getDeviceMemorySize();
59 VLOG(2) << "Device memory size for cuda engine " << device_memory_size;
60
61 if (device_memory_size > 0) {
62 device_memory = allocator->allocate(device_memory_size,
63 /*unused alignment=*/0, /*flags=*/0);
64 if (device_memory == nullptr) {
65 return errors::InvalidArgument(
66 "Out of GPU memory when creating execution context");
67 }
68 }
69 execution_context->setDeviceMemory(device_memory);
70 }
71 return ExecutionContext(allocator, device_memory, execution_context);
72 }
73
GetTrtBindingShape(const nvinfer1::ICudaEngine * cuda_engine,const nvinfer1::IExecutionContext * execution_context,int binding_index,bool use_implicit_batch,int batch_size,TensorShape & shape)74 Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
75 const nvinfer1::IExecutionContext* execution_context,
76 int binding_index, bool use_implicit_batch,
77 int batch_size, TensorShape& shape) {
78 nvinfer1::Dims dims;
79 if (use_implicit_batch) {
80 dims = cuda_engine->getBindingDimensions(binding_index);
81 } else {
82 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
83 // Get dims from context instead of engine in explicit batch mode because
84 // the engine might have dynamic shapes.
85 dims = execution_context->getBindingDimensions(binding_index);
86 if (dims.nbDims == -1) {
87 // Invalid dimensions. There can be multiple reasons for this. If we have
88 // incompatible input shapes (network invalid for the current profile)
89 // that can trigger this error.
90 return errors::Internal(
91 "Binding index out of range. This can happen if profile is not set, "
92 "or the network is invalid for the current profile.");
93 }
94 #else
95 return errors::Internal(
96 "Explicit batch mode is only supported with TensorRT 6 and above.");
97 #endif
98 }
99 TF_RETURN_IF_ERROR(TrtDimsToTensorShape(
100 dims, &shape,
101 use_implicit_batch ? absl::optional<int>(batch_size) : absl::nullopt));
102 return Status::OK();
103 }
104
GetTrtBindingIndex(const char * tensor_name,int profile_index,const nvinfer1::ICudaEngine * cuda_engine,int * binding_index)105 Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
106 const nvinfer1::ICudaEngine* cuda_engine,
107 int* binding_index) {
108 // If the engine has been built for K profiles, the first getNbBindings() / K
109 // bindings are used by profile number 0, the following getNbBindings() / K
110 // bindings are used by profile number 1 etc.
111 //
112 // GetBindingIndex(tensor_name) returns the binding index for the progile 0.
113 // We can also consider it as a "binding_index_within_profile".
114 *binding_index = cuda_engine->getBindingIndex(tensor_name);
115 if (*binding_index == -1) {
116 const string msg = StrCat("Input node ", tensor_name, " not found");
117 LOG(ERROR) << msg;
118 return errors::NotFound(msg);
119 }
120 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
121 int n_profiles = cuda_engine->getNbOptimizationProfiles();
122 #else
123 int n_profiles = 1;
124 #endif
125 // If we have more then one optimization profile, then we need to shift the
126 // binding index according to the following formula:
127 // binding_index_within_engine = binding_index_within_profile +
128 // profile_index * bindings_per_profile
129 const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles;
130 *binding_index = *binding_index + profile_index * bindings_per_profile;
131 return Status::OK();
132 }
133
SetTrtEngineInputs(nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * execution_context,const int trt_profile_idx,std::vector<void * > & buffers,bool use_implicit_batch,int num_batch,OpKernelContext * ctx,const DataVec * input_vec)134 Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine,
135 nvinfer1::IExecutionContext* execution_context,
136 const int trt_profile_idx,
137 std::vector<void*>& buffers, bool use_implicit_batch,
138 int num_batch, OpKernelContext* ctx,
139 const DataVec* input_vec) {
140 int n_inputs = ctx ? ctx->num_inputs() : (input_vec ? input_vec->size() : 0);
141 // Setup engine inputs.
142 for (int i = 0; i < n_inputs; i++) {
143 const string input_name =
144 ctx ? StrCat(IONamePrefixes::kInputPHName, i) : input_vec->at(i).name;
145 int binding_index;
146 TF_RETURN_IF_ERROR(GetTrtBindingIndex(input_name.c_str(), trt_profile_idx,
147 cuda_engine, &binding_index));
148 const Tensor& input_tensor = ctx ? ctx->input(i) : input_vec->at(i).tensor;
149 const TensorShape& input_shape = input_tensor.shape();
150
151 if (use_implicit_batch && ctx) {
152 // Ensure all inputs have the same batch size
153 if (num_batch != input_shape.dim_size(0)) {
154 const string msg =
155 StrCat("Input data has inconsistent batch size: ", num_batch,
156 " vs ", input_shape.dim_size(0));
157 return errors::NotFound(msg);
158 }
159 }
160 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
161 // Set known input dimensions. This is necessary because TRT network
162 // could be made with dynamic dimensions.
163 if (!use_implicit_batch) {
164 nvinfer1::Dims trt_dims;
165 trt_dims.nbDims = input_shape.dims();
166 for (int k = 0; k < input_shape.dims(); k++) {
167 trt_dims.d[k] = input_shape.dim_size(k);
168 }
169 bool ret =
170 execution_context->setBindingDimensions(binding_index, trt_dims);
171 if (!ret) {
172 VLOG(2) << "Error setting engine input " << binding_index << " "
173 << DebugString(trt_dims);
174 return errors::Internal(
175 "Binding dimension does not fit selected profile.");
176 }
177 }
178 #endif
179 // Setup input bindings.
180 auto dtype = cuda_engine->getBindingDataType(binding_index);
181 switch (dtype) {
182 case nvinfer1::DataType::kFLOAT:
183 buffers[binding_index] =
184 const_cast<float*>(input_tensor.flat<float>().data());
185 break;
186 case nvinfer1::DataType::kHALF:
187 buffers[binding_index] =
188 const_cast<Eigen::half*>(input_tensor.flat<Eigen::half>().data());
189 break;
190 case nvinfer1::DataType::kINT8:
191 return errors::Internal("INT8 inputs are not supported yet!");
192 case nvinfer1::DataType::kINT32:
193 buffers[binding_index] =
194 const_cast<int32*>(input_tensor.flat<int32>().data());
195 break;
196 default:
197 return errors::Internal("Unknown TRT data type: ",
198 static_cast<int>(dtype));
199 }
200 }
201
202 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
203 // Ensure all network dynamic dimensions (if any) are set in execution
204 // context.
205 if (!execution_context->allInputDimensionsSpecified()) {
206 return errors::Internal(
207 "Failed to set dimensions for all dynamic input tensors");
208 }
209 if (!execution_context->allInputShapesSpecified()) {
210 return errors::Internal(
211 "Failed to set dimensions for all shape input tensors.");
212 }
213 #endif
214 return Status::OK();
215 }
216
SetTrtEngineOutputs(nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * execution_context,int trt_profile_idx,std::vector<void * > & buffers,bool use_implicit_batch,int batch_size,OpKernelContext * ctx,DataVec * outputs)217 Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine,
218 nvinfer1::IExecutionContext* execution_context,
219 int trt_profile_idx, std::vector<void*>& buffers,
220 bool use_implicit_batch, int batch_size,
221 OpKernelContext* ctx, DataVec* outputs) {
222 // Either one of ctx or outpus should be specified
223 int n_outputs = ctx ? ctx->num_outputs() : (outputs ? outputs->size() : 0);
224 for (int i = 0; i < n_outputs; i++) {
225 const string output_name =
226 ctx ? StrCat(IONamePrefixes::kOutputPHName, i) : outputs->at(i).name;
227 int binding_index;
228 TF_RETURN_IF_ERROR(GetTrtBindingIndex(output_name.c_str(), trt_profile_idx,
229 cuda_engine, &binding_index));
230
231 // Get TRT output shapes for allocating output memory.
232 TensorShape output_shape;
233 TF_RETURN_IF_ERROR(GetTrtBindingShape(cuda_engine, execution_context,
234 binding_index, use_implicit_batch,
235 batch_size, output_shape));
236
237 // Allocate output tensor of TRTEngineOp.
238 Tensor* output_tensor = nullptr;
239 if (ctx) {
240 TF_RETURN_IF_ERROR(ctx->allocate_output(i, output_shape, &output_tensor));
241 } else {
242 // This path is used for unit tests. The tensor is already allocated.
243 // Its shape is not necessarily set correctly, we fix that.
244 VLOG(2) << "Applying shape " << output_shape.DebugString()
245 << " on output.";
246 output_tensor = &(outputs->at(i).tensor);
247 bool status = output_tensor->CopyFrom(*output_tensor, output_shape);
248 if (!status) {
249 return errors::Internal(
250 "Buffer size do not match while reshaping output tensors");
251 }
252 }
253
254 // Setup output bindings.
255 auto dtype = cuda_engine->getBindingDataType(binding_index);
256 switch (dtype) {
257 case nvinfer1::DataType::kFLOAT:
258 buffers[binding_index] =
259 const_cast<float*>(output_tensor->flat<float>().data());
260 break;
261 case nvinfer1::DataType::kHALF:
262 buffers[binding_index] =
263 const_cast<Eigen::half*>(output_tensor->flat<Eigen::half>().data());
264 break;
265 case nvinfer1::DataType::kINT8:
266 return errors::Internal("int8 is not supported yet!");
267 case nvinfer1::DataType::kINT32:
268 buffers[binding_index] =
269 const_cast<int32*>(output_tensor->flat<int32>().data());
270 break;
271 default:
272 return errors::Internal("Unknown TRT data type: ",
273 static_cast<int>(dtype));
274 }
275 }
276 return Status::OK();
277 }
278
TrtEnqueue(nvinfer1::IExecutionContext * execution_context,std::vector<void * > & buffers,cudaStream_t stream,bool use_implicit_batch,int batch_size)279 Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context,
280 std::vector<void*>& buffers, cudaStream_t stream,
281 bool use_implicit_batch, int batch_size) {
282 bool ret = false;
283 if (use_implicit_batch) {
284 ret = execution_context->enqueue(batch_size, &buffers[0], stream, nullptr);
285 VLOG(1) << "Called IExecutionContext::enqueue";
286 } else {
287 #if IS_TRT_VERSION_GE(6, 0, 0, 0)
288 ret = execution_context->enqueueV2(&buffers[0], stream, nullptr);
289 VLOG(1) << "Called IExecutionContext::enqueueV2";
290 #else
291 return errors::Internal(
292 "Explicit batch mode is only supported with TensorRT 6 and above.");
293 #endif
294 }
295 if (!ret) {
296 return errors::Internal("Failed to enqueue batch for TRT engine");
297 }
298 // Synchronization will be done by TF.
299 return Status::OK();
300 }
301
302 } // namespace tensorrt
303 } // namespace tensorflow
304
305 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
306