1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/kernel_packet/symbol_engine_extender.h"
18
19 #include <algorithm>
20 #include <memory>
21 #include <functional>
22 #include <vector>
23 #include "utils/anf_utils.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/arithmetic_ops.h"
28 #include "mindspore/core/symbolic_shape/operation_builder.h"
29 #include "include/common/utils/anfalgo.h"
30 #include "backend/common/graph_kernel/core/graph_builder.h"
31 #include "backend/common/graph_kernel/kernel_packet/kernel_packet_engine.h"
32 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
33 #include "backend/common/graph_kernel/graph_kernel_flags.h"
34 #include "include/backend/anf_runtime_algorithm.h"
35 #include "backend/common/pass/insert_type_transform_op.h"
36
37 namespace mindspore::graphkernel::packet {
38 using symshape::DependOn;
39
IsHostOp(const AnfNodePtr & node)40 inline bool IsHostOp(const AnfNodePtr &node) {
41 if (!node->isa<CNode>()) {
42 return false;
43 }
44 if (AnfAlgo::IsKernelSelectBackoffOp(node)) {
45 return true;
46 }
47 // ops inserted in InsertTypeTransformOp
48 return opt::IsBackOffOp(node->cast<CNodePtr>());
49 }
50
IsDeviceOp(const AnfNodePtr & node)51 inline bool IsDeviceOp(const AnfNodePtr &node) {
52 if (!AnfUtils::IsRealKernel(node) || IsHostOp(node) || node->kernel_info() == nullptr) {
53 return false;
54 }
55 auto build_info = AnfAlgo::GetSelectKernelBuildInfo(node);
56 if (build_info != nullptr && build_info->valid()) {
57 return true;
58 }
59 return false;
60 }
61
CheckBaseNode(const AnfNodePtr & node)62 bool SymbolEngineExtender::CheckBaseNode(const AnfNodePtr &node) {
63 if (GetCNodePrimitive(node) == nullptr) {
64 return false;
65 }
66 if (!IsDeviceOp(node)) {
67 return false;
68 }
69 auto &flags = GraphKernelFlags::GetInstance();
70 if (!flags.enable_packet_ops_only.empty()) {
71 if (std::find(flags.enable_packet_ops_only.begin(), flags.enable_packet_ops_only.end(),
72 AnfUtils::GetCNodeName(node)) == flags.enable_packet_ops_only.end()) {
73 return false;
74 }
75 } else if (std::find(flags.disable_packet_ops.begin(), flags.disable_packet_ops.end(),
76 AnfUtils::GetCNodeName(node)) != flags.disable_packet_ops.end()) {
77 return false;
78 }
79 MS_LOG(DEBUG) << "Search from the base node: " << node->DebugString();
80 return true;
81 }
82
FindShapeDependHostNode(const CNodePtr & node,HashSet<AnfNodePtr> * visited,HashSet<AnfNodePtr> * valid_nodes)83 void SymbolEngineExtender::FindShapeDependHostNode(const CNodePtr &node, HashSet<AnfNodePtr> *visited,
84 HashSet<AnfNodePtr> *valid_nodes) {
85 if (!visited->insert(node).second) {
86 return;
87 }
88 auto prim = GetCNodePrimitive(node);
89 if (prim == nullptr) {
90 return;
91 }
92 if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
93 return;
94 }
95 auto depends = symshape::GetShapeDepends(prim, node->size() - 1);
96 if (depends.empty()) {
97 MS_LOG(DEBUG) << "The node " << node->fullname_with_scope() << " shape depend status is empty.";
98 return;
99 }
100 MS_LOG(DEBUG) << "Add " << node->fullname_with_scope() << " into candidates.";
101 (void)valid_nodes->insert(node);
102 for (size_t i = 0; i < depends.size(); i++) {
103 auto inp = node->input(i + 1)->cast<CNodePtr>();
104 if (inp == nullptr) {
105 continue;
106 }
107 // assume that building shape for host op does not depend input value again.
108 if (depends[i] == DependOn::kShape && IsHostOp(inp)) {
109 FindShapeDependHostNode(inp, visited, valid_nodes);
110 }
111 }
112 }
113
FindValueDependNode(const CNodePtr & node,HashSet<AnfNodePtr> * visited,HashSet<AnfNodePtr> * valid_nodes)114 void SymbolEngineExtender::FindValueDependNode(const CNodePtr &node, HashSet<AnfNodePtr> *visited,
115 HashSet<AnfNodePtr> *valid_nodes) {
116 if (!visited->insert(node).second) {
117 return;
118 }
119 if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
120 return;
121 }
122 auto prim = GetCNodePrimitive(node);
123 if (prim == nullptr) {
124 return;
125 }
126 auto depends = symshape::GetValueDepends(prim, node->size() - 1);
127 // always try to fuse host op, if the node does not support symbolic value, the whole packet will be dropped.
128 // only fuse device op when it supports building symbolic value.
129 if (depends.empty() && !IsHostOp(node)) {
130 MS_LOG(DEBUG) << "The " << node->fullname_with_scope() << " is not host op and value depend status is empty.";
131 return;
132 }
133 MS_LOG(DEBUG) << "Add " << node->fullname_with_scope() << " into candidates.";
134 (void)valid_nodes->insert(node);
135 for (size_t i = 0; i < depends.size(); i++) {
136 auto inp = node->input(i + 1)->cast<CNodePtr>();
137 if (inp == nullptr) {
138 continue;
139 }
140 if (depends[i] == DependOn::kValue) {
141 FindValueDependNode(inp, visited, valid_nodes);
142 } else if (IsHostOp(inp)) {
143 MS_LOG(DEBUG) << "The input[" << i << "] is host op.";
144 FindShapeDependHostNode(inp, visited, valid_nodes);
145 }
146 }
147 }
148
FindCandidates(const CNodePtr & base_node)149 AnfNodePtrList SymbolEngineExtender::FindCandidates(const CNodePtr &base_node) {
150 HashSet<AnfNodePtr> visited;
151 HashSet<AnfNodePtr> valid_nodes;
152 auto depends = symshape::GetShapeDepends(GetCNodePrimitive(base_node), base_node->size() - 1);
153 if (depends.empty()) {
154 return {};
155 }
156 // use dfs to find the clusterable nodes.
157 for (size_t i = 0; i < depends.size(); i++) {
158 auto inp = base_node->input(i + 1);
159 if (!inp->isa<CNode>()) {
160 continue;
161 }
162 if (depends[i] == DependOn::kValue) {
163 MS_LOG(DEBUG) << "The input[" << i << "] " << inp->fullname_with_scope() << " is value-depended.";
164 FindValueDependNode(inp->cast<CNodePtr>(), &visited, &valid_nodes);
165 } else if (IsHostOp(inp)) {
166 MS_LOG(DEBUG) << "The input[" << i << "] " << inp->fullname_with_scope()
167 << " is not value-depended, but it's a host op.";
168 FindValueDependNode(inp->cast<CNodePtr>(), &visited, &valid_nodes);
169 }
170 }
171 if (valid_nodes.empty()) {
172 return {};
173 }
174 (void)valid_nodes.insert(base_node);
175
176 return TopoSort(base_node, SuccIncoming, [&valid_nodes](const AnfNodePtr &node) -> IncludeType {
177 return valid_nodes.count(node) > 0 ? FOLLOW : EXCLUDE;
178 });
179 }
180
FindOnlyDependShapeInputs(const FuncGraphPtr & fg) const181 ValuePtr SymbolEngineExtender::FindOnlyDependShapeInputs(const FuncGraphPtr &fg) const {
182 const auto ¶ms = fg->parameters();
183 std::vector<bool> only_depend_shape(params.size(), true);
184 auto engine = fg->symbol_engine();
185 MS_EXCEPTION_IF_NULL(engine);
186 // depend value when infer
187 for (size_t i = 0; i < params.size(); i++) {
188 if (engine->IsDependValue(params[i])) {
189 only_depend_shape[i] = false;
190 }
191 }
192 // depend value when launch kernel
193 auto kernel = fg->output()->cast<CNodePtr>();
194 MS_EXCEPTION_IF_NULL(kernel);
195 for (auto inp : kernel->inputs()) {
196 auto iter = std::find(params.begin(), params.end(), inp);
197 if (iter != params.end()) {
198 only_depend_shape[iter - params.begin()] = false;
199 }
200 }
201 return MakeValue<std::vector<bool>>(only_depend_shape);
202 }
203
CreatePacketNode(const FuncGraphPtr & main_fg,const FuncGraphPtr & sub_fg,const AnfNodePtrList & inputs)204 CNodePtr CreatePacketNode(const FuncGraphPtr &main_fg, const FuncGraphPtr &sub_fg, const AnfNodePtrList &inputs) {
205 std::vector<AnfNodePtr> fn_inputs;
206 fn_inputs.reserve(inputs.size() + 1);
207 (void)fn_inputs.emplace_back(NewValueNode(sub_fg));
208 (void)fn_inputs.insert(fn_inputs.end(), inputs.cbegin(), inputs.cend());
209 auto new_cnode = main_fg->NewCNode(fn_inputs);
210 new_cnode->set_abstract(sub_fg->output()->abstract());
211 new_cnode->set_kernel_info(std::make_shared<device::KernelInfo>());
212 return new_cnode;
213 }
214
ExtendNode(const AnfNodePtr & node,const FuncGraphPtr & main_fg)215 bool SymbolEngineExtender::ExtendNode(const AnfNodePtr &node, const FuncGraphPtr &main_fg) {
216 ClusterConfig config;
217 config.inline_sub_func_graph = false;
218 config.sort_parameter = true;
219
220 auto cnode = node->cast<CNodePtr>();
221 MS_EXCEPTION_IF_NULL(cnode);
222
223 auto nodes = FindCandidates(cnode);
224 if (nodes.size() <= 1) {
225 return false;
226 }
227 MS_LOG(DEBUG) << "The size of list of nodes to be clustered: " << nodes.size();
228 config.only_output_basenode = node;
229 // Check if the symbol engine supports inferring for the graph, if not, skip cluster of this graph
230 auto [fg, inputs, outputs] = BuildSingleGraphFromNodes(nodes, config);
231 if (outputs.size() != 1) {
232 MS_LOG(DEBUG) << "The size of outputs should be 1, but got " << outputs.size();
233 return false;
234 }
235 auto symbol_engine = KernelPacketEngine::Build(fg);
236 if (!symbol_engine->SupportInfer()) {
237 MS_LOG(DEBUG) << "Symbol engine doesn't support infer shape from node: " << node->fullname_with_scope();
238 return false;
239 }
240 auto new_cnode = CreatePacketNode(main_fg, fg, inputs);
241 if (!common::AnfAlgo::IsDynamicShape(new_cnode)) {
242 MS_LOG(DEBUG) << "The node " << new_cnode->DebugString() << " is not dynamic shape";
243 return false;
244 }
245 auto fuse_op_name = GkUtils::ExtractGraphKernelName(nodes, "", "extend");
246 fg->set_attr(kAttrKernelPacketNode, MakeValue(fuse_op_name));
247 fg->set_attr("only_depend_shape", FindOnlyDependShapeInputs(fg));
248 new_cnode->AddAttr(kAttrToPrim, MakeValue(AnfUtils::GetCNodeName(node) + "_packet"));
249 MS_LOG(INFO) << "Replace " << node->fullname_with_scope() << " with " << new_cnode->fullname_with_scope();
250 (void)main_fg->manager()->Replace(node, new_cnode);
251 return true;
252 }
253
Run(const FuncGraphPtr & func_graph)254 bool SymbolEngineExtender::Run(const FuncGraphPtr &func_graph) {
255 // Find the manager for the FuncGraph.
256 auto mng = func_graph->manager();
257 MS_EXCEPTION_IF_NULL(mng);
258 // Find all cnodes.
259 auto cnodes = TopoSort(func_graph->output(), SuccIncoming, [](const AnfNodePtr &node) {
260 if (node->isa<CNode>()) {
261 return FOLLOW;
262 }
263 return EXCLUDE;
264 });
265
266 bool changed = false;
267 // Process each subgraph.
268 for (auto cnode : cnodes) {
269 if (!CheckBaseNode(cnode)) {
270 continue;
271 }
272 if (ExtendNode(cnode, func_graph)) {
273 changed = true;
274 }
275 }
276 return changed;
277 }
278 } // namespace mindspore::graphkernel::packet
279