• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <dirent.h>
18 #include <sys/stat.h>
19 #include <sys/types.h>
20 #include <cstring>
21 #include <string>
22 #include <memory>
23 #include <algorithm>
24 #include <fstream>
25 #include <iostream>
26 
27 #include "load_mindir/load_model.h"
28 #include "load_mindir/anf_model_parser.h"
29 #include "proto/mind_ir.pb.h"
30 #include "utils/crypto.h"
31 
32 using std::string;
33 using std::vector;
34 
35 namespace mindspore {
ReadProtoFile(const std::string & file)36 std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
37   if (file.empty()) {
38     MS_LOG(ERROR) << "file is nullptr";
39     return nullptr;
40   }
41 
42   char real_path[PATH_MAX] = {0};
43 #if defined(_WIN32) || defined(_WIN64)
44   if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) {
45     MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
46     return nullptr;
47   }
48 #else
49   if (realpath(file.c_str(), real_path) == nullptr) {
50     MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
51     return nullptr;
52   }
53 #endif
54 
55   std::ifstream ifs(real_path);
56   if (!ifs.good()) {
57     MS_LOG(ERROR) << "file: " << real_path << " is not exist";
58     return nullptr;
59   }
60 
61   if (!ifs.is_open()) {
62     MS_LOG(ERROR) << "file: " << real_path << "open failed";
63     return nullptr;
64   }
65 
66   ifs.seekg(0, std::ios::end);
67   size_t size = ifs.tellg();
68   std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
69   if (buf == nullptr) {
70     MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
71     ifs.close();
72     return nullptr;
73   }
74 
75   ifs.seekg(0, std::ios::beg);
76   ifs.read(buf->data(), size);
77   ifs.close();
78 
79   return buf;
80 }
81 
get_all_files(const std::string & dir_in,std::vector<std::string> * files)82 bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
83   if (dir_in.empty()) {
84     return false;
85   }
86   struct stat s;
87   int ret = stat(dir_in.c_str(), &s);
88   if (ret != 0) {
89     MS_LOG(ERROR) << "stat error, ret is : " << ret;
90     return false;
91   }
92   if (!S_ISDIR(s.st_mode)) {
93     return false;
94   }
95   DIR *open_dir = opendir(dir_in.c_str());
96   if (open_dir == nullptr) {
97     MS_LOG(EXCEPTION) << "open dir " << dir_in.c_str() << " failed";
98   }
99   dirent *p = nullptr;
100   bool list_ret = true;
101   while ((p = readdir(open_dir)) != nullptr) {
102     struct stat st;
103     if (p->d_name[0] != '.') {
104       std::string name = dir_in + std::string("/") + std::string(p->d_name);
105       ret = stat(name.c_str(), &st);
106       if (ret != 0) {
107         MS_LOG(ERROR) << "stat error, ret is : " << ret;
108         list_ret = false;
109         break;
110       }
111       if (S_ISDIR(st.st_mode)) {
112         bool result = get_all_files(name, files);
113         if (!result) {
114           MS_LOG(ERROR) << "Get files failed.";
115           list_ret = false;
116           break;
117         }
118       } else if (S_ISREG(st.st_mode)) {
119         files->push_back(name);
120       }
121     }
122   }
123   closedir(open_dir);
124   return list_ret;
125 }
126 
endsWith(const string s,const string sub)127 int endsWith(const string s, const string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
128 
ParseModelProto(mind_ir::ModelProto * model,const std::string & path,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode)129 bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path, const unsigned char *dec_key,
130                      const size_t key_len, const std::string &dec_mode) {
131   if (dec_key != nullptr) {
132     size_t plain_len;
133     auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
134     if (plain_data == nullptr) {
135       MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
136       return false;
137     }
138     if (!model->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
139       MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode.";
140       return false;
141     }
142   } else {
143     std::fstream input_graph(path, std::ios::in | std::ios::binary);
144     if (!input_graph || !model->ParseFromIstream(&input_graph)) {
145       MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
146       return false;
147     }
148   }
149   return true;
150 }
151 
ParseGraphProto(mind_ir::GraphProto * graph,const std::string & path,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode)152 bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const unsigned char *dec_key,
153                      const size_t key_len, const std::string &dec_mode) {
154   if (dec_key != nullptr) {
155     size_t plain_len;
156     auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
157     if (plain_data == nullptr) {
158       MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
159       return false;
160     }
161     if (!graph->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
162       MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, "
163                        "dec_key or dec_mode";
164       return false;
165     }
166   } else {
167     std::fstream input_param(path, std::ios::in | std::ios::binary);
168     if (!input_param || !graph->ParseFromIstream(&input_param)) {
169       MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
170       return false;
171     }
172   }
173   return true;
174 }
175 
LoadPreprocess(const std::string & file_name)176 std::string LoadPreprocess(const std::string &file_name) {
177   if (file_name.length() > PATH_MAX) {
178     MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
179     return nullptr;
180   }
181   char abs_path_buff[PATH_MAX];
182 
183 #ifdef _WIN32
184   _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
185 #else
186   if (!realpath(file_name.c_str(), abs_path_buff)) {
187     MS_LOG(ERROR) << "Load MindIR get absolute path failed";
188   }
189 #endif
190 
191   // Read graph
192   mind_ir::ModelProto origin_model;
193   std::fstream mindir_stream(std::string(std::string(abs_path_buff)), std::ios::in | std::ios::binary);
194   if (!mindir_stream || !origin_model.ParseFromIstream(&mindir_stream)) {
195     MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
196     return std::string();
197   }
198 
199   return origin_model.preprocessor();
200 }
201 
LoadMindIRs(std::vector<std::string> file_names,bool is_lite,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode)202 std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
203                                                     const unsigned char *dec_key, const size_t key_len,
204                                                     const std::string &dec_mode) {
205   std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
206   for (const auto &file_name : file_names) {
207     MS_LOG(DEBUG) << "Load " << file_name;
208     funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, true));
209   }
210   return funcgraph_vec;
211 }
212 
LoadMindIR(const std::string & file_name,bool is_lite,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode,bool inc_load)213 std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
214                                       const size_t key_len, const std::string &dec_mode, bool inc_load) {
215   if (file_name.length() > PATH_MAX) {
216     MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
217     return nullptr;
218   }
219   char abs_path_buff[PATH_MAX];
220   vector<string> files;
221 
222 #ifdef _WIN32
223   _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
224 #else
225   if (!realpath(file_name.c_str(), abs_path_buff)) {
226     MS_LOG(ERROR) << "Load MindIR get absolute path failed";
227   }
228 #endif
229   // Read graph
230   mind_ir::ModelProto origin_model;
231   if (!ParseModelProto(&origin_model, std::string(abs_path_buff), dec_key, key_len, dec_mode)) {
232     return nullptr;
233   }
234   // Load parameter into graph
235   if (endsWith(std::string(abs_path_buff), "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
236     if (strlen(abs_path_buff) < strlen("graph.mindir")) {
237       MS_LOG(ERROR) << "The abs_path_buff length is less than 'graph.mindir'.";
238       return nullptr;
239     }
240     int path_len = SizeToInt(strlen(abs_path_buff) - strlen("graph.mindir"));
241     string var_path = std::string(abs_path_buff).substr(0, path_len);
242     var_path += "variables";
243     std::ifstream ifs(var_path);
244     if (ifs.good()) {
245       bool result = get_all_files(var_path, &files);
246       if (!result) {
247         MS_LOG(ERROR) << "Get files failed.";
248         return nullptr;
249       }
250     } else {
251       MS_LOG(ERROR) << "Load graph's variable folder failed, please check the correctness of variable folder.";
252       return nullptr;
253     }
254 
255     size_t file_size = files.size();
256     mind_ir::GraphProto *mod_graph = origin_model.mutable_graph();
257     for (size_t file_index = 0; file_index < file_size; file_index++) {
258       mind_ir::GraphProto param_graph;
259       if (!ParseGraphProto(&param_graph, files[file_index], dec_key, key_len, dec_mode)) {
260         return nullptr;
261       }
262       for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) {
263         mind_ir::TensorProto *param_proto = mod_graph->add_parameter();
264         param_proto->set_name(param_graph.parameter(param_index).name());
265         param_proto->set_data_type(param_graph.parameter(param_index).data_type());
266         param_proto->set_raw_data(param_graph.parameter(param_index).raw_data());
267         for (const auto &dim : param_graph.parameter(param_index).dims()) {
268           param_proto->add_dims(dim);
269         }
270       }
271     }
272   }
273 
274   MSANFModelParser model_parser;
275   if (!inc_load) {
276     MSANFModelParser::LoadTensorMapClear();
277   }
278   if (is_lite) {
279     model_parser.SetLite();
280   }
281   if (inc_load) {
282     model_parser.SetIncLoad();
283   }
284   FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
285   return dstgraph_ptr;
286 }
287 
ConvertStreamToFuncGraph(const char * buf,const size_t buf_size,bool is_lite)288 std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
289   MS_EXCEPTION_IF_NULL(buf);
290   std::string str(buf, buf_size);
291   mind_ir::ModelProto model_;
292   if (!model_.ParseFromString(str)) {
293     MS_LOG(ERROR) << "Parse model from buffer fail!";
294   }
295   MSANFModelParser model_parser;
296   if (is_lite) {
297     model_parser.SetLite();
298   }
299   FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
300   return dstgraph_ptr;
301 }
302 }  // namespace mindspore
303