• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/shape_ops_splitter.h"
17 #include <algorithm>
18 #include <vector>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <queue>
23 #include <map>
24 #include "frontend/optimizer/irpass.h"
25 #include "pipeline/jit/parse/python_adapter.h"
26 #include "backend/session/anf_runtime_algorithm.h"
27 #include "backend/kernel_compiler/common_utils.h"
28 #include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
29 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
30 #include "debug/anf_ir_dump.h"
31 
32 namespace mindspore {
33 namespace opt {
34 namespace {
CloneCNode(const AnfNodePtr & anf_node)35 AnfNodePtr CloneCNode(const AnfNodePtr &anf_node) {
36   auto func_graph = anf_node->func_graph();
37   MS_EXCEPTION_IF_NULL(func_graph);
38   auto cnode = anf_node->cast<CNodePtr>();
39   MS_EXCEPTION_IF_NULL(cnode);
40   TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
41   CNodePtr node = func_graph->NewCNode(cnode->inputs());
42   node->set_abstract(cnode->abstract());
43   node->set_forward(cnode->forward().first, cnode->forward().second);
44   node->set_inputs_value(cnode->inputs_value());
45   ScopePtr scope = (anf_node->scope() != kDefaultScope) ? anf_node->scope() : kDefaultScope;
46   node->set_scope(scope);
47   node->set_kernel_info(cnode->kernel_info_ptr());
48   node->set_primal_attrs(cnode->primal_attrs());
49   node->set_primal_debug_infos(cnode->primal_debug_infos());
50   return node;
51 }
52 
SplitNode(const AnfNodePtr & node,const FuncGraphManagerPtr & mng)53 void SplitNode(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
54   const auto &index_set = mng->node_users()[node];
55   std::map<AnfNodePtr, std::vector<int>> users_info;
56   std::for_each(index_set.cbegin(), index_set.cend(), [&users_info](const std::pair<AnfNodePtr, int> &iter) {
57     users_info[iter.first].push_back(iter.second);
58   });
59 
60   AnfNodePtrList split_nodes;
61   for (size_t i = 0; i < users_info.size(); ++i) {
62     split_nodes.push_back(CloneCNode(node));
63   }
64 
65   size_t i = 0;
66   for (auto [user, indices] : users_info) {
67     auto user_node = user->cast<CNodePtr>();
68     MS_EXCEPTION_IF_NULL(user_node);
69     for (auto index : indices) {
70       user_node->set_input(IntToSize(index), split_nodes[i]);
71     }
72     i++;
73   }
74 }
75 }  // namespace
76 
IsMultiUserShapeOps(const AnfNodePtr & node,const FuncGraphManagerPtr & mng) const77 bool ShapeOpsSplitter::IsMultiUserShapeOps(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) const {
78   auto &users = mng->node_users();
79   std::set<AnfNodePtr> user_set;
80   std::transform(users[node].cbegin(), users[node].cend(), std::inserter(user_set, user_set.end()),
81                  [](const std::pair<AnfNodePtr, int> &iter) { return iter.first; });
82   return user_set.size() > 1 && std::any_of(shape_ops_.begin(), shape_ops_.end(),
83                                             [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
84 }
85 
Process(const FuncGraphPtr & func_graph)86 bool ShapeOpsSplitter::Process(const FuncGraphPtr &func_graph) {
87   MS_EXCEPTION_IF_NULL(func_graph);
88   auto mng = func_graph->manager();
89   if (mng == nullptr) {
90     mng = Manage(func_graph, true);
91     func_graph->set_manager(mng);
92   }
93   bool changed = false;
94   auto todos = TopoSort(func_graph->get_return());
95   for (const auto &anf_node : todos) {
96     auto node = anf_node->cast<CNodePtr>();
97     if (node != nullptr && IsMultiUserShapeOps(node, mng)) {
98       SplitNode(node, mng);
99       changed = true;
100     }
101   }
102   if (changed) {
103     mng->RemoveRoots();
104     mng->KeepRoots({func_graph});
105   }
106   return changed;
107 }
108 
Run(const FuncGraphPtr & func_graph)109 bool ShapeOpsSplitter::Run(const FuncGraphPtr &func_graph) {
110   MS_EXCEPTION_IF_NULL(func_graph);
111   auto mng = func_graph->manager();
112   if (mng == nullptr) {
113     mng = Manage(func_graph, true);
114     func_graph->set_manager(mng);
115   }
116 
117   auto todos = TopoSort(func_graph->get_return());
118   bool result = false;
119   for (const auto &anf_node : todos) {
120     if (AnfAlgo::IsGraphKernel(anf_node)) {
121       auto sub_graph = AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
122       bool changed = false;
123       do {
124         changed = Process(sub_graph);
125         result = result || changed;
126       } while (changed);
127     }
128   }
129 
130   if (result) {
131     mng->RemoveRoots();
132     mng->KeepRoots({func_graph});
133   }
134   return result;
135 }
136 }  // namespace opt
137 }  // namespace mindspore
138