1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/config_parser/quant_param_parser.h"
18 #include <vector>
19 #include "src/common/log_adapter.h"
20 #include "mindspore/lite/tools/common/string_util.h"
21 #include "include/errorcode.h"
22
23 namespace mindspore {
24 namespace lite {
25 namespace {
26 constexpr int kQuantBitNumInt16 = 16;
27 constexpr int kQuantBitNumInt8 = 8;
28 constexpr int kMinSize = 0;
29 constexpr int kMaxSize = 65535;
30 } // namespace
ParseFilter(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)31 int QuantParamParser::ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant) {
32 MS_ASSERT(common_quant != nullptr);
33 if (!common_quant_string.min_quant_weight_size.empty()) {
34 if (!ConvertIntNum(common_quant_string.min_quant_weight_size, &common_quant->min_quant_weight_size)) {
35 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should be a valid number.";
36 return RET_INPUT_PARAM_INVALID;
37 }
38 if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
39 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
40 return RET_INPUT_PARAM_INVALID;
41 }
42 }
43 if (!common_quant_string.min_quant_weight_channel.empty()) {
44 if (!ConvertIntNum(common_quant_string.min_quant_weight_channel, &common_quant->min_quant_weight_channel)) {
45 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
46 return RET_INPUT_PARAM_INVALID;
47 }
48
49 if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
50 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be in the range [0,65535]." << std::endl;
51 return RET_INPUT_PARAM_INVALID;
52 }
53 }
54 if (!common_quant_string.skip_quant_node.empty()) {
55 std::vector<std::string> nodes = SplitStringToVector(common_quant_string.skip_quant_node, ',');
56 for (const auto &node : nodes) {
57 common_quant->skip_quant_node.insert(node);
58 }
59 }
60 return RET_OK;
61 }
62
ParseBitNum(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)63 int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant) {
64 if (!common_quant_string.bit_num.empty() && !ConvertIntNum(common_quant_string.bit_num, &common_quant->bit_num)) {
65 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be a valid number.";
66 return RET_INPUT_PARAM_INVALID;
67 }
68 if (common_quant->quant_type == quant::QUANT_WEIGHT) {
69 if (common_quant->bit_num < 0 || common_quant->bit_num > kQuantBitNumInt16) {
70 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [0,16].";
71 return RET_INPUT_PARAM_INVALID;
72 }
73 } else if (common_quant->quant_type == quant::QUANT_ALL) {
74 if (common_quant->bit_num != kQuantBitNumInt8) {
75 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
76 return RET_INPUT_PARAM_INVALID;
77 }
78 } else if (common_quant->quant_type == quant::QUANT_DYNAMIC) {
79 if (common_quant->bit_num != kQuantBitNumInt8) {
80 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be 8.";
81 return RET_INPUT_PARAM_INVALID;
82 }
83 }
84 return RET_OK;
85 }
86
ParseEnableEncode(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)87 int QuantParamParser::ParseEnableEncode(const CommonQuantString &common_quant_string,
88 quant::CommonQuantParam *common_quant) {
89 if (!common_quant_string.enable_encode.empty() &&
90 !ConvertBool(common_quant_string.enable_encode, &common_quant->enable_encode)) {
91 MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true or false.";
92 return RET_INPUT_PARAM_INVALID;
93 }
94 if (common_quant->quant_type == quant::QUANT_WEIGHT &&
95 (common_quant->bit_num != kQuantBitNumInt8 && common_quant->bit_num != kQuantBitNumInt16)) {
96 if (!common_quant->enable_encode) {
97 MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true when parameter bit_num belongs to [0,7] or [9,15].";
98 return RET_INPUT_PARAM_INVALID;
99 }
100 }
101 return RET_OK;
102 }
103
ParseCommonQuant(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)104 int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
105 quant::CommonQuantParam *common_quant) {
106 MS_ASSERT(common_quant != nullptr);
107 if (!common_quant_string.quant_type.empty()) {
108 auto ret = ParseQuantType(common_quant_string.quant_type, &common_quant->quant_type);
109 if (ret != RET_OK) {
110 MS_LOG(ERROR) << "Parse quant_type failed.";
111 return ret;
112 }
113 }
114
115 auto ret = ParseBitNum(common_quant_string, common_quant);
116 if (ret != RET_OK) {
117 MS_LOG(ERROR) << "Parse bit num failed.";
118 return ret;
119 }
120
121 ret = ParseEnableEncode(common_quant_string, common_quant);
122 if (ret != RET_OK) {
123 MS_LOG(ERROR) << "Parse enable encode failed.";
124 return ret;
125 }
126
127 ret = ParseFilter(common_quant_string, common_quant);
128 if (ret != RET_OK) {
129 MS_LOG(ERROR) << "Parse filter failed.";
130 return ret;
131 }
132
133 common_quant->debug_info_save_path = common_quant_string.debug_info_save_path;
134 if (!common_quant->debug_info_save_path.empty()) {
135 common_quant->is_debug = true;
136 }
137
138 // this is required only for model larger than 2G
139 common_quant->workspace = common_quant_string.workspace;
140 return RET_OK;
141 }
142
ParseMixedBitWeightQuant(const MixedBitWeightQuantString & mixed_bit_weight_quant_string,quant::MixedBitWeightQuantParam * mixed_bit_weight_quant)143 int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string,
144 quant::MixedBitWeightQuantParam *mixed_bit_weight_quant) {
145 if (!mixed_bit_weight_quant_string.init_scale.empty() &&
146 !ConvertDoubleNum(mixed_bit_weight_quant_string.init_scale, &mixed_bit_weight_quant->init_scale)) {
147 MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should be a valid number.";
148 return RET_INPUT_PARAM_INVALID;
149 }
150 if (mixed_bit_weight_quant->init_scale <= 0 || mixed_bit_weight_quant->init_scale >= 1) {
151 MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should at (0,1)";
152 return RET_INPUT_PARAM_INVALID;
153 }
154 if (!mixed_bit_weight_quant_string.use_cv_data.empty() &&
155 !ConvertBool(mixed_bit_weight_quant_string.use_cv_data, &mixed_bit_weight_quant->use_cv_data)) {
156 MS_LOG(ERROR) << "INPUT ILLEGAL: use_cv_data should be true or false.";
157 return RET_INPUT_PARAM_INVALID;
158 }
159 if (!mixed_bit_weight_quant_string.max_iterations.empty()) {
160 if (!ConvertIntNum(mixed_bit_weight_quant_string.max_iterations, &mixed_bit_weight_quant->max_iterations)) {
161 MS_LOG(ERROR) << "INPUT ILLEGAL: max_iterations should be a valid number.";
162 return RET_INPUT_PARAM_INVALID;
163 }
164 if (mixed_bit_weight_quant->max_iterations < quant::kMinIterations ||
165 mixed_bit_weight_quant->max_iterations > kMaxSize) {
166 MS_LOG(ERROR) << "INPUT ILLEGAL: max_iterations should be in the range [40,65535]." << std::endl;
167 return RET_INPUT_PARAM_INVALID;
168 }
169 }
170
171 if (!mixed_bit_weight_quant_string.auto_tune.empty() &&
172 !ConvertBool(mixed_bit_weight_quant_string.auto_tune, &mixed_bit_weight_quant->auto_tune)) {
173 MS_LOG(ERROR) << "INPUT ILLEGAL: auto_tune should be true or false.";
174 return RET_INPUT_PARAM_INVALID;
175 }
176 return RET_OK;
177 }
178
ParseFullQuant(const FullQuantString & full_quant_string,quant::FullQuantParam * full_quant)179 int QuantParamParser::ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant) {
180 if (!full_quant_string.activation_quant_method.empty() &&
181 ParseActivationQuantizedMethod(full_quant_string.activation_quant_method, &full_quant->activation_quant_method) !=
182 RET_OK) {
183 MS_LOG(ERROR) << "INPUT ILLEGAL: Parse activation_quant_method failed.";
184 return RET_INPUT_PARAM_INVALID;
185 }
186 if (!full_quant_string.bias_correction.empty() &&
187 !ConvertBool(full_quant_string.bias_correction, &full_quant->bias_correction)) {
188 MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
189 return RET_INPUT_PARAM_INVALID;
190 }
191 if (!full_quant_string.target_device.empty()) {
192 auto ret = ParseTargetDevice(full_quant_string.target_device, &full_quant->target_device);
193 if (ret != RET_OK) {
194 MS_LOG(ERROR) << "Parse device failed.";
195 return ret;
196 }
197 }
198 if (!full_quant_string.per_channel.empty() && !ConvertBool(full_quant_string.per_channel, &full_quant->per_channel)) {
199 MS_LOG(ERROR) << "INPUT ILLEGAL: per_channel should be true or false.";
200 return RET_INPUT_PARAM_INVALID;
201 }
202 if (!full_quant_string.smooth_alpha.empty() &&
203 !ConvertDoubleNum(full_quant_string.smooth_alpha, &full_quant->smooth_alpha)) {
204 MS_LOG(ERROR) << "INPUT ILLEGAL: smooth_alpha should be a valid number.";
205 return RET_INPUT_PARAM_INVALID;
206 }
207 if (!full_quant_string.enable_smooth_shift.empty() &&
208 !ConvertBool(full_quant_string.enable_smooth_shift, &full_quant->enable_smooth_shift)) {
209 MS_LOG(ERROR) << "INPUT ILLEGAL: enable_smooth_shift should be true or false.";
210 return RET_INPUT_PARAM_INVALID;
211 }
212 return RET_OK;
213 }
214
ParseQuantType(const std::string & quant_type_str,quant::QuantType * quant_type)215 int QuantParamParser::ParseQuantType(const std::string &quant_type_str, quant::QuantType *quant_type) {
216 if (quant_type_str == "WEIGHT_QUANT") {
217 (*quant_type) = quant::QUANT_WEIGHT;
218 return RET_OK;
219 } else if (quant_type_str == "FULL_QUANT") {
220 (*quant_type) = quant::QUANT_ALL;
221 return RET_OK;
222 } else if (quant_type_str == "DYNAMIC_QUANT") {
223 (*quant_type) = quant::QUANT_DYNAMIC;
224 return RET_OK;
225 } else if (quant_type_str.empty()) {
226 (*quant_type) = quant::QUANT_NONE;
227 return RET_OK;
228 } else {
229 MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT|DYNAMIC_QUANT.";
230 return RET_INPUT_PARAM_INVALID;
231 }
232 }
233
ParseTargetDevice(const std::string & target_device_str,quant::TargetDevice * target_device)234 int QuantParamParser::ParseTargetDevice(const std::string &target_device_str, quant::TargetDevice *target_device) {
235 if (target_device_str == "KIRIN") {
236 (*target_device) = quant::KIRIN;
237 return RET_OK;
238 } else if (target_device_str == "NVGPU") {
239 (*target_device) = quant::NVGPU;
240 return RET_OK;
241 } else if (target_device_str == "DSP") {
242 (*target_device) = quant::DSP;
243 return RET_OK;
244 } else if (target_device_str == "ASCEND") {
245 (*target_device) = quant::ASCEND;
246 return RET_OK;
247 } else {
248 MS_LOG(ERROR) << "INPUT ILLEGAL: target_device must be KIRIN|NVGPU|DSP|ASCEND.";
249 return RET_INPUT_PARAM_INVALID;
250 }
251 }
252
ParseActivationQuantizedMethod(const std::string & activation_quant_method_str,quant::ActivationQuantizedMethod * activation_quant_method)253 int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activation_quant_method_str,
254 quant::ActivationQuantizedMethod *activation_quant_method) {
255 if (activation_quant_method_str == "MAX_MIN") {
256 (*activation_quant_method) = quant::MAX_MIN;
257 return RET_OK;
258 } else if (activation_quant_method_str == "KL") {
259 (*activation_quant_method) = quant::KL;
260 return RET_OK;
261 } else if (activation_quant_method_str == "REMOVAL_OUTLIER") {
262 (*activation_quant_method) = quant::REMOVAL_OUTLIER;
263 return RET_OK;
264 } else {
265 MS_LOG(ERROR) << "INPUT ILLEGAL: activation_quant_method must be MAX_MIN|KL|REMOVAL_OUTLIER.";
266 return RET_INPUT_PARAM_INVALID;
267 }
268 }
269
ParseWeightQuant(const WeightQuantString & weight_quant_string,quant::WeightQuantParam * weight_quant)270 int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_string,
271 quant::WeightQuantParam *weight_quant) {
272 if (!weight_quant_string.dequant_strategy.empty()) {
273 if (weight_quant_string.dequant_strategy == "ON_THE_FLY") {
274 weight_quant->dequant_strategy = quant::ON_THE_FLY;
275 } else {
276 MS_LOG(ERROR) << "INPUT ILLEGAL: dequant_strategy must be ON_THE_FLY.";
277 return RET_INPUT_PARAM_INVALID;
278 }
279 }
280 if (!weight_quant_string.quant_strategy.empty()) {
281 if (weight_quant_string.quant_strategy == "GPTQ") {
282 weight_quant->quant_strategy = quant::GPTQ_ALGORITHM;
283 } else if (weight_quant_string.quant_strategy == "MAX_MIN") {
284 weight_quant->quant_strategy = quant::MAX_MIN_ALGORITHM;
285 } else {
286 MS_LOG(ERROR) << "INPUT ILLEGAL: quant_strategy must be GPTQ or MAX_MIN.";
287 return RET_INPUT_PARAM_INVALID;
288 }
289 }
290 if (!weight_quant_string.update_mindir.empty() &&
291 !ConvertBool(weight_quant_string.update_mindir, &weight_quant->update_mindir)) {
292 MS_LOG(ERROR) << "INPUT ILLEGAL: update_mindir should be true or false.";
293 return RET_INPUT_PARAM_INVALID;
294 }
295 if (!weight_quant_string.max_segments.empty() &&
296 !ConvertIntNum(weight_quant_string.max_segments, &weight_quant->max_segments)) {
297 MS_LOG(ERROR) << "INPUT ILLEGAL: decode_threads should be a number.";
298 return RET_INPUT_PARAM_INVALID;
299 }
300 if (!weight_quant_string.per_channel.empty() &&
301 !ConvertBool(weight_quant_string.per_channel, &weight_quant->per_channel)) {
302 MS_LOG(ERROR) << "INPUT ILLEGAL: per_channel should be true or false.";
303 return RET_INPUT_PARAM_INVALID;
304 }
305 if (!weight_quant_string.bias_correction.empty() &&
306 !ConvertBool(weight_quant_string.bias_correction, &weight_quant->bias_correction)) {
307 MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
308 return RET_INPUT_PARAM_INVALID;
309 }
310 return RET_OK;
311 }
312
ParseExportPrecisionMode(const std::string & precision_modeL_str,quant::PrecisionMode * precision_mode)313 int QuantParamParser::ParseExportPrecisionMode(const std::string &precision_modeL_str,
314 quant::PrecisionMode *precision_mode) {
315 if (precision_modeL_str == "QUANT") {
316 (*precision_mode) = quant::QUANT;
317 return RET_OK;
318 } else if (precision_modeL_str == "FLOAT32") {
319 (*precision_mode) = quant::FLOAT32;
320 return RET_OK;
321 } else {
322 MS_LOG(ERROR) << "INPUT ILLEGAL: export_precision_mode must be QUANT or FLOAT32.";
323 return RET_INPUT_PARAM_INVALID;
324 }
325 }
326
ParseTransformQuant(const TransformQuantString & transform_quant_string,quant::TransformQuantParam * transform_quant)327 int QuantParamParser::ParseTransformQuant(const TransformQuantString &transform_quant_string,
328 quant::TransformQuantParam *transform_quant) {
329 if (!transform_quant_string.export_precision_mode.empty()) {
330 auto ret = ParseExportPrecisionMode(transform_quant_string.export_precision_mode, &transform_quant->precision_mode);
331 if (ret != RET_OK) {
332 MS_LOG(ERROR) << "Parse precision mode failed.";
333 return ret;
334 }
335 }
336 return RET_OK;
337 }
338
ParseDynamicQuant(const DynamicQuantString & dynamic_quant_string,quant::DynamicQuantParam * dynamic_quant)339 int QuantParamParser::ParseDynamicQuant(const DynamicQuantString &dynamic_quant_string,
340 quant::DynamicQuantParam *dynamic_quant) {
341 if (!dynamic_quant_string.quant_strategy.empty()) {
342 auto ret = ParseDynamicQuantStrategy(dynamic_quant_string.quant_strategy, &dynamic_quant->quant_strategy);
343 if (ret != RET_OK) {
344 MS_LOG(ERROR) << "Parse dynamic quant strategy failed.";
345 return ret;
346 }
347 }
348 return RET_OK;
349 }
350
ParseDynamicQuantStrategy(const std::string & dynamic_quant_strategy_str,quant::DynamicQuantStrategy * dynamic_strategy)351 int QuantParamParser::ParseDynamicQuantStrategy(const std::string &dynamic_quant_strategy_str,
352 quant::DynamicQuantStrategy *dynamic_strategy) {
353 if (dynamic_quant_strategy_str == "ALWC") {
354 (*dynamic_strategy) = quant::ACTIVATION_LAYER_WEIGHT_CHANNEL;
355 } else if (dynamic_quant_strategy_str == "ACWL") {
356 (*dynamic_strategy) = quant::ACTIVATION_CHANNEL_WEIGHT_LAYER;
357 } else {
358 MS_LOG(ERROR) << "INPUT ILLEGAL: dynamic_quant_strategy must be ALWC or ACWL.";
359 return RET_INPUT_PARAM_INVALID;
360 }
361 return RET_OK;
362 }
363 } // namespace lite
364 } // namespace mindspore
365