1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/config_parser/config_file_parser.h"
18 #include <set>
19 #include "tools/common/parse_config_utils.h"
20 #include "include/errorcode.h"
21 #include "src/common/log_adapter.h"
22 #include "tools/converter/converter_context.h"
23 #include "tools/common/string_util.h"
24 #include "src/common/config_infos.h"
25 #include "src/common/common.h"
26 #include "nnacl/op_base.h"
27
28 namespace mindspore {
29 namespace lite {
30 namespace {
31 constexpr auto kCommonQuantParam = "common_quant_param";
32 constexpr auto kFullQuantParam = "full_quant_param";
33 constexpr auto kWeightQuantParam = "weight_quant_param";
34 constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
35 constexpr auto kDataPreprocessParam = "data_preprocess_param";
36 constexpr auto kRegistry = "registry";
37 constexpr auto kMicroParam = "micro_param";
38 constexpr auto kThirdPartyModelParam = "third_party_model";
39 constexpr auto kCpuOptionParam = "cpu_option_cfg_param";
40 constexpr auto kCustomOppPath = "custom_opp_path";
41 constexpr auto kTransformQuantParam = "transform_quant_param";
42 constexpr auto kDynamicQuantParam = "dynamic_quant_param";
43 constexpr auto kGraphKernelParam = "graph_kernel_param";
44 constexpr int kNumSize3 = 3;
45 constexpr int kNumSize2 = 2;
46 constexpr size_t kNumIndex0 = 0;
47 constexpr size_t kNumIndex1 = 1;
48 constexpr size_t kNumIndex2 = 2;
49 } // namespace
50 using ShapeVector = std::vector<int64_t>;
51 const int kBatchDim = 0;
52 const int kDynImgSize = 0;
53 const int kDynBatchSize = 1;
CheckBatchStringSupport(const std::vector<std::string> & batch_str_vec)54 bool CheckBatchStringSupport(const std::vector<std::string> &batch_str_vec) {
55 if (batch_str_vec.empty()) {
56 return false;
57 }
58 std::string only_batch = batch_str_vec[0];
59 for (size_t i = 1; i < batch_str_vec.size(); ++i) {
60 if (batch_str_vec[i] != only_batch) {
61 return false;
62 }
63 }
64 return true;
65 }
66
67 const size_t kIndex0 = 0;
68 const size_t kIndex1 = 1;
69 const size_t kIndex2 = 2;
70 const size_t kIndex3 = 3;
71 const int64_t kdynDim = -1;
DynBatchOrDynImage(const mindspore::ProfileConfigs & profile,size_t dynamic_input_index)72 int DynBatchOrDynImage(const mindspore::ProfileConfigs &profile, size_t dynamic_input_index) {
73 int dynamic_type = -1;
74 for (auto &info : profile.input_infos) {
75 if (!info.is_dynamic_shape) {
76 continue;
77 }
78 const auto &shape = info.input_shape;
79 if (shape.size() != kNCHWDimNumber) {
80 MS_LOG(ERROR) << "Dynamic input whose shape is not 4-dimensional is not supported, input shape: " << shape;
81 return -1;
82 }
83 auto dynamic_dims_count = std::count_if(shape.begin(), shape.end(), [](int64_t dim) { return dim == kdynDim; });
84 if (shape[kIndex0] != kdynDim && dynamic_dims_count == kHWDimNumber) {
85 if (dynamic_type != -1 && dynamic_type != kDynImgSize) {
86 MS_LOG(ERROR) << "Only dynamic batch or dynamic image size is supported, hybrid scenarios are not supported";
87 return -1;
88 }
89 dynamic_type = kDynImgSize;
90 } else if (shape[kIndex0] == kdynDim && dynamic_dims_count == 1) {
91 if (dynamic_type != -1 && dynamic_type != kDynBatchSize) {
92 MS_LOG(ERROR) << "Only dynamic batch or dynamic image size is supported, hybrid scenarios are not supported";
93 return -1;
94 }
95 dynamic_type = kDynBatchSize;
96 } else {
97 MS_LOG(ERROR) << "Only dynamic batch or dynamic image size is supported, input shape: " << shape;
98 return -1;
99 }
100 }
101 return dynamic_type;
102 }
103
CombineDynamicImageString(const struct mindspore::ProfileConfigs & profile,size_t dynamic_input)104 std::string CombineDynamicImageString(const struct mindspore::ProfileConfigs &profile, size_t dynamic_input) {
105 ShapeVector shape = profile.input_infos[dynamic_input].input_shape;
106 std::string ret = "";
107 size_t first_dim = kIndex0, second_dim = kIndex0;
108 if (shape[kIndex1] == kdynDim && shape[kIndex2] == kdynDim) {
109 first_dim = kIndex1;
110 second_dim = kIndex2;
111 } else if (shape[kIndex1] == kdynDim && shape[kIndex3] == kdynDim) {
112 first_dim = kIndex1;
113 second_dim = kIndex3;
114 } else if (shape[kIndex2] == kdynDim && shape[kIndex3] == kdynDim) {
115 first_dim = kIndex2;
116 second_dim = kIndex3;
117 }
118 for (size_t dim_idx = 0; dim_idx < profile.profiles.size(); ++dim_idx) {
119 auto &dynamic_item = profile.profiles[dim_idx].inputs[dynamic_input];
120 int64_t min_first = dynamic_item.min_dims[first_dim];
121 int64_t max_first = dynamic_item.max_dims[first_dim];
122 int64_t min_second = dynamic_item.min_dims[second_dim];
123 int64_t max_second = dynamic_item.max_dims[second_dim];
124 for (int64_t i = min_first; i <= max_first; ++i) {
125 for (int64_t j = min_second; j <= max_second; ++j) {
126 ret += std::to_string(i) + "," + std::to_string(j) + ";";
127 }
128 }
129 }
130 ret = ret.substr(0, ret.size() - 1); // discard the final ";"
131 return ret;
132 }
133
CombineDynamicBatchList(const struct mindspore::ProfileConfigs & profile,size_t dynamic_input)134 std::vector<size_t> CombineDynamicBatchList(const struct mindspore::ProfileConfigs &profile, size_t dynamic_input) {
135 std::vector<size_t> ret;
136 size_t batch_dim = 0;
137 for (size_t dim_idx = 0; dim_idx < profile.profiles.size(); ++dim_idx) {
138 auto &dynamic_item = profile.profiles[dim_idx].inputs[dynamic_input];
139 int64_t min = dynamic_item.min_dims[batch_dim];
140 int64_t max = dynamic_item.max_dims[batch_dim];
141 for (int64_t i = min; i <= max; ++i) {
142 ret.push_back(LongToSize(i));
143 }
144 }
145 return ret;
146 }
147
RemoveInputShapeBrackets(const std::string & input_shape_str)148 std::string RemoveInputShapeBrackets(const std::string &input_shape_str) {
149 std::string ret = "";
150 for (size_t i = 0; i < input_shape_str.size(); ++i) {
151 if (input_shape_str[i] == '[' || input_shape_str[i] == ']') {
152 continue;
153 }
154 ret += input_shape_str[i];
155 }
156 return ret;
157 }
158
FindInAscendMap(const std::string & key,const std::map<std::string,std::string> & ascend_map)159 std::string FindInAscendMap(const std::string &key, const std::map<std::string, std::string> &ascend_map) {
160 auto it = ascend_map.find(key);
161 if (it != ascend_map.end()) {
162 return it->second;
163 }
164 return "";
165 }
166
SetDynParams(const std::shared_ptr<mindspore::ConverterPara> & param,const std::map<std::string,std::string> & ascend_map)167 bool SetDynParams(const std::shared_ptr<mindspore::ConverterPara> ¶m,
168 const std::map<std::string, std::string> &ascend_map) {
169 struct mindspore::ProfileConfigs profile_configs;
170 if (!mindspore::ProfileParser::Parse(ascend_map, false, &profile_configs)) {
171 MS_LOG(ERROR) << "Parse input_shape and dynamic_dims failed";
172 return false;
173 }
174 const auto &input_infos = profile_configs.input_infos;
175 auto it = ascend_map.find("dynamic_dims");
176 if (it == ascend_map.end()) {
177 MS_LOG(INFO) << "Inputs are not dynamic";
178 return true;
179 }
180 std::vector<std::string> dynamic_dims_strs = mindspore::lite::SplitStringToVector(it->second, ';');
181 if (dynamic_dims_strs.size() != input_infos.size()) {
182 MS_LOG(ERROR) << "Invalid dynamic_dims, size " << dynamic_dims_strs.size() << " != input size "
183 << input_infos.size();
184 return false;
185 }
186 std::string one_dym_dims;
187 size_t dynamic_input_index = 0;
188 for (size_t i = 0; i < input_infos.size(); i++) {
189 auto &info = input_infos[i];
190 if (!info.is_dynamic_shape) {
191 continue;
192 }
193 if (one_dym_dims.empty()) {
194 one_dym_dims = dynamic_dims_strs[i];
195 dynamic_input_index = i;
196 } else if (one_dym_dims != dynamic_dims_strs[i]) {
197 MS_LOG(ERROR) << "Do not support different dynamic_dims, one " << one_dym_dims << ", other "
198 << dynamic_dims_strs[i];
199 return false;
200 }
201 }
202 int dynamic_type = DynBatchOrDynImage(profile_configs, dynamic_input_index);
203 switch (dynamic_type) {
204 case kDynImgSize:
205 param->aclModelOptionCfgParam.dynamic_image_size =
206 CombineDynamicImageString(profile_configs, dynamic_input_index);
207 break;
208 case kDynBatchSize:
209 param->aclModelOptionCfgParam.dynamic_batch_size = CombineDynamicBatchList(profile_configs, dynamic_input_index);
210 break;
211 default:
212 MS_LOG(ERROR) << "Do not support input shape";
213 return false;
214 }
215 return true;
216 }
217
ParseInputShapeTemplate(const std::string & shape_template,std::set<std::string> * dynamic_symbols)218 int ParseInputShapeTemplate(const std::string &shape_template, std::set<std::string> *dynamic_symbols) {
219 // the inputs_shape config is like: input1:[d0,d1,3];input2:[4,d0]
220 auto graph_inputs_shape_vec = SplitStringToVector(shape_template, ';');
221 for (const auto &graph_input_shape : graph_inputs_shape_vec) {
222 auto graph_input_shape_info = SplitStringToVector(graph_input_shape, ':');
223 MS_CHECK_TRUE_MSG(graph_input_shape_info.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the inputs_shape is invalid");
224 auto input_shape = graph_input_shape_info[1];
225 if (input_shape[0] != '[' || input_shape[input_shape.size() - 1] != ']') {
226 MS_LOG(ERROR) << "the inputs_shape is invalid";
227 return RET_INPUT_PARAM_INVALID;
228 }
229 input_shape = input_shape.substr(1, input_shape.size() - kIndex2);
230 auto input_shape_vec = SplitStringToVector(input_shape, ',');
231 for (const auto &shape : input_shape_vec) {
232 if (!IsNumber(shape)) {
233 dynamic_symbols->insert(shape);
234 }
235 }
236 }
237 return RET_OK;
238 }
239
ParseDynamicDimTemplate(const std::string & dims_template,std::set<std::string> * dynamic_symbols,MicroParamString * micro_param_string)240 int ParseDynamicDimTemplate(const std::string &dims_template, std::set<std::string> *dynamic_symbols,
241 MicroParamString *micro_param_string) {
242 // the dynamic_dim_params config is like: d0:[1,3~6];d1:[1~8]
243 auto dim_info_vec = SplitStringToVector(dims_template, ';');
244 MS_CHECK_TRUE_MSG(dim_info_vec.size() <= kIndex2, RET_NOT_SUPPORT, "currently, only support to set two dynamic dims");
245 for (const auto &dim_info : dim_info_vec) {
246 auto dim_vec = SplitStringToVector(dim_info, ':');
247 MS_CHECK_TRUE_MSG(dim_vec.size() == kIndex2, RET_INPUT_PARAM_INVALID, "the dynamic_dim_params is invalid");
248 std::string symbol = dim_vec[0];
249 if (dynamic_symbols->find(symbol) == dynamic_symbols->end()) {
250 MS_LOG(ERROR) << symbol << "is invalid, because it's not set in the inputs_shape.";
251 return RET_INPUT_PARAM_INVALID;
252 }
253 std::string dim_range = dim_vec[1];
254 if (dim_range[0] != '[' || dim_range[dim_range.size() - 1] != ']') {
255 MS_LOG(ERROR) << "the dynamic_dim_params is invalid";
256 return RET_INPUT_PARAM_INVALID;
257 }
258 dim_range = dim_range.substr(1, dim_range.size() - kIndex2);
259 auto discrete_vec = SplitStringToVector(dim_range, ',');
260 for (const auto &dim : discrete_vec) {
261 auto continuous_dim = SplitStringToVector(dim, '~');
262 MS_CHECK_TRUE_MSG(continuous_dim.size() == kIndex1 || continuous_dim.size() == kIndex2, RET_INPUT_PARAM_INVALID,
263 "the dynamic_dim_params is invalid");
264 if (continuous_dim.size() == kIndex1) {
265 if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0) {
266 MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0";
267 return RET_INPUT_PARAM_INVALID;
268 }
269 micro_param_string->dynamic_symbols_map[symbol].emplace_back(std::stoi(continuous_dim[0]));
270 continue;
271 }
272 if (!IsNumber(continuous_dim[0]) || std::stoi(continuous_dim[0]) <= 0 || !IsNumber(continuous_dim[1]) ||
273 std::stoi(continuous_dim[1]) <= 0) {
274 MS_LOG(ERROR) << "the dynamic_dim_params range value must be greater than 0";
275 return RET_INPUT_PARAM_INVALID;
276 }
277 auto start = std::stoi(continuous_dim[0]);
278 auto end = std::stoi(continuous_dim[1]);
279 for (auto i = start; i <= end; ++i) {
280 micro_param_string->dynamic_symbols_map[symbol].emplace_back(i);
281 }
282 }
283 }
284 return RET_OK;
285 }
286
SetVariableParams(const std::shared_ptr<mindspore::ConverterPara> & param,const std::map<std::string,std::string> & ascend_map)287 void ConfigFileParser::SetVariableParams(const std::shared_ptr<mindspore::ConverterPara> ¶m,
288 const std::map<std::string, std::string> &ascend_map) {
289 auto it = ascend_map.find("inputs_to_variable");
290 if (it != ascend_map.end()) {
291 std::vector<std::string> inputs_to_variables = mindspore::lite::SplitStringToVector(it->second, ',');
292 ProcessVariableParam(inputs_to_variables, inputs_variable_index_);
293 if (CheckVariableParm(inputs_variable_index_) != RET_OK) {
294 MS_LOG(ERROR) << "Check input variable param failed";
295 return;
296 }
297 }
298 auto output_it = ascend_map.find("outputs_to_variable");
299 if (output_it != ascend_map.end()) {
300 std::vector<std::string> outputs_to_variables = mindspore::lite::SplitStringToVector(output_it->second, ',');
301 ProcessVariableParam(outputs_to_variables, outputs_variable_index_);
302 if (CheckVariableParm(outputs_variable_index_) != RET_OK) {
303 MS_LOG(ERROR) << "Check output variable param failed";
304 return;
305 }
306 }
307 if (!inputs_variable_index_.empty() && !outputs_variable_index_.empty() &&
308 inputs_variable_index_.size() != outputs_variable_index_.size()) {
309 MS_LOG(ERROR) << "Input variable number is not equal output variable number";
310 return;
311 }
312 param->ascendGeOptionCfg.inputs_to_variable = inputs_variable_index_;
313 param->ascendGeOptionCfg.outputs_to_variable = outputs_variable_index_;
314 }
315
ProcessVariableParam(const std::vector<std::string> & variable_param,std::vector<int64_t> & variable_index)316 int ConfigFileParser::ProcessVariableParam(const std::vector<std::string> &variable_param,
317 std::vector<int64_t> &variable_index) {
318 for (auto &it : variable_param) {
319 auto remove_str = RemoveInputShapeBrackets(it);
320 int64_t min_index;
321 int64_t max_index;
322 if (!ProfileParser::ParseRangeStr(remove_str, &min_index, &max_index)) {
323 MS_LOG(ERROR) << "Parser range string " << remove_str << " failed";
324 return RET_ERROR;
325 }
326 if (max_index < min_index) {
327 MS_LOG(ERROR) << "The variable param in not valid" << max_index << "is not larger than" << min_index;
328 return RET_ERROR;
329 }
330 for (int64_t i = min_index; i <= max_index; ++i) {
331 variable_index.emplace_back(i);
332 }
333 }
334 return RET_OK;
335 }
336
CheckVariableParm(const std::vector<int64_t> & variable_index)337 int ConfigFileParser::CheckVariableParm(const std::vector<int64_t> &variable_index) {
338 for (size_t i = 1; i < variable_index.size(); ++i) {
339 if (variable_index[i] < variable_index[i - 1]) {
340 MS_LOG(ERROR) << "variable index is not valid" << variable_index[i] << " is less than " << variable_index[i - 1];
341 return RET_ERROR;
342 }
343 }
344 return RET_OK;
345 }
346
CheckPluginCustomOps(const std::vector<std::string> & plugin_custom_ops)347 bool ConfigFileParser::CheckPluginCustomOps(const std::vector<std::string> &plugin_custom_ops) {
348 if (find(plugin_custom_ops.begin(), plugin_custom_ops.end(), "None") != plugin_custom_ops.end() &&
349 plugin_custom_ops.size() != 1) {
350 MS_LOG(ERROR) << "plugin_custom_ops include None, can not include other param.";
351 return false;
352 }
353 return true;
354 }
355
ParseCustomPattern(const std::shared_ptr<mindspore::ConverterPara> & param,std::string custom_pattern_str)356 STATUS ConfigFileParser::ParseCustomPattern(const std::shared_ptr<mindspore::ConverterPara> ¶m,
357 std::string custom_pattern_str) {
358 std::vector<std::string> custom_pattern_strs = mindspore::lite::SplitStringToVector(custom_pattern_str, ";");
359 for (auto custom_pattern : custom_pattern_strs) {
360 std::vector<std::string> item = mindspore::lite::SplitStringToVector(custom_pattern, ":");
361 if (item.size() != kNumSize3) {
362 return RET_ERROR;
363 }
364 std::string op_type = item[0];
365 auto names_list = mindspore::lite::SplitStringToVector(item[1], ",");
366 std::string status = item[kNumSize2];
367 if (status == "enable") {
368 if (param->aclModelOptionCfgParam.enable_custom_fusion_pattern.find(op_type) !=
369 param->aclModelOptionCfgParam.enable_custom_fusion_pattern.end()) {
370 MS_LOG(ERROR) << op_type << " has define, can not defined repeat.";
371 return RET_ERROR;
372 }
373 param->aclModelOptionCfgParam.enable_custom_fusion_pattern[op_type] = names_list;
374 } else if (status == "disable") {
375 if (param->aclModelOptionCfgParam.disable_custom_fusion_pattern.find(op_type) !=
376 param->aclModelOptionCfgParam.disable_custom_fusion_pattern.end()) {
377 MS_LOG(ERROR) << op_type << " has define, can not defined repeat.";
378 return RET_ERROR;
379 }
380 param->aclModelOptionCfgParam.disable_custom_fusion_pattern[op_type] = names_list;
381 } else {
382 MS_LOG(ERROR) << "status only support enable or disable";
383 return RET_ERROR;
384 }
385 }
386 return RET_OK;
387 }
388
SetParamByConfigfile(const std::shared_ptr<mindspore::ConverterPara> & param,const std::map<std::string,std::string> & ascend_map)389 bool ConfigFileParser::SetParamByConfigfile(const std::shared_ptr<mindspore::ConverterPara> ¶m,
390 const std::map<std::string, std::string> &ascend_map) {
391 std::string ascend_string = "";
392 auto set_option = [&ascend_map](const std::string &key, std::string *option) {
393 auto it = ascend_map.find(key);
394 if (it != ascend_map.end() && !it->second.empty()) {
395 *option = it->second;
396 }
397 };
398 set_option("input_format", ¶m->aclModelOptionCfgParam.input_format);
399 set_option("precision_mode", ¶m->aclModelOptionCfgParam.precision_mode);
400 set_option("op_select_impl_mode", ¶m->aclModelOptionCfgParam.op_select_impl_mode);
401 set_option("fusion_switch_config_file_path", ¶m->aclModelOptionCfgParam.fusion_switch_config_file_path);
402 set_option("buffer_optimize", ¶m->aclModelOptionCfgParam.buffer_optimize);
403 set_option("insert_op_config_file_path", ¶m->aclModelOptionCfgParam.insert_op_config_file_path);
404 set_option("om_file_path", ¶m->aclModelOptionCfgParam.om_file_path);
405 set_option("aoe_mode", ¶m->aclModelOptionCfgParam.aoe_mode);
406 set_option(kDumpModelNameKey, ¶m->aclModelOptionCfgParam.dump_model_name);
407 set_option("provider", ¶m->provider);
408
409 auto plugin_custom_ops_str = FindInAscendMap(kPluginCustomOps, ascend_map);
410 std::vector<std::string> plugin_custom_ops_vec = {};
411 if (!plugin_custom_ops_str.empty()) {
412 MS_LOG(INFO) << "plugin_custom_ops: " << plugin_custom_ops_str;
413 plugin_custom_ops_vec = mindspore::lite::SplitStringToVector(plugin_custom_ops_str, ",");
414 if (!CheckPluginCustomOps(plugin_custom_ops_vec)) {
415 return false;
416 }
417 }
418 if (!plugin_custom_ops_vec.empty()) {
419 param->ascendGeOptionCfg.plugin_custom_ops = plugin_custom_ops_vec;
420 } else if (!(ascend_string = FindInAscendMap(kEnableCustomOp, ascend_map)).empty()) {
421 param->ascendGeOptionCfg.plugin_custom_ops = {"All"};
422 }
423 // parse for ascend custom fusion op
424 if (!plugin_custom_ops_vec.empty()) {
425 param->aclModelOptionCfgParam.plugin_custom_ops = plugin_custom_ops_vec;
426 }
427 auto custom_fusion_pattern_str = FindInAscendMap("custom_fusion_pattern", ascend_map);
428 if (!custom_fusion_pattern_str.empty()) {
429 auto status = ParseCustomPattern(param, custom_fusion_pattern_str);
430 if (status != RET_OK) {
431 MS_LOG(ERROR) << "custom fusion pattern wrong, eg:\n"
432 "custom_fusion_pattern=Fusion_op_type:node_name_1,node_name_2:enable\n"
433 "or: "
434 "custom_fusion_pattern=Fusion_op_type:node_name_1,node_name_2:disable";
435 return false;
436 }
437 }
438
439 auto op_attrs_str = FindInAscendMap(kOpAttrs, ascend_map);
440 std::vector<std::string> op_attrs_vec = {};
441 if (!op_attrs_str.empty()) {
442 MS_LOG(INFO) << "op_attrs_str: " << op_attrs_str;
443 op_attrs_vec = mindspore::lite::SplitStringToVector(op_attrs_str, ";");
444 std::map<std::string, std::string> attr;
445 for (auto op_attr_str : op_attrs_vec) {
446 MS_LOG(INFO) << "op_attr: " << op_attr_str;
447 auto op_attr = mindspore::lite::SplitStringToVector(op_attr_str, ":");
448 if (op_attr.size() != kNumSize3) {
449 MS_LOG(ERROR) << "Only support ops:attr:value, but get " << op_attr_str;
450 return false;
451 }
452 auto op_type = op_attr[kNumIndex0];
453 auto attr_key = op_attr[kNumIndex1];
454 auto attr_value = op_attr[kNumIndex2];
455 param->aclModelOptionCfgParam.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
456 param->ascendGeOptionCfg.op_attrs_map[op_type].insert(std::make_pair(attr_key, attr_value));
457 }
458 }
459 for (auto item : param->aclModelOptionCfgParam.op_attrs_map) {
460 for (auto attr : item.second) {
461 MS_LOG(INFO) << "op type: " << item.first << ", key: " << attr.first << ", value: " << attr.second;
462 }
463 }
464
465 auto it = ascend_map.find("input_shape");
466 if (it != ascend_map.end()) {
467 param->aclModelOptionCfgParam.input_shape = RemoveInputShapeBrackets(it->second);
468 }
469
470 it = ascend_map.find("device_id");
471 if (it != ascend_map.end()) {
472 int32_t val;
473 if (mindspore::lite::ConvertIntNum(it->second, &val)) {
474 param->aclModelOptionCfgParam.device_id = val;
475 } else {
476 MS_LOG(ERROR) << "Convert device id failed";
477 return false;
478 }
479 }
480
481 it = ascend_map.find("output_type");
482 if (it != ascend_map.end()) {
483 auto dtype_str = it->second;
484 if (dtype_str == "FP16") {
485 param->aclModelOptionCfgParam.output_type = DataType::kNumberTypeFloat16;
486 } else if (dtype_str == "FP32") {
487 param->aclModelOptionCfgParam.output_type = DataType::kNumberTypeFloat32;
488 } else if (dtype_str == "UINT8") {
489 param->aclModelOptionCfgParam.output_type = DataType::kNumberTypeUInt8;
490 } else {
491 MS_LOG(WARNING) << "Unsupported or invalid output_type, using default type";
492 }
493 }
494 SetVariableParams(param, ascend_map);
495 return SetDynParams(param, ascend_map);
496 }
497
ParseConfigFile(const std::string & config_file_path,std::map<int,std::map<std::string,std::string>> * model_param_infos)498 int ConfigFileParser::ParseConfigFile(const std::string &config_file_path,
499 std::map<int, std::map<std::string, std::string>> *model_param_infos) {
500 std::map<std::string, std::map<std::string, std::string>> maps;
501 auto ret = mindspore::lite::ParseConfigFile(config_file_path, &maps, model_param_infos);
502 if (ret != RET_OK) {
503 MS_LOG(ERROR) << "Parse config file failed.";
504 return ret;
505 }
506 ret = ParseConfigParam(&maps);
507 if (ret != RET_OK) {
508 MS_LOG(ERROR) << "Parse config param failed.";
509 return ret;
510 }
511 return RET_OK;
512 }
513
ParseConfigParam(std::map<std::string,std::map<std::string,std::string>> * maps)514 int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::string, std::string>> *maps) {
515 if (maps == nullptr) {
516 MS_LOG(ERROR) << "Maps is nullptr.";
517 return RET_ERROR;
518 }
519 for (const auto &config_info : *maps) {
520 ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
521 }
522 auto ret = ParseDataPreProcessString(*maps);
523 (void)maps->erase(kDataPreprocessParam);
524 if (ret != RET_OK) {
525 MS_LOG(ERROR) << "ParseDataPreProcessString failed.";
526 return ret;
527 }
528 ret = ParseCommonQuantString(*maps);
529 (void)maps->erase(kCommonQuantParam);
530 if (ret != RET_OK) {
531 MS_LOG(ERROR) << "ParseCommonQuantString failed.";
532 return ret;
533 }
534 ret = ParseMixedBitQuantString(*maps);
535 (void)maps->erase(kMixedBitWeightQuantParam);
536 if (ret != RET_OK) {
537 MS_LOG(ERROR) << "ParseMixedBitQuantString failed.";
538 return ret;
539 }
540 ret = ParseFullQuantString(*maps);
541 (void)maps->erase(kFullQuantParam);
542 if (ret != RET_OK) {
543 MS_LOG(ERROR) << "ParseFullQuantString failed.";
544 return ret;
545 }
546 ret = ParseRegistryInfoString(*maps);
547 (void)maps->erase(kRegistry);
548 if (ret != RET_OK) {
549 MS_LOG(ERROR) << "ParseExtendedintegrationString failed.";
550 return ret;
551 }
552 ret = ParseAclOptionCfgString(*maps);
553 (void)maps->erase(kAclOptionParam);
554 if (ret != RET_OK) {
555 MS_LOG(ERROR) << "ParseAclOptionCfgString failed.";
556 return ret;
557 }
558 ret = ParseMicroParamString(*maps);
559 (void)maps->erase(kMicroParam);
560 if (ret != RET_OK) {
561 MS_LOG(ERROR) << "ParseMicroParamString failed.";
562 return ret;
563 }
564 ret = ParseThirdPartyParamString(*maps);
565 (void)maps->erase(kThirdPartyModelParam);
566 if (ret != RET_OK) {
567 MS_LOG(ERROR) << "ParseTransformQuantString failed.";
568 return ret;
569 }
570 ret = ParseWeightQuantString(*maps);
571 (void)maps->erase(kWeightQuantParam);
572 if (ret != RET_OK) {
573 MS_LOG(ERROR) << "ParseWeightQuantString failed.";
574 return ret;
575 }
576 ret = ParseCpuOptionCfgString(*maps);
577 (void)maps->erase(kCpuOptionParam);
578 if (ret != RET_OK) {
579 MS_LOG(ERROR) << "ParseCpuOptionCfgString failed.";
580 return ret;
581 }
582 ret = ParseTransformQuantString(*maps);
583 (void)maps->erase(kTransformQuantParam);
584 if (ret != RET_OK) {
585 MS_LOG(ERROR) << "ParseTransformQuantString failed.";
586 return ret;
587 }
588 ret = ParseDynamicQuantString(*maps);
589 (void)maps->erase(kDynamicQuantParam);
590 if (ret != RET_OK) {
591 MS_LOG(ERROR) << "ParseDynamicQuantString failed.";
592 return ret;
593 }
594 (void)ParseGraphKernelString(*maps);
595 (void)maps->erase(kGraphKernelParam);
596 ret = ParseOMConverterString(*maps);
597 (void)maps->erase(kOMConverterOptionsSection);
598 if (ret != RET_OK) {
599 MS_LOG(ERROR) << "ParseOMConverterString failed.";
600 return ret;
601 }
602 return RET_OK;
603 }
604
SetMapData(const std::map<std::string,std::string> & input_map,const std::map<std::string,std::string &> & parse_map,const std::string & section,const std::set<std::string> & dynamic_key)605 int ConfigFileParser::SetMapData(const std::map<std::string, std::string> &input_map,
606 const std::map<std::string, std::string &> &parse_map, const std::string §ion,
607 const std::set<std::string> &dynamic_key) {
608 for (const auto &map : input_map) {
609 if (dynamic_key.find(map.first) != dynamic_key.end()) {
610 continue;
611 }
612 if (parse_map.find(map.first) == parse_map.end()) {
613 MS_LOG(ERROR) << "INPUT ILLEGAL: `" << map.first << "` is not supported in "
614 << "[" << section << "]";
615 return RET_INPUT_PARAM_INVALID;
616 } else {
617 parse_map.at(map.first) = map.second;
618 }
619 }
620 return RET_OK;
621 }
622
ParseDataPreProcessString(const std::map<std::string,std::map<std::string,std::string>> & maps)623 int ConfigFileParser::ParseDataPreProcessString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
624 if (maps.find(kDataPreprocessParam) != maps.end()) {
625 const auto &map = maps.at(kDataPreprocessParam);
626 std::map<std::string, std::string &> parse_map{
627 {"calibrate_path", data_pre_process_string_.calibrate_path},
628 {"calibrate_size", data_pre_process_string_.calibrate_size},
629 {"input_type", data_pre_process_string_.input_type},
630 {"image_to_format", data_pre_process_string_.image_to_format},
631 {"normalize_mean", data_pre_process_string_.normalize_mean},
632 {"normalize_std", data_pre_process_string_.normalize_std},
633 {"resize_width", data_pre_process_string_.resize_width},
634 {"resize_height", data_pre_process_string_.resize_height},
635 {"resize_method", data_pre_process_string_.resize_method},
636 {"center_crop_width", data_pre_process_string_.center_crop_width},
637 {"center_crop_height", data_pre_process_string_.center_crop_height},
638 };
639 return SetMapData(map, parse_map, kDataPreprocessParam);
640 }
641 return RET_OK;
642 }
643
ParseCommonQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)644 int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
645 if (maps.find(kCommonQuantParam) != maps.end()) {
646 const auto &map = maps.at(kCommonQuantParam);
647 std::map<std::string, std::string &> parse_map{
648 {"quant_type", common_quant_string_.quant_type},
649 {"bit_num", common_quant_string_.bit_num},
650 {"min_quant_weight_size", common_quant_string_.min_quant_weight_size},
651 {"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
652 {"skip_quant_node", common_quant_string_.skip_quant_node},
653 {"debug_info_save_path", common_quant_string_.debug_info_save_path},
654 {"enable_encode", common_quant_string_.enable_encode},
655 {"workspace", common_quant_string_.workspace},
656 };
657 return SetMapData(map, parse_map, kCommonQuantParam);
658 }
659 return RET_OK;
660 }
661
ParseMixedBitQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)662 int ConfigFileParser::ParseMixedBitQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
663 if (maps.find(kMixedBitWeightQuantParam) != maps.end()) {
664 const auto &map = maps.at(kMixedBitWeightQuantParam);
665 std::map<std::string, std::string &> parse_map{
666 {"init_scale", mixed_bit_quant_string_.init_scale},
667 {"auto_tune", mixed_bit_quant_string_.auto_tune},
668 {"use_cv_data", mixed_bit_quant_string_.use_cv_data},
669 {"max_iterations", mixed_bit_quant_string_.max_iterations},
670 };
671 return SetMapData(map, parse_map, kMixedBitWeightQuantParam);
672 }
673 return RET_OK;
674 }
675
ParseFullQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)676 int ConfigFileParser::ParseFullQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
677 if (maps.find(kFullQuantParam) != maps.end()) {
678 const auto &map = maps.at(kFullQuantParam);
679 std::map<std::string, std::string &> parse_map{
680 {"activation_quant_method", full_quant_string_.activation_quant_method},
681 {"bias_correction", full_quant_string_.bias_correction},
682 {"target_device", full_quant_string_.target_device},
683 {"per_channel", full_quant_string_.per_channel},
684 {"smooth_alpha", full_quant_string_.smooth_alpha},
685 {"enable_smooth_shift", full_quant_string_.enable_smooth_shift},
686 };
687 return SetMapData(map, parse_map, kFullQuantParam);
688 }
689 return RET_OK;
690 }
691
ParseRegistryInfoString(const std::map<std::string,std::map<std::string,std::string>> & maps)692 int ConfigFileParser::ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
693 if (maps.find(kRegistry) != maps.end()) {
694 const auto &map = maps.at(kRegistry);
695 std::map<std::string, std::string &> parse_map{
696 {"plugin_path", registry_info_string_.plugin_path},
697 {"disable_fusion", registry_info_string_.disable_fusion},
698 {"fusion_blacklists", registry_info_string_.fusion_blacklists},
699 };
700 return SetMapData(map, parse_map, kRegistry);
701 }
702 return RET_OK;
703 }
704
ParseAclOptionCfgString(const std::map<std::string,std::map<std::string,std::string>> & maps)705 int ConfigFileParser::ParseAclOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
706 if (maps.find(kAclOptionParam) != maps.end()) {
707 const auto &map = maps.at(kAclOptionParam);
708 std::map<std::string, std::string &> parse_map{
709 {"device_id", acl_option_cfg_string_.device_id},
710 {"input_format", acl_option_cfg_string_.input_format},
711 {"input_shape_vector", acl_option_cfg_string_.input_shape_vector},
712 {"input_shape", acl_option_cfg_string_.input_shape},
713 {"output_type", acl_option_cfg_string_.output_type},
714 {"precision_mode", acl_option_cfg_string_.precision_mode},
715 {"op_select_impl_mode", acl_option_cfg_string_.op_select_impl_mode},
716 {"fusion_switch_config_file_path", acl_option_cfg_string_.fusion_switch_config_file_path},
717 {"dynamic_batch_size", acl_option_cfg_string_.dynamic_batch_size},
718 {"buffer_optimize", acl_option_cfg_string_.buffer_optimize},
719 {"insert_op_config_file_path", acl_option_cfg_string_.insert_op_config_file_path},
720 {"dynamic_image_size", acl_option_cfg_string_.dynamic_image_size},
721 {"dynamic_dims", acl_option_cfg_string_.dynamic_dims},
722 {"aoe_mode", acl_option_cfg_string_.aoe_mode},
723 {"custom_opp_path", acl_option_cfg_string_.custom_opp_path}};
724 auto ret = SetMapData(map, parse_map, kAclOptionParam);
725 if (ret != RET_OK) {
726 MS_LOG(ERROR) << "set map data failed.";
727 return ret;
728 }
729 }
730 if (maps.find(kAclInitOptionParam) != maps.end()) {
731 const auto &map = maps.at(kAclInitOptionParam);
732 for (const auto &item : map) {
733 (void)acl_option_cfg_string_.init_options_map.emplace(item.first, item.second);
734 }
735 }
736 if (maps.find(kAclBuildOptionParam) != maps.end()) {
737 const auto &map = maps.at(kAclBuildOptionParam);
738 for (const auto &item : map) {
739 (void)acl_option_cfg_string_.build_options_map.emplace(item.first, item.second);
740 }
741 }
742 if (maps.find(kAoeGlobalOptionsSection) != maps.end()) {
743 const auto &map = maps.at(kAoeGlobalOptionsSection);
744 for (const auto &item : map) {
745 (void)acl_option_cfg_string_.aoe_global_options_map.emplace(item.first, item.second);
746 }
747 }
748 if (maps.find(kAoeTuningOptionsSection) != maps.end()) {
749 const auto &map = maps.at(kAoeTuningOptionsSection);
750 for (const auto &item : map) {
751 (void)acl_option_cfg_string_.aoe_tuning_options_map.emplace(item.first, item.second);
752 }
753 }
754 return RET_OK;
755 }
756
ParseMicroParamString(const std::map<std::string,std::map<std::string,std::string>> & maps)757 int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
758 if (maps.find(kMicroParam) == maps.end()) {
759 return RET_OK;
760 }
761 const auto &map = maps.at(kMicroParam);
762 const std::string graph_inputs_shape_template = "inputs_shape";
763 std::set<std::string> dynamic_symbols;
764 if (map.find(graph_inputs_shape_template) != map.end()) {
765 const auto &shape_template = map.at(graph_inputs_shape_template);
766 MS_CHECK_TRUE_MSG(ParseInputShapeTemplate(shape_template, &dynamic_symbols) == RET_OK, RET_ERROR,
767 "ParseInputShapeTemplate failed");
768 }
769 const std::string dynamic_dims = "dynamic_dim_params";
770 if (!dynamic_symbols.empty() && map.find(dynamic_dims) != map.end()) {
771 const auto &dims_template = map.at(dynamic_dims);
772 MS_CHECK_TRUE_MSG(ParseDynamicDimTemplate(dims_template, &dynamic_symbols, µ_param_string_) == RET_OK,
773 RET_ERROR, "ParseDynamicDimTemplate failed");
774 }
775 std::map<std::string, std::string &> parse_map{
776 {"target", micro_param_string_.target},
777 {"codegen_mode", micro_param_string_.codegen_mode},
778 {"debug_mode", micro_param_string_.debug_mode},
779 {"support_parallel", micro_param_string_.support_parallel},
780 {"enable_micro", micro_param_string_.enable_micro},
781 {"save_path", micro_param_string_.save_path},
782 {"project_name", micro_param_string_.project_name},
783 {"keep_original_weight", micro_param_string_.keep_original_weight},
784 {"changeable_weights_name", micro_param_string_.changeable_weights_name},
785 {"inputs_shape", micro_param_string_.inputs_shape},
786 {"dynamic_dim_params", micro_param_string_.dynamic_dim_params}};
787 return SetMapData(map, parse_map, kMicroParam);
788 }
789
ParseWeightQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)790 int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
791 if (maps.find(kWeightQuantParam) != maps.end()) {
792 const auto &map = maps.at(kWeightQuantParam);
793 std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy},
794 {"update_mindir", weight_quant_string_.update_mindir},
795 {"max_segments", weight_quant_string_.max_segments},
796 {"per_channel", weight_quant_string_.per_channel},
797 {"bias_correction", weight_quant_string_.bias_correction},
798 {"quant_strategy", weight_quant_string_.quant_strategy}};
799 return SetMapData(map, parse_map, kWeightQuantParam);
800 }
801 return RET_OK;
802 }
803
ParseCpuOptionCfgString(const std::map<std::string,std::map<std::string,std::string>> & maps)804 int ConfigFileParser::ParseCpuOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
805 if (maps.find(kCpuOptionParam) != maps.end()) {
806 const auto &map = maps.at(kCpuOptionParam);
807 std::map<std::string, std::string &> parse_map{{"architecture", cpu_option_cfg_string_.architecture},
808 {"instruction", cpu_option_cfg_string_.instruction}};
809 auto ret = SetMapData(map, parse_map, kCpuOptionParam);
810 if (cpu_option_cfg_string_.architecture.empty() || cpu_option_cfg_string_.instruction.empty()) {
811 MS_LOG(WARNING) << "[cpu_option_cfg_param] set incompletely, the model won't do optimize for cpu, please "
812 "check the parameter architecture and instruction are correct.";
813 }
814 return ret;
815 }
816 return RET_OK;
817 }
818
ParseTransformQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)819 int ConfigFileParser::ParseTransformQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
820 if (maps.find(kTransformQuantParam) != maps.end()) {
821 const auto &map = maps.at(kTransformQuantParam);
822 std::map<std::string, std::string &> parse_map{
823 {"export_precision_mode", transform_quant_string_.export_precision_mode},
824 };
825 return SetMapData(map, parse_map, kTransformQuantParam);
826 }
827 return RET_OK;
828 }
829
ParseDynamicQuantString(const std::map<std::string,std::map<std::string,std::string>> & maps)830 int ConfigFileParser::ParseDynamicQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
831 if (maps.find(kDynamicQuantParam) != maps.end()) {
832 const auto &map = maps.at(kDynamicQuantParam);
833 std::map<std::string, std::string &> parse_map{
834 {"quant_strategy", dynamic_quant_string_.quant_strategy},
835 };
836 return SetMapData(map, parse_map, kDynamicQuantParam);
837 }
838 return RET_OK;
839 }
840
ParseGraphKernelString(const std::map<std::string,std::map<std::string,std::string>> & maps)841 int ConfigFileParser::ParseGraphKernelString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
842 if (maps.find(kGraphKernelParam) != maps.end()) {
843 const auto &map = maps.at(kGraphKernelParam);
844 for (const auto &item : map) {
845 std::stringstream oss;
846 oss << "--" << item.first << "=" << item.second;
847 (void)graph_kernel_string_.emplace_back(oss.str());
848 }
849 }
850 return RET_OK;
851 }
852
ParseOMConverterString(const std::map<std::string,std::map<std::string,std::string>> & maps)853 int ConfigFileParser::ParseOMConverterString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
854 if (maps.find(kOMConverterOptionsSection) != maps.end()) {
855 const auto &map = maps.at(kOMConverterOptionsSection);
856 std::map<std::string, std::string &> parse_map{
857 {"input_name_vector", om_converter_string_.input_name_vector},
858 {"input_shape_vector", om_converter_string_.input_shape_vector},
859 {"input_data_type_vector", om_converter_string_.input_data_type_vector},
860 {"output_name_vector", om_converter_string_.output_name_vector},
861 {"output_shape_vector", om_converter_string_.output_shape_vector},
862 {"output_data_type_vector", om_converter_string_.output_data_type_vector}};
863 auto ret = SetMapData(map, parse_map, kOMConverterOptionsSection);
864 if (ret != RET_OK) {
865 MS_LOG(ERROR) << "Set map data failed.";
866 return ret;
867 }
868 }
869 return RET_OK;
870 }
871
ParseThirdPartyParamString(const std::map<std::string,std::map<std::string,std::string>> & sections)872 int ConfigFileParser::ParseThirdPartyParamString(
873 const std::map<std::string, std::map<std::string, std::string>> §ions) {
874 if (sections.find(kThirdPartyModelParam) == sections.end()) {
875 return RET_OK;
876 }
877 const auto &input_args = sections.at(kThirdPartyModelParam);
878 const std::map<std::string, std::string &> kValidArgs = {
879 {"input_shapes", third_party_model_string_.input_shapes},
880 {"input_dtypes", third_party_model_string_.input_dtypes},
881 {"input_names", third_party_model_string_.input_names},
882 {"input_formats", third_party_model_string_.input_formats},
883 {"output_shapes", third_party_model_string_.output_shapes},
884 {"output_dtypes", third_party_model_string_.output_dtypes},
885 {"output_names", third_party_model_string_.output_names},
886 {"output_formats", third_party_model_string_.output_formats},
887 {"extended_parameters", third_party_model_string_.extended_parameters},
888 };
889 return SetMapData(input_args, kValidArgs, kThirdPartyModelParam);
890 }
891 } // namespace lite
892 } // namespace mindspore
893