• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/converter/converter.h"
19 #include <memory>
20 #include <vector>
21 #include <set>
22 #include <map>
23 #include <tuple>
24 #include <algorithm>
25 #include <utility>
26 #include "src/common/log_adapter.h"
27 #include "tools/common/meta_graph_serializer.h"
28 #include "tools/lite_exporter/anf_exporter.h"
29 #include "tools/graph_kernel/converter/graph_kernel_optimization.h"
30 #ifdef SUPPORT_TRAIN
31 #include "src/train/train_populate_parameter.h"
32 #endif
33 #include "include/registry/model_parser_registry.h"
34 #include "include/api/format.h"
35 #include "src/common/dynamic_library_loader.h"
36 #include "src/common/log_util.h"
37 #include "tools/converter/parser/parser_utils.h"
38 #include "tools/converter/import/mindspore_importer.h"
39 #include "nnacl/op_base.h"
40 #include "tools/converter/micro/coder/coder.h"
41 #include "src/common/prim_util.h"
42 #include "src/common/version_manager.h"
43 #include "tools/common/tensor_util.h"
44 #include "include/api/model.h"
45 #include "tools/mindir_exporter/mindir_serializer.h"
46 #include "src/common/primitive_t_utils.h"
47 #include "tools/converter/config_parser/acl_option_param_parser.h"
48 #include "tools/converter/config_parser/micro_param_parser.h"
49 #include "tools/converter/config_parser/preprocess_parser.h"
50 #include "tools/converter/config_parser/quant_param_parser.h"
51 #include "tools/converter/config_parser/graph_kernel_param_parser.h"
52 #include "tools/converter/config_parser/third_party_param_parser.h"
53 #include "tools/converter/converter_funcgraph.h"
54 #include "tools/converter/converter_metagraph.h"
55 #include "tools/common/string_util.h"
56 #include "src/common/file_utils.h"
57 #include "ops/dynamic_shape.h"
58 #include "tools/common/parse_config_utils.h"
59 #include "src/common/file_utils.h"
60 #include "tools/converter/converter_packed_node.h"
61 #include "tools/converter/config_parser/cpu_option_param_parser.h"
62 #include "tools/converter/export_model.h"
63 
64 namespace mindspore {
65 std::map<std::string, Format> StrToEnumFormatMap = {{"NHWC", Format::NHWC}, {"NCHW", Format::NCHW}};
66 extern "C" {
67 void mindspore_log_init();
68 }
69 namespace lite {
70 #define CONVERTER_LOG_ERROR(str)   \
71   do {                             \
72     MS_LOG(ERROR) << str;          \
73     std::cout << str << std::endl; \
74   } while (0);
75 
76 namespace {
77 constexpr size_t kMaxNum1024 = 1024;
78 constexpr size_t kPluginPathMaxNum = 10;
79 constexpr int kPathLengthUpperLimit = 1024;
80 constexpr size_t kEncMaxLen = 16;
81 constexpr auto kFmk = "fmk";
82 constexpr auto kModelFile = "modelFile";
83 constexpr auto kOutputFile = "outputFile";
84 constexpr auto kWeightFile = "weightFile";
85 constexpr auto kFp16 = "fp16";
86 constexpr auto kInputshape = "inputShape";
87 constexpr auto kInputDataFormat = "inputDataFormat";
88 constexpr auto kEncryptKey = "encryptKey";
89 constexpr auto kEncryption = "encryption";
90 constexpr auto kInputDataType = "inputDataType";
91 constexpr auto kOutputDataType = "outputDataType";
92 constexpr auto kInfer = "infer";
93 std::map<std::string, FmkType> StrToEnumFmkTypeMap = {
94   {"CAFFE", FmkType::kFmkTypeCaffe},  {"MINDIR", FmkType::kFmkTypeMs}, {"TFLITE", FmkType::kFmkTypeTflite},
95   {"ONNX", FmkType::kFmkTypeOnnx},    {"TF", FmkType::kFmkTypeTf},     {"PYTORCH", FmkType::kFmkTypePytorch},
96   {"MSLITE", FmkType::kFmkTypeMsLite}};
97 std::map<std::string, DataType> StrToEnumDataTypeMap = {{"FLOAT", DataType::kNumberTypeFloat32},
98                                                         {"INT8", DataType::kNumberTypeInt8},
99                                                         {"UINT8", DataType::kNumberTypeUInt8},
100                                                         {"DEFAULT", DataType::kTypeUnknown}};
101 
102 #if defined(_WIN32) || defined(_WIN64)
103 static const char kSlash[] = "\\";
104 #else
105 static const char kSlash[] = "/";
106 #endif
107 
108 // Deal with early release of 3rd-party plugin library.
109 static std::vector<std::shared_ptr<DynamicLibraryLoader>> dl_loaders;
FileExists(const std::string & path)110 bool FileExists(const std::string &path) {
111   std::ifstream file(path);
112   return file.good();
113 }
114 
InitModelFmk(const std::string & value,const std::shared_ptr<ConverterPara> & param)115 int InitModelFmk(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
116   if (StrToEnumFmkTypeMap.find(value) != StrToEnumFmkTypeMap.end()) {
117     param->fmk_type = StrToEnumFmkTypeMap.at(value);
118   } else {
119     std::cerr << "INPUT ILLEGAL: fmk must be TF|TFLITE|CAFFE|MINDIR|ONNX|MSLITE" << std::endl;
120     return RET_INPUT_PARAM_INVALID;
121   }
122   return RET_OK;
123 }
124 
InitModelFile(const std::string & value,const std::shared_ptr<ConverterPara> & param)125 int InitModelFile(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
126   if (value.empty() || !FileExists(value)) {
127     MS_LOG(ERROR) << "model file path is empty or invalid";
128     return RET_INPUT_PARAM_INVALID;
129   }
130   param->model_file = value;
131   return RET_OK;
132 }
133 
InitModelInputDataType(const std::string & value,const std::shared_ptr<ConverterPara> & param)134 int InitModelInputDataType(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
135   if (StrToEnumDataTypeMap.find(value) == StrToEnumDataTypeMap.end()) {
136     std::cerr << "INPUT INVALID: inputDataType is invalid, supported inputDataType: FLOAT | INT8 | UINT8 | DEFAULT"
137               << std::endl;
138     return RET_INPUT_PARAM_INVALID;
139   }
140   param->input_data_type = StrToEnumDataTypeMap.at(value);
141   return RET_OK;
142 }
143 
InitModelOutputDataType(const std::string & value,const std::shared_ptr<ConverterPara> & param)144 int InitModelOutputDataType(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
145   if (StrToEnumDataTypeMap.find(value) == StrToEnumDataTypeMap.end()) {
146     std::cerr << "OUTPUT INVALID: outputDataType is invalid, supported outputDataType: FLOAT | INT8 | UINT8 | DEFAULT"
147               << std::endl;
148     return RET_INPUT_PARAM_INVALID;
149   }
150   param->output_data_type = StrToEnumDataTypeMap.at(value);
151   return RET_OK;
152 }
153 
InitModelSaveFP16(const std::string & value,const std::shared_ptr<ConverterPara> & param)154 int InitModelSaveFP16(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
155   if (value == "on") {
156     param->weight_fp16 = true;
157   } else if (value.empty() || value == "off") {
158     param->weight_fp16 = false;
159   } else {
160     std::cerr << "Init save_fp16 failed." << std::endl;
161     return RET_INPUT_PARAM_INVALID;
162   }
163   return RET_OK;
164 }
165 
InitModelTrainMode(const std::string & value,const std::shared_ptr<ConverterPara> & param)166 int InitModelTrainMode(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
167   if (value == "true") {
168     param->train_model = true;
169   } else if (value.empty() || value == "false") {
170     param->train_model = false;
171   } else {
172     std::cerr << "INPUT ILLEGAL: trainModel must be true|false " << std::endl;
173     return RET_INPUT_PARAM_INVALID;
174   }
175   return RET_OK;
176 }
177 
InitModelInputShape(const std::string & value,const std::shared_ptr<ConverterPara> & param)178 int InitModelInputShape(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
179   if (value.empty()) {
180     return RET_OK;
181   }
182   std::vector<int64_t> shape;
183   auto shape_strs = lite::StrSplit(value, std::string(";"));
184   for (const auto &shape_str : shape_strs) {
185     if (shape_str.empty()) {
186       continue;
187     }
188     shape.clear();
189     auto string_split = lite::StrSplit(shape_str, std::string(":"));
190     constexpr int kMinShapeSizeInStr = 2;
191     if (string_split.size() < kMinShapeSizeInStr) {
192       MS_LOG(ERROR) << "shape size must not be less than " << kMinShapeSizeInStr;
193       return lite::RET_INPUT_PARAM_INVALID;
194     }
195     auto name = string_split[0];
196     for (size_t i = 1; i < string_split.size() - 1; ++i) {
197       name += ":" + string_split[i];
198     }
199     if (name.empty()) {
200       MS_LOG(ERROR) << "input tensor name is empty";
201       return lite::RET_INPUT_PARAM_INVALID;
202     }
203     auto dim_strs = string_split[string_split.size() - 1];
204     if (dim_strs.empty()) {
205       MS_LOG(ERROR) << "input tensor dim string is empty";
206       return lite::RET_INPUT_PARAM_INVALID;
207     }
208     auto dims = lite::StrSplit(dim_strs, std::string(","));
209     if (dims.empty()) {
210       MS_LOG(ERROR) << "input tensor dim is empty";
211       return lite::RET_INPUT_PARAM_INVALID;
212     }
213     for (const auto &dim : dims) {
214       int64_t dim_value;
215       try {
216         dim_value = std::stoi(dim);
217       } catch (const std::exception &e) {
218         MS_LOG(ERROR) << "Get dim failed: " << e.what();
219         return lite::RET_INPUT_PARAM_INVALID;
220       }
221       shape.push_back(dim_value);
222     }
223     param->input_shape[name] = shape;
224   }
225   return RET_OK;
226 }
227 
InitModelInputDataFormat(const std::string & value,const std::shared_ptr<ConverterPara> & param)228 int InitModelInputDataFormat(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
229   if (StrToEnumFormatMap.find(value) != StrToEnumFormatMap.end()) {
230     param->input_format = StrToEnumFormatMap.at(value);
231   } else if (!value.empty()) {
232     MS_LOG(ERROR) << "Input format is invalid.";
233     return RET_INPUT_PARAM_INVALID;
234   }
235   return RET_OK;
236 }
237 
InitModelInfer(const std::string & value,const std::shared_ptr<ConverterPara> & param)238 int InitModelInfer(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
239   if (value == "true") {
240     param->pre_infer = true;
241   } else if (value == "false" || value.empty()) {
242     param->pre_infer = false;
243   } else {
244     std::cerr << "INPUT ILLEGAL: infer must be true|false " << std::endl;
245     return RET_INPUT_PARAM_INVALID;
246   }
247   return RET_OK;
248 }
249 
InitModelNoFusion(const std::string & value,const std::shared_ptr<ConverterPara> & param)250 int InitModelNoFusion(const std::string &value, const std::shared_ptr<ConverterPara> &param) {
251   if (value == "true") {
252     param->no_fusion = true;
253   } else if (value == "false") {
254     param->no_fusion = false;
255   } else if (!value.empty()) {
256     std::cerr << "INPUT ILLEGAL: NoFusion must be true|false " << std::endl;
257     return RET_INPUT_PARAM_INVALID;
258   }
259   return RET_OK;
260 }
261 
CreateConvertParam(const std::map<std::string,string> & model_params)262 std::shared_ptr<ConverterPara> CreateConvertParam(const std::map<std::string, string> &model_params) {
263   auto parm = std::make_shared<ConverterPara>();
264   std::map<std::string, std::function<int(const std::string &, const std::shared_ptr<ConverterPara> &)>> parse_funcs = {
265     {"fmk", InitModelFmk},
266     {"modelFile", InitModelFile},
267     {"inputDataType", InitModelInputDataType},
268     {"outputDataType", InitModelOutputDataType},
269     {"fp16", InitModelSaveFP16},
270     {"trainModel", InitModelTrainMode},
271     {"inputShape", InitModelInputShape},
272     {"inputDataFormat", InitModelInputDataFormat},
273     {"infer", InitModelInfer},
274     {"NoFusion", InitModelNoFusion}};
275   if (model_params.find("fmk") == model_params.end() || model_params.find("modelFile") == model_params.end()) {
276     MS_LOG(ERROR) << "INPUT ILLEGAL: fmk and modelFile must be set in [model_param].";
277     return nullptr;
278   }
279   for (auto &pair : model_params) {
280     if (parse_funcs.find(pair.first) == parse_funcs.end()) {
281       MS_LOG(ERROR) << "INPUT ILLEGAL: `" << pair.first << "` is not supported in [model_param]";
282       return nullptr;
283     }
284     if (parse_funcs[pair.first](pair.second, parm) != RET_OK) {
285       MS_LOG(ERROR) << pair.first << "'value is invalid";
286       return nullptr;
287     }
288   }
289   return parm;
290 }
291 }  // namespace
292 
StoreConverterParameters(const std::shared_ptr<ConverterPara> & param)293 STATUS StoreConverterParameters(const std::shared_ptr<ConverterPara> &param) {
294   if (param == nullptr) {
295     MS_LOG(ERROR) << "Input param is nullptr";
296     return RET_INPUT_PARAM_INVALID;
297   }
298   std::string param_input_shape;
299   for (auto i = param->input_shape.cbegin(); i != param->input_shape.cend(); ++i) {
300     std::stringstream input_shape_ss;
301     string input_shape_str;
302     (void)copy(i->second.begin(), i->second.end(), std::ostream_iterator<int>(input_shape_ss, ","));
303     input_shape_str = input_shape_ss.str();
304     input_shape_str.erase(input_shape_str.end() - 1);
305     param_input_shape += i->first + ":" + input_shape_str + ";";
306   }
307   std::map<std::string, std::map<std::string, std::string>> conver_param_maps;
308   conver_param_maps[mindspore::converter::KConverterParam][kFmk] = std::to_string(param->fmk_type);
309   conver_param_maps[mindspore::converter::KConverterParam][kModelFile] = param->model_file;
310   conver_param_maps[mindspore::converter::KConverterParam][kOutputFile] = param->output_file;
311   conver_param_maps[mindspore::converter::KConverterParam][kWeightFile] = param->weight_file;
312   std::stringstream weight_fp16_ss;
313   weight_fp16_ss << std::boolalpha << param->weight_fp16;
314   conver_param_maps[mindspore::converter::KConverterParam][kFp16] = weight_fp16_ss.str();
315   conver_param_maps[mindspore::converter::KConverterParam][kInputshape] = param_input_shape;
316   conver_param_maps[mindspore::converter::KConverterParam][kInputDataFormat] = std::to_string(param->input_format);
317   conver_param_maps[mindspore::converter::KConverterParam][kEncryptKey] = param->encrypt_key;
318   std::stringstream encryption_ss;
319   encryption_ss << std::boolalpha << param->enable_encryption;
320   conver_param_maps[mindspore::converter::KConverterParam][kEncryption] = encryption_ss.str();
321   conver_param_maps[mindspore::converter::KConverterParam][kInputDataType] =
322     std::to_string(static_cast<int>(param->input_data_type));
323   conver_param_maps[mindspore::converter::KConverterParam][kOutputDataType] =
324     std::to_string(static_cast<int>(param->output_data_type));
325   std::stringstream pre_infer_ss;
326   pre_infer_ss << std::boolalpha << param->pre_infer;
327   conver_param_maps[mindspore::converter::KConverterParam][kInfer] = pre_infer_ss.str();
328   for (const auto &config_info : conver_param_maps) {
329     ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
330   }
331   return RET_OK;
332 }
333 
CheckExistCustomOps(const schema::MetaGraphT * meta_graph,bool * exist_custom_nodes)334 int CheckExistCustomOps(const schema::MetaGraphT *meta_graph, bool *exist_custom_nodes) {
335   MS_CHECK_TRUE_MSG(meta_graph != nullptr && exist_custom_nodes != nullptr, RET_ERROR, "input params contain nullptr.");
336   flatbuffers::FlatBufferBuilder fbb(kMaxNum1024);
337   for (const auto &node : meta_graph->nodes) {
338     MS_CHECK_TRUE_RET(node != nullptr, RET_ERROR);
339     auto prim = ConvertToPrimitive(node->primitive.get(), &fbb);
340     if (prim == nullptr) {
341       MS_LOG(ERROR) << "get primitive failed.";
342       fbb.Clear();
343       return RET_ERROR;
344     }
345     if (IsCustomNode(prim, static_cast<int>(SCHEMA_CUR))) {
346       *exist_custom_nodes = true;
347       break;
348     }
349   }
350   fbb.Clear();
351   return RET_OK;
352 }
353 
PreInference(const schema::MetaGraphT & meta_graph,bool train_model)354 int PreInference(const schema::MetaGraphT &meta_graph, bool train_model) {
355   if (train_model) {
356     MS_LOG(WARNING) << "train model dont support pre-infer.";
357     return RET_OK;
358   }
359 
360   bool exist_custom_nodes = false;
361   auto check_ret = CheckExistCustomOps(&meta_graph, &exist_custom_nodes);
362   if (check_ret == RET_ERROR) {
363     MS_LOG(ERROR) << "CheckExistCustomOps failed.";
364     return RET_ERROR;
365   }
366   if (exist_custom_nodes) {
367     MS_LOG(WARNING) << "exist custom nodes and will not be pre-infer.";
368     return RET_OK;
369   }
370   mindspore::Model model;
371   flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
372   auto offset = schema::MetaGraph::Pack(builder, &meta_graph);
373   builder.Finish(offset);
374   schema::FinishMetaGraphBuffer(builder, offset);
375   int size = builder.GetSize();
376   auto content = builder.GetBufferPointer();
377   if (content == nullptr) {
378     MS_LOG(ERROR) << "GetBufferPointer nullptr";
379     return RET_ERROR;
380   }
381   auto context = std::make_shared<mindspore::Context>();
382   if (context == nullptr) {
383     MS_LOG(ERROR) << "New context failed while running ";
384     std::cerr << "New context failed while running " << std::endl;
385     return RET_ERROR;
386   }
387   std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
388   auto &device_list = context->MutableDeviceInfo();
389   device_list.push_back(device_info);
390 
391   auto ret = model.Build(content, size, kMindIR_Lite, context);
392   if (ret != kSuccess) {
393     MS_LOG(ERROR) << "Build error ";
394     std::cerr << "Build error " << std::endl;
395     return RET_ERROR;
396   }
397   for (auto &tensor : model.GetInputs()) {
398     if (tensor.Shape().empty() || tensor.DataSize() == 0 ||
399         std::find(tensor.Shape().begin(), tensor.Shape().end(), -1) != tensor.Shape().end()) {
400       MS_LOG(WARNING) << tensor.Name() << " is dynamic shape and will not be pre-infer.";
401       return RET_OK;
402     }
403     auto status = GenerateRandomData(&tensor);
404     if (status != RET_OK) {
405       MS_LOG(ERROR) << tensor.Name() << "GenerateRandomData failed.";
406       return status;
407     }
408   }
409   std::vector<MSTensor> outputs;
410   ret = model.Predict(model.GetInputs(), &outputs);
411   if (ret != kSuccess) {
412     MS_LOG(ERROR) << "Inference error ";
413     std::cerr << "Inference error " << std::endl;
414     return RET_ERROR;
415   }
416   return RET_OK;
417 }
418 
InitConfigParam(const std::shared_ptr<ConverterPara> & param,std::map<int,std::map<std::string,std::string>> * model_param_infos)419 int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> &param,
420                                    std::map<int, std::map<std::string, std::string>> *model_param_infos) {
421   model_param_infos->clear();
422   lite::ConfigFileParser config_parser;
423   std::map<std::string, std::map<std::string, std::string>> maps;
424   auto ret = RET_OK;
425   auto parse_map_ret = RET_OK;
426   if (!param->config_file.empty()) {
427     ret = config_parser.ParseConfigFile(param->config_file, nullptr);
428     parse_map_ret = mindspore::lite::ParseConfigFile(param->config_file, &maps, model_param_infos);
429   } else {
430     ret = config_parser.ParseConfigParam(&param->config_param);
431   }
432   if (ret != RET_OK || parse_map_ret != RET_OK) {
433     MS_LOG(ERROR) << "Parse config param failed.";
434     return ret;
435   }
436   if (model_param_infos->empty()) {
437     ret = lite::PreprocessParser::ParsePreprocess(config_parser.GetDataPreProcessString(), &param->dataPreProcessParam);
438     if (ret != RET_OK) {
439       MS_LOG(ERROR) << "Parse preprocess failed.";
440       return ret;
441     }
442     ret = lite::QuantParamParser::ParseCommonQuant(config_parser.GetCommonQuantString(), &param->commonQuantParam);
443     if (ret != RET_OK) {
444       MS_LOG(ERROR) << "Parse common quant param failed.";
445       return ret;
446     }
447     ret = lite::QuantParamParser::ParseFullQuant(config_parser.GetFullQuantString(), &param->fullQuantParam);
448     if (ret != RET_OK) {
449       MS_LOG(ERROR) << "Parse full quant param failed.";
450       return ret;
451     }
452     ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(),
453                                                            &param->mixedBitWeightQuantParam);
454     if (ret != RET_OK) {
455       MS_LOG(ERROR) << "Parse mixed bit weight quant param failed.";
456       return ret;
457     }
458     if (param->fmk_type == converter::kFmkTypeThirdParty) {
459       ret = lite::ThirdPartyParamParser::Parse(config_parser.GetThirdPartyModelString(), &param->thirdPartyModelParam);
460       if (ret != RET_OK) {
461         MS_LOG(ERROR) << "Parse third party param failed.";
462         return ret;
463       }
464     }
465     ret = InitExtendedIntegrationInfo(param, config_parser);
466     if (ret != RET_OK) {
467       MS_LOG(ERROR) << "Parse extended integration info failed.";
468       return ret;
469     }
470     std::string output_file = param->output_file;
471     param->aclModelOptionCfgParam.om_file_path = output_file;
472     auto dir_pos = output_file.find_last_of('/');
473     param->aclModelOptionCfgParam.dump_model_name =
474       dir_pos != std::string::npos ? output_file.substr(dir_pos + 1) : output_file;
475     lite::AclOptionParamParser acl_param_parser;
476     ret = acl_param_parser.ParseAclOptionCfg(config_parser.GetAclOptionCfgString(), &param->aclModelOptionCfgParam);
477     if (ret != RET_OK) {
478       MS_LOG(ERROR) << "Parse acl option param failed.";
479       return ret;
480     }
481     // parse ascend_context in config file, the priority is higher
482     if (maps.find("ascend_context") != maps.end()) {
483       auto map = maps.at("ascend_context");
484       if (!config_parser.SetParamByConfigfile(param, map)) {
485         MS_LOG(ERROR) << "Failed to parse config item ascend_context";
486         return RET_ERROR;
487       }
488     }
489     if (!param->config_file.empty()) {
490       (void)CheckOfflineParallelConfig(param->config_file, &param->parallel_split_config);
491     }
492 
493     lite::CpuOptionParamParser cpu_param_parser;
494     ret = cpu_param_parser.ParseCpuOptionCfg(config_parser.GetCpuOptionCfgString(), &param->cpuOptionCfgParam);
495     if (ret != RET_OK) {
496       MS_LOG(ERROR) << "Parse cpu option param failed.";
497       return ret;
498     }
499   }
500   MS_LOG(INFO)
501     << "If there are multi models, only support micro_param and model_param, other configure can not take effect";
502 
503   lite::MicroParamParser micro_param_parser;
504   ret = micro_param_parser.ParseMicroParam(config_parser.GetMicroParamString(), &param->microParam);
505   if (ret != RET_OK) {
506     MS_LOG(ERROR) << "Parse micro param failed.";
507     return ret;
508   }
509   ret =
510     lite::QuantParamParser::ParseTransformQuant(config_parser.GetTransformQuantString(), &param->transformQuantParam);
511   if (ret != RET_OK) {
512     MS_LOG(ERROR) << "Parse transform quant param failed.";
513     return ret;
514   }
515   ret = lite::QuantParamParser::ParseDynamicQuant(config_parser.GetDynamicQuantString(), &param->dynamicQuantParam);
516   if (ret != RET_OK) {
517     MS_LOG(ERROR) << "Parse dynamic quant param failed.";
518     return ret;
519   }
520   lite::GraphKernelParamParser graph_kernel_parser;
521   ret = graph_kernel_parser.ParseGraphKernelCfg(config_parser.GetGraphKernelString(), &param->graphKernelParam);
522   if (ret != RET_OK) {
523     MS_LOG(ERROR) << "Parse graph kernel param failed.";
524     return ret;
525   }
526   return RET_OK;
527 }
528 
InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPara> & param,const lite::ConfigFileParser & config_parser)529 int ConverterImpl::InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPara> &param,
530                                                const lite::ConfigFileParser &config_parser) {
531   auto extended_info = config_parser.GetRegistryInfoString();
532   if (!extended_info.plugin_path.empty()) {
533     const char delimiter = ';';
534     auto relative_path = lite::SplitStringToVector(extended_info.plugin_path, delimiter);
535     if (relative_path.size() > kPluginPathMaxNum) {
536       MS_LOG(ERROR) << "extended plugin library's num is too big, which shouldn't be larger than " << kPluginPathMaxNum;
537       return RET_INPUT_PARAM_INVALID;
538     }
539     for (auto &i : relative_path) {
540       param->plugins_path.push_back(lite::RealPath(i.c_str()));
541     }
542   }
543 
544   if (!extended_info.disable_fusion.empty()) {
545     if (extended_info.disable_fusion == "on") {
546       param->no_fusion = true;
547     } else if (extended_info.disable_fusion == "off") {
548       param->no_fusion = false;
549     } else {
550       std::cerr << "CONFIG SETTING ILLEGAL: disable_fusion should be on/off" << std::endl;
551       return RET_INPUT_PARAM_INVALID;
552     }
553   }
554 
555   if (!extended_info.fusion_blacklists.empty()) {
556     std::vector<std::string> fusions = SplitStringToVector(extended_info.fusion_blacklists, ",");
557     for (const auto &fusion : fusions) {
558       bool inserted = false;
559       std::tie(std::ignore, inserted) = param->fusion_blacklists.insert(fusion);
560       if (inserted) {
561         MS_LOG(DEBUG) << "Value was inserted successfully.";
562       }
563     }
564   }
565   return RET_OK;
566 }
567 
CheckOfflineParallelConfig(const std::string & file,ParallelSplitConfig * parallel_split_config)568 bool ConverterImpl::CheckOfflineParallelConfig(const std::string &file, ParallelSplitConfig *parallel_split_config) {
569   // device: [device0 device1] ---> {cpu, gpu}
570   // computeRate: [x: y] x >=0 && y >=0 && x/y < 10
571   MS_ASSERT(parallel_split_config != nullptr);
572   std::vector<std::string> config_devices = {"cpu", "gpu", "npu"};
573   auto compute_rate_result = GetStrFromConfigFile(file, kComputeRate);
574   if (compute_rate_result.empty()) {
575     return false;
576   }
577   std::string device0_result = GetStrFromConfigFile(file, kSplitDevice0);
578   if (device0_result.empty()) {
579     return false;
580   }
581   std::string device1_result = GetStrFromConfigFile(file, kSplitDevice1);
582   if (device1_result.empty()) {
583     return false;
584   }
585   bool device0_flag = false;
586   bool device1_flag = false;
587   for (const auto &device : config_devices) {
588     if (device == device0_result) {
589       device0_flag = true;
590     }
591     if (device == device1_result) {
592       device1_flag = true;
593     }
594   }
595   if (!device0_flag || !device1_flag) {
596     return false;
597   }
598   const char delimiter = ';';
599   std::vector<std::string> device_rates = lite::SplitStringToVector(compute_rate_result, delimiter);
600   const char colon = ':';
601   for (const auto &device : device_rates) {
602     std::vector<std::string> rate = lite::SplitStringToVector(device, colon);
603     int64_t compute_rate = 0;
604     try {
605       compute_rate = std::stoi(rate.back());
606     } catch (const std::exception &e) {
607       MS_LOG(ERROR) << "Get compute rate failed: " << e.what();
608       return false;
609     }
610     parallel_split_config->parallel_compute_rates_.push_back(compute_rate);
611   }
612   const size_t support_rates_num = 2;
613   if (parallel_split_config->parallel_compute_rates_.size() != support_rates_num) {
614     return false;
615   }
616   int64_t bigger_rate = INT32_MIN;
617   int64_t smaller_rate = INT32_MAX;
618   for (const auto &rate : parallel_split_config->parallel_compute_rates_) {
619     if (rate <= 0 || rate > INT32_MAX) {
620       return false;
621     }
622     bigger_rate = std::max(rate, bigger_rate);
623     smaller_rate = std::min(rate, smaller_rate);
624   }
625   parallel_split_config->parallel_devices_.push_back(device0_result);
626   parallel_split_config->parallel_devices_.push_back(device1_result);
627   // parall_split_type will extend by other user's attr
628   parallel_split_config->parallel_split_type_ = SplitByUserRatio;
629   if (smaller_rate == 0) {
630     MS_LOG(ERROR) << "smaller_rate is zero";
631     return false;
632   }
633   return bigger_rate / smaller_rate <= kMaxSplitRatio;
634 }
635 
GetStrFromConfigFile(const std::string & file,const std::string & target_key)636 std::string ConverterImpl::GetStrFromConfigFile(const std::string &file, const std::string &target_key) {
637   std::string res;
638   if (file.empty()) {
639     MS_LOG(ERROR) << "file is nullptr";
640     return res;
641   }
642   auto resolved_path = std::make_unique<char[]>(PATH_MAX);
643   if (resolved_path == nullptr) {
644     MS_LOG(ERROR) << "new resolved_path failed";
645     return "";
646   }
647 
648 #ifdef _WIN32
649   auto *real_path = _fullpath(resolved_path.get(), file.c_str(), kPathLengthUpperLimit);
650 #else
651   char *real_path = realpath(file.c_str(), resolved_path.get());
652 #endif
653   if (real_path == nullptr || strlen(real_path) == 0) {
654     MS_LOG(ERROR) << "file path is not valid : " << file;
655     return "";
656   }
657   std::ifstream ifs(resolved_path.get());
658   if (!ifs.good()) {
659     MS_LOG(ERROR) << "file: " << real_path << " is not exist";
660     return res;
661   }
662   if (!ifs.is_open()) {
663     MS_LOG(ERROR) << "file: " << real_path << "open failed";
664     return res;
665   }
666   std::string line;
667   while (std::getline(ifs, line)) {
668     lite::Trim(&line);
669     if (line.empty() || line.at(0) == '#' || line.at(0) == '[') {
670       continue;
671     }
672     auto index = line.find('=');
673     if (index == std::string::npos) {
674       MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
675       return "";
676     }
677     auto key = line.substr(0, index);
678     auto value = line.substr(index + 1);
679     lite::Trim(&key);
680     lite::Trim(&value);
681     if (key == target_key) {
682       return value;
683     }
684   }
685   return res;
686 }
687 
CheckFmkType(const std::shared_ptr<ConverterPara> & param)688 int CheckFmkType(const std::shared_ptr<ConverterPara> &param) {
689   if (param != nullptr) {
690     return RET_OK;
691   }
692   std::set kValidFmkTypes = {FmkType::kFmkTypeTf,     FmkType::kFmkTypeCaffe,     FmkType::kFmkTypeOnnx,
693                              FmkType::kFmkTypeMs,     FmkType::kFmkTypeTflite,    FmkType::kFmkTypePytorch,
694                              FmkType::kFmkTypeMsLite, FmkType::kFmkTypeThirdParty};
695   if (kValidFmkTypes.find(param->fmk_type) == kValidFmkTypes.end()) {
696     MS_LOG(ERROR) << "INPUT ILLEGAL: fmk_type must be "
697                      "TF|CAFFE|ONNX|MS|TFLITE|PYTORCH|MSLITE|THIRDPARTY"
698                   << ", but got " << param->fmk_type;
699     return RET_INPUT_PARAM_INVALID;
700   }
701   if (param->fmk_type != converter::kFmkTypeCaffe && !param->weight_file.empty()) {
702     MS_LOG(ERROR) << "INPUT ILLEGAL: weight_file is not a valid flag";
703     return RET_INPUT_PARAM_INVALID;
704   }
705   return RET_OK;
706 }
707 
CheckModelFile(const std::shared_ptr<ConverterPara> & param)708 int CheckModelFile(const std::shared_ptr<ConverterPara> &param) {
709   if (param != nullptr) {
710     if (param->model_file.empty()) {
711       MS_LOG(ERROR) << "INPUT MISSING: model file path is necessary";
712       return RET_INPUT_PARAM_INVALID;
713     }
714   }
715   return RET_OK;
716 }
717 
CheckOutputFile(const std::shared_ptr<ConverterPara> & param)718 int CheckOutputFile(const std::shared_ptr<ConverterPara> &param) {
719   if (param != nullptr && param->aclModelOptionCfgParam.offline) {
720     if (param->output_file.empty()) {
721       MS_LOG(ERROR) << "INPUT MISSING: output file path is necessary";
722       return RET_INPUT_PARAM_INVALID;
723     }
724 
725 #ifdef _WIN32
726     replace(param->output_file.begin(), param->output_file.end(), '/', '\\');
727 #endif
728 
729     if (param->output_file.rfind('/') == param->output_file.length() - 1 ||
730         param->output_file.rfind('\\') == param->output_file.length() - 1) {
731       MS_LOG(ERROR) << "INPUT ILLEGAL: output file must be a valid file path";
732       return RET_INPUT_PARAM_INVALID;
733     }
734     param->aclModelOptionCfgParam.om_file_path = param->output_file;
735   }
736   return RET_OK;
737 }
738 
CheckInputShape(const std::shared_ptr<ConverterPara> & param)739 int CheckInputShape(const std::shared_ptr<ConverterPara> &param) {
740   if (param != nullptr) {
741     if (param->input_shape.empty()) {
742       return RET_OK;
743     }
744     for (const auto &elem : param->input_shape) {
745       std::vector<int64_t> dims = elem.second;
746       if (dims.empty()) {
747         MS_LOG(ERROR) << "INPUT MISSING: input tensor dim is empty";
748         return lite::RET_INPUT_PARAM_INVALID;
749       }
750       bool has_negative_dim = std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim < 0; });
751       if (has_negative_dim) {
752         MS_LOG(ERROR) << "INPUT ILLEGAL: Unsupported dim < 0.";
753         return lite::RET_INPUT_PARAM_INVALID;
754       }
755     }
756   }
757   return RET_OK;
758 }
759 
CheckInputFormat(const std::shared_ptr<ConverterPara> & param)760 int CheckInputFormat(const std::shared_ptr<ConverterPara> &param) {
761   if (param != nullptr) {
762     std::set valid_values = {NHWC, NCHW};
763     if (std::find(valid_values.begin(), valid_values.end(), param->input_format) == valid_values.end()) {
764       MS_LOG(ERROR) << "INPUT ILLEGAL: input_format is not in {NHWC, NCHW}, but got " << param->input_format;
765       return RET_INPUT_PARAM_INVALID;
766     }
767   }
768   return RET_OK;
769 }
770 
CheckInputOutputDataType(const std::shared_ptr<ConverterPara> & param)771 int CheckInputOutputDataType(const std::shared_ptr<ConverterPara> &param) {
772   if (param != nullptr) {
773     std::set input_valid_values = {
774       DataType::kNumberTypeFloat16, DataType::kNumberTypeFloat32, DataType::kNumberTypeInt8, DataType::kNumberTypeUInt8,
775       DataType::kNumberTypeInt32,   DataType::kNumberTypeInt64,   DataType::kTypeUnknown};
776     if (std::find(input_valid_values.begin(), input_valid_values.end(), param->input_data_type) ==
777         input_valid_values.end()) {
778       MS_LOG(ERROR) << "INPUT ILLEGAL: input_data_type is not in {kNumberTypeFloat32, kNumberTypeInt8, "
779                     << "kNumberTypeUInt8, kTypeUnknown}, but got " << param->input_data_type;
780       return RET_INPUT_PARAM_INVALID;
781     }
782 
783     std::set output_valid_values = {DataType::kNumberTypeFloat16, DataType::kNumberTypeFloat32,
784                                     DataType::kNumberTypeInt8, DataType::kNumberTypeUInt8, DataType::kTypeUnknown};
785     if (std::find(output_valid_values.begin(), output_valid_values.end(), param->output_data_type) ==
786         output_valid_values.end()) {
787       MS_LOG(ERROR) << "INPUT ILLEGAL: output_data_type is not in {kNumberTypeFloat32, kNumberTypeInt8, "
788                     << "kNumberTypeUInt8, kTypeUnknown}, but got " << param->output_data_type;
789       return RET_INPUT_PARAM_INVALID;
790     }
791   }
792   return RET_OK;
793 }
794 
CheckSaveType(const std::shared_ptr<ConverterPara> & param)795 int CheckSaveType(const std::shared_ptr<ConverterPara> &param) {
796   if (param != nullptr) {
797     std::set valid_values = {kMindIR, kMindIR_Lite};
798     if (std::find(valid_values.begin(), valid_values.end(), param->save_type) == valid_values.end()) {
799       MS_LOG(ERROR) << "INPUT ILLEGAL: save_type is not in {kMindIR, kMindIR_Lite}, but got " << param->save_type;
800       return RET_INPUT_PARAM_INVALID;
801     }
802   }
803   return RET_OK;
804 }
805 
CheckEncrypt(const std::shared_ptr<ConverterPara> & param)806 int CheckEncrypt(const std::shared_ptr<ConverterPara> &param) {
807   if (param != nullptr) {
808     if (param->enable_encryption) {
809       if (param->encrypt_key.empty()) {
810         MS_LOG(ERROR) << "encryption param is true and encrypt_key param must be set. If you don't "
811                          "need to use model encryption, please set encryption param to false.";
812         return RET_INPUT_PARAM_INVALID;
813       }
814     }
815   }
816   return RET_OK;
817 }
818 
CheckTrainModel(const std::shared_ptr<ConverterPara> & param)819 int CheckTrainModel(const std::shared_ptr<ConverterPara> &param) {
820   if (param != nullptr) {
821     if (param->train_model) {
822       if (param->fmk_type != converter::FmkType::kFmkTypeMs) {
823         MS_LOG(ERROR) << "INPUT ILLEGAL: train model converter supporting only MINDIR format";
824         return RET_INPUT_PARAM_INVALID;
825       }
826       if ((param->input_data_type != DataType::kNumberTypeFloat32) &&
827           (param->input_data_type != DataType::kTypeUnknown)) {
828         MS_LOG(ERROR) << "INPUT ILLEGAL: train model converter supporting only FP32 input tensors";
829         return RET_INPUT_PARAM_INVALID;
830       }
831       if ((param->output_data_type != DataType::kNumberTypeFloat32) &&
832           (param->output_data_type != DataType::kTypeUnknown)) {
833         MS_LOG(ERROR) << "INPUT ILLEGAL: train model converter supporting only FP32 output tensors";
834         return RET_INPUT_PARAM_INVALID;
835       }
836     }
837   }
838   return RET_OK;
839 }
840 
CheckDevice(const std::shared_ptr<ConverterPara> & param)841 int CheckDevice(const std::shared_ptr<ConverterPara> &param) {
842   if (param != nullptr && !(param->device.empty())) {
843     std::set valid_values = {"Ascend310", "Ascend310P", "Ascend", "GPU"};
844     if (std::find(valid_values.begin(), valid_values.end(), param->device) == valid_values.end()) {
845       MS_LOG(ERROR) << "INPUT ILLEGAL: device is not in {GPU, Ascend, Ascend310, Ascend310P}, got " << param->device;
846       return RET_INPUT_PARAM_INVALID;
847     }
848   }
849   return RET_OK;
850 }
851 
CheckValueParam(const std::shared_ptr<ConverterPara> & param,bool is_multi_model)852 int CheckValueParam(const std::shared_ptr<ConverterPara> &param, bool is_multi_model) {
853   if (param == nullptr) {
854     MS_LOG(ERROR) << "INPUT MISSING: param is nullptr.";
855     return RET_INPUT_PARAM_INVALID;
856   }
857 
858   auto ret = CheckFmkType(param);
859   if (ret != RET_OK) {
860     MS_LOG(ERROR) << "Check value of fmk_type failed.";
861     return RET_INPUT_PARAM_INVALID;
862   }
863 
864   ret = CheckModelFile(param);
865   if (ret != RET_OK) {
866     MS_LOG(ERROR) << "Check value of model_file failed.";
867     return RET_INPUT_PARAM_INVALID;
868   }
869 
870   if (!is_multi_model) {
871     ret = CheckOutputFile(param);
872     if (ret != RET_OK) {
873       MS_LOG(ERROR) << "Check value of output_file failed.";
874       return RET_INPUT_PARAM_INVALID;
875     }
876   }
877 
878   ret = CheckInputShape(param);
879   if (ret != RET_OK) {
880     MS_LOG(ERROR) << "Check value of input_shape failed.";
881     return RET_INPUT_PARAM_INVALID;
882   }
883 
884   ret = CheckInputFormat(param);
885   if (ret != RET_OK) {
886     MS_LOG(ERROR) << "Check value of input_format failed.";
887     return RET_INPUT_PARAM_INVALID;
888   }
889 
890   ret = CheckInputOutputDataType(param);
891   if (ret != RET_OK) {
892     MS_LOG(ERROR) << "Check value of input_data_type or output_data_type failed.";
893     return RET_INPUT_PARAM_INVALID;
894   }
895 
896   ret = CheckSaveType(param);
897   if (ret != RET_OK) {
898     MS_LOG(ERROR) << "Check value of save_type failed.";
899     return RET_INPUT_PARAM_INVALID;
900   }
901 
902   ret = CheckEncrypt(param);
903   if (ret != RET_OK) {
904     MS_LOG(ERROR) << "Check value of encrypt failed.";
905     return RET_INPUT_PARAM_INVALID;
906   }
907 
908   ret = CheckTrainModel(param);
909   if (ret != RET_OK) {
910     MS_LOG(ERROR) << "Check value of train model failed.";
911     return RET_INPUT_PARAM_INVALID;
912   }
913 
914   ret = CheckDevice(param);
915   if (ret != RET_OK) {
916     MS_LOG(ERROR) << "Check device failed.";
917     return RET_INPUT_PARAM_INVALID;
918   }
919 
920   return RET_OK;
921 }
922 
LoadPluginLib(const std::shared_ptr<ConverterPara> & param)923 int ConverterImpl::LoadPluginLib(const std::shared_ptr<ConverterPara> &param) {
924   if (!param->plugins_path.empty()) {
925     for (auto &path : param->plugins_path) {
926       auto dl_loader = std::make_shared<DynamicLibraryLoader>();
927       MS_CHECK_TRUE_RET(dl_loader != nullptr, RET_ERROR);
928       auto status = dl_loader->Open(path);
929       if (status != RET_OK) {
930         MS_LOG(ERROR) << "open dynamic library failed. " << path;
931         return RET_ERROR;
932       }
933       dl_loaders.emplace_back(dl_loader);
934     }
935   }
936   return RET_OK;
937 }
938 
Convert(const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save)939 int ConverterImpl::Convert(const std::shared_ptr<ConverterPara> &param, void **model_data, size_t *data_size,
940                            bool not_save) {
941   if (param == nullptr) {
942     MS_LOG(ERROR) << "Input param is nullptr";
943     return RET_ERROR;
944   }
945   std::map<int, std::map<std::string, std::string>> model_param_infos;  // {model_index, {param_key:param_value}}
946   auto ret = InitConfigParam(param, &model_param_infos);
947   if (ret != RET_OK) {
948     MS_LOG(ERROR) << "Init config file failed: " << ret << " " << GetErrorInfo(ret);
949     return ret;
950   }
951   ret = StoreConverterParameters(param);
952   if (ret != RET_OK) {
953     MS_LOG(ERROR) << "Get converter parameter failed: " << ret << " " << GetErrorInfo(ret);
954     return ret;
955   }
956 
957   ret = LoadPluginLib(param);
958   if (ret != RET_OK) {
959     MS_LOG(ERROR) << "Load plugin lib failed: " << ret << " " << GetErrorInfo(ret);
960     return ret;
961   }
962 
963   if (model_param_infos.empty()) {
964     ret = CheckValueParam(param, false);
965     if (ret != RET_OK) {
966       MS_LOG(ERROR) << "Converter input parameters check valid failed";
967       return ret;
968     }
969     ret = HandleGraphCommon(param, model_data, data_size, not_save, false);
970     if (ret != RET_OK) {
971       MS_LOG(ERROR) << "Handle graph failed: " << ret << " " << GetErrorInfo(ret);
972       return ret;
973     }
974   } else {
975     size_t model_index = 0;
976     for (auto pair : model_param_infos) {
977       auto convert_param = CreateConvertParam(pair.second);
978       convert_param->microParam = param->microParam;
979       ret = CheckValueParam(convert_param, true);
980       if (ret != RET_OK) {
981         MS_LOG(ERROR) << "For model" << pair.first << ", converter input parameters check valid failed";
982         return ret;
983       }
984       if (model_index == model_param_infos.size() - 1) {
985         convert_param->microParam.is_last_model = true;
986       }
987       ret = HandleGraphCommon(convert_param, model_data, data_size, not_save, true);
988       if (ret != RET_OK) {
989         MS_LOG(ERROR) << "Handle graph failed: " << ret << " " << GetErrorInfo(ret);
990         return ret;
991       }
992       model_index++;
993     }
994   }
995   return RET_OK;
996 }
997 
HandleGraphCommon(const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save,bool is_multi_model)998 int ConverterImpl::HandleGraphCommon(const std::shared_ptr<ConverterPara> &param, void **model_data, size_t *data_size,
999                                      bool not_save, bool is_multi_model) {
1000   if (param->fmk_type == converter::kFmkTypeMsLite) {
1001     if (!param->microParam.enable_micro) {
1002       MS_LOG(ERROR) << "When fmk is set to MSLITE, only support micronization.";
1003       return RET_NOT_SUPPORT;
1004     }
1005     auto ret = ExecuteMicro(nullptr, param, is_multi_model);
1006     if (ret != RET_OK) {
1007       MS_LOG(ERROR) << "Micronize msLite-model failed.";
1008       return ret;
1009     }
1010     return ret;
1011   }
1012   auto graph = ConverterFuncGraph::Build(param);
1013   if (graph == nullptr) {
1014     MS_LOG(ERROR) << "Build func graph failed";
1015     return RET_ERROR;
1016   }
1017 
1018   int ret = ConverterFuncGraph::Optimize(param, graph);
1019   if (ret != RET_OK) {
1020     MS_LOG(ERROR) << "Optimize func graph failed: " << ret << " " << GetErrorInfo(ret);
1021     return ret;
1022   }
1023 
1024   ret = SaveGraph(graph, param, model_data, data_size, not_save, is_multi_model);
1025   if (ret != RET_OK) {
1026     MS_LOG(ERROR) << "Save graph failed: " << ret << " " << GetErrorInfo(ret);
1027     return ret;
1028   }
1029   return RET_OK;
1030 }
1031 
ExecuteMicro(const schema::MetaGraphT * meta_graph,const std::shared_ptr<ConverterPara> & param,bool is_multi_model)1032 int ConverterImpl::ExecuteMicro(const schema::MetaGraphT *meta_graph, const std::shared_ptr<ConverterPara> &param,
1033                                 bool is_multi_model) {
1034   std::string output_path = param->output_file;
1035   if (!is_multi_model) {
1036     param->microParam.is_last_model = true;
1037   } else {
1038     if (param->microParam.save_path.empty() || param->microParam.project_name.empty()) {
1039       MS_LOG(ERROR) << "Micro param for invalid: save_path or project name is needed";
1040       return RET_ERROR;
1041     }
1042     output_path = param->microParam.save_path + param->microParam.project_name;
1043     if (param->microParam.save_path[param->microParam.save_path.size() - 1] != '/' ||
1044         param->microParam.save_path[param->microParam.save_path.size() - 1] != '\\') {
1045       output_path = param->microParam.save_path + kSlash + param->microParam.project_name;
1046     }
1047   }
1048   auto status =
1049     meta_graph != nullptr
1050       ? micro::Coder::MicroSourceCodeGeneration(*meta_graph, output_path, &param->microParam, param->weight_fp16)
1051       : micro::Coder::MicroSourceCodeGeneration(param->model_file, output_path, &param->microParam, param->weight_fp16);
1052   if (status != RET_OK) {
1053     MS_LOG(ERROR) << "Execute Micro failed.";
1054   }
1055   return status;
1056 }
1057 
SaveGraph(FuncGraphPtr graph,const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save,bool is_multi_model)1058 int ConverterImpl::SaveGraph(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> &param, void **model_data,
1059                              size_t *data_size, bool not_save, bool is_multi_model) {
1060   int status = RET_ERROR;
1061   if (param->save_type == kMindIR) {
1062     status = SaveMindIRModel(graph, param, model_data, data_size);
1063     if (status != RET_OK) {
1064       MS_LOG(ERROR) << "Save mindir model failed :" << status << " " << GetErrorInfo(status);
1065       return RET_ERROR;
1066     }
1067     return RET_OK;
1068   }
1069 
1070   auto meta_graph = ConverterToMetaGraph::Build(param, graph);
1071   if (meta_graph == nullptr) {
1072     MS_LOG(ERROR) << "Convert to meta graph failed";
1073     return RET_ERROR;
1074   }
1075 
1076   if (!param->cpuOptionCfgParam.architecture.empty()) {
1077     std::string cpu_option = param->cpuOptionCfgParam.architecture + param->cpuOptionCfgParam.instruction;
1078     status = ConverterPackedNode(meta_graph, cpu_option);
1079     if (status != RET_OK) {
1080       MS_LOG(ERROR) << "save pack info failed.";
1081       return status;
1082     }
1083   }
1084 
1085   meta_graph->version = Version();
1086 
1087   if (param->pre_infer) {
1088     status = PreInference(*meta_graph, param->train_model);
1089     if (status != RET_OK) {
1090       MS_LOG(ERROR) << "Preinference failed: " << status << " " << GetErrorInfo(status);
1091       delete meta_graph;
1092       return status;
1093     }
1094   }
1095 
1096   if (param->microParam.enable_micro) {
1097     status = ExecuteMicro(meta_graph, param, is_multi_model);
1098   } else {
1099     status = ConverterToMetaGraph::Save(meta_graph, param, model_data, data_size, not_save);
1100   }
1101   delete meta_graph;
1102   if (status != RET_OK) {
1103     MS_LOG(ERROR) << "Save failed:" << status << " " << GetErrorInfo(status);
1104     return status;
1105   }
1106   return RET_OK;
1107 }
1108 
SaveMindIRModel(FuncGraphPtr graph,const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size)1109 int ConverterImpl::SaveMindIRModel(FuncGraphPtr graph, const std::shared_ptr<ConverterPara> &param, void **model_data,
1110                                    size_t *data_size) {
1111   int status = RET_OK;
1112   if (param->pre_infer) {
1113     schema::MetaGraphT *meta_graph = nullptr;
1114     auto new_param = std::make_shared<ConverterPara>();
1115     new_param->fmk_type = converter::kFmkTypeMs;
1116     new_param->save_type = kMindIR;
1117 
1118     std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
1119     auto mirror_graph = lite::CloneFuncGraph(graph, new_param, &cloned_func_graph);
1120     if (mirror_graph == nullptr) {
1121       MS_LOG(ERROR) << "Mirror graph is nullptr";
1122       return RET_ERROR;
1123     }
1124     meta_graph = lite::ConverterToMetaGraph::Build(new_param, mirror_graph);
1125     if (meta_graph == nullptr) {
1126       MS_LOG(ERROR) << "FuncGraph convert to meta graph failed";
1127       return RET_ERROR;
1128     }
1129     status = PreInference(*meta_graph, param->train_model);
1130     if (status != RET_OK) {
1131       MS_LOG(ERROR) << "PreInferenceMindIR failed: " << status << " " << GetErrorInfo(status);
1132       return status;
1133     }
1134   }
1135   status = ConverterFuncGraph::Save(param, graph, model_data, data_size);
1136   if (status != RET_OK) {
1137     MS_LOG(ERROR) << "Export to mindir failed: " << status << " " << GetErrorInfo(status);
1138     return RET_ERROR;
1139   }
1140   return RET_OK;
1141 }
1142 
RunConverter(const std::shared_ptr<ConverterPara> & param,void ** model_data,size_t * data_size,bool not_save)1143 int RunConverter(const std::shared_ptr<ConverterPara> &param, void **model_data, size_t *data_size, bool not_save) {
1144   mindspore::mindspore_log_init();
1145 
1146   param->aclModelOptionCfgParam.offline = !not_save;
1147   int status = RET_OK;
1148   ConverterImpl converter_impl;
1149   try {
1150     status = converter_impl.Convert(param, model_data, data_size, not_save);
1151   } catch (const std::exception &e) {
1152     MS_LOG(ERROR) << "Exception occurred: " << e.what();
1153     status = RET_ERROR;
1154   }
1155   if (status != RET_OK) {
1156     MS_LOG(ERROR) << "Convert model failed";
1157     NotSupportOp::GetInstance()->PrintOps();
1158     return status;
1159   }
1160 
1161   MS_LOG(INFO) << "CONVERT RESULT SUCCESS:" << status;
1162   std::cout << "CONVERT RESULT SUCCESS:" << status << std::endl;
1163   return status;
1164 }
1165 }  // namespace lite
1166 }  // namespace mindspore
1167