1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define USE_DEPRECATED_API
17 #include "tools/converter/quantizer/quant_param_holder.h"
18 #include <utility>
19 #include <vector>
20 #include <memory>
21 #include "schema/inner/model_generated.h"
22 #include "ir/anf.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "src/common/utils.h"
25
26 namespace mindspore {
27 namespace lite {
TensorQuantParamsInited(const schema::TensorT & tensor)28 bool TensorQuantParamsInited(const schema::TensorT &tensor) {
29 if (tensor.quantParams.empty()) {
30 return false;
31 }
32
33 bool is_quant_params_inited =
34 std::all_of(tensor.quantParams.cbegin(), tensor.quantParams.cend(),
35 [](const std::unique_ptr<mindspore::schema::QuantParamT> &quant_param) { return quant_param->inited; });
36 return is_quant_params_inited;
37 }
38
GetCNodeQuantHolder(const CNodePtr & cnode)39 QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode) {
40 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
41 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
42 if (primitive == nullptr) {
43 MS_LOG(INFO) << "primitive is nullptr";
44 return nullptr;
45 }
46 return GetCNodeQuantHolder(primitive);
47 }
48
GetCNodeQuantHolder(const PrimitivePtr & primitive)49 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
50 MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
51 QuantParamHolderPtr quant_params_holder = nullptr;
52 auto quant_params_valueptr = primitive->GetAttr("quant_params");
53 if (quant_params_valueptr == nullptr) {
54 quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
55 MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
56 primitive->AddAttr("quant_params", quant_params_holder);
57 } else {
58 quant_params_holder = quant_params_valueptr->cast<QuantParamHolderPtr>();
59 if (quant_params_holder == nullptr) {
60 quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
61 MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
62 primitive->AddAttr("quant_params", quant_params_holder);
63 }
64 }
65 return quant_params_holder;
66 }
67
set_input_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & input_quant_param)68 void QuantParamHolder::set_input_quant_param(const size_t &index,
69 const std::vector<schema::QuantParamT> &input_quant_param) {
70 if (index >= this->input_quant_params_.size()) {
71 std::vector<schema::QuantParamT> place_quant(1);
72 this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(),
73 place_quant);
74 }
75
76 this->input_quant_params_.at(index) = input_quant_param;
77 }
78
set_output_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & output_quant_param)79 void QuantParamHolder::set_output_quant_param(const size_t &index,
80 const std::vector<schema::QuantParamT> &output_quant_param) {
81 if (index >= this->output_quant_params_.size()) {
82 std::vector<schema::QuantParamT> place_quant(1);
83 this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(),
84 place_quant);
85 }
86 this->output_quant_params_.at(index) = output_quant_param;
87 }
88
IsInputQuantParamsInited()89 bool QuantParamHolder::IsInputQuantParamsInited() {
90 if (this->input_quant_params_.empty()) {
91 return false;
92 }
93 bool is_quant_params_inited =
94 std::all_of(this->input_quant_params_.begin(), this->input_quant_params_.end(),
95 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
96 return is_quant_params_inited;
97 }
98
IsOutputQuantParamsInited()99 bool QuantParamHolder::IsOutputQuantParamsInited() {
100 if (this->output_quant_params_.empty()) {
101 return false;
102 }
103 bool is_quant_params_inited =
104 std::all_of(this->output_quant_params_.begin(), this->output_quant_params_.end(),
105 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
106 return is_quant_params_inited;
107 }
108
IsInputExistInited()109 bool QuantParamHolder::IsInputExistInited() {
110 if (this->input_quant_params_.empty()) {
111 return false;
112 }
113 bool is_exist_param_inited =
114 std::any_of(this->input_quant_params_.begin(), this->input_quant_params_.end(),
115 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
116 return is_exist_param_inited;
117 }
118
IsOutputExistInited()119 bool QuantParamHolder::IsOutputExistInited() {
120 if (this->output_quant_params_.empty()) {
121 return false;
122 }
123 bool is_exist_param_inited =
124 std::any_of(this->output_quant_params_.begin(), this->output_quant_params_.end(),
125 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
126 return is_exist_param_inited;
127 }
128
ClearQuantParams()129 void QuantParamHolder::ClearQuantParams() {
130 quant_type_ = quant::QUANT_NONE;
131 input_quant_params_.clear();
132 output_quant_params_.clear();
133 }
134
CheckInit(size_t index,bool is_input)135 bool QuantParamHolder::CheckInit(size_t index, bool is_input) {
136 std::vector<schema::QuantParamT> param;
137 if (is_input) {
138 if (input_quant_params_.size() <= index) {
139 return false;
140 }
141 param = input_quant_params_.at(index);
142 } else {
143 if (output_quant_params_.size() <= index) {
144 return false;
145 }
146 param = output_quant_params_.at(index);
147 }
148 return (!param.empty() && param.front().inited);
149 }
150
SetQuantClusters(size_t index,const std::vector<float> & quant_cluster)151 void QuantParamHolder::SetQuantClusters(size_t index, const std::vector<float> &quant_cluster) {
152 quant_clusters.insert({index, quant_cluster});
153 }
154
GetQuantClusters(size_t index)155 std::vector<float> QuantParamHolder::GetQuantClusters(size_t index) {
156 auto iter = quant_clusters.find(index);
157 if (iter == quant_clusters.end()) {
158 return {};
159 } else {
160 return iter->second;
161 }
162 }
163 } // namespace lite
164 } // namespace mindspore
165