1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/converter.h"
17 #include "include/api/data_type.h"
18 #include "tools/converter/cxx_api/converter_para.h"
19 #include "tools/converter/converter_context.h"
20 #include "tools/converter/converter.h"
21 #include "src/common/log_adapter.h"
22
23 namespace mindspore {
24 namespace {
25 constexpr size_t kMaxSectionNum = 100;
26 constexpr size_t kMaxConfigNumPerSection = 1000;
27 } // namespace
28 namespace lite {
29 int RunConverter(const std::shared_ptr<ConverterPara> &data_);
30 }
Converter()31 Converter::Converter() {
32 data_ = std::make_shared<ConverterPara>();
33 if (data_ == nullptr) {
34 MS_LOG(ERROR) << "Create ConverterPara failed";
35 }
36 }
37
Converter(converter::FmkType fmk_type,const std::vector<char> & model_file,const std::vector<char> & output_file,const std::vector<char> & weight_file)38 Converter::Converter(converter::FmkType fmk_type, const std::vector<char> &model_file,
39 const std::vector<char> &output_file, const std::vector<char> &weight_file) {
40 data_ = std::make_shared<ConverterPara>();
41 if (data_ != nullptr) {
42 data_->fmk_type = fmk_type;
43 data_->model_file = CharToString(model_file);
44 data_->output_file = CharToString(output_file);
45 data_->weight_file = CharToString(weight_file);
46 } else {
47 MS_LOG(ERROR) << "Create ConverterPara failed";
48 }
49 }
50
SetConfigFile(const std::vector<char> & config_file)51 void Converter::SetConfigFile(const std::vector<char> &config_file) {
52 if (data_ != nullptr) {
53 data_->config_file = CharToString(config_file);
54 }
55 }
56
GetConfigFileChar() const57 std::vector<char> Converter::GetConfigFileChar() const {
58 std::string cfg_file = "";
59 if (data_ != nullptr) {
60 cfg_file = data_->config_file;
61 }
62 return StringToChar(cfg_file);
63 }
64
SetConfigInfo(const std::vector<char> & section,const std::map<std::vector<char>,std::vector<char>> & config)65 void Converter::SetConfigInfo(const std::vector<char> §ion,
66 const std::map<std::vector<char>, std::vector<char>> &config) {
67 auto section_str = CharToString(section);
68 auto config_str = MapVectorCharToString(config);
69 if (data_ != nullptr) {
70 if (data_->config_param.size() > kMaxSectionNum) {
71 MS_LOG(ERROR) << "Section num " << data_->config_param.size() << "exceeds max num " << kMaxSectionNum;
72 return;
73 }
74 if (data_->config_param.find(section_str) != data_->config_param.end()) {
75 MS_LOG(WARNING) << "Section " << section_str << "already exists, "
76 << "value will be overwrite.";
77 }
78 if (config.size() > kMaxConfigNumPerSection) {
79 MS_LOG(ERROR) << "Config num " << config.size() << " exceeds max num " << kMaxConfigNumPerSection << " in "
80 << section_str;
81 return;
82 }
83 data_->config_param[section_str] = config_str;
84 }
85 }
86
GetConfigInfoChar() const87 std::map<std::vector<char>, std::map<std::vector<char>, std::vector<char>>> Converter::GetConfigInfoChar() const {
88 return MapMapStringToChar(data_->config_param);
89 }
90
SetWeightFp16(bool weight_fp16)91 void Converter::SetWeightFp16(bool weight_fp16) {
92 if (data_ != nullptr) {
93 data_->weight_fp16 = weight_fp16;
94 }
95 }
96
GetWeightFp16() const97 bool Converter::GetWeightFp16() const {
98 if (data_ != nullptr) {
99 return data_->weight_fp16;
100 } else {
101 return false;
102 }
103 }
104
SetInputShape(const std::map<std::vector<char>,std::vector<int64_t>> & input_shape)105 void Converter::SetInputShape(const std::map<std::vector<char>, std::vector<int64_t>> &input_shape) {
106 auto input_shape_str = MapCharToString(input_shape);
107 if (data_ != nullptr) {
108 for (auto &it : input_shape_str) {
109 lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(it.first, it.second);
110 }
111 data_->input_shape = input_shape_str;
112 }
113 }
114
GetInputShapeChar() const115 std::map<std::vector<char>, std::vector<int64_t>> Converter::GetInputShapeChar() const {
116 std::map<std::string, std::vector<int64_t>> input_shape = {};
117 if (data_ != nullptr) {
118 input_shape = data_->input_shape;
119 }
120 return MapStringToChar(input_shape);
121 }
122
SetInputFormat(Format format)123 void Converter::SetInputFormat(Format format) {
124 if (data_ != nullptr) {
125 if (format != DEFAULT_FORMAT) {
126 data_->input_format = format;
127 }
128 data_->spec_input_format = format;
129 }
130 }
131
GetInputFormat() const132 Format Converter::GetInputFormat() const {
133 if (data_ != nullptr) {
134 return data_->input_format;
135 } else {
136 return DEFAULT_FORMAT;
137 }
138 }
139
SetOutputFormat(Format format)140 void Converter::SetOutputFormat(Format format) {
141 if (data_ != nullptr) {
142 data_->spec_output_format = format;
143 }
144 }
145
SetInputDataType(DataType data_type)146 void Converter::SetInputDataType(DataType data_type) {
147 if (data_ != nullptr) {
148 data_->input_data_type = data_type;
149 }
150 }
151
GetInputDataType()152 DataType Converter::GetInputDataType() {
153 if (data_ != nullptr) {
154 return data_->input_data_type;
155 } else {
156 return DataType::kTypeUnknown;
157 }
158 }
159
SetOutputDataType(DataType data_type)160 void Converter::SetOutputDataType(DataType data_type) {
161 if (data_ != nullptr) {
162 data_->output_data_type = data_type;
163 }
164 }
165
GetOutputDataType()166 DataType Converter::GetOutputDataType() {
167 if (data_ != nullptr) {
168 return data_->output_data_type;
169 } else {
170 return DataType::kTypeUnknown;
171 }
172 }
173
SetSaveType(ModelType save_type)174 void Converter::SetSaveType(ModelType save_type) {
175 if (data_ != nullptr) {
176 data_->save_type = save_type;
177 }
178 }
179
GetSaveType() const180 ModelType Converter::GetSaveType() const {
181 if (data_ != nullptr) {
182 return data_->save_type;
183 } else {
184 return kMindIR_Lite;
185 }
186 }
187
SetDecryptKey(const std::vector<char> & key)188 void Converter::SetDecryptKey(const std::vector<char> &key) {
189 if (data_ != nullptr) {
190 data_->decrypt_key = CharToString(key);
191 }
192 }
193
GetDecryptKeyChar() const194 std::vector<char> Converter::GetDecryptKeyChar() const {
195 std::string decrypt_key = "";
196 if (data_ != nullptr) {
197 decrypt_key = data_->decrypt_key;
198 }
199 return StringToChar(decrypt_key);
200 }
201
SetDecryptMode(const std::vector<char> & mode)202 void Converter::SetDecryptMode(const std::vector<char> &mode) {
203 if (data_ != nullptr) {
204 data_->decrypt_mode = CharToString(mode);
205 }
206 }
207
GetDecryptModeChar() const208 std::vector<char> Converter::GetDecryptModeChar() const {
209 std::string decrypt_mode = "";
210 if (data_ != nullptr) {
211 decrypt_mode = data_->decrypt_mode;
212 }
213 return StringToChar(decrypt_mode);
214 }
215
SetEnableEncryption(bool encryption)216 void Converter::SetEnableEncryption(bool encryption) {
217 if (data_ != nullptr) {
218 data_->enable_encryption = encryption;
219 }
220 }
221
GetEnableEncryption() const222 bool Converter::GetEnableEncryption() const {
223 if (data_ != nullptr) {
224 return data_->enable_encryption;
225 } else {
226 return false;
227 }
228 }
229
SetEncryptKey(const std::vector<char> & key)230 void Converter::SetEncryptKey(const std::vector<char> &key) {
231 if (data_ != nullptr) {
232 data_->encrypt_key = CharToString(key);
233 }
234 }
235
GetEncryptKeyChar() const236 std::vector<char> Converter::GetEncryptKeyChar() const {
237 std::string encrypt_key = "";
238 if (data_ != nullptr) {
239 encrypt_key = data_->encrypt_key;
240 }
241 return StringToChar(encrypt_key);
242 }
243
SetInfer(bool infer)244 void Converter::SetInfer(bool infer) {
245 if (data_ != nullptr) {
246 data_->pre_infer = infer;
247 }
248 }
249
GetInfer() const250 bool Converter::GetInfer() const {
251 if (data_ != nullptr) {
252 return data_->pre_infer;
253 } else {
254 return false;
255 }
256 }
257
SetTrainModel(bool train_model)258 void Converter::SetTrainModel(bool train_model) {
259 if (data_ != nullptr) {
260 data_->train_model = train_model;
261 }
262 }
263
GetTrainModel() const264 bool Converter::GetTrainModel() const {
265 if (data_ != nullptr) {
266 return data_->train_model;
267 } else {
268 return false;
269 }
270 }
271
SetNoFusion(bool no_fusion)272 void Converter::SetNoFusion(bool no_fusion) {
273 if (data_ != nullptr) {
274 data_->no_fusion = no_fusion;
275 }
276 }
277
GetNoFusion()278 bool Converter::GetNoFusion() {
279 if (data_ != nullptr) {
280 return data_->no_fusion;
281 } else {
282 return false;
283 }
284 }
285
SetOptimizeTransformer(bool optimizeTransformer)286 void Converter::SetOptimizeTransformer(bool optimizeTransformer) {
287 if (data_ != nullptr) {
288 data_->optimize_transformer = optimizeTransformer;
289 }
290 }
291
GetOptimizeTransformer()292 bool Converter::GetOptimizeTransformer() {
293 if (data_ != nullptr) {
294 return data_->optimize_transformer;
295 } else {
296 return false;
297 }
298 }
299
SetDevice(const std::vector<char> & device)300 void Converter::SetDevice(const std::vector<char> &device) {
301 if (data_ != nullptr) {
302 data_->device = CharToString(device);
303 }
304 }
305
GetDeviceChar()306 std::vector<char> Converter::GetDeviceChar() {
307 std::string device = "";
308 if (data_ != nullptr) {
309 device = data_->device;
310 }
311 return StringToChar(device);
312 }
313
SetDeviceId(int32_t device_id)314 void Converter::SetDeviceId(int32_t device_id) {
315 if (data_ != nullptr) {
316 data_->aclModelOptionCfgParam.device_id = device_id;
317 }
318 }
319
GetDeviceId()320 int32_t Converter::GetDeviceId() {
321 if (data_ != nullptr) {
322 return data_->aclModelOptionCfgParam.device_id;
323 }
324 return 0;
325 }
326
SetRankId(int32_t rank_id)327 void Converter::SetRankId(int32_t rank_id) {
328 if (data_ != nullptr) {
329 data_->aclModelOptionCfgParam.rank_id = rank_id;
330 }
331 }
332
GetRankId()333 int32_t Converter::GetRankId() {
334 if (data_ != nullptr) {
335 return data_->aclModelOptionCfgParam.rank_id;
336 }
337 return 0;
338 }
339
SetProvider(const std::vector<char> & provider)340 void Converter::SetProvider(const std::vector<char> &provider) {
341 if (data_ != nullptr) {
342 data_->provider = CharToString(provider);
343 }
344 }
345
GetProviderChar()346 std::vector<char> Converter::GetProviderChar() {
347 std::string provider = "";
348 if (data_ != nullptr) {
349 provider = data_->provider;
350 }
351 return StringToChar(provider);
352 }
353
SetChipName(const std::vector<char> & chip_name)354 void Converter::SetChipName(const std::vector<char> &chip_name) {
355 if (data_ != nullptr) {
356 data_->chip_name = CharToString(chip_name);
357 }
358 }
359
GetChipNameChar()360 std::vector<char> Converter::GetChipNameChar() {
361 std::string chip_name = "";
362 if (data_ != nullptr) {
363 chip_name = data_->chip_name;
364 }
365 return StringToChar(chip_name);
366 }
367
Convert()368 Status Converter::Convert() {
369 if (data_ != nullptr) {
370 Status ret = Status(static_cast<StatusCode>(lite::RunConverter(data_, nullptr, nullptr, false)));
371 data_->decrypt_key.clear(); // clear key
372 data_->encrypt_key.clear(); // clear key
373 if (ret != kSuccess) {
374 MS_LOG(ERROR) << "Convert model failed, ret=" << ret;
375 }
376 return ret;
377 } else {
378 return kLiteError;
379 }
380 }
381
Convert(size_t * data_size)382 void *Converter::Convert(size_t *data_size) {
383 void *model_data = nullptr;
384 if (data_ != nullptr) {
385 Status ret = Status(static_cast<StatusCode>(lite::RunConverter(data_, &model_data, data_size, true)));
386 data_->decrypt_key.clear(); // clear key
387 data_->encrypt_key.clear(); // clear key
388 if (ret != kSuccess) {
389 MS_LOG(ERROR) << "Convert model failed, ret=" << ret;
390 }
391 } else {
392 MS_LOG(ERROR) << "Convert model failed, data is null.";
393 }
394 return model_data;
395 }
396
Convert(converter::FmkType fmk_type,const std::vector<char> & model_file,const std::vector<char> & output_file,const std::vector<char> & weight_file)397 Status Converter::Convert(converter::FmkType fmk_type, const std::vector<char> &model_file,
398 const std::vector<char> &output_file, const std::vector<char> &weight_file) {
399 if (data_ != nullptr) {
400 data_->fmk_type = fmk_type;
401 data_->model_file = CharToString(model_file);
402 data_->output_file = CharToString(output_file);
403 data_->weight_file = CharToString(weight_file);
404 Status ret = Converter::Convert();
405 if (ret != kSuccess) {
406 MS_LOG(ERROR) << "Convert model " << CharToString(model_file) << " failed, ret=" << ret;
407 }
408 lite::ConverterInnerContext::GetInstance()->Free();
409 return ret;
410 } else {
411 MS_LOG(ERROR) << "Convert model " << CharToString(model_file) << " failed, data is null.";
412 return kLiteError;
413 }
414 }
415 } // namespace mindspore
416