1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/tf2tensorrt/common/utils.h"
23 #include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
24 #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
25 #include "tensorflow/compiler/tf2tensorrt/utils/trt_execution_context.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/core/status.h"
29 #include "tensorflow/core/platform/errors.h"
30
31 #if GOOGLE_CUDA && GOOGLE_TENSORRT
32 #include "third_party/tensorrt/NvInfer.h"
33
34 namespace tensorflow {
35 namespace tensorrt {
36
37 using absl::StrCat;
38
Create(nvinfer1::ICudaEngine * cuda_engine)39 ExecutionContext ExecutionContext::Create(nvinfer1::ICudaEngine* cuda_engine) {
40 nvinfer1::IExecutionContext* execution_context =
41 cuda_engine->createExecutionContextWithoutDeviceMemory();
42 return ExecutionContext(execution_context);
43 }
44
GetTrtBindingShape(const nvinfer1::ICudaEngine * cuda_engine,const nvinfer1::IExecutionContext * execution_context,int binding_index,bool use_implicit_batch,int batch_size,TensorShape & shape)45 Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine,
46 const nvinfer1::IExecutionContext* execution_context,
47 int binding_index, bool use_implicit_batch,
48 int batch_size, TensorShape& shape) {
49 nvinfer1::Dims dims;
50 if (use_implicit_batch) {
51 dims = cuda_engine->getBindingDimensions(binding_index);
52 } else {
53 // Get dims from context instead of engine in explicit batch mode because
54 // the engine might have dynamic shapes.
55 dims = execution_context->getBindingDimensions(binding_index);
56 if (dims.nbDims == -1) {
57 // Invalid dimensions. There can be multiple reasons for this. If we have
58 // incompatible input shapes (network invalid for the current profile)
59 // that can trigger this error.
60 return errors::Internal(
61 "Binding index out of range. This can happen if profile is not set, "
62 "or the network is invalid for the current profile.");
63 }
64 }
65 TF_RETURN_IF_ERROR(TrtDimsToTensorShape(
66 dims, &shape,
67 use_implicit_batch ? absl::optional<int>(batch_size) : absl::nullopt));
68 return Status::OK();
69 }
70
GetTrtBindingIndex(const char * tensor_name,int profile_index,const nvinfer1::ICudaEngine * cuda_engine,int * binding_index)71 Status GetTrtBindingIndex(const char* tensor_name, int profile_index,
72 const nvinfer1::ICudaEngine* cuda_engine,
73 int* binding_index) {
74 // If the engine has been built for K profiles, the first getNbBindings() / K
75 // bindings are used by profile number 0, the following getNbBindings() / K
76 // bindings are used by profile number 1 etc.
77 //
78 // GetBindingIndex(tensor_name) returns the binding index for the progile 0.
79 // We can also consider it as a "binding_index_within_profile".
80 *binding_index = cuda_engine->getBindingIndex(tensor_name);
81 if (*binding_index == -1) {
82 const string msg = StrCat("Input node ", tensor_name, " not found");
83 LOG(ERROR) << msg;
84 return errors::NotFound(msg);
85 }
86 int n_profiles = cuda_engine->getNbOptimizationProfiles();
87 // If we have more then one optimization profile, then we need to shift the
88 // binding index according to the following formula:
89 // binding_index_within_engine = binding_index_within_profile +
90 // profile_index * bindings_per_profile
91 const int bindings_per_profile = cuda_engine->getNbBindings() / n_profiles;
92 *binding_index = *binding_index + profile_index * bindings_per_profile;
93 return Status::OK();
94 }
95
SetTrtEngineInputs(nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * execution_context,const int trt_profile_idx,std::vector<void * > & buffers,bool use_implicit_batch,int num_batch,const TrtShapeOptimizationProfile & profiles,OpKernelContext * ctx,const DataVec * input_vec)96 Status SetTrtEngineInputs(nvinfer1::ICudaEngine* cuda_engine,
97 nvinfer1::IExecutionContext* execution_context,
98 const int trt_profile_idx,
99 std::vector<void*>& buffers, bool use_implicit_batch,
100 int num_batch,
101 const TrtShapeOptimizationProfile& profiles,
102 OpKernelContext* ctx, const DataVec* input_vec) {
103 int n_inputs = ctx ? ctx->num_inputs() : (input_vec ? input_vec->size() : 0);
104 // Setup engine inputs.
105 for (int i = 0; i < n_inputs; i++) {
106 const string input_name =
107 ctx ? StrCat(IONamePrefixes::kInputPHName, i) : input_vec->at(i).name;
108 int binding_index;
109 TF_RETURN_IF_ERROR(GetTrtBindingIndex(input_name.c_str(), trt_profile_idx,
110 cuda_engine, &binding_index));
111 const Tensor& input_tensor = ctx ? ctx->input(i) : input_vec->at(i).tensor;
112 const TensorShape& input_shape = input_tensor.shape();
113
114 if (use_implicit_batch && ctx) {
115 // Ensure all inputs have the same batch size
116 if (num_batch != input_shape.dim_size(0)) {
117 const string msg =
118 StrCat("Input data has inconsistent batch size: ", num_batch,
119 " vs ", input_shape.dim_size(0));
120 return errors::NotFound(msg);
121 }
122 }
123 // Set known input dimensions. This is necessary because TRT network
124 // could be made with dynamic dimensions.
125 if (!use_implicit_batch) {
126 TF_RETURN_IF_ERROR(profiles.SetInputShapeBinding(
127 i, binding_index, cuda_engine, execution_context));
128
129 if (cuda_engine->isExecutionBinding(binding_index)) {
130 nvinfer1::Dims trt_dims;
131 TF_RETURN_IF_ERROR(TensorShapeToTrtDims(input_shape, false, &trt_dims));
132 VLOG(2) << "Setting binding dimensions for idx " << binding_index;
133 bool ret =
134 execution_context->setBindingDimensions(binding_index, trt_dims);
135 if (!ret) {
136 VLOG(2) << "Error setting engine input " << binding_index << " "
137 << DebugString(trt_dims);
138 return errors::Internal(
139 "Binding dimension does not fit selected profile.");
140 }
141 }
142 }
143 // Setup input bindings.
144 auto dtype = cuda_engine->getBindingDataType(binding_index);
145 switch (dtype) {
146 case nvinfer1::DataType::kFLOAT:
147 buffers[binding_index] =
148 const_cast<float*>(input_tensor.flat<float>().data());
149 break;
150 case nvinfer1::DataType::kHALF:
151 buffers[binding_index] =
152 const_cast<Eigen::half*>(input_tensor.flat<Eigen::half>().data());
153 break;
154 case nvinfer1::DataType::kINT8:
155 return errors::Internal("INT8 inputs are not supported yet!");
156 case nvinfer1::DataType::kINT32:
157 buffers[binding_index] =
158 const_cast<int32*>(input_tensor.flat<int32>().data());
159 break;
160 default:
161 return errors::Internal("Unknown TRT data type: ",
162 static_cast<int>(dtype));
163 }
164 }
165
166 // Ensure all network dynamic dimensions (if any) are set in execution
167 // context.
168 if (!execution_context->allInputDimensionsSpecified()) {
169 return errors::Internal(
170 "Failed to set dimensions for all dynamic input tensors");
171 }
172 if (!execution_context->allInputShapesSpecified()) {
173 return errors::Internal(
174 "Failed to set dimensions for all shape input tensors.");
175 }
176 return Status::OK();
177 }
178
SetTrtEngineOutputs(nvinfer1::ICudaEngine * cuda_engine,nvinfer1::IExecutionContext * execution_context,int trt_profile_idx,std::vector<void * > & buffers,bool use_implicit_batch,int batch_size,OpKernelContext * ctx,DataVec * outputs)179 Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine,
180 nvinfer1::IExecutionContext* execution_context,
181 int trt_profile_idx, std::vector<void*>& buffers,
182 bool use_implicit_batch, int batch_size,
183 OpKernelContext* ctx, DataVec* outputs) {
184 // Either one of ctx or outpus should be specified
185 int n_outputs = ctx ? ctx->num_outputs() : (outputs ? outputs->size() : 0);
186 for (int i = 0; i < n_outputs; i++) {
187 const string output_name =
188 ctx ? StrCat(IONamePrefixes::kOutputPHName, i) : outputs->at(i).name;
189 int binding_index;
190 TF_RETURN_IF_ERROR(GetTrtBindingIndex(output_name.c_str(), trt_profile_idx,
191 cuda_engine, &binding_index));
192
193 // Get TRT output shapes for allocating output memory.
194 TensorShape output_shape;
195 TF_RETURN_IF_ERROR(GetTrtBindingShape(cuda_engine, execution_context,
196 binding_index, use_implicit_batch,
197 batch_size, output_shape));
198
199 // Allocate output tensor of TRTEngineOp.
200 Tensor* output_tensor = nullptr;
201 if (ctx) {
202 TF_RETURN_IF_ERROR(ctx->allocate_output(i, output_shape, &output_tensor));
203 } else {
204 // This path is used for unit tests. The tensor is already allocated.
205 // Its shape is not necessarily set correctly, we fix that.
206 VLOG(2) << "Applying shape " << output_shape.DebugString()
207 << " on output.";
208 output_tensor = &(outputs->at(i).tensor);
209 bool status = output_tensor->CopyFrom(*output_tensor, output_shape);
210 if (!status) {
211 return errors::Internal(
212 "Buffer size (", output_tensor->NumElements(),
213 ") do not match while reshaping output tensors to shape ",
214 output_shape.DebugString());
215 }
216 }
217
218 // Setup output bindings.
219 auto dtype = cuda_engine->getBindingDataType(binding_index);
220 switch (dtype) {
221 case nvinfer1::DataType::kFLOAT:
222 buffers[binding_index] =
223 const_cast<float*>(output_tensor->flat<float>().data());
224 break;
225 case nvinfer1::DataType::kHALF:
226 buffers[binding_index] =
227 const_cast<Eigen::half*>(output_tensor->flat<Eigen::half>().data());
228 break;
229 case nvinfer1::DataType::kINT8:
230 return errors::Internal("int8 is not supported yet!");
231 case nvinfer1::DataType::kINT32:
232 buffers[binding_index] =
233 const_cast<int32*>(output_tensor->flat<int32>().data());
234 break;
235 default:
236 return errors::Internal("Unknown TRT data type: ",
237 static_cast<int>(dtype));
238 }
239 }
240 return Status::OK();
241 }
242
TrtEnqueue(nvinfer1::IExecutionContext * execution_context,std::vector<void * > & buffers,cudaStream_t stream,bool use_implicit_batch,int batch_size)243 Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context,
244 std::vector<void*>& buffers, cudaStream_t stream,
245 bool use_implicit_batch, int batch_size) {
246 bool ret = false;
247 if (use_implicit_batch) {
248 ret = execution_context->enqueue(batch_size, &buffers[0], stream, nullptr);
249 VLOG(1) << "Called IExecutionContext::enqueue";
250 } else {
251 ret = execution_context->enqueueV2(&buffers[0], stream, nullptr);
252 VLOG(1) << "Called IExecutionContext::enqueueV2";
253 }
254 if (!ret) {
255 return errors::Internal("Failed to enqueue batch for TRT engine");
256 }
257 // Synchronization will be done by TF.
258 return Status::OK();
259 }
260
261 } // namespace tensorrt
262 } // namespace tensorflow
263
264 #endif // GOOGLE_CUDA && GOOGLE_TENSORRT
265