• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H
19 
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include "ir/anf.h"
24 #include "schema/inner/model_generated.h"
25 
26 namespace mindspore {
27 namespace lite {
28 using QuantParamsVector = std::vector<std::vector<schema::QuantParamT>>;
29 class QuantParamHolder : public Value {
30  public:
QuantParamHolder(size_t input_size,size_t output_size)31   QuantParamHolder(size_t input_size, size_t output_size) {
32     input_quant_params_.resize(input_size);
33     output_quant_params_.resize(output_size);
34     for (size_t i = 0; i < input_size; i++) {
35       std::vector<schema::QuantParamT> notinited_quant_params(1);
36       set_input_quant_param(i, notinited_quant_params);
37     }
38 
39     for (size_t i = 0; i < output_size; i++) {
40       std::vector<schema::QuantParamT> notinited_quant_params(1);
41       set_output_quant_param(i, notinited_quant_params);
42     }
43   }
44 
QuantParamHolder(const QuantParamsVector & input_quant_params,const QuantParamsVector & output_quant_params)45   QuantParamHolder(const QuantParamsVector &input_quant_params, const QuantParamsVector &output_quant_params) {
46     input_quant_params_ = input_quant_params;
47     output_quant_params_ = output_quant_params;
48   }
49 
50   ~QuantParamHolder() override = default;
51 
52   MS_DECLARE_PARENT(QuantParamHolder, Value);
53 
54   bool operator==(const Value &rhs) const override {  // unused
55     if (rhs.isa<QuantParamHolder>()) {
56       auto other_holder = dynamic_cast<const QuantParamHolder &>(rhs);
57       auto input_quant_params_rhs = other_holder.get_input_quant_params();
58       auto output_quant_params_rhs = other_holder.get_output_quant_params();
59       if (input_quant_params_rhs.size() != this->input_quant_params_.size() ||
60           output_quant_params_rhs.size() != this->output_quant_params_.size()) {
61         return false;
62       }
63       for (size_t i = 0; i < input_quant_params_rhs.size(); ++i) {
64         if (input_quant_params_rhs.at(i).size() != this->input_quant_params_.at(i).size()) {
65           return false;
66         }
67         auto *params = reinterpret_cast<const char *>(this->input_quant_params_.at(i).data());
68         auto *params_rhs = reinterpret_cast<const char *>(input_quant_params_rhs.at(i).data());
69         for (size_t j = 0; j < input_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) {
70           if (params[j] != params_rhs[j]) {
71             return false;
72           }
73         }
74       }
75       for (size_t i = 0; i < output_quant_params_rhs.size(); ++i) {
76         if (output_quant_params_rhs.at(i).size() != this->output_quant_params_.at(i).size()) {
77           return false;
78         }
79         auto *params = reinterpret_cast<const char *>(this->output_quant_params_.at(i).data());
80         auto *params_rhs = reinterpret_cast<const char *>(output_quant_params_rhs.at(i).data());
81         for (size_t j = 0; j < output_quant_params_rhs.at(i).size() * sizeof(schema::QuantParamT); ++j) {
82           if (params[j] != params_rhs[j]) {
83             return false;
84           }
85         }
86       }
87     } else {
88       return false;
89     }
90     return true;
91   }
92 
set_quant_type(const schema::QuantType & quant_type)93   void set_quant_type(const schema::QuantType &quant_type) { quant_type_ = quant_type; }
94 
quant_type()95   schema::QuantType quant_type() const { return quant_type_; }
96 
set_input_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & input_quant_param)97   void set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
98     if (index >= this->input_quant_params_.size()) {
99       std::vector<schema::QuantParamT> place_quant(1);
100       this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(),
101                                        place_quant);
102     }
103     this->input_quant_params_.at(index) = input_quant_param;
104   }
105 
set_output_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & output_quant_param)106   void set_output_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &output_quant_param) {
107     if (index >= this->output_quant_params_.size()) {
108       std::vector<schema::QuantParamT> place_quant(1);
109       this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(),
110                                         place_quant);
111     }
112     this->output_quant_params_.at(index) = output_quant_param;
113   }
114 
set_enable_huffman_code(bool enable_huffman_code)115   void set_enable_huffman_code(bool enable_huffman_code) { enable_huffman_code_ = enable_huffman_code; }
116 
enable_huffman_code()117   bool enable_huffman_code() const { return enable_huffman_code_; }
118 
get_input_quant_params()119   std::vector<std::vector<schema::QuantParamT>> get_input_quant_params() const { return this->input_quant_params_; }
120 
get_output_quant_params()121   std::vector<std::vector<schema::QuantParamT>> get_output_quant_params() const { return this->output_quant_params_; }
122 
123   // deprecated
ClearInputOutputQuantParam()124   void ClearInputOutputQuantParam() {
125     input_quant_params_.clear();
126     output_quant_params_.clear();
127   }
128 
IsInputQuantParamsInited()129   bool IsInputQuantParamsInited() {
130     if (this->input_quant_params_.empty()) {
131       return false;
132     }
133     for (auto &quant_param : this->input_quant_params_) {
134       if (!quant_param.front().inited) {
135         return false;
136       }
137     }
138     return true;
139   }
140 
IsOutputQuantParamsInited()141   bool IsOutputQuantParamsInited() {
142     if (this->output_quant_params_.empty()) {
143       return false;
144     }
145     for (auto &quant_param : this->output_quant_params_) {
146       if (!quant_param.front().inited) {
147         return false;
148       }
149     }
150     return true;
151   }
152 
153  private:
154   schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
155   QuantParamsVector input_quant_params_;
156   QuantParamsVector output_quant_params_;
157   bool enable_huffman_code_ = false;
158 };
159 using QuantParamHolderPtr = std::shared_ptr<QuantParamHolder>;
160 }  // namespace lite
161 }  // namespace mindspore
162 
163 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANT_PARAM_CONTEXT_H
164