1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <dirent.h>
18 #include <sys/stat.h>
19 #include <sys/types.h>
20 #include <cstring>
21 #include <string>
22 #include <memory>
23 #include <algorithm>
24 #include <fstream>
25 #include <iostream>
26
27 #include "load_mindir/load_model.h"
28 #include "load_mindir/anf_model_parser.h"
29 #include "proto/mind_ir.pb.h"
30 #include "utils/crypto.h"
31
32 using std::string;
33 using std::vector;
34
35 namespace mindspore {
ReadProtoFile(const std::string & file)36 std::shared_ptr<std::vector<char>> ReadProtoFile(const std::string &file) {
37 if (file.empty()) {
38 MS_LOG(ERROR) << "file is nullptr";
39 return nullptr;
40 }
41
42 char real_path[PATH_MAX] = {0};
43 #if defined(_WIN32) || defined(_WIN64)
44 if (_fullpath(real_path, file.c_str(), PATH_MAX) == nullptr) {
45 MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
46 return nullptr;
47 }
48 #else
49 if (realpath(file.c_str(), real_path) == nullptr) {
50 MS_LOG(ERROR) << "Get realpath failed, mind ir file is" << file;
51 return nullptr;
52 }
53 #endif
54
55 std::ifstream ifs(real_path);
56 if (!ifs.good()) {
57 MS_LOG(ERROR) << "file: " << real_path << " is not exist";
58 return nullptr;
59 }
60
61 if (!ifs.is_open()) {
62 MS_LOG(ERROR) << "file: " << real_path << "open failed";
63 return nullptr;
64 }
65
66 ifs.seekg(0, std::ios::end);
67 size_t size = ifs.tellg();
68 std::shared_ptr<std::vector<char>> buf(new (std::nothrow) std::vector<char>(size));
69 if (buf == nullptr) {
70 MS_LOG(ERROR) << "malloc buf failed, file: " << real_path;
71 ifs.close();
72 return nullptr;
73 }
74
75 ifs.seekg(0, std::ios::beg);
76 ifs.read(buf->data(), size);
77 ifs.close();
78
79 return buf;
80 }
81
get_all_files(const std::string & dir_in,std::vector<std::string> * files)82 bool get_all_files(const std::string &dir_in, std::vector<std::string> *files) {
83 if (dir_in.empty()) {
84 return false;
85 }
86 struct stat s;
87 int ret = stat(dir_in.c_str(), &s);
88 if (ret != 0) {
89 MS_LOG(ERROR) << "stat error, ret is : " << ret;
90 return false;
91 }
92 if (!S_ISDIR(s.st_mode)) {
93 return false;
94 }
95 DIR *open_dir = opendir(dir_in.c_str());
96 if (open_dir == nullptr) {
97 MS_LOG(EXCEPTION) << "open dir " << dir_in.c_str() << " failed";
98 }
99 dirent *p = nullptr;
100 bool list_ret = true;
101 while ((p = readdir(open_dir)) != nullptr) {
102 struct stat st;
103 if (p->d_name[0] != '.') {
104 std::string name = dir_in + std::string("/") + std::string(p->d_name);
105 ret = stat(name.c_str(), &st);
106 if (ret != 0) {
107 MS_LOG(ERROR) << "stat error, ret is : " << ret;
108 list_ret = false;
109 break;
110 }
111 if (S_ISDIR(st.st_mode)) {
112 bool result = get_all_files(name, files);
113 if (!result) {
114 MS_LOG(ERROR) << "Get files failed.";
115 list_ret = false;
116 break;
117 }
118 } else if (S_ISREG(st.st_mode)) {
119 files->push_back(name);
120 }
121 }
122 }
123 closedir(open_dir);
124 return list_ret;
125 }
126
endsWith(const string s,const string sub)127 int endsWith(const string s, const string sub) { return s.rfind(sub) == (s.length() - sub.length()) ? 1 : 0; }
128
ParseModelProto(mind_ir::ModelProto * model,const std::string & path,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode)129 bool ParseModelProto(mind_ir::ModelProto *model, const std::string &path, const unsigned char *dec_key,
130 const size_t key_len, const std::string &dec_mode) {
131 if (dec_key != nullptr) {
132 size_t plain_len;
133 auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
134 if (plain_data == nullptr) {
135 MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
136 return false;
137 }
138 if (!model->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
139 MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file, dec_key or dec_mode.";
140 return false;
141 }
142 } else {
143 std::fstream input_graph(path, std::ios::in | std::ios::binary);
144 if (!input_graph || !model->ParseFromIstream(&input_graph)) {
145 MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
146 return false;
147 }
148 }
149 return true;
150 }
151
ParseGraphProto(mind_ir::GraphProto * graph,const std::string & path,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode)152 bool ParseGraphProto(mind_ir::GraphProto *graph, const std::string &path, const unsigned char *dec_key,
153 const size_t key_len, const std::string &dec_mode) {
154 if (dec_key != nullptr) {
155 size_t plain_len;
156 auto plain_data = Decrypt(&plain_len, path, dec_key, key_len, dec_mode);
157 if (plain_data == nullptr) {
158 MS_LOG(ERROR) << "Decrypt MindIR file failed, please check the correctness of the dec_key or dec_mode.";
159 return false;
160 }
161 if (!graph->ParseFromArray(reinterpret_cast<char *>(plain_data.get()), static_cast<int32_t>(plain_len))) {
162 MS_LOG(ERROR) << "Load variable file failed, please check the correctness of the mindir's variable file, "
163 "dec_key or dec_mode";
164 return false;
165 }
166 } else {
167 std::fstream input_param(path, std::ios::in | std::ios::binary);
168 if (!input_param || !graph->ParseFromIstream(&input_param)) {
169 MS_LOG(ERROR) << "Load variable file failed, please check the correctness of mindir's variable file.";
170 return false;
171 }
172 }
173 return true;
174 }
175
LoadPreprocess(const std::string & file_name)176 std::string LoadPreprocess(const std::string &file_name) {
177 if (file_name.length() > PATH_MAX) {
178 MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
179 return nullptr;
180 }
181 char abs_path_buff[PATH_MAX];
182
183 #ifdef _WIN32
184 _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
185 #else
186 if (!realpath(file_name.c_str(), abs_path_buff)) {
187 MS_LOG(ERROR) << "Load MindIR get absolute path failed";
188 }
189 #endif
190
191 // Read graph
192 mind_ir::ModelProto origin_model;
193 std::fstream mindir_stream(std::string(std::string(abs_path_buff)), std::ios::in | std::ios::binary);
194 if (!mindir_stream || !origin_model.ParseFromIstream(&mindir_stream)) {
195 MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
196 return std::string();
197 }
198
199 return origin_model.preprocessor();
200 }
201
LoadMindIRs(std::vector<std::string> file_names,bool is_lite,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode)202 std::vector<std::shared_ptr<FuncGraph>> LoadMindIRs(std::vector<std::string> file_names, bool is_lite,
203 const unsigned char *dec_key, const size_t key_len,
204 const std::string &dec_mode) {
205 std::vector<std::shared_ptr<FuncGraph>> funcgraph_vec;
206 for (const auto &file_name : file_names) {
207 MS_LOG(DEBUG) << "Load " << file_name;
208 funcgraph_vec.push_back(LoadMindIR(file_name, is_lite, dec_key, key_len, dec_mode, true));
209 }
210 return funcgraph_vec;
211 }
212
LoadMindIR(const std::string & file_name,bool is_lite,const unsigned char * dec_key,const size_t key_len,const std::string & dec_mode,bool inc_load)213 std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite, const unsigned char *dec_key,
214 const size_t key_len, const std::string &dec_mode, bool inc_load) {
215 if (file_name.length() > PATH_MAX) {
216 MS_LOG(ERROR) << "The length of the file name exceeds the limit.";
217 return nullptr;
218 }
219 char abs_path_buff[PATH_MAX];
220 vector<string> files;
221
222 #ifdef _WIN32
223 _fullpath(abs_path_buff, file_name.c_str(), PATH_MAX);
224 #else
225 if (!realpath(file_name.c_str(), abs_path_buff)) {
226 MS_LOG(ERROR) << "Load MindIR get absolute path failed";
227 }
228 #endif
229 // Read graph
230 mind_ir::ModelProto origin_model;
231 if (!ParseModelProto(&origin_model, std::string(abs_path_buff), dec_key, key_len, dec_mode)) {
232 return nullptr;
233 }
234 // Load parameter into graph
235 if (endsWith(std::string(abs_path_buff), "_graph.mindir") && origin_model.graph().parameter_size() == 0) {
236 if (strlen(abs_path_buff) < strlen("graph.mindir")) {
237 MS_LOG(ERROR) << "The abs_path_buff length is less than 'graph.mindir'.";
238 return nullptr;
239 }
240 int path_len = SizeToInt(strlen(abs_path_buff) - strlen("graph.mindir"));
241 string var_path = std::string(abs_path_buff).substr(0, path_len);
242 var_path += "variables";
243 std::ifstream ifs(var_path);
244 if (ifs.good()) {
245 bool result = get_all_files(var_path, &files);
246 if (!result) {
247 MS_LOG(ERROR) << "Get files failed.";
248 return nullptr;
249 }
250 } else {
251 MS_LOG(ERROR) << "Load graph's variable folder failed, please check the correctness of variable folder.";
252 return nullptr;
253 }
254
255 size_t file_size = files.size();
256 mind_ir::GraphProto *mod_graph = origin_model.mutable_graph();
257 for (size_t file_index = 0; file_index < file_size; file_index++) {
258 mind_ir::GraphProto param_graph;
259 if (!ParseGraphProto(¶m_graph, files[file_index], dec_key, key_len, dec_mode)) {
260 return nullptr;
261 }
262 for (int param_index = 0; param_index < param_graph.parameter_size(); param_index++) {
263 mind_ir::TensorProto *param_proto = mod_graph->add_parameter();
264 param_proto->set_name(param_graph.parameter(param_index).name());
265 param_proto->set_data_type(param_graph.parameter(param_index).data_type());
266 param_proto->set_raw_data(param_graph.parameter(param_index).raw_data());
267 for (const auto &dim : param_graph.parameter(param_index).dims()) {
268 param_proto->add_dims(dim);
269 }
270 }
271 }
272 }
273
274 MSANFModelParser model_parser;
275 if (!inc_load) {
276 MSANFModelParser::LoadTensorMapClear();
277 }
278 if (is_lite) {
279 model_parser.SetLite();
280 }
281 if (inc_load) {
282 model_parser.SetIncLoad();
283 }
284 FuncGraphPtr dstgraph_ptr = model_parser.Parse(origin_model);
285 return dstgraph_ptr;
286 }
287
ConvertStreamToFuncGraph(const char * buf,const size_t buf_size,bool is_lite)288 std::shared_ptr<FuncGraph> ConvertStreamToFuncGraph(const char *buf, const size_t buf_size, bool is_lite) {
289 MS_EXCEPTION_IF_NULL(buf);
290 std::string str(buf, buf_size);
291 mind_ir::ModelProto model_;
292 if (!model_.ParseFromString(str)) {
293 MS_LOG(ERROR) << "Parse model from buffer fail!";
294 }
295 MSANFModelParser model_parser;
296 if (is_lite) {
297 model_parser.SetLite();
298 }
299 FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_);
300 return dstgraph_ptr;
301 }
302 } // namespace mindspore
303