• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define USE_DEPRECATED_API
17 #include "tools/optimizer/fusion/concat_concat_fusion.h"
18 #include <vector>
19 #include "mindspore/core/ops/array_ops.h"
20 #include "ir/func_graph.h"
21 #include "tools/optimizer/common/gllo_utils.h"
22 #include "ops/op_name.h"
23 
24 namespace mindspore {
25 namespace opt {
Run(const FuncGraphPtr & func_graph)26 bool ConcatConcatFusion::Run(const FuncGraphPtr &func_graph) {
27   MS_ASSERT(func_graph != nullptr);
28   auto node_list = TopoSort(func_graph->get_return());
29   for (auto &node : node_list) {
30     if (!utils::isa<CNode>(node) || !CheckPrimitiveType(node, prim::kPrimConcat)) {
31       continue;
32     }
33     auto cnode = node->cast<CNodePtr>();
34     if (Process(func_graph, cnode) != lite::RET_OK) {
35       MS_LOG(ERROR) << "Do ConcatConcatFusion failed, node name is " << node->fullname_with_scope();
36       return false;
37     }
38   }
39   UpdateManager(func_graph);
40   return true;
41 }
42 
Process(const FuncGraphPtr & func_graph,const CNodePtr & cnode)43 int ConcatConcatFusion::Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
44   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
45   auto prim = GetCNodePrimitive(cnode);
46   if (prim == nullptr) {
47     MS_LOG(ERROR) << "Concat's prim is a nullptr, node name is " << cnode->fullname_with_scope();
48     return lite::RET_NULL_PTR;
49   }
50   if (IsQuantParameterNode(prim)) {
51     return lite::RET_OK;
52   }
53   auto axis = prim->GetAttr(ops::kAxis) != nullptr ? GetValue<int64_t>(prim->GetAttr(ops::kAxis)) : 0;
54   auto &inputs = cnode->inputs();
55   std::vector<AnfNodePtr> new_inputs;
56   for (const auto &node : inputs) {
57     if (!utils::isa<CNode>(node) || !CheckPrimitiveType(node, prim::kPrimConcat)) {
58       new_inputs.push_back(node);
59       continue;
60     }
61     auto pre_concat = node->cast<CNodePtr>();
62     if (IsMultiOutputTensors(func_graph, pre_concat)) {
63       new_inputs.push_back(node);
64       continue;
65     }
66     auto pre_prim = GetCNodePrimitive(pre_concat);
67     if (pre_prim == nullptr) {
68       MS_LOG(ERROR) << "Concat's prim is a nullptr, node name is " << pre_concat->fullname_with_scope();
69       return lite::RET_NULL_PTR;
70     }
71     if (IsQuantParameterNode(pre_prim)) {
72       new_inputs.push_back(node);
73       continue;
74     }
75     auto pre_axis = pre_prim->GetAttr(ops::kAxis) != nullptr ? GetValue<int64_t>(pre_prim->GetAttr(ops::kAxis)) : 0;
76     if (pre_axis != axis) {
77       new_inputs.push_back(node);
78       continue;
79     }
80     auto pre_inputs = pre_concat->inputs();
81     new_inputs.insert(new_inputs.end(), pre_inputs.begin() + 1, pre_inputs.end());
82   }
83   cnode->set_inputs(new_inputs);
84   return lite::RET_OK;
85 }
86 }  // namespace opt
87 }  // namespace mindspore
88