• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
19 
20 #include <future>
21 #include <memory>
22 #include <map>
23 #include <list>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include <set>
28 #include <algorithm>
29 #include "tools/converter/quantizer/quantizer.h"
30 #include "tools/converter/quantizer/quantize_util.h"
31 #include "tools/converter/quantizer/quant_params.h"
32 #include "tools/converter/quantizer/quant_strategy.h"
33 #include "tools/converter/preprocess/preprocess_param.h"
34 #include "ir/func_graph.h"
35 #include "ir/anf.h"
36 #include "include/model.h"
37 #include "base/base.h"
38 #include "abstract/dshape.h"
39 #include "src/litert/lite_session.h"
40 #include "src/common/quant_utils.h"
41 
42 namespace mindspore::lite::quant {
43 class WeightQuantizer : public Quantizer {
44  public:
Quantizer(param)45   explicit WeightQuantizer(const std::shared_ptr<ConverterPara> &param, double init_scale = 0) : Quantizer(param) {
46     bit_num_ = param_->commonQuantParam.bit_num;
47     enable_encode_ = param_->commonQuantParam.enable_encode;
48     if (bit_num_ == 0) {
49       type_id_ = kNumberTypeInt16;
50       is_mixed_bit_ = true;
51       mixed_bit_init_scale_ = param_->mixedBitWeightQuantParam.init_scale;
52       is_auto_tune_ = param_->mixedBitWeightQuantParam.auto_tune;
53     }
54     // parse param for fixed bit quant.
55     if (!is_mixed_bit_) {
56       quant_min_ = QuantMin(bit_num_, false, false);
57       quant_max_ = QuantMax(bit_num_, false);
58       symmetric_quant_min_ = QuantMin(bit_num_, false, true);
59       symmetric_quant_max_ = QuantMax(bit_num_, false);
60       // parse type_id_
61       MS_ASSERT(bit_num_ > 0 && bit_num_ <= k16Bit);
62       if (bit_num_ > 0 && bit_num_ <= k8Bit) {
63         type_id_ = kNumberTypeInt8;
64       } else if (bit_num_ <= k16Bit) {
65         type_id_ = kNumberTypeInt16;
66       }
67     }
68     quant_strategy_ = std::make_unique<QuantStrategy>(param_->commonQuantParam.min_quant_weight_size,
69                                                       param_->commonQuantParam.min_quant_weight_channel,
70                                                       param_->commonQuantParam.skip_quant_node);
71     if (init_scale > 0) {
72       mixed_bit_init_scale_ = init_scale;
73     }
74     if (!param_->commonQuantParam.skip_quant_node.empty()) {
75       std::copy(param_->commonQuantParam.skip_quant_node.cbegin(), param_->commonQuantParam.skip_quant_node.cend(),
76                 std::inserter(skip_quant_node_, skip_quant_node_.begin()));
77     }
78     quant_type_ = param_->commonQuantParam.quant_type;
79     dequant_strategy_ = param_->weightQuantParam.dequant_strategy;
80     max_segments_ = param_->weightQuantParam.max_segments;
81     ascend_backend_ = param_->device.find("Ascend") != std::string::npos;
82     per_channel_ = param_->weightQuantParam.per_channel;
83     bias_correction_ = param_->weightQuantParam.bias_correction;
84     if (per_channel_) {
85       weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_CHANNEL;
86     } else {
87       weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_LAYER;
88     }
89   }
90 
91   ~WeightQuantizer() override = default;
92 
93   int DoQuantize(FuncGraphPtr func_graph) override;
94 
95   int WeightQuant(const FuncGraphPtr &func_graph, const std::set<PrimitivePtr> &support_weight_quant_types,
96                   const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &symmetric_types,
97                   bool compression = true);
98 
GetWeightQuantizedTensors()99   std::set<tensor::TensorPtr> GetWeightQuantizedTensors() { return this->weight_quantized_tensors_; }
100 
101  private:
102   int WeightQuantPerCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
103                           const std::set<PrimitivePtr> &support_weight_quant_types,
104                           const std::set<PrimitivePtr> &per_layer_types, const std::set<PrimitivePtr> &symmetric_types,
105                           bool compression = true);
106   int PreLinearQuant(const CNodePtr &cnode, int idx, const AnfNodePtr &input, ParameterPtr *parameter,
107                      tensor::TensorPtr *tensor_info);
108   int LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::set<PrimitivePtr> &per_layer_types,
109                   const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
110                   bool compression = true);
111   int MarkGraphWeightQuantType(const FuncGraphPtr &func_graph);
112   int MarkCNodeWeightQuantType(const CNodePtr &cnode);
113   int DoCompression(const CNodePtr &cnode, const ParameterPtr &parameter, const tensor::TensorPtr &tensor);
114   int DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &parameter, int idx, const tensor::TensorPtr &tensor_info,
115                     int preferred_dim, WeightQuantType weight_quant_type, bool symmetric = true);
116   int InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr &parameter, int idx,
117                         const tensor::TensorPtr &tensor_info);
118   int InsertAscendDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const ParameterPtr &parameter,
119                               int idx, const tensor::TensorPtr &tensor_info);
120 
121  private:
122   bool is_auto_tune_{false};
123   bool is_mixed_bit_{false};
124   bool linear_quant_{true};
125   size_t bit_num_{8};
126   double mixed_bit_init_scale_ = 0.02;
127   int quant_min_{-128};
128   int quant_max_{127};
129   int symmetric_quant_min_{-127};
130   int symmetric_quant_max_{127};
131   TypeId type_id_{kNumberTypeInt8};
132   std::set<std::string> skip_quant_node_;
133   std::unique_ptr<QuantStrategy> quant_strategy_;
134   quant::QuantType quant_type_{quant::QUANT_WEIGHT};
135   bool enable_encode_{true};
136   WeightQuantType weight_quant_type_ = WeightQuantType::FIXED_BIT_PER_CHANNEL;
137   DequantStrategy dequant_strategy_ = DEFAULT;
138   int max_segments_{1};
139   bool per_channel_{true};
140   bool bias_correction_{true};
141   // Support for mark shared weight node.
142   std::set<tensor::TensorPtr> weight_quantized_tensors_;
143   bool ascend_backend_ = false;
144 };
145 }  // namespace mindspore::lite::quant
146 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
147