• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIXED_BIT_WEIGHT_QUANTIZATION_H
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIXED_BIT_WEIGHT_QUANTIZATION_H
19 
20 #include <vector>
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include "ir/tensor.h"
25 #include "schema/inner/model_generated.h"
26 #include "src/common/log_adapter.h"
27 #include "src/common/quant_utils.h"
28 #include "tools/converter/quantizer/quant_params.h"
29 #include "tools/converter/quantizer/quantize_util.h"
30 #include "mindspore/core/ir/quantization_param.h"
31 
32 namespace mindspore::lite::quant {
33 class FixedBitWeightQuantization {
34  public:
35   FixedBitWeightQuantization() = default;
36 
37   ~FixedBitWeightQuantization() = default;
38 
39   int QuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
40                   quant::QuantType quant_type, int quant_max, int quant_min, size_t bit_num,
41                   WeightQuantType weight_quant_type, TypeId quant_data_type, int preferred_dim, bool symmetric = false,
42                   bool narrow_range = false, bool bias_correction = true);
43 
44   int QuantBias(const ParameterPtr &weight, const ParameterPtr &bias,
45                 const std::vector<schema::QuantParamT> &active_quant_params);
46 
47  private:
48   int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
49                                    const float *raw_datas, std::vector<schema::QuantParamT> *weight_quant_params,
50                                    const tensor::TensorPtr &weight, std::vector<schema::QuantParamT> *bias_quant_params,
51                                    std::vector<int32_t> *quant_datas);
52 
53   template <typename T>
54   int FixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
55                           const PrimitivePtr &primitive, quant::QuantType quant_type, int quant_max, int quant_min,
56                           size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type, int preferred_dim,
57                           bool symmetric = false, bool narrow_range = false, bool bias_correction = true) {
58     size_t elem_count = weight->DataSize();
59     auto *raw_data = static_cast<float *>(weight->data_c());
60     if (raw_data == nullptr) {
61       MS_LOG(ERROR) << "rawDatas is nullptr";
62       return RET_ERROR;
63     }
64     std::vector<T> quant_data(elem_count);
65     auto status = FixedBitStatisticsFilter<T>(weight, quant_type, quant_max, quant_min, bit_num, weight_quant_type,
66                                               preferred_dim, &quant_data, symmetric, narrow_range, bias_correction);
67     if (status == RET_NO_CHANGE) {
68       return status;
69     } else if (status != RET_OK) {
70       MS_LOG(ERROR) << "FixedBitStatisticsFilter failed : " << status;
71       return status;
72     }
73     status = UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(T),
74                                      quant_data_type);
75     if (status != RET_OK) {
76       MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
77       return RET_ERROR;
78     }
79     auto quant_type_value = MakeValue(static_cast<int>(quant_type));
80     MS_CHECK_TRUE_MSG(quant_type_value != nullptr, RET_ERROR, "quant_type is nullptr.");
81     primitive->AddAttr(quant::kQuantType, quant_type_value);
82     return RET_OK;
83   }
84 
85   template <typename T>
86   int FixedBitStatisticsFilter(const tensor::TensorPtr &weight, quant::QuantType quant_type, int quant_max,
87                                int quant_min, size_t bit_num, WeightQuantType weight_quant_type, int preferred_dim,
88                                std::vector<T> *quant_data, bool symmetric = false, bool narrow_range = false,
89                                bool bias_correction = true) {
90     MS_ASSERT(weight != nullptr);
91     auto dims = weight->shape();
92     if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
93       if (dims.size() <= 1) {
94         MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel";
95         weight_quant_type = FIXED_BIT_PER_LAYER;
96       }
97     }
98     if (weight->data_type_c() != kNumberTypeFloat32) {
99       MS_LOG(ERROR) << "data type is not Float32.";
100       return RET_ERROR;
101     }
102 
103     std::vector<schema::QuantParamT> quant_params;
104     int ret = RET_OK;
105     bool cal_gain = (quant_type == QUANT_WEIGHT) && bias_correction ? true : false;
106     if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
107       ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
108                                  quant_min, bit_num, quant_data, ConvertShapeVectorToInt32(dims), preferred_dim,
109                                  cal_gain, symmetric, narrow_range);
110       if (ret == RET_NO_CHANGE) {
111         return ret;
112       } else if (ret != RET_OK) {
113         MS_LOG(ERROR) << "Do per channel quant failed.";
114         return ret;
115       }
116     } else if (weight_quant_type == FIXED_BIT_PER_LAYER) {
117       ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
118                                quant_min, bit_num, quant_data, symmetric, narrow_range, cal_gain);
119       if (ret != RET_OK) {
120         MS_LOG(ERROR) << "Do per layer quant failed.";
121         return ret;
122       }
123     } else {
124       MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
125       return RET_ERROR;
126     }
127     auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
128     CHECK_NULL_RETURN(quantization_ptr);
129     weight->set_quant_param(std::vector<std::shared_ptr<mindspore::QuantizationParam>>{quantization_ptr});
130     return ret;
131   }
132 };
133 }  // namespace mindspore::lite::quant
134 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIXED_BIT_WEIGHT_QUANTIZATION_H
135