• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/config_parser/config_file_parser.h"
18 #include <set>
19 #include "tools/common/parse_config_utils.h"
20 #include "include/errorcode.h"
21 #include "src/common/log_adapter.h"
22 #include "tools/converter/converter_context.h"
23 #include "tools/common/string_util.h"
24 #include "src/common/config_infos.h"
25 #include "src/common/common.h"
26 #include "nnacl/op_base.h"
27 
28 namespace mindspore {
29 namespace lite {
30 namespace {
31 constexpr auto kCommonQuantParam = "common_quant_param";
32 constexpr auto kFullQuantParam = "full_quant_param";
33 constexpr auto kWeightQuantParam = "weight_quant_param";
34 constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
35 constexpr auto kDataPreprocessParam = "data_preprocess_param";
36 constexpr auto kRegistry = "registry";
37 constexpr auto kMicroParam = "micro_param";
38 constexpr auto kThirdPartyModelParam = "third_party_model";
39 constexpr auto kCpuOptionParam = "cpu_option_cfg_param";
40 constexpr auto kCustomOppPath = "custom_opp_path";
41 constexpr auto kTransformQuantParam = "transform_quant_param";
42 constexpr auto kDynamicQuantParam = "dynamic_quant_param";
43 constexpr auto kGraphKernelParam = "graph_kernel_param";
44 constexpr int kNumSize3 = 3;
45 constexpr int kNumSize2 = 2;
46 constexpr size_t kNumIndex0 = 0;
47 constexpr size_t kNumIndex1 = 1;
48 constexpr size_t kNumIndex2 = 2;
49 }  // namespace
50 using ShapeVector = std::vector<int64_t>;
51 const int kBatchDim = 0;
52 const int kDynImgSize = 0;
53 const int kDynBatchSize = 1;
CheckBatchStringSupport(const std::vector<std::string> & batch_str_vec)54 bool CheckBatchStringSupport(const std::vector<std::string> &batch_str_vec) {
55   if (batch_str_vec.empty()) {
56     return false;
57   }
58   std::string only_batch = batch_str_vec[0];
59   for (size_t i = 1; i < batch_str_vec.size(); ++i) {
60     if (batch_str_vec[i] != only_batch) {
61       return false;
62     }
63   }
64   return true;
65 }
66 
67 const size_t kIndex0 = 0;
68 const size_t kIndex1 = 1;
69 const size_t kIndex2 = 2;
70 const size_t kIndex3 = 3;
71 const int64_t kdynDim = -1;
DynBatchOrDynImage(const mindspore::ProfileConfigs & profile,size_t dynamic_input_index)72 int DynBatchOrDynImage(const mindspore::ProfileConfigs &profile, size_t dynamic_input_index) {
73   int dynamic_type = -1;
74   for (auto &info : profile.input_infos) {
75     if (!info.is_dynamic_shape) {
76       continue;
77     }
78     const auto &shape = info.input_shape;
79     if (shape.size() != kNCHWDimNumber) {
80       MS_LOG(ERROR) << "Dynamic input whose shape is not 4-dimensional is not supported, input shape: " << shape;
81       return -1;
82     }
83     auto dynamic_dims_count = std::count_if(shape.begin(), shape.end(), [](int64_t dim) { return dim == kdynDim; });
84     if (shape[kIndex0] != kdynDim && dynamic_dims_count == kHWDimNumber) {
85       if (dynamic_type != -1 && dynamic_type != kDynImgSize) {
86         MS_LOG(ERROR) << "Only dynamic batch or dynamic image size is supported, hybrid scenarios are not supported";
87         return -1;
88       }
89       dynamic_type = kDynImgSize;
90     } else if (shape[kIndex0] == kdynDim && dynamic_dims_count == 1) {
91       if (dynamic_type != -1 && dynamic_type != kDynBatchSize) {
92         MS_LOG(ERROR) << "Only dynamic batch or dynamic image size is supported, hybrid scenarios are not supported";
93         return -1;
94       }
95       dynamic_type = kDynBatchSize;
96     } else {
97       MS_LOG(ERROR) << "Only dynamic batch or dynamic image size is supported, input shape: " << shape;
98       return -1;
99     }
100   }
101   return dynamic_type;
102 }
103 
CombineDynamicImageString(const struct mindspore::ProfileConfigs & profile,size_t dynamic_input)104 std::string CombineDynamicImageString(const struct mindspore::ProfileConfigs &profile, size_t dynamic_input) {
105   ShapeVector shape = profile.input_infos[dynamic_input].input_shape;
106   std::string ret = "";
107   size_t first_dim = kIndex0, second_dim = kIndex0;
108   if (shape[kIndex1] == kdynDim && shape[kIndex2] == kdynDim) {
109     first_dim = kIndex1;
110     second_dim = kIndex2;
111   } else if (shape[kIndex1] == kdynDim && shape[kIndex3] == kdynDim) {
112     first_dim = kIndex1;
113     second_dim = kIndex3;
114   } else if (shape[kIndex2] == kdynDim && shape[kIndex3] == kdynDim) {
115     first_dim = kIndex2;
116     second_dim = kIndex3;
117   }
118   for (size_t dim_idx = 0; dim_idx < profile.profiles.size(); ++dim_idx) {
119     auto &dynamic_item = profile.profiles[dim_idx].inputs[dynamic_input];
120     int64_t min_first = dynamic_item.min_dims[first_dim];
121     int64_t max_first = dynamic_item.max_dims[first_dim];
122     int64_t min_second = dynamic_item.min_dims[second_dim];
123     int64_t max_second = dynamic_item.max_dims[second_dim];
124     for (int64_t i = min_first; i <= max_first; ++i) {
125       for (int64_t j = min_second; j <= max_second; ++j) {
126         ret += std::to_string(i) + "," + std::to_string(j) + ";";
127       }
128     }
129   }
130   ret = ret.substr(0, ret.size() - 1);  // discard the final ";"
131   return ret;
132 }
133 
CombineDynamicBatchList(const struct mindspore::ProfileConfigs & profile,size_t dynamic_input)134 std::vector<size_t> CombineDynamicBatchList(const struct mindspore::ProfileConfigs &profile, size_t dynamic_input) {
135   std::vector<size_t> ret;
136   size_t batch_dim = 0;
137   for (size_t dim_idx = 0; dim_idx < profile.profiles.size(); ++dim_idx) {
138     auto &dynamic_item = profile.profiles[dim_idx].inputs[dynamic_input];
139     int64_t min = dynamic_item.min_dims[batch_dim];
140     int64_t max = dynamic_item.max_dims[batch_dim];
141     for (int64_t i = min; i <= max; ++i) {
142       ret.push_back(LongToSize(i));
143     }
144   }
145   return ret;
146 }
147 
RemoveInputShapeBrackets(const std::string & input_shape_str)148 std::string RemoveInputShapeBrackets(const std::string &input_shape_str) {
149   std::string ret = "";
150   for (size_t i = 0; i < input_shape_str.size(); ++i) {
151     if (input_shape_str[i] == '[' || input_shape_str[i] == ']') {
152       continue;
153     }
154     ret += input_shape_str[i];
155   }
156   return ret;
157 }
158 
FindInAscendMap(const std::string & key,const std::map<std::string,std::string> & ascend_map)159 std::string FindInAscendMap(const std::string &key, const std::map<std::string, std::string> &ascend_map) {
160   auto it = ascend_map.find(key);
161   if (it != ascend_map.end()) {
162     return it->second;
163   }
164   return "";
165 }
166 
SetDynParams(const std::shared_ptr<mindspore::ConverterPara> & param,const std::map<std::string,std::string> & ascend_map)167 bool SetDynParams(const std::shared_ptr<mindspore::ConverterPara> &param,
168                   const std::map<std::string, std::string> &ascend_map) {
169   struct mindspore::ProfileConfigs profile_configs;
170   if (!mindspore::ProfileParser::Parse(ascend_map, false, &profile_configs)) {
171     MS_LOG(ERROR) << "Parse input_shape and dynamic_dims failed";
172     return false;
173   }
174   const auto &input_infos = profile_configs.input_infos;
175   auto it = ascend_map.find("dynamic_dims");
176   if (it == ascend_map.end()) {
177     MS_LOG(INFO) << "Inputs are not dynamic";
178     return true;
179   }
180   std::vector<std::string> dynamic_dims_strs = mindspore::lite::SplitStringToVector(it->second, ';');
181   if (dynamic_dims_strs.size() != input_infos.size()) {
182     MS_LOG(ERROR) << "Invalid dynamic_dims, size " << dynamic_dims_strs.size() << " != input size "
183                   << input_infos.size();
184     return false;
185   }
186   std::string one_dym_dims;
187   size_t dynamic_input_index = 0;
188   for (size_t i = 0; i < input_infos.size(); i++) {
189     auto &info = input_infos[i];
190     if (!info.is_dynamic_shape) {
191       continue;
192     }
193     if (one_dym_dims.empty()) {
194       one_dym_dims = dynamic_dims_strs[i];
195       dynamic_input_index = i;
196     } else if (one_dym_dims != dynamic_dims_strs[i]) {
197       MS_LOG(ERROR) << "Do not support different dynamic_dims, one " << one_dym_dims << ", other "
198                     << dynamic_dims_strs[i];
199       return false;
200     }
201   }
202   int dynamic_type = DynBatchOrDynImage(profile_configs, dynamic_input_index);
203   switch (dynamic_type) {
204     case kDynImgSize:
205       param->aclModelOptionCfgParam.dynamic_image_size =
206         CombineDynamicImageString(profile_configs, dynamic_input_index);
207       break;
208     case kDynBatchSize:
209       param->aclModelOptionCfgParam.dynamic_batch_size = CombineDynamicBatchList(profile_configs, dynamic_input_index);
210       break;
211     default:
212       MS_LOG(ERROR) << "Do not support input shape";
213       return false;
214   }
215   return true;
216 }
217 
ParseInputShapeTemplate(const std::string & shape_template,std::set<std::string> * dynamic_symbols)218 int ParseInputShapeTemplate(const std::string &shape_template, std::set<std::string> *dynamic_symbols) {
219   // the inputs_shape config is like: input1:[d0,d1,3];input2:[4,d0]
220   auto graph_inputs_shape_vec = SplitStringToVector(shape_template, ';');
221   for (const auto &graph_input_shape : graph_inputs_shape_vec) {
222     auto graph_input_shape_info = SplitStringToVector(graph_input_shape, ':');
223     MS_CHECK_TRUE_MSG(graph_input_shape_info.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the inputs_shape is invalid");
224     auto input_shape = graph_input_shape_info[1];
225     if (input_shape[0] != '[' || input_shape[input_shape.size() - 1] != ']') {
226       MS_LOG(ERROR) << "the inputs_shape is invalid";
227       return RET_INPUT_PARAM_INVALID;
228     }
229     input_shape = input_shape.substr(1, input_shape.size() - kIndex2);
230     auto input_shape_vec = SplitStringToVector(input_shape, ',');
231     for (const auto &shape : input_shape_vec) {
232       if (!IsNumber(shape)) {
233         dynamic_symbols->insert(shape);
234       }
235     }
236   }
237   return RET_OK;
238 }
239 
ParseDynamicDimTemplate(const std::string & dims_template,std::set<std::string> * dynamic_symbols,MicroParamString * micro_param_string)240 int ParseDynamicDimTemplate(const std::string &dims_template, std::set<std::string> *dynamic_symbols,
241                             MicroParamString *micro_param_string) {
242   // the dynamic_dim_params config is like: d0:[1,3~6];d1:[1~8]
243   auto dim_info_vec = SplitStringToVector(dims_template, ';');
244   MS_CHECK_TRUE_MSG(dim_info_vec.size() <= kIndex2, RET_NOT_SUPPORT, "currently, only support to set two dynamic dims");
245   for (const auto &dim_info : dim_info_vec) {
246     auto dim_vec = SplitStringToVector(dim_info, ':');
247     MS_CHECK_TRUE_MSG(dim_vec.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the dynamic_dim_params is invalid");
248     std::string symbol = dim_vec[0];
249     if (dynamic_symbols->find(symbol) == dynamic_symbols->end()) {
250       MS_LOG(ERROR) << symbol << "is invalid, because it's not set in the inputs_shape.";
251       return RET_INPUT_PARAM_INVALID;
252     }
253     std::string dim_range = dim_vec[1];
254     if (dim_range[0] != '[' || dim_range[dim_range.size() - 1] != ']') {
255       MS_LOG(ERROR) << "the dynamic_dim_params is invalid";
256       return RET_INPUT_PARAM_INVALID;
257     }
258     dim_range = dim_range.substr(1, dim_range.size() - kIndex2);
259     auto discrete_vec = SplitStringToVector(dim_range, ',');
260     for (const auto &dim : discrete_vec) {
261       auto continuous_dim = SplitStringToVector(dim, '~');
262       MS_CHECK_TRUE_MSG(continuous_dim.size() == kIndex1 || continuous_dim.size() == kIndex2, RET_INPUT_PARAM_INVALID,
263                         "the dynamic_dim_params is invalid");
264       if (continuous_dim.size() == kIndex1) {
265         if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0) {
266           MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0";
267           return RET_INPUT_PARAM_INVALID;
268         }
269         micro_param_string->dynamic_symbols_map[symbol].emplace_back(std::stoi(continuous_dim[0]));
270         continue;
271       }
272       if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0 || !IsNumber(continuous_dim[1]) ||
273           std::stoi(continuous_dim[1]) <= 0) {
274         MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0";
275         return RET_INPUT_PARAM_INVALID;
276       }
277       auto start = std::stoi(continuous_dim[0]);
278       auto end = std::stoi(continuous_dim[1]);
279       for (auto i = start; i <= end; ++i) {
280         micro_param_string->dynamic_symbols_map[symbol].emplace_back(i);
281       }
282     }
283   }
284   return RET_OK;
285 }
286 
SetVariableParams(const std::shared_ptr<mindspore::ConverterPara> & param,const std::map<std::string,std::string> & ascend_map)287 void ConfigFileParser::SetVariableParams(const std::shared_ptr<mindspore::ConverterPara> &param,
288                                          const std::map<std::string, std::string> &ascend_map) {
289   auto it = ascend_map.find("inputs_to_variable");
290   if (it != ascend_map.end()) {
291     std::vector<std::string> inputs_to_variables = mindspore::lite::SplitStringToVector(it->second, ',');
292     ProcessVariableParam(inputs_to_variables, inputs_variable_index_);
293     if (CheckVariableParm(inputs_variable_index_) != RET_OK) {
294       MS_LOG(ERROR) << "Check input variable param failed";
295       return;
296     }
297   }
298   auto output_it = ascend_map.find("outputs_to_variable");
299   if (output_it != ascend_map.end()) {
300     std::vector<std::string> outputs_to_variables = mindspore::lite::SplitStringToVector(output_it->second, ',');
301     ProcessVariableParam(outputs_to_variables, outputs_variable_index_);
302     if (CheckVariableParm(outputs_variable_index_) != RET_OK) {
303       MS_LOG(ERROR) << "Check output variable param failed";
304       return;
305     }
306   }
307   if (!inputs_variable_index_.empty() && !outputs_variable_index_.empty() &&
308       inputs_variable_index_.size() != outputs_variable_index_.size()) {
309     MS_LOG(ERROR) << "Input variable number is not equal output variable number";
310     return;
311   }
312   param->ascendGeOptionCfg.inputs_to_variable = inputs_variable_index_;
313   param->ascendGeOptionCfg.outputs_to_variable = outputs_variable_index_;
314 }
315 
ProcessVariableParam(const std::vector<std::string> & variable_param,std::vector<int64_t> & variable_index)316 int ConfigFileParser::ProcessVariableParam(const std::vector<std::string> &variable_param,
317                                            std::vector<int64_t> &variable_index) {
318   for (auto &it : variable_param) {
319     auto remove_str = RemoveInputShapeBrackets(it);
320     int64_t min_index;
321     int64_t max_index;
322     if (!ProfileParser::ParseRangeStr(remove_str, &min_index, &max_index)) {
323       MS_LOG(ERROR) << "Parser range string " << remove_str << " failed";
324       return RET_ERROR;
325     }
326     if (max_index < min_index) {
327       MS_LOG(ERROR) << "The variable param in not valid" << max_index << "is not larger than" << min_index;
328       return RET_ERROR;
329     }
330     for (int64_t i = min_index; i <= max_index; ++i) {
331       variable_index.emplace_back(i);
332     }
333   }
334   return RET_OK;
335 }
336 
CheckVariableParm(const std::vector<int64_t> & variable_index)337 int ConfigFileParser::CheckVariableParm(const std::vector<int64_t> &variable_index) {
338   for (size_t i = 1; i < variable_index.size(); ++i) {
339     if (variable_index[i] < variable_index[i - 1]) {
340       MS_LOG(ERROR) << "variable index is not valid" << variable_index[i] << " is less than " << variable_index[i - 1];
341       return RET_ERROR;
342     }
343   }
344   return RET_OK;
345 }
346 
CheckPluginCustomOps(const std::vector<std::string> & plugin_custom_ops)347 bool ConfigFileParser::CheckPluginCustomOps(const std::vector<std::string> &plugin_custom_ops) {
348   if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), "None") != plugin_custom_ops.end() &&
349       plugin_custom_ops.size() != 1) {
350     MS_LOG(ERROR) << "plugin_custom_ops include None, can not include other param.";
351     return false;
352   }
353   return true;
354 }
355 
ParseCustomPattern(const std::shared_ptr<mindspore::ConverterPara> & param,std::string custom_pattern_str)356 STATUS ConfigFileParser::ParseCustomPattern(const std::shared_ptr<mindspore::ConverterPara> &param,
357                                             std::string custom_pattern_str) {
358   std::vector<std::string> custom_pattern_strs = mindspore::lite::SplitStringToVector(custom_pattern_str, ";");
359   for (auto custom_pattern : custom_pattern_strs) {
360     std::vector<std::string> item = mindspore::lite::SplitStringToVector(custom_pattern, ":");
361     if (item.size() != kNumSize3) {
362       return RET_ERROR;
363     }
364     std::string op_type = item[0];
365     auto names_list = mindspore::lite::SplitStringToVector(item[1], ",");
366     std::string status = item[kNumSize2];
367     if (status == "enable") {
368       if (param->aclModelOptionCfgParam.enable_custom_fusion_pattern.find(op_type) !=
369           param->aclModelOptionCfgParam.enable_custom_fusion_pattern.end()) {
370         MS_LOG(ERROR) << op_type << " has define, can not defined repeat.";
371         return RET_ERROR;
372       }
373       param->aclModelOptionCfgParam.enable_custom_fusion_pattern[op_type] = names_list;
374     } else if (status == "disable") {
375       if (param->aclModelOptionCfgParam.disable_custom_fusion_pattern.find(op_type) !=
376           param->aclModelOptionCfgParam.disable_custom_fusion_pattern.end()) {
377         MS_LOG(ERROR) << op_type << " has define, can not defined repeat.";
378         return RET_ERROR;
379       }
380       param->aclModelOptionCfgParam.disable_custom_fusion_pattern[op_type] = names_list;
381     } else {
382       MS_LOG(ERROR) << "status only support enable or disable";
383       return RET_ERROR;
384     }
385   }
386   return RET_OK;
387 }
388 
SetParamByConfigfile(const std::shared_ptr<mindspore::ConverterPara> & param,const std::map<std::string,std::string> & ascend_map)389 bool ConfigFileParser::SetParamByConfigfile(const std::shared_ptr<mindspore::ConverterPara> &param,
390                                             const std::map<std::string, std::string> &ascend_map) {
391   std::string ascend_string = "";
392   auto set_option = [&ascend_map](const std::string &key, std::string *option) {
393     auto it = ascend_map.find(key);
394     if (it != ascend_map.end() && !it->second.empty()) {
395       *option = it->second;
396     }
397   };
398   set_option("input_format", &param->aclModelOptionCfgParam.input_format);
399   set_option("precision_mode", &param->aclModelOptionCfgParam.precision_mode);
400   set_option("op_select_impl_mode", &param->aclModelOptionCfgParam.op_select_impl_mode);
401   set_option("fusion_switch_config_file_path", &param->aclModelOptionCfgParam.fusion_switch_config_file_path);
402   set_option("buffer_optimize", &param->aclModelOptionCfgParam.buffer_optimize);
403   set_option("insert_op_config_file_path", &param->aclModelOptionCfgParam.insert_op_config_file_path);
404   set_option("om_file_path", &param->aclModelOptionCfgParam.om_file_path);
405   set_option("aoe_mode", &param->aclModelOptionCfgParam.aoe_mode);
406   set_option(kDumpModelNameKey, &param->aclModelOptionCfgParam.dump_model_name);
407   set_option("provider", &param->provider);
408 
409   auto plugin_custom_ops_str = FindInAscendMap(kPluginCustomOps, ascend_map);
410   std::vector<std::string> plugin_custom_ops_vec = {};
411   if (!plugin_custom_ops_str.empty()) {
412     MS_LOG(INFO) << "plugin_custom_ops: " << plugin_custom_ops_str;
413     plugin_custom_ops_vec = mindspore::lite::SplitStringToVector(plugin_custom_ops_str, ",");
414     if (!CheckPluginCustomOps(plugin_custom_ops_vec)) {
415       return false;
416     }
417   }
418   if (!plugin_custom_ops_vec.empty()) {
419     param->ascendGeOptionCfg.plugin_custom_ops = plugin_custom_ops_vec;
420   } else if (!(ascend_string = FindInAscendMap(kEnableCustomOp, ascend_map)).empty()) {
421     param->ascendGeOptionCfg.plugin_custom_ops = {"All"};
422   }
423   // parse for ascend custom fusion op
424   if (!plugin_custom_ops_vec.empty()) {
425     param->aclModelOptionCfgParam.plugin_custom_ops = plugin_custom_ops_vec;
426   }
427   auto custom_fusion_pattern_str = FindInAscendMap("custom_fusion_pattern", ascend_map);
428   if (!custom_fusion_pattern_str.empty()) {
429     auto status = ParseCustomPattern(param, custom_fusion_pattern_str);
430     if (status != RET_OK) {
431       MS_LOG(ERROR) << "custom fusion pattern wrong, eg:\n"
432                        "custom_fusion_pattern=Fusion_op_type:node_name_1,node_name_2:enable\n"
433                        "or: "
434                        "custom_fusion_pattern=Fusion_op_type:node_name_1,node_name_2:disable";
435       return false;
436     }
437   }
438 
439   auto op_attrs_str = FindInAscendMap(kOpAttrs, ascend_map);
440   std::vector<std::string> op_attrs_vec = {};
441   if (!op_attrs_str.empty()) {
442     MS_LOG(INFO) << "op_attrs_str: " << op_attrs_str;
443     op_attrs_vec = mindspore::lite::SplitStringToVector(op_attrs_str, ";");
444     std::map<std::string, std::string> attr;
445     for (auto op_attr_str : op_attrs_vec) {
446       MS_LOG(INFO) << "op_attr: " << op_attr_str;
447       auto op_attr = mindspore::lite::SplitStringToVector(op_attr_str, ":");
448       if (op_attr.size() != kNumSize3) {
449         MS_LOG(ERROR) << "Only support ops:attr:value, but get " << op_attr_str;
450         return false;
451       }
452       auto op_type = op_attr[kNumIndex0];
453       auto attr_key = op_attr[kNumIndex1];
454       auto attr_value = op_attr[kNumIndex2];
455       param->aclModelOptionCfgParam.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
456       param->ascendGeOptionCfg.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
457     }
458   }
459   for (auto item : param->aclModelOptionCfgParam.op_attrs_map) {
460     for (auto attr : item.second) {
461       MS_LOG(INFO) << "op type: " << item.first << ", key: " << attr.first << ", value: " << attr.second;
462     }
463   }
464 
465   auto it = ascend_map.find("input_shape");
466   if (it != ascend_map.end()) {
467     param->aclModelOptionCfgParam.input_shape = RemoveInputShapeBrackets(it->second);
468   }
469 
470   it = ascend_map.find("device_id");
471   if (it != ascend_map.end()) {
472     int32_t val;
473     if (mindspore::lite::ConvertIntNum(it->second, &val)) {
474       param->aclModelOptionCfgParam.device_id = val;
475     } else {
476       MS_LOG(ERROR) << "Convert device id failed";
477       return false;
478     }
479   }
480 
481   it = ascend_map.find("output_type");
482   if (it != ascend_map.end()) {
483     auto dtype_str = it->second;
484     if (dtype_str == "FP16") {
485       param->aclModelOptionCfgParam.output_type = DataType::kNumberTypeFloat16;
486     } else if (dtype_str == "FP32") {
487       param->aclModelOptionCfgParam.output_type = DataType::kNumberTypeFloat32;
488     } else if (dtype_str == "UINT8") {
489       param->aclModelOptionCfgParam.output_type = DataType::kNumberTypeUInt8;
490     } else {
491       MS_LOG(WARNING) << "Unsupported or invalid output_type, using default type";
492     }
493   }
494   SetVariableParams(param, ascend_map);
495   return SetDynParams(param, ascend_map);
496 }
497 
ParseConfigFile(const std::string & config_file_path,std::map<int,std::map<std::string,std::string>> * model_param_infos)498 int ConfigFileParser::ParseConfigFile(const std::string &config_file_path,
499                                       std::map<int, std::map<std::string, std::string>> *model_param_infos) {
500   std::map<std::string, std::map<std::string, std::string>> maps;
501   auto ret = mindspore::lite::ParseConfigFile(config_file_path, &maps, model_param_infos);
502   if (ret != RET_OK) {
503     MS_LOG(ERROR) << "Parse config file failed.";
504     return ret;
505   }
506   ret = ParseConfigParam(&maps);
507   if (ret != RET_OK) {
508     MS_LOG(ERROR) << "Parse config param failed.";
509     return ret;
510   }
511   return RET_OK;
512 }
513 
ParseConfigParam(std::map<std::string,std::map<std::string,std::string>> * maps)514 int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::string, std::string>> *maps) {
515   if (maps == nullptr) {
516     MS_LOG(ERROR) << "Maps is nullptr.";
517     return RET_ERROR;
518   }
519   for (const auto &config_info : *maps) {
520     ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
521   }
522   auto ret = ParseDataPreProcessString(*maps);
523   (void)maps->erase(kDataPreprocessParam);
524   if (ret != RET_OK) {
525     MS_LOG(ERROR) << "ParseDataPreProcessString failed.";
526     return ret;
527   }
528   ret = ParseCommonQuantString(*maps);
529   (void)maps->erase(kCommonQuantParam);
530   if (ret != RET_OK) {
531     MS_LOG(ERROR) << "ParseCommonQuantString failed.";
532     return ret;
533   }
534   ret = ParseMixedBitQuantString(*maps);
535   (void)maps->erase(kMixedBitWeightQuantParam);
536   if (ret != RET_OK) {
537     MS_LOG(ERROR) << "ParseMixedBitQuantString failed.";
538     return ret;
539   }
540   ret = ParseFullQuantString(*maps);
541   (void)maps->erase(kFullQuantParam);
542   if (ret != RET_OK) {
543     MS_LOG(ERROR) << "ParseFullQuantString failed.";
544     return ret;
545   }
546   ret = ParseRegistryInfoString(*maps);
547   (void)maps->erase(kRegistry);
548   if (ret != RET_OK) {
549     MS_LOG(ERROR) << "ParseExtendedintegrationString failed.";
550     return ret;
551   }
552   ret = ParseAclOptionCfgString(*maps);
553   (void)maps->erase(kAclOptionParam);
554   if (ret != RET_OK) {
555     MS_LOG(ERROR) << "ParseAclOptionCfgString failed.";
556     return ret;
557   }
558   ret = ParseMicroParamString(*maps);
559   (void)maps->erase(kMicroParam);
560   if (ret != RET_OK) {
561     MS_LOG(ERROR) << "ParseMicroParamString failed.";
562     return ret;
563   }
564   ret = ParseThirdPartyParamString(*maps);
565   (void)maps->erase(kThirdPartyModelParam);
566   if (ret != RET_OK) {
567     MS_LOG(ERROR) << "ParseTransformQuantString failed.";
568     return ret;
569   }
570   ret = ParseWeightQuantString(*maps);
571   (void)maps->erase(kWeightQuantParam);
572   if (ret != RET_OK) {
573     MS_LOG(ERROR) << "ParseWeightQuantString failed.";
574     return ret;
575   }
576   ret = ParseCpuOptionCfgString(*maps);
577   (void)maps->erase(kCpuOptionParam);
578   if (ret != RET_OK) {
579     MS_LOG(ERROR) << "ParseCpuOptionCfgString failed.";
580     return ret;
581   }
582   ret = ParseTransformQuantString(*maps);
583   (void)maps->erase(kTransformQuantParam);
584   if (ret != RET_OK) {
585     MS_LOG(ERROR) << "ParseTransformQuantString failed.";
586     return ret;
587   }
588   ret = ParseDynamicQuantString(*maps);
589   (void)maps->erase(kDynamicQuantParam);
590   if (ret != RET_OK) {
591     MS_LOG(ERROR) << "ParseDynamicQuantString failed.";
592     return ret;
593   }
594   (void)ParseGraphKernelString(*maps);
595   (void)maps->erase(kGraphKernelParam);
596   ret = ParseOMConverterString(*maps);
597   (void)maps->erase(kOMConverterOptionsSection);
598   if (ret != RET_OK) {
599     MS_LOG(ERROR) << "ParseOMConverterString failed.";
600     return ret;
601   }
602   return RET_OK;
603 }
604 
SetMapData(const std::map<std::string,std::string> & input_map,const std::map<std::string,std::string &> & parse_map,const std::string & section,const std::set<std::string> & dynamic_key)605 int ConfigFileParser::SetMapData(const std::map<std::string, std::string> &input_map,
606                                  const std::map<std::string, std::string &> &parse_map, const std::string &section,
607                                  const std::set<std::string> &dynamic_key) {
608   for (const auto &map : input_map) {
609     if (dynamic_key.find(map.first) != dynamic_key.end()) {
610       continue;
611     }
612     if (parse_map.find(map.first) == parse_map.end()) {
613       MS_LOG(ERROR) << "INPUT ILLEGAL: `" << map.first << "` is not supported in "
614                     << "[" << section << "]";
615       return RET_INPUT_PARAM_INVALID;
616     } else {
617       parse_map.at(map.first) = map.second;
618     }
619   }
620   return RET_OK;
621 }
622 
ParseDataPreProcessString(const std::map<std::string,std::map<std::string,std::string>> & maps)623 int ConfigFileParser::ParseDataPreProcessString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
624   if (maps.find(kDataPreprocessParam) != maps.end()) {
625     const auto &map = maps.at(kDataPreprocessParam);
626     std::map<std::string, std::string &> parse_map{
627       {"calibrate_path", data_pre_process_string_.calibrate_path},
628       {"calibrate_size", data_pre_process_string_.calibrate_size},
629       {"input_type", data_pre_process_string_.input_type},
630       {"image_to_format", data_pre_process_string_.image_to_format},
631       {"normalize_mean", data_pre_process_string_.normalize_mean},
632       {"normalize_std", data_pre_process_string_.normalize_std},
633       {"resize_width", data_pre_process_string_.resize_width},
634       {"resize_height", data_pre_process_string_.resize_height},
635       {"resize_method", data_pre_process_string_.resize_method},
636       {"center_crop_width", data_pre_process_string_.center_crop_width},
637       {"center_crop_height", data_pre_process_string_.center_crop_height},
638     };
639     return SetMapData(map, parse_map, kDataPreprocessParam);
640   }
641   return RET_OK;
642 }
643 
ParseCommonQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)644 int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
645   if (maps.find(kCommonQuantParam) != maps.end()) {
646     const auto &map = maps.at(kCommonQuantParam);
647     std::map<std::string, std::string &> parse_map{
648       {"quant_type", common_quant_string_.quant_type},
649       {"bit_num", common_quant_string_.bit_num},
650       {"min_quant_weight_size", common_quant_string_.min_quant_weight_size},
651       {"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
652       {"skip_quant_node", common_quant_string_.skip_quant_node},
653       {"debug_info_save_path", common_quant_string_.debug_info_save_path},
654       {"enable_encode", common_quant_string_.enable_encode},
655       {"workspace", common_quant_string_.workspace},
656     };
657     return SetMapData(map, parse_map, kCommonQuantParam);
658   }
659   return RET_OK;
660 }
661 
ParseMixedBitQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)662 int ConfigFileParser::ParseMixedBitQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
663   if (maps.find(kMixedBitWeightQuantParam) != maps.end()) {
664     const auto &map = maps.at(kMixedBitWeightQuantParam);
665     std::map<std::string, std::string &> parse_map{
666       {"init_scale", mixed_bit_quant_string_.init_scale},
667       {"auto_tune", mixed_bit_quant_string_.auto_tune},
668       {"use_cv_data", mixed_bit_quant_string_.use_cv_data},
669       {"max_iterations", mixed_bit_quant_string_.max_iterations},
670     };
671     return SetMapData(map, parse_map, kMixedBitWeightQuantParam);
672   }
673   return RET_OK;
674 }
675 
ParseFullQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)676 int ConfigFileParser::ParseFullQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
677   if (maps.find(kFullQuantParam) != maps.end()) {
678     const auto &map = maps.at(kFullQuantParam);
679     std::map<std::string, std::string &> parse_map{
680       {"activation_quant_method", full_quant_string_.activation_quant_method},
681       {"bias_correction", full_quant_string_.bias_correction},
682       {"target_device", full_quant_string_.target_device},
683       {"per_channel", full_quant_string_.per_channel},
684       {"smooth_alpha", full_quant_string_.smooth_alpha},
685       {"enable_smooth_shift", full_quant_string_.enable_smooth_shift},
686     };
687     return SetMapData(map, parse_map, kFullQuantParam);
688   }
689   return RET_OK;
690 }
691 
ParseRegistryInfoString(const std::map<std::string,std::map<std::string,std::string>> & maps)692 int ConfigFileParser::ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
693   if (maps.find(kRegistry) != maps.end()) {
694     const auto &map = maps.at(kRegistry);
695     std::map<std::string, std::string &> parse_map{
696       {"plugin_path", registry_info_string_.plugin_path},
697       {"disable_fusion", registry_info_string_.disable_fusion},
698       {"fusion_blacklists", registry_info_string_.fusion_blacklists},
699     };
700     return SetMapData(map, parse_map, kRegistry);
701   }
702   return RET_OK;
703 }
704 
ParseAclOptionCfgString(const std::map<std::string,std::map<std::string,std::string>> & maps)705 int ConfigFileParser::ParseAclOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
706   if (maps.find(kAclOptionParam) != maps.end()) {
707     const auto &map = maps.at(kAclOptionParam);
708     std::map<std::string, std::string &> parse_map{
709       {"device_id", acl_option_cfg_string_.device_id},
710       {"input_format", acl_option_cfg_string_.input_format},
711       {"input_shape_vector", acl_option_cfg_string_.input_shape_vector},
712       {"input_shape", acl_option_cfg_string_.input_shape},
713       {"output_type", acl_option_cfg_string_.output_type},
714       {"precision_mode", acl_option_cfg_string_.precision_mode},
715       {"op_select_impl_mode", acl_option_cfg_string_.op_select_impl_mode},
716       {"fusion_switch_config_file_path", acl_option_cfg_string_.fusion_switch_config_file_path},
717       {"dynamic_batch_size", acl_option_cfg_string_.dynamic_batch_size},
718       {"buffer_optimize", acl_option_cfg_string_.buffer_optimize},
719       {"insert_op_config_file_path", acl_option_cfg_string_.insert_op_config_file_path},
720       {"dynamic_image_size", acl_option_cfg_string_.dynamic_image_size},
721       {"dynamic_dims", acl_option_cfg_string_.dynamic_dims},
722       {"aoe_mode", acl_option_cfg_string_.aoe_mode},
723       {"custom_opp_path", acl_option_cfg_string_.custom_opp_path}};
724     auto ret = SetMapData(map, parse_map, kAclOptionParam);
725     if (ret != RET_OK) {
726       MS_LOG(ERROR) << "set map data failed.";
727       return ret;
728     }
729   }
730   if (maps.find(kAclInitOptionParam) != maps.end()) {
731     const auto &map = maps.at(kAclInitOptionParam);
732     for (const auto &item : map) {
733       (void)acl_option_cfg_string_.init_options_map.emplace(item.first, item.second);
734     }
735   }
736   if (maps.find(kAclBuildOptionParam) != maps.end()) {
737     const auto &map = maps.at(kAclBuildOptionParam);
738     for (const auto &item : map) {
739       (void)acl_option_cfg_string_.build_options_map.emplace(item.first, item.second);
740     }
741   }
742   if (maps.find(kAoeGlobalOptionsSection) != maps.end()) {
743     const auto &map = maps.at(kAoeGlobalOptionsSection);
744     for (const auto &item : map) {
745       (void)acl_option_cfg_string_.aoe_global_options_map.emplace(item.first, item.second);
746     }
747   }
748   if (maps.find(kAoeTuningOptionsSection) != maps.end()) {
749     const auto &map = maps.at(kAoeTuningOptionsSection);
750     for (const auto &item : map) {
751       (void)acl_option_cfg_string_.aoe_tuning_options_map.emplace(item.first, item.second);
752     }
753   }
754   return RET_OK;
755 }
756 
ParseMicroParamString(const std::map<std::string,std::map<std::string,std::string>> & maps)757 int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
758   if (maps.find(kMicroParam) == maps.end()) {
759     return RET_OK;
760   }
761   const auto &map = maps.at(kMicroParam);
762   const std::string graph_inputs_shape_template = "inputs_shape";
763   std::set<std::string> dynamic_symbols;
764   if (map.find(graph_inputs_shape_template) != map.end()) {
765     const auto &shape_template = map.at(graph_inputs_shape_template);
766     MS_CHECK_TRUE_MSG(ParseInputShapeTemplate(shape_template, &dynamic_symbols) == RET_OK, RET_ERROR,
767                       "ParseInputShapeTemplate failed");
768   }
769   const std::string dynamic_dims = "dynamic_dim_params";
770   if (!dynamic_symbols.empty() && map.find(dynamic_dims) != map.end()) {
771     const auto &dims_template = map.at(dynamic_dims);
772     MS_CHECK_TRUE_MSG(ParseDynamicDimTemplate(dims_template, &dynamic_symbols, &micro_param_string_) == RET_OK,
773                       RET_ERROR, "ParseDynamicDimTemplate failed");
774   }
775   std::map<std::string, std::string &> parse_map{
776     {"target", micro_param_string_.target},
777     {"codegen_mode", micro_param_string_.codegen_mode},
778     {"debug_mode", micro_param_string_.debug_mode},
779     {"support_parallel", micro_param_string_.support_parallel},
780     {"enable_micro", micro_param_string_.enable_micro},
781     {"save_path", micro_param_string_.save_path},
782     {"project_name", micro_param_string_.project_name},
783     {"keep_original_weight", micro_param_string_.keep_original_weight},
784     {"changeable_weights_name", micro_param_string_.changeable_weights_name},
785     {"inputs_shape", micro_param_string_.inputs_shape},
786     {"dynamic_dim_params", micro_param_string_.dynamic_dim_params}};
787   return SetMapData(map, parse_map, kMicroParam);
788 }
789 
ParseWeightQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)790 int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
791   if (maps.find(kWeightQuantParam) != maps.end()) {
792     const auto &map = maps.at(kWeightQuantParam);
793     std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy},
794                                                    {"update_mindir", weight_quant_string_.update_mindir},
795                                                    {"max_segments", weight_quant_string_.max_segments},
796                                                    {"per_channel", weight_quant_string_.per_channel},
797                                                    {"bias_correction", weight_quant_string_.bias_correction},
798                                                    {"quant_strategy", weight_quant_string_.quant_strategy}};
799     return SetMapData(map, parse_map, kWeightQuantParam);
800   }
801   return RET_OK;
802 }
803 
ParseCpuOptionCfgString(const std::map<std::string,std::map<std::string,std::string>> & maps)804 int ConfigFileParser::ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
805   if (maps.find(kCpuOptionParam) != maps.end()) {
806     const auto &map = maps.at(kCpuOptionParam);
807     std::map<std::string, std::string &> parse_map{{"architecture", cpu_option_cfg_string_.architecture},
808                                                    {"instruction", cpu_option_cfg_string_.instruction}};
809     auto ret = SetMapData(map, parse_map, kCpuOptionParam);
810     if (cpu_option_cfg_string_.architecture.empty() || cpu_option_cfg_string_.instruction.empty()) {
811       MS_LOG(WARNING) << "[cpu_option_cfg_param] set incompletely, the model won't do optimize for cpu, please "
812                          "check the parameter architecture and instruction are correct.";
813     }
814     return ret;
815   }
816   return RET_OK;
817 }
818 
ParseTransformQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)819 int ConfigFileParser::ParseTransformQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
820   if (maps.find(kTransformQuantParam) != maps.end()) {
821     const auto &map = maps.at(kTransformQuantParam);
822     std::map<std::string, std::string &> parse_map{
823       {"export_precision_mode", transform_quant_string_.export_precision_mode},
824     };
825     return SetMapData(map, parse_map, kTransformQuantParam);
826   }
827   return RET_OK;
828 }
829 
ParseDynamicQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)830 int ConfigFileParser::ParseDynamicQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
831   if (maps.find(kDynamicQuantParam) != maps.end()) {
832     const auto &map = maps.at(kDynamicQuantParam);
833     std::map<std::string, std::string &> parse_map{
834       {"quant_strategy", dynamic_quant_string_.quant_strategy},
835     };
836     return SetMapData(map, parse_map, kDynamicQuantParam);
837   }
838   return RET_OK;
839 }
840 
ParseGraphKernelString(const std::map<std::string,std::map<std::string,std::string>> & maps)841 int ConfigFileParser::ParseGraphKernelString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
842   if (maps.find(kGraphKernelParam) != maps.end()) {
843     const auto &map = maps.at(kGraphKernelParam);
844     for (const auto &item : map) {
845       std::stringstream oss;
846       oss << "--" << item.first << "=" << item.second;
847       (void)graph_kernel_string_.emplace_back(oss.str());
848     }
849   }
850   return RET_OK;
851 }
852 
ParseOMConverterString(const std::map<std::string,std::map<std::string,std::string>> & maps)853 int ConfigFileParser::ParseOMConverterString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
854   if (maps.find(kOMConverterOptionsSection) != maps.end()) {
855     const auto &map = maps.at(kOMConverterOptionsSection);
856     std::map<std::string, std::string &> parse_map{
857       {"input_name_vector", om_converter_string_.input_name_vector},
858       {"input_shape_vector", om_converter_string_.input_shape_vector},
859       {"input_data_type_vector", om_converter_string_.input_data_type_vector},
860       {"output_name_vector", om_converter_string_.output_name_vector},
861       {"output_shape_vector", om_converter_string_.output_shape_vector},
862       {"output_data_type_vector", om_converter_string_.output_data_type_vector}};
863     auto ret = SetMapData(map, parse_map, kOMConverterOptionsSection);
864     if (ret != RET_OK) {
865       MS_LOG(ERROR) << "Set map data failed.";
866       return ret;
867     }
868   }
869   return RET_OK;
870 }
871 
ParseThirdPartyParamString(const std::map<std::string,std::map<std::string,std::string>> & sections)872 int ConfigFileParser::ParseThirdPartyParamString(
873   const std::map<std::string, std::map<std::string, std::string>> &sections) {
874   if (sections.find(kThirdPartyModelParam) == sections.end()) {
875     return RET_OK;
876   }
877   const auto &input_args = sections.at(kThirdPartyModelParam);
878   const std::map<std::string, std::string &> kValidArgs = {
879     {"input_shapes", third_party_model_string_.input_shapes},
880     {"input_dtypes", third_party_model_string_.input_dtypes},
881     {"input_names", third_party_model_string_.input_names},
882     {"input_formats", third_party_model_string_.input_formats},
883     {"output_shapes", third_party_model_string_.output_shapes},
884     {"output_dtypes", third_party_model_string_.output_dtypes},
885     {"output_names", third_party_model_string_.output_names},
886     {"output_formats", third_party_model_string_.output_formats},
887     {"extended_parameters", third_party_model_string_.extended_parameters},
888   };
889   return SetMapData(input_args, kValidArgs, kThirdPartyModelParam);
890 }
891 }  // namespace lite
892 }  // namespace mindspore
893