1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H 19 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 #include "ir/anf.h" 24 #include "schema/inner/model_generated.h" 25 26 namespace mindspore { 27 namespace lite { 28 using QuantParamsVector = std::vector<std::vector<schema::QuantParamT>>; 29 class QuantParamHolder : public Value { 30 public: QuantParamHolder(size_t input_size,size_t output_size)31 QuantParamHolder(size_t input_size, size_t output_size) { 32 input_quant_params_.resize(input_size); 33 output_quant_params_.resize(output_size); 34 for (size_t i = 0; i < input_size; i++) { 35 std::vector<schema::QuantParamT> notinited_quant_params(1); 36 set_input_quant_param(i, notinited_quant_params); 37 } 38 39 for (size_t i = 0; i < output_size; i++) { 40 std::vector<schema::QuantParamT> notinited_quant_params(1); 41 set_output_quant_param(i, notinited_quant_params); 42 } 43 } 44 QuantParamHolder(const QuantParamsVector & input_quant_params,const QuantParamsVector & output_quant_params)45 QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) { 46 input_quant_params_ = input_quant_params; 47 output_quant_params_ = output_quant_params; 48 } 49 50 ~QuantParamHolder() override = default; 51 52 MS_DECLARE_PARENT(QuantParamHolder, Value); 53 54 bool operator==(const Value &rhs) const override { // unused 55 if (rhs.isa<QuantParamHolder>()) { 56 auto other_holder = dynamic_cast<const QuantParamHolder &>(rhs); 57 auto input_quant_params_rhs = other_holder.get_input_quant_params(); 58 auto output_quant_params_rhs = other_holder.get_output_quant_params(); 59 if (input_quant_params_rhs.size() != this->input_quant_params_.size() || 60 output_quant_params_rhs.size() != this->output_quant_params_.size()) { 61 return false; 62 } 63 for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) { 64 if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) { 65 return false; 66 } 67 auto *params = reinterpret_cast<const char *>(this->input_quant_params_.at(i).data()); 68 auto *params_rhs = reinterpret_cast<const char *>(input_quant_params_rhs.at(i).data()); 69 for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { 70 if (params[j] != params_rhs[j]) { 71 return false; 72 } 73 } 74 } 75 for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) { 76 if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) { 77 return false; 78 } 79 auto *params = reinterpret_cast<const char *>(this->output_quant_params_.at(i).data()); 80 auto *params_rhs = reinterpret_cast<const char *>(output_quant_params_rhs.at(i).data()); 81 for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) { 82 if (params[j] != params_rhs[j]) { 83 return false; 84 } 85 } 86 } 87 } else { 88 return false; 89 } 90 return true; 91 } 92 set_quant_type(const schema::QuantType & quant_type)93 void set_quant_type(const schema::QuantType &quant_type) { quant_type_ = quant_type; } 94 quant_type()95 schema::QuantType quant_type() const { return quant_type_; } 96 set_input_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & input_quant_param)97 void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) { 98 if (index >= this->input_quant_params_.size()) { 99 std::vector<schema::QuantParamT> place_quant(1); 100 this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(), 101 place_quant); 102 } 103 this->input_quant_params_.at(index) = input_quant_param; 104 } 105 set_output_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & output_quant_param)106 void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) { 107 if (index >= this->output_quant_params_.size()) { 108 std::vector<schema::QuantParamT> place_quant(1); 109 this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(), 110 place_quant); 111 } 112 this->output_quant_params_.at(index) = output_quant_param; 113 } 114 set_enable_huffman_code(bool enable_huffman_code)115 void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; } 116 enable_huffman_code()117 bool enable_huffman_code() const { return enable_huffman_code_; } 118 get_input_quant_params()119 std::vector<std::vector<schema::QuantParamT>> get_input_quant_params() const { return this->input_quant_params_; } 120 get_output_quant_params()121 std::vector<std::vector<schema::QuantParamT>> get_output_quant_params() const { return this->output_quant_params_; } 122 123 // deprecated ClearInputOutputQuantParam()124 void ClearInputOutputQuantParam() { 125 input_quant_params_.clear(); 126 output_quant_params_.clear(); 127 } 128 IsInputQuantParamsInited()129 bool IsInputQuantParamsInited() { 130 if (this->input_quant_params_.empty()) { 131 return false; 132 } 133 for (auto &quant_param : this->input_quant_params_) { 134 if (!quant_param.front().inited) { 135 return false; 136 } 137 } 138 return true; 139 } 140 IsOutputQuantParamsInited()141 bool IsOutputQuantParamsInited() { 142 if (this->output_quant_params_.empty()) { 143 return false; 144 } 145 for (auto &quant_param : this->output_quant_params_) { 146 if (!quant_param.front().inited) { 147 return false; 148 } 149 } 150 return true; 151 } 152 153 private: 154 schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; 155 QuantParamsVector input_quant_params_; 156 QuantParamsVector output_quant_params_; 157 bool enable_huffman_code_ = false; 158 }; 159 using QuantParamHolderPtr = std::shared_ptr<QuantParamHolder>; 160 } // namespace lite 161 } // namespace mindspore 162 163 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H 164