1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PARAM_HOLDER_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PARAM_HOLDER_H_ 19 #define USE_DEPRECATED_API 20 #include <vector> 21 #include <memory> 22 #include <map> 23 #include "ir/anf.h" 24 #include "schema/inner/model_generated.h" 25 #include "nnacl/op_base.h" 26 #include "src/common/log_util.h" 27 #include "tools/converter/quantizer/quant_params.h" 28 29 namespace mindspore { 30 namespace lite { 31 using QuantParamsVector = std::vector<std::vector<schema::QuantParamT>>; 32 class QuantParamHolder : public Value { 33 public: QuantParamHolder(size_t input_size,size_t output_size)34 QuantParamHolder(size_t input_size, size_t output_size) { 35 input_quant_params_.resize(input_size); 36 output_quant_params_.resize(output_size); 37 for (size_t i = 0; i < input_size; i++) { 38 std::vector<schema::QuantParamT> notinited_quant_params(1); 39 set_input_quant_param(i, notinited_quant_params); 40 } 41 42 for (size_t i = 0; i < output_size; i++) { 43 std::vector<schema::QuantParamT> notinited_quant_params(1); 44 set_output_quant_param(i, notinited_quant_params); 45 } 46 } 47 QuantParamHolder(const QuantParamsVector & input_quant_params,const QuantParamsVector & output_quant_params)48 QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) { 49 input_quant_params_ = input_quant_params; 50 output_quant_params_ = output_quant_params; 51 } 52 QuantParamHolder(const QuantParamHolder & obj)53 QuantParamHolder(const QuantParamHolder &obj) { 54 input_quant_params_ = obj.input_quant_params_; 55 output_quant_params_ = obj.output_quant_params_; 56 quant_type_ = obj.quant_type_; 57 enable_huffman_code_ = obj.enable_huffman_code_; 58 quant_clusters = obj.quant_clusters; 59 } 60 61 ~QuantParamHolder() override = default; 62 63 MS_DECLARE_PARENT(QuantParamHolder, Value); 64 65 bool operator==(const Value &rhs) const override { 66 if (rhs.isa<QuantParamHolder>()) { 67 auto other_holder = dynamic_cast<const QuantParamHolder &>(rhs); 68 auto input_quant_params_rhs = other_holder.get_input_quant_params(); 69 auto output_quant_params_rhs = other_holder.get_output_quant_params(); 70 if (input_quant_params_rhs.size() != this->input_quant_params_.size() || 71 output_quant_params_rhs.size() != this->output_quant_params_.size()) { 72 return false; 73 } 74 for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) { 75 if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) { 76 return false; 77 } 78 auto *params = reinterpret_cast<const int8_t *>(this->input_quant_params_.at(i).data()); 79 auto *params_rhs = reinterpret_cast<const int8_t *>(input_quant_params_rhs.at(i).data()); 80 MS_CHECK_TRUE_RET(params != nullptr && params_rhs != nullptr, false); 81 for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { 82 if (params[j] != params_rhs[j]) { 83 return false; 84 } 85 } 86 } 87 for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) { 88 if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) { 89 return false; 90 } 91 auto *params = reinterpret_cast<const int8_t *>(this->output_quant_params_.at(i).data()); 92 auto *params_rhs = reinterpret_cast<const int8_t *>(output_quant_params_rhs.at(i).data()); 93 MS_CHECK_TRUE_RET(params != nullptr && params_rhs != nullptr, false); 94 for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { 95 if (params[j] != params_rhs[j]) { 96 return false; 97 } 98 } 99 } 100 } else { 101 return false; 102 } 103 return true; 104 } 105 set_quant_type(const quant::QuantType & quant_type)106 void set_quant_type(const quant::QuantType &quant_type) { quant_type_ = quant_type; } 107 quant_type()108 quant::QuantType quant_type() const { return quant_type_; } 109 set_enable_huffman_code(bool enable_huffman_code)110 void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; } 111 enable_huffman_code()112 bool enable_huffman_code() const { return enable_huffman_code_; } 113 get_input_quant_params()114 std::vector<std::vector<schema::QuantParamT>> get_input_quant_params() const { return this->input_quant_params_; } 115 get_output_quant_params()116 std::vector<std::vector<schema::QuantParamT>> get_output_quant_params() const { return this->output_quant_params_; } 117 118 void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param); 119 120 void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param); 121 122 bool IsInputQuantParamsInited(); 123 124 bool IsOutputQuantParamsInited(); 125 126 bool IsInputExistInited(); 127 128 bool IsOutputExistInited(); 129 130 void ClearQuantParams(); 131 132 bool CheckInit(size_t index, bool is_input); 133 134 void SetQuantClusters(size_t index, const std::vector<float> &quant_cluster); 135 136 std::vector<float> GetQuantClusters(size_t index); 137 138 private: 139 quant::QuantType quant_type_{quant::QUANT_NONE}; 140 QuantParamsVector input_quant_params_; 141 QuantParamsVector output_quant_params_; 142 bool enable_huffman_code_ = false; 143 std::map<size_t, std::vector<float>> quant_clusters; 144 }; 145 using QuantParamHolderPtr = std::shared_ptr<QuantParamHolder>; 146 147 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive); 148 149 QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode); 150 151 bool TensorQuantParamsInited(const schema::TensorT &tensor); 152 } // namespace lite 153 } // namespace mindspore 154 155 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_PARAM_HOLDER_H_ 156