• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/api/serialization.h"
17 #include <dirent.h>
18 #include <fstream>
19 #include <sstream>
20 #include "utils/log_adapter.h"
21 #include "mindspore/core/load_mindir/load_model.h"
22 #include "extendrt/cxx_api/graph/graph_data.h"
23 #if !defined(_WIN32) && !defined(_WIN64)
24 #include "extendrt/cxx_api/dlutils.h"
25 #endif
26 #include "utils/crypto.h"
27 #include "extendrt/cxx_api/file_utils.h"
28 
29 namespace mindspore {
RealPath(const std::string & file,std::string * realpath_str)30 static Status RealPath(const std::string &file, std::string *realpath_str) {
31   MS_EXCEPTION_IF_NULL(realpath_str);
32   char real_path_mem[PATH_MAX] = {0};
33 #if defined(_WIN32) || defined(_WIN64)
34   auto real_path_ret = _fullpath(real_path_mem, common::SafeCStr(file), PATH_MAX);
35 #else
36   auto real_path_ret = realpath(common::SafeCStr(file), real_path_mem);
37 #endif
38   if (real_path_ret == nullptr) {
39     return Status(kMEInvalidInput, "File: " + file + " does not exist.");
40   }
41   *realpath_str = real_path_mem;
42   return kSuccess;
43 }
44 
ReadFile(const std::string & file)45 Buffer ReadFile(const std::string &file) {
46   Buffer buffer;
47   if (file.empty()) {
48     MS_LOG(ERROR) << "Pointer file is nullptr";
49     return buffer;
50   }
51 
52   std::string real_path;
53   auto status = RealPath(file, &real_path);
54   if (status != kSuccess) {
55     MS_LOG(ERROR) << status.GetErrDescription();
56     return buffer;
57   }
58 
59   std::ifstream ifs(real_path);
60   if (!ifs.good()) {
61     MS_LOG(ERROR) << "File: " << real_path << " does not exist";
62     return buffer;
63   }
64 
65   if (!ifs.is_open()) {
66     MS_LOG(ERROR) << "File: " << real_path << " open failed";
67     return buffer;
68   }
69 
70   (void)ifs.seekg(0, std::ios::end);
71   auto tellg_size = ifs.tellg();
72   if (tellg_size < 0) {
73     MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
74     ifs.close();
75     return buffer;
76   }
77   size_t size = static_cast<size_t>(tellg_size);
78   buffer.ResizeData(size);
79   if (buffer.DataSize() != size) {
80     MS_LOG(ERROR) << "Malloc buf failed, file: " << real_path;
81     ifs.close();
82     return buffer;
83   }
84 
85   (void)ifs.seekg(0, std::ios::beg);
86   (void)ifs.read(reinterpret_cast<char *>(buffer.MutableData()), static_cast<std::streamsize>(size));
87   ifs.close();
88 
89   return buffer;
90 }
91 
ReadFileNames(const std::string & dir)92 std::vector<std::string> ReadFileNames(const std::string &dir) {
93   std::vector<std::string> files;
94   auto dp = opendir(dir.c_str());
95   if (dp == nullptr) {
96     return {};
97   }
98   while (true) {
99     auto item = readdir(dp);
100     if (item == nullptr) {
101       break;
102     }
103     if (item->d_type == DT_REG) {
104       files.push_back(item->d_name);
105     }
106   }
107   closedir(dp);
108   return files;
109 }
110 
Key(const char * dec_key,size_t key_len)111 Key::Key(const char *dec_key, size_t key_len) {
112   len = 0;
113   if (key_len >= max_key_len) {
114     MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
115     return;
116   }
117 
118   auto sec_ret = memcpy_s(key, max_key_len, dec_key, key_len);
119   if (sec_ret != EOK) {
120     MS_LOG(ERROR) << "memcpy_s failed, src_len = " << key_len << ", dst_len = " << max_key_len << ", ret = " << sec_ret;
121     return;
122   }
123 
124   len = key_len;
125 }
126 
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)127 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
128                            const Key &dec_key, const std::vector<char> &dec_mode) {
129   std::stringstream err_msg;
130   if (graph == nullptr) {
131     err_msg << "Output args graph is nullptr.";
132     MS_LOG(ERROR) << err_msg.str();
133     return Status(kMEInvalidInput, err_msg.str());
134   }
135   if (model_type == kMindIR) {
136     FuncGraphPtr anf_graph = nullptr;
137     try {
138       if (dec_key.len > dec_key.max_key_len) {
139         err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
140         MS_LOG(ERROR) << err_msg.str();
141         return Status(kMEInvalidInput, err_msg.str());
142       } else if (dec_key.len == 0) {
143         if (IsCipherFile(reinterpret_cast<const unsigned char *>(model_data))) {
144           err_msg << "Load model failed. The model_data may be encrypted, please pass in correct key.";
145           MS_LOG(ERROR) << err_msg.str();
146           return Status(kMEInvalidInput, err_msg.str());
147         } else {
148           anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(model_data), data_size, true);
149         }
150       } else {
151         size_t plain_data_size;
152         auto plain_data = mindspore::Decrypt(&plain_data_size, reinterpret_cast<const unsigned char *>(model_data),
153                                              data_size, dec_key.key, dec_key.len, CharToString(dec_mode));
154         if (plain_data == nullptr) {
155           err_msg << "Load model failed. Please check the valid of dec_key and dec_mode.";
156           MS_LOG(ERROR) << err_msg.str();
157           return Status(kMEInvalidInput, err_msg.str());
158         }
159         anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(plain_data.get()), plain_data_size, true);
160       }
161     } catch (const std::exception &e) {
162       err_msg << "Load model failed. Please check the valid of dec_key and dec_mode." << e.what();
163       MS_LOG(ERROR) << err_msg.str();
164       return Status(kMEInvalidInput, err_msg.str());
165     }
166 
167     *graph = Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
168     return kSuccess;
169   } else if (model_type == kOM) {
170     *graph = Graph(std::make_shared<Graph::GraphData>(Buffer(model_data, data_size), kOM));
171     return kSuccess;
172   }
173 
174   err_msg << "Unsupported ModelType " << model_type;
175   MS_LOG(ERROR) << err_msg.str();
176   return Status(kMEInvalidInput, err_msg.str());
177 }
178 
Load(const std::vector<char> & file,ModelType model_type,Graph * graph)179 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph) {
180   return Load(file, model_type, graph, Key{}, StringToChar(kDecModeAesGcm));
181 }
182 
Load(const std::vector<char> & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)183 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
184                            const std::vector<char> &dec_mode) {
185   std::stringstream err_msg;
186   if (graph == nullptr) {
187     MS_LOG(ERROR) << "Output args graph is nullptr.";
188     return Status(kMEInvalidInput, "Output args graph is nullptr.");
189   }
190 
191   std::string file_path;
192   auto status = RealPath(CharToString(file), &file_path);
193   if (status != kSuccess) {
194     MS_LOG(ERROR) << status.GetErrDescription();
195     return status;
196   }
197 
198   if (model_type == kMindIR) {
199     FuncGraphPtr anf_graph;
200     if (dec_key.len > dec_key.max_key_len) {
201       err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
202       MS_LOG(ERROR) << err_msg.str();
203       return Status(kMEInvalidInput, err_msg.str());
204     } else if (dec_key.len == 0 && IsCipherFile(file_path)) {
205       err_msg << "Load model failed. The file may be encrypted, please pass in correct key.";
206       MS_LOG(ERROR) << err_msg.str();
207       return Status(kMEInvalidInput, err_msg.str());
208     }
209     MindIRLoader mindir_loader(true, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode),
210                                false);
211     anf_graph = mindir_loader.LoadMindIR(file_path);
212     if (anf_graph == nullptr) {
213       err_msg << "Load model failed.";
214       MS_LOG(ERROR) << err_msg.str();
215       return Status(kMEInvalidInput, err_msg.str());
216     }
217     auto graph_data = std::make_shared<Graph::GraphData>(anf_graph, kMindIR);
218 #if !defined(_WIN32) && !defined(_WIN64)
219     // Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
220     std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(file_path);
221     if (!preprocessor.empty()) {
222       std::string dataengine_so_path;
223       Status dlret = DLSoPath({"libmindspore.so"}, "_c_dataengine", &dataengine_so_path);
224       CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
225 
226       void *handle = nullptr;
227       void *function = nullptr;
228       dlret = DLSoOpen(dataengine_so_path, "ParseMindIRPreprocess_C", &handle, &function);
229       CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ParseMindIRPreprocess_C failed: " + dlret.GetErrDescription());
230       auto ParseMindIRPreprocessFun =
231         (void (*)(const std::vector<std::string> &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
232                   Status *))(function);
233 
234       std::vector<std::shared_ptr<dataset::Execute>> data_graph;
235       ParseMindIRPreprocessFun(preprocessor, &data_graph, &dlret);
236       CHECK_FAIL_AND_RELEASE(dlret, handle, "Load preprocess failed: " + dlret.GetErrDescription());
237       DLSoClose(handle);
238       if (!data_graph.empty()) {
239         graph_data->SetPreprocess(data_graph);
240       }
241     }
242 #endif
243     *graph = Graph(graph_data);
244     return kSuccess;
245   } else if (model_type == kOM) {
246     Buffer data = ReadFile(file_path);
247     if (data.Data() == nullptr) {
248       err_msg << "Read file " << file_path << " failed.";
249       MS_LOG(ERROR) << err_msg.str();
250       return Status(kMEInvalidInput, err_msg.str());
251     }
252     *graph = Graph(std::make_shared<Graph::GraphData>(data, kOM));
253     return kSuccess;
254   }
255 
256   err_msg << "Unsupported ModelType " << model_type;
257   MS_LOG(ERROR) << err_msg.str();
258   return Status(kMEInvalidInput, err_msg.str());
259 }
260 
Load(const std::vector<std::vector<char>> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::vector<char> & dec_mode)261 Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
262                            std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
263   std::stringstream err_msg;
264   if (graphs == nullptr) {
265     MS_LOG(ERROR) << "Output args graph is nullptr.";
266     return Status(kMEInvalidInput, "Output args graph is nullptr.");
267   }
268 
269   if (files.size() == 1) {
270     std::vector<Graph> result(files.size());
271     auto ret = Load(files[0], model_type, &result[0], dec_key, dec_mode);
272     *graphs = std::move(result);
273     return ret;
274   }
275 
276   std::vector<std::string> files_path;
277   for (const auto &file : files) {
278     std::string file_path;
279     auto status = RealPath(CharToString(file), &file_path);
280     if (status != kSuccess) {
281       MS_LOG(ERROR) << status.GetErrDescription();
282       return status;
283     }
284     files_path.emplace_back(std::move(file_path));
285   }
286 
287   if (model_type == kMindIR) {
288     if (dec_key.len > dec_key.max_key_len) {
289       err_msg << "The key length exceeds maximum length: " << dec_key.max_key_len;
290       MS_LOG(ERROR) << err_msg.str();
291       return Status(kMEInvalidInput, err_msg.str());
292     }
293     MindIRLoader mindir_loader(true, dec_key.len == 0 ? nullptr : dec_key.key, dec_key.len, CharToString(dec_mode),
294                                true);
295     auto anf_graphs = mindir_loader.LoadMindIRs(files_path);
296     if (anf_graphs.size() != files_path.size()) {
297       err_msg << "Load model failed, " << files_path.size() << " files got " << anf_graphs.size() << " graphs.";
298       MS_LOG(ERROR) << err_msg.str();
299       return Status(kMEInvalidInput, err_msg.str());
300     }
301 #if !defined(_WIN32) && !defined(_WIN64)
302     // Dataset so loading
303     std::string dataengine_so_path;
304     Status dlret = DLSoPath({"libmindspore.so"}, "_c_dataengine", &dataengine_so_path);
305     CHECK_FAIL_AND_RELEASE(dlret, nullptr, "Parse dataengine_so failed: " + dlret.GetErrDescription());
306 
307     void *handle = nullptr;
308     void *function = nullptr;
309     dlret = DLSoOpen(dataengine_so_path, "ParseMindIRPreprocess_C", &handle, &function);
310     CHECK_FAIL_AND_RELEASE(dlret, handle, "Parse ParseMindIRPreprocess_C failed: " + dlret.GetErrDescription());
311 
312     auto ParseMindIRPreprocessFun =
313       (void (*)(const std::vector<std::string> &, std::vector<std::shared_ptr<mindspore::dataset::Execute>> *,
314                 Status *))(function);
315 #endif
316     std::vector<Graph> results;
317     for (size_t i = 0; i < anf_graphs.size(); ++i) {
318       if (anf_graphs[i] == nullptr) {
319         if (dec_key.len == 0 && IsCipherFile(files_path[i])) {
320           err_msg << "Load model failed. The file " << files_path[i] << " be encrypted, please pass in correct key.";
321         } else {
322           err_msg << "Load model " << files_path[i] << " failed.";
323         }
324         MS_LOG(ERROR) << err_msg.str();
325         return Status(kMEInvalidInput, err_msg.str());
326       }
327       auto graph_data = std::make_shared<Graph::GraphData>(anf_graphs[i], kMindIR);
328 #if !defined(_WIN32) && !defined(_WIN64)
329       // Config preprocessor, temporary way to let mindspore.so depends on _c_dataengine
330       std::vector<std::string> preprocessor = mindir_loader.LoadPreprocess(files_path[i]);
331       if (!preprocessor.empty()) {
332         std::vector<std::shared_ptr<dataset::Execute>> data_graph;
333         ParseMindIRPreprocessFun(preprocessor, &data_graph, &dlret);
334         CHECK_FAIL_AND_RELEASE(dlret, handle, "Load preprocess failed: " + dlret.GetErrDescription());
335         if (!data_graph.empty()) {
336           graph_data->SetPreprocess(data_graph);
337         }
338       }
339 #endif
340       results.emplace_back(graph_data);
341     }
342 #if !defined(_WIN32) && !defined(_WIN64)
343     // Dataset so release
344     DLSoClose(handle);
345 #endif
346     *graphs = std::move(results);
347     return kSuccess;
348   }
349 
350   err_msg << "Unsupported ModelType " << model_type;
351   MS_LOG(ERROR) << err_msg.str();
352   return Status(kMEInvalidInput, err_msg.str());
353 }
354 
SetParameters(const std::map<std::vector<char>,Buffer> &,Model *)355 Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> &, Model *) {
356   MS_LOG(ERROR) << "Unsupported feature.";
357   return kMEFailed;
358 }
359 
ExportModel(const Model &,ModelType,Buffer *,QuantizationType,bool,const std::vector<std::vector<char>> &)360 Status Serialization::ExportModel(const Model &, ModelType, Buffer *, QuantizationType, bool,
361                                   const std::vector<std::vector<char>> & /* output_tensor_name */) {
362   MS_LOG(ERROR) << "Unsupported feature.";
363   return kMEFailed;
364 }
365 
ExportModel(const Model &,ModelType,const std::vector<char> &,QuantizationType,bool,const std::vector<std::vector<char>> & output_tensor_name)366 Status Serialization::ExportModel(const Model &, ModelType, const std::vector<char> &, QuantizationType, bool,
367                                   const std::vector<std::vector<char>> &output_tensor_name) {
368   MS_LOG(ERROR) << "Unsupported feature.";
369   return kMEFailed;
370 }
371 
ExportWeightsCollaborateWithMicro(const Model &,ModelType,const std::vector<char> &,bool,bool,const std::vector<std::vector<char>> &)372 Status Serialization::ExportWeightsCollaborateWithMicro(const Model &, ModelType, const std::vector<char> &, bool, bool,
373                                                         const std::vector<std::vector<char>> &) {
374   MS_LOG(ERROR) << "Unsupported feature.";
375   return kMEFailed;
376 }
377 }  // namespace mindspore
378