1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/shape_ops_splitter.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <set>
21 #include <utility>
22 #include <map>
23 #include "ir/anf.h"
24 #include "utils/anf_utils.h"
25
26 namespace mindspore::graphkernel {
27 namespace {
CloneCNode(const AnfNodePtr & anf_node)28 AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
29 auto func_graph = anf_node->func_graph();
30 MS_EXCEPTION_IF_NULL(func_graph);
31 auto cnode = anf_node->cast<CNodePtr>();
32 MS_EXCEPTION_IF_NULL(cnode);
33 TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
34 CNodePtr node = func_graph->NewCNode(cnode->inputs());
35 ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
36 node->set_scope(scope);
37 node->CloneCNodeInfo(cnode);
38 return node;
39 }
40
SplitNode(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)41 void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
42 const auto &index_set = mng->node_users()[node];
43 std::map<AnfNodePtr, std::vector<int>> users_info;
44 (void)std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) {
45 users_info[iter.first].push_back(iter.second);
46 });
47
48 AnfNodePtrList split_nodes;
49 for (size_t i = 0; i < users_info.size(); ++i) {
50 split_nodes.push_back(CloneCNode(node));
51 }
52
53 size_t i = 0;
54 for (const auto &[user, indices] : users_info) {
55 auto user_node = user->cast<CNodePtr>();
56 MS_EXCEPTION_IF_NULL(user_node);
57 for (auto index : indices) {
58 user_node->set_input(IntToSize(index), split_nodes[i]);
59 }
60 i++;
61 }
62 }
63 } // namespace
64
IsMultiUserShapeOps(const AnfNodePtr & node,const FuncGraphManagerPtr & mng) const65 bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) const {
66 auto &users = mng->node_users();
67 std::set<AnfNodePtr> user_set;
68 (void)std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()),
69 [](const std::pair<AnfNodePtr, int> &iter) { return iter.first; });
70 return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(),
71 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
72 }
73
Process(const FuncGraphPtr & func_graph) const74 bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) const {
75 MS_EXCEPTION_IF_NULL(func_graph);
76 auto mng = func_graph->manager();
77 if (mng == nullptr) {
78 mng = Manage(func_graph, true);
79 func_graph->set_manager(mng);
80 }
81 bool changed = false;
82 auto todos = TopoSort(func_graph->get_return());
83 for (const auto &anf_node : todos) {
84 auto node = anf_node->cast<CNodePtr>();
85 if (node != nullptr && IsMultiUserShapeOps(node, mng)) {
86 SplitNode(node, mng);
87 changed = true;
88 }
89 }
90 if (changed) {
91 mng->RemoveRoots();
92 mng->KeepRoots({func_graph});
93 }
94 return changed;
95 }
96
Run(const FuncGraphPtr & func_graph)97 bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
98 MS_EXCEPTION_IF_NULL(func_graph);
99 auto mng = func_graph->manager();
100 if (mng == nullptr) {
101 mng = Manage(func_graph, true);
102 func_graph->set_manager(mng);
103 }
104
105 auto todos = TopoSort(func_graph->get_return());
106 bool result = false;
107 for (const auto &anf_node : todos) {
108 if (AnfUtils::IsGraphKernel(anf_node)) {
109 auto sub_graph = GetCNodeFuncGraph(anf_node);
110 bool changed = false;
111 do {
112 changed = Process(sub_graph);
113 result = result || changed;
114 } while (changed);
115 }
116 }
117
118 if (result) {
119 mng->RemoveRoots();
120 mng->KeepRoots({func_graph});
121 }
122 return result;
123 }
124 } // namespace mindspore::graphkernel
125