• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
20 #include <cmath>
21 #include <string>
22 #include <memory>
23 #include <vector>
24 #include <set>
25 #include <functional>
26 #include <deque>
27 #include "include/common/utils/convert_utils.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "tools/common/graph_util.h"
32 #include "tools/lite_exporter/anf_exporter.h"
33 #include "tools/converter/graphdef_transform.h"
34 #include "tools/common/tensor_util.h"
35 #include "tools/optimizer/common/gllo_utils.h"
36 #include "ops/fusion/mat_mul_fusion.h"
37 #include "ops/auto_generate/gen_lite_ops.h"
38 #include "ops/fusion/conv2d_transpose_fusion.h"
39 #include "ops/ops_func_impl/gather.h"
40 #include "ops/op_utils.h"
41 #include "src/common/utils.h"
42 #include "src/common/file_utils.h"
43 #include "src/litert/cxx_api/tensor/tensor_impl.h"
44 #include "ir/anf.h"
45 #include "tools/converter/export_model.h"
46 #include "tools/converter/parser/parser_utils.h"
47 #include "ops/other_ops.h"
48 #include "utils/anf_utils.h"
49 #include "mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/op_base.h"
50 
51 using std::string;
52 using std::vector;
53 
54 namespace mindspore::lite::quant {
55 namespace {
56 constexpr size_t kGatherAxisIndex = 3;
57 constexpr int kDefaultThreadNum = 4;
58 constexpr size_t kEncMaxLen = 16;
59 constexpr size_t kModelSizeLimit = static_cast<size_t>(2) * 1024 * 1024 * 1024;
60 constexpr int kFakeQuantMinIndex = 1;
61 constexpr int kFakeQuantMaxIndex = 2;
62 }  // namespace
63 
GetQuantType(const CNodePtr & cnode,quant::QuantType * quant_type)64 int GetQuantType(const CNodePtr &cnode, quant::QuantType *quant_type) {
65   CHECK_NULL_RETURN(cnode);
66   CHECK_NULL_RETURN(quant_type);
67   auto quant_param_holder = GetCNodeQuantHolder(cnode);
68   if (quant_param_holder == nullptr) {
69     *quant_type = quant::QUANT_NONE;
70     return RET_OK;
71   }
72   *quant_type = quant_param_holder->quant_type();
73   return RET_OK;
74 }
75 
GetQuantTypeNew(const CNodePtr & cnode,quant::QuantType * quant_type)76 int GetQuantTypeNew(const CNodePtr &cnode, quant::QuantType *quant_type) {
77   CHECK_NULL_RETURN(cnode);
78   CHECK_NULL_RETURN(quant_type);
79   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
80   if (primitive == nullptr) {
81     MS_LOG(ERROR) << "primitive is nullptr.";
82     return RET_NULL_PTR;
83   }
84   auto quant_type_attr = primitive->GetAttr(quant::kQuantType);
85   if (quant_type_attr == nullptr) {
86     *quant_type = quant::QUANT_NONE;
87     return RET_OK;
88   }
89   *quant_type = static_cast<quant::QuantType>(GetValue<int32_t>(quant_type_attr));
90   return RET_OK;
91 }
92 
GetFuncGraphs(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)93 void GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
94   MS_ASSERT(func_graph != nullptr);
95   MS_ASSERT(all_func_graphs != nullptr);
96   all_func_graphs->insert(func_graph);
97   auto nodes = func_graph->GetOrderedCnodes();
98   std::deque<CNodePtr> to_process{};
99   to_process.insert(to_process.end(), nodes.begin(), nodes.end());
100   while (!to_process.empty()) {
101     auto &cur_cnode = to_process.front();
102     for (auto &input : cur_cnode->inputs()) {
103       if (!IsValueNode<FuncGraph>(input)) {
104         continue;
105       }
106       auto new_fg = GetValueNode<FuncGraphPtr>(input);
107       if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
108         continue;
109       }
110       all_func_graphs->insert(new_fg);
111       auto new_nodes = new_fg->GetOrderedCnodes();
112       to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
113     }
114     to_process.pop_front();
115   }
116 }
117 
UpdateDataType(const AnfNodePtr & node,TypeId new_data_type)118 int UpdateDataType(const AnfNodePtr &node, TypeId new_data_type) {
119   auto abstract_base = node->abstract();
120   if (abstract_base == nullptr) {
121     MS_LOG(ERROR) << "Abstract of node is nullptr, " << node->fullname_with_scope();
122     return RET_NULL_PTR;
123   }
124 
125   std::vector<AbstractBasePtr> abstracts;
126   if (utils::isa<abstract::AbstractTuple>(abstract_base)) {
127     auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract_base);
128     abstracts = abstract_tuple->elements();
129   } else {
130     abstracts.push_back(abstract_base);
131   }
132   for (auto &abstract : abstracts) {
133     auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
134     CHECK_NULL_RETURN(abstract_tensor);
135     CHECK_NULL_RETURN(abstract_tensor->element());
136     abstract_tensor->element()->set_type(TypeIdToType(new_data_type));
137   }
138   return RET_OK;
139 }
140 
IsGraphInDTypeCast(const CNodePtr & cnode)141 bool IsGraphInDTypeCast(const CNodePtr &cnode) {
142   if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
143     return false;
144   }
145   auto input_node = cnode->input(1);
146   MS_CHECK_FALSE(input_node == nullptr, false);
147   return IsGraphInput(input_node);
148 }
149 
IsGraphOutDTypeCast(const FuncGraphPtr & func_graph,const CNodePtr & cnode)150 bool IsGraphOutDTypeCast(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
151   if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
152     return false;
153   }
154   auto manager = func_graph->manager();
155   if (manager == nullptr) {
156     manager = Manage(func_graph, true);
157   }
158   MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr.");
159   auto node_users = manager->node_users()[cnode];
160   MS_CHECK_TRUE_MSG(!node_users.empty(), false, "node_users is empty.");
161   for (auto &node_user : node_users) {
162     auto output_cnode = node_user.first->cast<CNodePtr>();
163     MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "output_cnode is nullptr.");
164     if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
165       return false;
166     }
167   }
168   return true;
169 }
170 
GetCastNodeType(const FuncGraphPtr & func_graph,const CNodePtr & cnode,CastNodeType * cast_node_type)171 int GetCastNodeType(const FuncGraphPtr &func_graph, const CNodePtr &cnode, CastNodeType *cast_node_type) {
172   CHECK_NULL_RETURN(cast_node_type);
173   if (!opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
174     MS_LOG(DEBUG) << "Not QuantDtypeCastNode, cnode name: " << cnode->fullname_with_scope();
175     return RET_NOT_SUPPORT;
176   }
177   auto input_node = cnode->input(1);
178   MS_CHECK_FALSE(input_node == nullptr, RET_ERROR);
179 
180   // input node
181   TypeId pre_node_dtype = kTypeUnknown;
182   if (opt::GetDataTypeFromAnfNode(input_node, &pre_node_dtype) != RET_OK) {
183     MS_LOG(ERROR) << "Get data type failed, cnode name: " << input_node->fullname_with_scope();
184     return RET_ERROR;
185   }
186 
187   // output node
188   TypeId post_node_dtype = kTypeUnknown;
189   auto manager = func_graph->manager();
190   if (manager == nullptr) {
191     manager = Manage(func_graph, true);
192   }
193   CHECK_NULL_RETURN(manager);
194   auto node_users = manager->node_users()[cnode];
195   MS_CHECK_TRUE_RET(!node_users.empty(), RET_NULL_PTR);
196   auto output_cnode = node_users.begin()->first->cast<CNodePtr>();
197   CHECK_NULL_RETURN(output_cnode);
198 
199   if (!opt::CheckPrimitiveType(output_cnode, prim::kPrimReturn)) {
200     if (opt::GetDataTypeFromAnfNode(output_cnode, &post_node_dtype) != RET_OK) {
201       MS_LOG(ERROR) << "Get data type failed, cnode name: " << output_cnode->fullname_with_scope();
202       return RET_ERROR;
203     }
204     if (pre_node_dtype == kNumberTypeFloat32 &&
205         (post_node_dtype == kNumberTypeInt8 || post_node_dtype == kNumberTypeUInt8)) {
206       *cast_node_type = kQuant;
207     } else if ((pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) &&
208                post_node_dtype == kNumberTypeFloat32) {
209       *cast_node_type = kDeQuant;
210     } else {
211       MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
212     }
213   } else {
214     if (pre_node_dtype == kNumberTypeFloat32) {
215       *cast_node_type = kQuant;
216     } else if (pre_node_dtype == kNumberTypeInt8 || pre_node_dtype == kNumberTypeUInt8) {
217       *cast_node_type = kDeQuant;
218     } else {
219       MS_LOG(ERROR) << "Not support QuantDTypeCastNode, cnode name: " << cnode->fullname_with_scope();
220     }
221   }
222   return RET_OK;
223 }
224 
NodePrimitiveType(const CNodePtr & cnode)225 std::string NodePrimitiveType(const CNodePtr &cnode) {
226   if (cnode == nullptr) {
227     MS_LOG(ERROR) << "cnode is null";
228     return "";
229   }
230   auto primitive_c = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
231   if (primitive_c == nullptr) {
232     MS_LOG(ERROR) << "primitive_c is null";
233     return "";
234   }
235   return primitive_c->name();
236 }
237 
LargeModelBuildModel(const schema::MetaGraphT & meta_graph,const std::shared_ptr<ConverterPara> & param,const std::shared_ptr<mindspore::Model> & model,const std::shared_ptr<Context> & context,size_t * size)238 Status LargeModelBuildModel(const schema::MetaGraphT &meta_graph, const std::shared_ptr<ConverterPara> &param,
239                             const std::shared_ptr<mindspore::Model> &model, const std::shared_ptr<Context> &context,
240                             size_t *size) {
241   if (size == nullptr) {
242     return kLiteError;
243   }
244   if (param->commonQuantParam.workspace.empty()) {
245     MS_LOG(ERROR) << "The model is larger than 2G, mixedBitWeightQuant config needs to set workspace to save tmp model";
246     return kLiteError;
247   }
248   std::string tmp_save_file_path = param->commonQuantParam.workspace + "/tmp.ms";
249   tmp_save_file_path = lite::RealPath(tmp_save_file_path.c_str());
250   if (tmp_save_file_path.empty()) {
251     MS_LOG(ERROR) << param->commonQuantParam.workspace << " is invalid path. Please check it again.";
252     return kLiteError;
253   }
254   unsigned char encKey[kEncMaxLen] = {0};
255   size_t keyLen = 0;
256   auto status = MetaGraphSerializer::Save(meta_graph, tmp_save_file_path, size, encKey, keyLen, param->encrypt_mode);
257   if (status != RET_OK) {
258     MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status);
259     return kLiteError;
260   }
261 
262   mindspore::ModelType model_type = kMindIR_Lite;
263   auto ret = model->Build(tmp_save_file_path, model_type, context);
264   return ret;
265 }
266 
DumpGraph(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param,const std::string & save_path)267 int DumpGraph(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
268               const std::string &save_path) {
269   FuncGraphPtr func_graph_clone;
270   if (CloneFuncGraph(func_graph, param, &func_graph_clone) != RET_OK) {
271     MS_LOG(ERROR) << "Clone func_graph failed";
272     return RET_ERROR;
273   }
274   auto meta_graph = Export(func_graph_clone, true, true);
275   if (meta_graph == nullptr) {
276     MS_LOG(ERROR) << "Export to meta_graph failed";
277     return RET_ERROR;
278   }
279 
280   // transform
281   GraphDefTransform fb_transform;
282   fb_transform.SetGraphDef(meta_graph);
283   auto status = fb_transform.Transform(param);
284   if (status != RET_OK) {
285     MS_LOG(ERROR) << "FBTransform model failed";
286     delete meta_graph;
287     return RET_ERROR;
288   }
289   meta_graph->version = Version();
290 
291   status = UpdateGraphOutputName(meta_graph);
292   if (status != RET_OK) {
293     MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
294     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
295     delete meta_graph;
296     return RET_ERROR;
297   }
298 
299   unsigned char encKey[kEncMaxLen] = {0};
300   size_t keyLen = 0;
301   size_t size;
302   status = MetaGraphSerializer::Save(*meta_graph, save_path, &size, encKey, keyLen, param->encrypt_mode);
303   if (status != RET_OK) {
304     MS_LOG(ERROR) << "Save Large Model Failed: " << status << " " << GetErrorInfo(status);
305     return RET_ERROR;
306   }
307   return RET_OK;
308 }
309 
BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> & model,const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param,size_t * size)310 Status BuildModelByFuncGraph(const std::shared_ptr<mindspore::Model> &model, const FuncGraphPtr &func_graph,
311                              const std::shared_ptr<ConverterPara> &param, size_t *size) {
312   if (size == nullptr) {
313     return kLiteError;
314   }
315   FuncGraphPtr func_graph_clone;
316   if (CloneFuncGraph(func_graph, param, &func_graph_clone) != RET_OK) {
317     MS_LOG(ERROR) << "Clone func_graph failed";
318     return kLiteNullptr;
319   }
320   auto meta_graph = Export(func_graph_clone, true, true);
321   if (meta_graph == nullptr) {
322     MS_LOG(ERROR) << "Export to meta_graph failed";
323     return kLiteNullptr;
324   }
325 
326   // transform
327   GraphDefTransform fb_transform;
328   fb_transform.SetGraphDef(meta_graph);
329   auto status = fb_transform.Transform(param);
330   if (status != RET_OK) {
331     MS_LOG(ERROR) << "FBTransform model failed";
332     delete meta_graph;
333     return kLiteError;
334   }
335   meta_graph->version = Version();
336 
337   status = UpdateGraphOutputName(meta_graph);
338   if (status != RET_OK) {
339     MS_LOG(ERROR) << "UpdateGraphOutputName failed.";
340     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
341     delete meta_graph;
342     return kLiteError;
343   }
344 
345   auto context = std::make_shared<mindspore::Context>();
346   if (context == nullptr) {
347     MS_LOG(ERROR) << "New context failed while running.";
348     delete meta_graph;
349     return kLiteNullptr;
350   }
351   context->SetThreadNum(kDefaultThreadNum);
352   context->SetThreadAffinity(kCpuBindMode);
353 
354   std::shared_ptr<CPUDeviceInfo> device_info = std::make_shared<CPUDeviceInfo>();
355   if (device_info == nullptr) {
356     MS_LOG(ERROR) << "New device_info failed while running.";
357     delete meta_graph;
358     return kLiteNullptr;
359   }
360   auto &device_list = context->MutableDeviceInfo();
361   device_list.push_back(device_info);
362 
363   size_t tensors_size = 0;
364   for (auto &tensor : meta_graph->allTensors) {
365     tensors_size += tensor->data.size();
366   }
367 
368   if (tensors_size >= kModelSizeLimit) {
369     auto ret = LargeModelBuildModel(*meta_graph, param, model, context, size);
370     delete meta_graph;
371     return ret;
372   }
373 
374   flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
375   auto offset = schema::MetaGraph::Pack(builder, meta_graph);
376   builder.Finish(offset);
377   schema::FinishMetaGraphBuffer(builder, offset);
378   *size = builder.GetSize();
379   auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
380   if (content == nullptr) {
381     MS_LOG(ERROR) << "GetBufferPointer return null";
382     delete meta_graph;
383     return kLiteNullptr;
384   }
385 
386   auto ret = model->Build(content, *size, kMindIR, context);
387   delete meta_graph;
388   return ret;
389 }
390 
MSTensorToLiteTensor(const MSTensor & tensor)391 mindspore::lite::Tensor *MSTensorToLiteTensor(const MSTensor &tensor) {
392   if (tensor.impl() == nullptr) {
393     MS_LOG(ERROR) << "Tensor " << tensor.Name() << " is nullptr.";
394     return static_cast<lite::Tensor *>(nullptr);
395   }
396   auto lite_impl = std::static_pointer_cast<LiteTensorImpl>(tensor.impl());
397   return static_cast<mindspore::lite::Tensor *>(lite_impl->lite_tensor());
398 }
399 
MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> & src_tensors)400 std::vector<mindspore::lite::Tensor *> MSTensorToLiteTensors(const std::vector<mindspore::MSTensor> &src_tensors) {
401   std::vector<mindspore::lite::Tensor *> dst_tensors(src_tensors.size());
402   for (const auto &src_tensor : src_tensors) {
403     auto tensor = MSTensorToLiteTensor(src_tensor);
404     if (tensor == nullptr) {
405       return {};
406     }
407     dst_tensors.emplace_back(tensor);
408   }
409   return dst_tensors;
410 }
411 
GetParameterAndTensor(const AnfNodePtr & node,ParameterPtr * param_node,tensor::TensorPtr * tensor_info)412 void GetParameterAndTensor(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
413   CHECK_NULL_RETURN_VOID(param_node);
414   CHECK_NULL_RETURN_VOID(tensor_info);
415   if (node == nullptr) {
416     MS_LOG(ERROR) << "node is nullptr";
417     return;
418   }
419   auto op_name = node->fullname_with_scope();
420 
421   *param_node = node->cast<ParameterPtr>();
422   if (*param_node == nullptr) {
423     MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
424     return;
425   }
426   if (!(*param_node)->has_default()) {
427     MS_LOG(INFO) << op_name << " not has_default";
428     return;
429   }
430 
431   *tensor_info = std::static_pointer_cast<tensor::Tensor>((*param_node)->default_param());
432   if (*tensor_info == nullptr) {
433     MS_LOG(INFO) << "default_param can not cast to tensor::Tensor";
434     return;
435   }
436 }
437 
UpdateTensorDataAndSize(const AnfNodePtr & node,const tensor::TensorPtr & weight,const void * quant_datas,size_t new_size,TypeId new_data_type)438 int UpdateTensorDataAndSize(const AnfNodePtr &node, const tensor::TensorPtr &weight, const void *quant_datas,
439                             size_t new_size, TypeId new_data_type) {
440   CHECK_NULL_RETURN(quant_datas);
441   MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
442   MS_CHECK_TRUE_RET(new_size > 0, RET_NULL_PTR);
443   weight->set_data_type(new_data_type);
444   if (new_size != static_cast<size_t>(weight->data().nbytes())) {
445     MS_LOG(ERROR) << "Data size of tensor info is error.";
446     return RET_ERROR;
447   }
448   if (memcpy_s(weight->data_c(), weight->data().nbytes(), quant_datas, new_size) != EOK) {
449     MS_LOG(ERROR) << "memcpy data failed.";
450     return RET_ERROR;
451   }
452   // set dtype
453   auto ret = UpdateDataType(node, new_data_type);
454   if (ret != RET_OK) {
455     MS_LOG(ERROR) << node->fullname_with_scope() << " set new dtype failed.";
456     return ret;
457   }
458   return RET_OK;
459 }
460 
GetMatMulPreferredDim(const PrimitivePtr & primitive,int input_index,const std::vector<int> & dims)461 int GetMatMulPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims) {
462   size_t last_first_index = dims.size() - 1;
463   size_t last_second_index = dims.size() - 2;
464   auto matmul_prim = api::MakeShared<ops::MatMul>(primitive);
465   MS_ASSERT(matmul_prim != nullptr);
466   // For MatMul A
467   if (input_index == 0) {
468     if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) {
469       return last_first_index;
470     } else {
471       return last_second_index;
472     }
473   }
474   // For MatMul B
475   if (input_index == 1) {
476     if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) {
477       return last_second_index;
478     } else {
479       return last_first_index;
480     }
481   }
482   return 0;
483 }
484 
GetDeConvPreferredDim(const PrimitivePtr & primitive,const std::vector<int> & dims)485 int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int> &dims) {
486   auto prim = api::MakeShared<ops::Conv2DTranspose>(primitive);
487   MS_ASSERT(prim != nullptr);
488   if (prim->get_in_channel() == prim->get_group() && prim->get_out_channel() == prim->get_group()) {
489     // DepthWise-DeConv (CO\CI) KH KW 1
490     return 0;
491   }
492   // DeConv:CI KH KW CO
493   return dims.size() - 1;
494 }
495 
GetGatherPreferredDim(const CNodePtr & cnode)496 int GetGatherPreferredDim(const CNodePtr &cnode) {
497   if (cnode->size() < kGatherAxisIndex + kPrimOffset) {
498     MS_LOG(WARNING) << "gather cnode size < 4.";
499     return 0;
500   }
501   DataInfo data_info;
502   auto output_type_node = cnode->input(kGatherAxisIndex);
503   if (utils::isa<ParameterPtr>(output_type_node)) {
504     if (FetchDataFromParameterNode(cnode, kGatherAxisIndex, converter::kFmkTypeMs, &data_info, true) != lite::RET_OK) {
505       MS_LOG(WARNING) << "Fetch data from parameter node failed.";
506       return 0;
507     }
508   } else if (utils::isa<ValueNodePtr>(output_type_node)) {
509     if (FetchDataFromValueNode(cnode, kGatherAxisIndex, converter::kFmkTypeMs, false, &data_info, true) !=
510         lite::RET_OK) {
511       MS_LOG(WARNING) << "Fetch data from value node failed.";
512       return 0;
513     }
514   } else {
515     MS_LOG(WARNING) << "The data type is not a const.";
516     return 0;
517   }
518 
519   auto axis_data = reinterpret_cast<const int *>(data_info.data_.data());
520   CHECK_NULL_RETURN(axis_data);
521   return axis_data[0];
522 }
523 
GetPreferredDim(const CNodePtr & cnode,int input_index,const std::vector<int> & dims)524 int GetPreferredDim(const CNodePtr &cnode, int input_index, const std::vector<int> &dims) {
525   auto input_node = cnode->input(input_index + kPrimOffset);
526   if (input_node->isa<mindspore::Parameter>()) {
527     tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
528     if (input_tensor != nullptr) {
529       auto quantization_params = input_tensor->quant_params();
530       if (!quantization_params.empty()) {
531         auto quantization_param = quantization_params.front();
532         auto axis_attr = quantization_param->GetAttr(kChannelAxis);
533         if (axis_attr != nullptr) {
534           if (axis_attr->isa<Int64Imm>()) {
535             auto axis = axis_attr->cast<Int64ImmPtr>()->value();
536             MS_LOG(INFO) << "Quantization param axis is " << axis;
537             return axis;
538           }
539           MS_LOG(WARNING) << "Quantization param axis_attr is not int64";
540         }
541       }
542     }
543   }
544   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
545   CHECK_NULL_RETURN(primitive);
546   if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul ||
547       primitive->name() == ops::kNameBatchMatMul) {
548     return GetMatMulPreferredDim(primitive, input_index, dims);
549   } else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
550     return GetDeConvPreferredDim(primitive, dims);
551   } else if (primitive->name() == ops::kNameGather) {
552     return GetGatherPreferredDim(cnode);
553   } else if (primitive->name() == "FFN") {
554     // For FFN MatMul, transpose is false
555     return dims.size() - 1;
556   }
557   // The first index.
558   return 0;
559 }
560 
GetFollowedNodePreferredDim(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & dims)561 int GetFollowedNodePreferredDim(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &dims) {
562   auto manager = mindspore::Manage(func_graph, true);
563   auto node_users = manager->node_users()[cnode];
564   if (node_users.empty()) {
565     MS_LOG(WARNING) << cnode->fullname_with_scope() << " cnode is isolated.";
566     return 0;
567   }
568   if (node_users.size() > 1) {
569     MS_LOG(WARNING) << "The cnode dont has only one followed node";
570     return 0;
571   }
572   auto node_user = node_users.begin();
573   if (!utils::isa<CNodePtr>(node_user->first)) {
574     MS_LOG(WARNING) << "The followed op: " << node_user->first->fullname_with_scope() << " is not cnode";
575     return 0;
576   }
577   auto node_user_cnode = utils::cast<CNodePtr>(node_user->first);
578   return GetPreferredDim(node_user_cnode, node_user->second - 1, dims);
579 }
580 
ConvertShapeVectorToInt32(const ShapeVector & dims)581 std::vector<int> ConvertShapeVectorToInt32(const ShapeVector &dims) {
582   std::vector<int> shape;
583   for (auto dim : dims) {
584     if (dim > INT32_MAX || dim < INT32_MIN) {
585       MS_LOG(ERROR) << dim << " over int32 range.";
586       shape.push_back(-1);
587     } else {
588       shape.push_back(dim);
589     }
590   }
591   return shape;
592 }
593 
CheckNodeInSet(const CNodePtr & cnode,const std::set<PrimitivePtr> & support_primitive_types)594 bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types) {
595   for (const auto &type : support_primitive_types) {
596     if (opt::CheckPrimitiveType(cnode, type)) {
597       return true;
598     }
599   }
600   return false;
601 }
602 
CheckFollowedNodeInSet(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::set<PrimitivePtr> & support_primitive_types)603 bool CheckFollowedNodeInSet(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
604                             const std::set<PrimitivePtr> &support_primitive_types) {
605   auto manager = mindspore::Manage(func_graph, true);
606   auto node_users = manager->node_users()[cnode];
607   if (node_users.empty()) {
608     MS_LOG(WARNING) << cnode->fullname_with_scope() << " cnode is isolated.";
609     return false;
610   }
611   for (auto &node_user : node_users) {
612     if (!utils::isa<CNodePtr>(node_user.first)) {
613       MS_LOG(INFO) << "The followed op: " << node_user.first->fullname_with_scope() << " is not cnode";
614       return false;
615     }
616     auto node_user_cnode = utils::cast<CNodePtr>(node_user.first);
617     if (!CheckNodeInSet(node_user_cnode, support_primitive_types)) {
618       return false;
619     }
620   }
621   return true;
622 }
623 
DeQuantData(const mindspore::MSTensor * tensor,std::vector<double> * dequant_data)624 int DeQuantData(const mindspore::MSTensor *tensor, std::vector<double> *dequant_data) {
625   return DeQuantData(reinterpret_cast<const int8_t *>(tensor->Data().get()), tensor->ElementNum(),
626                      tensor->QuantParams(), dequant_data);
627 }
628 
GetElementNumFromShape(const std::vector<int> & dims,int * total_size)629 int GetElementNumFromShape(const std::vector<int> &dims, int *total_size) {
630   CHECK_NULL_RETURN(total_size);
631   *total_size = 1;
632   for (auto dim : dims) {
633     MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(*total_size, dim), RET_ERROR, "Int mul overflow.");
634     *total_size *= dim;
635   }
636   return RET_OK;
637 }
638 
GetBucketAllIndex(const std::vector<int> & dims,int preferred_dim,std::vector<std::vector<size_t>> * buckets_data_index)639 int GetBucketAllIndex(const std::vector<int> &dims, int preferred_dim,
640                       std::vector<std::vector<size_t>> *buckets_data_index) {
641   CHECK_NULL_RETURN(buckets_data_index);
642   int outer = 1;
643   for (int i = 0; i < preferred_dim; i++) {
644     outer *= dims[i];
645   }
646   int bucket_count = dims[preferred_dim];
647   int inner = 1;
648   for (size_t i = preferred_dim + 1; i < dims.size(); i++) {
649     inner *= dims[i];
650   }
651   if (inner <= 0 || outer <= 0 || bucket_count <= 0) {
652     return RET_ERROR;
653   }
654   for (int i = 0; i < bucket_count; i++) {
655     auto index = i * inner;
656     std::vector<size_t> bucket_index(inner * outer);
657     for (int j = 0; j < outer; j++) {
658       for (int k = 0; k < inner; k++) {
659         bucket_index[j * inner + k] = index + k;
660       }
661       index += bucket_count * inner;
662     }
663     buckets_data_index->push_back(bucket_index);
664   }
665   return RET_OK;
666 }
667 
CheckControlFlowType(const AnfNodePtr & node)668 bool CheckControlFlowType(const AnfNodePtr &node) {
669   if (node == nullptr) {
670     return false;
671   }
672   std::map<std::string, PrimitivePtr> control_flow_ops = {{"PartialFusion", prim::kPrimPartialFusion},
673                                                           {"Switch", prim::kPrimSwitch},
674                                                           {"switch_layer", prim::kPrimSwitchLayer},
675                                                           {"call", prim::kPrimCall}};
676 
677   if (node->isa<mindspore::CNode>()) {
678     auto cnode = node->cast<CNodePtr>();
679     // control flow call
680     if (!IsValueNode<mindspore::Primitive>(cnode->input(kPrimIndex))) {
681       return true;
682     }
683     auto prim = GetValuePtr<mindspore::Primitive>(cnode->input(kPrimIndex));
684     if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
685       return true;
686     }
687   } else if (node->isa<ValueNode>()) {
688     auto prim = GetValuePtr<mindspore::Primitive>(node);
689     if (control_flow_ops.find(prim->name()) != control_flow_ops.end()) {
690       return true;
691     }
692   }
693   return false;
694 }
695 
CloneFuncGraph(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param,FuncGraphPtr * func_graph_bak)696 int CloneFuncGraph(const FuncGraphPtr &func_graph, const std::shared_ptr<ConverterPara> &param,
697                    FuncGraphPtr *func_graph_bak) {
698   CHECK_NULL_RETURN(func_graph_bak);
699   CHECK_NULL_RETURN(param);
700   std::map<FuncGraphPtr, FuncGraphPtr> cloned_func_graph;
701   *func_graph_bak = lite::CloneFuncGraph(func_graph, param, &cloned_func_graph);
702   CHECK_NULL_RETURN(*func_graph_bak);
703   static auto root_func_manager = Manage(*func_graph_bak);
704   std::set<FuncGraphPtr> all_func_graphs = {};
705   lite::GetAllFuncGraph(*func_graph_bak, &all_func_graphs);
706   for (const auto &graph : all_func_graphs) {
707     graph->set_manager(root_func_manager);
708   }
709   return RET_OK;
710 }
711 
MarkOriginDataType(const FuncGraphPtr & func_graph)712 int MarkOriginDataType(const FuncGraphPtr &func_graph) {
713   auto cnodes = func_graph->GetOrderedCnodes();
714   for (auto &cnode : cnodes) {
715     TypeId type_id = kTypeUnknown;
716     if (opt::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
717       continue;
718     }
719     auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
720     if (ret != RET_OK) {
721       MS_LOG(INFO) << "CNode data type is unknown.";
722       return RET_OK;
723     }
724     if (type_id != kTypeUnknown) {
725       MS_LOG(INFO) << cnode->fullname_with_scope() << " origin type is " << type_id;
726       cnode->AddAttr("origin_type", MakeValue(static_cast<int>(type_id)));
727     }
728   }
729   return RET_OK;
730 }
731 
ConvertFp16ToFp32(const FuncGraphPtr & func_graph)732 int ConvertFp16ToFp32(const FuncGraphPtr &func_graph) {
733   auto cnodes = func_graph->GetOrderedCnodes();
734   for (auto &cnode : cnodes) {
735     auto ret = ConvertCNodeFp16ToFp32(cnode);
736     if (ret != RET_OK) {
737       MS_LOG(ERROR) << cnode->fullname_with_scope() << " convert fp16 To fp32 failed.";
738       return ret;
739     }
740   }
741   return RET_OK;
742 }
743 
ConvertCNodeFp32ToFp16(const CNodePtr & cnode)744 int ConvertCNodeFp32ToFp16(const CNodePtr &cnode) {
745   for (size_t i = kPrimOffset; i < cnode->size(); ++i) {
746     auto input = cnode->input(i);
747     if (input->isa<Parameter>() && input->cast<ParameterPtr>()->has_default()) {
748       MS_LOG(ERROR) << cnode->fullname_with_scope() << " Parameter.";
749       ParameterPtr param_node;
750       tensor::TensorPtr tensor_info;
751       GetParameterAndTensor(input, &param_node, &tensor_info);
752       CHECK_NULL_RETURN(tensor_info);
753       CHECK_NULL_RETURN(param_node);
754       if (tensor_info->data_type() == kNumberTypeFloat32) {
755         MS_LOG(INFO) << "convert " << input->fullname_with_scope() << " from fp32 to fp16.";
756         auto data = static_cast<float *>(tensor_info->data_c());
757         std::vector<float16> fp16_data(tensor_info->DataSize());
758         for (size_t j = 0; j < tensor_info->DataSize(); j++) {
759           fp16_data[j] = mindspore::Float16(data[j]);
760         }
761         mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
762           kNumberTypeFloat16, tensor_info->shape_c(), fp16_data.data(), fp16_data.size() * sizeof(float) / 2);
763         param_node->set_default_param(tensor_ptr);
764         param_node->set_abstract(tensor_ptr->ToAbstract());
765       }
766     } else if (input->isa<ValueNode>()) {
767       auto value_node = input->cast<ValueNodePtr>();
768       DataInfo data_info;
769       auto ret = FetchDataFromValueNode(cnode, i, converter::kFmkTypeMs, false, &data_info, false);
770       if (ret != RET_OK) {
771         MS_LOG(ERROR) << "Fetch data from value node failed.";
772         return ret;
773       }
774       std::vector<int64_t> shapes;
775       for (size_t j = 0; j < data_info.shape_.size(); ++j) {
776         shapes.push_back(data_info.shape_.at(j));
777       }
778       int total_size = 0;
779       ret = GetElementNumFromShape(data_info.shape_, &total_size);
780       if (ret != RET_OK) {
781         MS_LOG(ERROR) << "GetElementNumFromShape failed.";
782         return ret;
783       }
784       if (data_info.data_type_ == kNumberTypeFloat32) {
785         MS_LOG(ERROR) << "convert " << input->fullname_with_scope() << " from fp32 to fp16.";
786         auto data = static_cast<float *>(data_info.data_ptr_);
787         std::vector<float16> fp16_data(total_size);
788         for (int j = 0; j < total_size; j++) {
789           fp16_data[j] = mindspore::Float16(data[j]);
790         }
791         mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
792           kNumberTypeFloat16, shapes, fp16_data.data(), fp16_data.size() * sizeof(float) / 2);
793         auto values = MakeValue(tensor_ptr);
794         value_node->set_value(values);
795         value_node->set_abstract(tensor_ptr->ToAbstract());
796       }
797     }
798   }
799   return RET_OK;
800 }
801 
ConvertFp32ToFp16(const FuncGraphPtr & func_graph)802 int ConvertFp32ToFp16(const FuncGraphPtr &func_graph) {
803   auto cnodes = func_graph->GetOrderedCnodes();
804   for (auto &cnode : cnodes) {
805     auto ret = ConvertCNodeFp32ToFp16(cnode);
806     if (ret != RET_OK) {
807       MS_LOG(ERROR) << cnode->fullname_with_scope() << " convert cnode fp32 to fp16.";
808       return ret;
809     }
810   }
811   return RET_OK;
812 }
813 
ConvertCNodeFp16ToFp32(const CNodePtr & cnode)814 int ConvertCNodeFp16ToFp32(const CNodePtr &cnode) {
815   for (size_t i = kPrimOffset; i < cnode->size(); ++i) {
816     auto input = cnode->input(i);
817     if (!input->isa<Parameter>() || !input->cast<ParameterPtr>()->has_default()) {
818       continue;
819     }
820     ParameterPtr param_node;
821     tensor::TensorPtr tensor_info;
822     GetParameterAndTensor(input, &param_node, &tensor_info);
823     CHECK_NULL_RETURN(tensor_info);
824     CHECK_NULL_RETURN(param_node);
825     if (tensor_info->data_type() == kNumberTypeFloat16) {
826       MS_LOG(INFO) << "convert " << input->fullname_with_scope() << " from fp16 to fp32.";
827       auto data = static_cast<float16 *>(tensor_info->data_c());
828       std::vector<float> fp32_data(tensor_info->DataSize());
829       for (size_t j = 0; j < tensor_info->DataSize(); j++) {
830         fp32_data[j] = mindspore::Float16::ToFloat32(data[j]);
831       }
832       mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
833         kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float));
834 
835       tensor::TensorPtr input_tensor = quant::GetNodeTensor(input);
836       MS_CHECK_TRUE_MSG(input_tensor != nullptr, RET_NULL_PTR, "Get node tensor failed.");
837       auto quant_params = input_tensor->quant_params();
838       tensor_ptr->set_quant_param(quant_params);
839 
840       param_node->set_default_param(tensor_ptr);
841       param_node->set_abstract(tensor_ptr->ToAbstract());
842     }
843   }
844   return RET_OK;
845 }
846 
IsPerchannelWeight(const std::vector<schema::QuantParamT> & quant_params,const tensor::TensorPtr & weight,int preferred_dim)847 bool IsPerchannelWeight(const std::vector<schema::QuantParamT> &quant_params, const tensor::TensorPtr &weight,
848                         int preferred_dim) {
849   auto dims = weight->shape();
850   return (static_cast<int>(quant_params.size()) == dims[preferred_dim]);
851 }
852 
ConvertQuantParamTToQuantizationParam(const std::vector<schema::QuantParamT> & quant_params)853 QuantizationParamPtr ConvertQuantParamTToQuantizationParam(const std::vector<schema::QuantParamT> &quant_params) {
854   if (quant_params.empty()) {
855     return nullptr;
856   }
857   QuantizationParam quantization(quant::kLinearQuant);
858   std::vector<ValuePtr> scale_list;
859   std::vector<ValuePtr> zeroPoint_list;
860   std::vector<ValuePtr> min_list;
861   std::vector<ValuePtr> max_list;
862   std::vector<ValuePtr> varCorr_list;
863   std::vector<ValuePtr> meanCorr_list;
864   std::vector<ValuePtr> numBits_list;
865   std::vector<ValuePtr> narrowRange_list;
866   std::vector<ValuePtr> dstDtype_list;
867   std::vector<ValuePtr> roundType_list;
868   std::vector<ValuePtr> multiplier_list;
869   for (auto quant_param : quant_params) {
870     scale_list.push_back(MakeValue(quant_param.scale));
871     zeroPoint_list.push_back(MakeValue(quant_param.zeroPoint));
872     min_list.push_back(MakeValue(quant_param.min));
873     max_list.push_back(MakeValue(quant_param.max));
874     varCorr_list.push_back(MakeValue(quant_param.varCorr));
875     meanCorr_list.push_back(MakeValue(quant_param.meanCorr));
876     numBits_list.push_back(MakeValue(quant_param.numBits));
877     narrowRange_list.push_back(MakeValue(quant_param.narrowRange));
878     dstDtype_list.push_back(MakeValue(quant_param.dstDtype));
879     roundType_list.push_back(MakeValue(quant_param.roundType));
880     multiplier_list.push_back(MakeValue(quant_param.multiplier));
881   }
882   quantization.AddAttr(quant::kScaleList, std::make_shared<ValueList>(scale_list));
883   quantization.AddAttr(quant::kZeroPointList, std::make_shared<ValueList>(zeroPoint_list));
884   quantization.AddAttr(quant::kMinList, std::make_shared<ValueList>(min_list));
885   quantization.AddAttr(quant::kMaxList, std::make_shared<ValueList>(max_list));
886   quantization.AddAttr(quant::kVarCorrList, std::make_shared<ValueList>(varCorr_list));
887   quantization.AddAttr(quant::kMeanCorrList, std::make_shared<ValueList>(meanCorr_list));
888   quantization.AddAttr(quant::kNumBitList, std::make_shared<ValueList>(numBits_list));
889   quantization.AddAttr(quant::kNarrowRangeList, std::make_shared<ValueList>(narrowRange_list));
890   quantization.AddAttr(quant::kDstDtypeList, std::make_shared<ValueList>(dstDtype_list));
891   quantization.AddAttr(quant::kRoundTypeList, std::make_shared<ValueList>(roundType_list));
892   quantization.AddAttr(quant::kMultiplierList, std::make_shared<ValueList>(multiplier_list));
893   return std::make_shared<mindspore::QuantizationParam>(quantization);
894 }
895 
ConvertQuantizationParamToQuantParamT(const QuantizationParamPtr & quantization_param)896 std::vector<schema::QuantParamT> ConvertQuantizationParamToQuantParamT(const QuantizationParamPtr &quantization_param) {
897   std::vector<schema::QuantParamT> quant_params;
898   if (quantization_param == nullptr) {
899     return quant_params;
900   }
901   auto scale_list_attr = quantization_param->GetAttr(quant::kScaleList);
902   auto zero_point_list_attr = quantization_param->GetAttr(quant::kZeroPointList);
903   auto min_list_attr = quantization_param->GetAttr(quant::kMinList);
904   auto max_list_attr = quantization_param->GetAttr(quant::kMaxList);
905   auto var_corr_list_attr = quantization_param->GetAttr(quant::kVarCorrList);
906   auto mean_corr_list_attr = quantization_param->GetAttr(quant::kMeanCorrList);
907   auto num_bits_list_attr = quantization_param->GetAttr(quant::kNumBitList);
908   auto narrow_range_list_attr = quantization_param->GetAttr(quant::kNarrowRangeList);
909   auto dst_dtype_list_attr = quantization_param->GetAttr(quant::kDstDtypeList);
910   auto round_type_list_attr = quantization_param->GetAttr(quant::kRoundTypeList);
911   auto multiplier_list_attr = quantization_param->GetAttr(quant::kMultiplierList);
912   if (scale_list_attr != nullptr && zero_point_list_attr != nullptr && min_list_attr != nullptr &&
913       max_list_attr != nullptr && var_corr_list_attr != nullptr && mean_corr_list_attr != nullptr &&
914       num_bits_list_attr != nullptr && narrow_range_list_attr != nullptr) {
915     auto scales = GetValue<std::vector<double>>(scale_list_attr);
916     auto zero_points = GetValue<std::vector<int32_t>>(zero_point_list_attr);
917     auto mins = GetValue<std::vector<double>>(min_list_attr);
918     auto maxs = GetValue<std::vector<double>>(max_list_attr);
919     auto var_corrs = GetValue<std::vector<float>>(var_corr_list_attr);
920     auto mean_corrs = GetValue<std::vector<float>>(mean_corr_list_attr);
921     auto num_bits_list = GetValue<std::vector<int32_t>>(num_bits_list_attr);
922     auto narrow_range_list = GetValue<std::vector<bool>>(narrow_range_list_attr);
923     auto dst_dtype_list = GetValue<std::vector<int32_t>>(dst_dtype_list_attr);
924     auto round_type_list = GetValue<std::vector<int32_t>>(round_type_list_attr);
925     auto multiplier_list = GetValue<std::vector<int32_t>>(multiplier_list_attr);
926     for (size_t index = 0; index < scales.size(); ++index) {
927       schema::QuantParamT quant_param;
928       quant_param.scale = scales.at(index);
929       quant_param.zeroPoint = zero_points.at(index);
930       quant_param.min = mins.at(index);
931       quant_param.max = maxs.at(index);
932       quant_param.varCorr = var_corrs.at(index);
933       quant_param.meanCorr = mean_corrs.at(index);
934       quant_param.numBits = num_bits_list.at(index);
935       quant_param.narrowRange = narrow_range_list.at(index);
936       quant_param.dstDtype = dst_dtype_list.at(index);
937       quant_param.roundType = round_type_list.at(index);
938       quant_param.multiplier = multiplier_list.at(index);
939       quant_param.inited = true;
940       quant_params.push_back(quant_param);
941     }
942   }
943   return quant_params;
944 }
945 
RemoveInputNodeQuantParam(const CNodePtr & cnode,size_t index)946 int RemoveInputNodeQuantParam(const CNodePtr &cnode, size_t index) {
947   if (cnode->size() <= index) {
948     MS_LOG(ERROR) << "index out of range, cnode input size is: " << cnode->size() << ", but index: " << index;
949     return RET_ERROR;
950   }
951   auto input_node = cnode->input(index);
952   CHECK_NULL_RETURN(input_node);
953   auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
954   if (IsGraphInput(input_node)) {
955     if (cnode_primitive->HasAttr(quant::kGraphInputQuantParam)) {
956       cnode_primitive->EraseAttr(quant::kGraphInputQuantParam);
957     }
958   } else if (input_node->isa<mindspore::CNode>()) {
959     if (cnode_primitive->HasAttr(quant::kQuantParam)) {
960       cnode_primitive->EraseAttr(quant::kQuantParam);
961     }
962   } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
963     auto input_tensor = quant::GetNodeTensor(input_node);
964     CHECK_NULL_RETURN(input_tensor);
965     input_tensor->set_quant_param({});
966   } else {
967     MS_LOG(ERROR) << input_node->fullname_with_scope() << " index: " << index << " is not support "
968                   << input_node->type_name() << " type.";
969     return RET_ERROR;
970   }
971   return RET_OK;
972 }
973 
GetInputNodeQuantParam(const CNodePtr & cnode,size_t index,size_t multi_ouput_index)974 std::vector<schema::QuantParamT> GetInputNodeQuantParam(const CNodePtr &cnode, size_t index, size_t multi_ouput_index) {
975   if (cnode->size() <= index) {
976     MS_LOG(WARNING) << "index out of range, cnode input size is: " << cnode->size() << ", but index: " << index;
977     return {};
978   }
979   auto input_node = cnode->input(index);
980   MS_CHECK_TRUE_MSG(input_node != nullptr, {}, "Anf node nullptr.");
981   auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
982   MS_CHECK_TRUE_MSG(cnode_primitive != nullptr, {}, "Primitive is nullptr.");
983   if (IsGraphInput(input_node)) {
984     auto quantization_param_value = cnode_primitive->GetAttr(quant::kGraphInputQuantParam);
985     if (quantization_param_value == nullptr) {
986       MS_LOG(WARNING) << input_node->fullname_with_scope() << " quant param Not exist.";
987       return {};
988     }
989     auto quantization_param = quantization_param_value->cast<mindspore::QuantizationParamPtr>();
990     MS_CHECK_TRUE_MSG(quantization_param != nullptr, {}, "Graph input quant param Not exist.");
991     return quant::ConvertQuantizationParamToQuantParamT(quantization_param);
992   } else if (input_node->isa<mindspore::CNode>()) {
993     auto input_cnode = input_node->cast<mindspore::CNodePtr>();
994     auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
995     MS_CHECK_TRUE_MSG(input_cnode_primitive != nullptr, {}, "Primitive is nullptr.");
996     if (!input_cnode_primitive->HasAttr(quant::kQuantParam)) {
997       MS_LOG(WARNING) << input_node->fullname_with_scope() << " dont have quant param.";
998       return {};
999     }
1000     auto quantization_param_value = input_cnode_primitive->GetAttr(quant::kQuantParam);
1001     MS_CHECK_TRUE_MSG(quantization_param_value != nullptr, {}, "quantization_param_value is nullptr.");
1002     auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
1003     if (quantization_param_list.size() <= multi_ouput_index) {
1004       MS_LOG(WARNING) << "This node's input node: " << input_cnode->fullname_with_scope()
1005                       << "'s output quant_params size: " << quantization_param_list.size()
1006                       << ", but index: " << multi_ouput_index;
1007       return {};
1008     }
1009     // multi-output
1010     return quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.at(multi_ouput_index));
1011   } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
1012     tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
1013     MS_CHECK_TRUE_MSG(input_tensor != nullptr, {}, "Get node tensor failed.");
1014     auto quantization_params = input_tensor->quant_params();
1015     if (quantization_params.empty()) {
1016       MS_LOG(WARNING) << input_node->fullname_with_scope() << " quantization param is empty.";
1017       return {};
1018     }
1019     auto quantization_param = quantization_params.front();
1020     return quant::ConvertQuantizationParamToQuantParamT(quantization_param);
1021   } else {
1022     MS_LOG(ERROR) << cnode->fullname_with_scope() << " input node with index: " << index
1023                   << " Not supported for quant param";
1024   }
1025   return {};
1026 }
1027 
SetInputNodeQuantParam(const CNodePtr & cnode,size_t index,const std::vector<schema::QuantParamT> & quant_param)1028 STATUS SetInputNodeQuantParam(const CNodePtr &cnode, size_t index,
1029                               const std::vector<schema::QuantParamT> &quant_param) {
1030   auto input_node = cnode->input(index);
1031   MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "Anf node nullptr.");
1032   if (IsGraphInput(input_node)) {
1033     auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(kPrimIndex));
1034     MS_CHECK_TRUE_MSG(cnode_primitive != nullptr, RET_NULL_PTR, "Primitive is nullptr.");
1035     auto quantization_param = quant::ConvertQuantParamTToQuantizationParam(quant_param);
1036     cnode_primitive->AddAttr(quant::kGraphInputQuantParam, quantization_param);
1037   } else if (input_node->isa<mindspore::CNode>()) {
1038     auto input_cnode = input_node->cast<mindspore::CNodePtr>();
1039     auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
1040     MS_CHECK_TRUE_MSG(input_cnode_primitive != nullptr, RET_NULL_PTR, "Primitive is nullptr.");
1041     auto quantization_param = ConvertQuantParamTToQuantizationParam(quant_param);
1042     std::vector<ValuePtr> quantization_list{quantization_param};
1043     input_cnode_primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
1044   } else if (input_node->isa<mindspore::Parameter>() || input_node->isa<mindspore::ValueNode>()) {
1045     tensor::TensorPtr input_tensor = quant::GetNodeTensor(input_node);
1046     MS_CHECK_TRUE_MSG(input_tensor != nullptr, RET_NULL_PTR, "Get node tensor failed.");
1047     auto quantization_param = quant::ConvertQuantParamTToQuantizationParam(quant_param);
1048     CHECK_NULL_RETURN(quantization_param);
1049     input_tensor->set_quant_param(std::vector<std::shared_ptr<mindspore::QuantizationParam>>{quantization_param});
1050   } else {
1051     MS_LOG(WARNING) << input_node->fullname_with_scope() << " Not supported type.";
1052     return RET_ERROR;
1053   }
1054   return RET_OK;
1055 }
1056 
GetNodeTensor(const AnfNodePtr & node)1057 tensor::TensorPtr GetNodeTensor(const AnfNodePtr &node) {
1058   // Only Parameter or ValueNode Node has tensor
1059   if (node->isa<Parameter>()) {
1060     auto parameter = node->cast<ParameterPtr>();
1061     if (parameter->default_param() != nullptr) {
1062       return parameter->default_param()->cast<tensor::TensorPtr>();
1063     }
1064   } else if (node->isa<ValueNode>()) {
1065     return node->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
1066   }
1067   return nullptr;
1068 }
1069 
CloneQuantParam(const std::vector<schema::QuantParamT> & src)1070 std::vector<schema::QuantParamT> CloneQuantParam(const std::vector<schema::QuantParamT> &src) {
1071   MS_CHECK_TRUE_MSG(!src.empty(), {}, "Src is empty.");
1072   std::vector<schema::QuantParamT> dst;
1073   for (auto &quant_param : src) {
1074     schema::QuantParamT quant_param_clone;
1075     quant_param_clone.scale = quant_param.scale;
1076     quant_param_clone.zeroPoint = quant_param.zeroPoint;
1077     quant_param_clone.numBits = quant_param.numBits;
1078     quant_param_clone.narrowRange = quant_param.narrowRange;
1079     quant_param_clone.meanCorr = quant_param.meanCorr;
1080     quant_param_clone.varCorr = quant_param.varCorr;
1081     quant_param_clone.dstDtype = quant_param.dstDtype;
1082     quant_param_clone.min = quant_param.min;
1083     quant_param_clone.max = quant_param.max;
1084     quant_param_clone.roundType = quant_param.roundType;
1085     quant_param_clone.multiplier = quant_param.multiplier;
1086     dst.push_back(quant_param_clone);
1087   }
1088   return dst;
1089 }
1090 
CalBiasQuantParams(const std::vector<schema::QuantParamT> & active_params,const std::vector<schema::QuantParamT> & weight_params,std::vector<schema::QuantParamT> * bias_params)1091 int CalBiasQuantParams(const std::vector<schema::QuantParamT> &active_params,
1092                        const std::vector<schema::QuantParamT> &weight_params,
1093                        std::vector<schema::QuantParamT> *bias_params) {
1094   std::vector<double> input_scales;
1095   std::vector<double> filter_scales;
1096   std::vector<double> bias_scales;
1097   size_t sizeX = active_params.size();
1098   for (size_t i = 0; i < sizeX; i++) {
1099     input_scales.emplace_back(active_params[i].scale);
1100   }
1101   size_t sizeY = weight_params.size();
1102   if (sizeX != sizeY) {
1103     if (sizeX > 1 && sizeY > 1) {
1104       MS_LOG(ERROR) << "input and filter's scale count cannot match!";
1105       return RET_ERROR;
1106     }
1107   }
1108   for (size_t i = 0; i < sizeY; i++) {
1109     filter_scales.emplace_back(weight_params[i].scale);
1110   }
1111   size_t size = std::max(sizeX, sizeY);
1112   for (size_t i = 0; i < size; i++) {
1113     auto scaleX = sizeX > 1 ? input_scales[i] : input_scales[0];
1114     auto scaleY = sizeY > 1 ? filter_scales[i] : filter_scales[0];
1115     bias_scales.push_back(scaleX * scaleY);
1116   }
1117   MS_ASSERT(!bias_scales.empty());
1118 
1119   // set bias quant param
1120   for (double bias_scale : bias_scales) {
1121     schema::QuantParamT quant_param;
1122     if (bias_scale == 0) {
1123       MS_LOG(WARNING) << "bias_scale is 0, and set bias_scale to 1.";
1124       quant_param.scale = 1;
1125     } else {
1126       quant_param.scale = bias_scale;
1127     }
1128     quant_param.numBits = k32Bit;
1129     quant_param.zeroPoint = 0;
1130     quant_param.inited = true;
1131     bias_params->push_back(quant_param);
1132   }
1133   return RET_OK;
1134 }
1135 
IsAntiQuantModeNodes(const AnfNodePtr & node)1136 bool IsAntiQuantModeNodes(const AnfNodePtr &node) {
1137   CHECK_NULL_RETURN(node);
1138   if (!utils::isa<CNodePtr>(node) || !opt::CheckPrimitiveType(node, prim::kPrimMul)) {
1139     MS_LOG(INFO) << "The node is not Mul node";
1140     return false;
1141   }
1142   auto add_node = node->cast<CNodePtr>()->input(kIndexOne);
1143   if (!utils::isa<CNodePtr>(add_node) || !opt::CheckPrimitiveType(add_node, prim::kPrimAdd)) {
1144     MS_LOG(INFO) << "The node is not Add node. ";
1145     return false;
1146   }
1147   auto ascend_antiquant_node = add_node->cast<CNodePtr>()->input(kIndexOne);
1148   if (!utils::isa<CNodePtr>(ascend_antiquant_node) ||
1149       !(opt::CheckPrimitiveType(ascend_antiquant_node, prim::kPrimAntiQuant) ||
1150         GetCNodePrimitive(ascend_antiquant_node)->name() == "AscendAntiQuant")) {
1151     MS_LOG(INFO) << "The node is not AscendAntiquant node";
1152     return false;
1153   }
1154   return true;
1155 }
1156 
GetScaleZpFromAntiQuantModeNodes(const AnfNodePtr & node,ParameterPtr * scale_param_node,ParameterPtr * zp_param_node)1157 STATUS GetScaleZpFromAntiQuantModeNodes(const AnfNodePtr &node, ParameterPtr *scale_param_node,
1158                                         ParameterPtr *zp_param_node) {
1159   CHECK_NULL_RETURN(node);
1160   CHECK_NULL_RETURN(scale_param_node);
1161   CHECK_NULL_RETURN(zp_param_node);
1162 
1163   if (!utils::isa<CNodePtr>(node) || !opt::CheckPrimitiveType(node, prim::kPrimMul)) {
1164     return RET_ERROR;
1165   }
1166   auto add_node = node->cast<CNodePtr>()->input(kIndexOne);
1167   auto scale_param = node->cast<CNodePtr>()->input(kIndexTwo);
1168   if (opt::CheckPrimitiveType(scale_param, prim::kPrimLoad)) {
1169     scale_param = scale_param->cast<CNodePtr>()->input(kIndexOne);
1170   }
1171   *scale_param_node = scale_param->cast<ParameterPtr>();
1172   CHECK_NULL_RETURN(*scale_param_node);
1173   if (!utils::isa<CNodePtr>(add_node) || !opt::CheckPrimitiveType(add_node, prim::kPrimAdd)) {
1174     return RET_ERROR;
1175   }
1176   auto zp_param = add_node->cast<CNodePtr>()->input(kIndexTwo);
1177   if (opt::CheckPrimitiveType(zp_param, prim::kPrimLoad)) {
1178     zp_param = zp_param->cast<CNodePtr>()->input(kIndexOne);
1179   }
1180   *zp_param_node = zp_param->cast<ParameterPtr>();
1181   CHECK_NULL_RETURN(*zp_param_node);
1182   return RET_OK;
1183 }
1184 
RemoveAntiQuantModeNodes(const FuncGraphPtr & func_graph,const AnfNodePtr & node,int index)1185 STATUS RemoveAntiQuantModeNodes(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int index) {
1186   CHECK_NULL_RETURN(func_graph);
1187   CHECK_NULL_RETURN(node);
1188 
1189   auto manager = func_graph->manager();
1190   if (manager == nullptr) {
1191     manager = Manage(func_graph, true);
1192   }
1193   CHECK_NULL_RETURN(manager);
1194 
1195   if (!utils::isa<CNodePtr>(node)) {
1196     MS_LOG(ERROR) << "The node : " << node->fullname_with_scope() << ", it is not cnode";
1197     return lite::RET_ERROR;
1198   }
1199   auto cnode = node->cast<CNodePtr>();
1200   CHECK_NULL_RETURN(cnode);
1201 
1202   auto mul_node = cnode->input(index);
1203 
1204   if (!utils::isa<CNodePtr>(mul_node) || !opt::CheckPrimitiveType(mul_node, prim::kPrimMul)) {
1205     MS_LOG(WARNING) << "In AntiQuant mode, the node : " << cnode->fullname_with_scope() << " is not mul node";
1206     return RET_OK;
1207   }
1208   auto add_node = mul_node->cast<CNodePtr>()->input(kIndexOne);
1209   if (!opt::CheckPrimitiveType(add_node, prim::kPrimAdd)) {
1210     MS_LOG(WARNING) << "In AntiQuant mode, the node : " << add_node->fullname_with_scope() << " is not add node";
1211     return RET_OK;
1212   }
1213   auto ascend_antiquant_node = add_node->cast<CNodePtr>()->input(kIndexOne);
1214   if (!(opt::CheckPrimitiveType(ascend_antiquant_node, prim::kPrimAntiQuant) ||
1215         GetCNodePrimitive(ascend_antiquant_node)->name() == "AscendAntiQuant")) {
1216     MS_LOG(WARNING) << "In AntiQuant mode, the node : " << ascend_antiquant_node->fullname_with_scope()
1217                     << " is not antiquant node";
1218     return RET_OK;
1219   }
1220 
1221   manager->Replace(mul_node, ascend_antiquant_node->cast<CNodePtr>()->input(1));
1222   return lite::RET_OK;
1223 }
1224 
ExtractStrategy(const ValuePtr & stra)1225 std::vector<std::vector<int64_t>> ExtractStrategy(const ValuePtr &stra) {
1226   if (stra == nullptr) {
1227     return {};
1228   }
1229 
1230   auto var = stra->cast<ValueTuplePtr>();
1231   if (var == nullptr) {
1232     return {};
1233   }
1234   std::vector<std::vector<int64_t>> strategy;
1235   MS_LOG(INFO) << "Extract information: strategy " << stra->ToString();
1236   if (var->size() > 0) {
1237     std::vector<ValuePtr> elements = var->value();
1238     for (uint64_t index = 0; index < elements.size(); ++index) {
1239       std::vector<int64_t> dim;
1240       if (elements[index]->isa<ValueSequence>()) {
1241         auto value_tuple = elements[index]->cast<ValueTuplePtr>();
1242         std::vector<ValuePtr> value_vector = value_tuple->value();
1243         (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
1244                              [](const ValuePtr &value) { return static_cast<int64_t>(GetValue<int64_t>(value)); });
1245         strategy.push_back(dim);
1246       } else {
1247         MS_LOG(EXCEPTION) << "Failure: Strategy's format is wrong! Need ValueSequence";
1248       }
1249     }
1250     if (strategy.empty()) {
1251       MS_LOG(EXCEPTION) << "ExtractStrategy: failed to extract strategy";
1252     }
1253   }
1254 
1255   return strategy;
1256 }
1257 
CalQuantParamWithMinMax(const tensor::TensorPtr & min_value,const tensor::TensorPtr & max_value,bool symmetric)1258 std::vector<schema::QuantParamT> CalQuantParamWithMinMax(const tensor::TensorPtr &min_value,
1259                                                          const tensor::TensorPtr &max_value, bool symmetric) {
1260   std::vector<schema::QuantParamT> quant_params;
1261   // Ascend fake quant transform support PerLayer && PerChannel quant param
1262   if (min_value->ElementsNum() != max_value->ElementsNum()) {
1263     MS_LOG(ERROR) << "min value size not equal max value size";
1264     return {};
1265   }
1266   int size = min_value->ElementsNum();
1267   auto min_data = reinterpret_cast<float *>(min_value->data_c());
1268   auto max_data = reinterpret_cast<float *>(max_value->data_c());
1269   for (int i = 0; i < size; i++) {
1270     float real_min = *(min_data + i);
1271     float real_max = *(max_data + i);
1272     schema::QuantParamT quant_param;
1273     int bit_num = k8Bit;
1274 
1275     MS_LOG(DEBUG) << "min: " << real_min << " max: " << real_max << " bit_num: " << bit_num << " symmetric"
1276                   << symmetric;
1277     auto ret = CalQuantizationParams(&quant_param, real_min, real_max, bit_num, symmetric);
1278     if (ret != RET_OK) {
1279       MS_LOG(ERROR) << "Failed to calculate quant params";
1280       return {};
1281     }
1282     MS_LOG(INFO) << "quant param scale: " << quant_param.scale << " zp: " << quant_param.zeroPoint;
1283     quant_params.push_back(quant_param);
1284   }
1285   return quant_params;
1286 }
1287 
GetQuantParamWithFakeQuantNode(const CNodePtr & fake_quant_node,bool symmetric)1288 std::vector<schema::QuantParamT> GetQuantParamWithFakeQuantNode(const CNodePtr &fake_quant_node, bool symmetric) {
1289   tensor::TensorPtr min_value;
1290   tensor::TensorPtr max_value;
1291   auto min_input = fake_quant_node->input(kFakeQuantMinIndex + kPrimOffset);
1292   if (utils::isa<ParameterPtr>(min_input) && min_input->cast<ParameterPtr>()->has_default() &&
1293       min_input->cast<ParameterPtr>()->default_param() != nullptr) {
1294     min_value = min_input->cast<ParameterPtr>()->default_param()->cast<tensor::TensorPtr>();
1295   } else {
1296     MS_LOG(ERROR) << "Quant param get min value failed";
1297     return {};
1298   }
1299   auto max_input = fake_quant_node->input(kFakeQuantMaxIndex + kPrimOffset);
1300   if (utils::isa<ParameterPtr>(max_input) && max_input->cast<ParameterPtr>()->has_default() &&
1301       max_input->cast<ParameterPtr>()->default_param() != nullptr) {
1302     max_value = max_input->cast<ParameterPtr>()->default_param()->cast<tensor::TensorPtr>();
1303   } else {
1304     MS_LOG(ERROR) << "Quant param get max value failed";
1305     return {};
1306   }
1307   auto quant_params = CalQuantParamWithMinMax(min_value, max_value, symmetric);
1308   return quant_params;
1309 }
1310 
1311 }  // namespace mindspore::lite::quant
1312