• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "tools/converter/quantizer/calibrator.h"
20 #include <utility>
21 #include "mindspore/core/ops/sequence_ops.h"
22 #include "tools/converter/preprocess/image_preprocess.h"
23 #include "ops/tuple_get_item.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "include/errorcode.h"
26 #include "src/common/log_adapter.h"
27 
28 namespace mindspore::lite::quant {
29 namespace {
30 constexpr int kDefaultBinNumber = 2048;
31 }  // namespace
RecordMaxMinValue(const std::vector<float> & data,const std::unique_ptr<DataDistribution> & diverg_info)32 int Calibrator::RecordMaxMinValue(const std::vector<float> &data,
33                                   const std::unique_ptr<DataDistribution> &diverg_info) {
34   auto ret = diverg_info->RecordMaxMinValueArray(data);
35   if (ret != RET_OK) {
36     MS_LOG(ERROR) << "Record max min value array failed.";
37     return ret;
38   }
39   return RET_OK;
40 }
41 
ComputeThreshold()42 int Calibrator::ComputeThreshold() {
43   for (auto &kv : this->outputs_diverg_info_) {
44     auto &outputs_diverg_info = kv.second;
45     for (auto &diverg_info : outputs_diverg_info) {
46       MS_CHECK_TRUE_RET(diverg_info.second != nullptr, RET_ERROR);
47       auto ret = diverg_info.second->ComputeThreshold();
48       if (ret != RET_OK) {
49         MS_LOG(ERROR) << "Compute threshold failed.";
50         return ret;
51       }
52     }
53   }
54   // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
55   for (auto &kv : this->inputs_diverg_info_) {
56     auto &input_infos = kv.second;
57     for (size_t i = 0; i < input_infos.size(); i++) {
58       auto cnode = input_infos[i]->GetCNode();
59       MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr.");
60       bool already_computed = false;
61       MS_CHECK_GT(cnode->size(), i + 1, RET_ERROR);
62       auto input = cnode->input(i + 1);
63       if (input->isa<mindspore::CNode>()) {
64         auto input_cnode = input->cast<CNodePtr>();
65         for (const auto &outputs_diverg_info : outputs_diverg_info_) {
66           if (already_computed) {
67             break;
68           }
69           for (const auto &output_diverg_info : outputs_diverg_info.second) {
70             MS_CHECK_TRUE_RET(output_diverg_info.second != nullptr, RET_ERROR);
71             auto output_diverg_cnode = output_diverg_info.second->GetCNode();
72             if (output_diverg_cnode == input_cnode) {
73               if (NodePrimitiveType(input_cnode) != ops::kNameTupleGetItem) {
74                 *(input_infos[i]) = *output_diverg_info.second;
75                 input_infos[i]->GetCNode() = cnode;
76                 already_computed = true;
77                 break;
78               }
79             }
80           }
81         }
82       }
83       if (!already_computed) {
84         auto ret = input_infos[i]->ComputeThreshold();
85         if (ret != RET_OK) {
86           MS_LOG(ERROR) << "ComputeThreshold failed.";
87           return ret;
88         }
89       }
90     }
91   }
92   return RET_OK;
93 }
94 
UpdateDivergInterval()95 int Calibrator::UpdateDivergInterval() {
96   for (auto &kv : inputs_diverg_info_) {
97     for (auto &info : kv.second) {
98       info.second->UpdateInterval();
99     }
100   }
101   for (auto &kv : outputs_diverg_info_) {
102     for (auto &info : kv.second) {
103       info.second->UpdateInterval();
104     }
105   }
106   return RET_OK;
107 }
108 
UpdateDataFrequency(const std::vector<float> & data,const std::unique_ptr<DataDistribution> & diverg_info)109 int Calibrator::UpdateDataFrequency(const std::vector<float> &data,
110                                     const std::unique_ptr<DataDistribution> &diverg_info) {
111   MS_ASSERT(diverg_info != nullptr);
112   return diverg_info->UpdateHistogram(data);
113 }
114 
AddQuantizedOp(const CNodePtr & cnode)115 int Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
116   if (cnode == nullptr) {
117     MS_LOG(ERROR) << "To be quantized cnode is null";
118     return RET_ERROR;
119   }
120   auto node_name = cnode->fullname_with_scope();
121   auto input_size = cnode->size();
122   int index = 0;
123   for (size_t i = 1; i < input_size; i++) {
124     if (opt::CheckPrimitiveType(cnode->input(i), prim::kPrimMakeTuple)) {
125       auto input_cnode = cnode->input(i)->cast<CNodePtr>();
126       MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "input_cnode is nullptr.");
127       auto make_tuple_size = input_cnode->size() - 1;
128       for (size_t j = 0; j < make_tuple_size; j++) {
129         std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
130           cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
131         MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
132         inputs_diverg_info_[node_name].insert({index++, std::move(input_diverg)});
133       }
134     } else {
135       std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
136         cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
137       MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
138       inputs_diverg_info_[node_name].insert({index++, std::move(input_diverg)});
139     }
140   }
141 
142   if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
143     auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
144     MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
145     auto elements = tuple->elements();
146     for (size_t i = 0; i < elements.size(); i++) {
147       std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
148         cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
149       MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
150       outputs_diverg_info_[node_name].insert({i, std::move(output_diverg)});
151     }
152   } else {
153     std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
154       cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_, symmetric_);
155     MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
156     outputs_diverg_info_[node_name].insert({0, std::move(output_diverg)});
157   }
158   return RET_OK;
159 }
160 
GenerateInputData(const std::string & input_name,size_t image_index,mindspore::MSTensor * tensor) const161 int Calibrator::GenerateInputData(const std::string &input_name, size_t image_index,
162                                   mindspore::MSTensor *tensor) const {
163   return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
164 }
165 
CollectDataDistribution(const std::string & node_name,const std::vector<mindspore::MSTensor> & tensors,std::unordered_map<std::string,std::map<int,std::unique_ptr<DataDistribution>>> * diverg_info_map,CollectType collect_type)166 int Calibrator::CollectDataDistribution(
167   const std::string &node_name, const std::vector<mindspore::MSTensor> &tensors,
168   std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
169   CollectType collect_type) {
170   MS_CHECK_TRUE_MSG(diverg_info_map != nullptr, RET_ERROR, "diverg_info_map is nullptr.");
171   if (diverg_info_map->find(node_name) == diverg_info_map->end()) {
172     return RET_OK;
173   }
174   for (size_t i = 0; i < tensors.size(); i++) {
175     auto tensor = tensors[i];
176     if (tensor.IsConst() || tensor.DataType() != DataType::kNumberTypeFloat32) {
177       continue;
178     }
179     const auto *tensor_data = static_cast<const float *>(tensor.Data().get());
180     if (tensor_data == nullptr) {
181       MS_LOG(ERROR) << tensor.Name() << " tensor_data is nullptr.";
182       return RET_ERROR;
183     }
184     size_t elem_count = static_cast<size_t>(tensor.ElementNum());
185     MS_CHECK_GT(elem_count, 0, RET_ERROR);
186     std::vector<float> data(tensor_data, tensor_data + elem_count);
187     if (collect_type == MIN_MAX) {
188       MS_CHECK_LT(i, (*diverg_info_map)[node_name].size(), RET_ERROR);
189       auto ret = RecordMaxMinValue(data, (*diverg_info_map)[node_name][i]);
190       if (ret != RET_OK) {
191         MS_LOG(ERROR) << tensor.Name() << " record max min value failed.";
192         return RET_ERROR;
193       }
194     } else if (collect_type == KL_BIN) {
195       MS_CHECK_LT(i, (*diverg_info_map)[node_name].size(), RET_ERROR);
196       auto ret = UpdateDataFrequency(data, (*diverg_info_map)[node_name][i]);
197       if (ret != RET_OK) {
198         MS_LOG(ERROR) << tensor.Name() << " update data frequency failed.";
199         return RET_ERROR;
200       }
201     }
202   }
203   return RET_OK;
204 }
205 }  // namespace mindspore::lite::quant
206