1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "tools/converter/quantizer/calibrator.h"
20 #include <utility>
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "tools/converter/preprocess/image_preprocess.h"
23 #include "ops/tuple_get_item.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "include/errorcode.h"
26 #include "src/common/log_adapter.h"
27
28 namespace mindspore::lite::quant {
29 namespace {
30 constexpr int kDefaultBinNumber = 2048;
31 } // namespace
RecordMaxMinValue(const std::vector<float> & data,const std::unique_ptr<DataDistribution> & diverg_info)32 int Calibrator::RecordMaxMinValue(const std::vector<float> &data,
33 const std::unique_ptr<DataDistribution> &diverg_info) {
34 auto ret = diverg_info->RecordMaxMinValueArray(data);
35 if (ret != RET_OK) {
36 MS_LOG(ERROR) << "Record max min value array failed.";
37 return ret;
38 }
39 return RET_OK;
40 }
41
ComputeThreshold()42 int Calibrator::ComputeThreshold() {
43 for (auto &kv : this->outputs_diverg_info_) {
44 auto &outputs_diverg_info = kv.second;
45 for (auto &diverg_info : outputs_diverg_info) {
46 MS_CHECK_TRUE_RET(diverg_info.second != nullptr, RET_ERROR);
47 auto ret = diverg_info.second->ComputeThreshold();
48 if (ret != RET_OK) {
49 MS_LOG(ERROR) << "Compute threshold failed.";
50 return ret;
51 }
52 }
53 }
54 // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
55 for (auto &kv : this->inputs_diverg_info_) {
56 auto &input_infos = kv.second;
57 for (size_t i = 0; i < input_infos.size(); i++) {
58 auto cnode = input_infos[i]->GetCNode();
59 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr.");
60 bool already_computed = false;
61 MS_CHECK_GT(cnode->size(), i + 1, RET_ERROR);
62 auto input = cnode->input(i + 1);
63 if (input->isa<mindspore::CNode>()) {
64 auto input_cnode = input->cast<CNodePtr>();
65 for (const auto &outputs_diverg_info : outputs_diverg_info_) {
66 if (already_computed) {
67 break;
68 }
69 for (const auto &output_diverg_info : outputs_diverg_info.second) {
70 MS_CHECK_TRUE_RET(output_diverg_info.second != nullptr, RET_ERROR);
71 auto output_diverg_cnode = output_diverg_info.second->GetCNode();
72 if (output_diverg_cnode == input_cnode) {
73 if (NodePrimitiveType(input_cnode) != ops::kNameTupleGetItem) {
74 *(input_infos[i]) = *output_diverg_info.second;
75 input_infos[i]->GetCNode() = cnode;
76 already_computed = true;
77 break;
78 }
79 }
80 }
81 }
82 }
83 if (!already_computed) {
84 auto ret = input_infos[i]->ComputeThreshold();
85 if (ret != RET_OK) {
86 MS_LOG(ERROR) << "ComputeThreshold failed.";
87 return ret;
88 }
89 }
90 }
91 }
92 return RET_OK;
93 }
94
UpdateDivergInterval()95 int Calibrator::UpdateDivergInterval() {
96 for (auto &kv : inputs_diverg_info_) {
97 for (auto &info : kv.second) {
98 info.second->UpdateInterval();
99 }
100 }
101 for (auto &kv : outputs_diverg_info_) {
102 for (auto &info : kv.second) {
103 info.second->UpdateInterval();
104 }
105 }
106 return RET_OK;
107 }
108
UpdateDataFrequency(const std::vector<float> & data,const std::unique_ptr<DataDistribution> & diverg_info)109 int Calibrator::UpdateDataFrequency(const std::vector<float> &data,
110 const std::unique_ptr<DataDistribution> &diverg_info) {
111 MS_ASSERT(diverg_info != nullptr);
112 return diverg_info->UpdateHistogram(data);
113 }
114
AddQuantizedOp(const CNodePtr & cnode)115 int Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
116 if (cnode == nullptr) {
117 MS_LOG(ERROR) << "To be quantized cnode is null";
118 return RET_ERROR;
119 }
120 auto node_name = cnode->fullname_with_scope();
121 auto input_size = cnode->size();
122 int index = 0;
123 for (size_t i = 1; i < input_size; i++) {
124 if (opt::CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple)) {
125 auto input_cnode = cnode->input(i)->cast<CNodePtr>();
126 MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "input_cnode is nullptr.");
127 auto make_tuple_size = input_cnode->size() - 1;
128 for (size_t j = 0; j < make_tuple_size; j++) {
129 std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
130 cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
131 MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
132 inputs_diverg_info_[node_name].insert({index++, std::move(input_diverg)});
133 }
134 } else {
135 std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
136 cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
137 MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
138 inputs_diverg_info_[node_name].insert({index++, std::move(input_diverg)});
139 }
140 }
141
142 if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
143 auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
144 MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
145 auto elements = tuple->elements();
146 for (size_t i = 0; i < elements.size(); i++) {
147 std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
148 cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
149 MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
150 outputs_diverg_info_[node_name].insert({i, std::move(output_diverg)});
151 }
152 } else {
153 std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
154 cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
155 MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
156 outputs_diverg_info_[node_name].insert({0, std::move(output_diverg)});
157 }
158 return RET_OK;
159 }
160
GenerateInputData(const std::string & input_name,size_t image_index,mindspore::MSTensor * tensor) const161 int Calibrator::GenerateInputData(const std::string &input_name, size_t image_index,
162 mindspore::MSTensor *tensor) const {
163 return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
164 }
165
CollectDataDistribution(const std::string & node_name,const std::vector<mindspore::MSTensor> & tensors,std::unordered_map<std::string,std::map<int,std::unique_ptr<DataDistribution>>> * diverg_info_map,CollectType collect_type)166 int Calibrator::CollectDataDistribution(
167 const std::string &node_name, const std::vector<mindspore::MSTensor> &tensors,
168 std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
169 CollectType collect_type) {
170 MS_CHECK_TRUE_MSG(diverg_info_map != nullptr, RET_ERROR, "diverg_info_map is nullptr.");
171 if (diverg_info_map->find(node_name) == diverg_info_map->end()) {
172 return RET_OK;
173 }
174 for (size_t i = 0; i < tensors.size(); i++) {
175 auto tensor = tensors[i];
176 if (tensor.IsConst() || tensor.DataType() != DataType::kNumberTypeFloat32) {
177 continue;
178 }
179 const auto *tensor_data = static_cast<const float *>(tensor.Data().get());
180 if (tensor_data == nullptr) {
181 MS_LOG(ERROR) << tensor.Name() << " tensor_data is nullptr.";
182 return RET_ERROR;
183 }
184 size_t elem_count = static_cast<size_t>(tensor.ElementNum());
185 MS_CHECK_GT(elem_count, 0, RET_ERROR);
186 std::vector<float> data(tensor_data, tensor_data + elem_count);
187 if (collect_type == MIN_MAX) {
188 MS_CHECK_LT(i, (*diverg_info_map)[node_name].size(), RET_ERROR);
189 auto ret = RecordMaxMinValue(data, (*diverg_info_map)[node_name][i]);
190 if (ret != RET_OK) {
191 MS_LOG(ERROR) << tensor.Name() << " record max min value failed.";
192 return RET_ERROR;
193 }
194 } else if (collect_type == KL_BIN) {
195 MS_CHECK_LT(i, (*diverg_info_map)[node_name].size(), RET_ERROR);
196 auto ret = UpdateDataFrequency(data, (*diverg_info_map)[node_name][i]);
197 if (ret != RET_OK) {
198 MS_LOG(ERROR) << tensor.Name() << " update data frequency failed.";
199 return RET_ERROR;
200 }
201 }
202 }
203 return RET_OK;
204 }
205 } // namespace mindspore::lite::quant
206