• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/config_parser/quant_param_parser.h"
18 #include <vector>
19 #include "src/common/log_adapter.h"
20 #include "mindspore/lite/tools/common/string_util.h"
21 #include "include/errorcode.h"
22 
23 namespace mindspore {
24 namespace lite {
25 namespace {
26 constexpr int kQuantBitNumInt16 = 16;
27 constexpr int kQuantBitNumInt8 = 8;
28 constexpr int kMinSize = 0;
29 constexpr int kMaxSize = 65535;
30 }  // namespace
ParseFilter(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)31 int QuantParamParser::ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant) {
32   MS_ASSERT(common_quant != nullptr);
33   if (!common_quant_string.min_quant_weight_size.empty()) {
34     if (!ConvertIntNum(common_quant_string.min_quant_weight_size, &common_quant->min_quant_weight_size)) {
35       MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should be a valid number.";
36       return RET_INPUT_PARAM_INVALID;
37     }
38     if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
39       MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
40       return RET_INPUT_PARAM_INVALID;
41     }
42   }
43   if (!common_quant_string.min_quant_weight_channel.empty()) {
44     if (!ConvertIntNum(common_quant_string.min_quant_weight_channel, &common_quant->min_quant_weight_channel)) {
45       MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
46       return RET_INPUT_PARAM_INVALID;
47     }
48 
49     if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
50       MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be in the range [0,65535]." << std::endl;
51       return RET_INPUT_PARAM_INVALID;
52     }
53   }
54   if (!common_quant_string.skip_quant_node.empty()) {
55     std::vector<std::string> nodes = SplitStringToVector(common_quant_string.skip_quant_node, ',');
56     for (const auto &node : nodes) {
57       common_quant->skip_quant_node.insert(node);
58     }
59   }
60   return RET_OK;
61 }
62 
ParseBitNum(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)63 int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant) {
64   if (!common_quant_string.bit_num.empty() && !ConvertIntNum(common_quant_string.bit_num, &common_quant->bit_num)) {
65     MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be a valid number.";
66     return RET_INPUT_PARAM_INVALID;
67   }
68   if (common_quant->quant_type == quant::QUANT_WEIGHT) {
69     if (common_quant->bit_num < 0 || common_quant->bit_num > kQuantBitNumInt16) {
70       MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [0,16].";
71       return RET_INPUT_PARAM_INVALID;
72     }
73   } else if (common_quant->quant_type == quant::QUANT_ALL) {
74     if (common_quant->bit_num != kQuantBitNumInt8) {
75       MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
76       return RET_INPUT_PARAM_INVALID;
77     }
78   } else if (common_quant->quant_type == quant::QUANT_DYNAMIC) {
79     if (common_quant->bit_num != kQuantBitNumInt8) {
80       MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
81       return RET_INPUT_PARAM_INVALID;
82     }
83   }
84   return RET_OK;
85 }
86 
ParseEnableEncode(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)87 int QuantParamParser::ParseEnableEncode(const CommonQuantString &common_quant_string,
88                                         quant::CommonQuantParam *common_quant) {
89   if (!common_quant_string.enable_encode.empty() &&
90       !ConvertBool(common_quant_string.enable_encode, &common_quant->enable_encode)) {
91     MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true or false.";
92     return RET_INPUT_PARAM_INVALID;
93   }
94   if (common_quant->quant_type == quant::QUANT_WEIGHT &&
95       (common_quant->bit_num != kQuantBitNumInt8 && common_quant->bit_num != kQuantBitNumInt16)) {
96     if (!common_quant->enable_encode) {
97       MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true when parameter bit_num belongs to [0,7] or [9,15].";
98       return RET_INPUT_PARAM_INVALID;
99     }
100   }
101   return RET_OK;
102 }
103 
ParseCommonQuant(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)104 int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
105                                        quant::CommonQuantParam *common_quant) {
106   MS_ASSERT(common_quant != nullptr);
107   if (!common_quant_string.quant_type.empty()) {
108     auto ret = ParseQuantType(common_quant_string.quant_type, &common_quant->quant_type);
109     if (ret != RET_OK) {
110       MS_LOG(ERROR) << "Parse quant_type failed.";
111       return ret;
112     }
113   }
114 
115   auto ret = ParseBitNum(common_quant_string, common_quant);
116   if (ret != RET_OK) {
117     MS_LOG(ERROR) << "Parse bit num failed.";
118     return ret;
119   }
120 
121   ret = ParseEnableEncode(common_quant_string, common_quant);
122   if (ret != RET_OK) {
123     MS_LOG(ERROR) << "Parse enable encode failed.";
124     return ret;
125   }
126 
127   ret = ParseFilter(common_quant_string, common_quant);
128   if (ret != RET_OK) {
129     MS_LOG(ERROR) << "Parse filter failed.";
130     return ret;
131   }
132 
133   common_quant->debug_info_save_path = common_quant_string.debug_info_save_path;
134   if (!common_quant->debug_info_save_path.empty()) {
135     common_quant->is_debug = true;
136   }
137 
138   // this is required only for model larger than 2G
139   common_quant->workspace = common_quant_string.workspace;
140   return RET_OK;
141 }
142 
ParseMixedBitWeightQuant(const MixedBitWeightQuantString & mixed_bit_weight_quant_string,quant::MixedBitWeightQuantParam * mixed_bit_weight_quant)143 int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string,
144                                                quant::MixedBitWeightQuantParam *mixed_bit_weight_quant) {
145   if (!mixed_bit_weight_quant_string.init_scale.empty() &&
146       !ConvertDoubleNum(mixed_bit_weight_quant_string.init_scale, &mixed_bit_weight_quant->init_scale)) {
147     MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should be a valid number.";
148     return RET_INPUT_PARAM_INVALID;
149   }
150   if (mixed_bit_weight_quant->init_scale <= 0 || mixed_bit_weight_quant->init_scale >= 1) {
151     MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should at (0,1)";
152     return RET_INPUT_PARAM_INVALID;
153   }
154   if (!mixed_bit_weight_quant_string.use_cv_data.empty() &&
155       !ConvertBool(mixed_bit_weight_quant_string.use_cv_data, &mixed_bit_weight_quant->use_cv_data)) {
156     MS_LOG(ERROR) << "INPUT ILLEGAL: use_cv_data should be true or false.";
157     return RET_INPUT_PARAM_INVALID;
158   }
159   if (!mixed_bit_weight_quant_string.max_iterations.empty()) {
160     if (!ConvertIntNum(mixed_bit_weight_quant_string.max_iterations, &mixed_bit_weight_quant->max_iterations)) {
161       MS_LOG(ERROR) << "INPUT ILLEGAL: max_iterations should be a valid number.";
162       return RET_INPUT_PARAM_INVALID;
163     }
164     if (mixed_bit_weight_quant->max_iterations < quant::kMinIterations ||
165         mixed_bit_weight_quant->max_iterations > kMaxSize) {
166       MS_LOG(ERROR) << "INPUT ILLEGAL: max_iterations should be in the range [40,65535]." << std::endl;
167       return RET_INPUT_PARAM_INVALID;
168     }
169   }
170 
171   if (!mixed_bit_weight_quant_string.auto_tune.empty() &&
172       !ConvertBool(mixed_bit_weight_quant_string.auto_tune, &mixed_bit_weight_quant->auto_tune)) {
173     MS_LOG(ERROR) << "INPUT ILLEGAL: auto_tune should be true or false.";
174     return RET_INPUT_PARAM_INVALID;
175   }
176   return RET_OK;
177 }
178 
ParseFullQuant(const FullQuantString & full_quant_string,quant::FullQuantParam * full_quant)179 int QuantParamParser::ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant) {
180   if (!full_quant_string.activation_quant_method.empty() &&
181       ParseActivationQuantizedMethod(full_quant_string.activation_quant_method, &full_quant->activation_quant_method) !=
182         RET_OK) {
183     MS_LOG(ERROR) << "INPUT ILLEGAL: Parse activation_quant_method failed.";
184     return RET_INPUT_PARAM_INVALID;
185   }
186   if (!full_quant_string.bias_correction.empty() &&
187       !ConvertBool(full_quant_string.bias_correction, &full_quant->bias_correction)) {
188     MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
189     return RET_INPUT_PARAM_INVALID;
190   }
191   if (!full_quant_string.target_device.empty()) {
192     auto ret = ParseTargetDevice(full_quant_string.target_device, &full_quant->target_device);
193     if (ret != RET_OK) {
194       MS_LOG(ERROR) << "Parse device failed.";
195       return ret;
196     }
197   }
198   if (!full_quant_string.per_channel.empty() && !ConvertBool(full_quant_string.per_channel, &full_quant->per_channel)) {
199     MS_LOG(ERROR) << "INPUT ILLEGAL: per_channel should be true or false.";
200     return RET_INPUT_PARAM_INVALID;
201   }
202   if (!full_quant_string.smooth_alpha.empty() &&
203       !ConvertDoubleNum(full_quant_string.smooth_alpha, &full_quant->smooth_alpha)) {
204     MS_LOG(ERROR) << "INPUT ILLEGAL: smooth_alpha should be a valid number.";
205     return RET_INPUT_PARAM_INVALID;
206   }
207   if (!full_quant_string.enable_smooth_shift.empty() &&
208       !ConvertBool(full_quant_string.enable_smooth_shift, &full_quant->enable_smooth_shift)) {
209     MS_LOG(ERROR) << "INPUT ILLEGAL: enable_smooth_shift should be true or false.";
210     return RET_INPUT_PARAM_INVALID;
211   }
212   return RET_OK;
213 }
214 
ParseQuantType(const std::string & quant_type_str,quant::QuantType * quant_type)215 int QuantParamParser::ParseQuantType(const std::string &quant_type_str, quant::QuantType *quant_type) {
216   if (quant_type_str == "WEIGHT_QUANT") {
217     (*quant_type) = quant::QUANT_WEIGHT;
218     return RET_OK;
219   } else if (quant_type_str == "FULL_QUANT") {
220     (*quant_type) = quant::QUANT_ALL;
221     return RET_OK;
222   } else if (quant_type_str == "DYNAMIC_QUANT") {
223     (*quant_type) = quant::QUANT_DYNAMIC;
224     return RET_OK;
225   } else if (quant_type_str.empty()) {
226     (*quant_type) = quant::QUANT_NONE;
227     return RET_OK;
228   } else {
229     MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT|DYNAMIC_QUANT.";
230     return RET_INPUT_PARAM_INVALID;
231   }
232 }
233 
ParseTargetDevice(const std::string & target_device_str,quant::TargetDevice * target_device)234 int QuantParamParser::ParseTargetDevice(const std::string &target_device_str, quant::TargetDevice *target_device) {
235   if (target_device_str == "KIRIN") {
236     (*target_device) = quant::KIRIN;
237     return RET_OK;
238   } else if (target_device_str == "NVGPU") {
239     (*target_device) = quant::NVGPU;
240     return RET_OK;
241   } else if (target_device_str == "DSP") {
242     (*target_device) = quant::DSP;
243     return RET_OK;
244   } else if (target_device_str == "ASCEND") {
245     (*target_device) = quant::ASCEND;
246     return RET_OK;
247   } else {
248     MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN|NVGPU|DSP|ASCEND.";
249     return RET_INPUT_PARAM_INVALID;
250   }
251 }
252 
ParseActivationQuantizedMethod(const std::string & activation_quant_method_str,quant::ActivationQuantizedMethod * activation_quant_method)253 int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activation_quant_method_str,
254                                                      quant::ActivationQuantizedMethod *activation_quant_method) {
255   if (activation_quant_method_str == "MAX_MIN") {
256     (*activation_quant_method) = quant::MAX_MIN;
257     return RET_OK;
258   } else if (activation_quant_method_str == "KL") {
259     (*activation_quant_method) = quant::KL;
260     return RET_OK;
261   } else if (activation_quant_method_str == "REMOVAL_OUTLIER") {
262     (*activation_quant_method) = quant::REMOVAL_OUTLIER;
263     return RET_OK;
264   } else {
265     MS_LOG(ERROR) << "INPUT ILLEGAL: activation_quant_method must be MAX_MIN|KL|REMOVAL_OUTLIER.";
266     return RET_INPUT_PARAM_INVALID;
267   }
268 }
269 
ParseWeightQuant(const WeightQuantString & weight_quant_string,quant::WeightQuantParam * weight_quant)270 int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_string,
271                                        quant::WeightQuantParam *weight_quant) {
272   if (!weight_quant_string.dequant_strategy.empty()) {
273     if (weight_quant_string.dequant_strategy == "ON_THE_FLY") {
274       weight_quant->dequant_strategy = quant::ON_THE_FLY;
275     } else {
276       MS_LOG(ERROR) << "INPUT ILLEGAL: dequant_strategy must be ON_THE_FLY.";
277       return RET_INPUT_PARAM_INVALID;
278     }
279   }
280   if (!weight_quant_string.quant_strategy.empty()) {
281     if (weight_quant_string.quant_strategy == "GPTQ") {
282       weight_quant->quant_strategy = quant::GPTQ_ALGORITHM;
283     } else if (weight_quant_string.quant_strategy == "MAX_MIN") {
284       weight_quant->quant_strategy = quant::MAX_MIN_ALGORITHM;
285     } else {
286       MS_LOG(ERROR) << "INPUT ILLEGAL: quant_strategy must be GPTQ or MAX_MIN.";
287       return RET_INPUT_PARAM_INVALID;
288     }
289   }
290   if (!weight_quant_string.update_mindir.empty() &&
291       !ConvertBool(weight_quant_string.update_mindir, &weight_quant->update_mindir)) {
292     MS_LOG(ERROR) << "INPUT ILLEGAL: update_mindir should be true or false.";
293     return RET_INPUT_PARAM_INVALID;
294   }
295   if (!weight_quant_string.max_segments.empty() &&
296       !ConvertIntNum(weight_quant_string.max_segments, &weight_quant->max_segments)) {
297     MS_LOG(ERROR) << "INPUT ILLEGAL: decode_threads should be a number.";
298     return RET_INPUT_PARAM_INVALID;
299   }
300   if (!weight_quant_string.per_channel.empty() &&
301       !ConvertBool(weight_quant_string.per_channel, &weight_quant->per_channel)) {
302     MS_LOG(ERROR) << "INPUT ILLEGAL: per_channel should be true or false.";
303     return RET_INPUT_PARAM_INVALID;
304   }
305   if (!weight_quant_string.bias_correction.empty() &&
306       !ConvertBool(weight_quant_string.bias_correction, &weight_quant->bias_correction)) {
307     MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
308     return RET_INPUT_PARAM_INVALID;
309   }
310   return RET_OK;
311 }
312 
ParseExportPrecisionMode(const std::string & precision_modeL_str,quant::PrecisionMode * precision_mode)313 int QuantParamParser::ParseExportPrecisionMode(const std::string &precision_modeL_str,
314                                                quant::PrecisionMode *precision_mode) {
315   if (precision_modeL_str == "QUANT") {
316     (*precision_mode) = quant::QUANT;
317     return RET_OK;
318   } else if (precision_modeL_str == "FLOAT32") {
319     (*precision_mode) = quant::FLOAT32;
320     return RET_OK;
321   } else {
322     MS_LOG(ERROR) << "INPUT ILLEGAL: export_precision_mode must be QUANT or FLOAT32.";
323     return RET_INPUT_PARAM_INVALID;
324   }
325 }
326 
ParseTransformQuant(const TransformQuantString & transform_quant_string,quant::TransformQuantParam * transform_quant)327 int QuantParamParser::ParseTransformQuant(const TransformQuantString &transform_quant_string,
328                                           quant::TransformQuantParam *transform_quant) {
329   if (!transform_quant_string.export_precision_mode.empty()) {
330     auto ret = ParseExportPrecisionMode(transform_quant_string.export_precision_mode, &transform_quant->precision_mode);
331     if (ret != RET_OK) {
332       MS_LOG(ERROR) << "Parse precision mode failed.";
333       return ret;
334     }
335   }
336   return RET_OK;
337 }
338 
ParseDynamicQuant(const DynamicQuantString & dynamic_quant_string,quant::DynamicQuantParam * dynamic_quant)339 int QuantParamParser::ParseDynamicQuant(const DynamicQuantString &dynamic_quant_string,
340                                         quant::DynamicQuantParam *dynamic_quant) {
341   if (!dynamic_quant_string.quant_strategy.empty()) {
342     auto ret = ParseDynamicQuantStrategy(dynamic_quant_string.quant_strategy, &dynamic_quant->quant_strategy);
343     if (ret != RET_OK) {
344       MS_LOG(ERROR) << "Parse dynamic quant strategy failed.";
345       return ret;
346     }
347   }
348   return RET_OK;
349 }
350 
ParseDynamicQuantStrategy(const std::string & dynamic_quant_strategy_str,quant::DynamicQuantStrategy * dynamic_strategy)351 int QuantParamParser::ParseDynamicQuantStrategy(const std::string &dynamic_quant_strategy_str,
352                                                 quant::DynamicQuantStrategy *dynamic_strategy) {
353   if (dynamic_quant_strategy_str == "ALWC") {
354     (*dynamic_strategy) = quant::ACTIVATION_LAYER_WEIGHT_CHANNEL;
355   } else if (dynamic_quant_strategy_str == "ACWL") {
356     (*dynamic_strategy) = quant::ACTIVATION_CHANNEL_WEIGHT_LAYER;
357   } else {
358     MS_LOG(ERROR) << "INPUT ILLEGAL: dynamic_quant_strategy must be ALWC or ACWL.";
359     return RET_INPUT_PARAM_INVALID;
360   }
361   return RET_OK;
362 }
363 }  // namespace lite
364 }  // namespace mindspore
365