• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/extendrt/delegate/tensorrt/tensorrt_serializer.h"
18 #include "src/extendrt/delegate/tensorrt/tensorrt_runtime.h"
19 #include "src/common/file_utils.h"
20 
21 namespace mindspore::lite {
GetSerializedEngine()22 nvinfer1::ICudaEngine *TensorRTSerializer::GetSerializedEngine() {
23   if (serialize_file_path_.empty()) {
24     return nullptr;
25   }
26   char *trt_model_stream{nullptr};
27   size_t size{0};
28   trt_model_stream = ReadFile(serialize_file_path_.c_str(), &size);
29   if (trt_model_stream == nullptr || size == 0) {
30     MS_LOG(WARNING) << "read engine file failed : " << serialize_file_path_;
31     return nullptr;
32   }
33   nvinfer1::IRuntime *runtime = nvinfer1::createInferRuntime(logger_);
34   if (runtime == nullptr) {
35     delete[] trt_model_stream;
36     MS_LOG(ERROR) << "createInferRuntime failed.";
37     return nullptr;
38   }
39   nvinfer1::ICudaEngine *engine = runtime->deserializeCudaEngine(trt_model_stream, size, nullptr);
40   delete[] trt_model_stream;
41   runtime->destroy();
42   return engine;
43 }
SaveSerializedEngine(nvinfer1::ICudaEngine * engine)44 void TensorRTSerializer::SaveSerializedEngine(nvinfer1::ICudaEngine *engine) {
45   if (serialize_file_path_.size() == 0) {
46     return;
47   }
48   nvinfer1::IHostMemory *ptr = engine->serialize();
49   if (ptr == nullptr) {
50     MS_LOG(ERROR) << "serialize engine failed";
51     return;
52   }
53 
54   int ret = WriteToBin(serialize_file_path_, ptr->data(), ptr->size());
55   if (ret != RET_OK) {
56     MS_LOG(ERROR) << "save engine failed " << serialize_file_path_;
57   } else {
58     MS_LOG(INFO) << "save engine to " << serialize_file_path_;
59   }
60   ptr->destroy();
61   return;
62 }
63 }  // namespace mindspore::lite
64