• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/common/meta_graph_serializer.h"
18 #include <sys/stat.h>
19 #ifndef _MSC_VER
20 #include <unistd.h>
21 #endif
22 #include "flatbuffers/flatbuffers.h"
23 #include "src/common/log_adapter.h"
24 #include "nnacl/op_base.h"
25 #include "ir/dtype/type_id.h"
26 #include "src/common/utils.h"
27 #include "include/errorcode.h"
28 #include "securec/include/securec.h"
29 #include "src/common/file_utils.h"
30 
31 namespace mindspore::lite {
32 namespace {
33 constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
34 constexpr size_t kExternalDataHeadSize = 4096;
35 constexpr size_t kMagicNumberSize = 4;
36 constexpr size_t kFlatbuffersBuilderInitSize = 1024;
37 
ChangeMod(const std::string & file_path)38 void ChangeMod(const std::string &file_path) {
39 #ifndef _MSC_VER
40   if (access(file_path.c_str(), F_OK) == 0) {
41     (void)chmod(file_path.c_str(), S_IWUSR | S_IRUSR);
42   }
43 #endif
44 }
45 
ReopenFile(const std::string & file_path,std::ios_base::openmode open_mode=std::ios::in|std::ios::out,std::fstream * fs=nullptr)46 std::fstream *ReopenFile(const std::string &file_path, std::ios_base::openmode open_mode = std::ios::in | std::ios::out,
47                          std::fstream *fs = nullptr) {
48   if (fs == nullptr) {
49     ChangeMod(file_path);
50     return OpenFile(file_path, open_mode);
51   } else {
52     fs->close();
53     fs->open(file_path, open_mode);
54     if (!fs->good()) {
55       MS_LOG(DEBUG) << "File is not exist: " << file_path;
56       return nullptr;
57     }
58     if (!fs->is_open()) {
59       MS_LOG(DEBUG) << "Can not open file: " << file_path;
60       return nullptr;
61     }
62     return fs;
63   }
64 }
65 }  // namespace
66 
InitPath(const std::string & output_path)67 bool MetaGraphSerializer::InitPath(const std::string &output_path) {
68   if (!ParserPathAndModelName(output_path, &this->save_path_, &this->model_name_)) {
69     MS_LOG(ERROR) << "parser save path and model name from output_path failed.";
70     return false;
71   }
72 #ifdef _WIN32
73   save_model_path_ = save_path_ + "\\" + model_name_ + ".ms";
74   save_data_path_ = save_path_ + "\\" + model_name_ + ".msw";
75 #else
76   save_model_path_ = save_path_ + "/" + model_name_ + ".ms";
77   save_data_path_ = save_path_ + "/" + model_name_ + ".msw";
78 #endif
79   return true;
80 }
81 
Init(const schema::MetaGraphT & graph,bool save_together)82 bool MetaGraphSerializer::Init(const schema::MetaGraphT &graph, bool save_together) {
83   // init file streams
84   ChangeMod(save_model_path_);
85   model_fs_ = OpenFile(save_model_path_, std::ios::out | std::ios::binary | std::ios::trunc);
86   if (model_fs_ == nullptr) {
87     MS_LOG(ERROR) << "Open " << save_model_path_ << " failed";
88     return false;
89   }
90   if (save_together) {
91     return true;
92   }
93 
94   ChangeMod(save_data_path_);
95   data_fs_ = OpenFile(save_data_path_, std::ios::out | std::ios::binary | std::ios::trunc);
96   if (data_fs_ == nullptr) {
97     MS_LOG(ERROR) << "Open " << save_data_path_ << " failed";
98     return false;
99   }
100   // write weight file head
101   auto head_data = reinterpret_cast<char *>(malloc(kExternalDataHeadSize));
102   if (head_data == nullptr) {
103     MS_LOG(ERROR) << "Malloc data for file head failed";
104     return false;
105   }
106   if (memset_s(head_data, kExternalDataHeadSize, 0, kExternalDataHeadSize) != 0) {
107     MS_LOG(ERROR) << "memset_s in MetaGraphSerializer init failed.";
108     free(head_data);
109     return false;
110   }
111   // magic number of weight_s file: 0x12345678
112   auto sum_data = reinterpret_cast<uint32_t *>(head_data);
113   sum_data[0] = 0x12345678;
114   data_fs_->write(head_data, kExternalDataHeadSize);
115   if (data_fs_->bad()) {
116     MS_LOG(ERROR) << "Write file head failed";
117     free(head_data);
118     return false;
119   }
120   free(head_data);
121   cur_offset_ = kExternalDataHeadSize;
122   return true;
123 }
124 
AddExternalData(const char * data,size_t size)125 schema::ExternalDataT *MetaGraphSerializer::AddExternalData(const char *data, size_t size) {
126   MS_ASSERT(data_fs_ != nullptr);
127   auto external_data = new (std::nothrow) schema::ExternalDataT;
128   if (external_data == nullptr) {
129     MS_LOG(ERROR) << "Create ExternalDataT failed";
130     return nullptr;
131   }
132   external_data->location = model_name_ + ".msw";
133   external_data->offset = cur_offset_;
134   external_data->length = static_cast<int64_t>(size);
135   if (data == nullptr || size == 0) {
136     return external_data;
137   }
138   data_fs_->write(data, static_cast<int64_t>(size));
139   if (data_fs_->bad()) {
140     MS_LOG(ERROR) << "Write file failed";
141     delete external_data;
142     return nullptr;
143   }
144   std::stringstream oss;
145   oss << std::hash<char>()(data[0]);
146   external_data->checkSum = oss.str();
147   cur_offset_ += static_cast<int64_t>(size);
148   return external_data;
149 }
150 
ExtraAndSerializeModelWeight(const schema::MetaGraphT & graph)151 bool MetaGraphSerializer::ExtraAndSerializeModelWeight(const schema::MetaGraphT &graph) {
152   if (data_fs_ == nullptr) {
153     MS_LOG(ERROR) << "Weight file stream is not inited";
154     return false;
155   }
156   data_fs_ = ReopenFile(save_data_path_, std::ios::out | std::ios::app, data_fs_);
157   if (data_fs_ == nullptr) {
158     MS_LOG(ERROR) << "Reopen weight file stream failed";
159     return false;
160   }
161   if (this->cur_offset_ != kExternalDataHeadSize) {
162     MS_LOG(ERROR) << "Serialized model weight already";
163     return false;
164   }
165   for (const auto &tensor : graph.allTensors) {
166     if (tensor->nodeType == NodeType_CNode) {
167       continue;
168     }
169     if (tensor->dataType == kObjectTypeTensorType) {  // not support control-flow now
170       continue;
171     }
172     auto external_data =
173       this->AddExternalData(reinterpret_cast<const char *>(tensor->data.data()), tensor->data.size());
174     if (external_data == nullptr) {
175       MS_LOG(ERROR) << "Serialized model weight failed";
176       return false;
177     }
178     tensor->data.clear();
179     tensor->externalData.emplace_back(external_data);
180   }
181   return true;
182 }
183 
SerializeModelAndUpdateWeight(const schema::MetaGraphT & meta_graphT,const Byte * key,const size_t key_len,const std::string & enc_mode,size_t * size)184 bool MetaGraphSerializer::SerializeModelAndUpdateWeight(const schema::MetaGraphT &meta_graphT, const Byte *key,
185                                                         const size_t key_len, const std::string &enc_mode,
186                                                         size_t *size) {
187   // serialize model
188   flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
189   auto offset = schema::MetaGraph::Pack(builder, &meta_graphT);
190   builder.Finish(offset);
191   schema::FinishMetaGraphBuffer(builder, offset);
192   *size = builder.GetSize();
193   auto content = builder.GetBufferPointer();
194   if (!SerializeModel(content, *size, key, key_len, enc_mode)) {
195     MS_LOG(ERROR) << "Serialize graph failed";
196     return false;
197   }
198 
199   // update weight file using check-sum of model-buffer
200   auto model_crc32 = std::hash<uint8_t>()(content[0]);
201   if (data_fs_ == nullptr) {
202     MS_LOG(ERROR) << "Weight file stream is not inited";
203     return false;
204   }
205   data_fs_ = ReopenFile(save_data_path_, std::ios::in | std::ios::out, data_fs_);
206   if (data_fs_ == nullptr) {
207     MS_LOG(ERROR) << "Reopen weight file stream failed";
208     return false;
209   }
210   data_fs_->seekp(kMagicNumberSize, std::ios::beg);
211   data_fs_->write(reinterpret_cast<const char *>(&model_crc32), kMagicNumberSize);
212 #ifndef _MSC_VER
213   chmod(save_data_path_.c_str(), S_IRUSR);
214 #endif
215   return true;
216 }
217 
GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuilder * builder,const schema::MetaGraphT & graph,size_t * data_size)218 uint8_t *MetaGraphSerializer::GetMetaGraphPackedBuff(flatbuffers::FlatBufferBuilder *builder,
219                                                      const schema::MetaGraphT &graph, size_t *data_size) {
220   auto offset = schema::MetaGraph::Pack(*builder, &graph);
221   builder->Finish(offset);
222   schema::FinishMetaGraphBuffer(*builder, offset);
223   *data_size = builder->GetSize();
224   return builder->GetBufferPointer();
225 }
226 
Save(const schema::MetaGraphT & graph,const std::string & output_path,const Byte * key,const size_t key_len,const std::string & enc_mode)227 int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, const Byte *key,
228                               const size_t key_len, const std::string &enc_mode) {
229   size_t size = 0;
230   auto ret = MetaGraphSerializer::Save(graph, output_path, &size, key, key_len, enc_mode);
231   return ret;
232 }
233 
Save(const schema::MetaGraphT & graph,const std::string & output_path,size_t * size,const Byte * key,const size_t key_len,const std::string & enc_mode)234 int MetaGraphSerializer::Save(const schema::MetaGraphT &graph, const std::string &output_path, size_t *size,
235                               const Byte *key, const size_t key_len, const std::string &enc_mode) {
236   MetaGraphSerializer meta_graph_serializer;
237   *size = 0;
238   flatbuffers::FlatBufferBuilder builder(kFlatbuffersBuilderInitSize);
239   auto buffer = meta_graph_serializer.GetMetaGraphPackedBuff(&builder, graph, size);
240   if (!meta_graph_serializer.InitPath(output_path)) {
241     MS_LOG(ERROR) << "Init path failed";
242     return RET_ERROR;
243   }
244   size_t tensors_size = 0;
245   for (auto &tensor : graph.allTensors) {
246     tensors_size += tensor->data.size();
247   }
248 
249   auto save_together = (tensors_size < kModelSizeLimit && *size < kModelSizeLimit);
250   if (!meta_graph_serializer.Init(graph, save_together)) {
251     MS_LOG(ERROR) << "Init MetaGraphSerializer failed";
252     return RET_ERROR;
253   }
254   if (save_together) {
255     if (!meta_graph_serializer.SerializeModel(buffer, *size, key, key_len, enc_mode)) {
256       MS_LOG(ERROR) << "Serialize graph failed";
257       return RET_ERROR;
258     }
259   } else {
260     if (!meta_graph_serializer.ExtraAndSerializeModelWeight(graph)) {
261       MS_LOG(ERROR) << "Serialize graph weight failed";
262       return RET_ERROR;
263     }
264     size_t model_size = 0;
265     if (!meta_graph_serializer.SerializeModelAndUpdateWeight(graph, key, key_len, enc_mode, &model_size)) {
266       MS_LOG(ERROR) << "Serialize graph and adjust weight failed";
267       return RET_ERROR;
268     }
269     *size = model_size + tensors_size;
270   }
271   return RET_OK;
272 }
273 
~MetaGraphSerializer()274 MetaGraphSerializer::~MetaGraphSerializer() {
275   if (model_fs_ != nullptr) {
276     model_fs_->close();
277     delete model_fs_;
278   }
279   if (data_fs_ != nullptr) {
280     data_fs_->close();
281     delete data_fs_;
282   }
283 }
284 
SerializeModel(const void * content,size_t size,const Byte * key,const size_t key_len,const std::string & enc_mode)285 bool MetaGraphSerializer::SerializeModel(const void *content, size_t size, const Byte *key, const size_t key_len,
286                                          const std::string &enc_mode) {
287   MS_ASSERT(model_fs_ != nullptr);
288   if (size == 0 || content == nullptr) {
289     MS_LOG(ERROR) << "Input meta graph buffer is nullptr";
290     return false;
291   }
292   if (key_len > 0) {
293     size_t encrypt_len;
294     auto encrypt_content = Encrypt(&encrypt_len, reinterpret_cast<const Byte *>(content), size, key, key_len, enc_mode);
295     if (encrypt_content == nullptr || encrypt_len == 0) {
296       MS_LOG(ERROR) << "Encrypt failed.";
297       model_fs_->close();
298       return false;
299     }
300     model_fs_->write(reinterpret_cast<const char *>(encrypt_content.get()), encrypt_len);
301   } else {
302     model_fs_->write((const char *)content, static_cast<int64_t>(size));
303   }
304   if (model_fs_->bad()) {
305     MS_LOG(ERROR) << "Write model file failed: " << save_model_path_;
306     return false;
307   }
308 #ifndef _MSC_VER
309   chmod(save_model_path_.c_str(), S_IRUSR);
310 #endif
311   return true;
312 }
313 }  // namespace mindspore::lite
314