• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_INCLUDE_API_SERIALIZATION_H
17 #define MINDSPORE_INCLUDE_API_SERIALIZATION_H
18 
19 #include <string>
20 #include <vector>
21 #include <map>
22 #include <memory>
23 #include "include/api/status.h"
24 #include "include/api/types.h"
25 #include "include/api/model.h"
26 #include "include/api/graph.h"
27 #include "include/api/dual_abi_helper.h"
28 
29 namespace mindspore {
30 /// \brief The Serialization class is used to summarize methods for reading and writing model files.
31 class MS_API Serialization {
32  public:
33   /// \brief Loads a model file from memory buffer.
34   ///
35   /// \param[in] model_data A buffer filled by model file.
36   /// \param[in] data_size The size of the buffer.
37   /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
38   /// \param[out] graph The output parameter, an object saves graph data.
39   /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. Not supported on MindSpore Lite.
40   /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. Not supported on MindSpore Lite.
41   ///
42   /// \return Status.
43   inline static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
44                             const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
45 
46   /// \brief Loads a model file from path.
47   ///
48   /// \param[in] file The path of model file.
49   /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
50   /// \param[out] graph The output parameter, an object saves graph data.
51   /// \param[in] dec_key The decryption key, key length is 16, 24, or 32. Not supported on MindSpore Lite.
52   /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC. Not supported on MindSpore Lite.
53   ///
54   /// \return Status.
55   inline static Status Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key = {},
56                             const std::string &dec_mode = kDecModeAesGcm);
57 
58   /// \brief Load multiple models from multiple files, MindSpore Lite does not provide this feature.
59   ///
60   /// \param[in] files The path of model files.
61   /// \param[in] model_type The Type of model file, options are ModelType::kMindIR, ModelType::kOM.
62   /// \param[out] graphs The output parameter, an object saves graph data.
63   /// \param[in] dec_key The decryption key, key length is 16, 24, or 32.
64   /// \param[in] dec_mode The decryption mode, optional options are AES-GCM, AES-CBC.
65   ///
66   /// \return Status.
67   inline static Status Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
68                             const Key &dec_key = {}, const std::string &dec_mode = kDecModeAesGcm);
69 
70   /// \brief Configure model parameters, MindSpore Lite does not provide this feature.
71   ///
72   /// \param[in] parameters The parameters.
73   /// \param[in] model The model.
74   ///
75   /// \return Status.
76   inline static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
77 
78   /// \brief Export training model from memory buffer, MindSpore Lite does not provide this feature.
79   ///
80   /// \param[in] model The model data.
81   /// \param[in] model_type The model file type.
82   /// \param[out] model_data The model buffer.
83   /// \param[in] quantization_type The quantification type.
84   /// \param[in] export_inference_only Whether to export a reasoning only model.
85   /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
86   /// empty, and export the complete reasoning model.
87   ///
88   /// \return Status.
89   inline static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
90                                    QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
91                                    const std::vector<std::string> &output_tensor_name = {});
92 
93   /// \brief Export training model from file.
94   ///
95   /// \param[in] model The model data.
96   /// \param[in] model_type The model file type.
97   /// \param[in] model_file The path of exported model file.
98   /// \param[in] quantization_type The quantification type.
99   /// \param[in] export_inference_only Whether to export a reasoning only model.
100   /// \param[in] output_tensor_name The set the name of the output tensor of the exported reasoning model, default as
101   /// empty, and export the complete reasoning model.
102   ///
103   /// \return Status.
104   inline static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
105                                    QuantizationType quantization_type = kNoQuant, bool export_inference_only = true,
106                                    std::vector<std::string> output_tensor_name = {});
107 
108   /// \brief Experimental feature. Export model's weights, which can be used in micro only.
109   ///
110   /// \param[in] model The model data.
111   /// \param[in] model_type The model file type.
112   /// \param[in] weight_file The path of exported weight file.
113   /// \param[in] is_inference Whether to export weights from a reasoning model. Currently, only support this is `true`.
114   /// \param[in] enable_fp16 Float-weight is whether to be saved in float16 format.
115   /// \param[in] changeable_weights_name The set the name of these weight tensors, whose shape is changeable.
116   ///
117   /// \return Status.
118   inline static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
119                                                          const std::string &weight_file, bool is_inference = true,
120                                                          bool enable_fp16 = false,
121                                                          const std::vector<std::string> &changeable_weights_name = {});
122 
123  private:
124   static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph, const Key &dec_key,
125                      const std::vector<char> &dec_mode);
126   static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
127   static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
128                      const std::vector<char> &dec_mode);
129   static Status Load(const std::vector<std::vector<char>> &files, ModelType model_type, std::vector<Graph> *graphs,
130                      const Key &dec_key, const std::vector<char> &dec_mode);
131   static Status SetParameters(const std::map<std::vector<char>, Buffer> &parameters, Model *model);
132   static Status ExportModel(const Model &model, ModelType model_type, const std::vector<char> &model_file,
133                             QuantizationType quantization_type, bool export_inference_only,
134                             const std::vector<std::vector<char>> &output_tensor_name);
135   static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
136                             QuantizationType quantization_type, bool export_inference_only,
137                             const std::vector<std::vector<char>> &output_tensor_name);
138   static Status ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
139                                                   const std::vector<char> &weight_file, bool is_inference,
140                                                   bool enable_fp16,
141                                                   const std::vector<std::vector<char>> &changeable_weights_name);
142 };
143 
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::string & dec_mode)144 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
145                            const Key &dec_key, const std::string &dec_mode) {
146   return Load(model_data, data_size, model_type, graph, dec_key, StringToChar(dec_mode));
147 }
148 
Load(const std::string & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::string & dec_mode)149 Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph, const Key &dec_key,
150                            const std::string &dec_mode) {
151   return Load(StringToChar(file), model_type, graph, dec_key, StringToChar(dec_mode));
152 }
153 
Load(const std::vector<std::string> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::string & dec_mode)154 Status Serialization::Load(const std::vector<std::string> &files, ModelType model_type, std::vector<Graph> *graphs,
155                            const Key &dec_key, const std::string &dec_mode) {
156   return Load(VectorStringToChar(files), model_type, graphs, dec_key, StringToChar(dec_mode));
157 }
158 
SetParameters(const std::map<std::string,Buffer> & parameters,Model * model)159 Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
160   return SetParameters(MapStringToChar<Buffer>(parameters), model);
161 }
162 
ExportModel(const Model & model,ModelType model_type,const std::string & model_file,QuantizationType quantization_type,bool export_inference_only,std::vector<std::string> output_tensor_name)163 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
164                                   QuantizationType quantization_type, bool export_inference_only,
165                                   std::vector<std::string> output_tensor_name) {
166   return ExportModel(model, model_type, StringToChar(model_file), quantization_type, export_inference_only,
167                      VectorStringToChar(output_tensor_name));
168 }
169 
ExportModel(const Model & model,ModelType model_type,Buffer * model_data,QuantizationType quantization_type,bool export_inference_only,const std::vector<std::string> & output_tensor_name)170 Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data,
171                                   QuantizationType quantization_type, bool export_inference_only,
172                                   const std::vector<std::string> &output_tensor_name) {
173   return ExportModel(model, model_type, model_data, quantization_type, export_inference_only,
174                      VectorStringToChar(output_tensor_name));
175 }
176 
ExportWeightsCollaborateWithMicro(const Model & model,ModelType model_type,const std::string & weight_file,bool is_inference,bool enable_fp16,const std::vector<std::string> & changeable_weights_name)177 Status Serialization::ExportWeightsCollaborateWithMicro(const Model &model, ModelType model_type,
178                                                         const std::string &weight_file, bool is_inference,
179                                                         bool enable_fp16,
180                                                         const std::vector<std::string> &changeable_weights_name) {
181   return ExportWeightsCollaborateWithMicro(model, model_type, StringToChar(weight_file), is_inference, enable_fp16,
182                                            VectorStringToChar(changeable_weights_name));
183 }
184 }  // namespace mindspore
185 #endif  // MINDSPORE_INCLUDE_API_SERIALIZATION_H
186