• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/shape_ops_splitter.h"
17 
18 #include <algorithm>
19 #include <vector>
20 #include <set>
21 #include <utility>
22 #include <map>
23 #include "ir/anf.h"
24 #include "utils/anf_utils.h"
25 
26 namespace mindspore::graphkernel {
27 namespace {
CloneCNode(const AnfNodePtr & anf_node)28 AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
29   auto func_graph = anf_node->func_graph();
30   MS_EXCEPTION_IF_NULL(func_graph);
31   auto cnode = anf_node->cast<CNodePtr>();
32   MS_EXCEPTION_IF_NULL(cnode);
33   TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
34   CNodePtr node = func_graph->NewCNode(cnode->inputs());
35   ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
36   node->set_scope(scope);
37   node->CloneCNodeInfo(cnode);
38   return node;
39 }
40 
SplitNode(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)41 void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
42   const auto &index_set = mng->node_users()[node];
43   std::map<AnfNodePtr, std::vector<int>> users_info;
44   (void)std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) {
45     users_info[iter.first].push_back(iter.second);
46   });
47 
48   AnfNodePtrList split_nodes;
49   for (size_t i = 0; i < users_info.size(); ++i) {
50     split_nodes.push_back(CloneCNode(node));
51   }
52 
53   size_t i = 0;
54   for (const auto &[user, indices] : users_info) {
55     auto user_node = user->cast<CNodePtr>();
56     MS_EXCEPTION_IF_NULL(user_node);
57     for (auto index : indices) {
58       user_node->set_input(IntToSize(index), split_nodes[i]);
59     }
60     i++;
61   }
62 }
63 }  // namespace
64 
IsMultiUserShapeOps(const AnfNodePtr & node,const FuncGraphManagerPtr & mng) const65 bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) const {
66   auto &users = mng->node_users();
67   std::set<AnfNodePtr> user_set;
68   (void)std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()),
69                        [](const std::pair<AnfNodePtr, int> &iter) { return iter.first; });
70   return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(),
71                                             [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
72 }
73 
Process(const FuncGraphPtr & func_graph) const74 bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) const {
75   MS_EXCEPTION_IF_NULL(func_graph);
76   auto mng = func_graph->manager();
77   if (mng == nullptr) {
78     mng = Manage(func_graph, true);
79     func_graph->set_manager(mng);
80   }
81   bool changed = false;
82   auto todos = TopoSort(func_graph->get_return());
83   for (const auto &anf_node : todos) {
84     auto node = anf_node->cast<CNodePtr>();
85     if (node != nullptr && IsMultiUserShapeOps(node, mng)) {
86       SplitNode(node, mng);
87       changed = true;
88     }
89   }
90   if (changed) {
91     mng->RemoveRoots();
92     mng->KeepRoots({func_graph});
93   }
94   return changed;
95 }
96 
Run(const FuncGraphPtr & func_graph)97 bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
98   MS_EXCEPTION_IF_NULL(func_graph);
99   auto mng = func_graph->manager();
100   if (mng == nullptr) {
101     mng = Manage(func_graph, true);
102     func_graph->set_manager(mng);
103   }
104 
105   auto todos = TopoSort(func_graph->get_return());
106   bool result = false;
107   for (const auto &anf_node : todos) {
108     if (AnfUtils::IsGraphKernel(anf_node)) {
109       auto sub_graph = GetCNodeFuncGraph(anf_node);
110       bool changed = false;
111       do {
112         changed = Process(sub_graph);
113         result = result || changed;
114       } while (changed);
115     }
116   }
117 
118   if (result) {
119     mng->RemoveRoots();
120     mng->KeepRoots({func_graph});
121   }
122   return result;
123 }
124 }  // namespace mindspore::graphkernel
125