1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/config_parser/quant_param_parser.h"
18 #include "src/common/log_adapter.h"
19 #include "mindspore/lite/tools/common/string_util.h"
20 #include "include/errorcode.h"
21 namespace mindspore {
22 namespace lite {
23 namespace {
24 constexpr int kQuantBitNumInt16 = 16;
25 constexpr int kQuantBitNumInt8 = 8;
26 constexpr int kMinSize = 0;
27 constexpr int kMaxSize = 65535;
28 } // namespace
ParseCommonQuant(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)29 int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
30 quant::CommonQuantParam *common_quant) {
31 if (!common_quant_string.quant_type.empty()) {
32 auto ret = ParseQuantType(common_quant_string.quant_type, &common_quant->quant_type);
33 if (ret != RET_OK) {
34 MS_LOG(ERROR) << "Parse quant_type failed.";
35 return ret;
36 }
37 }
38
39 if (!common_quant_string.bit_num.empty() && !ConvertIntNum(common_quant_string.bit_num, &common_quant->bit_num)) {
40 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be a valid number.";
41 return RET_INPUT_PARAM_INVALID;
42 }
43 if (common_quant->quant_type == schema::QuantType_QUANT_WEIGHT) {
44 if (common_quant->bit_num < 0 || common_quant->bit_num > kQuantBitNumInt16) {
45 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [0,16].";
46 return RET_INPUT_PARAM_INVALID;
47 }
48 } else if (common_quant->quant_type == schema::QuantType_QUANT_ALL) {
49 if (common_quant->bit_num <= 0 || common_quant->bit_num > kQuantBitNumInt8) {
50 MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
51 return RET_INPUT_PARAM_INVALID;
52 }
53 }
54 if (!common_quant_string.min_quant_weight_size.empty() &&
55 !ConvertIntNum(common_quant_string.min_quant_weight_size, &common_quant->min_quant_weight_size)) {
56 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should be a valid number.";
57 return RET_INPUT_PARAM_INVALID;
58 }
59 if (!common_quant_string.min_quant_weight_channel.empty() &&
60 !ConvertIntNum(common_quant_string.min_quant_weight_channel, &common_quant->min_quant_weight_channel)) {
61 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
62 return RET_INPUT_PARAM_INVALID;
63 }
64 if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
65 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
66 return RET_INPUT_PARAM_INVALID;
67 }
68
69 if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
70 MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
71 return RET_INPUT_PARAM_INVALID;
72 }
73 return RET_OK;
74 }
75
ParseMixedBitWeightQuant(const MixedBitWeightQuantString & mixed_bit_weight_quant_string,quant::MixedBitWeightQuantParam * mixed_bit_weight_quant)76 int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string,
77 quant::MixedBitWeightQuantParam *mixed_bit_weight_quant) {
78 if (mixed_bit_weight_quant_string.init_scale.empty()) {
79 return RET_OK;
80 }
81 if (!ConvertDoubleNum(mixed_bit_weight_quant_string.init_scale, &mixed_bit_weight_quant->init_scale)) {
82 MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should be a valid number.";
83 return RET_INPUT_PARAM_INVALID;
84 }
85 if (mixed_bit_weight_quant->init_scale <= 0 || mixed_bit_weight_quant->init_scale >= 1) {
86 MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should at (0,1)";
87 return RET_INPUT_PARAM_INVALID;
88 }
89 return RET_OK;
90 }
91
ParseFullQuant(const FullQuantString & full_quant_string,quant::FullQuantParam * full_quant)92 int QuantParamParser::ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant) {
93 if (!full_quant_string.activation_quant_method.empty() &&
94 ParseActivationQuantizedMethod(full_quant_string.activation_quant_method, &full_quant->activation_quant_method) !=
95 RET_OK) {
96 MS_LOG(ERROR) << "INPUT ILLEGAL: Parse activation_quant_method failed.";
97 return RET_INPUT_PARAM_INVALID;
98 }
99 if (!full_quant_string.bias_correction.empty() &&
100 !ConvertBool(full_quant_string.bias_correction, &full_quant->bias_correction)) {
101 MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
102 return RET_INPUT_PARAM_INVALID;
103 }
104 return RET_OK;
105 }
106
ParseQuantType(const std::string & quant_type_str,schema::QuantType * quant_type)107 int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::QuantType *quant_type) {
108 if (quant_type_str == "WEIGHT_QUANT") {
109 (*quant_type) = schema::QuantType_QUANT_WEIGHT;
110 return RET_OK;
111 } else if (quant_type_str == "FULL_QUANT") {
112 (*quant_type) = schema::QuantType_QUANT_ALL;
113 return RET_OK;
114 } else if (quant_type_str.empty()) {
115 (*quant_type) = schema::QuantType_QUANT_NONE;
116 return RET_OK;
117 } else {
118 MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT.";
119 return RET_INPUT_PARAM_INVALID;
120 }
121 }
122
ParseActivationQuantizedMethod(const std::string & activation_quant_method_str,quant::ActivationQuantizedMethod * activation_quant_method)123 int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activation_quant_method_str,
124 quant::ActivationQuantizedMethod *activation_quant_method) {
125 if (activation_quant_method_str == "MAX_MIN") {
126 (*activation_quant_method) = quant::MAX_MIN;
127 return RET_OK;
128 } else if (activation_quant_method_str == "KL") {
129 (*activation_quant_method) = quant::KL;
130 return RET_OK;
131 } else if (activation_quant_method_str == "REMOVAL_OUTLIER") {
132 (*activation_quant_method) = quant::REMOVAL_OUTLIER;
133 return RET_OK;
134 } else {
135 MS_LOG(ERROR) << "INPUT ILLEGAL: activation_quant_method must be MAX_MIN|KL|REMOVAL_OUTLIER.";
136 return RET_INPUT_PARAM_INVALID;
137 }
138 }
139 } // namespace lite
140 } // namespace mindspore
141