1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
20 #include <cmath>
21 #include <string>
22 #include <memory>
23 #include <vector>
24 #include <set>
25 #include <functional>
26 #include <deque>
27 #include "include/common/utils/convert_utils.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "tools/common/graph_util.h"
32 #include "tools/lite_exporter/anf_exporter.h"
33 #include "tools/converter/graphdef_transform.h"
34 #include "tools/common/tensor_util.h"
35 #include "tools/optimizer/common/gllo_utils.h"
36 #include "ops/fusion/mat_mul_fusion.h"
37 #include "ops/auto_generate/gen_lite_ops.h"
38 #include "ops/fusion/conv2d_transpose_fusion.h"
39 #include "ops/ops_func_impl/gather.h"
40 #include "ops/op_utils.h"
41 #include "src/common/utils.h"
42 #include "src/common/file_utils.h"
43 #include "src/litert/cxx_api/tensor/tensor_impl.h"
44 #include "ir/anf.h"
45 #include "tools/converter/export_model.h"
46 #include "tools/converter/parser/parser_utils.h"
47 #include "ops/other_ops.h"
48 #include "utils/anf_utils.h"
49 #include "mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h"
50
51 using std::string;
52 using std::vector;
53
54 namespace mindspore::lite::quant {
55 namespace {
56 constexpr size_t kGatherAxisIndex = 3;
57 constexpr int kDefaultThreadNum = 4;
58 constexpr size_t kEncMaxLen = 16;
59 constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
60 constexpr int kFakeQuantMinIndex = 1;
61 constexpr int kFakeQuantMaxIndex = 2;
62 } // namespace
63
GetQuantType(const CNodePtr & cnode,quant::QuantType * quant_type)64 int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
65 CHECK_NULL_RETURN(cnode);
66 CHECK_NULL_RETURN(quant_type);
67 auto quant_param_holder = GetCNodeQuantHolder(cnode);
68 if (quant_param_holder == nullptr) {
69 *quant_type = quant::QUANT_NONE;
70 return RET_OK;
71 }
72 *quant_type = quant_param_holder->quant_type();
73 return RET_OK;
74 }
75
GetQuantTypeNew(const CNodePtr & cnode,quant::QuantType * quant_type)76 int GetQuantTypeNew(const CNodePtr &cnode, quant::QuantType *quant_type) {
77 CHECK_NULL_RETURN(cnode);
78 CHECK_NULL_RETURN(quant_type);
79 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
80 if (primitive == nullptr) {
81 MS_LOG(ERROR) << "primitive is nullptr.";
82 return RET_NULL_PTR;
83 }
84 auto quant_type_attr = primitive->GetAttr(quant::kQuantType);
85 if (quant_type_attr == nullptr) {
86 *quant_type = quant::QUANT_NONE;
87 return RET_OK;
88 }
89 *quant_type = static_cast<quant::QuantType>(GetValue<int32_t>(quant_type_attr));
90 return RET_OK;
91 }
92
GetFuncGraphs(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)93 void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
94 MS_ASSERT(func_graph != nullptr);
95 MS_ASSERT(all_func_graphs != nullptr);
96 all_func_graphs->insert(func_graph);
97 auto nodes = func_graph->GetOrderedCnodes();
98 std::deque<CNodePtr> to_process{};
99 to_process.insert(to_process.end(), nodes.begin(), nodes.end());
100 while (!to_process.empty()) {
101 auto &cur_cnode = to_process.front();
102 for (auto &input : cur_cnode->inputs()) {
103 if (!IsValueNode<FuncGraph>(input)) {
104 continue;
105 }
106 auto new_fg = GetValueNode<FuncGraphPtr>(input);
107 if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
108 continue;
109 }
110 all_func_graphs->insert(new_fg);
111 auto new_nodes = new_fg->GetOrderedCnodes();
112 to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
113 }
114 to_process.pop_front();
115 }
116 }
117
UpdateDataType(const AnfNodePtr & node,TypeId new_data_type)118 int UpdateDataType(const AnfNodePtr &node, TypeId new_data_type) {
119 auto abstract_base = node->abstract();
120 if (abstract_base == nullptr) {
121 MS_LOG(ERROR) << "Abstract of node is nullptr, " << node->fullname_with_scope();
122 return RET_NULL_PTR;
123 }
124
125 std::vector<AbstractBasePtr> abstracts;
126 if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
127 auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
128 abstracts = abstract_tuple->elements();
129 } else {
130 abstracts.push_back(abstract_base);
131 }
132 for (auto &abstract : abstracts) {
133 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
134 CHECK_NULL_RETURN(abstract_tensor);
135 CHECK_NULL_RETURN(abstract_tensor->element());
136 abstract_tensor->element()->set_type(TypeIdToType(new_data_type));
137 }
138 return RET_OK;
139 }
140
IsGraphInDTypeCast(const CNodePtr & cnode)141 bool IsGraphInDTypeCast(const CNodePtr &cnode) {
142 if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
143 return false;
144 }
145 auto input_node = cnode->input(1);
146 MS_CHECK_FALSE(input_node == nullptr, false);
147 return IsGraphInput(input_node);
148 }
149
IsGraphOutDTypeCast(const FuncGraphPtr & func_graph,const CNodePtr & cnode)150 bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
151 if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
152 return false;
153 }
154 auto manager = func_graph->manager();
155 if (manager == nullptr) {
156 manager = Manage(func_graph, true);
157 }
158 MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr.");
159 auto node_users = manager->node_users()[cnode];
160 MS_CHECK_TRUE_MSG(!node_users.empty(), false, "node_users is empty.");
161 for (auto &node_user : node_users) {
162 auto output_cnode = node_user.first->cast<CNodePtr>();
163 MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "output_cnode is nullptr.");
164 if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
165 return false;
166 }
167 }
168 return true;
169 }
170
GetCastNodeType(const FuncGraphPtr & func_graph,const CNodePtr & cnode,CastNodeType * cast_node_type)171 int GetCastNodeType(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CastNodeType *cast_node_type) {
172 CHECK_NULL_RETURN(cast_node_type);
173 if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
174 MS_LOG(DEBUG) << "Not QuantDtypeCastNode, cnode name: " << cnode->fullname_with_scope();
175 return RET_NOT_SUPPORT;
176 }
177 auto input_node = cnode->input(1);
178 MS_CHECK_FALSE(input_node == nullptr, RET_ERROR);
179
180 // input node
181 TypeId pre_node_dtype = kTypeUnknown;
182 if (opt::GetDataTypeFromAnfNode(input_node, &pre_node_dtype) != RET_OK) {
183 MS_LOG(ERROR) << "Get data type failed, cnode name: " << input_node->fullname_with_scope();
184 return RET_ERROR;
185 }
186
187 // output node
188 TypeId post_node_dtype = kTypeUnknown;
189 auto manager = func_graph->manager();
190 if (manager == nullptr) {
191 manager = Manage(func_graph, true);
192 }
193 CHECK_NULL_RETURN(manager);
194 auto node_users = manager->node_users()[cnode];
195 MS_CHECK_TRUE_RET(!node_users.empty(), RET_NULL_PTR);
196 auto output_cnode = node_users.begin()->first->cast<CNodePtr>();
197 CHECK_NULL_RETURN(output_cnode);
198
199 if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
200 if (opt::GetDataTypeFromAnfNode(output_cnode, &post_node_dtype) != RET_OK) {
201 MS_LOG(ERROR) << "Get data type failed, cnode name: " << output_cnode->fullname_with_scope();
202 return RET_ERROR;
203 }
204 if (pre_node_dtype == kNumberTypeFloat32 &&
205 (post_node_dtype == kNumberTypeInt8 || post_node_dtype == kNumberTypeUInt8)) {
206 *cast_node_type = kQuant;
207 } else if ((pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) &&
208 post_node_dtype == kNumberTypeFloat32) {
209 *cast_node_type = kDeQuant;
210 } else {
211 MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
212 }
213 } else {
214 if (pre_node_dtype == kNumberTypeFloat32) {
215 *cast_node_type = kQuant;
216 } else if (pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) {
217 *cast_node_type = kDeQuant;
218 } else {
219 MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
220 }
221 }
222 return RET_OK;
223 }
224
NodePrimitiveType(const CNodePtr & cnode)225 std::string NodePrimitiveType(const CNodePtr &cnode) {
226 if (cnode == nullptr) {
227 MS_LOG(ERROR) << "cnode is null";
228 return "";
229 }
230 auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
231 if (primitive_c == nullptr) {
232 MS_LOG(ERROR) << "primitive_c is null";
233 return "";
234 }
235 return primitive_c->name();
236 }
237
LargeModelBuildModel(const schema::MetaGraphT & meta_graph,const std::shared_ptr<ConverterPara> & param,const std::shared_ptr<mindspore::Model> & model,const std::shared_ptr<Context> & context,size_t * size)238 Status LargeModelBuildModel(const schema::MetaGraphT &meta_graph, const std::shared_ptr<ConverterPara> ¶m,
239 const std::shared_ptr<mindspore::Model> &model, const std::shared_ptr<Context> &context,
240 size_t *size) {
241 if (size == nullptr) {
242 return kLiteError;
243 }
244 if (param->commonQuantParam.workspace.empty()) {
245 MS_LOG(ERROR) << "The model is larger than 2G, mixedBitWeightQuant config needs to set workspace to save tmp model";
246 return kLiteError;
247 }
248 std::string tmp_save_file_path = param->commonQuantParam.workspace + "/tmp.ms";
249 tmp_save_file_path = lite::RealPath(tmp_save_file_path.c_str());
250 if (tmp_save_file_path.empty()) {
251 MS_LOG(ERROR) << param->commonQuantParam.workspace << " is invalid path. Please check it again.";
252 return kLiteError;
253 }
254 unsigned char encKey[kEncMaxLen] = {0};
255 size_t keyLen = 0;
256 auto status = MetaGraphSerializer::Save(meta_graph, tmp_save_file_path, size, encKey, keyLen, param->encrypt_mode);
257 if (status != RET_OK) {
258 MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status);
259 return kLiteError;
260 }
261
262 mindspore::ModelType model_type = kMindIR_Lite;
263 auto ret = model->Build(tmp_save_file_path, model_type, context);
264 return ret;
265 }
266
DumpGraph(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param,const std::string & save_path)267 int DumpGraph(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m,
268 const std::string &save_path) {
269 FuncGraphPtr func_graph_clone;
270 if (CloneFuncGraph(func_graph, param, &func_graph_clone) != RET_OK) {
271 MS_LOG(ERROR) << "Clone func_graph failed";
272 return RET_ERROR;
273 }
274 auto meta_graph = Export(func_graph_clone, true, true);
275 if (meta_graph == nullptr) {
276 MS_LOG(ERROR) << "Export to meta_graph failed";
277 return RET_ERROR;
278 }
279
280 // transform
281 GraphDefTransform fb_transform;
282 fb_transform.SetGraphDef(meta_graph);
283 auto status = fb_transform.Transform(param);
284 if (status != RET_OK) {
285 MS_LOG(ERROR) << "FBTransform model failed";
286 delete meta_graph;
287 return RET_ERROR;
288 }
289 meta_graph->version = Version();
290
291 status = UpdateGraphOutputName(meta_graph);
292 if (status != RET_OK) {
293 MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
294 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
295 delete meta_graph;
296 return RET_ERROR;
297 }
298
299 unsigned char encKey[kEncMaxLen] = {0};
300 size_t keyLen = 0;
301 size_t size;
302 status = MetaGraphSerializer::Save(*meta_graph, save_path, &size, encKey, keyLen, param->encrypt_mode);
303 if (status != RET_OK) {
304 MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status);
305 return RET_ERROR;
306 }
307 return RET_OK;
308 }
309
BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> & model,const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param,size_t * size)310 Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
311 const std::shared_ptr<ConverterPara> ¶m, size_t *size) {
312 if (size == nullptr) {
313 return kLiteError;
314 }
315 FuncGraphPtr func_graph_clone;
316 if (CloneFuncGraph(func_graph, param, &func_graph_clone) != RET_OK) {
317 MS_LOG(ERROR) << "Clone func_graph failed";
318 return kLiteNullptr;
319 }
320 auto meta_graph = Export(func_graph_clone, true, true);
321 if (meta_graph == nullptr) {
322 MS_LOG(ERROR) << "Export to meta_graph failed";
323 return kLiteNullptr;
324 }
325
326 // transform
327 GraphDefTransform fb_transform;
328 fb_transform.SetGraphDef(meta_graph);
329 auto status = fb_transform.Transform(param);
330 if (status != RET_OK) {
331 MS_LOG(ERROR) << "FBTransform model failed";
332 delete meta_graph;
333 return kLiteError;
334 }
335 meta_graph->version = Version();
336
337 status = UpdateGraphOutputName(meta_graph);
338 if (status != RET_OK) {
339 MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
340 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
341 delete meta_graph;
342 return kLiteError;
343 }
344
345 auto context = std::make_shared<mindspore::Context>();
346 if (context == nullptr) {
347 MS_LOG(ERROR) << "New context failed while running.";
348 delete meta_graph;
349 return kLiteNullptr;
350 }
351 context->SetThreadNum(kDefaultThreadNum);
352 context->SetThreadAffinity(kCpuBindMode);
353
354 std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
355 if (device_info == nullptr) {
356 MS_LOG(ERROR) << "New device_info failed while running.";
357 delete meta_graph;
358 return kLiteNullptr;
359 }
360 auto &device_list = context->MutableDeviceInfo();
361 device_list.push_back(device_info);
362
363 size_t tensors_size = 0;
364 for (auto &tensor : meta_graph->allTensors) {
365 tensors_size += tensor->data.size();
366 }
367
368 if (tensors_size >= kModelSizeLimit) {
369 auto ret = LargeModelBuildModel(*meta_graph, param, model, context, size);
370 delete meta_graph;
371 return ret;
372 }
373
374 flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
375 auto offset = schema::MetaGraph::Pack(builder, meta_graph);
376 builder.Finish(offset);
377 schema::FinishMetaGraphBuffer(builder, offset);
378 *size = builder.GetSize();
379 auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
380 if (content == nullptr) {
381 MS_LOG(ERROR) << "GetBufferPointer return null";
382 delete meta_graph;
383 return kLiteNullptr;
384 }
385
386 auto ret = model->Build(content, *size, kMindIR, context);
387 delete meta_graph;
388 return ret;
389 }
390
MSTensorToLiteTensor(const MSTensor & tensor)391 mindspore::lite::Tensor *MSTensorToLiteTensor(const MSTensor &tensor) {
392 if (tensor.impl() == nullptr) {
393 MS_LOG(ERROR) << "Tensor " << tensor.Name() << " is nullptr.";
394 return static_cast<lite::Tensor *>(nullptr);
395 }
396 auto lite_impl = std::static_pointer_cast<LiteTensorImpl>(tensor.impl());
397 return static_cast<mindspore::lite::Tensor *>(lite_impl->lite_tensor());
398 }
399
MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> & src_tensors)400 std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> &src_tensors) {
401 std::vector<mindspore::lite::Tensor *> dst_tensors(src_tensors.size());
402 for (const auto &src_tensor : src_tensors) {
403 auto tensor = MSTensorToLiteTensor(src_tensor);
404 if (tensor == nullptr) {
405 return {};
406 }
407 dst_tensors.emplace_back(tensor);
408 }
409 return dst_tensors;
410 }
411
GetParameterAndTensor(const AnfNodePtr & node,ParameterPtr * param_node,tensor::TensorPtr * tensor_info)412 void GetParameterAndTensor(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
413 CHECK_NULL_RETURN_VOID(param_node);
414 CHECK_NULL_RETURN_VOID(tensor_info);
415 if (node == nullptr) {
416 MS_LOG(ERROR) << "node is nullptr";
417 return;
418 }
419 auto op_name = node->fullname_with_scope();
420
421 *param_node = node->cast<ParameterPtr>();
422 if (*param_node == nullptr) {
423 MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
424 return;
425 }
426 if (!(*param_node)->has_default()) {
427 MS_LOG(INFO) << op_name << " not has_default";
428 return;
429 }
430
431 *tensor_info = std::static_pointer_cast<tensor::Tensor>((*param_node)->default_param());
432 if (*tensor_info == nullptr) {
433 MS_LOG(INFO) << "default_param can not cast to tensor::Tensor";
434 return;
435 }
436 }
437
UpdateTensorDataAndSize(const AnfNodePtr & node,const tensor::TensorPtr & weight,const void * quant_datas,size_t new_size,TypeId new_data_type)438 int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &weight, const void *quant_datas,
439 size_t new_size, TypeId new_data_type) {
440 CHECK_NULL_RETURN(quant_datas);
441 MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
442 MS_CHECK_TRUE_RET(new_size > 0, RET_NULL_PTR);
443 weight->set_data_type(new_data_type);
444 if (new_size != static_cast<size_t>(weight->data().nbytes())) {
445 MS_LOG(ERROR) << "Data size of tensor info is error.";
446 return RET_ERROR;
447 }
448 if (memcpy_s(weight->data_c(), weight->data().nbytes(), quant_datas, new_size) != EOK) {
449 MS_LOG(ERROR) << "memcpy data failed.";
450 return RET_ERROR;
451 }
452 // set dtype
453 auto ret = UpdateDataType(node, new_data_type);
454 if (ret != RET_OK) {
455 MS_LOG(ERROR) << node->fullname_with_scope() << " set new dtype failed.";
456 return ret;
457 }
458 return RET_OK;
459 }
460
GetMatMulPreferredDim(const PrimitivePtr & primitive,int input_index,const std::vector<int> & dims)461 int GetMatMulPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims) {
462 size_t last_first_index = dims.size() - 1;
463 size_t last_second_index = dims.size() - 2;
464 auto matmul_prim = api::MakeShared<ops::MatMul>(primitive);
465 MS_ASSERT(matmul_prim != nullptr);
466 // For MatMul A
467 if (input_index == 0) {
468 if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) {
469 return last_first_index;
470 } else {
471 return last_second_index;
472 }
473 }
474 // For MatMul B
475 if (input_index == 1) {
476 if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) {
477 return last_second_index;
478 } else {
479 return last_first_index;
480 }
481 }
482 return 0;
483 }
484
GetDeConvPreferredDim(const PrimitivePtr & primitive,const std::vector<int> & dims)485 int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int> &dims) {
486 auto prim = api::MakeShared<ops::Conv2DTranspose>(primitive);
487 MS_ASSERT(prim != nullptr);
488 if (prim->get_in_channel() == prim->get_group() && prim->get_out_channel() == prim->get_group()) {
489 // DepthWise-DeConv (CO\CI) KH KW 1
490 return 0;
491 }
492 // DeConv:CI KH KW CO
493 return dims.size() - 1;
494 }
495
GetGatherPreferredDim(const CNodePtr & cnode)496 int GetGatherPreferredDim(const CNodePtr &cnode) {
497 if (cnode->size() < kGatherAxisIndex + kPrimOffset) {
498 MS_LOG(WARNING) << "gather cnode size < 4.";
499 return 0;
500 }
501 DataInfo data_info;
502 auto output_type_node = cnode->input(kGatherAxisIndex);
503 if (utils::isa<ParameterPtr>(output_type_node)) {
504 if (FetchDataFromParameterNode(cnode, kGatherAxisIndex, converter::kFmkTypeMs, &data_info, true) != lite::RET_OK) {
505 MS_LOG(WARNING) << "Fetch data from parameter node failed.";
506 return 0;
507 }
508 } else if (utils::isa<ValueNodePtr>(output_type_node)) {
509 if (FetchDataFromValueNode(cnode, kGatherAxisIndex, converter::kFmkTypeMs, false, &data_info, true) !=
510 lite::RET_OK) {
511 MS_LOG(WARNING) << "Fetch data from value node failed.";
512 return 0;
513 }
514 } else {
515 MS_LOG(WARNING) << "The data type is not a const.";
516 return 0;
517 }
518
519 auto axis_data = reinterpret_cast<const int *>(data_info.data_.data());
520 CHECK_NULL_RETURN(axis_data);
521 return axis_data[0];
522 }
523
GetPreferredDim(const CNodePtr & cnode,int input_index,const std::vector<int> & dims)524 int GetPreferredDim(const CNodePtr &cnode, int input_index, const std::vector<int> &dims) {
525 auto input_node = cnode->input(input_index + kPrimOffset);
526 if (input_node->isa<mindspore::Parameter>()) {
527 tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
528 if (input_tensor != nullptr) {
529 auto quantization_params = input_tensor->quant_params();
530 if (!quantization_params.empty()) {
531 auto quantization_param = quantization_params.front();
532 auto axis_attr = quantization_param->GetAttr(kChannelAxis);
533 if (axis_attr != nullptr) {
534 if (axis_attr->isa<Int64Imm>()) {
535 auto axis = axis_attr->cast<Int64ImmPtr>()->value();
536 MS_LOG(INFO) << "Quantization param axis is " << axis;
537 return axis;
538 }
539 MS_LOG(WARNING) << "Quantization param axis_attr is not int64";
540 }
541 }
542 }
543 }
544 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
545 CHECK_NULL_RETURN(primitive);
546 if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul ||
547 primitive->name() == ops::kNameBatchMatMul) {
548 return GetMatMulPreferredDim(primitive, input_index, dims);
549 } else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
550 return GetDeConvPreferredDim(primitive, dims);
551 } else if (primitive->name() == ops::kNameGather) {
552 return GetGatherPreferredDim(cnode);
553 } else if (primitive->name() == "FFN") {
554 // For FFN MatMul, transpose is false
555 return dims.size() - 1;
556 }
557 // The first index.
558 return 0;
559 }
560
GetFollowedNodePreferredDim(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & dims)561 int GetFollowedNodePreferredDim(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &dims) {
562 auto manager = mindspore::Manage(func_graph, true);
563 auto node_users = manager->node_users()[cnode];
564 if (node_users.empty()) {
565 MS_LOG(WARNING) << cnode->fullname_with_scope() << " cnode is isolated.";
566 return 0;
567 }
568 if (node_users.size() > 1) {
569 MS_LOG(WARNING) << "The cnode dont has only one followed node";
570 return 0;
571 }
572 auto node_user = node_users.begin();
573 if (!utils::isa<CNodePtr>(node_user->first)) {
574 MS_LOG(WARNING) << "The followed op: " << node_user->first->fullname_with_scope() << " is not cnode";
575 return 0;
576 }
577 auto node_user_cnode = utils::cast<CNodePtr>(node_user->first);
578 return GetPreferredDim(node_user_cnode, node_user->second - 1, dims);
579 }
580
ConvertShapeVectorToInt32(const ShapeVector & dims)581 std::vector<int> ConvertShapeVectorToInt32(const ShapeVector &dims) {
582 std::vector<int> shape;
583 for (auto dim : dims) {
584 if (dim > INT32_MAX || dim < INT32_MIN) {
585 MS_LOG(ERROR) << dim << " over int32 range.";
586 shape.push_back(-1);
587 } else {
588 shape.push_back(dim);
589 }
590 }
591 return shape;
592 }
593
CheckNodeInSet(const CNodePtr & cnode,const std::set<PrimitivePtr> & support_primitive_types)594 bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types) {
595 for (const auto &type : support_primitive_types) {
596 if (opt::CheckPrimitiveType(cnode, type)) {
597 return true;
598 }
599 }
600 return false;
601 }
602
CheckFollowedNodeInSet(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::set<PrimitivePtr> & support_primitive_types)603 bool CheckFollowedNodeInSet(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
604 const std::set<PrimitivePtr> &support_primitive_types) {
605 auto manager = mindspore::Manage(func_graph, true);
606 auto node_users = manager->node_users()[cnode];
607 if (node_users.empty()) {
608 MS_LOG(WARNING) << cnode->fullname_with_scope() << " cnode is isolated.";
609 return false;
610 }
611 for (auto &node_user : node_users) {
612 if (!utils::isa<CNodePtr>(node_user.first)) {
613 MS_LOG(INFO) << "The followed op: " << node_user.first->fullname_with_scope() << " is not cnode";
614 return false;
615 }
616 auto node_user_cnode = utils::cast<CNodePtr>(node_user.first);
617 if (!CheckNodeInSet(node_user_cnode, support_primitive_types)) {
618 return false;
619 }
620 }
621 return true;
622 }
623
DeQuantData(const mindspore::MSTensor * tensor,std::vector<double> * dequant_data)624 int DeQuantData(const mindspore::MSTensor *tensor, std::vector<double> *dequant_data) {
625 return DeQuantData(reinterpret_cast<const int8_t *>(tensor->Data().get()), tensor->ElementNum(),
626 tensor->QuantParams(), dequant_data);
627 }
628
GetElementNumFromShape(const std::vector<int> & dims,int * total_size)629 int GetElementNumFromShape(const std::vector<int> &dims, int *total_size) {
630 CHECK_NULL_RETURN(total_size);
631 *total_size = 1;
632 for (auto dim : dims) {
633 MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(*total_size, dim), RET_ERROR, "Int mul overflow.");
634 *total_size *= dim;
635 }
636 return RET_OK;
637 }
638
GetBucketAllIndex(const std::vector<int> & dims,int preferred_dim,std::vector<std::vector<size_t>> * buckets_data_index)639 int GetBucketAllIndex(const std::vector<int> &dims, int preferred_dim,
640 std::vector<std::vector<size_t>> *buckets_data_index) {
641 CHECK_NULL_RETURN(buckets_data_index);
642 int outer = 1;
643 for (int i = 0; i < preferred_dim; i++) {
644 outer *= dims[i];
645 }
646 int bucket_count = dims[preferred_dim];
647 int inner = 1;
648 for (size_t i = preferred_dim + 1; i < dims.size(); i++) {
649 inner *= dims[i];
650 }
651 if (inner <= 0 || outer <= 0 || bucket_count <= 0) {
652 return RET_ERROR;
653 }
654 for (int i = 0; i < bucket_count; i++) {
655 auto index = i * inner;
656 std::vector<size_t> bucket_index(inner * outer);
657 for (int j = 0; j < outer; j++) {
658 for (int k = 0; k < inner; k++) {
659 bucket_index[j * inner + k] = index + k;
660 }
661 index += bucket_count * inner;
662 }
663 buckets_data_index->push_back(bucket_index);
664 }
665 return RET_OK;
666 }
667
CheckControlFlowType(const AnfNodePtr & node)668 bool CheckControlFlowType(const AnfNodePtr &node) {
669 if (node == nullptr) {
670 return false;
671 }
672 std::map<std::string, PrimitivePtr> control_flow_ops = {{"PartialFusion", prim::kPrimPartialFusion},
673 {"Switch", prim::kPrimSwitch},
674 {"switch_layer", prim::kPrimSwitchLayer},
675 {"call", prim::kPrimCall}};
676
677 if (node->isa<mindspore::CNode>()) {
678 auto cnode = node->cast<CNodePtr>();
679 // control flow call
680 if (!IsValueNode<mindspore::Primitive>(cnode->input(kPrimIndex))) {
681 return true;
682 }
683 auto prim = GetValuePtr<mindspore::Primitive>(cnode->input(kPrimIndex));
684 if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
685 return true;
686 }
687 } else if (node->isa<ValueNode>()) {
688 auto prim = GetValuePtr<mindspore::Primitive>(node);
689 if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
690 return true;
691 }
692 }
693 return false;
694 }
695
CloneFuncGraph(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param,FuncGraphPtr * func_graph_bak)696 int CloneFuncGraph(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> ¶m,
697 FuncGraphPtr *func_graph_bak) {
698 CHECK_NULL_RETURN(func_graph_bak);
699 CHECK_NULL_RETURN(param);
700 std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
701 *func_graph_bak = lite::CloneFuncGraph(func_graph, param, &cloned_func_graph);
702 CHECK_NULL_RETURN(*func_graph_bak);
703 static auto root_func_manager = Manage(*func_graph_bak);
704 std::set<FuncGraphPtr> all_func_graphs = {};
705 lite::GetAllFuncGraph(*func_graph_bak, &all_func_graphs);
706 for (const auto &graph : all_func_graphs) {
707 graph->set_manager(root_func_manager);
708 }
709 return RET_OK;
710 }
711
MarkOriginDataType(const FuncGraphPtr & func_graph)712 int MarkOriginDataType(const FuncGraphPtr &func_graph) {
713 auto cnodes = func_graph->GetOrderedCnodes();
714 for (auto &cnode : cnodes) {
715 TypeId type_id = kTypeUnknown;
716 if (opt::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
717 continue;
718 }
719 auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
720 if (ret != RET_OK) {
721 MS_LOG(INFO) << "CNode data type is unknown.";
722 return RET_OK;
723 }
724 if (type_id != kTypeUnknown) {
725 MS_LOG(INFO) << cnode->fullname_with_scope() << " origin type is " << type_id;
726 cnode->AddAttr("origin_type", MakeValue(static_cast<int>(type_id)));
727 }
728 }
729 return RET_OK;
730 }
731
ConvertFp16ToFp32(const FuncGraphPtr & func_graph)732 int ConvertFp16ToFp32(const FuncGraphPtr &func_graph) {
733 auto cnodes = func_graph->GetOrderedCnodes();
734 for (auto &cnode : cnodes) {
735 auto ret = ConvertCNodeFp16ToFp32(cnode);
736 if (ret != RET_OK) {
737 MS_LOG(ERROR) << cnode->fullname_with_scope() << " convert fp16 To fp32 failed.";
738 return ret;
739 }
740 }
741 return RET_OK;
742 }
743
ConvertCNodeFp32ToFp16(const CNodePtr & cnode)744 int ConvertCNodeFp32ToFp16(const CNodePtr &cnode) {
745 for (size_t i = kPrimOffset; i < cnode->size(); ++i) {
746 auto input = cnode->input(i);
747 if (input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) {
748 MS_LOG(ERROR) << cnode->fullname_with_scope() << " Parameter.";
749 ParameterPtr param_node;
750 tensor::TensorPtr tensor_info;
751 GetParameterAndTensor(input, ¶m_node, &tensor_info);
752 CHECK_NULL_RETURN(tensor_info);
753 CHECK_NULL_RETURN(param_node);
754 if (tensor_info->data_type() == kNumberTypeFloat32) {
755 MS_LOG(INFO) << "convert " << input->fullname_with_scope() << " from fp32 to fp16.";
756 auto data = static_cast<float *>(tensor_info->data_c());
757 std::vector<float16> fp16_data(tensor_info->DataSize());
758 for (size_t j = 0; j < tensor_info->DataSize(); j++) {
759 fp16_data[j] = mindspore::Float16(data[j]);
760 }
761 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
762 kNumberTypeFloat16, tensor_info->shape_c(), fp16_data.data(), fp16_data.size() * sizeof(float) / 2);
763 param_node->set_default_param(tensor_ptr);
764 param_node->set_abstract(tensor_ptr->ToAbstract());
765 }
766 } else if (input->isa<ValueNode>()) {
767 auto value_node = input->cast<ValueNodePtr>();
768 DataInfo data_info;
769 auto ret = FetchDataFromValueNode(cnode, i, converter::kFmkTypeMs, false, &data_info, false);
770 if (ret != RET_OK) {
771 MS_LOG(ERROR) << "Fetch data from value node failed.";
772 return ret;
773 }
774 std::vector<int64_t> shapes;
775 for (size_t j = 0; j < data_info.shape_.size(); ++j) {
776 shapes.push_back(data_info.shape_.at(j));
777 }
778 int total_size = 0;
779 ret = GetElementNumFromShape(data_info.shape_, &total_size);
780 if (ret != RET_OK) {
781 MS_LOG(ERROR) << "GetElementNumFromShape failed.";
782 return ret;
783 }
784 if (data_info.data_type_ == kNumberTypeFloat32) {
785 MS_LOG(ERROR) << "convert " << input->fullname_with_scope() << " from fp32 to fp16.";
786 auto data = static_cast<float *>(data_info.data_ptr_);
787 std::vector<float16> fp16_data(total_size);
788 for (int j = 0; j < total_size; j++) {
789 fp16_data[j] = mindspore::Float16(data[j]);
790 }
791 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
792 kNumberTypeFloat16, shapes, fp16_data.data(), fp16_data.size() * sizeof(float) / 2);
793 auto values = MakeValue(tensor_ptr);
794 value_node->set_value(values);
795 value_node->set_abstract(tensor_ptr->ToAbstract());
796 }
797 }
798 }
799 return RET_OK;
800 }
801
ConvertFp32ToFp16(const FuncGraphPtr & func_graph)802 int ConvertFp32ToFp16(const FuncGraphPtr &func_graph) {
803 auto cnodes = func_graph->GetOrderedCnodes();
804 for (auto &cnode : cnodes) {
805 auto ret = ConvertCNodeFp32ToFp16(cnode);
806 if (ret != RET_OK) {
807 MS_LOG(ERROR) << cnode->fullname_with_scope() << " convert cnode fp32 to fp16.";
808 return ret;
809 }
810 }
811 return RET_OK;
812 }
813
ConvertCNodeFp16ToFp32(const CNodePtr & cnode)814 int ConvertCNodeFp16ToFp32(const CNodePtr &cnode) {
815 for (size_t i = kPrimOffset; i < cnode->size(); ++i) {
816 auto input = cnode->input(i);
817 if (!input->isa<Parameter>() || !input->cast<ParameterPtr>()->has_default()) {
818 continue;
819 }
820 ParameterPtr param_node;
821 tensor::TensorPtr tensor_info;
822 GetParameterAndTensor(input, ¶m_node, &tensor_info);
823 CHECK_NULL_RETURN(tensor_info);
824 CHECK_NULL_RETURN(param_node);
825 if (tensor_info->data_type() == kNumberTypeFloat16) {
826 MS_LOG(INFO) << "convert " << input->fullname_with_scope() << " from fp16 to fp32.";
827 auto data = static_cast<float16 *>(tensor_info->data_c());
828 std::vector<float> fp32_data(tensor_info->DataSize());
829 for (size_t j = 0; j < tensor_info->DataSize(); j++) {
830 fp32_data[j] = mindspore::Float16::ToFloat32(data[j]);
831 }
832 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
833 kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float));
834
835 tensor::TensorPtr input_tensor = quant::GetNodeTensor(input);
836 MS_CHECK_TRUE_MSG(input_tensor != nullptr, RET_NULL_PTR, "Get node tensor failed.");
837 auto quant_params = input_tensor->quant_params();
838 tensor_ptr->set_quant_param(quant_params);
839
840 param_node->set_default_param(tensor_ptr);
841 param_node->set_abstract(tensor_ptr->ToAbstract());
842 }
843 }
844 return RET_OK;
845 }
846
IsPerchannelWeight(const std::vector<schema::QuantParamT> & quant_params,const tensor::TensorPtr & weight,int preferred_dim)847 bool IsPerchannelWeight(const std::vector<schema::QuantParamT> &quant_params, const tensor::TensorPtr &weight,
848 int preferred_dim) {
849 auto dims = weight->shape();
850 return (static_cast<int>(quant_params.size()) == dims[preferred_dim]);
851 }
852
ConvertQuantParamTToQuantizationParam(const std::vector<schema::QuantParamT> & quant_params)853 QuantizationParamPtr ConvertQuantParamTToQuantizationParam(const std::vector<schema::QuantParamT> &quant_params) {
854 if (quant_params.empty()) {
855 return nullptr;
856 }
857 QuantizationParam quantization(quant::kLinearQuant);
858 std::vector<ValuePtr> scale_list;
859 std::vector<ValuePtr> zeroPoint_list;
860 std::vector<ValuePtr> min_list;
861 std::vector<ValuePtr> max_list;
862 std::vector<ValuePtr> varCorr_list;
863 std::vector<ValuePtr> meanCorr_list;
864 std::vector<ValuePtr> numBits_list;
865 std::vector<ValuePtr> narrowRange_list;
866 std::vector<ValuePtr> dstDtype_list;
867 std::vector<ValuePtr> roundType_list;
868 std::vector<ValuePtr> multiplier_list;
869 for (auto quant_param : quant_params) {
870 scale_list.push_back(MakeValue(quant_param.scale));
871 zeroPoint_list.push_back(MakeValue(quant_param.zeroPoint));
872 min_list.push_back(MakeValue(quant_param.min));
873 max_list.push_back(MakeValue(quant_param.max));
874 varCorr_list.push_back(MakeValue(quant_param.varCorr));
875 meanCorr_list.push_back(MakeValue(quant_param.meanCorr));
876 numBits_list.push_back(MakeValue(quant_param.numBits));
877 narrowRange_list.push_back(MakeValue(quant_param.narrowRange));
878 dstDtype_list.push_back(MakeValue(quant_param.dstDtype));
879 roundType_list.push_back(MakeValue(quant_param.roundType));
880 multiplier_list.push_back(MakeValue(quant_param.multiplier));
881 }
882 quantization.AddAttr(quant::kScaleList, std::make_shared<ValueList>(scale_list));
883 quantization.AddAttr(quant::kZeroPointList, std::make_shared<ValueList>(zeroPoint_list));
884 quantization.AddAttr(quant::kMinList, std::make_shared<ValueList>(min_list));
885 quantization.AddAttr(quant::kMaxList, std::make_shared<ValueList>(max_list));
886 quantization.AddAttr(quant::kVarCorrList, std::make_shared<ValueList>(varCorr_list));
887 quantization.AddAttr(quant::kMeanCorrList, std::make_shared<ValueList>(meanCorr_list));
888 quantization.AddAttr(quant::kNumBitList, std::make_shared<ValueList>(numBits_list));
889 quantization.AddAttr(quant::kNarrowRangeList, std::make_shared<ValueList>(narrowRange_list));
890 quantization.AddAttr(quant::kDstDtypeList, std::make_shared<ValueList>(dstDtype_list));
891 quantization.AddAttr(quant::kRoundTypeList, std::make_shared<ValueList>(roundType_list));
892 quantization.AddAttr(quant::kMultiplierList, std::make_shared<ValueList>(multiplier_list));
893 return std::make_shared<mindspore::QuantizationParam>(quantization);
894 }
895
ConvertQuantizationParamToQuantParamT(const QuantizationParamPtr & quantization_param)896 std::vector<schema::QuantParamT> ConvertQuantizationParamToQuantParamT(const QuantizationParamPtr &quantization_param) {
897 std::vector<schema::QuantParamT> quant_params;
898 if (quantization_param == nullptr) {
899 return quant_params;
900 }
901 auto scale_list_attr = quantization_param->GetAttr(quant::kScaleList);
902 auto zero_point_list_attr = quantization_param->GetAttr(quant::kZeroPointList);
903 auto min_list_attr = quantization_param->GetAttr(quant::kMinList);
904 auto max_list_attr = quantization_param->GetAttr(quant::kMaxList);
905 auto var_corr_list_attr = quantization_param->GetAttr(quant::kVarCorrList);
906 auto mean_corr_list_attr = quantization_param->GetAttr(quant::kMeanCorrList);
907 auto num_bits_list_attr = quantization_param->GetAttr(quant::kNumBitList);
908 auto narrow_range_list_attr = quantization_param->GetAttr(quant::kNarrowRangeList);
909 auto dst_dtype_list_attr = quantization_param->GetAttr(quant::kDstDtypeList);
910 auto round_type_list_attr = quantization_param->GetAttr(quant::kRoundTypeList);
911 auto multiplier_list_attr = quantization_param->GetAttr(quant::kMultiplierList);
912 if (scale_list_attr != nullptr && zero_point_list_attr != nullptr && min_list_attr != nullptr &&
913 max_list_attr != nullptr && var_corr_list_attr != nullptr && mean_corr_list_attr != nullptr &&
914 num_bits_list_attr != nullptr && narrow_range_list_attr != nullptr) {
915 auto scales = GetValue<std::vector<double>>(scale_list_attr);
916 auto zero_points = GetValue<std::vector<int32_t>>(zero_point_list_attr);
917 auto mins = GetValue<std::vector<double>>(min_list_attr);
918 auto maxs = GetValue<std::vector<double>>(max_list_attr);
919 auto var_corrs = GetValue<std::vector<float>>(var_corr_list_attr);
920 auto mean_corrs = GetValue<std::vector<float>>(mean_corr_list_attr);
921 auto num_bits_list = GetValue<std::vector<int32_t>>(num_bits_list_attr);
922 auto narrow_range_list = GetValue<std::vector<bool>>(narrow_range_list_attr);
923 auto dst_dtype_list = GetValue<std::vector<int32_t>>(dst_dtype_list_attr);
924 auto round_type_list = GetValue<std::vector<int32_t>>(round_type_list_attr);
925 auto multiplier_list = GetValue<std::vector<int32_t>>(multiplier_list_attr);
926 for (size_t index = 0; index < scales.size(); ++index) {
927 schema::QuantParamT quant_param;
928 quant_param.scale = scales.at(index);
929 quant_param.zeroPoint = zero_points.at(index);
930 quant_param.min = mins.at(index);
931 quant_param.max = maxs.at(index);
932 quant_param.varCorr = var_corrs.at(index);
933 quant_param.meanCorr = mean_corrs.at(index);
934 quant_param.numBits = num_bits_list.at(index);
935 quant_param.narrowRange = narrow_range_list.at(index);
936 quant_param.dstDtype = dst_dtype_list.at(index);
937 quant_param.roundType = round_type_list.at(index);
938 quant_param.multiplier = multiplier_list.at(index);
939 quant_param.inited = true;
940 quant_params.push_back(quant_param);
941 }
942 }
943 return quant_params;
944 }
945
RemoveInputNodeQuantParam(const CNodePtr & cnode,size_t index)946 int RemoveInputNodeQuantParam(const CNodePtr &cnode, size_t index) {
947 if (cnode->size() <= index) {
948 MS_LOG(ERROR) << "index out of range, cnode input size is: " << cnode->size() << ", but index: " << index;
949 return RET_ERROR;
950 }
951 auto input_node = cnode->input(index);
952 CHECK_NULL_RETURN(input_node);
953 auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
954 if (IsGraphInput(input_node)) {
955 if (cnode_primitive->HasAttr(quant::kGraphInputQuantParam)) {
956 cnode_primitive->EraseAttr(quant::kGraphInputQuantParam);
957 }
958 } else if (input_node->isa<mindspore::CNode>()) {
959 if (cnode_primitive->HasAttr(quant::kQuantParam)) {
960 cnode_primitive->EraseAttr(quant::kQuantParam);
961 }
962 } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
963 auto input_tensor = quant::GetNodeTensor(input_node);
964 CHECK_NULL_RETURN(input_tensor);
965 input_tensor->set_quant_param({});
966 } else {
967 MS_LOG(ERROR) << input_node->fullname_with_scope() << " index: " << index << " is not support "
968 << input_node->type_name() << " type.";
969 return RET_ERROR;
970 }
971 return RET_OK;
972 }
973
GetInputNodeQuantParam(const CNodePtr & cnode,size_t index,size_t multi_ouput_index)974 std::vector<schema::QuantParamT> GetInputNodeQuantParam(const CNodePtr &cnode, size_t index, size_t multi_ouput_index) {
975 if (cnode->size() <= index) {
976 MS_LOG(WARNING) << "index out of range, cnode input size is: " << cnode->size() << ", but index: " << index;
977 return {};
978 }
979 auto input_node = cnode->input(index);
980 MS_CHECK_TRUE_MSG(input_node != nullptr, {}, "Anf node nullptr.");
981 auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
982 MS_CHECK_TRUE_MSG(cnode_primitive != nullptr, {}, "Primitive is nullptr.");
983 if (IsGraphInput(input_node)) {
984 auto quantization_param_value = cnode_primitive->GetAttr(quant::kGraphInputQuantParam);
985 if (quantization_param_value == nullptr) {
986 MS_LOG(WARNING) << input_node->fullname_with_scope() << " quant param Not exist.";
987 return {};
988 }
989 auto quantization_param = quantization_param_value->cast<mindspore::QuantizationParamPtr>();
990 MS_CHECK_TRUE_MSG(quantization_param != nullptr, {}, "Graph input quant param Not exist.");
991 return quant::ConvertQuantizationParamToQuantParamT(quantization_param);
992 } else if (input_node->isa<mindspore::CNode>()) {
993 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
994 auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
995 MS_CHECK_TRUE_MSG(input_cnode_primitive != nullptr, {}, "Primitive is nullptr.");
996 if (!input_cnode_primitive->HasAttr(quant::kQuantParam)) {
997 MS_LOG(WARNING) << input_node->fullname_with_scope() << " dont have quant param.";
998 return {};
999 }
1000 auto quantization_param_value = input_cnode_primitive->GetAttr(quant::kQuantParam);
1001 MS_CHECK_TRUE_MSG(quantization_param_value != nullptr, {}, "quantization_param_value is nullptr.");
1002 auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
1003 if (quantization_param_list.size() <= multi_ouput_index) {
1004 MS_LOG(WARNING) << "This node's input node: " << input_cnode->fullname_with_scope()
1005 << "'s output quant_params size: " << quantization_param_list.size()
1006 << ", but index: " << multi_ouput_index;
1007 return {};
1008 }
1009 // multi-output
1010 return quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.at(multi_ouput_index));
1011 } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
1012 tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
1013 MS_CHECK_TRUE_MSG(input_tensor != nullptr, {}, "Get node tensor failed.");
1014 auto quantization_params = input_tensor->quant_params();
1015 if (quantization_params.empty()) {
1016 MS_LOG(WARNING) << input_node->fullname_with_scope() << " quantization param is empty.";
1017 return {};
1018 }
1019 auto quantization_param = quantization_params.front();
1020 return quant::ConvertQuantizationParamToQuantParamT(quantization_param);
1021 } else {
1022 MS_LOG(ERROR) << cnode->fullname_with_scope() << " input node with index: " << index
1023 << " Not supported for quant param";
1024 }
1025 return {};
1026 }
1027
SetInputNodeQuantParam(const CNodePtr & cnode,size_t index,const std::vector<schema::QuantParamT> & quant_param)1028 STATUS SetInputNodeQuantParam(const CNodePtr &cnode, size_t index,
1029 const std::vector<schema::QuantParamT> &quant_param) {
1030 auto input_node = cnode->input(index);
1031 MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "Anf node nullptr.");
1032 if (IsGraphInput(input_node)) {
1033 auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(kPrimIndex));
1034 MS_CHECK_TRUE_MSG(cnode_primitive != nullptr, RET_NULL_PTR, "Primitive is nullptr.");
1035 auto quantization_param = quant::ConvertQuantParamTToQuantizationParam(quant_param);
1036 cnode_primitive->AddAttr(quant::kGraphInputQuantParam, quantization_param);
1037 } else if (input_node->isa<mindspore::CNode>()) {
1038 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
1039 auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
1040 MS_CHECK_TRUE_MSG(input_cnode_primitive != nullptr, RET_NULL_PTR, "Primitive is nullptr.");
1041 auto quantization_param = ConvertQuantParamTToQuantizationParam(quant_param);
1042 std::vector<ValuePtr> quantization_list{quantization_param};
1043 input_cnode_primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
1044 } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
1045 tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
1046 MS_CHECK_TRUE_MSG(input_tensor != nullptr, RET_NULL_PTR, "Get node tensor failed.");
1047 auto quantization_param = quant::ConvertQuantParamTToQuantizationParam(quant_param);
1048 CHECK_NULL_RETURN(quantization_param);
1049 input_tensor->set_quant_param(std::vector<std::shared_ptr<mindspore::QuantizationParam>>{quantization_param});
1050 } else {
1051 MS_LOG(WARNING) << input_node->fullname_with_scope() << " Not supported type.";
1052 return RET_ERROR;
1053 }
1054 return RET_OK;
1055 }
1056
GetNodeTensor(const AnfNodePtr & node)1057 tensor::TensorPtr GetNodeTensor(const AnfNodePtr &node) {
1058 // Only Parameter or ValueNode Node has tensor
1059 if (node->isa<Parameter>()) {
1060 auto parameter = node->cast<ParameterPtr>();
1061 if (parameter->default_param() != nullptr) {
1062 return parameter->default_param()->cast<tensor::TensorPtr>();
1063 }
1064 } else if (node->isa<ValueNode>()) {
1065 return node->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
1066 }
1067 return nullptr;
1068 }
1069
CloneQuantParam(const std::vector<schema::QuantParamT> & src)1070 std::vector<schema::QuantParamT> CloneQuantParam(const std::vector<schema::QuantParamT> &src) {
1071 MS_CHECK_TRUE_MSG(!src.empty(), {}, "Src is empty.");
1072 std::vector<schema::QuantParamT> dst;
1073 for (auto &quant_param : src) {
1074 schema::QuantParamT quant_param_clone;
1075 quant_param_clone.scale = quant_param.scale;
1076 quant_param_clone.zeroPoint = quant_param.zeroPoint;
1077 quant_param_clone.numBits = quant_param.numBits;
1078 quant_param_clone.narrowRange = quant_param.narrowRange;
1079 quant_param_clone.meanCorr = quant_param.meanCorr;
1080 quant_param_clone.varCorr = quant_param.varCorr;
1081 quant_param_clone.dstDtype = quant_param.dstDtype;
1082 quant_param_clone.min = quant_param.min;
1083 quant_param_clone.max = quant_param.max;
1084 quant_param_clone.roundType = quant_param.roundType;
1085 quant_param_clone.multiplier = quant_param.multiplier;
1086 dst.push_back(quant_param_clone);
1087 }
1088 return dst;
1089 }
1090
CalBiasQuantParams(const std::vector<schema::QuantParamT> & active_params,const std::vector<schema::QuantParamT> & weight_params,std::vector<schema::QuantParamT> * bias_params)1091 int CalBiasQuantParams(const std::vector<schema::QuantParamT> &active_params,
1092 const std::vector<schema::QuantParamT> &weight_params,
1093 std::vector<schema::QuantParamT> *bias_params) {
1094 std::vector<double> input_scales;
1095 std::vector<double> filter_scales;
1096 std::vector<double> bias_scales;
1097 size_t sizeX = active_params.size();
1098 for (size_t i = 0; i < sizeX; i++) {
1099 input_scales.emplace_back(active_params[i].scale);
1100 }
1101 size_t sizeY = weight_params.size();
1102 if (sizeX != sizeY) {
1103 if (sizeX > 1 && sizeY > 1) {
1104 MS_LOG(ERROR) << "input and filter's scale count cannot match!";
1105 return RET_ERROR;
1106 }
1107 }
1108 for (size_t i = 0; i < sizeY; i++) {
1109 filter_scales.emplace_back(weight_params[i].scale);
1110 }
1111 size_t size = std::max(sizeX, sizeY);
1112 for (size_t i = 0; i < size; i++) {
1113 auto scaleX = sizeX > 1 ? input_scales[i] : input_scales[0];
1114 auto scaleY = sizeY > 1 ? filter_scales[i] : filter_scales[0];
1115 bias_scales.push_back(scaleX * scaleY);
1116 }
1117 MS_ASSERT(!bias_scales.empty());
1118
1119 // set bias quant param
1120 for (double bias_scale : bias_scales) {
1121 schema::QuantParamT quant_param;
1122 if (bias_scale == 0) {
1123 MS_LOG(WARNING) << "bias_scale is 0, and set bias_scale to 1.";
1124 quant_param.scale = 1;
1125 } else {
1126 quant_param.scale = bias_scale;
1127 }
1128 quant_param.numBits = k32Bit;
1129 quant_param.zeroPoint = 0;
1130 quant_param.inited = true;
1131 bias_params->push_back(quant_param);
1132 }
1133 return RET_OK;
1134 }
1135
IsAntiQuantModeNodes(const AnfNodePtr & node)1136 bool IsAntiQuantModeNodes(const AnfNodePtr &node) {
1137 CHECK_NULL_RETURN(node);
1138 if (!utils::isa<CNodePtr>(node) || !opt::CheckPrimitiveType(node, prim::kPrimMul)) {
1139 MS_LOG(INFO) << "The node is not Mul node";
1140 return false;
1141 }
1142 auto add_node = node->cast<CNodePtr>()->input(kIndexOne);
1143 if (!utils::isa<CNodePtr>(add_node) || !opt::CheckPrimitiveType(add_node, prim::kPrimAdd)) {
1144 MS_LOG(INFO) << "The node is not Add node. ";
1145 return false;
1146 }
1147 auto ascend_antiquant_node = add_node->cast<CNodePtr>()->input(kIndexOne);
1148 if (!utils::isa<CNodePtr>(ascend_antiquant_node) ||
1149 !(opt::CheckPrimitiveType(ascend_antiquant_node, prim::kPrimAntiQuant) ||
1150 GetCNodePrimitive(ascend_antiquant_node)->name() == "AscendAntiQuant")) {
1151 MS_LOG(INFO) << "The node is not AscendAntiquant node";
1152 return false;
1153 }
1154 return true;
1155 }
1156
GetScaleZpFromAntiQuantModeNodes(const AnfNodePtr & node,ParameterPtr * scale_param_node,ParameterPtr * zp_param_node)1157 STATUS GetScaleZpFromAntiQuantModeNodes(const AnfNodePtr &node, ParameterPtr *scale_param_node,
1158 ParameterPtr *zp_param_node) {
1159 CHECK_NULL_RETURN(node);
1160 CHECK_NULL_RETURN(scale_param_node);
1161 CHECK_NULL_RETURN(zp_param_node);
1162
1163 if (!utils::isa<CNodePtr>(node) || !opt::CheckPrimitiveType(node, prim::kPrimMul)) {
1164 return RET_ERROR;
1165 }
1166 auto add_node = node->cast<CNodePtr>()->input(kIndexOne);
1167 auto scale_param = node->cast<CNodePtr>()->input(kIndexTwo);
1168 if (opt::CheckPrimitiveType(scale_param, prim::kPrimLoad)) {
1169 scale_param = scale_param->cast<CNodePtr>()->input(kIndexOne);
1170 }
1171 *scale_param_node = scale_param->cast<ParameterPtr>();
1172 CHECK_NULL_RETURN(*scale_param_node);
1173 if (!utils::isa<CNodePtr>(add_node) || !opt::CheckPrimitiveType(add_node, prim::kPrimAdd)) {
1174 return RET_ERROR;
1175 }
1176 auto zp_param = add_node->cast<CNodePtr>()->input(kIndexTwo);
1177 if (opt::CheckPrimitiveType(zp_param, prim::kPrimLoad)) {
1178 zp_param = zp_param->cast<CNodePtr>()->input(kIndexOne);
1179 }
1180 *zp_param_node = zp_param->cast<ParameterPtr>();
1181 CHECK_NULL_RETURN(*zp_param_node);
1182 return RET_OK;
1183 }
1184
RemoveAntiQuantModeNodes(const FuncGraphPtr & func_graph,const AnfNodePtr & node,int index)1185 STATUS RemoveAntiQuantModeNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int index) {
1186 CHECK_NULL_RETURN(func_graph);
1187 CHECK_NULL_RETURN(node);
1188
1189 auto manager = func_graph->manager();
1190 if (manager == nullptr) {
1191 manager = Manage(func_graph, true);
1192 }
1193 CHECK_NULL_RETURN(manager);
1194
1195 if (!utils::isa<CNodePtr>(node)) {
1196 MS_LOG(ERROR) << "The node : " << node->fullname_with_scope() << ", it is not cnode";
1197 return lite::RET_ERROR;
1198 }
1199 auto cnode = node->cast<CNodePtr>();
1200 CHECK_NULL_RETURN(cnode);
1201
1202 auto mul_node = cnode->input(index);
1203
1204 if (!utils::isa<CNodePtr>(mul_node) || !opt::CheckPrimitiveType(mul_node, prim::kPrimMul)) {
1205 MS_LOG(WARNING) << "In AntiQuant mode, the node : " << cnode->fullname_with_scope() << " is not mul node";
1206 return RET_OK;
1207 }
1208 auto add_node = mul_node->cast<CNodePtr>()->input(kIndexOne);
1209 if (!opt::CheckPrimitiveType(add_node, prim::kPrimAdd)) {
1210 MS_LOG(WARNING) << "In AntiQuant mode, the node : " << add_node->fullname_with_scope() << " is not add node";
1211 return RET_OK;
1212 }
1213 auto ascend_antiquant_node = add_node->cast<CNodePtr>()->input(kIndexOne);
1214 if (!(opt::CheckPrimitiveType(ascend_antiquant_node, prim::kPrimAntiQuant) ||
1215 GetCNodePrimitive(ascend_antiquant_node)->name() == "AscendAntiQuant")) {
1216 MS_LOG(WARNING) << "In AntiQuant mode, the node : " << ascend_antiquant_node->fullname_with_scope()
1217 << " is not antiquant node";
1218 return RET_OK;
1219 }
1220
1221 manager->Replace(mul_node, ascend_antiquant_node->cast<CNodePtr>()->input(1));
1222 return lite::RET_OK;
1223 }
1224
ExtractStrategy(const ValuePtr & stra)1225 std::vector<std::vector<int64_t>> ExtractStrategy(const ValuePtr &stra) {
1226 if (stra == nullptr) {
1227 return {};
1228 }
1229
1230 auto var = stra->cast<ValueTuplePtr>();
1231 if (var == nullptr) {
1232 return {};
1233 }
1234 std::vector<std::vector<int64_t>> strategy;
1235 MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
1236 if (var->size() > 0) {
1237 std::vector<ValuePtr> elements = var->value();
1238 for (uint64_t index = 0; index < elements.size(); ++index) {
1239 std::vector<int64_t> dim;
1240 if (elements[index]->isa<ValueSequence>()) {
1241 auto value_tuple = elements[index]->cast<ValueTuplePtr>();
1242 std::vector<ValuePtr> value_vector = value_tuple->value();
1243 (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
1244 [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
1245 strategy.push_back(dim);
1246 } else {
1247 MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
1248 }
1249 }
1250 if (strategy.empty()) {
1251 MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
1252 }
1253 }
1254
1255 return strategy;
1256 }
1257
CalQuantParamWithMinMax(const tensor::TensorPtr & min_value,const tensor::TensorPtr & max_value,bool symmetric)1258 std::vector<schema::QuantParamT> CalQuantParamWithMinMax(const tensor::TensorPtr &min_value,
1259 const tensor::TensorPtr &max_value, bool symmetric) {
1260 std::vector<schema::QuantParamT> quant_params;
1261 // Ascend fake quant transform support PerLayer && PerChannel quant param
1262 if (min_value->ElementsNum() != max_value->ElementsNum()) {
1263 MS_LOG(ERROR) << "min value size not equal max value size";
1264 return {};
1265 }
1266 int size = min_value->ElementsNum();
1267 auto min_data = reinterpret_cast<float *>(min_value->data_c());
1268 auto max_data = reinterpret_cast<float *>(max_value->data_c());
1269 for (int i = 0; i < size; i++) {
1270 float real_min = *(min_data + i);
1271 float real_max = *(max_data + i);
1272 schema::QuantParamT quant_param;
1273 int bit_num = k8Bit;
1274
1275 MS_LOG(DEBUG) << "min: " << real_min << " max: " << real_max << " bit_num: " << bit_num << " symmetric"
1276 << symmetric;
1277 auto ret = CalQuantizationParams(&quant_param, real_min, real_max, bit_num, symmetric);
1278 if (ret != RET_OK) {
1279 MS_LOG(ERROR) << "Failed to calculate quant params";
1280 return {};
1281 }
1282 MS_LOG(INFO) << "quant param scale: " << quant_param.scale << " zp: " << quant_param.zeroPoint;
1283 quant_params.push_back(quant_param);
1284 }
1285 return quant_params;
1286 }
1287
GetQuantParamWithFakeQuantNode(const CNodePtr & fake_quant_node,bool symmetric)1288 std::vector<schema::QuantParamT> GetQuantParamWithFakeQuantNode(const CNodePtr &fake_quant_node, bool symmetric) {
1289 tensor::TensorPtr min_value;
1290 tensor::TensorPtr max_value;
1291 auto min_input = fake_quant_node->input(kFakeQuantMinIndex + kPrimOffset);
1292 if (utils::isa<ParameterPtr>(min_input) && min_input->cast<ParameterPtr>()->has_default() &&
1293 min_input->cast<ParameterPtr>()->default_param() != nullptr) {
1294 min_value = min_input->cast<ParameterPtr>()->default_param()->cast<tensor::TensorPtr>();
1295 } else {
1296 MS_LOG(ERROR) << "Quant param get min value failed";
1297 return {};
1298 }
1299 auto max_input = fake_quant_node->input(kFakeQuantMaxIndex + kPrimOffset);
1300 if (utils::isa<ParameterPtr>(max_input) && max_input->cast<ParameterPtr>()->has_default() &&
1301 max_input->cast<ParameterPtr>()->default_param() != nullptr) {
1302 max_value = max_input->cast<ParameterPtr>()->default_param()->cast<tensor::TensorPtr>();
1303 } else {
1304 MS_LOG(ERROR) << "Quant param get max value failed";
1305 return {};
1306 }
1307 auto quant_params = CalQuantParamWithMinMax(min_value, max_value, symmetric);
1308 return quant_params;
1309 }
1310
1311 } // namespace mindspore::lite::quant
1312