1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
17 #include <algorithm>
18 #include <vector>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <queue>
23 #include <map>
24 #include "frontend/optimizer/irpass.h"
25 #include "pipeline/jit/parse/python_adapter.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "backend/kernel_compiler/common_utils.h"
28 #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
29 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
30 #include "debug/anf_ir_dump.h"
31
32 namespace mindspore {
33 namespace opt {
34 namespace {
CloneCNode(const AnfNodePtr & anf_node)35 AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
36 auto func_graph = anf_node->func_graph();
37 MS_EXCEPTION_IF_NULL(func_graph);
38 auto cnode = anf_node->cast<CNodePtr>();
39 MS_EXCEPTION_IF_NULL(cnode);
40 TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
41 CNodePtr node = func_graph->NewCNode(cnode->inputs());
42 node->set_abstract(cnode->abstract());
43 node->set_forward(cnode->forward().first, cnode->forward().second);
44 node->set_inputs_value(cnode->inputs_value());
45 ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
46 node->set_scope(scope);
47 node->set_kernel_info(cnode->kernel_info_ptr());
48 node->set_primal_attrs(cnode->primal_attrs());
49 node->set_primal_debug_infos(cnode->primal_debug_infos());
50 return node;
51 }
52
SplitNode(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)53 void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
54 const auto &index_set = mng->node_users()[node];
55 std::map<AnfNodePtr, std::vector<int>> users_info;
56 std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) {
57 users_info[iter.first].push_back(iter.second);
58 });
59
60 AnfNodePtrList split_nodes;
61 for (size_t i = 0; i < users_info.size(); ++i) {
62 split_nodes.push_back(CloneCNode(node));
63 }
64
65 size_t i = 0;
66 for (auto [user, indices] : users_info) {
67 auto user_node = user->cast<CNodePtr>();
68 MS_EXCEPTION_IF_NULL(user_node);
69 for (auto index : indices) {
70 user_node->set_input(IntToSize(index), split_nodes[i]);
71 }
72 i++;
73 }
74 }
75 } // namespace
76
IsMultiUserShapeOps(const AnfNodePtr & node,const FuncGraphManagerPtr & mng) const77 bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) const {
78 auto &users = mng->node_users();
79 std::set<AnfNodePtr> user_set;
80 std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()),
81 [](const std::pair<AnfNodePtr, int> &iter) { return iter.first; });
82 return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(),
83 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
84 }
85
Process(const FuncGraphPtr & func_graph)86 bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
87 MS_EXCEPTION_IF_NULL(func_graph);
88 auto mng = func_graph->manager();
89 if (mng == nullptr) {
90 mng = Manage(func_graph, true);
91 func_graph->set_manager(mng);
92 }
93 bool changed = false;
94 auto todos = TopoSort(func_graph->get_return());
95 for (const auto &anf_node : todos) {
96 auto node = anf_node->cast<CNodePtr>();
97 if (node != nullptr && IsMultiUserShapeOps(node, mng)) {
98 SplitNode(node, mng);
99 changed = true;
100 }
101 }
102 if (changed) {
103 mng->RemoveRoots();
104 mng->KeepRoots({func_graph});
105 }
106 return changed;
107 }
108
Run(const FuncGraphPtr & func_graph)109 bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
110 MS_EXCEPTION_IF_NULL(func_graph);
111 auto mng = func_graph->manager();
112 if (mng == nullptr) {
113 mng = Manage(func_graph, true);
114 func_graph->set_manager(mng);
115 }
116
117 auto todos = TopoSort(func_graph->get_return());
118 bool result = false;
119 for (const auto &anf_node : todos) {
120 if (AnfAlgo::IsGraphKernel(anf_node)) {
121 auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
122 bool changed = false;
123 do {
124 changed = Process(sub_graph);
125 result = result || changed;
126 } while (changed);
127 }
128 }
129
130 if (result) {
131 mng->RemoveRoots();
132 mng->KeepRoots({func_graph});
133 }
134 return result;
135 }
136 } // namespace opt
137 } // namespace mindspore
138