1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/converter/converter.h"
19 #include <memory>
20 #include <vector>
21 #include <set>
22 #include <map>
23 #include <tuple>
24 #include <algorithm>
25 #include <utility>
26 #include "src/common/log_adapter.h"
27 #include "tools/common/meta_graph_serializer.h"
28 #include "tools/lite_exporter/anf_exporter.h"
29 #include "tools/graph_kernel/converter/graph_kernel_optimization.h"
30 #ifdef SUPPORT_TRAIN
31 #include "src/train/train_populate_parameter.h"
32 #endif
33 #include "include/registry/model_parser_registry.h"
34 #include "include/api/format.h"
35 #include "src/common/dynamic_library_loader.h"
36 #include "src/common/log_util.h"
37 #include "tools/converter/parser/parser_utils.h"
38 #include "tools/converter/import/mindspore_importer.h"
39 #include "nnacl/op_base.h"
40 #include "tools/converter/micro/coder/coder.h"
41 #include "src/common/prim_util.h"
42 #include "src/common/version_manager.h"
43 #include "tools/common/tensor_util.h"
44 #include "include/api/model.h"
45 #include "tools/mindir_exporter/mindir_serializer.h"
46 #include "src/common/primitive_t_utils.h"
47 #include "tools/converter/config_parser/acl_option_param_parser.h"
48 #include "tools/converter/config_parser/micro_param_parser.h"
49 #include "tools/converter/config_parser/preprocess_parser.h"
50 #include "tools/converter/config_parser/quant_param_parser.h"
51 #include "tools/converter/config_parser/graph_kernel_param_parser.h"
52 #include "tools/converter/config_parser/third_party_param_parser.h"
53 #include "tools/converter/converter_funcgraph.h"
54 #include "tools/converter/converter_metagraph.h"
55 #include "tools/common/string_util.h"
56 #include "src/common/file_utils.h"
57 #include "ops/dynamic_shape.h"
58 #include "tools/common/parse_config_utils.h"
59 #include "src/common/file_utils.h"
60 #include "tools/converter/converter_packed_node.h"
61 #include "tools/converter/config_parser/cpu_option_param_parser.h"
62 #include "tools/converter/export_model.h"
63
64 namespace mindspore {
65 std::map<std::string, Format> StrToEnumFormatMap = {{"NHWC", Format::NHWC}, {"NCHW", Format::NCHW}};
66 extern "C" {
67 void mindspore_log_init();
68 }
69 namespace lite {
70 #define CONVERTER_LOG_ERROR(str) \
71 do { \
72 MS_LOG(ERROR) << str; \
73 std::cout << str << std::endl; \
74 } while (0);
75
76 namespace {
77 constexpr size_t kMaxNum1024 = 1024;
78 constexpr size_t kPluginPathMaxNum = 10;
79 constexpr int kPathLengthUpperLimit = 1024;
80 constexpr size_t kEncMaxLen = 16;
81 constexpr auto kFmk = "fmk";
82 constexpr auto kModelFile = "modelFile";
83 constexpr auto kOutputFile = "outputFile";
84 constexpr auto kWeightFile = "weightFile";
85 constexpr auto kFp16 = "fp16";
86 constexpr auto kInputshape = "inputShape";
87 constexpr auto kInputDataFormat = "inputDataFormat";
88 constexpr auto kEncryptKey = "encryptKey";
89 constexpr auto kEncryption = "encryption";
90 constexpr auto kInputDataType = "inputDataType";
91 constexpr auto kOutputDataType = "outputDataType";
92 constexpr auto kInfer = "infer";
93 std::map<std::string, FmkType> StrToEnumFmkTypeMap = {
94 {"CAFFE", FmkType::kFmkTypeCaffe}, {"MINDIR", FmkType::kFmkTypeMs}, {"TFLITE", FmkType::kFmkTypeTflite},
95 {"ONNX", FmkType::kFmkTypeOnnx}, {"TF", FmkType::kFmkTypeTf}, {"PYTORCH", FmkType::kFmkTypePytorch},
96 {"MSLITE", FmkType::kFmkTypeMsLite}};
97 std::map<std::string, DataType> StrToEnumDataTypeMap = {{"FLOAT", DataType::kNumberTypeFloat32},
98 {"INT8", DataType::kNumberTypeInt8},
99 {"UINT8", DataType::kNumberTypeUInt8},
100 {"DEFAULT", DataType::kTypeUnknown}};
101
102 #if defined(_WIN32) || defined(_WIN64)
103 static const char kSlash[] = "\\";
104 #else
105 static const char kSlash[] = "/";
106 #endif
107
108 // Deal with early release of 3rd-party plugin library.
109 static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
FileExists(const std::string & path)110 bool FileExists(const std::string &path) {
111 std::ifstream file(path);
112 return file.good();
113 }
114
InitModelFmk(const std::string & value,const std::shared_ptr<ConverterPara> & param)115 int InitModelFmk(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
116 if (StrToEnumFmkTypeMap.find(value) != StrToEnumFmkTypeMap.end()) {
117 param->fmk_type = StrToEnumFmkTypeMap.at(value);
118 } else {
119 std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|MSLITE" << std::endl;
120 return RET_INPUT_PARAM_INVALID;
121 }
122 return RET_OK;
123 }
124
InitModelFile(const std::string & value,const std::shared_ptr<ConverterPara> & param)125 int InitModelFile(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
126 if (value.empty() || !FileExists(value)) {
127 MS_LOG(ERROR) << "model file path is empty or invalid";
128 return RET_INPUT_PARAM_INVALID;
129 }
130 param->model_file = value;
131 return RET_OK;
132 }
133
InitModelInputDataType(const std::string & value,const std::shared_ptr<ConverterPara> & param)134 int InitModelInputDataType(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
135 if (StrToEnumDataTypeMap.find(value) == StrToEnumDataTypeMap.end()) {
136 std::cerr << "INPUT INVALID: inputDataType is invalid, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT"
137 << std::endl;
138 return RET_INPUT_PARAM_INVALID;
139 }
140 param->input_data_type = StrToEnumDataTypeMap.at(value);
141 return RET_OK;
142 }
143
InitModelOutputDataType(const std::string & value,const std::shared_ptr<ConverterPara> & param)144 int InitModelOutputDataType(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
145 if (StrToEnumDataTypeMap.find(value) == StrToEnumDataTypeMap.end()) {
146 std::cerr << "OUTPUT INVALID: outputDataType is invalid, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT"
147 << std::endl;
148 return RET_INPUT_PARAM_INVALID;
149 }
150 param->output_data_type = StrToEnumDataTypeMap.at(value);
151 return RET_OK;
152 }
153
InitModelSaveFP16(const std::string & value,const std::shared_ptr<ConverterPara> & param)154 int InitModelSaveFP16(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
155 if (value == "on") {
156 param->weight_fp16 = true;
157 } else if (value.empty() || value == "off") {
158 param->weight_fp16 = false;
159 } else {
160 std::cerr << "Init save_fp16 failed." << std::endl;
161 return RET_INPUT_PARAM_INVALID;
162 }
163 return RET_OK;
164 }
165
InitModelTrainMode(const std::string & value,const std::shared_ptr<ConverterPara> & param)166 int InitModelTrainMode(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
167 if (value == "true") {
168 param->train_model = true;
169 } else if (value.empty() || value == "false") {
170 param->train_model = false;
171 } else {
172 std::cerr << "INPUT ILLEGAL: trainModel must be true|false " << std::endl;
173 return RET_INPUT_PARAM_INVALID;
174 }
175 return RET_OK;
176 }
177
InitModelInputShape(const std::string & value,const std::shared_ptr<ConverterPara> & param)178 int InitModelInputShape(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
179 if (value.empty()) {
180 return RET_OK;
181 }
182 std::vector<int64_t> shape;
183 auto shape_strs = lite::StrSplit(value, std::string(";"));
184 for (const auto &shape_str : shape_strs) {
185 if (shape_str.empty()) {
186 continue;
187 }
188 shape.clear();
189 auto string_split = lite::StrSplit(shape_str, std::string(":"));
190 constexpr int kMinShapeSizeInStr = 2;
191 if (string_split.size() < kMinShapeSizeInStr) {
192 MS_LOG(ERROR) << "shape size must not be less than " << kMinShapeSizeInStr;
193 return lite::RET_INPUT_PARAM_INVALID;
194 }
195 auto name = string_split[0];
196 for (size_t i = 1; i < string_split.size() - 1; ++i) {
197 name += ":" + string_split[i];
198 }
199 if (name.empty()) {
200 MS_LOG(ERROR) << "input tensor name is empty";
201 return lite::RET_INPUT_PARAM_INVALID;
202 }
203 auto dim_strs = string_split[string_split.size() - 1];
204 if (dim_strs.empty()) {
205 MS_LOG(ERROR) << "input tensor dim string is empty";
206 return lite::RET_INPUT_PARAM_INVALID;
207 }
208 auto dims = lite::StrSplit(dim_strs, std::string(","));
209 if (dims.empty()) {
210 MS_LOG(ERROR) << "input tensor dim is empty";
211 return lite::RET_INPUT_PARAM_INVALID;
212 }
213 for (const auto &dim : dims) {
214 int64_t dim_value;
215 try {
216 dim_value = std::stoi(dim);
217 } catch (const std::exception &e) {
218 MS_LOG(ERROR) << "Get dim failed: " << e.what();
219 return lite::RET_INPUT_PARAM_INVALID;
220 }
221 shape.push_back(dim_value);
222 }
223 param->input_shape[name] = shape;
224 }
225 return RET_OK;
226 }
227
InitModelInputDataFormat(const std::string & value,const std::shared_ptr<ConverterPara> & param)228 int InitModelInputDataFormat(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
229 if (StrToEnumFormatMap.find(value) != StrToEnumFormatMap.end()) {
230 param->input_format = StrToEnumFormatMap.at(value);
231 } else if (!value.empty()) {
232 MS_LOG(ERROR) << "Input format is invalid.";
233 return RET_INPUT_PARAM_INVALID;
234 }
235 return RET_OK;
236 }
237
InitModelInfer(const std::string & value,const std::shared_ptr<ConverterPara> & param)238 int InitModelInfer(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
239 if (value == "true") {
240 param->pre_infer = true;
241 } else if (value == "false" || value.empty()) {
242 param->pre_infer = false;
243 } else {
244 std::cerr << "INPUT ILLEGAL: infer must be true|false " << std::endl;
245 return RET_INPUT_PARAM_INVALID;
246 }
247 return RET_OK;
248 }
249
InitModelNoFusion(const std::string & value,const std::shared_ptr<ConverterPara> & param)250 int InitModelNoFusion(const std::string &value, const std::shared_ptr<ConverterPara> ¶m) {
251 if (value == "true") {
252 param->no_fusion = true;
253 } else if (value == "false") {
254 param->no_fusion = false;
255 } else if (!value.empty()) {
256 std::cerr << "INPUT ILLEGAL: NoFusion must be true|false " << std::endl;
257 return RET_INPUT_PARAM_INVALID;
258 }
259 return RET_OK;
260 }
261
CreateConvertParam(const std::map<std::string,string> & model_params)262 std::shared_ptr<ConverterPara> CreateConvertParam(const std::map<std::string, string> &model_params) {
263 auto parm = std::make_shared<ConverterPara>();
264 std::map<std::string, std::function<int(const std::string &, const std::shared_ptr<ConverterPara> &)>> parse_funcs = {
265 {"fmk", InitModelFmk},
266 {"modelFile", InitModelFile},
267 {"inputDataType", InitModelInputDataType},
268 {"outputDataType", InitModelOutputDataType},
269 {"fp16", InitModelSaveFP16},
270 {"trainModel", InitModelTrainMode},
271 {"inputShape", InitModelInputShape},
272 {"inputDataFormat", InitModelInputDataFormat},
273 {"infer", InitModelInfer},
274 {"NoFusion", InitModelNoFusion}};
275 if (model_params.find("fmk") == model_params.end() || model_params.find("modelFile") == model_params.end()) {
276 MS_LOG(ERROR) << "INPUT ILLEGAL: fmk and modelFile must be set in [model_param].";
277 return nullptr;
278 }
279 for (auto &pair : model_params) {
280 if (parse_funcs.find(pair.first) == parse_funcs.end()) {
281 MS_LOG(ERROR) << "INPUT ILLEGAL: `" << pair.first << "` is not supported in [model_param]";
282 return nullptr;
283 }
284 if (parse_funcs[pair.first](pair.second, parm) != RET_OK) {
285 MS_LOG(ERROR) << pair.first << "'value is invalid";
286 return nullptr;
287 }
288 }
289 return parm;
290 }
291 } // namespace
292
StoreConverterParameters(const std::shared_ptr<ConverterPara> & param)293 STATUS StoreConverterParameters(const std::shared_ptr<ConverterPara> ¶m) {
294 if (param == nullptr) {
295 MS_LOG(ERROR) << "Input param is nullptr";
296 return RET_INPUT_PARAM_INVALID;
297 }
298 std::string param_input_shape;
299 for (auto i = param->input_shape.cbegin(); i != param->input_shape.cend(); ++i) {
300 std::stringstream input_shape_ss;
301 string input_shape_str;
302 (void)copy(i->second.begin(), i->second.end(), std::ostream_iterator<int>(input_shape_ss, ","));
303 input_shape_str = input_shape_ss.str();
304 input_shape_str.erase(input_shape_str.end() - 1);
305 param_input_shape += i->first + ":" + input_shape_str + ";";
306 }
307 std::map<std::string, std::map<std::string, std::string>> conver_param_maps;
308 conver_param_maps[mindspore::converter::KConverterParam][kFmk] = std::to_string(param->fmk_type);
309 conver_param_maps[mindspore::converter::KConverterParam][kModelFile] = param->model_file;
310 conver_param_maps[mindspore::converter::KConverterParam][kOutputFile] = param->output_file;
311 conver_param_maps[mindspore::converter::KConverterParam][kWeightFile] = param->weight_file;
312 std::stringstream weight_fp16_ss;
313 weight_fp16_ss << std::boolalpha << param->weight_fp16;
314 conver_param_maps[mindspore::converter::KConverterParam][kFp16] = weight_fp16_ss.str();
315 conver_param_maps[mindspore::converter::KConverterParam][kInputshape] = param_input_shape;
316 conver_param_maps[mindspore::converter::KConverterParam][kInputDataFormat] = std::to_string(param->input_format);
317 conver_param_maps[mindspore::converter::KConverterParam][kEncryptKey] = param->encrypt_key;
318 std::stringstream encryption_ss;
319 encryption_ss << std::boolalpha << param->enable_encryption;
320 conver_param_maps[mindspore::converter::KConverterParam][kEncryption] = encryption_ss.str();
321 conver_param_maps[mindspore::converter::KConverterParam][kInputDataType] =
322 std::to_string(static_cast<int>(param->input_data_type));
323 conver_param_maps[mindspore::converter::KConverterParam][kOutputDataType] =
324 std::to_string(static_cast<int>(param->output_data_type));
325 std::stringstream pre_infer_ss;
326 pre_infer_ss << std::boolalpha << param->pre_infer;
327 conver_param_maps[mindspore::converter::KConverterParam][kInfer] = pre_infer_ss.str();
328 for (const auto &config_info : conver_param_maps) {
329 ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
330 }
331 return RET_OK;
332 }
333
CheckExistCustomOps(const schema::MetaGraphT * meta_graph,bool * exist_custom_nodes)334 int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom_nodes) {
335 MS_CHECK_TRUE_MSG(meta_graph != nullptr && exist_custom_nodes != nullptr, RET_ERROR, "input params contain nullptr.");
336 flatbuffers::FlatBufferBuilder fbb(kMaxNum1024);
337 for (const auto &node : meta_graph->nodes) {
338 MS_CHECK_TRUE_RET(node != nullptr, RET_ERROR);
339 auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
340 if (prim == nullptr) {
341 MS_LOG(ERROR) << "get primitive failed.";
342 fbb.Clear();
343 return RET_ERROR;
344 }
345 if (IsCustomNode(prim, static_cast<int>(SCHEMA_CUR))) {
346 *exist_custom_nodes = true;
347 break;
348 }
349 }
350 fbb.Clear();
351 return RET_OK;
352 }
353
PreInference(const schema::MetaGraphT & meta_graph,bool train_model)354 int PreInference(const schema::MetaGraphT &meta_graph, bool train_model) {
355 if (train_model) {
356 MS_LOG(WARNING) << "train model dont support pre-infer.";
357 return RET_OK;
358 }
359
360 bool exist_custom_nodes = false;
361 auto check_ret = CheckExistCustomOps(&meta_graph, &exist_custom_nodes);
362 if (check_ret == RET_ERROR) {
363 MS_LOG(ERROR) << "CheckExistCustomOps failed.";
364 return RET_ERROR;
365 }
366 if (exist_custom_nodes) {
367 MS_LOG(WARNING) << "exist custom nodes and will not be pre-infer.";
368 return RET_OK;
369 }
370 mindspore::Model model;
371 flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
372 auto offset = schema::MetaGraph::Pack(builder, &meta_graph);
373 builder.Finish(offset);
374 schema::FinishMetaGraphBuffer(builder, offset);
375 int size = builder.GetSize();
376 auto content = builder.GetBufferPointer();
377 if (content == nullptr) {
378 MS_LOG(ERROR) << "GetBufferPointer nullptr";
379 return RET_ERROR;
380 }
381 auto context = std::make_shared<mindspore::Context>();
382 if (context == nullptr) {
383 MS_LOG(ERROR) << "New context failed while running ";
384 std::cerr << "New context failed while running " << std::endl;
385 return RET_ERROR;
386 }
387 std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
388 auto &device_list = context->MutableDeviceInfo();
389 device_list.push_back(device_info);
390
391 auto ret = model.Build(content, size, kMindIR_Lite, context);
392 if (ret != kSuccess) {
393 MS_LOG(ERROR) << "Build error ";
394 std::cerr << "Build error " << std::endl;
395 return RET_ERROR;
396 }
397 for (auto &tensor : model.GetInputs()) {
398 if (tensor.Shape().empty() || tensor.DataSize() == 0 ||
399 std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) {
400 MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer.";
401 return RET_OK;
402 }
403 auto status = GenerateRandomData(&tensor);
404 if (status != RET_OK) {
405 MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed.";
406 return status;
407 }
408 }
409 std::vector<MSTensor> outputs;
410 ret = model.Predict(model.GetInputs(), &outputs);
411 if (ret != kSuccess) {
412 MS_LOG(ERROR) << "Inference error ";
413 std::cerr << "Inference error " << std::endl;
414 return RET_ERROR;
415 }
416 return RET_OK;
417 }
418
InitConfigParam(const std::shared_ptr<ConverterPara> & param,std::map<int,std::map<std::string,std::string>> * model_param_infos)419 int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> ¶m,
420 std::map<int, std::map<std::string, std::string>> *model_param_infos) {
421 model_param_infos->clear();
422 lite::ConfigFileParser config_parser;
423 std::map<std::string, std::map<std::string, std::string>> maps;
424 auto ret = RET_OK;
425 auto parse_map_ret = RET_OK;
426 if (!param->config_file.empty()) {
427 ret = config_parser.ParseConfigFile(param->config_file, nullptr);
428 parse_map_ret = mindspore::lite::ParseConfigFile(param->config_file, &maps, model_param_infos);
429 } else {
430 ret = config_parser.ParseConfigParam(¶m->config_param);
431 }
432 if (ret != RET_OK || parse_map_ret != RET_OK) {
433 MS_LOG(ERROR) << "Parse config param failed.";
434 return ret;
435 }
436 if (model_param_infos->empty()) {
437 ret = lite::PreprocessParser::ParsePreprocess(config_parser.GetDataPreProcessString(), ¶m->dataPreProcessParam);
438 if (ret != RET_OK) {
439 MS_LOG(ERROR) << "Parse preprocess failed.";
440 return ret;
441 }
442 ret = lite::QuantParamParser::ParseCommonQuant(config_parser.GetCommonQuantString(), ¶m->commonQuantParam);
443 if (ret != RET_OK) {
444 MS_LOG(ERROR) << "Parse common quant param failed.";
445 return ret;
446 }
447 ret = lite::QuantParamParser::ParseFullQuant(config_parser.GetFullQuantString(), ¶m->fullQuantParam);
448 if (ret != RET_OK) {
449 MS_LOG(ERROR) << "Parse full quant param failed.";
450 return ret;
451 }
452 ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(),
453 ¶m->mixedBitWeightQuantParam);
454 if (ret != RET_OK) {
455 MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
456 return ret;
457 }
458 if (param->fmk_type == converter::kFmkTypeThirdParty) {
459 ret = lite::ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), ¶m->thirdPartyModelParam);
460 if (ret != RET_OK) {
461 MS_LOG(ERROR) << "Parse third party param failed.";
462 return ret;
463 }
464 }
465 ret = InitExtendedIntegrationInfo(param, config_parser);
466 if (ret != RET_OK) {
467 MS_LOG(ERROR) << "Parse extended integration info failed.";
468 return ret;
469 }
470 std::string output_file = param->output_file;
471 param->aclModelOptionCfgParam.om_file_path = output_file;
472 auto dir_pos = output_file.find_last_of('/');
473 param->aclModelOptionCfgParam.dump_model_name =
474 dir_pos != std::string::npos ? output_file.substr(dir_pos + 1) : output_file;
475 lite::AclOptionParamParser acl_param_parser;
476 ret = acl_param_parser.ParseAclOptionCfg(config_parser.GetAclOptionCfgString(), ¶m->aclModelOptionCfgParam);
477 if (ret != RET_OK) {
478 MS_LOG(ERROR) << "Parse acl option param failed.";
479 return ret;
480 }
481 // parse ascend_context in config file, the priority is higher
482 if (maps.find("ascend_context") != maps.end()) {
483 auto map = maps.at("ascend_context");
484 if (!config_parser.SetParamByConfigfile(param, map)) {
485 MS_LOG(ERROR) << "Failed to parse config item ascend_context";
486 return RET_ERROR;
487 }
488 }
489 if (!param->config_file.empty()) {
490 (void)CheckOfflineParallelConfig(param->config_file, ¶m->parallel_split_config);
491 }
492
493 lite::CpuOptionParamParser cpu_param_parser;
494 ret = cpu_param_parser.ParseCpuOptionCfg(config_parser.GetCpuOptionCfgString(), ¶m->cpuOptionCfgParam);
495 if (ret != RET_OK) {
496 MS_LOG(ERROR) << "Parse cpu option param failed.";
497 return ret;
498 }
499 }
500 MS_LOG(INFO)
501 << "If there are multi models, only support micro_param and model_param, other configure can not take effect";
502
503 lite::MicroParamParser micro_param_parser;
504 ret = micro_param_parser.ParseMicroParam(config_parser.GetMicroParamString(), ¶m->microParam);
505 if (ret != RET_OK) {
506 MS_LOG(ERROR) << "Parse micro param failed.";
507 return ret;
508 }
509 ret =
510 lite::QuantParamParser::ParseTransformQuant(config_parser.GetTransformQuantString(), ¶m->transformQuantParam);
511 if (ret != RET_OK) {
512 MS_LOG(ERROR) << "Parse transform quant param failed.";
513 return ret;
514 }
515 ret = lite::QuantParamParser::ParseDynamicQuant(config_parser.GetDynamicQuantString(), ¶m->dynamicQuantParam);
516 if (ret != RET_OK) {
517 MS_LOG(ERROR) << "Parse dynamic quant param failed.";
518 return ret;
519 }
520 lite::GraphKernelParamParser graph_kernel_parser;
521 ret = graph_kernel_parser.ParseGraphKernelCfg(config_parser.GetGraphKernelString(), ¶m->graphKernelParam);
522 if (ret != RET_OK) {
523 MS_LOG(ERROR) << "Parse graph kernel param failed.";
524 return ret;
525 }
526 return RET_OK;
527 }
528
InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPara> & param,const lite::ConfigFileParser & config_parser)529 int ConverterImpl::InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPara> ¶m,
530 const lite::ConfigFileParser &config_parser) {
531 auto extended_info = config_parser.GetRegistryInfoString();
532 if (!extended_info.plugin_path.empty()) {
533 const char delimiter = ';';
534 auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, delimiter);
535 if (relative_path.size() > kPluginPathMaxNum) {
536 MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than " << kPluginPathMaxNum;
537 return RET_INPUT_PARAM_INVALID;
538 }
539 for (auto &i : relative_path) {
540 param->plugins_path.push_back(lite::RealPath(i.c_str()));
541 }
542 }
543
544 if (!extended_info.disable_fusion.empty()) {
545 if (extended_info.disable_fusion == "on") {
546 param->no_fusion = true;
547 } else if (extended_info.disable_fusion == "off") {
548 param->no_fusion = false;
549 } else {
550 std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl;
551 return RET_INPUT_PARAM_INVALID;
552 }
553 }
554
555 if (!extended_info.fusion_blacklists.empty()) {
556 std::vector<std::string> fusions = SplitStringToVector(extended_info.fusion_blacklists, ",");
557 for (const auto &fusion : fusions) {
558 bool inserted = false;
559 std::tie(std::ignore, inserted) = param->fusion_blacklists.insert(fusion);
560 if (inserted) {
561 MS_LOG(DEBUG) << "Value was inserted successfully.";
562 }
563 }
564 }
565 return RET_OK;
566 }
567
CheckOfflineParallelConfig(const std::string & file,ParallelSplitConfig * parallel_split_config)568 bool ConverterImpl::CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
569 // device: [device0 device1] ---> {cpu, gpu}
570 // computeRate: [x: y] x >=0 && y >=0 && x/y < 10
571 MS_ASSERT(parallel_split_config != nullptr);
572 std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
573 auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
574 if (compute_rate_result.empty()) {
575 return false;
576 }
577 std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
578 if (device0_result.empty()) {
579 return false;
580 }
581 std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
582 if (device1_result.empty()) {
583 return false;
584 }
585 bool device0_flag = false;
586 bool device1_flag = false;
587 for (const auto &device : config_devices) {
588 if (device == device0_result) {
589 device0_flag = true;
590 }
591 if (device == device1_result) {
592 device1_flag = true;
593 }
594 }
595 if (!device0_flag || !device1_flag) {
596 return false;
597 }
598 const char delimiter = ';';
599 std::vector<std::string> device_rates = lite::SplitStringToVector(compute_rate_result, delimiter);
600 const char colon = ':';
601 for (const auto &device : device_rates) {
602 std::vector<std::string> rate = lite::SplitStringToVector(device, colon);
603 int64_t compute_rate = 0;
604 try {
605 compute_rate = std::stoi(rate.back());
606 } catch (const std::exception &e) {
607 MS_LOG(ERROR) << "Get compute rate failed: " << e.what();
608 return false;
609 }
610 parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
611 }
612 const size_t support_rates_num = 2;
613 if (parallel_split_config->parallel_compute_rates_.size() != support_rates_num) {
614 return false;
615 }
616 int64_t bigger_rate = INT32_MIN;
617 int64_t smaller_rate = INT32_MAX;
618 for (const auto &rate : parallel_split_config->parallel_compute_rates_) {
619 if (rate <= 0 || rate > INT32_MAX) {
620 return false;
621 }
622 bigger_rate = std::max(rate, bigger_rate);
623 smaller_rate = std::min(rate, smaller_rate);
624 }
625 parallel_split_config->parallel_devices_.push_back(device0_result);
626 parallel_split_config->parallel_devices_.push_back(device1_result);
627 // parall_split_type will extend by other user's attr
628 parallel_split_config->parallel_split_type_ = SplitByUserRatio;
629 if (smaller_rate == 0) {
630 MS_LOG(ERROR) << "smaller_rate is zero";
631 return false;
632 }
633 return bigger_rate / smaller_rate <= kMaxSplitRatio;
634 }
635
GetStrFromConfigFile(const std::string & file,const std::string & target_key)636 std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
637 std::string res;
638 if (file.empty()) {
639 MS_LOG(ERROR) << "file is nullptr";
640 return res;
641 }
642 auto resolved_path = std::make_unique<char[]>(PATH_MAX);
643 if (resolved_path == nullptr) {
644 MS_LOG(ERROR) << "new resolved_path failed";
645 return "";
646 }
647
648 #ifdef _WIN32
649 auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
650 #else
651 char *real_path = realpath(file.c_str(), resolved_path.get());
652 #endif
653 if (real_path == nullptr || strlen(real_path) == 0) {
654 MS_LOG(ERROR) << "file path is not valid : " << file;
655 return "";
656 }
657 std::ifstream ifs(resolved_path.get());
658 if (!ifs.good()) {
659 MS_LOG(ERROR) << "file: " << real_path << " is not exist";
660 return res;
661 }
662 if (!ifs.is_open()) {
663 MS_LOG(ERROR) << "file: " << real_path << "open failed";
664 return res;
665 }
666 std::string line;
667 while (std::getline(ifs, line)) {
668 lite::Trim(&line);
669 if (line.empty() || line.at(0) == '#' || line.at(0) == '[') {
670 continue;
671 }
672 auto index = line.find('=');
673 if (index == std::string::npos) {
674 MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
675 return "";
676 }
677 auto key = line.substr(0, index);
678 auto value = line.substr(index + 1);
679 lite::Trim(&key);
680 lite::Trim(&value);
681 if (key == target_key) {
682 return value;
683 }
684 }
685 return res;
686 }
687
CheckFmkType(const std::shared_ptr<ConverterPara> & param)688 int CheckFmkType(const std::shared_ptr<ConverterPara> ¶m) {
689 if (param != nullptr) {
690 return RET_OK;
691 }
692 std::set kValidFmkTypes = {FmkType::kFmkTypeTf, FmkType::kFmkTypeCaffe, FmkType::kFmkTypeOnnx,
693 FmkType::kFmkTypeMs, FmkType::kFmkTypeTflite, FmkType::kFmkTypePytorch,
694 FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty};
695 if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) {
696 MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be "
697 "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY"
698 << ", but got " << param->fmk_type;
699 return RET_INPUT_PARAM_INVALID;
700 }
701 if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
702 MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag";
703 return RET_INPUT_PARAM_INVALID;
704 }
705 return RET_OK;
706 }
707
CheckModelFile(const std::shared_ptr<ConverterPara> & param)708 int CheckModelFile(const std::shared_ptr<ConverterPara> ¶m) {
709 if (param != nullptr) {
710 if (param->model_file.empty()) {
711 MS_LOG(ERROR) << "INPUT MISSING: model file path is necessary";
712 return RET_INPUT_PARAM_INVALID;
713 }
714 }
715 return RET_OK;
716 }
717
CheckOutputFile(const std::shared_ptr<ConverterPara> & param)718 int CheckOutputFile(const std::shared_ptr<ConverterPara> ¶m) {
719 if (param != nullptr && param->aclModelOptionCfgParam.offline) {
720 if (param->output_file.empty()) {
721 MS_LOG(ERROR) << "INPUT MISSING: output file path is necessary";
722 return RET_INPUT_PARAM_INVALID;
723 }
724
725 #ifdef _WIN32
726 replace(param->output_file.begin(), param->output_file.end(), '/', '\\');
727 #endif
728
729 if (param->output_file.rfind('/') == param->output_file.length() - 1 ||
730 param->output_file.rfind('\\') == param->output_file.length() - 1) {
731 MS_LOG(ERROR) << "INPUT ILLEGAL: output file must be a valid file path";
732 return RET_INPUT_PARAM_INVALID;
733 }
734 param->aclModelOptionCfgParam.om_file_path = param->output_file;
735 }
736 return RET_OK;
737 }
738
CheckInputShape(const std::shared_ptr<ConverterPara> & param)739 int CheckInputShape(const std::shared_ptr<ConverterPara> ¶m) {
740 if (param != nullptr) {
741 if (param->input_shape.empty()) {
742 return RET_OK;
743 }
744 for (const auto &elem : param->input_shape) {
745 std::vector<int64_t> dims = elem.second;
746 if (dims.empty()) {
747 MS_LOG(ERROR) << "INPUT MISSING: input tensor dim is empty";
748 return lite::RET_INPUT_PARAM_INVALID;
749 }
750 bool has_negative_dim = std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim < 0; });
751 if (has_negative_dim) {
752 MS_LOG(ERROR) << "INPUT ILLEGAL: Unsupported dim < 0.";
753 return lite::RET_INPUT_PARAM_INVALID;
754 }
755 }
756 }
757 return RET_OK;
758 }
759
CheckInputFormat(const std::shared_ptr<ConverterPara> & param)760 int CheckInputFormat(const std::shared_ptr<ConverterPara> ¶m) {
761 if (param != nullptr) {
762 std::set valid_values = {NHWC, NCHW};
763 if (std::find(valid_values.begin(), valid_values.end(), param->input_format) == valid_values.end()) {
764 MS_LOG(ERROR) << "INPUT ILLEGAL: input_format is not in {NHWC, NCHW}, but got " << param->input_format;
765 return RET_INPUT_PARAM_INVALID;
766 }
767 }
768 return RET_OK;
769 }
770
CheckInputOutputDataType(const std::shared_ptr<ConverterPara> & param)771 int CheckInputOutputDataType(const std::shared_ptr<ConverterPara> ¶m) {
772 if (param != nullptr) {
773 std::set input_valid_values = {
774 DataType::kNumberTypeFloat16, DataType::kNumberTypeFloat32, DataType::kNumberTypeInt8, DataType::kNumberTypeUInt8,
775 DataType::kNumberTypeInt32, DataType::kNumberTypeInt64, DataType::kTypeUnknown};
776 if (std::find(input_valid_values.begin(), input_valid_values.end(), param->input_data_type) ==
777 input_valid_values.end()) {
778 MS_LOG(ERROR) << "INPUT ILLEGAL: input_data_type is not in {kNumberTypeFloat32, kNumberTypeInt8, "
779 << "kNumberTypeUInt8, kTypeUnknown}, but got " << param->input_data_type;
780 return RET_INPUT_PARAM_INVALID;
781 }
782
783 std::set output_valid_values = {DataType::kNumberTypeFloat16, DataType::kNumberTypeFloat32,
784 DataType::kNumberTypeInt8, DataType::kNumberTypeUInt8, DataType::kTypeUnknown};
785 if (std::find(output_valid_values.begin(), output_valid_values.end(), param->output_data_type) ==
786 output_valid_values.end()) {
787 MS_LOG(ERROR) << "INPUT ILLEGAL: output_data_type is not in {kNumberTypeFloat32, kNumberTypeInt8, "
788 << "kNumberTypeUInt8, kTypeUnknown}, but got " << param->output_data_type;
789 return RET_INPUT_PARAM_INVALID;
790 }
791 }
792 return RET_OK;
793 }
794
CheckSaveType(const std::shared_ptr<ConverterPara> & param)795 int CheckSaveType(const std::shared_ptr<ConverterPara> ¶m) {
796 if (param != nullptr) {
797 std::set valid_values = {kMindIR, kMindIR_Lite};
798 if (std::find(valid_values.begin(), valid_values.end(), param->save_type) == valid_values.end()) {
799 MS_LOG(ERROR) << "INPUT ILLEGAL: save_type is not in {kMindIR, kMindIR_Lite}, but got " << param->save_type;
800 return RET_INPUT_PARAM_INVALID;
801 }
802 }
803 return RET_OK;
804 }
805
CheckEncrypt(const std::shared_ptr<ConverterPara> & param)806 int CheckEncrypt(const std::shared_ptr<ConverterPara> ¶m) {
807 if (param != nullptr) {
808 if (param->enable_encryption) {
809 if (param->encrypt_key.empty()) {
810 MS_LOG(ERROR) << "encryption param is true and encrypt_key param must be set. If you don't "
811 "need to use model encryption, please set encryption param to false.";
812 return RET_INPUT_PARAM_INVALID;
813 }
814 }
815 }
816 return RET_OK;
817 }
818
CheckTrainModel(const std::shared_ptr<ConverterPara> & param)819 int CheckTrainModel(const std::shared_ptr<ConverterPara> ¶m) {
820 if (param != nullptr) {
821 if (param->train_model) {
822 if (param->fmk_type != converter::FmkType::kFmkTypeMs) {
823 MS_LOG(ERROR) << "INPUT ILLEGAL: train model converter supporting only MINDIR format";
824 return RET_INPUT_PARAM_INVALID;
825 }
826 if ((param->input_data_type != DataType::kNumberTypeFloat32) &&
827 (param->input_data_type != DataType::kTypeUnknown)) {
828 MS_LOG(ERROR) << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors";
829 return RET_INPUT_PARAM_INVALID;
830 }
831 if ((param->output_data_type != DataType::kNumberTypeFloat32) &&
832 (param->output_data_type != DataType::kTypeUnknown)) {
833 MS_LOG(ERROR) << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors";
834 return RET_INPUT_PARAM_INVALID;
835 }
836 }
837 }
838 return RET_OK;
839 }
840
CheckDevice(const std::shared_ptr<ConverterPara> & param)841 int CheckDevice(const std::shared_ptr<ConverterPara> ¶m) {
842 if (param != nullptr && !(param->device.empty())) {
843 std::set valid_values = {"Ascend310", "Ascend310P", "Ascend", "GPU"};
844 if (std::find(valid_values.begin(), valid_values.end(), param->device) == valid_values.end()) {
845 MS_LOG(ERROR) << "INPUT ILLEGAL: device is not in {GPU, Ascend, Ascend310, Ascend310P}, got " << param->device;
846 return RET_INPUT_PARAM_INVALID;
847 }
848 }
849 return RET_OK;
850 }
851
CheckValueParam(const std::shared_ptr<ConverterPara> & param,bool is_multi_model)852 int CheckValueParam(const std::shared_ptr<ConverterPara> ¶m, bool is_multi_model) {
853 if (param == nullptr) {
854 MS_LOG(ERROR) << "INPUT MISSING: param is nullptr.";
855 return RET_INPUT_PARAM_INVALID;
856 }
857
858 auto ret = CheckFmkType(param);
859 if (ret != RET_OK) {
860 MS_LOG(ERROR) << "Check value of fmk_type failed.";
861 return RET_INPUT_PARAM_INVALID;
862 }
863
864 ret = CheckModelFile(param);
865 if (ret != RET_OK) {
866 MS_LOG(ERROR) << "Check value of model_file failed.";
867 return RET_INPUT_PARAM_INVALID;
868 }
869
870 if (!is_multi_model) {
871 ret = CheckOutputFile(param);
872 if (ret != RET_OK) {
873 MS_LOG(ERROR) << "Check value of output_file failed.";
874 return RET_INPUT_PARAM_INVALID;
875 }
876 }
877
878 ret = CheckInputShape(param);
879 if (ret != RET_OK) {
880 MS_LOG(ERROR) << "Check value of input_shape failed.";
881 return RET_INPUT_PARAM_INVALID;
882 }
883
884 ret = CheckInputFormat(param);
885 if (ret != RET_OK) {
886 MS_LOG(ERROR) << "Check value of input_format failed.";
887 return RET_INPUT_PARAM_INVALID;
888 }
889
890 ret = CheckInputOutputDataType(param);
891 if (ret != RET_OK) {
892 MS_LOG(ERROR) << "Check value of input_data_type or output_data_type failed.";
893 return RET_INPUT_PARAM_INVALID;
894 }
895
896 ret = CheckSaveType(param);
897 if (ret != RET_OK) {
898 MS_LOG(ERROR) << "Check value of save_type failed.";
899 return RET_INPUT_PARAM_INVALID;
900 }
901
902 ret = CheckEncrypt(param);
903 if (ret != RET_OK) {
904 MS_LOG(ERROR) << "Check value of encrypt failed.";
905 return RET_INPUT_PARAM_INVALID;
906 }
907
908 ret = CheckTrainModel(param);
909 if (ret != RET_OK) {
910 MS_LOG(ERROR) << "Check value of train model failed.";
911 return RET_INPUT_PARAM_INVALID;
912 }
913
914 ret = CheckDevice(param);
915 if (ret != RET_OK) {
916 MS_LOG(ERROR) << "Check device failed.";
917 return RET_INPUT_PARAM_INVALID;
918 }
919
920 return RET_OK;
921 }
922
LoadPluginLib(const std::shared_ptr<ConverterPara> & param)923 int ConverterImpl::LoadPluginLib(const std::shared_ptr<ConverterPara> ¶m) {
924 if (!param->plugins_path.empty()) {
925 for (auto &path : param->plugins_path) {
926 auto dl_loader = std::make_shared<DynamicLibraryLoader>();
927 MS_CHECK_TRUE_RET(dl_loader != nullptr, RET_ERROR);
928 auto status = dl_loader->Open(path);
929 if (status != RET_OK) {
930 MS_LOG(ERROR) << "open dynamic library failed. " << path;
931 return RET_ERROR;
932 }
933 dl_loaders.emplace_back(dl_loader);
934 }
935 }
936 return RET_OK;
937 }
938
Convert(const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save)939 int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> ¶m, void **model_data, size_t *data_size,
940 bool not_save) {
941 if (param == nullptr) {
942 MS_LOG(ERROR) << "Input param is nullptr";
943 return RET_ERROR;
944 }
945 std::map<int, std::map<std::string, std::string>> model_param_infos; // {model_index, {param_key:param_value}}
946 auto ret = InitConfigParam(param, &model_param_infos);
947 if (ret != RET_OK) {
948 MS_LOG(ERROR) << "Init config file failed: " << ret << " " << GetErrorInfo(ret);
949 return ret;
950 }
951 ret = StoreConverterParameters(param);
952 if (ret != RET_OK) {
953 MS_LOG(ERROR) << "Get converter parameter failed: " << ret << " " << GetErrorInfo(ret);
954 return ret;
955 }
956
957 ret = LoadPluginLib(param);
958 if (ret != RET_OK) {
959 MS_LOG(ERROR) << "Load plugin lib failed: " << ret << " " << GetErrorInfo(ret);
960 return ret;
961 }
962
963 if (model_param_infos.empty()) {
964 ret = CheckValueParam(param, false);
965 if (ret != RET_OK) {
966 MS_LOG(ERROR) << "Converter input parameters check valid failed";
967 return ret;
968 }
969 ret = HandleGraphCommon(param, model_data, data_size, not_save, false);
970 if (ret != RET_OK) {
971 MS_LOG(ERROR) << "Handle graph failed: " << ret << " " << GetErrorInfo(ret);
972 return ret;
973 }
974 } else {
975 size_t model_index = 0;
976 for (auto pair : model_param_infos) {
977 auto convert_param = CreateConvertParam(pair.second);
978 convert_param->microParam = param->microParam;
979 ret = CheckValueParam(convert_param, true);
980 if (ret != RET_OK) {
981 MS_LOG(ERROR) << "For model" << pair.first << ", converter input parameters check valid failed";
982 return ret;
983 }
984 if (model_index == model_param_infos.size() - 1) {
985 convert_param->microParam.is_last_model = true;
986 }
987 ret = HandleGraphCommon(convert_param, model_data, data_size, not_save, true);
988 if (ret != RET_OK) {
989 MS_LOG(ERROR) << "Handle graph failed: " << ret << " " << GetErrorInfo(ret);
990 return ret;
991 }
992 model_index++;
993 }
994 }
995 return RET_OK;
996 }
997
HandleGraphCommon(const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save,bool is_multi_model)998 int ConverterImpl::HandleGraphCommon(const std::shared_ptr<ConverterPara> ¶m, void **model_data, size_t *data_size,
999 bool not_save, bool is_multi_model) {
1000 if (param->fmk_type == converter::kFmkTypeMsLite) {
1001 if (!param->microParam.enable_micro) {
1002 MS_LOG(ERROR) << "When fmk is set to MSLITE, only support micronization.";
1003 return RET_NOT_SUPPORT;
1004 }
1005 auto ret = ExecuteMicro(nullptr, param, is_multi_model);
1006 if (ret != RET_OK) {
1007 MS_LOG(ERROR) << "Micronize msLite-model failed.";
1008 return ret;
1009 }
1010 return ret;
1011 }
1012 auto graph = ConverterFuncGraph::Build(param);
1013 if (graph == nullptr) {
1014 MS_LOG(ERROR) << "Build func graph failed";
1015 return RET_ERROR;
1016 }
1017
1018 int ret = ConverterFuncGraph::Optimize(param, graph);
1019 if (ret != RET_OK) {
1020 MS_LOG(ERROR) << "Optimize func graph failed: " << ret << " " << GetErrorInfo(ret);
1021 return ret;
1022 }
1023
1024 ret = SaveGraph(graph, param, model_data, data_size, not_save, is_multi_model);
1025 if (ret != RET_OK) {
1026 MS_LOG(ERROR) << "Save graph failed: " << ret << " " << GetErrorInfo(ret);
1027 return ret;
1028 }
1029 return RET_OK;
1030 }
1031
ExecuteMicro(const schema::MetaGraphT * meta_graph,const std::shared_ptr<ConverterPara> & param,bool is_multi_model)1032 int ConverterImpl::ExecuteMicro(const schema::MetaGraphT *meta_graph, const std::shared_ptr<ConverterPara> ¶m,
1033 bool is_multi_model) {
1034 std::string output_path = param->output_file;
1035 if (!is_multi_model) {
1036 param->microParam.is_last_model = true;
1037 } else {
1038 if (param->microParam.save_path.empty() || param->microParam.project_name.empty()) {
1039 MS_LOG(ERROR) << "Micro param for invalid: save_path or project name is needed";
1040 return RET_ERROR;
1041 }
1042 output_path = param->microParam.save_path + param->microParam.project_name;
1043 if (param->microParam.save_path[param->microParam.save_path.size() - 1] != '/' ||
1044 param->microParam.save_path[param->microParam.save_path.size() - 1] != '\\') {
1045 output_path = param->microParam.save_path + kSlash + param->microParam.project_name;
1046 }
1047 }
1048 auto status =
1049 meta_graph != nullptr
1050 ? micro::Coder::MicroSourceCodeGeneration(*meta_graph, output_path, ¶m->microParam, param->weight_fp16)
1051 : micro::Coder::MicroSourceCodeGeneration(param->model_file, output_path, ¶m->microParam, param->weight_fp16);
1052 if (status != RET_OK) {
1053 MS_LOG(ERROR) << "Execute Micro failed.";
1054 }
1055 return status;
1056 }
1057
SaveGraph(FuncGraphPtr graph,const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save,bool is_multi_model)1058 int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> ¶m, void **model_data,
1059 size_t *data_size, bool not_save, bool is_multi_model) {
1060 int status = RET_ERROR;
1061 if (param->save_type == kMindIR) {
1062 status = SaveMindIRModel(graph, param, model_data, data_size);
1063 if (status != RET_OK) {
1064 MS_LOG(ERROR) << "Save mindir model failed :" << status << " " << GetErrorInfo(status);
1065 return RET_ERROR;
1066 }
1067 return RET_OK;
1068 }
1069
1070 auto meta_graph = ConverterToMetaGraph::Build(param, graph);
1071 if (meta_graph == nullptr) {
1072 MS_LOG(ERROR) << "Convert to meta graph failed";
1073 return RET_ERROR;
1074 }
1075
1076 if (!param->cpuOptionCfgParam.architecture.empty()) {
1077 std::string cpu_option = param->cpuOptionCfgParam.architecture + param->cpuOptionCfgParam.instruction;
1078 status = ConverterPackedNode(meta_graph, cpu_option);
1079 if (status != RET_OK) {
1080 MS_LOG(ERROR) << "save pack info failed.";
1081 return status;
1082 }
1083 }
1084
1085 meta_graph->version = Version();
1086
1087 if (param->pre_infer) {
1088 status = PreInference(*meta_graph, param->train_model);
1089 if (status != RET_OK) {
1090 MS_LOG(ERROR) << "Preinference failed: " << status << " " << GetErrorInfo(status);
1091 delete meta_graph;
1092 return status;
1093 }
1094 }
1095
1096 if (param->microParam.enable_micro) {
1097 status = ExecuteMicro(meta_graph, param, is_multi_model);
1098 } else {
1099 status = ConverterToMetaGraph::Save(meta_graph, param, model_data, data_size, not_save);
1100 }
1101 delete meta_graph;
1102 if (status != RET_OK) {
1103 MS_LOG(ERROR) << "Save failed:" << status << " " << GetErrorInfo(status);
1104 return status;
1105 }
1106 return RET_OK;
1107 }
1108
SaveMindIRModel(FuncGraphPtr graph,const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size)1109 int ConverterImpl::SaveMindIRModel(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> ¶m, void **model_data,
1110 size_t *data_size) {
1111 int status = RET_OK;
1112 if (param->pre_infer) {
1113 schema::MetaGraphT *meta_graph = nullptr;
1114 auto new_param = std::make_shared<ConverterPara>();
1115 new_param->fmk_type = converter::kFmkTypeMs;
1116 new_param->save_type = kMindIR;
1117
1118 std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
1119 auto mirror_graph = lite::CloneFuncGraph(graph, new_param, &cloned_func_graph);
1120 if (mirror_graph == nullptr) {
1121 MS_LOG(ERROR) << "Mirror graph is nullptr";
1122 return RET_ERROR;
1123 }
1124 meta_graph = lite::ConverterToMetaGraph::Build(new_param, mirror_graph);
1125 if (meta_graph == nullptr) {
1126 MS_LOG(ERROR) << "FuncGraph convert to meta graph failed";
1127 return RET_ERROR;
1128 }
1129 status = PreInference(*meta_graph, param->train_model);
1130 if (status != RET_OK) {
1131 MS_LOG(ERROR) << "PreInferenceMindIR failed: " << status << " " << GetErrorInfo(status);
1132 return status;
1133 }
1134 }
1135 status = ConverterFuncGraph::Save(param, graph, model_data, data_size);
1136 if (status != RET_OK) {
1137 MS_LOG(ERROR) << "Export to mindir failed: " << status << " " << GetErrorInfo(status);
1138 return RET_ERROR;
1139 }
1140 return RET_OK;
1141 }
1142
RunConverter(const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save)1143 int RunConverter(const std::shared_ptr<ConverterPara> ¶m, void **model_data, size_t *data_size, bool not_save) {
1144 mindspore::mindspore_log_init();
1145
1146 param->aclModelOptionCfgParam.offline = !not_save;
1147 int status = RET_OK;
1148 ConverterImpl converter_impl;
1149 try {
1150 status = converter_impl.Convert(param, model_data, data_size, not_save);
1151 } catch (const std::exception &e) {
1152 MS_LOG(ERROR) << "Exception occurred: " << e.what();
1153 status = RET_ERROR;
1154 }
1155 if (status != RET_OK) {
1156 MS_LOG(ERROR) << "Convert model failed";
1157 NotSupportOp::GetInstance()->PrintOps();
1158 return status;
1159 }
1160
1161 MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status;
1162 std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
1163 return status;
1164 }
1165 } // namespace lite
1166 } // namespace mindspore
1167