1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/api/serialization.h"
17 #include <dirent.h>
18 #include <fstream>
19 #include <sstream>
20 #include "utils/log_adapter.h"
21 #include "mindspore/core/load_mindir/load_model.h"
22 #include "extendrt/cxx_api/graph/graph_data.h"
23 #if !defined(_WIN32) && !defined(_WIN64)
24 #include "extendrt/cxx_api/dlutils.h"
25 #endif
26 #include "utils/crypto.h"
27 #include "extendrt/cxx_api/file_utils.h"
28
29 namespace mindspore {
RealPath(const std::string & file,std::string * realpath_str)30 static Status RealPath(const std::string &file, std::string *realpath_str) {
31 MS_EXCEPTION_IF_NULL(realpath_str);
32 char real_path_mem[PATH_MAX] = {0};
33 #if defined(_WIN32) || defined(_WIN64)
34 auto real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX);
35 #else
36 auto real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
37 #endif
38 if (real_path_ret == nullptr) {
39 return Status(kMEInvalidInput, "File: " + file + " does not exist.");
40 }
41 *realpath_str = real_path_mem;
42 return kSuccess;
43 }
44
ReadFile(const std::string & file)45 Buffer ReadFile(const std::string &file) {
46 Buffer buffer;
47 if (file.empty()) {
48 MS_LOG(ERROR) << "Pointer file is nullptr";
49 return buffer;
50 }
51
52 std::string real_path;
53 auto status = RealPath(file, &real_path);
54 if (status != kSuccess) {
55 MS_LOG(ERROR) << status.GetErrDescription();
56 return buffer;
57 }
58
59 std::ifstream ifs(real_path);
60 if (!ifs.good()) {
61 MS_LOG(ERROR) << "File: " << real_path << " does not exist";
62 return buffer;
63 }
64
65 if (!ifs.is_open()) {
66 MS_LOG(ERROR) << "File: " << real_path << " open failed";
67 return buffer;
68 }
69
70 (void)ifs.seekg(0, std::ios::end);
71 auto tellg_size = ifs.tellg();
72 if (tellg_size < 0) {
73 MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
74 ifs.close();
75 return buffer;
76 }
77 size_t size = static_cast<size_t>(tellg_size);
78 buffer.ResizeData(size);
79 if (buffer.DataSize() != size) {
80 MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
81 ifs.close();
82 return buffer;
83 }
84
85 (void)ifs.seekg(0, std::ios::beg);
86 (void)ifs.read(reinterpret_cast<char *>(buffer.MutableData()), static_cast<std::streamsize>(size));
87 ifs.close();
88
89 return buffer;
90 }
91
ReadFileNames(const std::string & dir)92 std::vector<std::string> ReadFileNames(const std::string &dir) {
93 std::vector<std::string> files;
94 auto dp = opendir(dir.c_str());
95 if (dp == nullptr) {
96 return {};
97 }
98 while (true) {
99 auto item = readdir(dp);
100 if (item == nullptr) {
101 break;
102 }
103 if (item->d_type == DT_REG) {
104 files.push_back(item->d_name);
105 }
106 }
107 closedir(dp);
108 return files;
109 }
110
Key(const char * dec_key,size_t key_len)111 Key::Key(const char *dec_key, size_t key_len) {
112 len = 0;
113 if (key_len >= max_key_len) {
114 MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
115 return;
116 }
117
118 auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
119 if (sec_ret != EOK) {
120 MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
121 return;
122 }
123
124 len = key_len;
125 }
126
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)127 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
128 const Key &dec_key, const std::vector<char> &dec_mode) {
129 std::stringstream err_msg;
130 if (graph == nullptr) {
131 err_msg << "Output args graph is nullptr.";
132 MS_LOG(ERROR) << err_msg.str();
133 return Status(kMEInvalidInput, err_msg.str());
134 }
135 if (model_type == kMindIR) {
136 FuncGraphPtr anf_graph = nullptr;
137 try {
138 if (dec_key.len > dec_key.max_key_len) {
139 err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
140 MS_LOG(ERROR) << err_msg.str();
141 return Status(kMEInvalidInput, err_msg.str());
142 } else if (dec_key.len == 0) {
143 if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
144 err_msg << "Load model failed. The model_data may be encrypted, please pass in correct key.";
145 MS_LOG(ERROR) << err_msg.str();
146 return Status(kMEInvalidInput, err_msg.str());
147 } else {
148 anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size, true);
149 }
150 } else {
151 size_t plain_data_size;
152 auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data),
153 data_size, dec_key.key, dec_key.len, CharToString(dec_mode));
154 if (plain_data == nullptr) {
155 err_msg << "Load model failed. Please check the valid of dec_key and dec_mode.";
156 MS_LOG(ERROR) << err_msg.str();
157 return Status(kMEInvalidInput, err_msg.str());
158 }
159 anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size, true);
160 }
161 } catch (const std::exception &e) {
162 err_msg << "Load model failed. Please check the valid of dec_key and dec_mode." << e.what();
163 MS_LOG(ERROR) << err_msg.str();
164 return Status(kMEInvalidInput, err_msg.str());
165 }
166
167 *graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
168 return kSuccess;
169 } else if (model_type == kOM) {
170 *graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
171 return kSuccess;
172 }
173
174 err_msg << "Unsupported ModelType " << model_type;
175 MS_LOG(ERROR) << err_msg.str();
176 return Status(kMEInvalidInput, err_msg.str());
177 }
178
Load(const std::vector<char> & file,ModelType model_type,Graph * graph)179 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
180 return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm));
181 }
182
Load(const std::vector<char> & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)183 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
184 const std::vector<char> &dec_mode) {
185 std::stringstream err_msg;
186 if (graph == nullptr) {
187 MS_LOG(ERROR) << "Output args graph is nullptr.";
188 return Status(kMEInvalidInput, "Output args graph is nullptr.");
189 }
190
191 std::string file_path;
192 auto status = RealPath(CharToString(file), &file_path);
193 if (status != kSuccess) {
194 MS_LOG(ERROR) << status.GetErrDescription();
195 return status;
196 }
197
198 if (model_type == kMindIR) {
199 FuncGraphPtr anf_graph;
200 if (dec_key.len > dec_key.max_key_len) {
201 err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
202 MS_LOG(ERROR) << err_msg.str();
203 return Status(kMEInvalidInput, err_msg.str());
204 } else if (dec_key.len == 0 && IsCipherFile(file_path)) {
205 err_msg << "Load model failed. The file may be encrypted, please pass in correct key.";
206 MS_LOG(ERROR) << err_msg.str();
207 return Status(kMEInvalidInput, err_msg.str());
208 }
209 MindIRLoader mindir_loader(true, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode),
210 false);
211 anf_graph = mindir_loader.LoadMindIR(file_path);
212 if (anf_graph == nullptr) {
213 err_msg << "Load model failed.";
214 MS_LOG(ERROR) << err_msg.str();
215 return Status(kMEInvalidInput, err_msg.str());
216 }
217 auto graph_data = std::make_shared<Graph::GraphData>(anf_graph, kMindIR);
218 #if !defined(_WIN32) && !defined(_WIN64)
219 // Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
220 std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(file_path);
221 if (!preprocessor.empty()) {
222 std::string dataengine_so_path;
223 Status dlret = DLSoPath({"libmindspore.so"}, "_c_dataengine", &dataengine_so_path);
224 CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
225
226 void *handle = nullptr;
227 void *function = nullptr;
228 dlret = DLSoOpen(dataengine_so_path, "ParseMindIRPreprocess_C", &handle, &function);
229 CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ParseMindIRPreprocess_C failed: " + dlret.GetErrDescription());
230 auto ParseMindIRPreprocessFun =
231 (void (*)(const std::vector<std::string> &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
232 Status *))(function);
233
234 std::vector<std::shared_ptr<dataset::Execute>> data_graph;
235 ParseMindIRPreprocessFun(preprocessor, &data_graph, &dlret);
236 CHECK_FAIL_AND_RELEASE(dlret, handle, "Load preprocess failed: " + dlret.GetErrDescription());
237 DLSoClose(handle);
238 if (!data_graph.empty()) {
239 graph_data->SetPreprocess(data_graph);
240 }
241 }
242 #endif
243 *graph = Graph(graph_data);
244 return kSuccess;
245 } else if (model_type == kOM) {
246 Buffer data = ReadFile(file_path);
247 if (data.Data() == nullptr) {
248 err_msg << "Read file " << file_path << " failed.";
249 MS_LOG(ERROR) << err_msg.str();
250 return Status(kMEInvalidInput, err_msg.str());
251 }
252 *graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
253 return kSuccess;
254 }
255
256 err_msg << "Unsupported ModelType " << model_type;
257 MS_LOG(ERROR) << err_msg.str();
258 return Status(kMEInvalidInput, err_msg.str());
259 }
260
Load(const std::vector<std::vector<char>> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::vector<char> & dec_mode)261 Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
262 std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
263 std::stringstream err_msg;
264 if (graphs == nullptr) {
265 MS_LOG(ERROR) << "Output args graph is nullptr.";
266 return Status(kMEInvalidInput, "Output args graph is nullptr.");
267 }
268
269 if (files.size() == 1) {
270 std::vector<Graph> result(files.size());
271 auto ret = Load(files[0], model_type, &result[0], dec_key, dec_mode);
272 *graphs = std::move(result);
273 return ret;
274 }
275
276 std::vector<std::string> files_path;
277 for (const auto &file : files) {
278 std::string file_path;
279 auto status = RealPath(CharToString(file), &file_path);
280 if (status != kSuccess) {
281 MS_LOG(ERROR) << status.GetErrDescription();
282 return status;
283 }
284 files_path.emplace_back(std::move(file_path));
285 }
286
287 if (model_type == kMindIR) {
288 if (dec_key.len > dec_key.max_key_len) {
289 err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
290 MS_LOG(ERROR) << err_msg.str();
291 return Status(kMEInvalidInput, err_msg.str());
292 }
293 MindIRLoader mindir_loader(true, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode),
294 true);
295 auto anf_graphs = mindir_loader.LoadMindIRs(files_path);
296 if (anf_graphs.size() != files_path.size()) {
297 err_msg << "Load model failed, " << files_path.size() << " files got " << anf_graphs.size() << " graphs.";
298 MS_LOG(ERROR) << err_msg.str();
299 return Status(kMEInvalidInput, err_msg.str());
300 }
301 #if !defined(_WIN32) && !defined(_WIN64)
302 // Dataset so loading
303 std::string dataengine_so_path;
304 Status dlret = DLSoPath({"libmindspore.so"}, "_c_dataengine", &dataengine_so_path);
305 CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
306
307 void *handle = nullptr;
308 void *function = nullptr;
309 dlret = DLSoOpen(dataengine_so_path, "ParseMindIRPreprocess_C", &handle, &function);
310 CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ParseMindIRPreprocess_C failed: " + dlret.GetErrDescription());
311
312 auto ParseMindIRPreprocessFun =
313 (void (*)(const std::vector<std::string> &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
314 Status *))(function);
315 #endif
316 std::vector<Graph> results;
317 for (size_t i = 0; i < anf_graphs.size(); ++i) {
318 if (anf_graphs[i] == nullptr) {
319 if (dec_key.len == 0 && IsCipherFile(files_path[i])) {
320 err_msg << "Load model failed. The file " << files_path[i] << " be encrypted, please pass in correct key.";
321 } else {
322 err_msg << "Load model " << files_path[i] << " failed.";
323 }
324 MS_LOG(ERROR) << err_msg.str();
325 return Status(kMEInvalidInput, err_msg.str());
326 }
327 auto graph_data = std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR);
328 #if !defined(_WIN32) && !defined(_WIN64)
329 // Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
330 std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(files_path[i]);
331 if (!preprocessor.empty()) {
332 std::vector<std::shared_ptr<dataset::Execute>> data_graph;
333 ParseMindIRPreprocessFun(preprocessor, &data_graph, &dlret);
334 CHECK_FAIL_AND_RELEASE(dlret, handle, "Load preprocess failed: " + dlret.GetErrDescription());
335 if (!data_graph.empty()) {
336 graph_data->SetPreprocess(data_graph);
337 }
338 }
339 #endif
340 results.emplace_back(graph_data);
341 }
342 #if !defined(_WIN32) && !defined(_WIN64)
343 // Dataset so release
344 DLSoClose(handle);
345 #endif
346 *graphs = std::move(results);
347 return kSuccess;
348 }
349
350 err_msg << "Unsupported ModelType " << model_type;
351 MS_LOG(ERROR) << err_msg.str();
352 return Status(kMEInvalidInput, err_msg.str());
353 }
354
SetParameters(const std::map<std::vector<char>,Buffer> &,Model *)355 Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &, Model *) {
356 MS_LOG(ERROR) << "Unsupported feature.";
357 return kMEFailed;
358 }
359
ExportModel(const Model &,ModelType,Buffer *,QuantizationType,bool,const std::vector<std::vector<char>> &)360 Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
361 const std::vector<std::vector<char>> & /* output_tensor_name */) {
362 MS_LOG(ERROR) << "Unsupported feature.";
363 return kMEFailed;
364 }
365
ExportModel(const Model &,ModelType,const std::vector<char> &,QuantizationType,bool,const std::vector<std::vector<char>> & output_tensor_name)366 Status Serialization::ExportModel(const Model &, ModelType, const std::vector<char> &, QuantizationType, bool,
367 const std::vector<std::vector<char>> &output_tensor_name) {
368 MS_LOG(ERROR) << "Unsupported feature.";
369 return kMEFailed;
370 }
371
ExportWeightsCollaborateWithMicro(const Model &,ModelType,const std::vector<char> &,bool,bool,const std::vector<std::vector<char>> &)372 Status Serialization::ExportWeightsCollaborateWithMicro(const Model &, ModelType, const std::vector<char> &, bool, bool,
373 const std::vector<std::vector<char>> &) {
374 MS_LOG(ERROR) << "Unsupported feature.";
375 return kMEFailed;
376 }
377 } // namespace mindspore
378