• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h"
20 #include <memory>
21 #include <set>
22 #include <vector>
23 #include <string>
24 #include <algorithm>
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/lite_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "tools/optimizer/graph/node_infershape.h"
29 #include "tools/optimizer/common/gllo_utils.h"
30 #include "tools/optimizer/common/format_utils.h"
31 #include "tools/common/node_util.h"
32 #include "tools/common/tensor_util.h"
33 #include "tools/converter/quantizer/fse_decoder.h"
34 #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
35 #include "ops/fse_decode.h"
36 #include "ops/op_name.h"
37 #include "ops/auto_generate/gen_lite_ops.h"
38 #include "ops/fusion/mul_fusion.h"
39 #include "ops/fusion/add_fusion.h"
40 #include "ops/fusion/mat_mul_fusion.h"
41 #include "ops/array_ops.h"
42 #include "ir/dtype.h"
43 
44 namespace mindspore::lite::quant {
45 namespace {
46 constexpr size_t kMinSize2 = 2;
47 constexpr size_t kMinSize3 = 3;
48 constexpr size_t kTableExtend = 3;
49 constexpr size_t kAlignOffset = 7;
50 constexpr size_t kInt32Mask = 31;
51 constexpr int kLastFisrtIndex = -1;
52 constexpr int kLastSecondIndex = -2;
53 const char *ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding";
54 constexpr char IN_STRATEGY[] = "in_strategy";
55 }  // namespace
SetCastNodeAbstract(const CNodePtr & cnode,const AnfNodePtr & input_node,const CNodePtr & cast_cnode)56 int InsertQuantNodeManager::SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node,
57                                                 const CNodePtr &cast_cnode) {
58   CHECK_NULL_RETURN(cnode);
59   CHECK_NULL_RETURN(input_node);
60   CHECK_NULL_RETURN(cast_cnode);
61 
62   AbstractBasePtr abstract;
63   if (cnode->abstract() != nullptr) {
64     abstract = cnode->abstract()->Clone();
65   } else if (input_node->abstract() != nullptr) {
66     abstract = input_node->abstract()->Clone();
67   } else {
68     MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope()
69                   << " input node: " << input_node->fullname_with_scope();
70     return RET_NULL_PTR;
71   }
72   cast_cnode->set_abstract(abstract);
73   return RET_OK;
74 }
75 
76 // If dtype can be fetched, check data type, otherwise return RET_OK
CheckDataType(const AnfNodePtr & input_node,TypeId check_type_id) const77 int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const {
78   bool is_graph_input = IsGraphInput(input_node);
79   if (!input_node->isa<mindspore::CNode>() && !is_graph_input) {
80     return RET_NO_CHANGE;
81   }
82   bool is_special_node =
83     input_node->isa<mindspore::CNode>() && opt::IsSpecialType(input_node->cast<mindspore::CNodePtr>());
84   if (!is_special_node || is_graph_input) {
85     TypeId type_id;
86     auto ret = opt::GetDataTypeFromAnfNode(input_node, &type_id);
87     if (ret != RET_OK) {
88       MS_LOG(WARNING) << "Fetch DataType from cnode failed.";
89       return RET_OK;
90     }
91     if (type_id != check_type_id) {
92       return RET_NO_CHANGE;
93     }
94   }
95   return RET_OK;
96 }
97 
InsertDynamicQuantWithIndex(const FuncGraphPtr & graph,const CNodePtr & cnode,size_t index,bool activation_channel)98 int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index,
99                                                         bool activation_channel) {
100   auto primitive = std::make_shared<ops::DynamicQuant>();
101   CHECK_NULL_RETURN(primitive);
102   auto primitive_c = primitive->GetPrim();
103   primitive->set_dst_type(dst_type_);
104   bool symmetric = activation_channel ? true : false;
105   primitive->set_symmetric(symmetric);
106   primitive->set_activation_channel(activation_channel);
107   if (activation_channel && SetPreferAxes(cnode, index, primitive) != RET_OK) {
108     MS_LOG(ERROR) << "Set prefer axis failed, " << cnode->fullname_with_scope();
109     return RET_ERROR;
110   }
111   auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(index)});
112   CHECK_NULL_RETURN(dynamic_quant_cnode);
113   auto name = cnode->fullname_with_scope() + "_dynamic_cast_node_" + std::to_string(index);
114   dynamic_quant_cnode->set_fullname_with_scope(name);
115   CHECK_NULL_RETURN(cnode->abstract());
116   auto abstract = cnode->abstract()->Clone();
117   if (abstract == nullptr) {
118     MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
119     return RET_NULL_PTR;
120   }
121   dynamic_quant_cnode->set_abstract(abstract);
122   abstract->set_shape(cnode->input(index)->Shape());
123   auto ret = UpdateDataType(dynamic_quant_cnode, dst_type_);
124   if (ret != RET_OK) {
125     MS_LOG(ERROR) << cnode->fullname_with_scope() << " set new dtype failed.";
126     return ret;
127   }
128   ret = MarkDynamicQuantize(dynamic_quant_cnode);
129   if (ret != RET_OK) {
130     MS_LOG(ERROR) << cnode->fullname_with_scope() << " mark quant type failed.";
131     return ret;
132   }
133   cnode->set_input(index, dynamic_quant_cnode);
134   return RET_OK;
135 }
136 
SetPreferAxes(const CNodePtr & cnode,size_t index,const std::shared_ptr<ops::DynamicQuant> & dynamic_primitive)137 int InsertQuantNodeManager::SetPreferAxes(const CNodePtr &cnode, size_t index,
138                                           const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive) {
139   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
140   if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) {
141     auto matmul_prim = api::MakeShared<ops::MatMul>(primitive);
142     CHECK_NULL_RETURN(matmul_prim);
143     auto shape = opt::GetAnfNodeOutputShape(cnode->input(index), 0);
144     std::vector<int> prefer_axes;
145     for (int i = 0; i < static_cast<int>(shape.size()) - C2NUM; ++i) {
146       prefer_axes.push_back(i);
147     }
148     // For MatMul A
149     if (index == kInputIndex + kPrimOffset) {
150       if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) {
151         prefer_axes.push_back(kLastFisrtIndex);
152         dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
153         dynamic_primitive->set_transpose(true);
154       } else {
155         prefer_axes.push_back(kLastSecondIndex);
156         dynamic_primitive->set_prefer_axis(kLastSecondIndex);
157         dynamic_primitive->set_transpose(false);
158       }
159     }
160     // For MatMul B
161     if (index == kWeightIndex + kPrimOffset) {
162       if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) {
163         prefer_axes.push_back(kLastSecondIndex);
164         dynamic_primitive->set_prefer_axis(kLastSecondIndex);
165         dynamic_primitive->set_transpose(true);
166       } else {
167         prefer_axes.push_back(kLastFisrtIndex);
168         dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
169         dynamic_primitive->set_transpose(false);
170       }
171     }
172     dynamic_primitive->set_prefer_axes(prefer_axes);
173   } else {
174     MS_LOG(WARNING) << "cnode don't need prefer axis, cnode name: " << cnode->fullname_with_scope();
175   }
176   return RET_OK;
177 }
178 
NewDynamicQuantNode(const FuncGraphPtr & graph,const CNodePtr & cnode,bool activation_channel)179 int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
180                                                 bool activation_channel) {
181   auto op_name = cnode->fullname_with_scope();
182   if (cnode->size() < kMinSize3) {
183     MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
184     return RET_ERROR;
185   }
186   auto input = cnode->input(kInputIndex + kPrimOffset);
187   auto weight = cnode->input(kWeightIndex + kPrimOffset);
188   if (activation_channel && (input->isa<mindspore::CNode>() || IsGraphInput(input)) &&
189       (weight->isa<mindspore::CNode>() || IsGraphInput(weight))) {
190     return RET_NOT_SUPPORT;
191   }
192   if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
193     auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimOffset, activation_channel);
194     if (ret != RET_OK) {
195       MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
196     }
197   }
198   if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
199     auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimOffset, activation_channel);
200     if (ret != RET_OK) {
201       MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
202     }
203   }
204   return RET_OK;
205 }
206 
MarkDynamicQuantize(const CNodePtr & cnode)207 int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
208   CHECK_NULL_RETURN(cnode);
209   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
210   CHECK_NULL_RETURN(primitive);
211   auto quant_param_holder = GetCNodeQuantHolder(primitive);
212   quant_param_holder->set_quant_type(quant::QUANT_DYNAMIC);
213   return RET_OK;
214 }
215 
InsertDynamicQuantNode(const FuncGraphPtr & graph,const std::set<PrimitivePtr> & support_dynamic_quant_ops,const std::set<std::string> & skip_quant_node,bool activation_channel)216 int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
217                                                    const std::set<PrimitivePtr> &support_dynamic_quant_ops,
218                                                    const std::set<std::string> &skip_quant_node,
219                                                    bool activation_channel) {
220   CHECK_NULL_RETURN(graph);
221   auto cnodes = graph->GetOrderedCnodes();
222   for (auto &cnode : cnodes) {
223     auto op_name = cnode->fullname_with_scope();
224     if (skip_quant_node.find(op_name) != skip_quant_node.end()) {
225       MS_LOG(INFO) << op_name << " is skip dynamic quant.";
226       continue;
227     }
228     auto ret = CheckDataType(cnode, kNumberTypeFloat32);
229     if (ret == RET_NO_CHANGE) {
230       continue;
231     }
232     if (opt::IsSpecialType(cnode)) {
233       continue;
234     }
235     auto is_support_node = CheckNodeInSet(cnode, support_dynamic_quant_ops);
236     if (!is_support_node) {
237       auto type = NodePrimitiveType(cnode);
238       MS_LOG(INFO) << "node:" << op_name << " type:" << type << " will not quantify.";
239       continue;
240     }
241     ret = NewDynamicQuantNode(graph, cnode, activation_channel);
242     if (ret == RET_NOT_SUPPORT) {
243       continue;
244     }
245     if (ret != RET_OK) {
246       MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed.";
247       return ret;
248     }
249     ret = MarkDynamicQuantize(cnode);
250     if (ret != RET_OK) {
251       MS_LOG(ERROR) << "node:" << op_name << " new mark dynamic quant node failed.";
252       return ret;
253     }
254     ret = UpdateDataType(cnode, kNumberTypeFloat32);
255     if (ret != RET_OK) {
256       MS_LOG(ERROR) << "node:" << op_name << " update datatype failed.";
257       return ret;
258     }
259   }
260   return RET_OK;
261 }
262 
InsertDequantNode(const FuncGraphPtr & graph)263 int InsertQuantNodeManager::InsertDequantNode(const FuncGraphPtr &graph) {
264   CHECK_NULL_RETURN(graph);
265   auto cnodes = graph->GetOrderedCnodes();
266   for (auto &cnode : cnodes) {
267     quant::QuantType curr_quant_type;
268     if (GetQuantType(cnode, &curr_quant_type) != RET_OK) {
269       MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
270       return RET_ERROR;
271     }
272     if (curr_quant_type != quant::QUANT_ALL) {
273       MS_LOG(INFO) << "Invalid cnode quant type, cnode name: " << cnode->fullname_with_scope()
274                    << " quant type: " << curr_quant_type;
275       continue;
276     }
277     auto status = InsertForwardCastNode(graph, cnode, kNumberTypeFloat32, curr_quant_type);
278     if (status != RET_OK) {
279       MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
280       return status;
281     }
282     // DetectionPostProcess op(Uint8toFp32, not need backward cast node)
283     if (!CheckNodeInSet(cnode, kUint8toFP32Operator)) {
284       status = InsertBackwardCastNode(graph, cnode, kNumberTypeFloat32, curr_quant_type);
285       if (status != RET_OK) {
286         MS_LOG(ERROR) << "InsertBackwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
287         return status;
288       }
289     }
290   }  // for
291   return RET_OK;
292 }
293 
InsertQuantDtypeCastNodeNew(const FuncGraphPtr & graph,const CNodePtr & cnode,InsertDirection insert_direction,TypeId cast_dtype,CastNodeType cast_node_type,size_t index,const AnfNodePtr & output_node)294 int InsertQuantNodeManager::InsertQuantDtypeCastNodeNew(const FuncGraphPtr &graph, const CNodePtr &cnode,
295                                                         InsertDirection insert_direction, TypeId cast_dtype,
296                                                         CastNodeType cast_node_type, size_t index,
297                                                         const AnfNodePtr &output_node) {
298   CHECK_NULL_RETURN(graph);
299   CHECK_NULL_RETURN(cnode);
300   if (insert_direction == FORWARD) {
301     return InsertForwardQuantNodeNew(graph, cnode, cast_dtype, index, cast_node_type);
302   } else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
303     return InsertBackwardDeQuantNode(graph, cnode, cast_dtype, index, output_node);
304   }
305   MS_LOG(ERROR) << "Invalid insert direction: " << insert_direction;
306   return RET_NOT_SUPPORT;
307 }
308 
InsertQuantDtypeCastNode(const FuncGraphPtr & graph,const CNodePtr & cnode,InsertDirection insert_direction,TypeId cast_dtype,CastNodeType cast_node_type,size_t index,const AnfNodePtr & output_node)309 int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
310                                                      InsertDirection insert_direction, TypeId cast_dtype,
311                                                      CastNodeType cast_node_type, size_t index,
312                                                      const AnfNodePtr &output_node) {
313   CHECK_NULL_RETURN(graph);
314   CHECK_NULL_RETURN(cnode);
315   if (insert_direction == FORWARD) {
316     return InsertForwardQuantNode(graph, cnode, cast_dtype, index, cast_node_type);
317   } else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
318     return InsertBackwardDeQuantNode(graph, cnode, cast_dtype, index, output_node);
319   }
320   MS_LOG(ERROR) << "Invalid insert direction: " << insert_direction;
321   return RET_NOT_SUPPORT;
322 }
323 
InsertForwardQuantNodeNew(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,size_t index,CastNodeType cast_node_type)324 int InsertQuantNodeManager::InsertForwardQuantNodeNew(const FuncGraphPtr &graph, const CNodePtr &cnode,
325                                                       TypeId cast_dtype, size_t index, CastNodeType cast_node_type) {
326   if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
327     MS_LOG(ERROR) << "Invalid cast dtype: " << cast_dtype;
328     return RET_NOT_SUPPORT;
329   }
330 
331   auto input_node = cnode->input(index);
332   CHECK_NULL_RETURN(input_node);
333   if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
334     MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope();
335     return RET_ERROR;
336   }
337   if (CheckDataType(input_node, cast_dtype) != RET_OK) {
338     return RET_NO_CHANGE;
339   }
340   // insert forward cast_node
341   TypeId src_dtype;
342   TypeId dst_dtype;
343   std::vector<schema::QuantParamT> cast_input_quant_params;
344   std::vector<schema::QuantParamT> cast_output_quant_params;
345   if (cast_node_type == kQuant) {
346     src_dtype = cast_dtype;
347     dst_dtype = kNumberTypeInt8;
348     cast_output_quant_params = quant::GetInputNodeQuantParam(cnode, index);
349     std::copy(cast_output_quant_params.cbegin(), cast_output_quant_params.cend(),
350               std::back_inserter(cast_input_quant_params));
351     // Uint8toInt8
352     if (src_dtype == kNumberTypeUInt8) {
353       for (auto &quant_param : cast_input_quant_params) {
354         quant_param.zeroPoint += kU8ZeroPointOffset;
355       }
356     }
357   } else {
358     src_dtype = kNumberTypeInt8;
359     dst_dtype = cast_dtype;
360     auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
361     auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<mindspore::Primitive>>(input_cnode->input(0));
362     if (input_cnode_primitive_c == nullptr) {
363       MS_LOG(DEBUG) << "input: " << index << " " << input_cnode->fullname_with_scope() << ": "
364                     << " PrimitiveC is null";
365       return RET_NO_CHANGE;
366     }
367     auto quantization_param_value = input_cnode_primitive_c->GetAttr(quant::kQuantParam);
368     MS_CHECK_TRUE_MSG(quantization_param_value != nullptr, RET_ERROR, "quantization_param_value is nullptr.");
369     auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
370     if (quantization_param_list.empty()) {
371       MS_LOG(ERROR) << input_node->fullname_with_scope() << " quantization param Not exist.";
372       return RET_ERROR;
373     }
374     cast_input_quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.front());
375     std::copy(cast_input_quant_params.cbegin(), cast_input_quant_params.cend(),
376               std::back_inserter(cast_output_quant_params));
377   }
378   ValueNodePtr new_primitive =
379     NewQuantCastPrimitive(src_dtype, dst_dtype, input_node, cast_output_quant_params, 0, true);
380   CHECK_NULL_RETURN(new_primitive);
381   std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
382   auto quant_cast_cnode = graph->NewCNode(op_inputs);
383   CHECK_NULL_RETURN(quant_cast_cnode);
384   quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
385                                             "_pre");
386   // set abstract
387   if (input_node->abstract() != nullptr) {
388     auto abstract = input_node->abstract()->Clone();
389     quant_cast_cnode->set_abstract(abstract);
390     if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
391       MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
392       return RET_ERROR;
393     }
394   } else {
395     MS_LOG(INFO) << "input node abstract nullptr, input node name: " << input_node->fullname_with_scope();
396   }
397   auto manager = graph->manager();
398   if (manager == nullptr) {
399     manager = Manage(graph, true);
400   }
401   CHECK_NULL_RETURN(manager);
402   manager->SetEdge(cnode, index, quant_cast_cnode);
403   MS_LOG(INFO) << "InsertForwardQuantNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
404                << " dst_type: " << dst_dtype;
405   return RET_OK;
406 }
407 
InsertForwardQuantNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,size_t index,CastNodeType cast_node_type)408 int InsertQuantNodeManager::InsertForwardQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
409                                                    size_t index, CastNodeType cast_node_type) {
410   if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
411     MS_LOG(ERROR) << "Invalid cast dtype: " << cast_dtype;
412     return RET_NOT_SUPPORT;
413   }
414 
415   auto input_node = cnode->input(index);
416   CHECK_NULL_RETURN(input_node);
417   if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
418     MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope();
419     return RET_ERROR;
420   }
421   if (CheckDataType(input_node, cast_dtype) != RET_OK) {
422     return RET_NO_CHANGE;
423   }
424   // insert forward cast_node
425   TypeId src_dtype;
426   TypeId dst_dtype;
427   std::vector<schema::QuantParamT> input_quant_params;
428   std::vector<schema::QuantParamT> output_quant_params;
429   if (cast_node_type == kQuant) {
430     src_dtype = cast_dtype;
431     dst_dtype = kNumberTypeInt8;
432     auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
433     CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
434     if (curr_primitive_quant_param_holder->get_input_quant_params().size() < index) {
435       MS_LOG(ERROR) << "quant param is invalid.";
436       return RET_ERROR;
437     }
438     output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[index - 1];
439     std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params));
440     // Uint8toInt8
441     if (src_dtype == kNumberTypeUInt8) {
442       for (auto &quant_param : input_quant_params) {
443         quant_param.zeroPoint += kU8ZeroPointOffset;
444       }
445     }
446   } else {
447     src_dtype = kNumberTypeInt8;
448     dst_dtype = cast_dtype;
449     auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
450     auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<mindspore::Primitive>>(input_cnode->input(0));
451     if (input_cnode_primitive_c == nullptr) {
452       MS_LOG(DEBUG) << "input: " << index << " " << input_cnode->fullname_with_scope() << ": "
453                     << " PrimitiveC is null";
454       return RET_NO_CHANGE;
455     }
456     auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
457     if (input_primitive_quant_param_holder->get_output_quant_params().empty()) {
458       MS_LOG(ERROR) << "output quant param is empty.";
459       return RET_ERROR;
460     }
461     input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0];
462     std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
463   }
464   ValueNodePtr new_primitive =
465     NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params, 0, false);
466   CHECK_NULL_RETURN(new_primitive);
467   std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
468   auto quant_cast_cnode = graph->NewCNode(op_inputs);
469   CHECK_NULL_RETURN(quant_cast_cnode);
470   quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
471                                             "_pre");
472   // set abstract
473   if (input_node->abstract() != nullptr) {
474     auto abstract = input_node->abstract()->Clone();
475     quant_cast_cnode->set_abstract(abstract);
476     if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
477       MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
478       return RET_ERROR;
479     }
480   } else {
481     MS_LOG(INFO) << "input node abstract nullptr, input node name: " << input_node->fullname_with_scope();
482   }
483   auto manager = graph->manager();
484   if (manager == nullptr) {
485     manager = Manage(graph, true);
486   }
487   CHECK_NULL_RETURN(manager);
488   manager->SetEdge(cnode, index, quant_cast_cnode);
489   MS_LOG(INFO) << "InsertForwardQuantNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
490                << " dst_type: " << dst_dtype;
491   return RET_OK;
492 }
493 
InsertBackwardDeQuantNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,size_t index,const AnfNodePtr & output_node)494 int InsertQuantNodeManager::InsertBackwardDeQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
495                                                       TypeId cast_dtype, size_t index, const AnfNodePtr &output_node) {
496   if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
497     MS_LOG(ERROR) << "Invalid cast dtype: " << cast_dtype;
498     return RET_NOT_SUPPORT;
499   }
500   CHECK_NULL_RETURN(output_node);
501   // If cnode or outputnode is QuantDTypeCast, do nothing.
502   if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast) ||
503       opt::CheckPrimitiveType(output_node, prim::kPrimQuantDTypeCast)) {
504     return RET_NO_CHANGE;
505   }
506   auto ret = CheckDataType(output_node, cast_dtype);
507   if (ret != RET_OK) {
508     MS_LOG(ERROR) << "Check data type failed, cnode name: " << output_node->fullname_with_scope();
509     return ret;
510   }
511   auto manager = graph->manager();
512   if (manager == nullptr) {
513     manager = Manage(graph, true);
514   }
515   CHECK_NULL_RETURN(manager);
516 
517   // insert backward cast_node
518   TypeId src_dtype = kNumberTypeInt8;
519   TypeId dst_dtype = cast_dtype;
520   std::vector<schema::QuantParamT> input_quant_params;
521   std::vector<schema::QuantParamT> output_quant_params;
522 
523   auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
524   CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
525   if (curr_primitive_quant_param_holder->get_output_quant_params().empty()) {
526     MS_LOG(ERROR) << "quant param is invalid.";
527     return RET_ERROR;
528   }
529   input_quant_params = curr_primitive_quant_param_holder->get_output_quant_params().front();
530   std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
531   // Int8toUint8
532   if (dst_dtype == kNumberTypeUInt8) {
533     for (auto &quant_param : output_quant_params) {
534       quant_param.zeroPoint += kU8ZeroPointOffset;
535     }
536   }
537   ValueNodePtr new_primitive =
538     NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params, 0, false);
539   CHECK_NULL_RETURN(new_primitive);
540   std::vector<AnfNodePtr> op_inputs = {new_primitive, cnode->cast<AnfNodePtr>()};
541   auto quant_cast_cnode = graph->NewCNode(op_inputs);
542   MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
543   quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
544                                             "_post");
545   if (SetCastNodeAbstract(cnode, output_node, quant_cast_cnode) != RET_OK) {
546     MS_LOG(ERROR) << "SetCastNodeAbstract failed.";
547     return RET_ERROR;
548   }
549   if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
550     MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
551     return RET_ERROR;
552   }
553   manager->SetEdge(output_node, index, quant_cast_cnode);
554   MS_LOG(INFO) << "InsertBackwardDeQuantNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
555                << " dst_type: " << dst_dtype;
556   return RET_OK;
557 }
558 
InsertForwardCastNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,quant::QuantType curr_quant_type)559 int InsertQuantNodeManager::InsertForwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
560                                                   quant::QuantType curr_quant_type) {
561   // inputs
562   for (size_t index = 1; index < cnode->size(); index++) {
563     auto input_node = cnode->input(index);
564     CHECK_NULL_RETURN(input_node);
565     if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
566       MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input.";
567       continue;
568     }
569     quant::QuantType pre_quant_type = quant::QUANT_NONE;
570     if (input_node->isa<mindspore::CNode>()) {
571       if (GetQuantType(input_node->cast<mindspore::CNodePtr>(), &pre_quant_type) != RET_OK) {
572         MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
573         return RET_ERROR;
574       }
575     }
576     if (pre_quant_type == quant::QUANT_NONE && curr_quant_type == quant::QUANT_ALL) {
577       auto status = InsertQuantDtypeCastNode(graph, cnode, FORWARD, cast_dtype, kQuant, index, nullptr);
578       if (status != RET_OK && status != RET_NO_CHANGE) {
579         MS_LOG(ERROR) << "InsertQuantDtypeCastNode kQuant failed, cnode name: " << cnode->fullname_with_scope();
580         return status;
581       }
582     }
583   }
584   return RET_OK;
585 }
586 
InsertCastNodeForFullQuant(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,quant::QuantType curr_quant_type)587 int InsertQuantNodeManager::InsertCastNodeForFullQuant(const FuncGraphPtr &graph, const CNodePtr &cnode,
588                                                        TypeId cast_dtype, quant::QuantType curr_quant_type) {
589   // inputs
590   for (size_t index = 1; index < cnode->size(); index++) {
591     auto input_node = cnode->input(index);
592     CHECK_NULL_RETURN(input_node);
593     if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
594       MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input.";
595       continue;
596     }
597     quant::QuantType pre_quant_type = quant::QUANT_NONE;
598     if (input_node->isa<mindspore::CNode>()) {
599       if (GetQuantTypeNew(input_node->cast<mindspore::CNodePtr>(), &pre_quant_type) != RET_OK) {
600         MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
601         return RET_ERROR;
602       }
603     }
604     if (pre_quant_type == quant::QUANT_NONE && curr_quant_type == quant::QUANT_ALL) {
605       auto status = InsertQuantDtypeCastNodeNew(graph, cnode, FORWARD, cast_dtype, kQuant, index, nullptr);
606       if (status != RET_OK && status != RET_NO_CHANGE) {
607         MS_LOG(ERROR) << "InsertQuantDtypeCastNode kQuant failed, cnode name: " << cnode->fullname_with_scope();
608         return status;
609       }
610     } else if (pre_quant_type == quant::QUANT_ALL && curr_quant_type == quant::QUANT_NONE) {
611       auto status = InsertQuantDtypeCastNodeNew(graph, cnode, FORWARD, cast_dtype, kDeQuant, index, nullptr);
612       if (status != RET_OK && status != RET_NO_CHANGE) {
613         MS_LOG(ERROR) << "InsertQuantDtypeCastNode kDeQuant failed, cnode name: " << cnode->fullname_with_scope();
614         return status;
615       }
616     }
617   }
618   return RET_OK;
619 }
620 
InsertBackwardCastNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,quant::QuantType curr_quant_type)621 int InsertQuantNodeManager::InsertBackwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
622                                                    quant::QuantType curr_quant_type) {
623   // outputs
624   auto manager = graph->manager();
625   if (manager == nullptr) {
626     manager = Manage(graph, true);
627   }
628   CHECK_NULL_RETURN(manager);
629   auto node_users = manager->node_users()[cnode];
630   for (auto &node_user : node_users) {
631     auto output_cnode = node_user.first->cast<CNodePtr>();
632     quant::QuantType post_quant_type;
633     if (GetQuantType(output_cnode, &post_quant_type) != RET_OK) {
634       MS_LOG(ERROR) << "Get quant type failed, cnode name: " << output_cnode->fullname_with_scope();
635       return RET_ERROR;
636     }
637     if (curr_quant_type == quant::QUANT_ALL && post_quant_type == quant::QUANT_NONE) {
638       auto status =
639         InsertQuantDtypeCastNode(graph, cnode, BACKWARD, cast_dtype, kDeQuant, node_user.second, node_user.first);
640       if (status != RET_OK && status != RET_NO_CHANGE) {
641         MS_LOG(ERROR) << "InsertQuantDtypeCastNode dequant failed, cnode name: " << cnode->fullname_with_scope();
642         return status;
643       }
644     }
645   }  // node_users
646   return RET_OK;
647 }
648 
InsertQuantDtypeCastFlyNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,TypeId src_dtype,TypeId dst_dtype,int axis,bool is_quant_attribute)649 int InsertQuantNodeManager::InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
650                                                         size_t input_index, TypeId src_dtype, TypeId dst_dtype,
651                                                         int axis, bool is_quant_attribute) {
652   MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
653   auto cnode_primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
654   if (cnode_primitive == nullptr) {
655     MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
656     return RET_ERROR;
657   }
658   auto input_node = cnode->input(input_index);
659   if (!input_node->isa<mindspore::Parameter>()) {
660     MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
661     return RET_ERROR;
662   }
663   auto input_quant_params = quant::GetInputNodeQuantParam(cnode, input_index);
664 
665   CNodePtr quant_cast_cnode = nullptr;
666   if (is_quant_attribute) {
667     ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, {}, axis, false);
668     MS_CHECK_TRUE_MSG(new_primitive != nullptr, RET_NULL_PTR, "New quant_cast primitive failed!");
669     std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
670     quant_cast_cnode = func_graph->NewCNode(op_inputs);
671   } else {
672     quant_cast_cnode =
673       CreateQuantInputCastNode(func_graph, cnode, input_node, src_dtype, dst_dtype, input_quant_params, axis);
674   }
675   CHECK_NULL_RETURN(quant_cast_cnode);
676   opt::NodeInferShape infer;
677   auto status = infer.InferShape(quant_cast_cnode);
678   if (status != RET_OK) {
679     MS_LOG(ERROR) << quant_cast_cnode->fullname_with_scope() << " InferShape failed.";
680     return RET_ERROR;
681   }
682   auto manager = func_graph->manager();
683   CHECK_NULL_RETURN(manager);
684   auto ret = manager->Replace(input_node, quant_cast_cnode);
685   if (!ret) {
686     MS_LOG(ERROR) << "Replace QuantDtypeCast failed.";
687     return RET_ERROR;
688   }
689   cnode_primitive->DelAttr(quant::kQuantParam);
690   MS_LOG(INFO) << "InsertCastNode cnode name: " << quant_cast_cnode->fullname_with_scope()
691                << " src_dtype: " << src_dtype << " dst_dtype: " << dst_dtype;
692 
693   return RET_OK;
694 }
695 
CreateQuantInputCastNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const AnfNodePtr input_node,TypeId src_dtype,TypeId dst_dtype,const std::vector<schema::QuantParamT> & input_quant_params,int axis)696 CNodePtr InsertQuantNodeManager::CreateQuantInputCastNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
697                                                           const AnfNodePtr input_node, TypeId src_dtype,
698                                                           TypeId dst_dtype,
699                                                           const std::vector<schema::QuantParamT> &input_quant_params,
700                                                           int axis) {
701   ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_node, {}, axis, false);
702   std::vector<float> scales;
703   std::vector<int> zps;
704   std::vector<float> mean_corrs;
705   std::vector<float> var_corrs;
706   for (size_t i = 0; i < input_quant_params.size(); ++i) {
707     scales.push_back(static_cast<float>(input_quant_params.at(i).scale));
708     zps.push_back(static_cast<int64_t>(input_quant_params.at(i).zeroPoint));
709     mean_corrs.push_back(static_cast<float>(input_quant_params.at(i).meanCorr));
710     var_corrs.push_back(static_cast<float>(input_quant_params.at(i).varCorr));
711   }
712   auto scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, "scales");
713   auto zps_node = opt::BuildIntVecParameterNode(func_graph, zps, "zps");
714   auto mean_corrs_node = opt::BuildFloatVecParameterNode(func_graph, mean_corrs, "mean_corrs");
715   auto var_corrs_node = opt::BuildFloatVecParameterNode(func_graph, var_corrs, "var_corrs");
716 
717   std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node,      scales_node,
718                                        zps_node,      mean_corrs_node, var_corrs_node};
719   auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
720   if (quant_cast_cnode == nullptr) {
721     MS_LOG(ERROR) << "New quant cast node failed.";
722     return nullptr;
723   }
724   auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
725   int index = 0;
726   if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
727     index = 0;
728   }
729   const int quant_dtype_cast_offset = 10000;
730   quant_cast_cnode->set_fullname_with_scope(strings.at(0) + "-QuantDtypeCast-op" +
731                                             std::to_string(index + quant_dtype_cast_offset));
732   return quant_cast_cnode;
733 }
734 
CalculateScaleZPNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,ParameterPtr * scales_node,ParameterPtr * zps_node,TypeId dst_dtype,int axis)735 int InsertQuantNodeManager::CalculateScaleZPNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
736                                                  size_t input_index, ParameterPtr *scales_node, ParameterPtr *zps_node,
737                                                  TypeId dst_dtype, int axis) {
738   CHECK_NULL_RETURN(scales_node);
739   CHECK_NULL_RETURN(zps_node);
740   MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
741   auto input_node = cnode->input(input_index);
742   auto input_quant_params = quant::GetInputNodeQuantParam(cnode, input_index);
743   if (input_quant_params.empty()) {
744     MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << input_index << " quant param is empty.";
745     return RET_ERROR;
746   }
747 
748   if (dst_dtype == kNumberTypeFloat16) {
749     std::vector<float16> scales;
750     std::vector<float16> zps;
751     for (size_t i = 0; i < input_quant_params.size(); ++i) {
752       scales.push_back(static_cast<float16>(input_quant_params.at(i).scale * input_quant_params.at(i).varCorr));
753       zps.push_back(static_cast<float16>(-input_quant_params.at(i).zeroPoint +
754                                          input_quant_params.at(i).meanCorr /
755                                            (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr)));
756     }
757     *scales_node = opt::BuildFloat16VecParameterNode(func_graph, scales, input_node->fullname_with_scope() + "-scales");
758     *zps_node = opt::BuildFloat16VecParameterNode(func_graph, zps, input_node->fullname_with_scope() + "-zps");
759   } else {
760     std::vector<float> scales;
761     std::vector<float> zps;
762     for (size_t i = 0; i < input_quant_params.size(); ++i) {
763       scales.push_back(static_cast<float>(input_quant_params.at(i).scale * input_quant_params.at(i).varCorr));
764       zps.push_back(static_cast<float>(-input_quant_params.at(i).zeroPoint +
765                                        input_quant_params.at(i).meanCorr /
766                                          (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr)));
767     }
768     *scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, input_node->fullname_with_scope() + "-scales");
769     *zps_node = opt::BuildFloatVecParameterNode(func_graph, zps, input_node->fullname_with_scope() + "-zps");
770   }
771   if (*scales_node == nullptr || *zps_node == nullptr) {
772     MS_LOG(ERROR) << "Failed to build scales node, zps node ";
773     return RET_ERROR;
774   }
775   if (input_quant_params.size() > 1) {
776     ShapeVector shape;
777     if (opt::FetchShapeFromAbstract(input_node->abstract(), &shape) != lite::RET_OK) {
778       MS_LOG(ERROR) << "fetch shape failed." << input_node->fullname_with_scope();
779       return lite::RET_ERROR;
780     }
781 
782     std::vector<int64_t> shape_vector = {};
783     for (size_t i = 0; i < shape.size(); i++) {
784       if (i == static_cast<size_t>(axis)) {
785         shape_vector.push_back((int64_t)input_quant_params.size());
786       } else {
787         shape_vector.push_back(1);
788       }
789     }
790     auto scales_abstract = (*scales_node)->abstract();
791     CHECK_NULL_RETURN(scales_abstract);
792     scales_abstract->set_shape(std::make_shared<abstract::Shape>(shape_vector));
793     auto zps_abstract = (*zps_node)->abstract();
794     CHECK_NULL_RETURN(zps_abstract);
795     zps_abstract->set_shape(std::make_shared<abstract::Shape>(shape_vector));
796   }
797   return RET_OK;
798 }
799 
SetParallelStrategy(const CNodePtr & cnode,const std::vector<std::vector<int64_t>> & in_strategy)800 int InsertQuantNodeManager::SetParallelStrategy(const CNodePtr &cnode,
801                                                 const std::vector<std::vector<int64_t>> &in_strategy) {
802   auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
803   CHECK_NULL_RETURN(primitive);
804   primitive->AddAttr(IN_STRATEGY, MakeValue(in_strategy));
805   return RET_OK;
806 }
807 
GetAddMulNodeParallelStrategy(ShapeVector weight_shape,std::vector<int64_t> weight_strategy,int axis,bool per_channel)808 std::vector<std::vector<int64_t>> InsertQuantNodeManager::GetAddMulNodeParallelStrategy(
809   ShapeVector weight_shape, std::vector<int64_t> weight_strategy, int axis, bool per_channel) {
810   std::vector<std::vector<int64_t>> add_mul_in_strategy;
811   std::vector<int64_t> in_strategy_1 = weight_strategy;
812   add_mul_in_strategy.push_back(in_strategy_1);
813   std::vector<int64_t> in_strategy_2;
814 
815   // if perlayer quant, the input2 strategy is set to 1.
816   // if perchannel quant, the input2 strategy is set by axis, the axis dim is set by matmul input strategy,
817   // the other dim is set to 1.
818   if (per_channel) {
819     for (size_t i = 0; i < weight_shape.size(); i++) {
820       if (i == static_cast<size_t>(axis) && i < weight_strategy.size()) {
821         in_strategy_2.push_back(weight_strategy.at(i));
822       } else {
823         in_strategy_2.push_back(1);
824       }
825     }
826   } else {
827     in_strategy_2.push_back(1);
828   }
829 
830   add_mul_in_strategy.push_back(in_strategy_2);
831   return add_mul_in_strategy;
832 }
833 
InsertAscendAntiQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,TypeId src_dtype,TypeId dst_dtype,int axis,const std::string & ascend_backend)834 int InsertQuantNodeManager::InsertAscendAntiQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
835                                                       size_t input_index, TypeId src_dtype, TypeId dst_dtype, int axis,
836                                                       const std::string &ascend_backend) {
837   auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
838   CHECK_NULL_RETURN(primitive);
839   MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
840   auto input_node = cnode->input(input_index);
841   auto manager = func_graph->manager();
842   CHECK_NULL_RETURN(manager);
843   std::vector<std::vector<int64_t>> cnode_in_strategy;
844   if (primitive->HasAttr(IN_STRATEGY)) {
845     cnode_in_strategy = ExtractStrategy(primitive->GetAttr(IN_STRATEGY));
846     CHECK_LESS_RETURN(cnode_in_strategy.size(), input_index);
847     MS_LOG(INFO) << "cnode: " << cnode->fullname_with_scope() << " in strategy is " << cnode_in_strategy;
848   }
849   if (!input_node->isa<mindspore::Parameter>()) {
850     MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
851     return RET_ERROR;
852   }
853 
854   // parameter+cast+add+mul+matmul
855   // parameter+gather+cast+add+mul
856   auto input_quant_params = quant::GetInputNodeQuantParam(cnode, input_index);
857   if (input_quant_params.empty()) {
858     MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << input_index << " quant param is empty.";
859     return RET_ERROR;
860   }
861 
862   // Insert cast node
863   CNodePtr cast_cnode = nullptr;
864   if (ascend_backend == "910b") {
865     MS_LOG(INFO) << "The ascend_backend is 910b, it will insert antiquant node";
866     if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
867       cast_cnode = NewAscendAntiQuantCNode(func_graph, cnode, dst_dtype);
868     } else {
869       cast_cnode = NewAscendAntiQuantCNode(func_graph, input_node, dst_dtype);
870     }
871   } else {
872     if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
873       cast_cnode = NewCastNode(func_graph, cnode, dst_dtype);
874     } else {
875       cast_cnode = NewCastNode(func_graph, input_node, dst_dtype);
876     }
877   }
878 
879   CHECK_NULL_RETURN(cast_cnode);
880   // cast node do not need to set parallel strategy, antiquant node need set parallel strategy
881   if (primitive->HasAttr(IN_STRATEGY) && ascend_backend == "910b") {
882     std::vector<std::vector<int64_t>> cast_in_strategy;
883     std::vector<int64_t> in_strategy_1 = cnode_in_strategy[input_index - kPrimOffset];
884     cast_in_strategy.push_back(in_strategy_1);
885     auto ret = SetParallelStrategy(cast_cnode, cast_in_strategy);
886     if (ret != RET_OK) {
887       MS_LOG(ERROR) << "Fail to set cnode parallel strategy, cnode: " << cast_cnode->fullname_with_scope();
888       return RET_ERROR;
889     }
890   }
891 
892   ParameterPtr scales_node;
893   ParameterPtr zps_node;
894   auto ret = CalculateScaleZPNode(func_graph, cnode, input_index, &scales_node, &zps_node, dst_dtype, axis);
895   if (ret != RET_OK) {
896     MS_LOG(ERROR) << "Fail to calculate scale & zero_point node: " << cnode->fullname_with_scope();
897     return RET_ERROR;
898   }
899 
900   auto add_cnode = NewAddNode(func_graph, cast_cnode, zps_node);
901   CHECK_NULL_RETURN(add_cnode);
902 
903   auto mul_cnode = NewMulNode(func_graph, add_cnode, scales_node);
904   CHECK_NULL_RETURN(mul_cnode);
905 
906   if (primitive->HasAttr(IN_STRATEGY)) {
907     ShapeVector weight_shape;
908     if (opt::FetchShapeFromAbstract(input_node->abstract(), &weight_shape) != lite::RET_OK) {
909       MS_LOG(ERROR) << "fetch shape failed." << input_node->fullname_with_scope();
910       return lite::RET_ERROR;
911     }
912     std::vector<int64_t> weight_strategy = cnode_in_strategy[input_index - kPrimOffset];
913     bool per_channel = input_quant_params.size() > 1;
914     auto add_mul_in_strategy = GetAddMulNodeParallelStrategy(weight_shape, weight_strategy, axis, per_channel);
915 
916     // add_cnode & mul_cnode set parallel strategy
917     ret = SetParallelStrategy(add_cnode, add_mul_in_strategy);
918     if (ret != RET_OK) {
919       MS_LOG(ERROR) << "Fail to set add cnode parallel strategy, cnode: " << add_cnode->fullname_with_scope();
920       return RET_ERROR;
921     }
922     ret = SetParallelStrategy(mul_cnode, add_mul_in_strategy);
923     if (ret != RET_OK) {
924       MS_LOG(ERROR) << "Fail to set mul cnode parallel strategy, cnode: " << mul_cnode->fullname_with_scope();
925       return RET_ERROR;
926     }
927   }
928 
929   auto node_map = manager->node_users();
930 
931   // Remove QuantParam
932   ret = RemoveInputNodeQuantParam(cnode, input_index);
933   if (ret != RET_OK) {
934     MS_LOG(ERROR) << "Fail to Remove node: " << input_node->fullname_with_scope() << " quant param";
935     return RET_ERROR;
936   }
937 
938   AnfNodeIndexSet node_user;
939   if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
940     node_user = node_map[cnode];
941   } else {
942     node_user = node_map[input_node];
943   }
944   for (const auto &user : node_user) {
945     manager->SetEdge(user.first, user.second, mul_cnode);
946   }
947   return RET_OK;
948 }
949 
InsertFSEDecodeNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,TypeId dst_dtype)950 int InsertQuantNodeManager::InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
951                                                 size_t input_index, TypeId dst_dtype) {
952   auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
953   if (primitive == nullptr) {
954     MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
955     return RET_ERROR;
956   }
957   MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
958   auto input_node = cnode->input(input_index);
959   if (!input_node->isa<mindspore::Parameter>()) {
960     MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
961     return RET_ERROR;
962   }
963   auto shape = input_node->Shape();
964   std::vector<AnfNodePtr> op_inputs;
965   int ret = CreateFSEInputs(func_graph, input_node, &op_inputs, dst_dtype);
966   if (ret != RET_OK) {
967     MS_LOG(ERROR) << "CreateFSEInputs failed.";
968     return RET_ERROR;
969   }
970 
971   auto fse_decode_cnode = func_graph->NewCNode(op_inputs);
972   CHECK_NULL_RETURN(fse_decode_cnode);
973   auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
974   int index = 0;
975   if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
976     index = 0;
977   }
978   const int fse_decode_offset = 20000;
979   fse_decode_cnode->set_fullname_with_scope(strings.at(0) + "-FSEDecode-op" +
980                                             std::to_string(index + fse_decode_offset));
981   CHECK_NULL_RETURN(cnode->abstract());
982   auto fse_abstract = cnode->abstract()->Clone();
983   fse_abstract->set_shape(shape);
984   fse_decode_cnode->set_abstract(fse_abstract);
985 
986   auto manager = func_graph->manager();
987   CHECK_NULL_RETURN(manager);
988   auto ret_bool = manager->Replace(input_node, fse_decode_cnode);
989   if (!ret_bool) {
990     MS_LOG(ERROR) << "Replace QuantDtypeCast failed.";
991     return RET_ERROR;
992   }
993 
994   return RET_OK;
995 }
996 
CreateFSEInputs(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,std::vector<AnfNodePtr> * op_inputs,TypeId dst_dtype)997 int InsertQuantNodeManager::CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
998                                             std::vector<AnfNodePtr> *op_inputs, TypeId dst_dtype) {
999   CHECK_NULL_RETURN(op_inputs);
1000   if (!input_node->isa<mindspore::Parameter>()) {
1001     MS_LOG(ERROR) << "FSEDecode input is not parameter node.";
1002     return RET_ERROR;
1003   }
1004   auto parameter_ptr = input_node->cast<ParameterPtr>();
1005   CHECK_NULL_RETURN(parameter_ptr);
1006   if (!parameter_ptr->has_default()) {
1007     MS_LOG(ERROR) << input_node->fullname_with_scope() << " parameter dont have default.";
1008     return RET_ERROR;
1009   }
1010   auto tensor = parameter_ptr->default_param()->cast<tensor::TensorPtr>();
1011   CHECK_NULL_RETURN(tensor);
1012   int8_t *data8 = reinterpret_cast<int8_t *>(tensor->data_c());
1013   size_t data_size = tensor->DataSize();
1014   FSEBuffer fse_buffer;
1015   auto ret = FSEDecoder::DecodeBuffer(data8, data_size, &fse_buffer);
1016   if (ret != RET_OK) {
1017     MS_LOG(ERROR) << input_node->fullname_with_scope() << " buffer decode failed.";
1018     return RET_ERROR;
1019   }
1020   ValueNodePtr new_primitive = NewFSEDecodePrimitive(dst_dtype, fse_buffer.curr_chunk, fse_buffer.curr_chunk_index,
1021                                                      fse_buffer.curr_bit_count, fse_buffer.table_log);
1022   op_inputs->push_back(new_primitive);
1023 
1024   // make shape to (1,chunk_size)
1025   ShapeVector shape_vector;
1026   shape_vector.push_back(1);
1027   shape_vector.push_back(fse_buffer.chunk_size);
1028   auto chunk_tensor_info =
1029     lite::CreateTensorInfo(fse_buffer.chunks, fse_buffer.chunk_size, shape_vector, kNumberTypeInt8);
1030   parameter_ptr->set_default_param(chunk_tensor_info);
1031   parameter_ptr->set_abstract(chunk_tensor_info->ToAbstract());
1032   op_inputs->push_back(input_node);
1033 
1034   size_t table_size = 1u << fse_buffer.table_log;
1035   std::vector<uint16_t> states_table(table_size);
1036   std::vector<uint8_t> bit_count_table(table_size);
1037   std::vector<uint16_t> symbol_table(table_size);
1038 
1039   ret = FSEDecoder::FSECreateStatesForDecoding(fse_buffer.frequency, fse_buffer.frequency_count, fse_buffer.table_log,
1040                                                states_table.data(), bit_count_table.data(), symbol_table.data());
1041   if (ret != RET_OK) {
1042     MS_LOG(ERROR) << "FSE create states for decoding failed.";
1043     return RET_ERROR;
1044   }
1045   std::vector<int64_t> shape = {static_cast<int64_t>(table_size)};
1046 
1047   auto states_table_tensor_info =
1048     lite::CreateTensorInfo(states_table.data(), sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
1049   auto states_table_node = opt::BuildParameterNode(func_graph, states_table_tensor_info, "states_table");
1050   op_inputs->push_back(states_table_node);
1051 
1052   auto bit_count_table_tensor_info =
1053     lite::CreateTensorInfo(bit_count_table.data(), sizeof(uint8_t) * table_size, shape, kNumberTypeUInt8);
1054   auto bit_count_table_node = opt::BuildParameterNode(func_graph, bit_count_table_tensor_info, "bit_count_table");
1055   op_inputs->push_back(bit_count_table_node);
1056 
1057   auto symbol_table_tensor_info =
1058     lite::CreateTensorInfo(symbol_table.data(), sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
1059   auto symbol_table_node = opt::BuildParameterNode(func_graph, symbol_table_tensor_info, "symbol_table");
1060   op_inputs->push_back(symbol_table_node);
1061 
1062   auto centroids_tensor_info =
1063     lite::CreateTensorInfo(fse_buffer.centroids, sizeof(float) * fse_buffer.centroid_size,
1064                            {static_cast<int64_t>(fse_buffer.centroid_size)}, kNumberTypeFloat32);
1065   auto centroids_node = opt::BuildParameterNode(func_graph, centroids_tensor_info, "centroids");
1066   op_inputs->push_back(centroids_node);
1067 
1068   auto shape_tensor_info = lite::CreateTensorInfo(ConvertShapeVectorToInt32(tensor->shape_c()).data(),
1069                                                   sizeof(int32_t) * tensor->shape_c().size(),
1070                                                   {static_cast<int64_t>(tensor->shape_c().size())}, kNumberTypeInt32);
1071   auto shape_node = opt::BuildParameterNode(func_graph, shape_tensor_info, "input_shape");
1072   op_inputs->push_back(shape_node);
1073 
1074   auto chunk_ends_tensor_info =
1075     lite::CreateTensorInfo(fse_buffer.chunk_ends, sizeof(uint64_t) * fse_buffer.chunk_ends_count,
1076                            {static_cast<int64_t>(fse_buffer.chunk_ends_count)}, kNumberTypeUInt64);
1077   auto chunk_ends_node = opt::BuildParameterNode(func_graph, chunk_ends_tensor_info, "chunk_ends");
1078   op_inputs->push_back(chunk_ends_node);
1079 
1080   return RET_OK;
1081 }
1082 
NewCastNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,int dst_type)1083 CNodePtr InsertQuantNodeManager::NewCastNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
1084                                              int dst_type) {
1085   auto prim_c = std::make_shared<ops::Cast>();
1086   MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1087   auto prim = prim_c->GetPrim();
1088   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1089   MS_LOG(INFO) << "dst_type:" << dst_type;
1090   TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
1091   prim->AddAttr(ops::kDstType, type_ptr);
1092   prim->AddAttr(ATTR_NO_NEED_CONSTANT_FOLDING, MakeValue(true));
1093   std::vector<AnfNodePtr> cast_op_inputs = {NewValueNode(prim), input_node};
1094   auto cast_cnode = func_graph->NewCNode(cast_op_inputs);
1095   cast_cnode->set_fullname_with_scope(input_node->fullname_with_scope() + "-Cast");
1096   cast_cnode->set_abstract(input_node->abstract()->Clone());
1097   auto ret = UpdateDataType(cast_cnode, TypeId(dst_type));
1098   if (ret != RET_OK) {
1099     MS_LOG(ERROR) << cast_cnode->fullname_with_scope() << " set dst_type " << dst_type << " failed.";
1100     return nullptr;
1101   }
1102   return cast_cnode;
1103 }
1104 
NewAscendAntiQuantCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,int dst_type)1105 CNodePtr InsertQuantNodeManager::NewAscendAntiQuantCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
1106                                                          int dst_type) {
1107   auto dst_prim = std::make_shared<acl::AscendAntiQuant>();
1108   if (dst_prim == nullptr) {
1109     return nullptr;
1110   }
1111   dst_prim->AddAttr("scale", MakeValue(1.0f));
1112   dst_prim->AddAttr("offset", MakeValue(0.0f));
1113   MS_LOG(INFO) << "dst_type:" << dst_type;
1114   TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
1115   dst_prim->AddAttr(ops::kOutputDType, type_ptr);
1116   std::vector<AnfNodePtr> cast_op_inputs = {NewValueNode(dst_prim), input_node};
1117   auto anti_cnode = func_graph->NewCNode(cast_op_inputs);
1118   anti_cnode->set_fullname_with_scope(input_node->fullname_with_scope() + "-AntiQuant");
1119   anti_cnode->set_abstract(input_node->abstract()->Clone());
1120   anti_cnode->abstract()->set_type(type_ptr);
1121   auto ret = UpdateDataType(anti_cnode, TypeId(dst_type));
1122   if (ret != RET_OK) {
1123     MS_LOG(ERROR) << anti_cnode->fullname_with_scope() << " set dst_type " << dst_type << " failed.";
1124     return nullptr;
1125   }
1126   return anti_cnode;
1127 }
1128 
NewMulNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_1,const AnfNodePtr & input_2)1129 CNodePtr InsertQuantNodeManager::NewMulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_1,
1130                                             const AnfNodePtr &input_2) {
1131   auto prim_c = std::make_shared<ops::MulFusion>();
1132   MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1133   auto prim = prim_c->GetPrim();
1134   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1135   prim->AddAttr(ATTR_NO_NEED_CONSTANT_FOLDING, MakeValue(true));
1136   std::vector<AnfNodePtr> op_inputs = {NewValueNode(prim), input_1, input_2};
1137   auto cnode = func_graph->NewCNode(op_inputs);
1138   if (cnode == nullptr) {
1139     MS_LOG(ERROR) << "cnode is nullptr.";
1140     return nullptr;
1141   }
1142   cnode->set_fullname_with_scope(input_1->fullname_with_scope() + "-" + input_2->fullname_with_scope() + "-Mul");
1143   cnode->set_abstract(input_1->abstract()->Clone());
1144   return cnode;
1145 }
1146 
NewAddNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_1,const AnfNodePtr & input_2)1147 CNodePtr InsertQuantNodeManager::NewAddNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_1,
1148                                             const AnfNodePtr &input_2) {
1149   auto prim_c = std::make_shared<ops::AddFusion>();
1150   MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1151   auto prim = prim_c->GetPrim();
1152   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1153   prim->AddAttr(ATTR_NO_NEED_CONSTANT_FOLDING, MakeValue(true));
1154   std::vector<AnfNodePtr> op_inputs = {NewValueNode(prim), input_1, input_2};
1155   auto cnode = func_graph->NewCNode(op_inputs);
1156   if (cnode == nullptr) {
1157     MS_LOG(ERROR) << "cnode is nullptr.";
1158     return nullptr;
1159   }
1160   cnode->set_fullname_with_scope(input_1->fullname_with_scope() + "-" + input_2->fullname_with_scope() + "-Add");
1161   cnode->set_abstract(input_1->abstract()->Clone());
1162   return cnode;
1163 }
1164 
NewQuantCastPrimitive(int src_type,int dst_type,const std::vector<schema::QuantParamT> & input_quant_params,const std::vector<schema::QuantParamT> & output_quant_params,int axis,bool set_quant_flag)1165 ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(int src_type, int dst_type,
1166                                                            const std::vector<schema::QuantParamT> &input_quant_params,
1167                                                            const std::vector<schema::QuantParamT> &output_quant_params,
1168                                                            int axis, bool set_quant_flag) {
1169   auto prim_c = std::make_shared<ops::QuantDTypeCast>();
1170   MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1171   prim_c->Init(src_type, dst_type);
1172   prim_c->set_axis(axis);
1173   auto quant_params_holder = std::make_shared<QuantParamHolder>(input_quant_params.size(), output_quant_params.size());
1174   MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
1175   if (set_quant_flag) {
1176     quant_params_holder->set_quant_type(quant::QUANT_ALL);
1177   }
1178   quant_params_holder->set_input_quant_param(0, input_quant_params);
1179   quant_params_holder->set_output_quant_param(0, output_quant_params);
1180   auto prim = prim_c->GetPrim();
1181   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1182   prim->AddAttr("quant_params", quant_params_holder);
1183   return NewValueNode(prim);
1184 }
1185 
NewQuantCastPrimitive(int src_type,int dst_type,const AnfNodePtr & input_node,const std::vector<schema::QuantParamT> & output_quant_params,int axis,bool set_quant_flag)1186 ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(int src_type, int dst_type, const AnfNodePtr &input_node,
1187                                                            const std::vector<schema::QuantParamT> &output_quant_params,
1188                                                            int axis, bool set_quant_flag) {
1189   auto prim_c = std::make_shared<ops::QuantDTypeCast>();
1190   MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1191   prim_c->Init(src_type, dst_type);
1192   prim_c->set_axis(axis);
1193   auto prim = prim_c->GetPrim();
1194   if (set_quant_flag) {
1195     prim->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_ALL)));
1196   }
1197   // Set quant param to quant_cast_cnode
1198   if (!output_quant_params.empty()) {
1199     auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(output_quant_params);
1200     std::vector<ValuePtr> quantization_list = {quantization_ptr};
1201     auto quant_ptr = std::make_shared<ValueList>(quantization_list);
1202     MS_CHECK_TRUE_MSG(quant_ptr != nullptr, nullptr, "quant_ptr is nullptr.");
1203     prim->AddAttr(quant::kQuantParam, quant_ptr);
1204   } else {
1205     MS_LOG(WARNING) << "New quant cast node's output quant param is empty, input node: "
1206                     << input_node->fullname_with_scope();
1207   }
1208   return NewValueNode(prim);
1209 }
1210 
NewFSEDecodePrimitive(int dst_type,uint64_t curr_chunk,int64_t curr_chunk_index,int64_t curr_bit_count,int64_t table_log)1211 ValueNodePtr InsertQuantNodeManager::NewFSEDecodePrimitive(int dst_type, uint64_t curr_chunk, int64_t curr_chunk_index,
1212                                                            int64_t curr_bit_count, int64_t table_log) {
1213   auto prim_c = std::make_shared<ops::FSEDecode>();
1214   MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1215   prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);
1216 
1217   auto prim = prim_c->GetPrim();
1218   MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1219   return NewValueNode(prim);
1220 }
1221 
InsertAscendQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1222 int InsertQuantNodeManager::InsertAscendQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
1223   for (size_t i = 1; i < cnode->size(); i++) {
1224     if (cnode->input(i)->isa<CNode>() || IsGraphInput(cnode->input(i))) {
1225       auto ret = InsertAscendQuantNode(func_graph, cnode, i);
1226       if (ret != RET_OK) {
1227         MS_LOG(ERROR) << "InsertAscendQuantNode failed.";
1228         return ret;
1229       }
1230     }
1231   }
1232   return RET_OK;
1233 }
1234 
InsertAscendQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index)1235 int InsertQuantNodeManager::InsertAscendQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1236                                                   size_t input_index) {
1237   CHECK_NULL_RETURN(func_graph);
1238   CHECK_NULL_RETURN(cnode);
1239   auto x_q_param_origin = quant::GetInputNodeQuantParam(cnode, input_index);
1240   if (x_q_param_origin.empty()) {
1241     auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
1242     CHECK_NULL_RETURN(curr_quant_param_holder);
1243     auto input_quant_param = curr_quant_param_holder->get_input_quant_params();
1244     x_q_param_origin = input_quant_param.at(input_index - kPrimOffset);
1245   }
1246   if (x_q_param_origin.size() != kPerTensor) {
1247     MS_LOG(ERROR) << cnode->fullname_with_scope() << " x quant param size " << x_q_param_origin.size() << " != 1";
1248     return RET_ERROR;
1249   }
1250   auto x_q_param = quant::CloneQuantParam(x_q_param_origin);
1251   x_q_param.at(0).scale = 1 / x_q_param.at(0).scale;
1252   auto input_node = cnode->input(input_index);
1253   CHECK_NULL_RETURN(input_node);
1254   ValueNodePtr new_primitive = NewQuantCastPrimitive(kNumberTypeFloat32, kNumberTypeInt8, input_node, x_q_param);
1255   std::vector<AnfNodePtr> op_inputs = {new_primitive, cnode->input(input_index)};
1256   auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
1257   CHECK_NULL_RETURN(quant_cast_cnode);
1258   quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "-quant-" + std::to_string(input_index));
1259   // set abstract
1260   if (cnode->input(input_index)->abstract() != nullptr) {
1261     auto abstract = cnode->input(input_index)->abstract()->Clone();
1262     quant_cast_cnode->set_abstract(abstract);
1263     if (quant::UpdateDataType(quant_cast_cnode, kNumberTypeInt8) != RET_OK) {
1264       MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
1265       return RET_ERROR;
1266     }
1267   } else {
1268     MS_LOG(ERROR) << "input node abstract nullptr, input node name: " << cnode->fullname_with_scope();
1269     return RET_ERROR;
1270   }
1271   auto manager = func_graph->manager();
1272   if (manager == nullptr) {
1273     manager = Manage(func_graph, true);
1274   }
1275   CHECK_NULL_RETURN(manager);
1276   manager->SetEdge(cnode, input_index, quant_cast_cnode);
1277   MS_LOG(INFO) << cnode->fullname_with_scope() << " Insert Ascend QuantNode, scale: " << x_q_param.at(0).scale;
1278   return RET_OK;
1279 }
1280 
InsertAscendDeQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1281 int InsertQuantNodeManager::InsertAscendDeQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
1282   CHECK_NULL_RETURN(func_graph);
1283   CHECK_NULL_RETURN(cnode);
1284   auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(kPrimIndex));
1285   if (cnode_primitive == nullptr) {
1286     MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is nullptr.";
1287     return RET_ERROR;
1288   }
1289   auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
1290   CHECK_NULL_RETURN(curr_quant_param_holder);
1291   auto input_quant_param = curr_quant_param_holder->get_input_quant_params();
1292   auto x_q_param = quant::GetInputNodeQuantParam(cnode, Index0 + kPrimOffset);
1293   if (x_q_param.empty()) {
1294     x_q_param = input_quant_param.at(Index0);
1295   }
1296   if (x_q_param.size() != kPerTensor) {
1297     MS_LOG(ERROR) << cnode->fullname_with_scope() << " x quant param size " << x_q_param.size() << " != 1";
1298     return RET_ERROR;
1299   }
1300   auto w_q_params = quant::GetInputNodeQuantParam(cnode, Index1 + kPrimOffset);
1301   if (w_q_params.empty()) {
1302     w_q_params = input_quant_param.at(Index1);
1303   }
1304   if (w_q_params.empty()) {
1305     MS_LOG(ERROR) << cnode->fullname_with_scope() << " w quant param is empty.";
1306     return RET_ERROR;
1307   }
1308   MS_LOG(INFO) << cnode->fullname_with_scope() << " x scale:" << x_q_param.at(0).scale
1309                << " w scale size:" << w_q_params.size();
1310   std::vector<uint64_t> deq_scales(w_q_params.size());
1311   for (size_t i = 0; i < w_q_params.size(); ++i) {
1312     float float32_deq_scale = static_cast<float>(x_q_param.at(0).scale * w_q_params.at(i).scale);
1313     void *ptr = &float32_deq_scale;
1314     uint32_t *uint32_deq_scale = reinterpret_cast<uint32_t *>(ptr);
1315     uint64_t u64_deq_scale = 0;
1316     u64_deq_scale |= *uint32_deq_scale;
1317     deq_scales[i] = u64_deq_scale;
1318   }
1319   auto dtype = kNumberTypeFloat32;
1320   if (cnode->HasAttr("origin_type")) {
1321     auto value = cnode->GetAttr("origin_type");
1322     dtype = static_cast<TypeId>(opt::CastToInt(value).front());
1323   }
1324   auto prim_c = std::make_shared<ops::QuantDTypeCast>();
1325   CHECK_NULL_RETURN(prim_c);
1326 
1327   prim_c->Init(kNumberTypeInt32, dtype);
1328   auto prim = prim_c->GetPrim();
1329   // copy cnode quant param to dequant
1330   if (cnode_primitive->HasAttr(quant::kQuantParam)) {
1331     prim->AddAttr(quant::kQuantParam, cnode_primitive->GetAttr(quant::kQuantParam));
1332   }
1333   auto quant_dtype_cast_primitive = NewValueNode(prim);
1334   std::vector<AnfNodePtr> op_inputs;
1335   op_inputs.push_back(quant_dtype_cast_primitive);
1336   op_inputs.push_back(cnode);
1337   auto deq_scales_tensor_info = lite::CreateTensorInfo(deq_scales.data(), sizeof(uint64_t) * deq_scales.size(),
1338                                                        {static_cast<int64_t>(deq_scales.size())}, kNumberTypeUInt64);
1339   auto deq_scales_node =
1340     opt::BuildParameterNode(func_graph, deq_scales_tensor_info, cnode->fullname_with_scope() + "-deq_scales");
1341   op_inputs.push_back(deq_scales_node);
1342 
1343   auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
1344   CHECK_NULL_RETURN(quant_cast_cnode);
1345   quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "-dequant");
1346   // set abstract
1347   if (cnode->abstract() != nullptr) {
1348     auto abstract = cnode->abstract()->Clone();
1349     quant_cast_cnode->set_abstract(abstract);
1350     if (quant::UpdateDataType(quant_cast_cnode, dtype) != RET_OK) {
1351       MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
1352       return RET_ERROR;
1353     }
1354   } else {
1355     MS_LOG(ERROR) << "input node abstract nullptr, input node name: " << cnode->fullname_with_scope();
1356     return RET_ERROR;
1357   }
1358 
1359   auto manager = func_graph->manager();
1360   if (manager == nullptr) {
1361     manager = Manage(func_graph, true);
1362   }
1363   CHECK_NULL_RETURN(manager);
1364   auto node_users = manager->node_users()[cnode];
1365   for (auto &node_user : node_users) {
1366     manager->SetEdge(node_user.first, node_user.second, quant_cast_cnode);
1367   }
1368   MS_LOG(INFO) << cnode->fullname_with_scope() << " Insert Ascend DeQuant Node.";
1369   return RET_OK;
1370 }
1371 
AdjustTransposeNodeForSingleMatMulNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1372 int InsertQuantNodeManager::AdjustTransposeNodeForSingleMatMulNode(const FuncGraphPtr &func_graph,
1373                                                                    const CNodePtr &cnode) {
1374   const std::set<PrimitivePtr> support_transpose_types = {prim::kPrimMatMulFusion, prim::kPrimMatMul,
1375                                                           prim::kPrimBatchMatMul};
1376   if (!CheckNodeInSet(cnode, support_transpose_types)) {
1377     return RET_OK;
1378   }
1379   auto prim_ptr = GetCNodePrimitive(cnode);
1380   CHECK_NULL_RETURN(prim_ptr);
1381 
1382   auto transpose_a = prim_ptr->GetAttr(mindspore::ops::kTransposeA);
1383   auto transpose_b = prim_ptr->GetAttr(mindspore::ops::kTransposeB);
1384 
1385   if (transpose_a != nullptr && GetValue<bool>(transpose_a)) {
1386     MS_LOG(ERROR) << cnode->fullname_with_scope() << " transposeA is true.";
1387     return RET_ERROR;
1388   }
1389   if (transpose_b != nullptr && GetValue<bool>(transpose_b)) {
1390     int ret = RET_ERROR;
1391     MS_LOG(INFO) << cnode->fullname_with_scope() << ":" << cnode->input(kWeightIndex + kPrimOffset)->type_name();
1392     if (cnode->input(kWeightIndex + kPrimOffset)->isa<CNode>()) {
1393       return RET_OK;
1394     } else if (cnode->input(kWeightIndex + kPrimOffset)->isa<Parameter>()) {
1395       auto manager = Manage(func_graph);
1396       CHECK_NULL_RETURN(manager);
1397       auto weight_input = cnode->input(kWeightIndex + 1);
1398       auto dst_prim = GetCNodePrimitive(cnode);
1399       MS_LOG(INFO) << cnode->fullname_with_scope() << " transpose_b is true.";
1400       dst_prim->AddAttr(mindspore::ops::kTransposeB, MakeValue(false));
1401       ParameterPtr param_node;
1402       tensor::TensorPtr tensor_info;
1403       GetParameterAndTensor(weight_input, &param_node, &tensor_info);
1404       if (tensor_info->shape_c().size() == DIMENSION_3D) {
1405         MS_LOG(INFO) << weight_input->fullname_with_scope() << " shape is " << tensor_info->shape_c()
1406                      << " will not do transpose";
1407         return RET_OK;
1408       }
1409       if (tensor_info->shape_c().size() != DIMENSION_2D) {
1410         MS_LOG(ERROR) << weight_input->fullname_with_scope() << " shape is " << tensor_info->shape_c()
1411                       << " is large than 2.";
1412         return RET_ERROR;
1413       }
1414 
1415       if (tensor_info->data_type_c() == kNumberTypeFloat32) {
1416         ret = TransposeData<float>(param_node, tensor_info);
1417       } else if (tensor_info->data_type_c() == kNumberTypeFloat16) {
1418         ret = TransposeData<Float16>(param_node, tensor_info);
1419       } else {
1420         MS_LOG(ERROR) << "transpose data only support Float32 or Float16.";
1421         return RET_OK;
1422       }
1423 
1424       if (ret != RET_OK) {
1425         MS_LOG(ERROR) << weight_input->fullname_with_scope() << " transposeData failed.";
1426         return ret;
1427       }
1428     } else {
1429       MS_LOG(ERROR) << "Dont support type is " << cnode->input(kWeightIndex + kPrimOffset)->type_name();
1430       return RET_ERROR;
1431     }
1432   }
1433   return RET_OK;
1434 }
1435 
AdjustTransposeNodeForMatMul(const FuncGraphPtr & func_graph)1436 int InsertQuantNodeManager::AdjustTransposeNodeForMatMul(const FuncGraphPtr &func_graph) {
1437   auto cnodes = func_graph->GetOrderedCnodes();
1438   for (auto &cnode : cnodes) {
1439     auto ret = AdjustTransposeNodeForSingleMatMulNode(func_graph, cnode);
1440     if (ret != RET_OK) {
1441       MS_LOG(ERROR) << cnode->fullname_with_scope() << " Adjust Transpose Node failed.";
1442       return ret;
1443     }
1444   }
1445   return RET_OK;
1446 }
1447 
InsertTransposeNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index)1448 int InsertQuantNodeManager::InsertTransposeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index) {
1449   auto prim_ptr = GetCNodePrimitive(cnode);
1450   CHECK_NULL_RETURN(prim_ptr);
1451   std::vector<int> perm;
1452   ShapeVector shape;
1453   auto ret = opt::FetchShapeFromAbstract(cnode->input(index)->abstract(), &shape);
1454   if (ret != RET_OK) {
1455     MS_LOG(ERROR) << "Fetch shape from abstract failed.";
1456     return RET_OK;
1457   }
1458   if (shape.size() == DIMENSION_2D) {
1459     perm = {1, 0};
1460   } else if (shape.size() == DIMENSION_3D) {
1461     perm = {0, 2, 1};
1462   } else if (shape.size() == DIMENSION_4D) {
1463     perm = {0, 1, 3, 2};
1464   } else {
1465     MS_LOG(ERROR) << shape.size() << " is invalid.";
1466     return RET_ERROR;
1467   }
1468   auto transpose = opt::GenTransposeNode(func_graph, cnode->input(index), perm,
1469                                          cnode->input(index)->fullname_with_scope() + "-transpose");
1470   auto manager = Manage(func_graph);
1471   MS_ASSERT(manager != nullptr);
1472   manager->SetEdge(cnode, kWeightIndex + kPrimOffset, transpose);
1473   prim_ptr->set_attr(mindspore::ops::kTransposeB, MakeValue(false));
1474   return RET_OK;
1475 }
1476 }  // namespace mindspore::lite::quant
1477