• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/api/serialization.h"
18 #include <algorithm>
19 #include <queue>
20 #include "include/api/graph.h"
21 #include "include/api/types.h"
22 #include "include/model.h"
23 #include "src/cxx_api/graph/graph_data.h"
24 #include "src/cxx_api/model/model_impl.h"
25 #include "src/cxx_api/converters.h"
26 #include "src/common/log_adapter.h"
27 
28 namespace mindspore {
Key(const char * dec_key,size_t key_len)29 Key::Key(const char *dec_key, size_t key_len) {
30   len = 0;
31   if (key_len >= max_key_len) {
32     MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
33     return;
34   }
35 
36   memcpy(key, dec_key, key_len);
37   len = key_len;
38 }
39 
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)40 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
41                            const Key &dec_key, const std::vector<char> &dec_mode) {
42   if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
43     MS_LOG(ERROR) << "Unsupported Feature.";
44     return kLiteError;
45   }
46 
47   if (model_data == nullptr) {
48     MS_LOG(ERROR) << "model data is nullptr.";
49     return kLiteNullptr;
50   }
51   if (graph == nullptr) {
52     MS_LOG(ERROR) << "graph is nullptr.";
53     return kLiteNullptr;
54   }
55   if (model_type != kMindIR) {
56     MS_LOG(ERROR) << "Unsupported IR.";
57     return kLiteInputParamInvalid;
58   }
59 
60   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
61   if (model == nullptr) {
62     MS_LOG(ERROR) << "New model failed.";
63     return kLiteNullptr;
64   }
65   auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
66   if (graph_data == nullptr) {
67     MS_LOG(ERROR) << "New graph data failed.";
68     return kLiteMemoryFailed;
69   }
70   *graph = Graph(graph_data);
71   return kSuccess;
72 }
73 
Load(const std::vector<char> & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)74 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
75                            const std::vector<char> &dec_mode) {
76   if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
77     MS_LOG(ERROR) << "Unsupported Feature.";
78     return kLiteError;
79   }
80 
81   if (graph == nullptr) {
82     MS_LOG(ERROR) << "graph is nullptr.";
83     return kLiteNullptr;
84   }
85   if (model_type != kMindIR) {
86     MS_LOG(ERROR) << "Unsupported IR.";
87     return kLiteInputParamInvalid;
88   }
89 
90   std::string filename(file.data(), file.size());
91   if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
92     filename = filename + ".ms";
93   }
94 
95   auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
96   if (model == nullptr) {
97     MS_LOG(ERROR) << "New model failed.";
98     return kLiteNullptr;
99   }
100   auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
101   if (graph_data == nullptr) {
102     MS_LOG(ERROR) << "New graph data failed.";
103     return kLiteMemoryFailed;
104   }
105   *graph = Graph(graph_data);
106   return kSuccess;
107 }
108 
Load(const std::vector<std::vector<char>> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::vector<char> & dec_mode)109 Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
110                            std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
111   MS_LOG(ERROR) << "Unsupported Feature.";
112   return kLiteError;
113 }
114 
SetParameters(const std::map<std::string,Buffer> & parameters,Model * model)115 Status Serialization::SetParameters(const std::map<std::string, Buffer> &parameters, Model *model) {
116   MS_LOG(ERROR) << "Unsupported feature.";
117   return kMEFailed;
118 }
119 
ExportModel(const Model & model,ModelType model_type,Buffer * model_data)120 Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
121   MS_LOG(ERROR) << "Unsupported feature.";
122   return kMEFailed;
123 }
124 
ExportModel(const Model & model,ModelType model_type,const std::string & model_file,QuantizationType quantization_type,bool export_inference_only,std::vector<std::string> output_tensor_name)125 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
126                                   QuantizationType quantization_type, bool export_inference_only,
127                                   std::vector<std::string> output_tensor_name) {
128   if (model.impl_ == nullptr) {
129     MS_LOG(ERROR) << "Model implement is null.";
130     return kLiteUninitializedObj;
131   }
132   if (!model.impl_->IsTrainModel()) {
133     MS_LOG(ERROR) << "Model is not TrainModel.";
134     return kLiteError;
135   }
136   if (model_type != kFlatBuffer) {
137     MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
138     return kLiteParamInvalid;
139   }
140   if (model.impl_->session_ == nullptr) {
141     MS_LOG(ERROR) << "Model session is nullptr.";
142     return kLiteError;
143   }
144   auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
145                                            A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name);
146 
147   return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
148 }
149 }  // namespace mindspore
150