1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/api/serialization.h"
18 #include <algorithm>
19 #include <queue>
20 #include "include/api/graph.h"
21 #include "include/api/types.h"
22 #include "include/model.h"
23 #include "src/cxx_api/graph/graph_data.h"
24 #include "src/cxx_api/model/model_impl.h"
25 #include "src/cxx_api/converters.h"
26 #include "src/common/log_adapter.h"
27
28 namespace mindspore {
Key(const char * dec_key,size_t key_len)29 Key::Key(const char *dec_key, size_t key_len) {
30 len = 0;
31 if (key_len >= max_key_len) {
32 MS_LOG(ERROR) << "Invalid key len " << key_len << " is more than max key len " << max_key_len;
33 return;
34 }
35
36 memcpy(key, dec_key, key_len);
37 len = key_len;
38 }
39
Load(const void * model_data,size_t data_size,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)40 Status Serialization::Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph,
41 const Key &dec_key, const std::vector<char> &dec_mode) {
42 if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
43 MS_LOG(ERROR) << "Unsupported Feature.";
44 return kLiteError;
45 }
46
47 if (model_data == nullptr) {
48 MS_LOG(ERROR) << "model data is nullptr.";
49 return kLiteNullptr;
50 }
51 if (graph == nullptr) {
52 MS_LOG(ERROR) << "graph is nullptr.";
53 return kLiteNullptr;
54 }
55 if (model_type != kMindIR) {
56 MS_LOG(ERROR) << "Unsupported IR.";
57 return kLiteInputParamInvalid;
58 }
59
60 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(static_cast<const char *>(model_data), data_size));
61 if (model == nullptr) {
62 MS_LOG(ERROR) << "New model failed.";
63 return kLiteNullptr;
64 }
65 auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
66 if (graph_data == nullptr) {
67 MS_LOG(ERROR) << "New graph data failed.";
68 return kLiteMemoryFailed;
69 }
70 *graph = Graph(graph_data);
71 return kSuccess;
72 }
73
Load(const std::vector<char> & file,ModelType model_type,Graph * graph,const Key & dec_key,const std::vector<char> & dec_mode)74 Status Serialization::Load(const std::vector<char> &file, ModelType model_type, Graph *graph, const Key &dec_key,
75 const std::vector<char> &dec_mode) {
76 if (dec_key.len != 0 || CharToString(dec_mode) != kDecModeAesGcm) {
77 MS_LOG(ERROR) << "Unsupported Feature.";
78 return kLiteError;
79 }
80
81 if (graph == nullptr) {
82 MS_LOG(ERROR) << "graph is nullptr.";
83 return kLiteNullptr;
84 }
85 if (model_type != kMindIR) {
86 MS_LOG(ERROR) << "Unsupported IR.";
87 return kLiteInputParamInvalid;
88 }
89
90 std::string filename(file.data(), file.size());
91 if (filename.substr(filename.find_last_of(".") + 1) != "ms") {
92 filename = filename + ".ms";
93 }
94
95 auto model = std::shared_ptr<lite::Model>(lite::Model::Import(filename.c_str()));
96 if (model == nullptr) {
97 MS_LOG(ERROR) << "New model failed.";
98 return kLiteNullptr;
99 }
100 auto graph_data = std::shared_ptr<Graph::GraphData>(new (std::nothrow) Graph::GraphData(model));
101 if (graph_data == nullptr) {
102 MS_LOG(ERROR) << "New graph data failed.";
103 return kLiteMemoryFailed;
104 }
105 *graph = Graph(graph_data);
106 return kSuccess;
107 }
108
Load(const std::vector<std::vector<char>> & files,ModelType model_type,std::vector<Graph> * graphs,const Key & dec_key,const std::vector<char> & dec_mode)109 Status Serialization::Load(const std::vector<std::vector<char>> &files, ModelType model_type,
110 std::vector<Graph> *graphs, const Key &dec_key, const std::vector<char> &dec_mode) {
111 MS_LOG(ERROR) << "Unsupported Feature.";
112 return kLiteError;
113 }
114
SetParameters(const std::map<std::string,Buffer> & parameters,Model * model)115 Status Serialization::SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model) {
116 MS_LOG(ERROR) << "Unsupported feature.";
117 return kMEFailed;
118 }
119
ExportModel(const Model & model,ModelType model_type,Buffer * model_data)120 Status Serialization::ExportModel(const Model &model, ModelType model_type, Buffer *model_data) {
121 MS_LOG(ERROR) << "Unsupported feature.";
122 return kMEFailed;
123 }
124
ExportModel(const Model & model,ModelType model_type,const std::string & model_file,QuantizationType quantization_type,bool export_inference_only,std::vector<std::string> output_tensor_name)125 Status Serialization::ExportModel(const Model &model, ModelType model_type, const std::string &model_file,
126 QuantizationType quantization_type, bool export_inference_only,
127 std::vector<std::string> output_tensor_name) {
128 if (model.impl_ == nullptr) {
129 MS_LOG(ERROR) << "Model implement is null.";
130 return kLiteUninitializedObj;
131 }
132 if (!model.impl_->IsTrainModel()) {
133 MS_LOG(ERROR) << "Model is not TrainModel.";
134 return kLiteError;
135 }
136 if (model_type != kFlatBuffer) {
137 MS_LOG(ERROR) << "Unsupported Export Format " << model_type;
138 return kLiteParamInvalid;
139 }
140 if (model.impl_->session_ == nullptr) {
141 MS_LOG(ERROR) << "Model session is nullptr.";
142 return kLiteError;
143 }
144 auto ret = model.impl_->session_->Export(model_file, export_inference_only ? lite::MT_INFERENCE : lite::MT_TRAIN,
145 A2L_ConvertQT(quantization_type), lite::FT_FLATBUFFERS, output_tensor_name);
146
147 return (ret == mindspore::lite::RET_OK) ? kSuccess : kLiteError;
148 }
149 } // namespace mindspore
150