• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/config_parser/micro_param_parser.h"
18 #include "tools/converter/micro/coder/config.h"
19 #include "tools/common/string_util.h"
20 #include "src/common/log_adapter.h"
21 #include "src/common/log_util.h"
22 #include "nnacl/op_base.h"
23 
24 namespace mindspore {
25 namespace lite {
ParseTarget(const std::string & target,micro::MicroParam * micro_param)26 STATUS MicroParamParser::ParseTarget(const std::string &target, micro::MicroParam *micro_param) {
27   MS_LOG(DEBUG) << "Micro HW target: " << target;
28   if (!target.empty()) {
29     micro_param->target = target;
30   }
31   return RET_OK;
32 }
ParseCodeGenMode(const std::string & codegen_mode,micro::MicroParam * micro_param)33 STATUS MicroParamParser::ParseCodeGenMode(const std::string &codegen_mode, micro::MicroParam *micro_param) {
34   MS_LOG(DEBUG) << "Micro codegen mode: " << codegen_mode;
35   if (!codegen_mode.empty()) {
36     micro_param->codegen_mode = codegen_mode;
37   }
38   return RET_OK;
39 }
ParseSupportParallel(const std::string & support_parallel,micro::MicroParam * micro_param)40 STATUS MicroParamParser::ParseSupportParallel(const std::string &support_parallel, micro::MicroParam *micro_param) {
41   MS_LOG(DEBUG) << "Micro supports parallel: " << support_parallel;
42   if (support_parallel.empty()) {
43     return RET_OK;
44   }
45   micro_param->support_parallel = false;  // default
46   bool is_parallel;
47   if (ConvertBool(support_parallel, &is_parallel)) {
48     micro_param->support_parallel = is_parallel;
49   }
50   return RET_OK;
51 }
ParseDebugMode(const std::string & debug_mode,micro::MicroParam * micro_param)52 STATUS MicroParamParser::ParseDebugMode(const std::string &debug_mode, micro::MicroParam *micro_param) {
53   MS_LOG(DEBUG) << "Micro enables debug mode: " << debug_mode;
54   if (debug_mode.empty()) {
55     return RET_OK;
56   }
57   micro_param->debug_mode = false;  // default
58   bool is_debug_mode;
59   if (ConvertBool(debug_mode, &is_debug_mode)) {
60     micro_param->debug_mode = is_debug_mode;
61   }
62   return RET_OK;
63 }
64 
ParseEnableMicro(const std::string & enable_micro,micro::MicroParam * micro_param)65 STATUS MicroParamParser::ParseEnableMicro(const std::string &enable_micro, micro::MicroParam *micro_param) {
66   MS_LOG(DEBUG) << "Micro enables : " << enable_micro;
67   if (enable_micro.empty()) {
68     return RET_OK;
69   }
70   micro_param->enable_micro = false;  // default
71   bool is_enable_micro;
72   if (ConvertBool(enable_micro, &is_enable_micro)) {
73     micro_param->enable_micro = is_enable_micro;
74   }
75   return RET_OK;
76 }
77 
ParseSavePath(const std::string & save_path,micro::MicroParam * micro_param)78 STATUS MicroParamParser::ParseSavePath(const std::string &save_path, micro::MicroParam *micro_param) {
79   MS_LOG(DEBUG) << "Micro save path : " << save_path;
80   if (!save_path.empty()) {
81     micro_param->save_path = save_path;
82   }
83   return RET_OK;
84 }
85 
ParseProjName(const std::string & project_name,micro::MicroParam * micro_param)86 STATUS MicroParamParser::ParseProjName(const std::string &project_name, micro::MicroParam *micro_param) {
87   MS_LOG(DEBUG) << "Micro project name : " << project_name;
88   if (!project_name.empty()) {
89     micro_param->project_name = project_name;
90   }
91   return RET_OK;
92 }
93 
ParseKeepOriginalWeight(const std::string & keep_weight,micro::MicroParam * micro_param)94 STATUS MicroParamParser::ParseKeepOriginalWeight(const std::string &keep_weight, micro::MicroParam *micro_param) {
95   MS_LOG(DEBUG) << "Micro enables : " << keep_weight;
96   if (keep_weight.empty()) {
97     return RET_OK;
98   }
99   micro_param->keep_original_weight = false;  // default
100   bool is_keep_original_weight;
101   if (ConvertBool(keep_weight, &is_keep_original_weight)) {
102     micro_param->keep_original_weight = is_keep_original_weight;
103   } else {
104     MS_LOG(ERROR) << "Micro param invalid, keep_original_weight can only be set as true or false.";
105     return RET_INPUT_PARAM_INVALID;
106   }
107   return RET_OK;
108 }
109 
ParseChangeableWeightsName(const std::string & changeable_weights_name,micro::MicroParam * micro_param)110 STATUS MicroParamParser::ParseChangeableWeightsName(const std::string &changeable_weights_name,
111                                                     micro::MicroParam *micro_param) {
112   MS_LOG(DEBUG) << "Micro record changeable weights name: " << changeable_weights_name;
113   if (!changeable_weights_name.empty()) {
114     micro_param->changeable_weights_name = changeable_weights_name;
115   }
116   return RET_OK;
117 }
118 
ParseGraphInputsShapeTemplate(const std::string & graph_inputs_shape_template,const std::map<std::string,std::vector<int>> & dynamic_symbols_map,micro::MicroParam * micro_param)119 STATUS MicroParamParser::ParseGraphInputsShapeTemplate(
120   const std::string &graph_inputs_shape_template, const std::map<std::string, std::vector<int>> &dynamic_symbols_map,
121   micro::MicroParam *micro_param) {
122   MS_LOG(DEBUG) << "Micro record inputs shape: " << graph_inputs_shape_template;
123   if (!graph_inputs_shape_template.empty()) {
124     auto graph_inputs_shape_vec = SplitStringToVector(graph_inputs_shape_template, ';');
125     std::map<std::string, std::vector<std::string>> graph_inputs_info;
126     std::vector<std::vector<std::string>> graph_inputs_shape;
127     std::vector<std::string> inputs_name;
128     for (const auto &graph_input_shape : graph_inputs_shape_vec) {
129       auto input_shape_info = SplitStringToVector(graph_input_shape, ':');
130       std::string input_name = input_shape_info[0];
131       std::string input_shape = input_shape_info[1].substr(1, input_shape_info[1].size() - C2NUM);
132       auto input_shape_vec = SplitStringToVector(input_shape, ',');
133       graph_inputs_info[input_name] = input_shape_vec;
134       graph_inputs_shape.push_back(input_shape_vec);
135       inputs_name.push_back(input_name);
136     }
137     micro_param->graph_inputs_origin_info = graph_inputs_info;
138     micro_param->inputs_shape_by_scenes.clear();
139     std::map<std::string, std::vector<int>> symbols_to_num;
140     std::map<std::string, int> symbols_index;
141     std::vector<std::string> symbols;
142     std::vector<size_t> scene_num_by_symbol;
143     int index = 0;
144     size_t scene_num = 1;
145     for (const auto &item : dynamic_symbols_map) {
146       symbols_index[item.first] = index++;
147       symbols.push_back(item.first);
148       for (const auto &num : item.second) {
149         symbols_to_num[item.first].push_back(num);
150       }
151       if (symbols_to_num[item.first].empty()) {
152         MS_LOG(ERROR) << "Micro param invalid, dynamic symbol must have value.";
153         return RET_INPUT_PARAM_INVALID;
154       }
155       scene_num_by_symbol.push_back(symbols_to_num[item.first].size());
156       scene_num *= symbols_to_num[item.first].size();
157     }
158     micro_param->dynamic_symbols = symbols;
159     micro_param->dynamic_symbols_num = scene_num_by_symbol;
160     micro_param->dynamic_symbols_map = dynamic_symbols_map;
161     std::vector<size_t> post_multi(symbols.size(), 1);
162     for (int i = static_cast<int>(post_multi.size()) - 2; i >= 0; --i) {
163       post_multi[i] = post_multi[i + 1] * scene_num_by_symbol[i + 1];
164     }
165     std::vector<int> real_num(symbols.size());
166     for (size_t i = 0; i < scene_num; ++i) {
167       size_t remain = i;
168       for (size_t j = 0; j < symbols.size(); ++j) {
169         real_num[j] = remain / post_multi[j];
170         remain %= post_multi[j];
171       }
172       for (size_t j = 0; j < graph_inputs_shape.size(); ++j) {
173         const auto &input_template = graph_inputs_shape[j];
174         std::vector<int> input_shape;
175         for (const auto &dim : input_template) {
176           if (IsNumber(dim)) {
177             input_shape.push_back(std::stoi(dim));
178             continue;
179           }
180           if (symbols_index.find(dim) == symbols_index.end()) {
181             MS_LOG(ERROR) << "Dynamic symbol cannot find real num.";
182             return RET_INPUT_PARAM_INVALID;
183           }
184           input_shape.push_back(symbols_to_num[dim][real_num[symbols_index[dim]]]);
185         }
186         micro_param->inputs_shape_by_scenes[inputs_name[j]].push_back(input_shape);
187       }
188     }
189   }
190   return RET_OK;
191 }
192 
ParseMicroParam(const MicroParamString & micro_param_string,micro::MicroParam * micro_param)193 STATUS MicroParamParser::ParseMicroParam(const MicroParamString &micro_param_string, micro::MicroParam *micro_param) {
194   CHECK_NULL_RETURN(micro_param);
195   if (ParseTarget(micro_param_string.target, micro_param) != RET_OK) {
196     MS_LOG(ERROR) << "Parse HW target val: " << micro_param_string.target;
197     return RET_INPUT_PARAM_INVALID;
198   }
199   if (ParseCodeGenMode(micro_param_string.codegen_mode, micro_param) != RET_OK) {
200     MS_LOG(ERROR) << "Parse codegen_mode val; " << micro_param_string.codegen_mode;
201     return RET_INPUT_PARAM_INVALID;
202   }
203   if (ParseSupportParallel(micro_param_string.support_parallel, micro_param) != RET_OK) {
204     MS_LOG(ERROR) << "Parse support_parallel val; " << micro_param_string.support_parallel;
205     return RET_INPUT_PARAM_INVALID;
206   }
207   if (ParseDebugMode(micro_param_string.debug_mode, micro_param) != RET_OK) {
208     MS_LOG(ERROR) << "Parse debug mode val; " << micro_param_string.debug_mode;
209     return RET_INPUT_PARAM_INVALID;
210   }
211   if (ParseEnableMicro(micro_param_string.enable_micro, micro_param) != RET_OK) {
212     MS_LOG(ERROR) << "Parse enable micro val; " << micro_param_string.enable_micro;
213     return RET_INPUT_PARAM_INVALID;
214   }
215   if (ParseSavePath(micro_param_string.save_path, micro_param) != RET_OK) {
216     MS_LOG(ERROR) << "Parse save path val failed: " << micro_param_string.save_path;
217     return RET_INPUT_PARAM_INVALID;
218   }
219   if (ParseProjName(micro_param_string.project_name, micro_param) != RET_OK) {
220     MS_LOG(ERROR) << "Parse project name val failed: " << micro_param_string.project_name;
221     return RET_INPUT_PARAM_INVALID;
222   }
223   if (!micro_param_string.keep_original_weight.empty()) {
224     if (ParseKeepOriginalWeight(micro_param_string.keep_original_weight, micro_param) != RET_OK) {
225       MS_LOG(ERROR) << "Parse keep_original_weight failed, the val: " << micro_param_string.keep_original_weight;
226       return RET_INPUT_PARAM_INVALID;
227     }
228   }
229   if (!micro_param_string.changeable_weights_name.empty() && !micro_param->keep_original_weight) {
230     MS_LOG(ERROR) << "When changeable_weights_name is set, the keep_original_weight must be true.";
231     return RET_INPUT_PARAM_INVALID;
232   }
233   if (ParseChangeableWeightsName(micro_param_string.changeable_weights_name, micro_param) != RET_OK) {
234     MS_LOG(ERROR) << "Parse changeable_weights_name failed, the val: " << micro_param_string.changeable_weights_name;
235     return RET_INPUT_PARAM_INVALID;
236   }
237   if (ParseGraphInputsShapeTemplate(micro_param_string.inputs_shape, micro_param_string.dynamic_symbols_map,
238                                     micro_param) != RET_OK) {
239     MS_LOG(ERROR) << "Parse inputs_shape & dynamic_dim_params failed, the inputs_shape val: "
240                   << micro_param_string.inputs_shape;
241     return RET_INPUT_PARAM_INVALID;
242   }
243   return RET_OK;
244 }
245 }  // namespace lite
246 }  // namespace mindspore
247