1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/config_parser/micro_param_parser.h"
18 #include "tools/converter/micro/coder/config.h"
19 #include "tools/common/string_util.h"
20 #include "src/common/log_adapter.h"
21 #include "src/common/log_util.h"
22 #include "nnacl/op_base.h"
23
24 namespace mindspore {
25 namespace lite {
ParseTarget(const std::string & target,micro::MicroParam * micro_param)26 STATUS MicroParamParser::ParseTarget(const std::string &target, micro::MicroParam *micro_param) {
27 MS_LOG(DEBUG) << "Micro HW target: " << target;
28 if (!target.empty()) {
29 micro_param->target = target;
30 }
31 return RET_OK;
32 }
ParseCodeGenMode(const std::string & codegen_mode,micro::MicroParam * micro_param)33 STATUS MicroParamParser::ParseCodeGenMode(const std::string &codegen_mode, micro::MicroParam *micro_param) {
34 MS_LOG(DEBUG) << "Micro codegen mode: " << codegen_mode;
35 if (!codegen_mode.empty()) {
36 micro_param->codegen_mode = codegen_mode;
37 }
38 return RET_OK;
39 }
ParseSupportParallel(const std::string & support_parallel,micro::MicroParam * micro_param)40 STATUS MicroParamParser::ParseSupportParallel(const std::string &support_parallel, micro::MicroParam *micro_param) {
41 MS_LOG(DEBUG) << "Micro supports parallel: " << support_parallel;
42 if (support_parallel.empty()) {
43 return RET_OK;
44 }
45 micro_param->support_parallel = false; // default
46 bool is_parallel;
47 if (ConvertBool(support_parallel, &is_parallel)) {
48 micro_param->support_parallel = is_parallel;
49 }
50 return RET_OK;
51 }
ParseDebugMode(const std::string & debug_mode,micro::MicroParam * micro_param)52 STATUS MicroParamParser::ParseDebugMode(const std::string &debug_mode, micro::MicroParam *micro_param) {
53 MS_LOG(DEBUG) << "Micro enables debug mode: " << debug_mode;
54 if (debug_mode.empty()) {
55 return RET_OK;
56 }
57 micro_param->debug_mode = false; // default
58 bool is_debug_mode;
59 if (ConvertBool(debug_mode, &is_debug_mode)) {
60 micro_param->debug_mode = is_debug_mode;
61 }
62 return RET_OK;
63 }
64
ParseEnableMicro(const std::string & enable_micro,micro::MicroParam * micro_param)65 STATUS MicroParamParser::ParseEnableMicro(const std::string &enable_micro, micro::MicroParam *micro_param) {
66 MS_LOG(DEBUG) << "Micro enables : " << enable_micro;
67 if (enable_micro.empty()) {
68 return RET_OK;
69 }
70 micro_param->enable_micro = false; // default
71 bool is_enable_micro;
72 if (ConvertBool(enable_micro, &is_enable_micro)) {
73 micro_param->enable_micro = is_enable_micro;
74 }
75 return RET_OK;
76 }
77
ParseSavePath(const std::string & save_path,micro::MicroParam * micro_param)78 STATUS MicroParamParser::ParseSavePath(const std::string &save_path, micro::MicroParam *micro_param) {
79 MS_LOG(DEBUG) << "Micro save path : " << save_path;
80 if (!save_path.empty()) {
81 micro_param->save_path = save_path;
82 }
83 return RET_OK;
84 }
85
ParseProjName(const std::string & project_name,micro::MicroParam * micro_param)86 STATUS MicroParamParser::ParseProjName(const std::string &project_name, micro::MicroParam *micro_param) {
87 MS_LOG(DEBUG) << "Micro project name : " << project_name;
88 if (!project_name.empty()) {
89 micro_param->project_name = project_name;
90 }
91 return RET_OK;
92 }
93
ParseKeepOriginalWeight(const std::string & keep_weight,micro::MicroParam * micro_param)94 STATUS MicroParamParser::ParseKeepOriginalWeight(const std::string &keep_weight, micro::MicroParam *micro_param) {
95 MS_LOG(DEBUG) << "Micro enables : " << keep_weight;
96 if (keep_weight.empty()) {
97 return RET_OK;
98 }
99 micro_param->keep_original_weight = false; // default
100 bool is_keep_original_weight;
101 if (ConvertBool(keep_weight, &is_keep_original_weight)) {
102 micro_param->keep_original_weight = is_keep_original_weight;
103 } else {
104 MS_LOG(ERROR) << "Micro param invalid, keep_original_weight can only be set as true or false.";
105 return RET_INPUT_PARAM_INVALID;
106 }
107 return RET_OK;
108 }
109
ParseChangeableWeightsName(const std::string & changeable_weights_name,micro::MicroParam * micro_param)110 STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeable_weights_name,
111 micro::MicroParam *micro_param) {
112 MS_LOG(DEBUG) << "Micro record changeable weights name: " << changeable_weights_name;
113 if (!changeable_weights_name.empty()) {
114 micro_param->changeable_weights_name = changeable_weights_name;
115 }
116 return RET_OK;
117 }
118
ParseGraphInputsShapeTemplate(const std::string & graph_inputs_shape_template,const std::map<std::string,std::vector<int>> & dynamic_symbols_map,micro::MicroParam * micro_param)119 STATUS MicroParamParser::ParseGraphInputsShapeTemplate(
120 const std::string &graph_inputs_shape_template, const std::map<std::string, std::vector<int>> &dynamic_symbols_map,
121 micro::MicroParam *micro_param) {
122 MS_LOG(DEBUG) << "Micro record inputs shape: " << graph_inputs_shape_template;
123 if (!graph_inputs_shape_template.empty()) {
124 auto graph_inputs_shape_vec = SplitStringToVector(graph_inputs_shape_template, ';');
125 std::map<std::string, std::vector<std::string>> graph_inputs_info;
126 std::vector<std::vector<std::string>> graph_inputs_shape;
127 std::vector<std::string> inputs_name;
128 for (const auto &graph_input_shape : graph_inputs_shape_vec) {
129 auto input_shape_info = SplitStringToVector(graph_input_shape, ':');
130 std::string input_name = input_shape_info[0];
131 std::string input_shape = input_shape_info[1].substr(1, input_shape_info[1].size() - C2NUM);
132 auto input_shape_vec = SplitStringToVector(input_shape, ',');
133 graph_inputs_info[input_name] = input_shape_vec;
134 graph_inputs_shape.push_back(input_shape_vec);
135 inputs_name.push_back(input_name);
136 }
137 micro_param->graph_inputs_origin_info = graph_inputs_info;
138 micro_param->inputs_shape_by_scenes.clear();
139 std::map<std::string, std::vector<int>> symbols_to_num;
140 std::map<std::string, int> symbols_index;
141 std::vector<std::string> symbols;
142 std::vector<size_t> scene_num_by_symbol;
143 int index = 0;
144 size_t scene_num = 1;
145 for (const auto &item : dynamic_symbols_map) {
146 symbols_index[item.first] = index++;
147 symbols.push_back(item.first);
148 for (const auto &num : item.second) {
149 symbols_to_num[item.first].push_back(num);
150 }
151 if (symbols_to_num[item.first].empty()) {
152 MS_LOG(ERROR) << "Micro param invalid, dynamic symbol must have value.";
153 return RET_INPUT_PARAM_INVALID;
154 }
155 scene_num_by_symbol.push_back(symbols_to_num[item.first].size());
156 scene_num *= symbols_to_num[item.first].size();
157 }
158 micro_param->dynamic_symbols = symbols;
159 micro_param->dynamic_symbols_num = scene_num_by_symbol;
160 micro_param->dynamic_symbols_map = dynamic_symbols_map;
161 std::vector<size_t> post_multi(symbols.size(), 1);
162 for (int i = static_cast<int>(post_multi.size()) - 2; i >= 0; --i) {
163 post_multi[i] = post_multi[i + 1] * scene_num_by_symbol[i + 1];
164 }
165 std::vector<int> real_num(symbols.size());
166 for (size_t i = 0; i < scene_num; ++i) {
167 size_t remain = i;
168 for (size_t j = 0; j < symbols.size(); ++j) {
169 real_num[j] = remain / post_multi[j];
170 remain %= post_multi[j];
171 }
172 for (size_t j = 0; j < graph_inputs_shape.size(); ++j) {
173 const auto &input_template = graph_inputs_shape[j];
174 std::vector<int> input_shape;
175 for (const auto &dim : input_template) {
176 if (IsNumber(dim)) {
177 input_shape.push_back(std::stoi(dim));
178 continue;
179 }
180 if (symbols_index.find(dim) == symbols_index.end()) {
181 MS_LOG(ERROR) << "Dynamic symbol cannot find real num.";
182 return RET_INPUT_PARAM_INVALID;
183 }
184 input_shape.push_back(symbols_to_num[dim][real_num[symbols_index[dim]]]);
185 }
186 micro_param->inputs_shape_by_scenes[inputs_name[j]].push_back(input_shape);
187 }
188 }
189 }
190 return RET_OK;
191 }
192
ParseMicroParam(const MicroParamString & micro_param_string,micro::MicroParam * micro_param)193 STATUS MicroParamParser::ParseMicroParam(const MicroParamString µ_param_string, micro::MicroParam *micro_param) {
194 CHECK_NULL_RETURN(micro_param);
195 if (ParseTarget(micro_param_string.target, micro_param) != RET_OK) {
196 MS_LOG(ERROR) << "Parse HW target val: " << micro_param_string.target;
197 return RET_INPUT_PARAM_INVALID;
198 }
199 if (ParseCodeGenMode(micro_param_string.codegen_mode, micro_param) != RET_OK) {
200 MS_LOG(ERROR) << "Parse codegen_mode val; " << micro_param_string.codegen_mode;
201 return RET_INPUT_PARAM_INVALID;
202 }
203 if (ParseSupportParallel(micro_param_string.support_parallel, micro_param) != RET_OK) {
204 MS_LOG(ERROR) << "Parse support_parallel val; " << micro_param_string.support_parallel;
205 return RET_INPUT_PARAM_INVALID;
206 }
207 if (ParseDebugMode(micro_param_string.debug_mode, micro_param) != RET_OK) {
208 MS_LOG(ERROR) << "Parse debug mode val; " << micro_param_string.debug_mode;
209 return RET_INPUT_PARAM_INVALID;
210 }
211 if (ParseEnableMicro(micro_param_string.enable_micro, micro_param) != RET_OK) {
212 MS_LOG(ERROR) << "Parse enable micro val; " << micro_param_string.enable_micro;
213 return RET_INPUT_PARAM_INVALID;
214 }
215 if (ParseSavePath(micro_param_string.save_path, micro_param) != RET_OK) {
216 MS_LOG(ERROR) << "Parse save path val failed: " << micro_param_string.save_path;
217 return RET_INPUT_PARAM_INVALID;
218 }
219 if (ParseProjName(micro_param_string.project_name, micro_param) != RET_OK) {
220 MS_LOG(ERROR) << "Parse project name val failed: " << micro_param_string.project_name;
221 return RET_INPUT_PARAM_INVALID;
222 }
223 if (!micro_param_string.keep_original_weight.empty()) {
224 if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) {
225 MS_LOG(ERROR) << "Parse keep_original_weight failed, the val: " << micro_param_string.keep_original_weight;
226 return RET_INPUT_PARAM_INVALID;
227 }
228 }
229 if (!micro_param_string.changeable_weights_name.empty() && !micro_param->keep_original_weight) {
230 MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true.";
231 return RET_INPUT_PARAM_INVALID;
232 }
233 if (ParseChangeableWeightsName(micro_param_string.changeable_weights_name, micro_param) != RET_OK) {
234 MS_LOG(ERROR) << "Parse changeable_weights_name failed, the val: " << micro_param_string.changeable_weights_name;
235 return RET_INPUT_PARAM_INVALID;
236 }
237 if (ParseGraphInputsShapeTemplate(micro_param_string.inputs_shape, micro_param_string.dynamic_symbols_map,
238 micro_param) != RET_OK) {
239 MS_LOG(ERROR) << "Parse inputs_shape & dynamic_dim_params failed, the inputs_shape val: "
240 << micro_param_string.inputs_shape;
241 return RET_INPUT_PARAM_INVALID;
242 }
243 return RET_OK;
244 }
245 } // namespace lite
246 } // namespace mindspore
247