1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ 19 20 #include <future> 21 #include <memory> 22 #include <map> 23 #include <list> 24 #include <string> 25 #include <utility> 26 #include <vector> 27 #include <set> 28 #include <algorithm> 29 #include "tools/converter/quantizer/quantizer.h" 30 #include "tools/converter/quantizer/quantize_util.h" 31 #include "tools/converter/quantizer/quant_params.h" 32 #include "tools/converter/quantizer/quant_strategy.h" 33 #include "tools/converter/preprocess/preprocess_param.h" 34 #include "ir/func_graph.h" 35 #include "ir/anf.h" 36 #include "include/model.h" 37 #include "base/base.h" 38 #include "abstract/dshape.h" 39 #include "src/litert/lite_session.h" 40 #include "src/common/quant_utils.h" 41 42 namespace mindspore::lite::quant { 43 class WeightQuantizer : public Quantizer { 44 public: Quantizer(param)45 explicit WeightQuantizer(const std::shared_ptr<ConverterPara> ¶m, double init_scale = 0) : Quantizer(param) { 46 bit_num_ = param_->commonQuantParam.bit_num; 47 enable_encode_ = param_->commonQuantParam.enable_encode; 48 if (bit_num_ == 0) { 49 type_id_ = kNumberTypeInt16; 50 is_mixed_bit_ = true; 51 mixed_bit_init_scale_ = param_->mixedBitWeightQuantParam.init_scale; 52 is_auto_tune_ = param_->mixedBitWeightQuantParam.auto_tune; 53 } 54 // parse param for fixed bit quant. 55 if (!is_mixed_bit_) { 56 quant_min_ = QuantMin(bit_num_, false, false); 57 quant_max_ = QuantMax(bit_num_, false); 58 symmetric_quant_min_ = QuantMin(bit_num_, false, true); 59 symmetric_quant_max_ = QuantMax(bit_num_, false); 60 // parse type_id_ 61 MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit); 62 if (bit_num_ > 0 && bit_num_ <= k8Bit) { 63 type_id_ = kNumberTypeInt8; 64 } else if (bit_num_ <= k16Bit) { 65 type_id_ = kNumberTypeInt16; 66 } 67 } 68 quant_strategy_ = std::make_unique<QuantStrategy>(param_->commonQuantParam.min_quant_weight_size, 69 param_->commonQuantParam.min_quant_weight_channel, 70 param_->commonQuantParam.skip_quant_node); 71 if (init_scale > 0) { 72 mixed_bit_init_scale_ = init_scale; 73 } 74 if (!param_->commonQuantParam.skip_quant_node.empty()) { 75 std::copy(param_->commonQuantParam.skip_quant_node.cbegin(), param_->commonQuantParam.skip_quant_node.cend(), 76 std::inserter(skip_quant_node_, skip_quant_node_.begin())); 77 } 78 quant_type_ = param_->commonQuantParam.quant_type; 79 dequant_strategy_ = param_->weightQuantParam.dequant_strategy; 80 max_segments_ = param_->weightQuantParam.max_segments; 81 ascend_backend_ = param_->device.find("Ascend") != std::string::npos; 82 per_channel_ = param_->weightQuantParam.per_channel; 83 bias_correction_ = param_->weightQuantParam.bias_correction; 84 if (per_channel_) { 85 weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_CHANNEL; 86 } else { 87 weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_LAYER; 88 } 89 } 90 91 ~WeightQuantizer() override = default; 92 93 int DoQuantize(FuncGraphPtr func_graph) override; 94 95 int WeightQuant(const FuncGraphPtr &func_graph, const std::set<PrimitivePtr> &support_weight_quant_types, 96 const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &symmetric_types, 97 bool compression = true); 98 GetWeightQuantizedTensors()99 std::set<tensor::TensorPtr> GetWeightQuantizedTensors() { return this->weight_quantized_tensors_; } 100 101 private: 102 int WeightQuantPerCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, 103 const std::set<PrimitivePtr> &support_weight_quant_types, 104 const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &symmetric_types, 105 bool compression = true); 106 int PreLinearQuant(const CNodePtr &cnode, int idx, const AnfNodePtr &input, ParameterPtr *parameter, 107 tensor::TensorPtr *tensor_info); 108 int LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::set<PrimitivePtr> &per_layer_types, 109 const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices, 110 bool compression = true); 111 int MarkGraphWeightQuantType(const FuncGraphPtr &func_graph); 112 int MarkCNodeWeightQuantType(const CNodePtr &cnode); 113 int DoCompression(const CNodePtr &cnode, const ParameterPtr ¶meter, const tensor::TensorPtr &tensor); 114 int DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx, const tensor::TensorPtr &tensor_info, 115 int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true); 116 int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr ¶meter, int idx, 117 const tensor::TensorPtr &tensor_info); 118 int InsertAscendDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr ¶meter, 119 int idx, const tensor::TensorPtr &tensor_info); 120 121 private: 122 bool is_auto_tune_{false}; 123 bool is_mixed_bit_{false}; 124 bool linear_quant_{true}; 125 size_t bit_num_{8}; 126 double mixed_bit_init_scale_ = 0.02; 127 int quant_min_{-128}; 128 int quant_max_{127}; 129 int symmetric_quant_min_{-127}; 130 int symmetric_quant_max_{127}; 131 TypeId type_id_{kNumberTypeInt8}; 132 std::set<std::string> skip_quant_node_; 133 std::unique_ptr<QuantStrategy> quant_strategy_; 134 quant::QuantType quant_type_{quant::QUANT_WEIGHT}; 135 bool enable_encode_{true}; 136 WeightQuantType weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_CHANNEL; 137 DequantStrategy dequant_strategy_ = DEFAULT; 138 int max_segments_{1}; 139 bool per_channel_{true}; 140 bool bias_correction_{true}; 141 // Support for mark shared weight node. 142 std::set<tensor::TensorPtr> weight_quantized_tensors_; 143 bool ascend_backend_ = false; 144 }; 145 } // namespace mindspore::lite::quant 146 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_ 147