• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/kernel_packet/symbol_engine_extender.h"
18 
19 #include <algorithm>
20 #include <memory>
21 #include <functional>
22 #include <vector>
23 #include "utils/anf_utils.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/arithmetic_ops.h"
28 #include "mindspore/core/symbolic_shape/operation_builder.h"
29 #include "include/common/utils/anfalgo.h"
30 #include "backend/common/graph_kernel/core/graph_builder.h"
31 #include "backend/common/graph_kernel/kernel_packet/kernel_packet_engine.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
33 #include "backend/common/graph_kernel/graph_kernel_flags.h"
34 #include "include/backend/anf_runtime_algorithm.h"
35 #include "backend/common/pass/insert_type_transform_op.h"
36 
37 namespace mindspore::graphkernel::packet {
38 using symshape::DependOn;
39 
IsHostOp(const AnfNodePtr & node)40 inline bool IsHostOp(const AnfNodePtr &node) {
41   if (!node->isa<CNode>()) {
42     return false;
43   }
44   if (AnfAlgo::IsKernelSelectBackoffOp(node)) {
45     return true;
46   }
47   // ops inserted in InsertTypeTransformOp
48   return opt::IsBackOffOp(node->cast<CNodePtr>());
49 }
50 
IsDeviceOp(const AnfNodePtr & node)51 inline bool IsDeviceOp(const AnfNodePtr &node) {
52   if (!AnfUtils::IsRealKernel(node) || IsHostOp(node) || node->kernel_info() == nullptr) {
53     return false;
54   }
55   auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
56   if (build_info != nullptr && build_info->valid()) {
57     return true;
58   }
59   return false;
60 }
61 
CheckBaseNode(const AnfNodePtr & node)62 bool SymbolEngineExtender::CheckBaseNode(const AnfNodePtr &node) {
63   if (GetCNodePrimitive(node) == nullptr) {
64     return false;
65   }
66   if (!IsDeviceOp(node)) {
67     return false;
68   }
69   auto &flags = GraphKernelFlags::GetInstance();
70   if (!flags.enable_packet_ops_only.empty()) {
71     if (std::find(flags.enable_packet_ops_only.begin(), flags.enable_packet_ops_only.end(),
72                   AnfUtils::GetCNodeName(node)) == flags.enable_packet_ops_only.end()) {
73       return false;
74     }
75   } else if (std::find(flags.disable_packet_ops.begin(), flags.disable_packet_ops.end(),
76                        AnfUtils::GetCNodeName(node)) != flags.disable_packet_ops.end()) {
77     return false;
78   }
79   MS_LOG(DEBUG) << "Search from the base node: " << node->DebugString();
80   return true;
81 }
82 
FindShapeDependHostNode(const CNodePtr & node,HashSet<AnfNodePtr> * visited,HashSet<AnfNodePtr> * valid_nodes)83 void SymbolEngineExtender::FindShapeDependHostNode(const CNodePtr &node, HashSet<AnfNodePtr> *visited,
84                                                    HashSet<AnfNodePtr> *valid_nodes) {
85   if (!visited->insert(node).second) {
86     return;
87   }
88   auto prim = GetCNodePrimitive(node);
89   if (prim == nullptr) {
90     return;
91   }
92   if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
93     return;
94   }
95   auto depends = symshape::GetShapeDepends(prim, node->size() - 1);
96   if (depends.empty()) {
97     MS_LOG(DEBUG) << "The node " << node->fullname_with_scope() << " shape depend status is empty.";
98     return;
99   }
100   MS_LOG(DEBUG) << "Add " << node->fullname_with_scope() << " into candidates.";
101   (void)valid_nodes->insert(node);
102   for (size_t i = 0; i < depends.size(); i++) {
103     auto inp = node->input(i + 1)->cast<CNodePtr>();
104     if (inp == nullptr) {
105       continue;
106     }
107     // assume that building shape for host op does not depend input value again.
108     if (depends[i] == DependOn::kShape && IsHostOp(inp)) {
109       FindShapeDependHostNode(inp, visited, valid_nodes);
110     }
111   }
112 }
113 
FindValueDependNode(const CNodePtr & node,HashSet<AnfNodePtr> * visited,HashSet<AnfNodePtr> * valid_nodes)114 void SymbolEngineExtender::FindValueDependNode(const CNodePtr &node, HashSet<AnfNodePtr> *visited,
115                                                HashSet<AnfNodePtr> *valid_nodes) {
116   if (!visited->insert(node).second) {
117     return;
118   }
119   if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
120     return;
121   }
122   auto prim = GetCNodePrimitive(node);
123   if (prim == nullptr) {
124     return;
125   }
126   auto depends = symshape::GetValueDepends(prim, node->size() - 1);
127   // always try to fuse host op, if the node does not support symbolic value, the whole packet will be dropped.
128   // only fuse device op when it supports building symbolic value.
129   if (depends.empty() && !IsHostOp(node)) {
130     MS_LOG(DEBUG) << "The " << node->fullname_with_scope() << " is not host op and value depend status is empty.";
131     return;
132   }
133   MS_LOG(DEBUG) << "Add " << node->fullname_with_scope() << " into candidates.";
134   (void)valid_nodes->insert(node);
135   for (size_t i = 0; i < depends.size(); i++) {
136     auto inp = node->input(i + 1)->cast<CNodePtr>();
137     if (inp == nullptr) {
138       continue;
139     }
140     if (depends[i] == DependOn::kValue) {
141       FindValueDependNode(inp, visited, valid_nodes);
142     } else if (IsHostOp(inp)) {
143       MS_LOG(DEBUG) << "The input[" << i << "] is host op.";
144       FindShapeDependHostNode(inp, visited, valid_nodes);
145     }
146   }
147 }
148 
FindCandidates(const CNodePtr & base_node)149 AnfNodePtrList SymbolEngineExtender::FindCandidates(const CNodePtr &base_node) {
150   HashSet<AnfNodePtr> visited;
151   HashSet<AnfNodePtr> valid_nodes;
152   auto depends = symshape::GetShapeDepends(GetCNodePrimitive(base_node), base_node->size() - 1);
153   if (depends.empty()) {
154     return {};
155   }
156   // use dfs to find the clusterable nodes.
157   for (size_t i = 0; i < depends.size(); i++) {
158     auto inp = base_node->input(i + 1);
159     if (!inp->isa<CNode>()) {
160       continue;
161     }
162     if (depends[i] == DependOn::kValue) {
163       MS_LOG(DEBUG) << "The input[" << i << "] " << inp->fullname_with_scope() << " is value-depended.";
164       FindValueDependNode(inp->cast<CNodePtr>(), &visited, &valid_nodes);
165     } else if (IsHostOp(inp)) {
166       MS_LOG(DEBUG) << "The input[" << i << "] " << inp->fullname_with_scope()
167                     << " is not value-depended, but it's a host op.";
168       FindValueDependNode(inp->cast<CNodePtr>(), &visited, &valid_nodes);
169     }
170   }
171   if (valid_nodes.empty()) {
172     return {};
173   }
174   (void)valid_nodes.insert(base_node);
175 
176   return TopoSort(base_node, SuccIncoming, [&valid_nodes](const AnfNodePtr &node) -> IncludeType {
177     return valid_nodes.count(node) > 0 ? FOLLOW : EXCLUDE;
178   });
179 }
180 
FindOnlyDependShapeInputs(const FuncGraphPtr & fg) const181 ValuePtr SymbolEngineExtender::FindOnlyDependShapeInputs(const FuncGraphPtr &fg) const {
182   const auto &params = fg->parameters();
183   std::vector<bool> only_depend_shape(params.size(), true);
184   auto engine = fg->symbol_engine();
185   MS_EXCEPTION_IF_NULL(engine);
186   // depend value when infer
187   for (size_t i = 0; i < params.size(); i++) {
188     if (engine->IsDependValue(params[i])) {
189       only_depend_shape[i] = false;
190     }
191   }
192   // depend value when launch kernel
193   auto kernel = fg->output()->cast<CNodePtr>();
194   MS_EXCEPTION_IF_NULL(kernel);
195   for (auto inp : kernel->inputs()) {
196     auto iter = std::find(params.begin(), params.end(), inp);
197     if (iter != params.end()) {
198       only_depend_shape[iter - params.begin()] = false;
199     }
200   }
201   return MakeValue<std::vector<bool>>(only_depend_shape);
202 }
203 
CreatePacketNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & sub_fg,const AnfNodePtrList & inputs)204 CNodePtr CreatePacketNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) {
205   std::vector<AnfNodePtr> fn_inputs;
206   fn_inputs.reserve(inputs.size() + 1);
207   (void)fn_inputs.emplace_back(NewValueNode(sub_fg));
208   (void)fn_inputs.insert(fn_inputs.end(), inputs.cbegin(), inputs.cend());
209   auto new_cnode = main_fg->NewCNode(fn_inputs);
210   new_cnode->set_abstract(sub_fg->output()->abstract());
211   new_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
212   return new_cnode;
213 }
214 
ExtendNode(const AnfNodePtr & node,const FuncGraphPtr & main_fg)215 bool SymbolEngineExtender::ExtendNode(const AnfNodePtr &node, const FuncGraphPtr &main_fg) {
216   ClusterConfig config;
217   config.inline_sub_func_graph = false;
218   config.sort_parameter = true;
219 
220   auto cnode = node->cast<CNodePtr>();
221   MS_EXCEPTION_IF_NULL(cnode);
222 
223   auto nodes = FindCandidates(cnode);
224   if (nodes.size() <= 1) {
225     return false;
226   }
227   MS_LOG(DEBUG) << "The size of list of nodes to be clustered: " << nodes.size();
228   config.only_output_basenode = node;
229   // Check if the symbol engine supports inferring for the graph, if not, skip cluster of this graph
230   auto [fg, inputs, outputs] = BuildSingleGraphFromNodes(nodes, config);
231   if (outputs.size() != 1) {
232     MS_LOG(DEBUG) << "The size of outputs should be 1, but got " << outputs.size();
233     return false;
234   }
235   auto symbol_engine = KernelPacketEngine::Build(fg);
236   if (!symbol_engine->SupportInfer()) {
237     MS_LOG(DEBUG) << "Symbol engine doesn't support infer shape from node: " << node->fullname_with_scope();
238     return false;
239   }
240   auto new_cnode = CreatePacketNode(main_fg, fg, inputs);
241   if (!common::AnfAlgo::IsDynamicShape(new_cnode)) {
242     MS_LOG(DEBUG) << "The node " << new_cnode->DebugString() << " is not dynamic shape";
243     return false;
244   }
245   auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", "extend");
246   fg->set_attr(kAttrKernelPacketNode, MakeValue(fuse_op_name));
247   fg->set_attr("only_depend_shape", FindOnlyDependShapeInputs(fg));
248   new_cnode->AddAttr(kAttrToPrim, MakeValue(AnfUtils::GetCNodeName(node) + "_packet"));
249   MS_LOG(INFO) << "Replace " << node->fullname_with_scope() << " with " << new_cnode->fullname_with_scope();
250   (void)main_fg->manager()->Replace(node, new_cnode);
251   return true;
252 }
253 
Run(const FuncGraphPtr & func_graph)254 bool SymbolEngineExtender::Run(const FuncGraphPtr &func_graph) {
255   // Find the manager for the FuncGraph.
256   auto mng = func_graph->manager();
257   MS_EXCEPTION_IF_NULL(mng);
258   // Find all cnodes.
259   auto cnodes = TopoSort(func_graph->output(), SuccIncoming, [](const AnfNodePtr &node) {
260     if (node->isa<CNode>()) {
261       return FOLLOW;
262     }
263     return EXCLUDE;
264   });
265 
266   bool changed = false;
267   // Process each subgraph.
268   for (auto cnode : cnodes) {
269     if (!CheckBaseNode(cnode)) {
270       continue;
271     }
272     if (ExtendNode(cnode, func_graph)) {
273       changed = true;
274     }
275   }
276   return changed;
277 }
278 }  // namespace mindspore::graphkernel::packet
279