1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "mindspore/lite/tools/converter/quantizer/insert_quant_node_manager.h"
20 #include <memory>
21 #include <set>
22 #include <vector>
23 #include <string>
24 #include <algorithm>
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/lite_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "tools/optimizer/graph/node_infershape.h"
29 #include "tools/optimizer/common/gllo_utils.h"
30 #include "tools/optimizer/common/format_utils.h"
31 #include "tools/common/node_util.h"
32 #include "tools/common/tensor_util.h"
33 #include "tools/converter/quantizer/fse_decoder.h"
34 #include "tools/converter/adapter/acl/mapper/tbe_op_def.h"
35 #include "ops/fse_decode.h"
36 #include "ops/op_name.h"
37 #include "ops/auto_generate/gen_lite_ops.h"
38 #include "ops/fusion/mul_fusion.h"
39 #include "ops/fusion/add_fusion.h"
40 #include "ops/fusion/mat_mul_fusion.h"
41 #include "ops/array_ops.h"
42 #include "ir/dtype.h"
43
44 namespace mindspore::lite::quant {
45 namespace {
46 constexpr size_t kMinSize2 = 2;
47 constexpr size_t kMinSize3 = 3;
48 constexpr size_t kTableExtend = 3;
49 constexpr size_t kAlignOffset = 7;
50 constexpr size_t kInt32Mask = 31;
51 constexpr int kLastFisrtIndex = -1;
52 constexpr int kLastSecondIndex = -2;
53 const char *ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding";
54 constexpr char IN_STRATEGY[] = "in_strategy";
55 } // namespace
SetCastNodeAbstract(const CNodePtr & cnode,const AnfNodePtr & input_node,const CNodePtr & cast_cnode)56 int InsertQuantNodeManager::SetCastNodeAbstract(const CNodePtr &cnode, const AnfNodePtr &input_node,
57 const CNodePtr &cast_cnode) {
58 CHECK_NULL_RETURN(cnode);
59 CHECK_NULL_RETURN(input_node);
60 CHECK_NULL_RETURN(cast_cnode);
61
62 AbstractBasePtr abstract;
63 if (cnode->abstract() != nullptr) {
64 abstract = cnode->abstract()->Clone();
65 } else if (input_node->abstract() != nullptr) {
66 abstract = input_node->abstract()->Clone();
67 } else {
68 MS_LOG(ERROR) << "Abstract is nullptr, cnode name: " << cnode->fullname_with_scope()
69 << " input node: " << input_node->fullname_with_scope();
70 return RET_NULL_PTR;
71 }
72 cast_cnode->set_abstract(abstract);
73 return RET_OK;
74 }
75
76 // If dtype can be fetched, check data type, otherwise return RET_OK
CheckDataType(const AnfNodePtr & input_node,TypeId check_type_id) const77 int InsertQuantNodeManager::CheckDataType(const AnfNodePtr &input_node, TypeId check_type_id) const {
78 bool is_graph_input = IsGraphInput(input_node);
79 if (!input_node->isa<mindspore::CNode>() && !is_graph_input) {
80 return RET_NO_CHANGE;
81 }
82 bool is_special_node =
83 input_node->isa<mindspore::CNode>() && opt::IsSpecialType(input_node->cast<mindspore::CNodePtr>());
84 if (!is_special_node || is_graph_input) {
85 TypeId type_id;
86 auto ret = opt::GetDataTypeFromAnfNode(input_node, &type_id);
87 if (ret != RET_OK) {
88 MS_LOG(WARNING) << "Fetch DataType from cnode failed.";
89 return RET_OK;
90 }
91 if (type_id != check_type_id) {
92 return RET_NO_CHANGE;
93 }
94 }
95 return RET_OK;
96 }
97
InsertDynamicQuantWithIndex(const FuncGraphPtr & graph,const CNodePtr & cnode,size_t index,bool activation_channel)98 int InsertQuantNodeManager::InsertDynamicQuantWithIndex(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index,
99 bool activation_channel) {
100 auto primitive = std::make_shared<ops::DynamicQuant>();
101 CHECK_NULL_RETURN(primitive);
102 auto primitive_c = primitive->GetPrim();
103 primitive->set_dst_type(dst_type_);
104 bool symmetric = activation_channel ? true : false;
105 primitive->set_symmetric(symmetric);
106 primitive->set_activation_channel(activation_channel);
107 if (activation_channel && SetPreferAxes(cnode, index, primitive) != RET_OK) {
108 MS_LOG(ERROR) << "Set prefer axis failed, " << cnode->fullname_with_scope();
109 return RET_ERROR;
110 }
111 auto dynamic_quant_cnode = graph->NewCNode(primitive_c, {cnode->input(index)});
112 CHECK_NULL_RETURN(dynamic_quant_cnode);
113 auto name = cnode->fullname_with_scope() + "_dynamic_cast_node_" + std::to_string(index);
114 dynamic_quant_cnode->set_fullname_with_scope(name);
115 CHECK_NULL_RETURN(cnode->abstract());
116 auto abstract = cnode->abstract()->Clone();
117 if (abstract == nullptr) {
118 MS_LOG(ERROR) << "Abstract of node is nullptr, " << cnode->fullname_with_scope();
119 return RET_NULL_PTR;
120 }
121 dynamic_quant_cnode->set_abstract(abstract);
122 abstract->set_shape(cnode->input(index)->Shape());
123 auto ret = UpdateDataType(dynamic_quant_cnode, dst_type_);
124 if (ret != RET_OK) {
125 MS_LOG(ERROR) << cnode->fullname_with_scope() << " set new dtype failed.";
126 return ret;
127 }
128 ret = MarkDynamicQuantize(dynamic_quant_cnode);
129 if (ret != RET_OK) {
130 MS_LOG(ERROR) << cnode->fullname_with_scope() << " mark quant type failed.";
131 return ret;
132 }
133 cnode->set_input(index, dynamic_quant_cnode);
134 return RET_OK;
135 }
136
SetPreferAxes(const CNodePtr & cnode,size_t index,const std::shared_ptr<ops::DynamicQuant> & dynamic_primitive)137 int InsertQuantNodeManager::SetPreferAxes(const CNodePtr &cnode, size_t index,
138 const std::shared_ptr<ops::DynamicQuant> &dynamic_primitive) {
139 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
140 if (primitive->name() == ops::kNameMatMulFusion || primitive->name() == ops::kNameMatMul) {
141 auto matmul_prim = api::MakeShared<ops::MatMul>(primitive);
142 CHECK_NULL_RETURN(matmul_prim);
143 auto shape = opt::GetAnfNodeOutputShape(cnode->input(index), 0);
144 std::vector<int> prefer_axes;
145 for (int i = 0; i < static_cast<int>(shape.size()) - C2NUM; ++i) {
146 prefer_axes.push_back(i);
147 }
148 // For MatMul A
149 if (index == kInputIndex + kPrimOffset) {
150 if (matmul_prim->GetAttr(ops::kTransposeA) != nullptr && matmul_prim->get_transpose_a()) {
151 prefer_axes.push_back(kLastFisrtIndex);
152 dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
153 dynamic_primitive->set_transpose(true);
154 } else {
155 prefer_axes.push_back(kLastSecondIndex);
156 dynamic_primitive->set_prefer_axis(kLastSecondIndex);
157 dynamic_primitive->set_transpose(false);
158 }
159 }
160 // For MatMul B
161 if (index == kWeightIndex + kPrimOffset) {
162 if (matmul_prim->GetAttr(ops::kTransposeB) != nullptr && matmul_prim->get_transpose_b()) {
163 prefer_axes.push_back(kLastSecondIndex);
164 dynamic_primitive->set_prefer_axis(kLastSecondIndex);
165 dynamic_primitive->set_transpose(true);
166 } else {
167 prefer_axes.push_back(kLastFisrtIndex);
168 dynamic_primitive->set_prefer_axis(kLastFisrtIndex);
169 dynamic_primitive->set_transpose(false);
170 }
171 }
172 dynamic_primitive->set_prefer_axes(prefer_axes);
173 } else {
174 MS_LOG(WARNING) << "cnode don't need prefer axis, cnode name: " << cnode->fullname_with_scope();
175 }
176 return RET_OK;
177 }
178
NewDynamicQuantNode(const FuncGraphPtr & graph,const CNodePtr & cnode,bool activation_channel)179 int InsertQuantNodeManager::NewDynamicQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
180 bool activation_channel) {
181 auto op_name = cnode->fullname_with_scope();
182 if (cnode->size() < kMinSize3) {
183 MS_LOG(ERROR) << op_name << " cnode size:" << cnode->size() << " < 3.";
184 return RET_ERROR;
185 }
186 auto input = cnode->input(kInputIndex + kPrimOffset);
187 auto weight = cnode->input(kWeightIndex + kPrimOffset);
188 if (activation_channel && (input->isa<mindspore::CNode>() || IsGraphInput(input)) &&
189 (weight->isa<mindspore::CNode>() || IsGraphInput(weight))) {
190 return RET_NOT_SUPPORT;
191 }
192 if (input->isa<mindspore::CNode>() || IsGraphInput(input)) {
193 auto ret = InsertDynamicQuantWithIndex(graph, cnode, kInputIndex + kPrimOffset, activation_channel);
194 if (ret != RET_OK) {
195 MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
196 }
197 }
198 if (weight->isa<mindspore::CNode>() || IsGraphInput(weight)) {
199 auto ret = InsertDynamicQuantWithIndex(graph, cnode, kWeightIndex + kPrimOffset, activation_channel);
200 if (ret != RET_OK) {
201 MS_LOG(ERROR) << "Insert dynamic quant with index failed.";
202 }
203 }
204 return RET_OK;
205 }
206
MarkDynamicQuantize(const CNodePtr & cnode)207 int InsertQuantNodeManager::MarkDynamicQuantize(const CNodePtr &cnode) {
208 CHECK_NULL_RETURN(cnode);
209 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
210 CHECK_NULL_RETURN(primitive);
211 auto quant_param_holder = GetCNodeQuantHolder(primitive);
212 quant_param_holder->set_quant_type(quant::QUANT_DYNAMIC);
213 return RET_OK;
214 }
215
InsertDynamicQuantNode(const FuncGraphPtr & graph,const std::set<PrimitivePtr> & support_dynamic_quant_ops,const std::set<std::string> & skip_quant_node,bool activation_channel)216 int InsertQuantNodeManager::InsertDynamicQuantNode(const FuncGraphPtr &graph,
217 const std::set<PrimitivePtr> &support_dynamic_quant_ops,
218 const std::set<std::string> &skip_quant_node,
219 bool activation_channel) {
220 CHECK_NULL_RETURN(graph);
221 auto cnodes = graph->GetOrderedCnodes();
222 for (auto &cnode : cnodes) {
223 auto op_name = cnode->fullname_with_scope();
224 if (skip_quant_node.find(op_name) != skip_quant_node.end()) {
225 MS_LOG(INFO) << op_name << " is skip dynamic quant.";
226 continue;
227 }
228 auto ret = CheckDataType(cnode, kNumberTypeFloat32);
229 if (ret == RET_NO_CHANGE) {
230 continue;
231 }
232 if (opt::IsSpecialType(cnode)) {
233 continue;
234 }
235 auto is_support_node = CheckNodeInSet(cnode, support_dynamic_quant_ops);
236 if (!is_support_node) {
237 auto type = NodePrimitiveType(cnode);
238 MS_LOG(INFO) << "node:" << op_name << " type:" << type << " will not quantify.";
239 continue;
240 }
241 ret = NewDynamicQuantNode(graph, cnode, activation_channel);
242 if (ret == RET_NOT_SUPPORT) {
243 continue;
244 }
245 if (ret != RET_OK) {
246 MS_LOG(ERROR) << "node:" << op_name << " new dynamic quant node failed.";
247 return ret;
248 }
249 ret = MarkDynamicQuantize(cnode);
250 if (ret != RET_OK) {
251 MS_LOG(ERROR) << "node:" << op_name << " new mark dynamic quant node failed.";
252 return ret;
253 }
254 ret = UpdateDataType(cnode, kNumberTypeFloat32);
255 if (ret != RET_OK) {
256 MS_LOG(ERROR) << "node:" << op_name << " update datatype failed.";
257 return ret;
258 }
259 }
260 return RET_OK;
261 }
262
InsertDequantNode(const FuncGraphPtr & graph)263 int InsertQuantNodeManager::InsertDequantNode(const FuncGraphPtr &graph) {
264 CHECK_NULL_RETURN(graph);
265 auto cnodes = graph->GetOrderedCnodes();
266 for (auto &cnode : cnodes) {
267 quant::QuantType curr_quant_type;
268 if (GetQuantType(cnode, &curr_quant_type) != RET_OK) {
269 MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
270 return RET_ERROR;
271 }
272 if (curr_quant_type != quant::QUANT_ALL) {
273 MS_LOG(INFO) << "Invalid cnode quant type, cnode name: " << cnode->fullname_with_scope()
274 << " quant type: " << curr_quant_type;
275 continue;
276 }
277 auto status = InsertForwardCastNode(graph, cnode, kNumberTypeFloat32, curr_quant_type);
278 if (status != RET_OK) {
279 MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
280 return status;
281 }
282 // DetectionPostProcess op(Uint8toFp32, not need backward cast node)
283 if (!CheckNodeInSet(cnode, kUint8toFP32Operator)) {
284 status = InsertBackwardCastNode(graph, cnode, kNumberTypeFloat32, curr_quant_type);
285 if (status != RET_OK) {
286 MS_LOG(ERROR) << "InsertBackwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
287 return status;
288 }
289 }
290 } // for
291 return RET_OK;
292 }
293
InsertQuantDtypeCastNodeNew(const FuncGraphPtr & graph,const CNodePtr & cnode,InsertDirection insert_direction,TypeId cast_dtype,CastNodeType cast_node_type,size_t index,const AnfNodePtr & output_node)294 int InsertQuantNodeManager::InsertQuantDtypeCastNodeNew(const FuncGraphPtr &graph, const CNodePtr &cnode,
295 InsertDirection insert_direction, TypeId cast_dtype,
296 CastNodeType cast_node_type, size_t index,
297 const AnfNodePtr &output_node) {
298 CHECK_NULL_RETURN(graph);
299 CHECK_NULL_RETURN(cnode);
300 if (insert_direction == FORWARD) {
301 return InsertForwardQuantNodeNew(graph, cnode, cast_dtype, index, cast_node_type);
302 } else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
303 return InsertBackwardDeQuantNode(graph, cnode, cast_dtype, index, output_node);
304 }
305 MS_LOG(ERROR) << "Invalid insert direction: " << insert_direction;
306 return RET_NOT_SUPPORT;
307 }
308
InsertQuantDtypeCastNode(const FuncGraphPtr & graph,const CNodePtr & cnode,InsertDirection insert_direction,TypeId cast_dtype,CastNodeType cast_node_type,size_t index,const AnfNodePtr & output_node)309 int InsertQuantNodeManager::InsertQuantDtypeCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
310 InsertDirection insert_direction, TypeId cast_dtype,
311 CastNodeType cast_node_type, size_t index,
312 const AnfNodePtr &output_node) {
313 CHECK_NULL_RETURN(graph);
314 CHECK_NULL_RETURN(cnode);
315 if (insert_direction == FORWARD) {
316 return InsertForwardQuantNode(graph, cnode, cast_dtype, index, cast_node_type);
317 } else if (insert_direction == BACKWARD && cast_node_type == kDeQuant) {
318 return InsertBackwardDeQuantNode(graph, cnode, cast_dtype, index, output_node);
319 }
320 MS_LOG(ERROR) << "Invalid insert direction: " << insert_direction;
321 return RET_NOT_SUPPORT;
322 }
323
InsertForwardQuantNodeNew(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,size_t index,CastNodeType cast_node_type)324 int InsertQuantNodeManager::InsertForwardQuantNodeNew(const FuncGraphPtr &graph, const CNodePtr &cnode,
325 TypeId cast_dtype, size_t index, CastNodeType cast_node_type) {
326 if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
327 MS_LOG(ERROR) << "Invalid cast dtype: " << cast_dtype;
328 return RET_NOT_SUPPORT;
329 }
330
331 auto input_node = cnode->input(index);
332 CHECK_NULL_RETURN(input_node);
333 if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
334 MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope();
335 return RET_ERROR;
336 }
337 if (CheckDataType(input_node, cast_dtype) != RET_OK) {
338 return RET_NO_CHANGE;
339 }
340 // insert forward cast_node
341 TypeId src_dtype;
342 TypeId dst_dtype;
343 std::vector<schema::QuantParamT> cast_input_quant_params;
344 std::vector<schema::QuantParamT> cast_output_quant_params;
345 if (cast_node_type == kQuant) {
346 src_dtype = cast_dtype;
347 dst_dtype = kNumberTypeInt8;
348 cast_output_quant_params = quant::GetInputNodeQuantParam(cnode, index);
349 std::copy(cast_output_quant_params.cbegin(), cast_output_quant_params.cend(),
350 std::back_inserter(cast_input_quant_params));
351 // Uint8toInt8
352 if (src_dtype == kNumberTypeUInt8) {
353 for (auto &quant_param : cast_input_quant_params) {
354 quant_param.zeroPoint += kU8ZeroPointOffset;
355 }
356 }
357 } else {
358 src_dtype = kNumberTypeInt8;
359 dst_dtype = cast_dtype;
360 auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
361 auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<mindspore::Primitive>>(input_cnode->input(0));
362 if (input_cnode_primitive_c == nullptr) {
363 MS_LOG(DEBUG) << "input: " << index << " " << input_cnode->fullname_with_scope() << ": "
364 << " PrimitiveC is null";
365 return RET_NO_CHANGE;
366 }
367 auto quantization_param_value = input_cnode_primitive_c->GetAttr(quant::kQuantParam);
368 MS_CHECK_TRUE_MSG(quantization_param_value != nullptr, RET_ERROR, "quantization_param_value is nullptr.");
369 auto quantization_param_list = GetValue<std::vector<QuantizationParamPtr>>(quantization_param_value);
370 if (quantization_param_list.empty()) {
371 MS_LOG(ERROR) << input_node->fullname_with_scope() << " quantization param Not exist.";
372 return RET_ERROR;
373 }
374 cast_input_quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_param_list.front());
375 std::copy(cast_input_quant_params.cbegin(), cast_input_quant_params.cend(),
376 std::back_inserter(cast_output_quant_params));
377 }
378 ValueNodePtr new_primitive =
379 NewQuantCastPrimitive(src_dtype, dst_dtype, input_node, cast_output_quant_params, 0, true);
380 CHECK_NULL_RETURN(new_primitive);
381 std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
382 auto quant_cast_cnode = graph->NewCNode(op_inputs);
383 CHECK_NULL_RETURN(quant_cast_cnode);
384 quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
385 "_pre");
386 // set abstract
387 if (input_node->abstract() != nullptr) {
388 auto abstract = input_node->abstract()->Clone();
389 quant_cast_cnode->set_abstract(abstract);
390 if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
391 MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
392 return RET_ERROR;
393 }
394 } else {
395 MS_LOG(INFO) << "input node abstract nullptr, input node name: " << input_node->fullname_with_scope();
396 }
397 auto manager = graph->manager();
398 if (manager == nullptr) {
399 manager = Manage(graph, true);
400 }
401 CHECK_NULL_RETURN(manager);
402 manager->SetEdge(cnode, index, quant_cast_cnode);
403 MS_LOG(INFO) << "InsertForwardQuantNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
404 << " dst_type: " << dst_dtype;
405 return RET_OK;
406 }
407
InsertForwardQuantNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,size_t index,CastNodeType cast_node_type)408 int InsertQuantNodeManager::InsertForwardQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
409 size_t index, CastNodeType cast_node_type) {
410 if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
411 MS_LOG(ERROR) << "Invalid cast dtype: " << cast_dtype;
412 return RET_NOT_SUPPORT;
413 }
414
415 auto input_node = cnode->input(index);
416 CHECK_NULL_RETURN(input_node);
417 if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
418 MS_LOG(ERROR) << "Invalid input node, input node name: " << input_node->fullname_with_scope();
419 return RET_ERROR;
420 }
421 if (CheckDataType(input_node, cast_dtype) != RET_OK) {
422 return RET_NO_CHANGE;
423 }
424 // insert forward cast_node
425 TypeId src_dtype;
426 TypeId dst_dtype;
427 std::vector<schema::QuantParamT> input_quant_params;
428 std::vector<schema::QuantParamT> output_quant_params;
429 if (cast_node_type == kQuant) {
430 src_dtype = cast_dtype;
431 dst_dtype = kNumberTypeInt8;
432 auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
433 CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
434 if (curr_primitive_quant_param_holder->get_input_quant_params().size() < index) {
435 MS_LOG(ERROR) << "quant param is invalid.";
436 return RET_ERROR;
437 }
438 output_quant_params = curr_primitive_quant_param_holder->get_input_quant_params()[index - 1];
439 std::copy(output_quant_params.cbegin(), output_quant_params.cend(), std::back_inserter(input_quant_params));
440 // Uint8toInt8
441 if (src_dtype == kNumberTypeUInt8) {
442 for (auto &quant_param : input_quant_params) {
443 quant_param.zeroPoint += kU8ZeroPointOffset;
444 }
445 }
446 } else {
447 src_dtype = kNumberTypeInt8;
448 dst_dtype = cast_dtype;
449 auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
450 auto input_cnode_primitive_c = GetValueNode<std::shared_ptr<mindspore::Primitive>>(input_cnode->input(0));
451 if (input_cnode_primitive_c == nullptr) {
452 MS_LOG(DEBUG) << "input: " << index << " " << input_cnode->fullname_with_scope() << ": "
453 << " PrimitiveC is null";
454 return RET_NO_CHANGE;
455 }
456 auto input_primitive_quant_param_holder = GetCNodeQuantHolder(input_cnode_primitive_c);
457 if (input_primitive_quant_param_holder->get_output_quant_params().empty()) {
458 MS_LOG(ERROR) << "output quant param is empty.";
459 return RET_ERROR;
460 }
461 input_quant_params = input_primitive_quant_param_holder->get_output_quant_params()[0];
462 std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
463 }
464 ValueNodePtr new_primitive =
465 NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params, 0, false);
466 CHECK_NULL_RETURN(new_primitive);
467 std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
468 auto quant_cast_cnode = graph->NewCNode(op_inputs);
469 CHECK_NULL_RETURN(quant_cast_cnode);
470 quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
471 "_pre");
472 // set abstract
473 if (input_node->abstract() != nullptr) {
474 auto abstract = input_node->abstract()->Clone();
475 quant_cast_cnode->set_abstract(abstract);
476 if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
477 MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
478 return RET_ERROR;
479 }
480 } else {
481 MS_LOG(INFO) << "input node abstract nullptr, input node name: " << input_node->fullname_with_scope();
482 }
483 auto manager = graph->manager();
484 if (manager == nullptr) {
485 manager = Manage(graph, true);
486 }
487 CHECK_NULL_RETURN(manager);
488 manager->SetEdge(cnode, index, quant_cast_cnode);
489 MS_LOG(INFO) << "InsertForwardQuantNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
490 << " dst_type: " << dst_dtype;
491 return RET_OK;
492 }
493
InsertBackwardDeQuantNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,size_t index,const AnfNodePtr & output_node)494 int InsertQuantNodeManager::InsertBackwardDeQuantNode(const FuncGraphPtr &graph, const CNodePtr &cnode,
495 TypeId cast_dtype, size_t index, const AnfNodePtr &output_node) {
496 if (cast_dtype != kNumberTypeUInt8 && cast_dtype != kNumberTypeFloat32) {
497 MS_LOG(ERROR) << "Invalid cast dtype: " << cast_dtype;
498 return RET_NOT_SUPPORT;
499 }
500 CHECK_NULL_RETURN(output_node);
501 // If cnode or outputnode is QuantDTypeCast, do nothing.
502 if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast) ||
503 opt::CheckPrimitiveType(output_node, prim::kPrimQuantDTypeCast)) {
504 return RET_NO_CHANGE;
505 }
506 auto ret = CheckDataType(output_node, cast_dtype);
507 if (ret != RET_OK) {
508 MS_LOG(ERROR) << "Check data type failed, cnode name: " << output_node->fullname_with_scope();
509 return ret;
510 }
511 auto manager = graph->manager();
512 if (manager == nullptr) {
513 manager = Manage(graph, true);
514 }
515 CHECK_NULL_RETURN(manager);
516
517 // insert backward cast_node
518 TypeId src_dtype = kNumberTypeInt8;
519 TypeId dst_dtype = cast_dtype;
520 std::vector<schema::QuantParamT> input_quant_params;
521 std::vector<schema::QuantParamT> output_quant_params;
522
523 auto curr_primitive_quant_param_holder = GetCNodeQuantHolder(cnode);
524 CHECK_NULL_RETURN(curr_primitive_quant_param_holder);
525 if (curr_primitive_quant_param_holder->get_output_quant_params().empty()) {
526 MS_LOG(ERROR) << "quant param is invalid.";
527 return RET_ERROR;
528 }
529 input_quant_params = curr_primitive_quant_param_holder->get_output_quant_params().front();
530 std::copy(input_quant_params.cbegin(), input_quant_params.cend(), std::back_inserter(output_quant_params));
531 // Int8toUint8
532 if (dst_dtype == kNumberTypeUInt8) {
533 for (auto &quant_param : output_quant_params) {
534 quant_param.zeroPoint += kU8ZeroPointOffset;
535 }
536 }
537 ValueNodePtr new_primitive =
538 NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, output_quant_params, 0, false);
539 CHECK_NULL_RETURN(new_primitive);
540 std::vector<AnfNodePtr> op_inputs = {new_primitive, cnode->cast<AnfNodePtr>()};
541 auto quant_cast_cnode = graph->NewCNode(op_inputs);
542 MS_CHECK_TRUE_MSG(quant_cast_cnode != nullptr, RET_NULL_PTR, "quant_cast_cnode is nullptr.");
543 quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_dtype_cast_" + std::to_string(index) +
544 "_post");
545 if (SetCastNodeAbstract(cnode, output_node, quant_cast_cnode) != RET_OK) {
546 MS_LOG(ERROR) << "SetCastNodeAbstract failed.";
547 return RET_ERROR;
548 }
549 if (quant::UpdateDataType(quant_cast_cnode, dst_dtype) != RET_OK) {
550 MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
551 return RET_ERROR;
552 }
553 manager->SetEdge(output_node, index, quant_cast_cnode);
554 MS_LOG(INFO) << "InsertBackwardDeQuantNode cnode name: " << cnode->fullname_with_scope() << " src dtype:" << src_dtype
555 << " dst_type: " << dst_dtype;
556 return RET_OK;
557 }
558
InsertForwardCastNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,quant::QuantType curr_quant_type)559 int InsertQuantNodeManager::InsertForwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
560 quant::QuantType curr_quant_type) {
561 // inputs
562 for (size_t index = 1; index < cnode->size(); index++) {
563 auto input_node = cnode->input(index);
564 CHECK_NULL_RETURN(input_node);
565 if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
566 MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input.";
567 continue;
568 }
569 quant::QuantType pre_quant_type = quant::QUANT_NONE;
570 if (input_node->isa<mindspore::CNode>()) {
571 if (GetQuantType(input_node->cast<mindspore::CNodePtr>(), &pre_quant_type) != RET_OK) {
572 MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
573 return RET_ERROR;
574 }
575 }
576 if (pre_quant_type == quant::QUANT_NONE && curr_quant_type == quant::QUANT_ALL) {
577 auto status = InsertQuantDtypeCastNode(graph, cnode, FORWARD, cast_dtype, kQuant, index, nullptr);
578 if (status != RET_OK && status != RET_NO_CHANGE) {
579 MS_LOG(ERROR) << "InsertQuantDtypeCastNode kQuant failed, cnode name: " << cnode->fullname_with_scope();
580 return status;
581 }
582 }
583 }
584 return RET_OK;
585 }
586
InsertCastNodeForFullQuant(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,quant::QuantType curr_quant_type)587 int InsertQuantNodeManager::InsertCastNodeForFullQuant(const FuncGraphPtr &graph, const CNodePtr &cnode,
588 TypeId cast_dtype, quant::QuantType curr_quant_type) {
589 // inputs
590 for (size_t index = 1; index < cnode->size(); index++) {
591 auto input_node = cnode->input(index);
592 CHECK_NULL_RETURN(input_node);
593 if (!input_node->isa<mindspore::CNode>() && !IsGraphInput(input_node)) {
594 MS_LOG(DEBUG) << "Invalid input node, not CNode and graph input.";
595 continue;
596 }
597 quant::QuantType pre_quant_type = quant::QUANT_NONE;
598 if (input_node->isa<mindspore::CNode>()) {
599 if (GetQuantTypeNew(input_node->cast<mindspore::CNodePtr>(), &pre_quant_type) != RET_OK) {
600 MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
601 return RET_ERROR;
602 }
603 }
604 if (pre_quant_type == quant::QUANT_NONE && curr_quant_type == quant::QUANT_ALL) {
605 auto status = InsertQuantDtypeCastNodeNew(graph, cnode, FORWARD, cast_dtype, kQuant, index, nullptr);
606 if (status != RET_OK && status != RET_NO_CHANGE) {
607 MS_LOG(ERROR) << "InsertQuantDtypeCastNode kQuant failed, cnode name: " << cnode->fullname_with_scope();
608 return status;
609 }
610 } else if (pre_quant_type == quant::QUANT_ALL && curr_quant_type == quant::QUANT_NONE) {
611 auto status = InsertQuantDtypeCastNodeNew(graph, cnode, FORWARD, cast_dtype, kDeQuant, index, nullptr);
612 if (status != RET_OK && status != RET_NO_CHANGE) {
613 MS_LOG(ERROR) << "InsertQuantDtypeCastNode kDeQuant failed, cnode name: " << cnode->fullname_with_scope();
614 return status;
615 }
616 }
617 }
618 return RET_OK;
619 }
620
InsertBackwardCastNode(const FuncGraphPtr & graph,const CNodePtr & cnode,TypeId cast_dtype,quant::QuantType curr_quant_type)621 int InsertQuantNodeManager::InsertBackwardCastNode(const FuncGraphPtr &graph, const CNodePtr &cnode, TypeId cast_dtype,
622 quant::QuantType curr_quant_type) {
623 // outputs
624 auto manager = graph->manager();
625 if (manager == nullptr) {
626 manager = Manage(graph, true);
627 }
628 CHECK_NULL_RETURN(manager);
629 auto node_users = manager->node_users()[cnode];
630 for (auto &node_user : node_users) {
631 auto output_cnode = node_user.first->cast<CNodePtr>();
632 quant::QuantType post_quant_type;
633 if (GetQuantType(output_cnode, &post_quant_type) != RET_OK) {
634 MS_LOG(ERROR) << "Get quant type failed, cnode name: " << output_cnode->fullname_with_scope();
635 return RET_ERROR;
636 }
637 if (curr_quant_type == quant::QUANT_ALL && post_quant_type == quant::QUANT_NONE) {
638 auto status =
639 InsertQuantDtypeCastNode(graph, cnode, BACKWARD, cast_dtype, kDeQuant, node_user.second, node_user.first);
640 if (status != RET_OK && status != RET_NO_CHANGE) {
641 MS_LOG(ERROR) << "InsertQuantDtypeCastNode dequant failed, cnode name: " << cnode->fullname_with_scope();
642 return status;
643 }
644 }
645 } // node_users
646 return RET_OK;
647 }
648
InsertQuantDtypeCastFlyNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,TypeId src_dtype,TypeId dst_dtype,int axis,bool is_quant_attribute)649 int InsertQuantNodeManager::InsertQuantDtypeCastFlyNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
650 size_t input_index, TypeId src_dtype, TypeId dst_dtype,
651 int axis, bool is_quant_attribute) {
652 MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
653 auto cnode_primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
654 if (cnode_primitive == nullptr) {
655 MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
656 return RET_ERROR;
657 }
658 auto input_node = cnode->input(input_index);
659 if (!input_node->isa<mindspore::Parameter>()) {
660 MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
661 return RET_ERROR;
662 }
663 auto input_quant_params = quant::GetInputNodeQuantParam(cnode, input_index);
664
665 CNodePtr quant_cast_cnode = nullptr;
666 if (is_quant_attribute) {
667 ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_quant_params, {}, axis, false);
668 MS_CHECK_TRUE_MSG(new_primitive != nullptr, RET_NULL_PTR, "New quant_cast primitive failed!");
669 std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node};
670 quant_cast_cnode = func_graph->NewCNode(op_inputs);
671 } else {
672 quant_cast_cnode =
673 CreateQuantInputCastNode(func_graph, cnode, input_node, src_dtype, dst_dtype, input_quant_params, axis);
674 }
675 CHECK_NULL_RETURN(quant_cast_cnode);
676 opt::NodeInferShape infer;
677 auto status = infer.InferShape(quant_cast_cnode);
678 if (status != RET_OK) {
679 MS_LOG(ERROR) << quant_cast_cnode->fullname_with_scope() << " InferShape failed.";
680 return RET_ERROR;
681 }
682 auto manager = func_graph->manager();
683 CHECK_NULL_RETURN(manager);
684 auto ret = manager->Replace(input_node, quant_cast_cnode);
685 if (!ret) {
686 MS_LOG(ERROR) << "Replace QuantDtypeCast failed.";
687 return RET_ERROR;
688 }
689 cnode_primitive->DelAttr(quant::kQuantParam);
690 MS_LOG(INFO) << "InsertCastNode cnode name: " << quant_cast_cnode->fullname_with_scope()
691 << " src_dtype: " << src_dtype << " dst_dtype: " << dst_dtype;
692
693 return RET_OK;
694 }
695
CreateQuantInputCastNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const AnfNodePtr input_node,TypeId src_dtype,TypeId dst_dtype,const std::vector<schema::QuantParamT> & input_quant_params,int axis)696 CNodePtr InsertQuantNodeManager::CreateQuantInputCastNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
697 const AnfNodePtr input_node, TypeId src_dtype,
698 TypeId dst_dtype,
699 const std::vector<schema::QuantParamT> &input_quant_params,
700 int axis) {
701 ValueNodePtr new_primitive = NewQuantCastPrimitive(src_dtype, dst_dtype, input_node, {}, axis, false);
702 std::vector<float> scales;
703 std::vector<int> zps;
704 std::vector<float> mean_corrs;
705 std::vector<float> var_corrs;
706 for (size_t i = 0; i < input_quant_params.size(); ++i) {
707 scales.push_back(static_cast<float>(input_quant_params.at(i).scale));
708 zps.push_back(static_cast<int64_t>(input_quant_params.at(i).zeroPoint));
709 mean_corrs.push_back(static_cast<float>(input_quant_params.at(i).meanCorr));
710 var_corrs.push_back(static_cast<float>(input_quant_params.at(i).varCorr));
711 }
712 auto scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, "scales");
713 auto zps_node = opt::BuildIntVecParameterNode(func_graph, zps, "zps");
714 auto mean_corrs_node = opt::BuildFloatVecParameterNode(func_graph, mean_corrs, "mean_corrs");
715 auto var_corrs_node = opt::BuildFloatVecParameterNode(func_graph, var_corrs, "var_corrs");
716
717 std::vector<AnfNodePtr> op_inputs = {new_primitive, input_node, scales_node,
718 zps_node, mean_corrs_node, var_corrs_node};
719 auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
720 if (quant_cast_cnode == nullptr) {
721 MS_LOG(ERROR) << "New quant cast node failed.";
722 return nullptr;
723 }
724 auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
725 int index = 0;
726 if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
727 index = 0;
728 }
729 const int quant_dtype_cast_offset = 10000;
730 quant_cast_cnode->set_fullname_with_scope(strings.at(0) + "-QuantDtypeCast-op" +
731 std::to_string(index + quant_dtype_cast_offset));
732 return quant_cast_cnode;
733 }
734
CalculateScaleZPNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,ParameterPtr * scales_node,ParameterPtr * zps_node,TypeId dst_dtype,int axis)735 int InsertQuantNodeManager::CalculateScaleZPNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
736 size_t input_index, ParameterPtr *scales_node, ParameterPtr *zps_node,
737 TypeId dst_dtype, int axis) {
738 CHECK_NULL_RETURN(scales_node);
739 CHECK_NULL_RETURN(zps_node);
740 MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
741 auto input_node = cnode->input(input_index);
742 auto input_quant_params = quant::GetInputNodeQuantParam(cnode, input_index);
743 if (input_quant_params.empty()) {
744 MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << input_index << " quant param is empty.";
745 return RET_ERROR;
746 }
747
748 if (dst_dtype == kNumberTypeFloat16) {
749 std::vector<float16> scales;
750 std::vector<float16> zps;
751 for (size_t i = 0; i < input_quant_params.size(); ++i) {
752 scales.push_back(static_cast<float16>(input_quant_params.at(i).scale * input_quant_params.at(i).varCorr));
753 zps.push_back(static_cast<float16>(-input_quant_params.at(i).zeroPoint +
754 input_quant_params.at(i).meanCorr /
755 (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr)));
756 }
757 *scales_node = opt::BuildFloat16VecParameterNode(func_graph, scales, input_node->fullname_with_scope() + "-scales");
758 *zps_node = opt::BuildFloat16VecParameterNode(func_graph, zps, input_node->fullname_with_scope() + "-zps");
759 } else {
760 std::vector<float> scales;
761 std::vector<float> zps;
762 for (size_t i = 0; i < input_quant_params.size(); ++i) {
763 scales.push_back(static_cast<float>(input_quant_params.at(i).scale * input_quant_params.at(i).varCorr));
764 zps.push_back(static_cast<float>(-input_quant_params.at(i).zeroPoint +
765 input_quant_params.at(i).meanCorr /
766 (input_quant_params.at(i).scale * input_quant_params.at(i).varCorr)));
767 }
768 *scales_node = opt::BuildFloatVecParameterNode(func_graph, scales, input_node->fullname_with_scope() + "-scales");
769 *zps_node = opt::BuildFloatVecParameterNode(func_graph, zps, input_node->fullname_with_scope() + "-zps");
770 }
771 if (*scales_node == nullptr || *zps_node == nullptr) {
772 MS_LOG(ERROR) << "Failed to build scales node, zps node ";
773 return RET_ERROR;
774 }
775 if (input_quant_params.size() > 1) {
776 ShapeVector shape;
777 if (opt::FetchShapeFromAbstract(input_node->abstract(), &shape) != lite::RET_OK) {
778 MS_LOG(ERROR) << "fetch shape failed." << input_node->fullname_with_scope();
779 return lite::RET_ERROR;
780 }
781
782 std::vector<int64_t> shape_vector = {};
783 for (size_t i = 0; i < shape.size(); i++) {
784 if (i == static_cast<size_t>(axis)) {
785 shape_vector.push_back((int64_t)input_quant_params.size());
786 } else {
787 shape_vector.push_back(1);
788 }
789 }
790 auto scales_abstract = (*scales_node)->abstract();
791 CHECK_NULL_RETURN(scales_abstract);
792 scales_abstract->set_shape(std::make_shared<abstract::Shape>(shape_vector));
793 auto zps_abstract = (*zps_node)->abstract();
794 CHECK_NULL_RETURN(zps_abstract);
795 zps_abstract->set_shape(std::make_shared<abstract::Shape>(shape_vector));
796 }
797 return RET_OK;
798 }
799
SetParallelStrategy(const CNodePtr & cnode,const std::vector<std::vector<int64_t>> & in_strategy)800 int InsertQuantNodeManager::SetParallelStrategy(const CNodePtr &cnode,
801 const std::vector<std::vector<int64_t>> &in_strategy) {
802 auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
803 CHECK_NULL_RETURN(primitive);
804 primitive->AddAttr(IN_STRATEGY, MakeValue(in_strategy));
805 return RET_OK;
806 }
807
GetAddMulNodeParallelStrategy(ShapeVector weight_shape,std::vector<int64_t> weight_strategy,int axis,bool per_channel)808 std::vector<std::vector<int64_t>> InsertQuantNodeManager::GetAddMulNodeParallelStrategy(
809 ShapeVector weight_shape, std::vector<int64_t> weight_strategy, int axis, bool per_channel) {
810 std::vector<std::vector<int64_t>> add_mul_in_strategy;
811 std::vector<int64_t> in_strategy_1 = weight_strategy;
812 add_mul_in_strategy.push_back(in_strategy_1);
813 std::vector<int64_t> in_strategy_2;
814
815 // if perlayer quant, the input2 strategy is set to 1.
816 // if perchannel quant, the input2 strategy is set by axis, the axis dim is set by matmul input strategy,
817 // the other dim is set to 1.
818 if (per_channel) {
819 for (size_t i = 0; i < weight_shape.size(); i++) {
820 if (i == static_cast<size_t>(axis) && i < weight_strategy.size()) {
821 in_strategy_2.push_back(weight_strategy.at(i));
822 } else {
823 in_strategy_2.push_back(1);
824 }
825 }
826 } else {
827 in_strategy_2.push_back(1);
828 }
829
830 add_mul_in_strategy.push_back(in_strategy_2);
831 return add_mul_in_strategy;
832 }
833
InsertAscendAntiQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,TypeId src_dtype,TypeId dst_dtype,int axis,const std::string & ascend_backend)834 int InsertQuantNodeManager::InsertAscendAntiQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
835 size_t input_index, TypeId src_dtype, TypeId dst_dtype, int axis,
836 const std::string &ascend_backend) {
837 auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
838 CHECK_NULL_RETURN(primitive);
839 MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
840 auto input_node = cnode->input(input_index);
841 auto manager = func_graph->manager();
842 CHECK_NULL_RETURN(manager);
843 std::vector<std::vector<int64_t>> cnode_in_strategy;
844 if (primitive->HasAttr(IN_STRATEGY)) {
845 cnode_in_strategy = ExtractStrategy(primitive->GetAttr(IN_STRATEGY));
846 CHECK_LESS_RETURN(cnode_in_strategy.size(), input_index);
847 MS_LOG(INFO) << "cnode: " << cnode->fullname_with_scope() << " in strategy is " << cnode_in_strategy;
848 }
849 if (!input_node->isa<mindspore::Parameter>()) {
850 MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
851 return RET_ERROR;
852 }
853
854 // parameter+cast+add+mul+matmul
855 // parameter+gather+cast+add+mul
856 auto input_quant_params = quant::GetInputNodeQuantParam(cnode, input_index);
857 if (input_quant_params.empty()) {
858 MS_LOG(ERROR) << cnode->fullname_with_scope() << " index: " << input_index << " quant param is empty.";
859 return RET_ERROR;
860 }
861
862 // Insert cast node
863 CNodePtr cast_cnode = nullptr;
864 if (ascend_backend == "910b") {
865 MS_LOG(INFO) << "The ascend_backend is 910b, it will insert antiquant node";
866 if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
867 cast_cnode = NewAscendAntiQuantCNode(func_graph, cnode, dst_dtype);
868 } else {
869 cast_cnode = NewAscendAntiQuantCNode(func_graph, input_node, dst_dtype);
870 }
871 } else {
872 if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
873 cast_cnode = NewCastNode(func_graph, cnode, dst_dtype);
874 } else {
875 cast_cnode = NewCastNode(func_graph, input_node, dst_dtype);
876 }
877 }
878
879 CHECK_NULL_RETURN(cast_cnode);
880 // cast node do not need to set parallel strategy, antiquant node need set parallel strategy
881 if (primitive->HasAttr(IN_STRATEGY) && ascend_backend == "910b") {
882 std::vector<std::vector<int64_t>> cast_in_strategy;
883 std::vector<int64_t> in_strategy_1 = cnode_in_strategy[input_index - kPrimOffset];
884 cast_in_strategy.push_back(in_strategy_1);
885 auto ret = SetParallelStrategy(cast_cnode, cast_in_strategy);
886 if (ret != RET_OK) {
887 MS_LOG(ERROR) << "Fail to set cnode parallel strategy, cnode: " << cast_cnode->fullname_with_scope();
888 return RET_ERROR;
889 }
890 }
891
892 ParameterPtr scales_node;
893 ParameterPtr zps_node;
894 auto ret = CalculateScaleZPNode(func_graph, cnode, input_index, &scales_node, &zps_node, dst_dtype, axis);
895 if (ret != RET_OK) {
896 MS_LOG(ERROR) << "Fail to calculate scale & zero_point node: " << cnode->fullname_with_scope();
897 return RET_ERROR;
898 }
899
900 auto add_cnode = NewAddNode(func_graph, cast_cnode, zps_node);
901 CHECK_NULL_RETURN(add_cnode);
902
903 auto mul_cnode = NewMulNode(func_graph, add_cnode, scales_node);
904 CHECK_NULL_RETURN(mul_cnode);
905
906 if (primitive->HasAttr(IN_STRATEGY)) {
907 ShapeVector weight_shape;
908 if (opt::FetchShapeFromAbstract(input_node->abstract(), &weight_shape) != lite::RET_OK) {
909 MS_LOG(ERROR) << "fetch shape failed." << input_node->fullname_with_scope();
910 return lite::RET_ERROR;
911 }
912 std::vector<int64_t> weight_strategy = cnode_in_strategy[input_index - kPrimOffset];
913 bool per_channel = input_quant_params.size() > 1;
914 auto add_mul_in_strategy = GetAddMulNodeParallelStrategy(weight_shape, weight_strategy, axis, per_channel);
915
916 // add_cnode & mul_cnode set parallel strategy
917 ret = SetParallelStrategy(add_cnode, add_mul_in_strategy);
918 if (ret != RET_OK) {
919 MS_LOG(ERROR) << "Fail to set add cnode parallel strategy, cnode: " << add_cnode->fullname_with_scope();
920 return RET_ERROR;
921 }
922 ret = SetParallelStrategy(mul_cnode, add_mul_in_strategy);
923 if (ret != RET_OK) {
924 MS_LOG(ERROR) << "Fail to set mul cnode parallel strategy, cnode: " << mul_cnode->fullname_with_scope();
925 return RET_ERROR;
926 }
927 }
928
929 auto node_map = manager->node_users();
930
931 // Remove QuantParam
932 ret = RemoveInputNodeQuantParam(cnode, input_index);
933 if (ret != RET_OK) {
934 MS_LOG(ERROR) << "Fail to Remove node: " << input_node->fullname_with_scope() << " quant param";
935 return RET_ERROR;
936 }
937
938 AnfNodeIndexSet node_user;
939 if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
940 node_user = node_map[cnode];
941 } else {
942 node_user = node_map[input_node];
943 }
944 for (const auto &user : node_user) {
945 manager->SetEdge(user.first, user.second, mul_cnode);
946 }
947 return RET_OK;
948 }
949
InsertFSEDecodeNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,TypeId dst_dtype)950 int InsertQuantNodeManager::InsertFSEDecodeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
951 size_t input_index, TypeId dst_dtype) {
952 auto primitive = GetValueNode<std::shared_ptr<mindspore::Primitive>>(cnode->input(kPrimIndex));
953 if (primitive == nullptr) {
954 MS_LOG(ERROR) << "primitive_c is nullptr: " << cnode->fullname_with_scope();
955 return RET_ERROR;
956 }
957 MS_CHECK_LT(input_index, cnode->size(), RET_ERROR);
958 auto input_node = cnode->input(input_index);
959 if (!input_node->isa<mindspore::Parameter>()) {
960 MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << input_index << " is not parameter node.";
961 return RET_ERROR;
962 }
963 auto shape = input_node->Shape();
964 std::vector<AnfNodePtr> op_inputs;
965 int ret = CreateFSEInputs(func_graph, input_node, &op_inputs, dst_dtype);
966 if (ret != RET_OK) {
967 MS_LOG(ERROR) << "CreateFSEInputs failed.";
968 return RET_ERROR;
969 }
970
971 auto fse_decode_cnode = func_graph->NewCNode(op_inputs);
972 CHECK_NULL_RETURN(fse_decode_cnode);
973 auto strings = SplitStringToVector(cnode->fullname_with_scope(), "-op");
974 int index = 0;
975 if (!ConvertIntNum(strings.at(strings.size() - 1), &index)) {
976 index = 0;
977 }
978 const int fse_decode_offset = 20000;
979 fse_decode_cnode->set_fullname_with_scope(strings.at(0) + "-FSEDecode-op" +
980 std::to_string(index + fse_decode_offset));
981 CHECK_NULL_RETURN(cnode->abstract());
982 auto fse_abstract = cnode->abstract()->Clone();
983 fse_abstract->set_shape(shape);
984 fse_decode_cnode->set_abstract(fse_abstract);
985
986 auto manager = func_graph->manager();
987 CHECK_NULL_RETURN(manager);
988 auto ret_bool = manager->Replace(input_node, fse_decode_cnode);
989 if (!ret_bool) {
990 MS_LOG(ERROR) << "Replace QuantDtypeCast failed.";
991 return RET_ERROR;
992 }
993
994 return RET_OK;
995 }
996
CreateFSEInputs(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,std::vector<AnfNodePtr> * op_inputs,TypeId dst_dtype)997 int InsertQuantNodeManager::CreateFSEInputs(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
998 std::vector<AnfNodePtr> *op_inputs, TypeId dst_dtype) {
999 CHECK_NULL_RETURN(op_inputs);
1000 if (!input_node->isa<mindspore::Parameter>()) {
1001 MS_LOG(ERROR) << "FSEDecode input is not parameter node.";
1002 return RET_ERROR;
1003 }
1004 auto parameter_ptr = input_node->cast<ParameterPtr>();
1005 CHECK_NULL_RETURN(parameter_ptr);
1006 if (!parameter_ptr->has_default()) {
1007 MS_LOG(ERROR) << input_node->fullname_with_scope() << " parameter dont have default.";
1008 return RET_ERROR;
1009 }
1010 auto tensor = parameter_ptr->default_param()->cast<tensor::TensorPtr>();
1011 CHECK_NULL_RETURN(tensor);
1012 int8_t *data8 = reinterpret_cast<int8_t *>(tensor->data_c());
1013 size_t data_size = tensor->DataSize();
1014 FSEBuffer fse_buffer;
1015 auto ret = FSEDecoder::DecodeBuffer(data8, data_size, &fse_buffer);
1016 if (ret != RET_OK) {
1017 MS_LOG(ERROR) << input_node->fullname_with_scope() << " buffer decode failed.";
1018 return RET_ERROR;
1019 }
1020 ValueNodePtr new_primitive = NewFSEDecodePrimitive(dst_dtype, fse_buffer.curr_chunk, fse_buffer.curr_chunk_index,
1021 fse_buffer.curr_bit_count, fse_buffer.table_log);
1022 op_inputs->push_back(new_primitive);
1023
1024 // make shape to (1,chunk_size)
1025 ShapeVector shape_vector;
1026 shape_vector.push_back(1);
1027 shape_vector.push_back(fse_buffer.chunk_size);
1028 auto chunk_tensor_info =
1029 lite::CreateTensorInfo(fse_buffer.chunks, fse_buffer.chunk_size, shape_vector, kNumberTypeInt8);
1030 parameter_ptr->set_default_param(chunk_tensor_info);
1031 parameter_ptr->set_abstract(chunk_tensor_info->ToAbstract());
1032 op_inputs->push_back(input_node);
1033
1034 size_t table_size = 1u << fse_buffer.table_log;
1035 std::vector<uint16_t> states_table(table_size);
1036 std::vector<uint8_t> bit_count_table(table_size);
1037 std::vector<uint16_t> symbol_table(table_size);
1038
1039 ret = FSEDecoder::FSECreateStatesForDecoding(fse_buffer.frequency, fse_buffer.frequency_count, fse_buffer.table_log,
1040 states_table.data(), bit_count_table.data(), symbol_table.data());
1041 if (ret != RET_OK) {
1042 MS_LOG(ERROR) << "FSE create states for decoding failed.";
1043 return RET_ERROR;
1044 }
1045 std::vector<int64_t> shape = {static_cast<int64_t>(table_size)};
1046
1047 auto states_table_tensor_info =
1048 lite::CreateTensorInfo(states_table.data(), sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
1049 auto states_table_node = opt::BuildParameterNode(func_graph, states_table_tensor_info, "states_table");
1050 op_inputs->push_back(states_table_node);
1051
1052 auto bit_count_table_tensor_info =
1053 lite::CreateTensorInfo(bit_count_table.data(), sizeof(uint8_t) * table_size, shape, kNumberTypeUInt8);
1054 auto bit_count_table_node = opt::BuildParameterNode(func_graph, bit_count_table_tensor_info, "bit_count_table");
1055 op_inputs->push_back(bit_count_table_node);
1056
1057 auto symbol_table_tensor_info =
1058 lite::CreateTensorInfo(symbol_table.data(), sizeof(uint16_t) * table_size, shape, kNumberTypeUInt16);
1059 auto symbol_table_node = opt::BuildParameterNode(func_graph, symbol_table_tensor_info, "symbol_table");
1060 op_inputs->push_back(symbol_table_node);
1061
1062 auto centroids_tensor_info =
1063 lite::CreateTensorInfo(fse_buffer.centroids, sizeof(float) * fse_buffer.centroid_size,
1064 {static_cast<int64_t>(fse_buffer.centroid_size)}, kNumberTypeFloat32);
1065 auto centroids_node = opt::BuildParameterNode(func_graph, centroids_tensor_info, "centroids");
1066 op_inputs->push_back(centroids_node);
1067
1068 auto shape_tensor_info = lite::CreateTensorInfo(ConvertShapeVectorToInt32(tensor->shape_c()).data(),
1069 sizeof(int32_t) * tensor->shape_c().size(),
1070 {static_cast<int64_t>(tensor->shape_c().size())}, kNumberTypeInt32);
1071 auto shape_node = opt::BuildParameterNode(func_graph, shape_tensor_info, "input_shape");
1072 op_inputs->push_back(shape_node);
1073
1074 auto chunk_ends_tensor_info =
1075 lite::CreateTensorInfo(fse_buffer.chunk_ends, sizeof(uint64_t) * fse_buffer.chunk_ends_count,
1076 {static_cast<int64_t>(fse_buffer.chunk_ends_count)}, kNumberTypeUInt64);
1077 auto chunk_ends_node = opt::BuildParameterNode(func_graph, chunk_ends_tensor_info, "chunk_ends");
1078 op_inputs->push_back(chunk_ends_node);
1079
1080 return RET_OK;
1081 }
1082
NewCastNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,int dst_type)1083 CNodePtr InsertQuantNodeManager::NewCastNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
1084 int dst_type) {
1085 auto prim_c = std::make_shared<ops::Cast>();
1086 MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1087 auto prim = prim_c->GetPrim();
1088 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1089 MS_LOG(INFO) << "dst_type:" << dst_type;
1090 TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
1091 prim->AddAttr(ops::kDstType, type_ptr);
1092 prim->AddAttr(ATTR_NO_NEED_CONSTANT_FOLDING, MakeValue(true));
1093 std::vector<AnfNodePtr> cast_op_inputs = {NewValueNode(prim), input_node};
1094 auto cast_cnode = func_graph->NewCNode(cast_op_inputs);
1095 cast_cnode->set_fullname_with_scope(input_node->fullname_with_scope() + "-Cast");
1096 cast_cnode->set_abstract(input_node->abstract()->Clone());
1097 auto ret = UpdateDataType(cast_cnode, TypeId(dst_type));
1098 if (ret != RET_OK) {
1099 MS_LOG(ERROR) << cast_cnode->fullname_with_scope() << " set dst_type " << dst_type << " failed.";
1100 return nullptr;
1101 }
1102 return cast_cnode;
1103 }
1104
NewAscendAntiQuantCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_node,int dst_type)1105 CNodePtr InsertQuantNodeManager::NewAscendAntiQuantCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
1106 int dst_type) {
1107 auto dst_prim = std::make_shared<acl::AscendAntiQuant>();
1108 if (dst_prim == nullptr) {
1109 return nullptr;
1110 }
1111 dst_prim->AddAttr("scale", MakeValue(1.0f));
1112 dst_prim->AddAttr("offset", MakeValue(0.0f));
1113 MS_LOG(INFO) << "dst_type:" << dst_type;
1114 TypePtr type_ptr = TypeIdToType(TypeId(dst_type));
1115 dst_prim->AddAttr(ops::kOutputDType, type_ptr);
1116 std::vector<AnfNodePtr> cast_op_inputs = {NewValueNode(dst_prim), input_node};
1117 auto anti_cnode = func_graph->NewCNode(cast_op_inputs);
1118 anti_cnode->set_fullname_with_scope(input_node->fullname_with_scope() + "-AntiQuant");
1119 anti_cnode->set_abstract(input_node->abstract()->Clone());
1120 anti_cnode->abstract()->set_type(type_ptr);
1121 auto ret = UpdateDataType(anti_cnode, TypeId(dst_type));
1122 if (ret != RET_OK) {
1123 MS_LOG(ERROR) << anti_cnode->fullname_with_scope() << " set dst_type " << dst_type << " failed.";
1124 return nullptr;
1125 }
1126 return anti_cnode;
1127 }
1128
NewMulNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_1,const AnfNodePtr & input_2)1129 CNodePtr InsertQuantNodeManager::NewMulNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_1,
1130 const AnfNodePtr &input_2) {
1131 auto prim_c = std::make_shared<ops::MulFusion>();
1132 MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1133 auto prim = prim_c->GetPrim();
1134 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1135 prim->AddAttr(ATTR_NO_NEED_CONSTANT_FOLDING, MakeValue(true));
1136 std::vector<AnfNodePtr> op_inputs = {NewValueNode(prim), input_1, input_2};
1137 auto cnode = func_graph->NewCNode(op_inputs);
1138 if (cnode == nullptr) {
1139 MS_LOG(ERROR) << "cnode is nullptr.";
1140 return nullptr;
1141 }
1142 cnode->set_fullname_with_scope(input_1->fullname_with_scope() + "-" + input_2->fullname_with_scope() + "-Mul");
1143 cnode->set_abstract(input_1->abstract()->Clone());
1144 return cnode;
1145 }
1146
NewAddNode(const FuncGraphPtr & func_graph,const AnfNodePtr & input_1,const AnfNodePtr & input_2)1147 CNodePtr InsertQuantNodeManager::NewAddNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_1,
1148 const AnfNodePtr &input_2) {
1149 auto prim_c = std::make_shared<ops::AddFusion>();
1150 MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1151 auto prim = prim_c->GetPrim();
1152 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1153 prim->AddAttr(ATTR_NO_NEED_CONSTANT_FOLDING, MakeValue(true));
1154 std::vector<AnfNodePtr> op_inputs = {NewValueNode(prim), input_1, input_2};
1155 auto cnode = func_graph->NewCNode(op_inputs);
1156 if (cnode == nullptr) {
1157 MS_LOG(ERROR) << "cnode is nullptr.";
1158 return nullptr;
1159 }
1160 cnode->set_fullname_with_scope(input_1->fullname_with_scope() + "-" + input_2->fullname_with_scope() + "-Add");
1161 cnode->set_abstract(input_1->abstract()->Clone());
1162 return cnode;
1163 }
1164
NewQuantCastPrimitive(int src_type,int dst_type,const std::vector<schema::QuantParamT> & input_quant_params,const std::vector<schema::QuantParamT> & output_quant_params,int axis,bool set_quant_flag)1165 ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(int src_type, int dst_type,
1166 const std::vector<schema::QuantParamT> &input_quant_params,
1167 const std::vector<schema::QuantParamT> &output_quant_params,
1168 int axis, bool set_quant_flag) {
1169 auto prim_c = std::make_shared<ops::QuantDTypeCast>();
1170 MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1171 prim_c->Init(src_type, dst_type);
1172 prim_c->set_axis(axis);
1173 auto quant_params_holder = std::make_shared<QuantParamHolder>(input_quant_params.size(), output_quant_params.size());
1174 MS_CHECK_TRUE_MSG(quant_params_holder != nullptr, nullptr, "quant_params_holder is nullptr.");
1175 if (set_quant_flag) {
1176 quant_params_holder->set_quant_type(quant::QUANT_ALL);
1177 }
1178 quant_params_holder->set_input_quant_param(0, input_quant_params);
1179 quant_params_holder->set_output_quant_param(0, output_quant_params);
1180 auto prim = prim_c->GetPrim();
1181 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1182 prim->AddAttr("quant_params", quant_params_holder);
1183 return NewValueNode(prim);
1184 }
1185
NewQuantCastPrimitive(int src_type,int dst_type,const AnfNodePtr & input_node,const std::vector<schema::QuantParamT> & output_quant_params,int axis,bool set_quant_flag)1186 ValueNodePtr InsertQuantNodeManager::NewQuantCastPrimitive(int src_type, int dst_type, const AnfNodePtr &input_node,
1187 const std::vector<schema::QuantParamT> &output_quant_params,
1188 int axis, bool set_quant_flag) {
1189 auto prim_c = std::make_shared<ops::QuantDTypeCast>();
1190 MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1191 prim_c->Init(src_type, dst_type);
1192 prim_c->set_axis(axis);
1193 auto prim = prim_c->GetPrim();
1194 if (set_quant_flag) {
1195 prim->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_ALL)));
1196 }
1197 // Set quant param to quant_cast_cnode
1198 if (!output_quant_params.empty()) {
1199 auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(output_quant_params);
1200 std::vector<ValuePtr> quantization_list = {quantization_ptr};
1201 auto quant_ptr = std::make_shared<ValueList>(quantization_list);
1202 MS_CHECK_TRUE_MSG(quant_ptr != nullptr, nullptr, "quant_ptr is nullptr.");
1203 prim->AddAttr(quant::kQuantParam, quant_ptr);
1204 } else {
1205 MS_LOG(WARNING) << "New quant cast node's output quant param is empty, input node: "
1206 << input_node->fullname_with_scope();
1207 }
1208 return NewValueNode(prim);
1209 }
1210
NewFSEDecodePrimitive(int dst_type,uint64_t curr_chunk,int64_t curr_chunk_index,int64_t curr_bit_count,int64_t table_log)1211 ValueNodePtr InsertQuantNodeManager::NewFSEDecodePrimitive(int dst_type, uint64_t curr_chunk, int64_t curr_chunk_index,
1212 int64_t curr_bit_count, int64_t table_log) {
1213 auto prim_c = std::make_shared<ops::FSEDecode>();
1214 MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr.");
1215 prim_c->Init(dst_type, curr_chunk, curr_chunk_index, curr_bit_count, table_log);
1216
1217 auto prim = prim_c->GetPrim();
1218 MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr");
1219 return NewValueNode(prim);
1220 }
1221
InsertAscendQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1222 int InsertQuantNodeManager::InsertAscendQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
1223 for (size_t i = 1; i < cnode->size(); i++) {
1224 if (cnode->input(i)->isa<CNode>() || IsGraphInput(cnode->input(i))) {
1225 auto ret = InsertAscendQuantNode(func_graph, cnode, i);
1226 if (ret != RET_OK) {
1227 MS_LOG(ERROR) << "InsertAscendQuantNode failed.";
1228 return ret;
1229 }
1230 }
1231 }
1232 return RET_OK;
1233 }
1234
InsertAscendQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index)1235 int InsertQuantNodeManager::InsertAscendQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
1236 size_t input_index) {
1237 CHECK_NULL_RETURN(func_graph);
1238 CHECK_NULL_RETURN(cnode);
1239 auto x_q_param_origin = quant::GetInputNodeQuantParam(cnode, input_index);
1240 if (x_q_param_origin.empty()) {
1241 auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
1242 CHECK_NULL_RETURN(curr_quant_param_holder);
1243 auto input_quant_param = curr_quant_param_holder->get_input_quant_params();
1244 x_q_param_origin = input_quant_param.at(input_index - kPrimOffset);
1245 }
1246 if (x_q_param_origin.size() != kPerTensor) {
1247 MS_LOG(ERROR) << cnode->fullname_with_scope() << " x quant param size " << x_q_param_origin.size() << " != 1";
1248 return RET_ERROR;
1249 }
1250 auto x_q_param = quant::CloneQuantParam(x_q_param_origin);
1251 x_q_param.at(0).scale = 1 / x_q_param.at(0).scale;
1252 auto input_node = cnode->input(input_index);
1253 CHECK_NULL_RETURN(input_node);
1254 ValueNodePtr new_primitive = NewQuantCastPrimitive(kNumberTypeFloat32, kNumberTypeInt8, input_node, x_q_param);
1255 std::vector<AnfNodePtr> op_inputs = {new_primitive, cnode->input(input_index)};
1256 auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
1257 CHECK_NULL_RETURN(quant_cast_cnode);
1258 quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "-quant-" + std::to_string(input_index));
1259 // set abstract
1260 if (cnode->input(input_index)->abstract() != nullptr) {
1261 auto abstract = cnode->input(input_index)->abstract()->Clone();
1262 quant_cast_cnode->set_abstract(abstract);
1263 if (quant::UpdateDataType(quant_cast_cnode, kNumberTypeInt8) != RET_OK) {
1264 MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
1265 return RET_ERROR;
1266 }
1267 } else {
1268 MS_LOG(ERROR) << "input node abstract nullptr, input node name: " << cnode->fullname_with_scope();
1269 return RET_ERROR;
1270 }
1271 auto manager = func_graph->manager();
1272 if (manager == nullptr) {
1273 manager = Manage(func_graph, true);
1274 }
1275 CHECK_NULL_RETURN(manager);
1276 manager->SetEdge(cnode, input_index, quant_cast_cnode);
1277 MS_LOG(INFO) << cnode->fullname_with_scope() << " Insert Ascend QuantNode, scale: " << x_q_param.at(0).scale;
1278 return RET_OK;
1279 }
1280
InsertAscendDeQuantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1281 int InsertQuantNodeManager::InsertAscendDeQuantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
1282 CHECK_NULL_RETURN(func_graph);
1283 CHECK_NULL_RETURN(cnode);
1284 auto cnode_primitive = GetValueNode<PrimitivePtr>(cnode->input(kPrimIndex));
1285 if (cnode_primitive == nullptr) {
1286 MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is nullptr.";
1287 return RET_ERROR;
1288 }
1289 auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
1290 CHECK_NULL_RETURN(curr_quant_param_holder);
1291 auto input_quant_param = curr_quant_param_holder->get_input_quant_params();
1292 auto x_q_param = quant::GetInputNodeQuantParam(cnode, Index0 + kPrimOffset);
1293 if (x_q_param.empty()) {
1294 x_q_param = input_quant_param.at(Index0);
1295 }
1296 if (x_q_param.size() != kPerTensor) {
1297 MS_LOG(ERROR) << cnode->fullname_with_scope() << " x quant param size " << x_q_param.size() << " != 1";
1298 return RET_ERROR;
1299 }
1300 auto w_q_params = quant::GetInputNodeQuantParam(cnode, Index1 + kPrimOffset);
1301 if (w_q_params.empty()) {
1302 w_q_params = input_quant_param.at(Index1);
1303 }
1304 if (w_q_params.empty()) {
1305 MS_LOG(ERROR) << cnode->fullname_with_scope() << " w quant param is empty.";
1306 return RET_ERROR;
1307 }
1308 MS_LOG(INFO) << cnode->fullname_with_scope() << " x scale:" << x_q_param.at(0).scale
1309 << " w scale size:" << w_q_params.size();
1310 std::vector<uint64_t> deq_scales(w_q_params.size());
1311 for (size_t i = 0; i < w_q_params.size(); ++i) {
1312 float float32_deq_scale = static_cast<float>(x_q_param.at(0).scale * w_q_params.at(i).scale);
1313 void *ptr = &float32_deq_scale;
1314 uint32_t *uint32_deq_scale = reinterpret_cast<uint32_t *>(ptr);
1315 uint64_t u64_deq_scale = 0;
1316 u64_deq_scale |= *uint32_deq_scale;
1317 deq_scales[i] = u64_deq_scale;
1318 }
1319 auto dtype = kNumberTypeFloat32;
1320 if (cnode->HasAttr("origin_type")) {
1321 auto value = cnode->GetAttr("origin_type");
1322 dtype = static_cast<TypeId>(opt::CastToInt(value).front());
1323 }
1324 auto prim_c = std::make_shared<ops::QuantDTypeCast>();
1325 CHECK_NULL_RETURN(prim_c);
1326
1327 prim_c->Init(kNumberTypeInt32, dtype);
1328 auto prim = prim_c->GetPrim();
1329 // copy cnode quant param to dequant
1330 if (cnode_primitive->HasAttr(quant::kQuantParam)) {
1331 prim->AddAttr(quant::kQuantParam, cnode_primitive->GetAttr(quant::kQuantParam));
1332 }
1333 auto quant_dtype_cast_primitive = NewValueNode(prim);
1334 std::vector<AnfNodePtr> op_inputs;
1335 op_inputs.push_back(quant_dtype_cast_primitive);
1336 op_inputs.push_back(cnode);
1337 auto deq_scales_tensor_info = lite::CreateTensorInfo(deq_scales.data(), sizeof(uint64_t) * deq_scales.size(),
1338 {static_cast<int64_t>(deq_scales.size())}, kNumberTypeUInt64);
1339 auto deq_scales_node =
1340 opt::BuildParameterNode(func_graph, deq_scales_tensor_info, cnode->fullname_with_scope() + "-deq_scales");
1341 op_inputs.push_back(deq_scales_node);
1342
1343 auto quant_cast_cnode = func_graph->NewCNode(op_inputs);
1344 CHECK_NULL_RETURN(quant_cast_cnode);
1345 quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "-dequant");
1346 // set abstract
1347 if (cnode->abstract() != nullptr) {
1348 auto abstract = cnode->abstract()->Clone();
1349 quant_cast_cnode->set_abstract(abstract);
1350 if (quant::UpdateDataType(quant_cast_cnode, dtype) != RET_OK) {
1351 MS_LOG(ERROR) << "UpdateDataType failed, cnode name: " << quant_cast_cnode->fullname_with_scope();
1352 return RET_ERROR;
1353 }
1354 } else {
1355 MS_LOG(ERROR) << "input node abstract nullptr, input node name: " << cnode->fullname_with_scope();
1356 return RET_ERROR;
1357 }
1358
1359 auto manager = func_graph->manager();
1360 if (manager == nullptr) {
1361 manager = Manage(func_graph, true);
1362 }
1363 CHECK_NULL_RETURN(manager);
1364 auto node_users = manager->node_users()[cnode];
1365 for (auto &node_user : node_users) {
1366 manager->SetEdge(node_user.first, node_user.second, quant_cast_cnode);
1367 }
1368 MS_LOG(INFO) << cnode->fullname_with_scope() << " Insert Ascend DeQuant Node.";
1369 return RET_OK;
1370 }
1371
AdjustTransposeNodeForSingleMatMulNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1372 int InsertQuantNodeManager::AdjustTransposeNodeForSingleMatMulNode(const FuncGraphPtr &func_graph,
1373 const CNodePtr &cnode) {
1374 const std::set<PrimitivePtr> support_transpose_types = {prim::kPrimMatMulFusion, prim::kPrimMatMul,
1375 prim::kPrimBatchMatMul};
1376 if (!CheckNodeInSet(cnode, support_transpose_types)) {
1377 return RET_OK;
1378 }
1379 auto prim_ptr = GetCNodePrimitive(cnode);
1380 CHECK_NULL_RETURN(prim_ptr);
1381
1382 auto transpose_a = prim_ptr->GetAttr(mindspore::ops::kTransposeA);
1383 auto transpose_b = prim_ptr->GetAttr(mindspore::ops::kTransposeB);
1384
1385 if (transpose_a != nullptr && GetValue<bool>(transpose_a)) {
1386 MS_LOG(ERROR) << cnode->fullname_with_scope() << " transposeA is true.";
1387 return RET_ERROR;
1388 }
1389 if (transpose_b != nullptr && GetValue<bool>(transpose_b)) {
1390 int ret = RET_ERROR;
1391 MS_LOG(INFO) << cnode->fullname_with_scope() << ":" << cnode->input(kWeightIndex + kPrimOffset)->type_name();
1392 if (cnode->input(kWeightIndex + kPrimOffset)->isa<CNode>()) {
1393 return RET_OK;
1394 } else if (cnode->input(kWeightIndex + kPrimOffset)->isa<Parameter>()) {
1395 auto manager = Manage(func_graph);
1396 CHECK_NULL_RETURN(manager);
1397 auto weight_input = cnode->input(kWeightIndex + 1);
1398 auto dst_prim = GetCNodePrimitive(cnode);
1399 MS_LOG(INFO) << cnode->fullname_with_scope() << " transpose_b is true.";
1400 dst_prim->AddAttr(mindspore::ops::kTransposeB, MakeValue(false));
1401 ParameterPtr param_node;
1402 tensor::TensorPtr tensor_info;
1403 GetParameterAndTensor(weight_input, ¶m_node, &tensor_info);
1404 if (tensor_info->shape_c().size() == DIMENSION_3D) {
1405 MS_LOG(INFO) << weight_input->fullname_with_scope() << " shape is " << tensor_info->shape_c()
1406 << " will not do transpose";
1407 return RET_OK;
1408 }
1409 if (tensor_info->shape_c().size() != DIMENSION_2D) {
1410 MS_LOG(ERROR) << weight_input->fullname_with_scope() << " shape is " << tensor_info->shape_c()
1411 << " is large than 2.";
1412 return RET_ERROR;
1413 }
1414
1415 if (tensor_info->data_type_c() == kNumberTypeFloat32) {
1416 ret = TransposeData<float>(param_node, tensor_info);
1417 } else if (tensor_info->data_type_c() == kNumberTypeFloat16) {
1418 ret = TransposeData<Float16>(param_node, tensor_info);
1419 } else {
1420 MS_LOG(ERROR) << "transpose data only support Float32 or Float16.";
1421 return RET_OK;
1422 }
1423
1424 if (ret != RET_OK) {
1425 MS_LOG(ERROR) << weight_input->fullname_with_scope() << " transposeData failed.";
1426 return ret;
1427 }
1428 } else {
1429 MS_LOG(ERROR) << "Dont support type is " << cnode->input(kWeightIndex + kPrimOffset)->type_name();
1430 return RET_ERROR;
1431 }
1432 }
1433 return RET_OK;
1434 }
1435
AdjustTransposeNodeForMatMul(const FuncGraphPtr & func_graph)1436 int InsertQuantNodeManager::AdjustTransposeNodeForMatMul(const FuncGraphPtr &func_graph) {
1437 auto cnodes = func_graph->GetOrderedCnodes();
1438 for (auto &cnode : cnodes) {
1439 auto ret = AdjustTransposeNodeForSingleMatMulNode(func_graph, cnode);
1440 if (ret != RET_OK) {
1441 MS_LOG(ERROR) << cnode->fullname_with_scope() << " Adjust Transpose Node failed.";
1442 return ret;
1443 }
1444 }
1445 return RET_OK;
1446 }
1447
InsertTransposeNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t index)1448 int InsertQuantNodeManager::InsertTransposeNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index) {
1449 auto prim_ptr = GetCNodePrimitive(cnode);
1450 CHECK_NULL_RETURN(prim_ptr);
1451 std::vector<int> perm;
1452 ShapeVector shape;
1453 auto ret = opt::FetchShapeFromAbstract(cnode->input(index)->abstract(), &shape);
1454 if (ret != RET_OK) {
1455 MS_LOG(ERROR) << "Fetch shape from abstract failed.";
1456 return RET_OK;
1457 }
1458 if (shape.size() == DIMENSION_2D) {
1459 perm = {1, 0};
1460 } else if (shape.size() == DIMENSION_3D) {
1461 perm = {0, 2, 1};
1462 } else if (shape.size() == DIMENSION_4D) {
1463 perm = {0, 1, 3, 2};
1464 } else {
1465 MS_LOG(ERROR) << shape.size() << " is invalid.";
1466 return RET_ERROR;
1467 }
1468 auto transpose = opt::GenTransposeNode(func_graph, cnode->input(index), perm,
1469 cnode->input(index)->fullname_with_scope() + "-transpose");
1470 auto manager = Manage(func_graph);
1471 MS_ASSERT(manager != nullptr);
1472 manager->SetEdge(cnode, kWeightIndex + kPrimOffset, transpose);
1473 prim_ptr->set_attr(mindspore::ops::kTransposeB, MakeValue(false));
1474 return RET_OK;
1475 }
1476 } // namespace mindspore::lite::quant
1477