1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define USE_DEPRECATED_API
17 #include "tools/optimizer/fusion/concat_concat_fusion.h"
18 #include <vector>
19 #include "mindspore/core/ops/array_ops.h"
20 #include "ir/func_graph.h"
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "ops/op_name.h"
23
24 namespace mindspore {
25 namespace opt {
Run(const FuncGraphPtr & func_graph)26 bool ConcatConcatFusion::Run(const FuncGraphPtr &func_graph) {
27 MS_ASSERT(func_graph != nullptr);
28 auto node_list = TopoSort(func_graph->get_return());
29 for (auto &node : node_list) {
30 if (!utils::isa<CNode>(node) || !CheckPrimitiveType(node, prim::kPrimConcat)) {
31 continue;
32 }
33 auto cnode = node->cast<CNodePtr>();
34 if (Process(func_graph, cnode) != lite::RET_OK) {
35 MS_LOG(ERROR) << "Do ConcatConcatFusion failed, node name is " << node->fullname_with_scope();
36 return false;
37 }
38 }
39 UpdateManager(func_graph);
40 return true;
41 }
42
Process(const FuncGraphPtr & func_graph,const CNodePtr & cnode)43 int ConcatConcatFusion::Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
44 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
45 auto prim = GetCNodePrimitive(cnode);
46 if (prim == nullptr) {
47 MS_LOG(ERROR) << "Concat's prim is a nullptr, node name is " << cnode->fullname_with_scope();
48 return lite::RET_NULL_PTR;
49 }
50 if (IsQuantParameterNode(prim)) {
51 return lite::RET_OK;
52 }
53 auto axis = prim->GetAttr(ops::kAxis) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kAxis)) : 0;
54 auto &inputs = cnode->inputs();
55 std::vector<AnfNodePtr> new_inputs;
56 for (const auto &node : inputs) {
57 if (!utils::isa<CNode>(node) || !CheckPrimitiveType(node, prim::kPrimConcat)) {
58 new_inputs.push_back(node);
59 continue;
60 }
61 auto pre_concat = node->cast<CNodePtr>();
62 if (IsMultiOutputTensors(func_graph, pre_concat)) {
63 new_inputs.push_back(node);
64 continue;
65 }
66 auto pre_prim = GetCNodePrimitive(pre_concat);
67 if (pre_prim == nullptr) {
68 MS_LOG(ERROR) << "Concat's prim is a nullptr, node name is " << pre_concat->fullname_with_scope();
69 return lite::RET_NULL_PTR;
70 }
71 if (IsQuantParameterNode(pre_prim)) {
72 new_inputs.push_back(node);
73 continue;
74 }
75 auto pre_axis = pre_prim->GetAttr(ops::kAxis) != nullptr ? GetValue<int64_t>(pre_prim->GetAttr(ops::kAxis)) : 0;
76 if (pre_axis != axis) {
77 new_inputs.push_back(node);
78 continue;
79 }
80 auto pre_inputs = pre_concat->inputs();
81 new_inputs.insert(new_inputs.end(), pre_inputs.begin() + 1, pre_inputs.end());
82 }
83 cnode->set_inputs(new_inputs);
84 return lite::RET_OK;
85 }
86 } // namespace opt
87 } // namespace mindspore
88