1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/common/meta_graph_serializer.h"
18 #include <sys/stat.h>
19 #ifndef _MSC_VER
20 #include <unistd.h>
21 #endif
22 #include "flatbuffers/flatbuffers.h"
23 #include "src/common/log_adapter.h"
24 #include "nnacl/op_base.h"
25 #include "ir/dtype/type_id.h"
26 #include "src/common/utils.h"
27 #include "include/errorcode.h"
28 #include "securec/include/securec.h"
29 #include "src/common/file_utils.h"
30
31 namespace mindspore::lite {
32 namespace {
33 constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
34 constexpr size_t kExternalDataHeadSize = 4096;
35 constexpr size_t kMagicNumberSize = 4;
36 constexpr size_t kFlatbuffersBuilderInitSize = 1024;
37
ChangeMod(const std::string & file_path)38 void ChangeMod(const std::string &file_path) {
39 #ifndef _MSC_VER
40 if (access(file_path.c_str(), F_OK) == 0) {
41 (void)chmod(file_path.c_str(), S_IWUSR | S_IRUSR);
42 }
43 #endif
44 }
45
ReopenFile(const std::string & file_path,std::ios_base::openmode open_mode=std::ios::in|std::ios::out,std::fstream * fs=nullptr)46 std::fstream *ReopenFile(const std::string &file_path, std::ios_base::openmode open_mode = std::ios::in | std::ios::out,
47 std::fstream *fs = nullptr) {
48 if (fs == nullptr) {
49 ChangeMod(file_path);
50 return OpenFile(file_path, open_mode);
51 } else {
52 fs->close();
53 fs->open(file_path, open_mode);
54 if (!fs->good()) {
55 MS_LOG(DEBUG) << "File is not exist: " << file_path;
56 return nullptr;
57 }
58 if (!fs->is_open()) {
59 MS_LOG(DEBUG) << "Can not open file: " << file_path;
60 return nullptr;
61 }
62 return fs;
63 }
64 }
65 } // namespace
66
InitPath(const std::string & output_path)67 bool MetaGraphSerializer::InitPath(const std::string &output_path) {
68 if (!ParserPathAndModelName(output_path, &this->save_path_, &this->model_name_)) {
69 MS_LOG(ERROR) << "parser save path and model name from output_path failed.";
70 return false;
71 }
72 #ifdef _WIN32
73 save_model_path_ = save_path_ + "\\" + model_name_ + ".ms";
74 save_data_path_ = save_path_ + "\\" + model_name_ + ".msw";
75 #else
76 save_model_path_ = save_path_ + "/" + model_name_ + ".ms";
77 save_data_path_ = save_path_ + "/" + model_name_ + ".msw";
78 #endif
79 return true;
80 }
81
Init(const schema::MetaGraphT & graph,bool save_together)82 bool MetaGraphSerializer::Init(const schema::MetaGraphT &graph, bool save_together) {
83 // init file streams
84 ChangeMod(save_model_path_);
85 model_fs_ = OpenFile(save_model_path_, std::ios::out | std::ios::binary | std::ios::trunc);
86 if (model_fs_ == nullptr) {
87 MS_LOG(ERROR) << "Open " << save_model_path_ << " failed";
88 return false;
89 }
90 if (save_together) {
91 return true;
92 }
93
94 ChangeMod(save_data_path_);
95 data_fs_ = OpenFile(save_data_path_, std::ios::out | std::ios::binary | std::ios::trunc);
96 if (data_fs_ == nullptr) {
97 MS_LOG(ERROR) << "Open " << save_data_path_ << " failed";
98 return false;
99 }
100 // write weight file head
101 auto head_data = reinterpret_cast<char *>(malloc(kExternalDataHeadSize));
102 if (head_data == nullptr) {
103 MS_LOG(ERROR) << "Malloc data for file head failed";
104 return false;
105 }
106 if (memset_s(head_data, kExternalDataHeadSize, 0, kExternalDataHeadSize) != 0) {
107 MS_LOG(ERROR) << "memset_s in MetaGraphSerializer init failed.";
108 free(head_data);
109 return false;
110 }
111 // magic number of weight_s file: 0x12345678
112 auto sum_data = reinterpret_cast<uint32_t *>(head_data);
113 sum_data[0] = 0x12345678;
114 data_fs_->write(head_data, kExternalDataHeadSize);
115 if (data_fs_->bad()) {
116 MS_LOG(ERROR) << "Write file head failed";
117 free(head_data);
118 return false;
119 }
120 free(head_data);
121 cur_offset_ = kExternalDataHeadSize;
122 return true;
123 }
124
AddExternalData(const char * data,size_t size)125 schema::ExternalDataT *MetaGraphSerializer::AddExternalData(const char *data, size_t size) {
126 MS_ASSERT(data_fs_ != nullptr);
127 auto external_data = new (std::nothrow) schema::ExternalDataT;
128 if (external_data == nullptr) {
129 MS_LOG(ERROR) << "Create ExternalDataT failed";
130 return nullptr;
131 }
132 external_data->location = model_name_ + ".msw";
133 external_data->offset = cur_offset_;
134 external_data->length = static_cast<int64_t>(size);
135 if (data == nullptr || size == 0) {
136 return external_data;
137 }
138 data_fs_->write(data, static_cast<int64_t>(size));
139 if (data_fs_->bad()) {
140 MS_LOG(ERROR) << "Write file failed";
141 delete external_data;
142 return nullptr;
143 }
144 std::stringstream oss;
145 oss << std::hash<char>()(data[0]);
146 external_data->checkSum = oss.str();
147 cur_offset_ += static_cast<int64_t>(size);
148 return external_data;
149 }
150
ExtraAndSerializeModelWeight(const schema::MetaGraphT & graph)151 bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph) {
152 if (data_fs_ == nullptr) {
153 MS_LOG(ERROR) << "Weight file stream is not inited";
154 return false;
155 }
156 data_fs_ = ReopenFile(save_data_path_, std::ios::out | std::ios::app, data_fs_);
157 if (data_fs_ == nullptr) {
158 MS_LOG(ERROR) << "Reopen weight file stream failed";
159 return false;
160 }
161 if (this->cur_offset_ != kExternalDataHeadSize) {
162 MS_LOG(ERROR) << "Serialized model weight already";
163 return false;
164 }
165 for (const auto &tensor : graph.allTensors) {
166 if (tensor->nodeType == NodeType_CNode) {
167 continue;
168 }
169 if (tensor->dataType == kObjectTypeTensorType) { // not support control-flow now
170 continue;
171 }
172 auto external_data =
173 this->AddExternalData(reinterpret_cast<const char *>(tensor->data.data()), tensor->data.size());
174 if (external_data == nullptr) {
175 MS_LOG(ERROR) << "Serialized model weight failed";
176 return false;
177 }
178 tensor->data.clear();
179 tensor->externalData.emplace_back(external_data);
180 }
181 return true;
182 }
183
SerializeModelAndUpdateWeight(const schema::MetaGraphT & meta_graphT,const Byte * key,const size_t key_len,const std::string & enc_mode,size_t * size)184 bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
185 const size_t key_len, const std::string &enc_mode,
186 size_t *size) {
187 // serialize model
188 flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
189 auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
190 builder.Finish(offset);
191 schema::FinishMetaGraphBuffer(builder, offset);
192 *size = builder.GetSize();
193 auto content = builder.GetBufferPointer();
194 if (!SerializeModel(content, *size, key, key_len, enc_mode)) {
195 MS_LOG(ERROR) << "Serialize graph failed";
196 return false;
197 }
198
199 // update weight file using check-sum of model-buffer
200 auto model_crc32 = std::hash<uint8_t>()(content[0]);
201 if (data_fs_ == nullptr) {
202 MS_LOG(ERROR) << "Weight file stream is not inited";
203 return false;
204 }
205 data_fs_ = ReopenFile(save_data_path_, std::ios::in | std::ios::out, data_fs_);
206 if (data_fs_ == nullptr) {
207 MS_LOG(ERROR) << "Reopen weight file stream failed";
208 return false;
209 }
210 data_fs_->seekp(kMagicNumberSize, std::ios::beg);
211 data_fs_->write(reinterpret_cast<const char *>(&model_crc32), kMagicNumberSize);
212 #ifndef _MSC_VER
213 chmod(save_data_path_.c_str(), S_IRUSR);
214 #endif
215 return true;
216 }
217
GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuilder * builder,const schema::MetaGraphT & graph,size_t * data_size)218 uint8_t *MetaGraphSerializer::GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuilder *builder,
219 const schema::MetaGraphT &graph, size_t *data_size) {
220 auto offset = schema::MetaGraph::Pack(*builder, &graph);
221 builder->Finish(offset);
222 schema::FinishMetaGraphBuffer(*builder, offset);
223 *data_size = builder->GetSize();
224 return builder->GetBufferPointer();
225 }
226
Save(const schema::MetaGraphT & graph,const std::string & output_path,const Byte * key,const size_t key_len,const std::string & enc_mode)227 int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
228 const size_t key_len, const std::string &enc_mode) {
229 size_t size = 0;
230 auto ret = MetaGraphSerializer::Save(graph, output_path, &size, key, key_len, enc_mode);
231 return ret;
232 }
233
Save(const schema::MetaGraphT & graph,const std::string & output_path,size_t * size,const Byte * key,const size_t key_len,const std::string & enc_mode)234 int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size,
235 const Byte *key, const size_t key_len, const std::string &enc_mode) {
236 MetaGraphSerializer meta_graph_serializer;
237 *size = 0;
238 flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
239 auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, size);
240 if (!meta_graph_serializer.InitPath(output_path)) {
241 MS_LOG(ERROR) << "Init path failed";
242 return RET_ERROR;
243 }
244 size_t tensors_size = 0;
245 for (auto &tensor : graph.allTensors) {
246 tensors_size += tensor->data.size();
247 }
248
249 auto save_together = (tensors_size < kModelSizeLimit && *size < kModelSizeLimit);
250 if (!meta_graph_serializer.Init(graph, save_together)) {
251 MS_LOG(ERROR) << "Init MetaGraphSerializer failed";
252 return RET_ERROR;
253 }
254 if (save_together) {
255 if (!meta_graph_serializer.SerializeModel(buffer, *size, key, key_len, enc_mode)) {
256 MS_LOG(ERROR) << "Serialize graph failed";
257 return RET_ERROR;
258 }
259 } else {
260 if (!meta_graph_serializer.ExtraAndSerializeModelWeight(graph)) {
261 MS_LOG(ERROR) << "Serialize graph weight failed";
262 return RET_ERROR;
263 }
264 size_t model_size = 0;
265 if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode, &model_size)) {
266 MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
267 return RET_ERROR;
268 }
269 *size = model_size + tensors_size;
270 }
271 return RET_OK;
272 }
273
~MetaGraphSerializer()274 MetaGraphSerializer::~MetaGraphSerializer() {
275 if (model_fs_ != nullptr) {
276 model_fs_->close();
277 delete model_fs_;
278 }
279 if (data_fs_ != nullptr) {
280 data_fs_->close();
281 delete data_fs_;
282 }
283 }
284
SerializeModel(const void * content,size_t size,const Byte * key,const size_t key_len,const std::string & enc_mode)285 bool MetaGraphSerializer::SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
286 const std::string &enc_mode) {
287 MS_ASSERT(model_fs_ != nullptr);
288 if (size == 0 || content == nullptr) {
289 MS_LOG(ERROR) << "Input meta graph buffer is nullptr";
290 return false;
291 }
292 if (key_len > 0) {
293 size_t encrypt_len;
294 auto encrypt_content = Encrypt(&encrypt_len, reinterpret_cast<const Byte *>(content), size, key, key_len, enc_mode);
295 if (encrypt_content == nullptr || encrypt_len == 0) {
296 MS_LOG(ERROR) << "Encrypt failed.";
297 model_fs_->close();
298 return false;
299 }
300 model_fs_->write(reinterpret_cast<const char *>(encrypt_content.get()), encrypt_len);
301 } else {
302 model_fs_->write((const char *)content, static_cast<int64_t>(size));
303 }
304 if (model_fs_->bad()) {
305 MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
306 return false;
307 }
308 #ifndef _MSC_VER
309 chmod(save_model_path_.c_str(), S_IRUSR);
310 #endif
311 return true;
312 }
313 } // namespace mindspore::lite
314