1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h"
17 #include "utils/ms_context.h"
18 #include "backend/optimizer/common/fusion_id_allocator.h"
19 #include "backend/session/anf_runtime_algorithm.h"
20
21 namespace mindspore {
22 namespace opt {
CheckEltWiseNode(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)23 bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
24 auto manager = kernel_graph.manager();
25 MS_EXCEPTION_IF_NULL(manager);
26 MS_EXCEPTION_IF_NULL(node);
27 if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
28 return false;
29 }
30 auto cnode = node->cast<CNodePtr>();
31 MS_EXCEPTION_IF_NULL(cnode);
32 size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
33 return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
34 AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_USE &&
35 cnode->inputs().size() == ELTWISE_INPUT_SIZE;
36 }
37
CheckDoubleInEltWiseNode(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)38 bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
39 auto manager = kernel_graph.manager();
40 MS_EXCEPTION_IF_NULL(manager);
41 MS_EXCEPTION_IF_NULL(node);
42 if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
43 return false;
44 }
45 auto cnode = node->cast<CNodePtr>();
46 MS_EXCEPTION_IF_NULL(cnode);
47 size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
48 return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
49 AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_USE &&
50 cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
51 }
52
CheckMultiOutputEltWiseNode(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)53 bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
54 auto manager = kernel_graph.manager();
55 MS_EXCEPTION_IF_NULL(manager);
56 MS_EXCEPTION_IF_NULL(node);
57 if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
58 return false;
59 }
60 auto cnode = node->cast<CNodePtr>();
61 MS_EXCEPTION_IF_NULL(cnode);
62 size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
63 return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
64 AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_MULTI_USE &&
65 cnode->inputs().size() == ELTWISE_INPUT_SIZE;
66 }
67
GetNotUpdateStateUserNums(const session::KernelGraph & kernel_graph,const AnfNodePtr & node)68 size_t FusionBasePass::GetNotUpdateStateUserNums(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
69 MS_EXCEPTION_IF_NULL(node);
70 auto manager = kernel_graph.manager();
71 MS_EXCEPTION_IF_NULL(manager);
72 auto user_nodes = manager->node_users()[node];
73 size_t not_updatestate_users = 0;
74 for (auto &user : user_nodes) {
75 auto user_node = user.first;
76 if (!AnfAlgo::CheckPrimitiveType(user_node, prim::kPrimUpdateState)) {
77 not_updatestate_users++;
78 }
79 }
80 return not_updatestate_users;
81 }
82
SetRecordFusionId(const std::unordered_set<AnfNodePtr> & record)83 void FusionBasePass::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) {
84 auto id = fusion_id_allocator->AllocateFusionId();
85 for (auto node : record) {
86 fusion_id_allocator->SetFusionId(node, id);
87 }
88 }
89
MatchUBFusionPattern(const session::KernelGraph & kernel_graph)90 bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) {
91 auto manager = kernel_graph.manager();
92 MS_EXCEPTION_IF_NULL(manager);
93 auto return_node = kernel_graph.get_return();
94 MS_EXCEPTION_IF_NULL(return_node);
95 if (return_node->inputs().size() <= 1) {
96 return false;
97 }
98 MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
99 FusedNodeRecord candidate_fusion;
100 MatchSingleFusionPattern(kernel_graph, &candidate_fusion);
101 if (candidate_fusion.empty()) {
102 return false;
103 }
104 MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
105 return true;
106 }
107
Run(const FuncGraphPtr & graph)108 bool FusionBasePass::Run(const FuncGraphPtr &graph) {
109 MS_EXCEPTION_IF_NULL(graph);
110 auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
111 MS_EXCEPTION_IF_NULL(kernel_graph);
112 return MatchUBFusionPattern(*kernel_graph);
113 }
114 } // namespace opt
115 } // namespace mindspore
116