• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/config_parser/quant_param_parser.h"
18 #include "src/common/log_adapter.h"
19 #include "mindspore/lite/tools/common/string_util.h"
20 #include "include/errorcode.h"
21 namespace mindspore {
22 namespace lite {
23 namespace {
24 constexpr int kQuantBitNumInt16 = 16;
25 constexpr int kQuantBitNumInt8 = 8;
26 constexpr int kMinSize = 0;
27 constexpr int kMaxSize = 65535;
28 }  // namespace
ParseCommonQuant(const CommonQuantString & common_quant_string,quant::CommonQuantParam * common_quant)29 int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
30                                        quant::CommonQuantParam *common_quant) {
31   if (!common_quant_string.quant_type.empty()) {
32     auto ret = ParseQuantType(common_quant_string.quant_type, &common_quant->quant_type);
33     if (ret != RET_OK) {
34       MS_LOG(ERROR) << "Parse quant_type failed.";
35       return ret;
36     }
37   }
38 
39   if (!common_quant_string.bit_num.empty() && !ConvertIntNum(common_quant_string.bit_num, &common_quant->bit_num)) {
40     MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be a valid number.";
41     return RET_INPUT_PARAM_INVALID;
42   }
43   if (common_quant->quant_type == schema::QuantType_QUANT_WEIGHT) {
44     if (common_quant->bit_num < 0 || common_quant->bit_num > kQuantBitNumInt16) {
45       MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [0,16].";
46       return RET_INPUT_PARAM_INVALID;
47     }
48   } else if (common_quant->quant_type == schema::QuantType_QUANT_ALL) {
49     if (common_quant->bit_num <= 0 || common_quant->bit_num > kQuantBitNumInt8) {
50       MS_LOG(ERROR) << "INPUT ILLEGAL: bit_num should be [1,8].";
51       return RET_INPUT_PARAM_INVALID;
52     }
53   }
54   if (!common_quant_string.min_quant_weight_size.empty() &&
55       !ConvertIntNum(common_quant_string.min_quant_weight_size, &common_quant->min_quant_weight_size)) {
56     MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should be a valid number.";
57     return RET_INPUT_PARAM_INVALID;
58   }
59   if (!common_quant_string.min_quant_weight_channel.empty() &&
60       !ConvertIntNum(common_quant_string.min_quant_weight_channel, &common_quant->min_quant_weight_channel)) {
61     MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
62     return RET_INPUT_PARAM_INVALID;
63   }
64   if (common_quant->min_quant_weight_size < kMinSize || common_quant->min_quant_weight_size > kMaxSize) {
65     MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
66     return RET_INPUT_PARAM_INVALID;
67   }
68 
69   if (common_quant->min_quant_weight_channel < kMinSize || common_quant->min_quant_weight_channel > kMaxSize) {
70     MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
71     return RET_INPUT_PARAM_INVALID;
72   }
73   return RET_OK;
74 }
75 
ParseMixedBitWeightQuant(const MixedBitWeightQuantString & mixed_bit_weight_quant_string,quant::MixedBitWeightQuantParam * mixed_bit_weight_quant)76 int QuantParamParser::ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string,
77                                                quant::MixedBitWeightQuantParam *mixed_bit_weight_quant) {
78   if (mixed_bit_weight_quant_string.init_scale.empty()) {
79     return RET_OK;
80   }
81   if (!ConvertDoubleNum(mixed_bit_weight_quant_string.init_scale, &mixed_bit_weight_quant->init_scale)) {
82     MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should be a valid number.";
83     return RET_INPUT_PARAM_INVALID;
84   }
85   if (mixed_bit_weight_quant->init_scale <= 0 || mixed_bit_weight_quant->init_scale >= 1) {
86     MS_LOG(ERROR) << "INPUT ILLEGAL: init_scale should at (0,1)";
87     return RET_INPUT_PARAM_INVALID;
88   }
89   return RET_OK;
90 }
91 
ParseFullQuant(const FullQuantString & full_quant_string,quant::FullQuantParam * full_quant)92 int QuantParamParser::ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant) {
93   if (!full_quant_string.activation_quant_method.empty() &&
94       ParseActivationQuantizedMethod(full_quant_string.activation_quant_method, &full_quant->activation_quant_method) !=
95         RET_OK) {
96     MS_LOG(ERROR) << "INPUT ILLEGAL: Parse activation_quant_method failed.";
97     return RET_INPUT_PARAM_INVALID;
98   }
99   if (!full_quant_string.bias_correction.empty() &&
100       !ConvertBool(full_quant_string.bias_correction, &full_quant->bias_correction)) {
101     MS_LOG(ERROR) << "INPUT ILLEGAL: bias_correction should be true or false.";
102     return RET_INPUT_PARAM_INVALID;
103   }
104   return RET_OK;
105 }
106 
ParseQuantType(const std::string & quant_type_str,schema::QuantType * quant_type)107 int QuantParamParser::ParseQuantType(const std::string &quant_type_str, schema::QuantType *quant_type) {
108   if (quant_type_str == "WEIGHT_QUANT") {
109     (*quant_type) = schema::QuantType_QUANT_WEIGHT;
110     return RET_OK;
111   } else if (quant_type_str == "FULL_QUANT") {
112     (*quant_type) = schema::QuantType_QUANT_ALL;
113     return RET_OK;
114   } else if (quant_type_str.empty()) {
115     (*quant_type) = schema::QuantType_QUANT_NONE;
116     return RET_OK;
117   } else {
118     MS_LOG(ERROR) << "INPUT ILLEGAL: quant_type must be WEIGHT_QUANT|FULL_QUANT.";
119     return RET_INPUT_PARAM_INVALID;
120   }
121 }
122 
ParseActivationQuantizedMethod(const std::string & activation_quant_method_str,quant::ActivationQuantizedMethod * activation_quant_method)123 int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activation_quant_method_str,
124                                                      quant::ActivationQuantizedMethod *activation_quant_method) {
125   if (activation_quant_method_str == "MAX_MIN") {
126     (*activation_quant_method) = quant::MAX_MIN;
127     return RET_OK;
128   } else if (activation_quant_method_str == "KL") {
129     (*activation_quant_method) = quant::KL;
130     return RET_OK;
131   } else if (activation_quant_method_str == "REMOVAL_OUTLIER") {
132     (*activation_quant_method) = quant::REMOVAL_OUTLIER;
133     return RET_OK;
134   } else {
135     MS_LOG(ERROR) << "INPUT ILLEGAL: activation_quant_method must be MAX_MIN|KL|REMOVAL_OUTLIER.";
136     return RET_INPUT_PARAM_INVALID;
137   }
138 }
139 }  // namespace lite
140 }  // namespace mindspore
141