• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define USE_DEPRECATED_API
17 #include "tools/converter/quantizer/quant_param_holder.h"
18 #include <utility>
19 #include <vector>
20 #include <memory>
21 #include "schema/inner/model_generated.h"
22 #include "ir/anf.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "src/common/utils.h"
25 
26 namespace mindspore {
27 namespace lite {
TensorQuantParamsInited(const schema::TensorT & tensor)28 bool TensorQuantParamsInited(const schema::TensorT &tensor) {
29   if (tensor.quantParams.empty()) {
30     return false;
31   }
32 
33   bool is_quant_params_inited =
34     std::all_of(tensor.quantParams.cbegin(), tensor.quantParams.cend(),
35                 [](const std::unique_ptr<mindspore::schema::QuantParamT> &quant_param) { return quant_param->inited; });
36   return is_quant_params_inited;
37 }
38 
GetCNodeQuantHolder(const CNodePtr & cnode)39 QuantParamHolderPtr GetCNodeQuantHolder(const CNodePtr &cnode) {
40   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
41   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
42   if (primitive == nullptr) {
43     MS_LOG(INFO) << "primitive is nullptr";
44     return nullptr;
45   }
46   return GetCNodeQuantHolder(primitive);
47 }
48 
GetCNodeQuantHolder(const PrimitivePtr & primitive)49 QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
50   MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
51   QuantParamHolderPtr quant_params_holder = nullptr;
52   auto quant_params_valueptr = primitive->GetAttr("quant_params");
53   if (quant_params_valueptr == nullptr) {
54     quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
55     MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
56     primitive->AddAttr("quant_params", quant_params_holder);
57   } else {
58     quant_params_holder = quant_params_valueptr->cast<QuantParamHolderPtr>();
59     if (quant_params_holder == nullptr) {
60       quant_params_holder = std::make_shared<QuantParamHolder>(0, 0);
61       MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
62       primitive->AddAttr("quant_params", quant_params_holder);
63     }
64   }
65   return quant_params_holder;
66 }
67 
set_input_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & input_quant_param)68 void QuantParamHolder::set_input_quant_param(const size_t &index,
69                                              const std::vector<schema::QuantParamT> &input_quant_param) {
70   if (index >= this->input_quant_params_.size()) {
71     std::vector<schema::QuantParamT> place_quant(1);
72     this->input_quant_params_.insert(this->input_quant_params_.end(), index + 1 - input_quant_params_.size(),
73                                      place_quant);
74   }
75 
76   this->input_quant_params_.at(index) = input_quant_param;
77 }
78 
set_output_quant_param(const size_t & index,const std::vector<schema::QuantParamT> & output_quant_param)79 void QuantParamHolder::set_output_quant_param(const size_t &index,
80                                               const std::vector<schema::QuantParamT> &output_quant_param) {
81   if (index >= this->output_quant_params_.size()) {
82     std::vector<schema::QuantParamT> place_quant(1);
83     this->output_quant_params_.insert(this->output_quant_params_.end(), index + 1 - output_quant_params_.size(),
84                                       place_quant);
85   }
86   this->output_quant_params_.at(index) = output_quant_param;
87 }
88 
IsInputQuantParamsInited()89 bool QuantParamHolder::IsInputQuantParamsInited() {
90   if (this->input_quant_params_.empty()) {
91     return false;
92   }
93   bool is_quant_params_inited =
94     std::all_of(this->input_quant_params_.begin(), this->input_quant_params_.end(),
95                 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
96   return is_quant_params_inited;
97 }
98 
IsOutputQuantParamsInited()99 bool QuantParamHolder::IsOutputQuantParamsInited() {
100   if (this->output_quant_params_.empty()) {
101     return false;
102   }
103   bool is_quant_params_inited =
104     std::all_of(this->output_quant_params_.begin(), this->output_quant_params_.end(),
105                 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
106   return is_quant_params_inited;
107 }
108 
IsInputExistInited()109 bool QuantParamHolder::IsInputExistInited() {
110   if (this->input_quant_params_.empty()) {
111     return false;
112   }
113   bool is_exist_param_inited =
114     std::any_of(this->input_quant_params_.begin(), this->input_quant_params_.end(),
115                 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
116   return is_exist_param_inited;
117 }
118 
IsOutputExistInited()119 bool QuantParamHolder::IsOutputExistInited() {
120   if (this->output_quant_params_.empty()) {
121     return false;
122   }
123   bool is_exist_param_inited =
124     std::any_of(this->output_quant_params_.begin(), this->output_quant_params_.end(),
125                 [](const std::vector<schema::QuantParamT> &quant_params) { return quant_params.front().inited; });
126   return is_exist_param_inited;
127 }
128 
ClearQuantParams()129 void QuantParamHolder::ClearQuantParams() {
130   quant_type_ = quant::QUANT_NONE;
131   input_quant_params_.clear();
132   output_quant_params_.clear();
133 }
134 
CheckInit(size_t index,bool is_input)135 bool QuantParamHolder::CheckInit(size_t index, bool is_input) {
136   std::vector<schema::QuantParamT> param;
137   if (is_input) {
138     if (input_quant_params_.size() <= index) {
139       return false;
140     }
141     param = input_quant_params_.at(index);
142   } else {
143     if (output_quant_params_.size() <= index) {
144       return false;
145     }
146     param = output_quant_params_.at(index);
147   }
148   return (!param.empty() && param.front().inited);
149 }
150 
SetQuantClusters(size_t index,const std::vector<float> & quant_cluster)151 void QuantParamHolder::SetQuantClusters(size_t index, const std::vector<float> &quant_cluster) {
152   quant_clusters.insert({index, quant_cluster});
153 }
154 
GetQuantClusters(size_t index)155 std::vector<float> QuantParamHolder::GetQuantClusters(size_t index) {
156   auto iter = quant_clusters.find(index);
157   if (iter == quant_clusters.end()) {
158     return {};
159   } else {
160     return iter->second;
161   }
162 }
163 }  // namespace lite
164 }  // namespace mindspore
165