1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_INCLUDE_API_SERIALIZATION_H
17 #define MINDSPORE_INCLUDE_API_SERIALIZATION_H
18
19 #include <string>
20 #include <vector>
21 #include <map>
22 #include <memory>
23 #include "include/api/status.h"
24 #include "include/api/types.h"
25 #include "include/api/model.h"
26 #include "include/api/graph.h"
27 #include "include/api/dual_abi_helper.h"
28
29 namespace mindspore {
30 /// \brief The Serialization class is used to summarize methods for reading and writing model files.
31 class MS_API Serialization {
32 public:
33 /// \brief Loads a model file from memory buffer.
34 ///
35 /// \param[in] model_data A buffer filled by model file.
36 /// \param[in] data_size The size of the buffer.
37 /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
38 /// \param[out] graph The output parameter, an object saves graph data.
39 /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. Not supported on MindSpore Lite.
40 /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. Not supported on MindSpore Lite.
41 ///
42 /// \return Status.
43 inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
44 const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
45
46 /// \brief Loads a model file from path.
47 ///
48 /// \param[in] file The path of model file.
49 /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
50 /// \param[out] graph The output parameter, an object saves graph data.
51 /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. Not supported on MindSpore Lite.
52 /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. Not supported on MindSpore Lite.
53 ///
54 /// \return Status.
55 inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
56 const std::string &dec_mode = kDecModeAesGcm);
57
58 /// \brief Load multiple models from multiple files, MindSpore Lite does not provide this feature.
59 ///
60 /// \param[in] files The path of model files.
61 /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
62 /// \param[out] graphs The output parameter, an object saves graph data.
63 /// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
64 /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
65 ///
66 /// \return Status.
67 inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
68 const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
69
70 /// \brief Configure model parameters, MindSpore Lite does not provide this feature.
71 ///
72 /// \param[in] parameters The parameters.
73 /// \param[in] model The model.
74 ///
75 /// \return Status.
76 inline static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
77
78 /// \brief Export training model from memory buffer, MindSpore Lite does not provide this feature.
79 ///
80 /// \param[in] model The model data.
81 /// \param[in] model_type The model file type.
82 /// \param[out] model_data The model buffer.
83 /// \param[in] quantization_type The quantification type.
84 /// \param[in] export_inference_only Whether to export a reasoning only model.
85 /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
86 /// empty, and export the complete reasoning model.
87 ///
88 /// \return Status.
89 inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
90 QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
91 const std::vector<std::string> &output_tensor_name = {});
92
93 /// \brief Export training model from file.
94 ///
95 /// \param[in] model The model data.
96 /// \param[in] model_type The model file type.
97 /// \param[in] model_file The path of exported model file.
98 /// \param[in] quantization_type The quantification type.
99 /// \param[in] export_inference_only Whether to export a reasoning only model.
100 /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
101 /// empty, and export the complete reasoning model.
102 ///
103 /// \return Status.
104 inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
105 QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
106 std::vector<std::string> output_tensor_name = {});
107
108 /// \brief Experimental feature. Export model's weights, which can be used in micro only.
109 ///
110 /// \param[in] model The model data.
111 /// \param[in] model_type The model file type.
112 /// \param[in] weight_file The path of exported weight file.
113 /// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
114 /// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
115 /// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
116 ///
117 /// \return Status.
118 inline static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
119 const std::string &weight_file, bool is_inference = true,
120 bool enable_fp16 = false,
121 const std::vector<std::string> &changeable_weights_name = {});
122
123 private:
124 static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
125 const std::vector<char> &dec_mode);
126 static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
127 static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
128 const std::vector<char> &dec_mode);
129 static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
130 const Key &dec_key, const std::vector<char> &dec_mode);
131 static Status SetParameters(const std::map<std::vector<char>, Buffer> ¶meters, Model *model);
132 static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
133 QuantizationType quantization_type, bool export_inference_only,
134 const std::vector<std::vector<char>> &output_tensor_name);
135 static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
136 QuantizationType quantization_type, bool export_inference_only,
137 const std::vector<std::vector<char>> &output_tensor_name);
138 static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
139 const std::vector<char> &weight_file, bool is_inference,
140 bool enable_fp16,
141 const std::vector<std::vector<char>> &changeable_weights_name);
142 };
143
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::string & dec_mode)144 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
145 const Key &dec_key, const std::string &dec_mode) {
146 return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
147 }
148
Load(const std::string & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::string & dec_mode)149 Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
150 const std::string &dec_mode) {
151 return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
152 }
153
Load(const std::vector<std::string> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::string & dec_mode)154 Status Serialization::Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
155 const Key &dec_key, const std::string &dec_mode) {
156 return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
157 }
158
SetParameters(const std::map<std::string,Buffer> & parameters,Model * model)159 Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
160 return SetParameters(MapStringToChar<Buffer>(parameters), model);
161 }
162
ExportModel(const Model & model,ModelType model_type,const std::string & model_file,QuantizationType quantization_type,bool export_inference_only,std::vector<std::string> output_tensor_name)163 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
164 QuantizationType quantization_type, bool export_inference_only,
165 std::vector<std::string> output_tensor_name) {
166 return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only,
167 VectorStringToChar(output_tensor_name));
168 }
169
ExportModel(const Model & model,ModelType model_type,Buffer * model_data,QuantizationType quantization_type,bool export_inference_only,const std::vector<std::string> & output_tensor_name)170 Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
171 QuantizationType quantization_type, bool export_inference_only,
172 const std::vector<std::string> &output_tensor_name) {
173 return ExportModel(model, model_type, model_data, quantization_type, export_inference_only,
174 VectorStringToChar(output_tensor_name));
175 }
176
ExportWeightsCollaborateWithMicro(const Model & model,ModelType model_type,const std::string & weight_file,bool is_inference,bool enable_fp16,const std::vector<std::string> & changeable_weights_name)177 Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
178 const std::string &weight_file, bool is_inference,
179 bool enable_fp16,
180 const std::vector<std::string> &changeable_weights_name) {
181 return ExportWeightsCollaborateWithMicro(model, model_type, StringToChar(weight_file), is_inference, enable_fp16,
182 VectorStringToChar(changeable_weights_name));
183 }
184 } // namespace mindspore
185 #endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
186