1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/delegate/tensorrt/tensorrt_subgraph.h"
18 #include <cuda_runtime_api.h>
19 #include <string>
20 #include <vector>
21 #include <set>
22 #include "src/delegate/delegate_utils.h"
23
24 namespace mindspore::lite {
~TensorRTSubGraph()25 TensorRTSubGraph::~TensorRTSubGraph() {
26 if (network_ != nullptr) {
27 network_->destroy();
28 network_ = nullptr;
29 }
30 if (config_ != nullptr) {
31 config_->destroy();
32 config_ = nullptr;
33 }
34 if (trt_context_ != nullptr) {
35 trt_context_->destroy();
36 trt_context_ = nullptr;
37 }
38 if (engine_ != nullptr) {
39 engine_->destroy();
40 engine_ = nullptr;
41 }
42 if (tensor_bindings_ != nullptr) {
43 delete tensor_bindings_;
44 tensor_bindings_ = nullptr;
45 }
46 for (auto op : all_ops_) {
47 delete op;
48 }
49 }
50
Init()51 int TensorRTSubGraph::Init() {
52 auto ret = GetGraphInOutOps(inputs_, outputs_, &in_ops_, &out_ops_, all_ops_);
53 if (ret != RET_OK) {
54 MS_LOG(ERROR) << "Get NPU subgraph input and output ops failed.";
55 return RET_ERROR;
56 }
57 this->network_ = runtime_->GetBuilder()->createNetworkV2(
58 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
59 if (this->network_ == nullptr) {
60 MS_LOG(ERROR) << "New network failed.";
61 return RET_ERROR;
62 }
63 for (size_t i = 0; i < inputs_.size(); i++) {
64 if (inputs_[i].Shape().size() != DIMENSION_4D) {
65 MS_LOG(WARNING) << "hw dims resize is unsupported.";
66 input_hw_index_ = -1;
67 }
68 }
69 if (SetDeviceConfig() != RET_OK) {
70 MS_LOG(WARNING) << "set tensorrt config failed.";
71 }
72 profile_ = runtime_->GetBuilder()->createOptimizationProfile();
73 return RET_OK;
74 }
75
BuildEngine()76 int TensorRTSubGraph::BuildEngine() {
77 // print all network ops
78 if (this->config_->addOptimizationProfile(profile_) == -1) {
79 MS_LOG(ERROR) << "addOptimizationProfile failed.";
80 return RET_ERROR;
81 }
82 MS_LOG(INFO) << "build engine for tensorrt network: " << this->network_->getName();
83 for (int i = 0; i < this->network_->getNbLayers(); i++) {
84 MS_LOG(DEBUG) << "tensorrt op: " << this->network_->getLayer(i)->getName();
85 }
86 MS_LOG(DEBUG) << "end of tensorrt network: " << this->network_->getName();
87
88 this->engine_ = runtime_->GetBuilder()->buildEngineWithConfig(*this->network_, *this->config_);
89 if (this->engine_ == nullptr) {
90 MS_LOG(ERROR) << "Create engine failed in TensorRT network";
91 return RET_ERROR;
92 }
93 return RET_OK;
94 }
95
SetDeviceConfig()96 int TensorRTSubGraph::SetDeviceConfig() {
97 this->config_ = runtime_->GetBuilder()->createBuilderConfig();
98 if (this->config_ == nullptr) {
99 MS_LOG(ERROR) << "create builder config failed.";
100 return RET_ERROR;
101 }
102 // set fp16
103 if (device_info_->GetEnableFP16() && runtime_->GetBuilder()->platformHasFastFp16()) {
104 MS_LOG(INFO) << "set fp16 flag successfully for tensorrt.";
105 config_->setFlag(nvinfer1::BuilderFlag::kFP16);
106 input_hw_index_ = -1;
107 }
108
109 // config setMaxWorkspaceSize to 32 MB for max limit
110 config_->setMaxWorkspaceSize(32 * (1 << 20));
111 return RET_OK;
112 }
113
SupportFP16()114 bool TensorRTSubGraph::SupportFP16() {
115 int deviceCnt = 0;
116
117 cudaError ret = cudaGetDeviceCount(&deviceCnt);
118 if (ret != cudaSuccess) {
119 MS_LOG(ERROR) << "cudaGetDeviceCount failed.";
120 return false;
121 }
122 std::vector<std::string> supportFP16_versions{"5.3", "6.0", "6.2", "7.0", "7.2", "7.5", "8.0", "8.6"};
123 cudaDeviceProp prop;
124 std::string version;
125 for (int dev = 0; dev < deviceCnt; dev++) {
126 ret = cudaGetDeviceProperties(&prop, dev);
127 if (ret != cudaSuccess) {
128 MS_LOG(ERROR) << "cuDeviceGetAttribute failed.";
129 return false;
130 }
131 version = std::to_string(prop.major) + "." + std::to_string(prop.minor);
132 if (std::find(supportFP16_versions.begin(), supportFP16_versions.end(), version) != supportFP16_versions.end()) {
133 MS_LOG(INFO) << "cuda device version is: " << version << ", support FP16, set enable FP16 tag successful";
134 return true;
135 }
136 }
137 MS_LOG(WARNING) << "cuda device version is: " << version << ", don't support FP16, set enable FP16 tag failed";
138 return false;
139 }
140
SetTensorRTNetworkInput(const mindspore::MSTensor & in_tensor)141 nvinfer1::ITensor *TensorRTSubGraph::SetTensorRTNetworkInput(const mindspore::MSTensor &in_tensor) {
142 for (int i = 0; i < this->network_->getNbInputs(); i++) {
143 if (in_tensor.Name().compare(this->network_->getInput(i)->getName()) == 0) {
144 MS_LOG(INFO) << "input tensor is already added in network: " << in_tensor.Name();
145 return this->network_->getInput(i);
146 }
147 }
148
149 auto cuda_dtype = ConvertDataType(in_tensor.DataType());
150 if (static_cast<int>(cuda_dtype) == -1) {
151 MS_LOG(ERROR) << "Unsupported input data type " << static_cast<int>(in_tensor.DataType());
152 return nullptr;
153 }
154 nvinfer1::Dims input_dims = ConvertCudaDims(in_tensor.Shape());
155 if (runtime_->GetBatchSize() == 0) {
156 runtime_->SetBatchSize(input_dims.d[0]);
157 MS_LOG(INFO) << "batch size init as " << runtime_->GetBatchSize();
158 input_dims.d[0] = -1; // dynamic batch size with wildcard N, default batchsize is first dims
159 input_batchsize_index_ = 0;
160 } else {
161 for (int n = 0; n < input_dims.nbDims; n++) {
162 if (input_dims.d[n] == runtime_->GetBatchSize()) {
163 // first dims equals to batchsize
164 input_dims.d[n] = -1;
165 input_batchsize_index_ = n;
166 break;
167 }
168 }
169 }
170 // only support NHWC HW dim resize
171 if (input_hw_index_ != -1) {
172 MS_LOG(INFO) << "input tensor format is (NHWC:1, NCHW:0): " << in_tensor.format();
173 input_hw_index_ = in_tensor.format() == Format::NHWC ? 1 : 2; // NCHW is 2
174 input_dims.d[input_hw_index_] = -1;
175 input_dims.d[input_hw_index_ + 1] = -1;
176 }
177 // We do not need to check the return of setDimension and addOptimizationProfile here as all dims are explicitly set
178 nvinfer1::Dims input_dims_min = ConvertCudaDims(in_tensor.Shape());
179 input_dims_min.d[input_batchsize_index_] = 1;
180 if (input_hw_index_ != -1) {
181 input_dims_min.d[input_hw_index_] = 1;
182 input_dims_min.d[input_hw_index_ + 1] = 1;
183 }
184 if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMIN, input_dims_min)) {
185 MS_LOG(ERROR) << "setDimensions of kMIN failed.";
186 return nullptr;
187 }
188 nvinfer1::Dims input_dims_opt = ConvertCudaDims(in_tensor.Shape());
189 if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kOPT, input_dims_opt)) {
190 MS_LOG(ERROR) << "setDimensions of kOPT failed.";
191 return nullptr;
192 }
193 nvinfer1::Dims input_dims_max = ConvertCudaDims(in_tensor.Shape());
194 // input_dims_max should be the same with input network dims
195 if (!profile_->setDimensions(in_tensor.Name().c_str(), nvinfer1::OptProfileSelector::kMAX, input_dims_max)) {
196 MS_LOG(ERROR) << "setDimensions of kMAX failed.";
197 return nullptr;
198 }
199
200 return this->network_->addInput(in_tensor.Name().c_str(), cuda_dtype, input_dims);
201 }
202
BuildTensorRTGraph()203 int TensorRTSubGraph::BuildTensorRTGraph() {
204 MS_ASSERT(!all_ops_.empty());
205 // Connect NetWork.
206 int ret;
207 for (auto cur_op : all_ops_) {
208 for (auto in_tensor : cur_op->inputs()) {
209 // Data From CPU
210 if (IsSubGraphInputTensor(this->inputs(), in_tensor)) {
211 nvinfer1::ITensor *trt_tensor = SetTensorRTNetworkInput(in_tensor);
212 if (trt_tensor == nullptr) {
213 MS_LOG(ERROR) << "SetTensorRTNetworkInput failed for " << in_tensor.Name();
214 return RET_ERROR;
215 }
216 cur_op->AddInnerInTensors(ITensorHelper{trt_tensor, in_tensor.format()});
217 continue;
218 }
219
220 ITensorHelper trt_tensor = FindTensorRTInputs(cur_op, in_tensor);
221 if (trt_tensor.trt_tensor_ == nullptr) {
222 // weight tensor
223 if (trt_specific_weight_nodes_.find(cur_op->type()) == trt_specific_weight_nodes_.end()) {
224 if (in_tensor.Data() == nullptr) {
225 MS_LOG(ERROR) << "Weight Tensor data is nullptr.";
226 return RET_ERROR;
227 }
228 trt_tensor.trt_tensor_ = lite::ConvertConstantTensor(this->network_, in_tensor);
229 trt_tensor.format_ = Format::NHWC;
230 MS_LOG(INFO) << "auto convert constant tensor for: " << in_tensor.Name();
231 cur_op->AddInnerInTensors(trt_tensor);
232 }
233 } else {
234 cur_op->AddInnerInTensors(trt_tensor);
235 }
236 }
237
238 ret = cur_op->AddInnerOp(this->network_);
239 if (ret != RET_OK) {
240 MS_LOG(ERROR) << "Add op failed in TensorRT network";
241 return RET_ERROR;
242 }
243 }
244
245 ret = MarkOutputs();
246 if (ret != RET_OK) {
247 MS_LOG(ERROR) << "MarkOutputs failed in TensorRT network";
248 return ret;
249 }
250 std::string network_name =
251 "network_" + std::string(network_->getInput(0)->getName()) + "_" + std::string(network_->getOutput(0)->getName());
252 network_->setName(network_name.c_str());
253 this->name_ = network_name;
254
255 ret = BuildEngine();
256 if (ret != RET_OK) {
257 MS_LOG(ERROR) << "Create engine failed in TensorRT network";
258 return ret;
259 }
260 return RET_OK;
261 }
262
MarkOutputs()263 int TensorRTSubGraph::MarkOutputs() {
264 // Mark NetWork Output Tensor.
265 for (auto out_tensor : outputs_) {
266 for (auto out_op : this->out_ops_) {
267 for (size_t index = 0; index < out_op->outputs().size(); index++) {
268 if (out_op->outputs()[index] == out_tensor) {
269 nvinfer1::ITensor *out_trt_tensor = out_op->GetInnerOutTensor()[index].trt_tensor_;
270 if (out_op->GetInnerOutTensor()[index].trt_tensor_->getDimensions().nbDims == DIMENSION_4D &&
271 out_op->GetInnerOutTensor()[index].format_ == Format::NCHW &&
272 !SameDims(out_op->GetInnerOutTensor()[index].trt_tensor_->getDimensions(), out_tensor.Shape())) {
273 // transpose subgraph output from nchw to nhwc
274 nvinfer1::IShuffleLayer *transpose_layer_out =
275 NCHW2NHWC(network_, *out_op->GetInnerOutTensor()[index].trt_tensor_);
276 if (transpose_layer_out == nullptr) {
277 MS_LOG(ERROR) << "op action convert failed";
278 return RET_ERROR;
279 }
280 transpose_layer_out->setName((out_tensor.Name() + "_transpose2NHWC").c_str());
281 out_trt_tensor = transpose_layer_out->getOutput(0);
282 }
283
284 out_trt_tensor->setName(out_tensor.Name().c_str());
285 MS_LOG(INFO) << "markOutput for: " << out_tensor.Name();
286 this->network_->markOutput(*out_trt_tensor);
287 for (int n = 0; n < out_trt_tensor->getDimensions().nbDims; n++) {
288 if (out_trt_tensor->getDimensions().d[n] == -1) {
289 output_batchsize_index_ = n;
290 break;
291 }
292 }
293 }
294 }
295 }
296 }
297 return RET_OK;
298 }
299
Prepare()300 int TensorRTSubGraph::Prepare() {
301 lite::SetCudaDevice(device_info_);
302 if (this->engine_ == nullptr) {
303 MS_LOG(ERROR) << "engine_ is null in this builder_";
304 return RET_ERROR;
305 }
306 this->trt_context_ = this->engine_->createExecutionContext();
307 if (this->trt_context_ == nullptr) {
308 MS_LOG(ERROR) << "TensorRTSubGraph create context failed.";
309 return RET_ERROR;
310 }
311 int binding_num = this->engine_->getNbBindings();
312 tensor_bindings_ = new (std::nothrow) void *[binding_num];
313 if (tensor_bindings_ == nullptr) {
314 MS_LOG(ERROR) << "malloc tensor binding array failed.";
315 return RET_ERROR;
316 }
317
318 for (auto tensor : inputs_) {
319 auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor, tensor.DataSize());
320 if (device_ptr == nullptr) {
321 MS_LOG(ERROR) << "malloc for inputs tensor device memory failed.";
322 return RET_ERROR;
323 }
324 int index = this->engine_->getBindingIndex(tensor.Name().c_str());
325 tensor_bindings_[index] = device_ptr;
326 trt_in_tensor_name_.push_back(tensor.Name());
327 nvinfer1::Dims input_dims = ConvertCudaDims(tensor.Shape());
328 for (int od = 0; od < input_dims.nbDims; od++) {
329 MS_LOG(DEBUG) << "in tensor " << tensor.Name() << " dims at " << od << " is " << input_dims.d[od];
330 }
331
332 if (!this->trt_context_->setBindingDimensions(index, input_dims)) {
333 MS_LOG(ERROR) << "invalid input dims of " << tensor.Name();
334 return RET_ERROR;
335 }
336 }
337
338 if (!this->trt_context_->allInputDimensionsSpecified()) {
339 MS_LOG(ERROR) << "input dims need to be specified.";
340 return RET_ERROR;
341 }
342
343 for (auto tensor : outputs_) {
344 tensor.MutableData();
345 auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(tensor, tensor.DataSize());
346 if (device_ptr == nullptr) {
347 MS_LOG(ERROR) << "malloc for outputs tensor device memory failed.";
348 return RET_ERROR;
349 }
350 int index = this->engine_->getBindingIndex(tensor.Name().c_str());
351 tensor_bindings_[index] = device_ptr;
352 trt_out_tensor_name_.push_back(tensor.Name());
353 }
354 return RET_OK;
355 }
356
ReSize()357 int TensorRTSubGraph::ReSize() {
358 for (size_t i = 0; i < trt_in_tensor_name_.size(); i++) {
359 // only support resize batch size
360 for (int j = 0; j < this->network_->getNbInputs(); j++) {
361 if (std::strcmp(this->network_->getInput(j)->getName(), trt_in_tensor_name_[i].c_str()) != 0) {
362 continue;
363 }
364 nvinfer1::Dims contruct_dim = this->network_->getInput(j)->getDimensions();
365 if (static_cast<size_t>(contruct_dim.nbDims) != inputs_[i].Shape().size()) {
366 MS_LOG(ERROR) << "invalid resize input.";
367 return RET_ERROR;
368 }
369 if (input_hw_index_ == -1) {
370 // only NHWC format support HW resize, otherwise only support batchsize resize
371 for (int d = 0; d < contruct_dim.nbDims; d++) {
372 if (d != input_batchsize_index_ && contruct_dim.d[d] != inputs_[i].Shape()[d]) {
373 MS_LOG(ERROR) << "only support dynamic batch size resize input.";
374 return RET_ERROR;
375 }
376 }
377 } else {
378 if (contruct_dim.d[DIMENSION_4D - 1] != inputs_[i].Shape()[DIMENSION_4D - 1]) {
379 MS_LOG(ERROR) << "don't support dynamic channel resize input.";
380 return RET_ERROR;
381 }
382 }
383 }
384 MS_LOG(INFO) << "resize at input_batch_index " << input_batchsize_index_ << ", update batch size to "
385 << inputs_[i].Shape()[input_batchsize_index_];
386 runtime_->SetBatchSize(inputs_[i].Shape()[input_batchsize_index_]);
387
388 // inputs_ is dupulated by mindrt, name is untustable.
389 auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(trt_in_tensor_name_[i], inputs_[i].DataSize(),
390 ConvertDataType(inputs_[i].DataType()));
391 if (device_ptr == nullptr) {
392 MS_LOG(ERROR) << "realloc for input tensor device memory failed.";
393 return RET_ERROR;
394 }
395 int index = this->engine_->getBindingIndex(trt_in_tensor_name_[i].c_str());
396 tensor_bindings_[index] = device_ptr;
397 // Set actual input size
398 nvinfer1::Dims input_dims = ConvertCudaDims(inputs_[i].Shape());
399 for (int od = 0; od < input_dims.nbDims; od++) {
400 MS_LOG(DEBUG) << "in tensor " << trt_in_tensor_name_[i] << " dims at " << od << " is " << input_dims.d[od];
401 }
402
403 if (!this->trt_context_->setBindingDimensions(index, input_dims)) {
404 MS_LOG(ERROR) << "invalid input dims of " << inputs_[i].Name();
405 return RET_ERROR;
406 }
407 }
408 if (!this->trt_context_->allInputDimensionsSpecified()) {
409 MS_LOG(ERROR) << "input dims need to be specified.";
410 return RET_ERROR;
411 }
412
413 for (size_t i = 0; i < trt_out_tensor_name_.size(); i++) {
414 int index = this->engine_->getBindingIndex(trt_out_tensor_name_[i].c_str());
415 auto device_ptr = runtime_->GetAllocator()->MallocDeviceMem(trt_out_tensor_name_[i], outputs_[i].DataSize(),
416 ConvertDataType(outputs_[i].DataType()));
417 if (device_ptr == nullptr) {
418 MS_LOG(ERROR) << "realloc for outputs tensor device memory failed.";
419 return RET_ERROR;
420 }
421 tensor_bindings_[index] = device_ptr;
422 }
423 return RET_OK;
424 }
425
Execute()426 int TensorRTSubGraph::Execute() {
427 lite::SetCudaDevice(device_info_);
428 if (runtime_->GetBatchSize() <= 0) {
429 MS_LOG(ERROR) << "TensorRTSubGraph has invalid batch size.";
430 return RET_ERROR;
431 }
432 for (size_t i = 0; i < inputs_.size(); i++) {
433 if (runtime_->GetAllocator()->GetMemIsValid(trt_in_tensor_name_[i])) {
434 MS_LOG(INFO) << "no need memcpy to cuda for input tensor: " << trt_in_tensor_name_[i];
435 continue;
436 }
437 int ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(inputs_[i], trt_in_tensor_name_[i], true);
438 if (ret != RET_OK) {
439 MS_LOG(ERROR) << "sync mem from host to device failed for " << trt_in_tensor_name_[i];
440 return ret;
441 }
442 runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], true);
443 }
444
445 auto ret = this->trt_context_->executeV2(tensor_bindings_);
446 if (!ret) {
447 MS_LOG(ERROR) << "TensorRT execute failed.";
448 return RET_ERROR;
449 }
450
451 for (size_t i = 0; i < trt_out_tensor_name_.size(); i++) {
452 int index = this->engine_->getBindingIndex(trt_out_tensor_name_[i].c_str());
453 // actual output tensor dims
454 auto out_dims = this->trt_context_->getBindingDimensions(index);
455 std::vector<int64_t> new_shape = lite::ConvertMSShape(out_dims);
456 // batchsize resize need set new batch size
457 if (runtime_->GetBatchSize() != new_shape[output_batchsize_index_]) {
458 new_shape[output_batchsize_index_] = runtime_->GetBatchSize();
459 }
460 for (int od = 0; od < out_dims.nbDims; od++) {
461 MS_LOG(DEBUG) << "out tensor " << trt_out_tensor_name_[i] << " dims at " << od << " is " << new_shape[od];
462 }
463 outputs_[i].SetShape(new_shape);
464
465 if (outputs_[i].MutableData() == nullptr) {
466 MS_LOG(ERROR) << "realloc for outputs tensor failed.";
467 return RET_ERROR;
468 }
469 runtime_->GetAllocator()->MarkMemValid(trt_out_tensor_name_[i], true);
470 int sync_ret = runtime_->GetAllocator()->SyncMemInHostAndDevice(outputs_[i], trt_out_tensor_name_[i], false);
471 if (sync_ret != RET_OK) {
472 MS_LOG(ERROR) << "sync mem from device to host failed for " << trt_out_tensor_name_[i];
473 return sync_ret;
474 }
475 runtime_->GetAllocator()->MarkMemValid(trt_out_tensor_name_[i], false);
476 }
477 // make mem invalid, prepare for next execute
478 for (size_t i = 0; i < inputs_.size(); i++) {
479 runtime_->GetAllocator()->MarkMemValid(trt_in_tensor_name_[i], false);
480 }
481 return RET_OK;
482 }
483
FindTensorRTInputs(TensorRTOp * cur_op,const mindspore::MSTensor & in_tensor)484 ITensorHelper TensorRTSubGraph::FindTensorRTInputs(TensorRTOp *cur_op, const mindspore::MSTensor &in_tensor) {
485 for (auto input_op : cur_op->in_ops()) {
486 for (size_t i = 0; i < input_op->outputs().size(); i++) {
487 auto out_tensor = input_op->outputs().at(i);
488 if (in_tensor.Name().compare(out_tensor.Name()) == 0) {
489 return input_op->GetInnerOutTensor().at(i);
490 }
491 }
492 }
493 return ITensorHelper{};
494 }
495 } // namespace mindspore::lite
496