• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PARAM_HOLDER_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PARAM_HOLDER_H_
19 #define USE_DEPRECATED_API
20 #include <vector>
21 #include <memory>
22 #include <map>
23 #include "ir/anf.h"
24 #include "schema/inner/model_generated.h"
25 #include "nnacl/op_base.h"
26 #include "src/common/log_util.h"
27 #include "tools/converter/quantizer/quant_params.h"
28 
29 namespace mindspore {
30 namespace lite {
31 using QuantParamsVector = std::vector<std::vector<schema::QuantParamT>>;
32 class QuantParamHolder : public Value {
33  public:
QuantParamHolder(size_t input_size,size_t output_size)34   QuantParamHolder(size_t input_size, size_t output_size) {
35     input_quant_params_.resize(input_size);
36     output_quant_params_.resize(output_size);
37     for (size_t i = 0; i < input_size; i++) {
38       std::vector<schema::QuantParamT> notinited_quant_params(1);
39       set_input_quant_param(i, notinited_quant_params);
40     }
41 
42     for (size_t i = 0; i < output_size; i++) {
43       std::vector<schema::QuantParamT> notinited_quant_params(1);
44       set_output_quant_param(i, notinited_quant_params);
45     }
46   }
47 
QuantParamHolder(const QuantParamsVector & input_quant_params,const QuantParamsVector & output_quant_params)48   QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) {
49     input_quant_params_ = input_quant_params;
50     output_quant_params_ = output_quant_params;
51   }
52 
QuantParamHolder(const QuantParamHolder & obj)53   QuantParamHolder(const QuantParamHolder &obj) {
54     input_quant_params_ = obj.input_quant_params_;
55     output_quant_params_ = obj.output_quant_params_;
56     quant_type_ = obj.quant_type_;
57     enable_huffman_code_ = obj.enable_huffman_code_;
58     quant_clusters = obj.quant_clusters;
59   }
60 
61   ~QuantParamHolder() override = default;
62 
63   MS_DECLARE_PARENT(QuantParamHolder, Value);
64 
65   bool operator==(const Value &rhs) const override {
66     if (rhs.isa<QuantParamHolder>()) {
67       auto other_holder = dynamic_cast<const QuantParamHolder &>(rhs);
68       auto input_quant_params_rhs = other_holder.get_input_quant_params();
69       auto output_quant_params_rhs = other_holder.get_output_quant_params();
70       if (input_quant_params_rhs.size() != this->input_quant_params_.size() ||
71           output_quant_params_rhs.size() != this->output_quant_params_.size()) {
72         return false;
73       }
74       for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) {
75         if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) {
76           return false;
77         }
78         auto *params = reinterpret_cast<const int8_t *>(this->input_quant_params_.at(i).data());
79         auto *params_rhs = reinterpret_cast<const int8_t *>(input_quant_params_rhs.at(i).data());
80         MS_CHECK_TRUE_RET(params != nullptr && params_rhs != nullptr, false);
81         for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) {
82           if (params[j] != params_rhs[j]) {
83             return false;
84           }
85         }
86       }
87       for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) {
88         if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) {
89           return false;
90         }
91         auto *params = reinterpret_cast<const int8_t *>(this->output_quant_params_.at(i).data());
92         auto *params_rhs = reinterpret_cast<const int8_t *>(output_quant_params_rhs.at(i).data());
93         MS_CHECK_TRUE_RET(params != nullptr && params_rhs != nullptr, false);
94         for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) {
95           if (params[j] != params_rhs[j]) {
96             return false;
97           }
98         }
99       }
100     } else {
101       return false;
102     }
103     return true;
104   }
105 
set_quant_type(const quant::QuantType & quant_type)106   void set_quant_type(const quant::QuantType &quant_type) { quant_type_ = quant_type; }
107 
quant_type()108   quant::QuantType quant_type() const { return quant_type_; }
109 
set_enable_huffman_code(bool enable_huffman_code)110   void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; }
111 
enable_huffman_code()112   bool enable_huffman_code() const { return enable_huffman_code_; }
113 
get_input_quant_params()114   std::vector<std::vector<schema::QuantParamT>> get_input_quant_params() const { return this->input_quant_params_; }
115 
get_output_quant_params()116   std::vector<std::vector<schema::QuantParamT>> get_output_quant_params() const { return this->output_quant_params_; }
117 
118   void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param);
119 
120   void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param);
121 
122   bool IsInputQuantParamsInited();
123 
124   bool IsOutputQuantParamsInited();
125 
126   bool IsInputExistInited();
127 
128   bool IsOutputExistInited();
129 
130   void ClearQuantParams();
131 
132   bool CheckInit(size_t index, bool is_input);
133 
134   void SetQuantClusters(size_t index, const std::vector<float> &quant_cluster);
135 
136   std::vector<float> GetQuantClusters(size_t index);
137 
138  private:
139   quant::QuantType quant_type_{quant::QUANT_NONE};
140   QuantParamsVector input_quant_params_;
141   QuantParamsVector output_quant_params_;
142   bool enable_huffman_code_ = false;
143   std::map<size_t, std::vector<float>> quant_clusters;
144 };
145 using QuantParamHolderPtr = std::shared_ptr<QuantParamHolder>;
146 
147 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
148 
149 QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode);
150 
151 bool TensorQuantParamsInited(const schema::TensorT &tensor);
152 }  // namespace lite
153 }  // namespace mindspore
154 
155 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PARAM_HOLDER_H_
156