1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/api/serialization.h"
18 #include <algorithm>
19 #include <queue>
20 #include "include/api/graph.h"
21 #include "include/api/types.h"
22 #include "include/model.h"
23 #include "src/litert/cxx_api/graph/graph_data.h"
24 #include "src/litert/cxx_api/model/model_impl.h"
25 #include "src/litert/cxx_api/converters.h"
26 #include "src/common/log_adapter.h"
27 #include "src/litert/lite_session.h"
28
29 namespace mindspore {
Key(const char * dec_key,size_t key_len)30 Key::Key(const char *dec_key, size_t key_len) {
31 len = 0;
32 if (key_len >= max_key_len) {
33 MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
34 return;
35 }
36
37 (void)memcpy(key, dec_key, key_len);
38 len = key_len;
39 }
40
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)41 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
42 const Key &dec_key, const std::vector<char> &dec_mode) {
43 if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
44 MS_LOG(ERROR) << "Unsupported Feature.";
45 return kLiteError;
46 }
47
48 if (model_data == nullptr) {
49 MS_LOG(ERROR) << "model data is nullptr.";
50 return kLiteNullptr;
51 }
52 if (graph == nullptr) {
53 MS_LOG(ERROR) << "graph is nullptr.";
54 return kLiteNullptr;
55 }
56 if (model_type != kMindIR && model_type != kMindIR_Lite) {
57 MS_LOG(ERROR) << "Unsupported IR.";
58 return kLiteInputParamInvalid;
59 }
60
61 size_t lite_buf_size = 0;
62 char *lite_buf = nullptr;
63 lite::LiteSession session;
64 auto buf_model_type = session.LoadModelByBuff(reinterpret_cast<const char *>(model_data), data_size, &lite_buf,
65 &lite_buf_size, model_type);
66 if (buf_model_type == mindspore::ModelType::kUnknownType || lite_buf == nullptr) {
67 MS_LOG(ERROR) << "Invalid model_buf";
68 return kLiteNullptr;
69 }
70 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(lite_buf), data_size));
71 if (model == nullptr) {
72 MS_LOG(ERROR) << "New model failed.";
73 return kLiteNullptr;
74 }
75 if (buf_model_type == mindspore::ModelType::kMindIR) {
76 free(lite_buf);
77 lite_buf = nullptr;
78 }
79 auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
80 if (graph_data == nullptr) {
81 MS_LOG(ERROR) << "New graph data failed.";
82 return kLiteMemoryFailed;
83 }
84 *graph = Graph(graph_data);
85 return kSuccess;
86 }
87
Load(const std::vector<char> & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)88 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
89 const std::vector<char> &dec_mode) {
90 if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
91 MS_LOG(ERROR) << "Unsupported Feature.";
92 return kLiteError;
93 }
94
95 if (graph == nullptr) {
96 MS_LOG(ERROR) << "graph is nullptr.";
97 return kLiteNullptr;
98 }
99 if (model_type != kMindIR && model_type != kMindIR_Lite) {
100 MS_LOG(ERROR) << "Unsupported IR.";
101 return kLiteInputParamInvalid;
102 }
103
104 std::string filename(file.data(), file.size());
105 if (filename.size() > static_cast<size_t>((std::numeric_limits<int>::max)())) {
106 MS_LOG(ERROR) << "file name is too long.";
107 return kLiteInputParamInvalid;
108 }
109 auto pos = filename.find_last_of('.');
110 if (pos == std::string::npos || filename.substr(pos + 1) != "ms") {
111 filename = filename + ".ms";
112 }
113
114 size_t model_size;
115 lite::LiteSession session;
116 auto model_buf = session.LoadModelByPath(filename, model_type, &model_size, false);
117 if (model_buf == nullptr) {
118 MS_LOG(ERROR) << "Read model file failed";
119 return kLiteNullptr;
120 }
121 auto model =
122 std::shared_ptr<lite::Model>(lite::ImportFromBuffer(static_cast<const char *>(model_buf), model_size, true));
123 if (model == nullptr) {
124 MS_LOG(ERROR) << "New model failed.";
125 return kLiteNullptr;
126 }
127 auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
128 if (graph_data == nullptr) {
129 MS_LOG(ERROR) << "New graph data failed.";
130 return kLiteMemoryFailed;
131 }
132 *graph = Graph(graph_data);
133 return kSuccess;
134 }
135
Load(const std::vector<std::vector<char>> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::vector<char> & dec_mode)136 Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
137 std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
138 MS_LOG(ERROR) << "Unsupported Feature.";
139 return kLiteError;
140 }
141
SetParameters(const std::map<std::vector<char>,Buffer> & parameters,Model * model)142 Status Serialization::SetParameters(const std::map<std::vector<char>, Buffer> ¶meters, Model *model) {
143 MS_LOG(ERROR) << "Unsupported feature.";
144 return kMEFailed;
145 }
146
ExportModel(const Model & model,ModelType model_type,Buffer * model_data,QuantizationType quantization_type,bool export_inference_only,const std::vector<std::vector<char>> & output_tensor_name)147 Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
148 QuantizationType quantization_type, bool export_inference_only,
149 const std::vector<std::vector<char>> &output_tensor_name) {
150 if (model.impl_ == nullptr) {
151 MS_LOG(ERROR) << "Model implement is null.";
152 return kLiteUninitializedObj;
153 }
154 if (!model.impl_->IsTrainModel()) {
155 MS_LOG(ERROR) << "Model is not TrainModel.";
156 return kLiteError;
157 }
158 if (model_data == nullptr) {
159 MS_LOG(ERROR) << "model_data is nullptr.";
160 return kLiteParamInvalid;
161 }
162 if (model_type != kMindIR && model_type != kMindIR_Lite) {
163 MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
164 return kLiteParamInvalid;
165 }
166 if (model.impl_->session_ == nullptr) {
167 MS_LOG(ERROR) << "Model session is nullptr.";
168 return kLiteError;
169 }
170 auto ret = model.impl_->session_->Export(model_data, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
171 A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS,
172 VectorCharToString(output_tensor_name));
173
174 return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
175 }
176
ExportModel(const Model & model,ModelType model_type,const std::vector<char> & model_file,QuantizationType quantization_type,bool export_inference_only,const std::vector<std::vector<char>> & output_tensor_name)177 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
178 QuantizationType quantization_type, bool export_inference_only,
179 const std::vector<std::vector<char>> &output_tensor_name) {
180 if (model.impl_ == nullptr) {
181 MS_LOG(ERROR) << "Model implement is null.";
182 return kLiteUninitializedObj;
183 }
184 if (model.impl_->session_ == nullptr) {
185 MS_LOG(ERROR) << "Model hasn't been built.";
186 return kLiteError;
187 }
188 if (!model.impl_->IsTrainModel()) {
189 MS_LOG(ERROR) << "Model is not TrainModel.";
190 return kLiteError;
191 }
192 if (model_type != kMindIR && model_type != kMindIR_Lite) {
193 MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
194 return kLiteParamInvalid;
195 }
196 auto ret = model.impl_->session_->Export(
197 CharToString(model_file), export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
198 A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, VectorCharToString(output_tensor_name));
199
200 return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
201 }
202
ExportWeightsCollaborateWithMicro(const Model & model,ModelType model_type,const std::vector<char> & weight_file,bool is_inference,bool enable_fp16,const std::vector<std::vector<char>> & changeable_weights_name)203 Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
204 const std::vector<char> &weight_file, bool is_inference,
205 bool enable_fp16,
206 const std::vector<std::vector<char>> &changeable_weights_name) {
207 if (model.impl_ == nullptr) {
208 MS_LOG(ERROR) << "Model implement is null.";
209 return kLiteUninitializedObj;
210 }
211 if (model.impl_->session_ == nullptr) {
212 MS_LOG(ERROR) << "Model hasn't been built.";
213 return kLiteError;
214 }
215 if (!model.impl_->IsTrainModel()) {
216 MS_LOG(ERROR) << "Model is not TrainModel.";
217 return kLiteError;
218 }
219 if (model_type != kMindIR && model_type != kMindIR_Lite) {
220 MS_LOG(ERROR) << "Model type is not kMindIR or kMindIR_Lite";
221 return kLiteParamInvalid;
222 }
223 if (!is_inference) {
224 MS_LOG(ERROR) << "Currently, can only export inference-model's weights.";
225 return kLiteNotSupport;
226 }
227 auto ret = model.impl_->session_->ExportWeightsCollaborateWithMicro(CharToString(weight_file), lite::MT_INFERENCE,
228 lite::FT_FLATBUFFERS, enable_fp16,
229 VectorCharToString(changeable_weights_name));
230
231 return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
232 }
233 } // namespace mindspore
234