1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/core/graph_kernel_splitter.h"
17 #include <algorithm>
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include <queue>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "ir/anf.h"
25 #include "utils/anf_utils.h"
26 #include "utils/hash_map.h"
27 #include "utils/hash_set.h"
28 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
29 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
30 #include "backend/common/graph_kernel/split_model/split_model_factory.h"
31
32 namespace mindspore::graphkernel {
33 namespace {
TraverseFuncGraphFromCNode(const CNodePtr & cnode,const std::function<void (AnfNodePtr &)> & callback)34 void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
35 mindspore::HashSet<AnfNodePtr> visited;
36 std::queue<AnfNodePtr> que;
37 que.push(cnode);
38 (void)visited.insert(cnode);
39 while (!que.empty()) {
40 auto ft_node = que.front();
41 que.pop();
42 callback(ft_node);
43 auto ft_cnode = ft_node->cast<CNodePtr>();
44 if (ft_cnode == nullptr) {
45 continue;
46 }
47 for (const auto &in_node : ft_cnode->inputs()) {
48 if (visited.count(in_node) == 0) {
49 que.push(in_node);
50 (void)visited.insert(in_node);
51 }
52 }
53 }
54 }
55
56 // Visited each AnfNode once, use callback to do the job on AnfNode
TraverseFuncGraph(const FuncGraphPtr & root,const std::function<void (AnfNodePtr &)> & callback)57 inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) {
58 TraverseFuncGraphFromCNode(root->get_return(), callback);
59 }
60
61 class Area {
62 public:
Area(const AnfNodePtrList & anf_arr)63 explicit Area(const AnfNodePtrList &anf_arr) {
64 nodes_.insert(anf_arr.cbegin(), anf_arr.cend());
65 for (auto &node : anf_arr) {
66 auto cnode = node->cast<CNodePtr>();
67 if (cnode == nullptr) {
68 continue;
69 }
70 const auto &inputs = cnode->inputs();
71 if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
72 (void)spy_cnodes_.emplace_back(node);
73 }
74 }
75 }
76
77 ~Area() = default;
78
79 // Set the external inputs of spy as a Parameter.
CreateParameters(const FuncGraphPtr & func_graph,mindspore::HashMap<ParameterPtr,AnfNodePtr> * param_node_map)80 void CreateParameters(const FuncGraphPtr &func_graph, mindspore::HashMap<ParameterPtr, AnfNodePtr> *param_node_map) {
81 mindspore::HashMap<AnfNodePtr, ParameterPtr> node_param_map;
82 for (auto node : this->spy_cnodes_) {
83 auto cnode = node->cast<CNodePtr>();
84 MS_EXCEPTION_IF_NULL(cnode);
85 for (size_t i = 1; i < cnode->size(); ++i) {
86 AnfNodePtr in_node = cnode->input(i);
87 if (!IsExternalCNode(in_node)) {
88 continue;
89 }
90 auto it = node_param_map.find(in_node);
91 if (it == node_param_map.end()) {
92 auto new_param = std::make_shared<Parameter>(func_graph);
93 new_param->set_abstract(in_node->abstract());
94 func_graph->add_parameter(new_param);
95 (void)node_param_map.emplace(in_node, new_param);
96 cnode->set_input(i, new_param);
97 } else {
98 cnode->set_input(i, it->second);
99 }
100 }
101 }
102 this->spy_cnodes_.clear(); // spy list is not useful anymore
103 for (auto &&elem : node_param_map) {
104 (void)param_node_map->emplace(elem.second, elem.first);
105 }
106 return;
107 }
108
109 // Make a return node for traitor nodes.
CreateReturnNode(const FuncGraphPtr & func_graph,mindspore::HashMap<AnfNodePtr,size_t> * tuple_node_index)110 void CreateReturnNode(const FuncGraphPtr &func_graph, mindspore::HashMap<AnfNodePtr, size_t> *tuple_node_index) {
111 // If there's no traitor in the area, it means that this area is the last part
112 // of the original FuncGraph, it already contains the original Return node.
113 if (traitor_nodes_.empty()) {
114 for (auto &node : nodes_) {
115 if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
116 func_graph->set_return(node->cast<CNodePtr>());
117 node->set_func_graph(func_graph);
118 return;
119 }
120 }
121 MS_LOG(ERROR) << "Cannot find the return node in " << func_graph->ToString();
122 return;
123 }
124 AnfNodePtrList return_inputs = {NewValueNode(prim::kPrimReturn)};
125 if (traitor_nodes_.size() > 1) {
126 // The area has multiple output, it's necessary to make a tuple for them.
127 AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
128 AbstractBasePtrList abstracts;
129 size_t i = 0;
130 for (auto &traitor : traitor_nodes_) {
131 (void)tuple_node_index->emplace(traitor, i++);
132 (void)maketuple_inputs.emplace_back(traitor);
133 (void)abstracts.emplace_back(traitor->abstract());
134 }
135 auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
136 maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
137 (void)nodes_.insert(maketuple_node);
138 (void)return_inputs.emplace_back(maketuple_node);
139 } else {
140 (void)return_inputs.emplace_back(traitor_nodes_[0]);
141 }
142 auto return_node = func_graph->NewCNode(return_inputs);
143 return_node->set_abstract(return_inputs.back()->abstract());
144 func_graph->set_return(return_node);
145 (void)nodes_.insert(return_node);
146 traitor_nodes_.clear(); // traitor list is not useful anymore
147 return;
148 }
149
AddTraitor(const AnfNodePtr & node)150 void AddTraitor(const AnfNodePtr &node) {
151 if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
152 (void)traitor_nodes_.emplace_back(node);
153 }
154 }
155
nodes() const156 const mindspore::HashSet<AnfNodePtr> &nodes() const { return nodes_; }
spy_cnodes() const157 const std::vector<AnfNodePtr> &spy_cnodes() const { return spy_cnodes_; }
158
159 private:
160 // This is a CNode that does not belong to this area.
IsExternalCNode(const AnfNodePtr & node) const161 bool IsExternalCNode(const AnfNodePtr &node) const { return node->isa<CNode>() && this->nodes_.count(node) == 0; }
162
163 // nodes in this area
164 mindspore::HashSet<AnfNodePtr> nodes_;
165 // if a node's output is used by other Area, it's a traitor
166 std::vector<AnfNodePtr> traitor_nodes_;
167 // if a node use other Area's output, it's a spy
168 std::vector<AnfNodePtr> spy_cnodes_;
169 };
170
171 class AreaGraph {
172 public:
173 using AreaGraphPtr = std::shared_ptr<AreaGraph>;
174
175 // Build an area graph to maintain the relation between areas.
176 // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
BuildAreaGraph(const std::vector<AnfNodePtrList> & node_groups)177 static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
178 auto area_graph = std::make_shared<AreaGraph>(node_groups);
179 if (area_graph == nullptr) {
180 return nullptr;
181 }
182 if (!area_graph->TopoSort()) {
183 MS_LOG(WARNING) << "The groups have a cycle. The first node is " << node_groups[0][0]->fullname_with_scope();
184 return nullptr;
185 }
186 return area_graph;
187 }
188
189 // Split the graph to multiple areas, and reconnect the edges between the areas.
190 // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
191 // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
SplitGraph(const FuncGraphPtr & main_func_graph,std::vector<CNodePtr> * main_cnodes,std::vector<size_t> * cnode_group_id,const std::function<void (const Area &)> & expand_callback)192 void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
193 std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) {
194 main_cnodes->clear();
195 main_cnodes->resize(areas_.size(), nullptr);
196
197 for (auto &area : this->areas_) {
198 expand_callback(area);
199 }
200
201 for (auto index : topo_order_) {
202 auto ¤t_area = areas_[index];
203 auto sub_func_graph = std::make_shared<FuncGraph>();
204 mindspore::HashMap<ParameterPtr, AnfNodePtr> param_node_map;
205
206 current_area.CreateParameters(sub_func_graph, ¶m_node_map);
207 current_area.CreateReturnNode(sub_func_graph, &node_index_in_returned_tuple_);
208 auto new_main_cnode = this->CreateMainCNode(main_func_graph, sub_func_graph, *main_cnodes, param_node_map);
209 (*main_cnodes)[index] = new_main_cnode;
210 }
211
212 SortCNodes(main_cnodes);
213 *cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore.
214 return;
215 }
216
AreaGraph(const std::vector<AnfNodePtrList> & node_groups)217 explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
218 for (size_t i = 0; i < node_groups.size(); ++i) {
219 (void)areas_.emplace_back(node_groups[i]);
220 for (const auto &node : node_groups[i]) {
221 node_area_map_[node] = i;
222 }
223 }
224 for (auto &area : areas_) {
225 for (auto &spy : area.spy_cnodes()) {
226 auto cnode = spy->cast<CNodePtr>();
227 MS_EXCEPTION_IF_NULL(cnode);
228 size_t v = node_area_map_[spy];
229 for (auto &in_node : cnode->inputs()) {
230 if (!in_node->isa<CNode>()) {
231 continue;
232 }
233 // area edge u -> v
234 size_t u = node_area_map_[in_node];
235 if (u == v) {
236 continue;
237 }
238 areas_[u].AddTraitor(in_node);
239 if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
240 (void)edge_prev_[v].emplace_back(u);
241 }
242 }
243 }
244 }
245 }
246 ~AreaGraph() = default;
247
248 private:
249 // Topological sort the areas.
TopoSort()250 bool TopoSort() {
251 std::vector<int> out_degree(edge_prev_.size(), 0);
252 std::queue<size_t> que;
253 for (auto &prev : edge_prev_) {
254 for (size_t i : prev) {
255 out_degree[i]++;
256 }
257 }
258 for (size_t i = 0; i < out_degree.size(); ++i) {
259 if (out_degree[i] == 0) {
260 que.push(i);
261 }
262 }
263 while (!que.empty()) {
264 size_t u = que.front();
265 que.pop();
266 (void)topo_order_.emplace_back(u);
267 for (size_t i : edge_prev_[u]) {
268 if (--out_degree[i] == 0) {
269 que.push(i);
270 }
271 }
272 }
273 std::reverse(topo_order_.begin(), topo_order_.end());
274 return topo_order_.size() == areas_.size();
275 }
276
277 // Make a CNode in main graph to hold the sub_func_graph.
CreateMainCNode(const FuncGraphPtr & main_func_graph,const FuncGraphPtr & sub_func_graph,const std::vector<CNodePtr> & main_cnodes,const mindspore::HashMap<ParameterPtr,AnfNodePtr> & param_node_map)278 CNodePtr CreateMainCNode(const FuncGraphPtr &main_func_graph, const FuncGraphPtr &sub_func_graph,
279 const std::vector<CNodePtr> &main_cnodes,
280 const mindspore::HashMap<ParameterPtr, AnfNodePtr> ¶m_node_map) {
281 TraceGuard guard(std::make_shared<TraceOpt>(sub_func_graph->debug_info()));
282 AnfNodePtrList main_cnode_inputs = {NewValueNode(sub_func_graph)};
283 for (const auto ¶m : sub_func_graph->parameters()) {
284 // assert the param exists.
285 const auto &input_node = param_node_map.find(param->cast<ParameterPtr>())->second;
286 size_t input_area = node_area_map_[input_node];
287 // if the input node is in a tuple, then we need to create a GetItem fot it.
288 if (node_index_in_returned_tuple_.count(input_node) != 0) {
289 auto idx_val = SizeToLong(node_index_in_returned_tuple_[input_node]);
290 auto idx = NewValueNode(idx_val);
291 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
292 AnfNodePtrList getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), main_cnodes[input_area], idx};
293 TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info()));
294 auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
295 auto abs_tuple = dyn_cast<abstract::AbstractTuple>(main_cnodes[input_area]->abstract());
296 if (idx_val < SizeToLong(abs_tuple->size())) {
297 getitem_node->set_abstract(abs_tuple->elements()[LongToSize(idx_val)]);
298 } else {
299 getitem_node->set_abstract(main_cnodes[input_area]->abstract());
300 }
301 (void)main_cnode_inputs.emplace_back(getitem_node);
302 } else {
303 (void)main_cnode_inputs.emplace_back(main_cnodes[input_area]);
304 }
305 }
306 auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
307 new_main_cnode->set_abstract(sub_func_graph->output()->abstract());
308 return new_main_cnode;
309 }
310
SortCNodes(std::vector<CNodePtr> * main_cnodes) const311 void SortCNodes(std::vector<CNodePtr> *main_cnodes) const {
312 std::vector<CNodePtr> main_cnodes_sorted;
313 (void)std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
314 [main_cnodes](size_t index) { return main_cnodes->at(index); });
315 *main_cnodes = std::move(main_cnodes_sorted);
316 }
317
318 // Areas in this subgraph
319 std::vector<Area> areas_;
320 // Adjacency table of areas
321 std::vector<std::vector<size_t>> edge_prev_;
322 // Topological order of areas
323 std::vector<size_t> topo_order_;
324 // Map AnfNode to Area id
325 mindspore::HashMap<AnfNodePtr, size_t> node_area_map_;
326 // Map the nodes to their index if there are multiple value in an area
327 mindspore::HashMap<AnfNodePtr, size_t> node_index_in_returned_tuple_;
328 };
329
330 class Splitter {
331 public:
332 using SplitterPtr = std::shared_ptr<Splitter>;
333
Split()334 bool Split() {
335 GenParamMap();
336 auto ori_sub_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
337 if (!split_schemer_->Split(ori_sub_func_graph)) {
338 return false;
339 }
340
341 auto area_graph = AreaGraph::BuildAreaGraph(split_schemer_->split_plan());
342 if (area_graph == nullptr) {
343 return false;
344 }
345
346 // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
347 std::vector<size_t> cnodes_group_id;
348 area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id,
349 [this](const Area &area) { this->AreaExpand(area); });
350
351 RebuildGraph(cnodes_group_id);
352
353 return true;
354 }
355
MakeSplitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)356 static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) {
357 MS_EXCEPTION_IF_NULL(main_cnode);
358 MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
359 MS_EXCEPTION_IF_NULL(split_schemer);
360 return std::make_shared<Splitter>(main_cnode, split_schemer);
361 }
362
Splitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)363 Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer)
364 : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
365 ~Splitter() = default;
366
367 private:
368 // Maintain new subgraphs in main graph.
RebuildGraph(const std::vector<size_t> & cnodes_group_id)369 void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
370 BindFuncGraph();
371 RecoverParameter();
372 SetSplitNodeName(cnodes_group_id);
373 ConnectToMainGraph(cnodes_group_id);
374 UpdateMainNodesKernelInfo();
375 }
376
377 // Rebind nodes to its new sub_func_graph
BindFuncGraph() const378 void BindFuncGraph() const {
379 for (const auto &cnode : new_subgraph_cnodes_) {
380 auto sub_func_graph = GetCNodeFuncGraph(cnode);
381 auto callback = [&sub_func_graph](const AnfNodePtr &node) {
382 if (!node->isa<ValueNode>()) {
383 node->set_func_graph(sub_func_graph);
384 }
385 };
386 TraverseFuncGraph(sub_func_graph, callback);
387 }
388 }
389
390 // Recover the original subgraph's parameter if the new graph needs it
RecoverParameter()391 void RecoverParameter() {
392 for (const auto &cnode : new_subgraph_cnodes_) {
393 auto sub_func_graph = GetCNodeFuncGraph(cnode);
394 auto callback = [&cnode, &sub_func_graph, this](const AnfNodePtr &node) {
395 auto param = node->cast<ParameterPtr>();
396 if (param == nullptr) {
397 return;
398 }
399 auto it = this->param_to_main_graph_node_map_.find(param);
400 if (it != this->param_to_main_graph_node_map_.end()) {
401 auto input = it->second;
402 cnode->add_input(input);
403 sub_func_graph->add_parameter(param);
404 // Avoid repeating parameters.
405 (void)this->param_to_main_graph_node_map_.erase(it);
406 }
407 };
408 TraverseFuncGraph(sub_func_graph, callback);
409 }
410 }
411
InlineSubFuncGraph(const CNodePtr & main_node)412 CNodePtr InlineSubFuncGraph(const CNodePtr &main_node) {
413 auto func_graph = GetCNodeFuncGraph(main_node);
414 const auto &inputs = main_node->inputs();
415 auto output = func_graph->output()->cast<CNodePtr>();
416 MS_EXCEPTION_IF_NULL(output);
417 const auto ¶meters = func_graph->parameters();
418 mindspore::HashMap<AnfNodePtr, AnfNodePtr> param_input;
419 for (size_t i = 0; i < parameters.size(); ++i) {
420 param_input[parameters[i]] = inputs[i + 1];
421 }
422 auto sub_nodes = TopoSort(func_graph->get_return());
423 for (auto node : sub_nodes) {
424 if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
425 cnode->set_func_graph(main_func_graph_);
426 for (size_t i = 1; i < cnode->size(); ++i) {
427 auto iter = param_input.find(cnode->input(i));
428 if (iter != param_input.end()) {
429 cnode->set_input(i, iter->second);
430 }
431 }
432 if (AnfUtils::IsRealKernel(node)) {
433 (void)maingraph_nodes_.emplace_back(node);
434 }
435 }
436 }
437 return output;
438 }
439
SetSplitNodeName(const std::vector<size_t> & cnodes_group_id) const440 void SetSplitNodeName(const std::vector<size_t> &cnodes_group_id) const {
441 auto old_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
442 std::string ori_node_name;
443 if (old_func_graph->has_attr(kAttrNodeName)) {
444 ori_node_name = GetValue<std::string>(old_func_graph->get_attr(kAttrNodeName));
445 } else {
446 ori_node_name = GetValue<std::string>(old_func_graph->get_attr("graph_kernel"));
447 }
448 for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
449 auto group_id = cnodes_group_id[i];
450 if (!split_schemer_->NeedInline(group_id)) {
451 std::string node_name = ori_node_name + "_" + std::to_string(group_id);
452 AnfUtils::SetNodeAttr(kAttrNodeName, MakeValue(node_name), new_subgraph_cnodes_[i]);
453 }
454 }
455 }
456
457 // Set the new sub_func_graph node as input of nodes original main graph.
ConnectToMainGraph(const std::vector<size_t> & cnodes_group_id)458 void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
459 // For single output kernel, the last area contains the original output node (return node),
460 // to replace old subgraph with new subgraphs, just replace the old CNode with new last CNode.
461 // For multiple output kernel, to avoid returning Parameter, the last MakeTuple was distribute to
462 // a new FuncGraph, just inline the last MakeTuple node.
463 mindspore::HashMap<AnfNodePtr, AnfNodePtr> replace_map;
464
465 for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
466 if (split_schemer_->NeedInline(cnodes_group_id[i])) {
467 // Connect the sub_graph's inner node to main_graph
468 auto output = InlineSubFuncGraph(new_subgraph_cnodes_[i]);
469 if (i + 1 == new_subgraph_cnodes_.size()) {
470 replace_map[this->old_subgraph_cnode_] = output;
471 } else {
472 replace_map[new_subgraph_cnodes_[i]] = output;
473 }
474 } else {
475 if (i + 1 == new_subgraph_cnodes_.size()) {
476 replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
477 }
478 (void)maingraph_nodes_.emplace_back(new_subgraph_cnodes_[i]);
479 }
480 }
481
482 TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
483 auto cnode = node->cast<CNodePtr>();
484 if (cnode == nullptr) {
485 return;
486 }
487 for (size_t i = 1; i < cnode->size(); ++i) {
488 auto input_node = cnode->input(i);
489 auto iter = replace_map.find(input_node);
490 if (iter != replace_map.end()) {
491 cnode->set_input(i, iter->second);
492 }
493 }
494 });
495 }
496
UpdateMainNodesKernelInfo() const497 void UpdateMainNodesKernelInfo() const {
498 auto graph_manager = main_func_graph_->manager();
499 MS_EXCEPTION_IF_NULL(graph_manager);
500
501 for (auto node : maingraph_nodes_) {
502 MS_LOG(DEBUG) << "Update kernel_info for " << node->DebugString() << " (" << node->fullname_with_scope() << ")";
503 auto sub_func_graph = GetCNodeFuncGraph(node);
504 if (sub_func_graph != nullptr) {
505 graph_manager->AddFuncGraph(sub_func_graph);
506 auto attr = GkUtils::ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
507 sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
508 Callback::Instance()->SetGraphKernelNodeKernelInfo(node);
509 } else {
510 Callback::Instance()->ResetKernelInfo(node);
511 }
512 }
513 }
514
515 // Copy all Parameter and ValueNode that the area used.
AreaExpand(const Area & area)516 void AreaExpand(const Area &area) {
517 mindspore::HashMap<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
518 for (auto sub_node : area.nodes()) {
519 auto sub_cnode = sub_node->cast<CNodePtr>();
520 if (sub_cnode == nullptr) {
521 continue;
522 }
523 for (size_t i = 1; i < sub_cnode->size(); ++i) {
524 auto in_node = sub_cnode->input(i);
525 if (in_node->isa<CNode>()) {
526 continue;
527 }
528 auto it = old_valuenode_and_param_map.find(in_node);
529 if (it != old_valuenode_and_param_map.end()) {
530 sub_cnode->set_input(i, it->second);
531 } else {
532 if (in_node->isa<Parameter>()) {
533 auto param = in_node->cast<ParameterPtr>();
534 auto cp_param = this->ParameterClone(param, in_node->func_graph());
535 old_valuenode_and_param_map[in_node] = cp_param->cast<AnfNodePtr>();
536 sub_cnode->set_input(i, cp_param);
537 }
538 }
539 }
540 }
541 }
542
GenParamMap()543 void GenParamMap() {
544 auto sub_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
545 auto ¶m_arr = sub_func_graph->parameters();
546 for (size_t i = 0; i < param_arr.size(); ++i) {
547 auto param = param_arr[i]->cast<ParameterPtr>();
548 MS_EXCEPTION_IF_NULL(param);
549 param_to_main_graph_node_map_[param] = old_subgraph_cnode_->input(i + 1);
550 }
551 }
552
ParameterClone(const ParameterPtr & param,const FuncGraphPtr & func)553 ParameterPtr ParameterClone(const ParameterPtr ¶m, const FuncGraphPtr &func) {
554 ParameterPtr param_c = std::make_shared<Parameter>(func);
555 param_c->set_name(param->name());
556 param_c->set_abstract(param->abstract());
557 auto node = param_to_main_graph_node_map_[param];
558 param_to_main_graph_node_map_[param_c] = node;
559 return param_c;
560 }
561
562 FuncGraphPtr main_func_graph_;
563 CNodePtr old_subgraph_cnode_; // The cnode that holds the original sub_func_graph
564 std::vector<CNodePtr> new_subgraph_cnodes_; // The cnode list that hold the new sub_func_graph
565 std::vector<AnfNodePtr> maingraph_nodes_; // The nodes in main graph finally, include "call" and inlined node
566 SplitSchemerPtr split_schemer_;
567 mindspore::HashMap<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
568 };
569
570 class CppCostModelSplitSchemer : public CommonSplitSchemer {
571 public:
CppCostModelSplitSchemer(const std::string & processor)572 explicit CppCostModelSplitSchemer(const std::string &processor) : processor_(processor) {}
573 ~CppCostModelSplitSchemer() = default;
Split(const FuncGraphPtr & func_graph)574 bool Split(const FuncGraphPtr &func_graph) override {
575 if (!SplitByCostModel(func_graph)) {
576 return false;
577 }
578 GroupReturnNode(func_graph);
579 return true;
580 }
581
582 protected:
SplitByCostModel(const FuncGraphPtr & func_graph)583 bool SplitByCostModel(const FuncGraphPtr &func_graph) {
584 mindspore::HashMap<inner::NodePtr, AnfNodePtr> op_node_map;
585 auto lg = GkUtils::AnfGraph2LiteGraph(func_graph, &op_node_map);
586 MS_LOG(DEBUG) << "Litegraph: " << lg->ToString();
587 // use the original node index to sort the split_plan's nodes.
588 mindspore::HashMap<AnfNodePtr, size_t> node_idx_map;
589 for (size_t i = 0; i < lg->ops().size(); ++i) {
590 node_idx_map[op_node_map[lg->ops()[i]]] = i;
591 }
592 auto model = inner::SplitModelFactory::Instance().CreateSplitModel(processor_);
593 MS_EXCEPTION_IF_NULL(model);
594 model->Run(lg);
595 auto &areas = model->areas();
596 for (auto &area : areas) {
597 AnfNodePtrList nodes;
598 for (auto &op : area->ops()) {
599 (void)nodes.emplace_back(op_node_map[op]);
600 node_group_[nodes.back()] = split_plan_.size();
601 }
602 std::sort(nodes.begin(), nodes.end(), [&node_idx_map](const AnfNodePtr &a, const AnfNodePtr &b) {
603 return node_idx_map[a] < node_idx_map[b];
604 });
605 (void)split_plan_.emplace_back(std::move(nodes));
606 need_inline_.push_back((area->mode() == inner::AreaMode::BASIC ? 1 : 0));
607 }
608 return split_plan_.size() > 1 || (split_plan_.size() == 1 && NeedInline(0));
609 }
610
611 std::string processor_;
612 };
613 } // namespace
614
GetSplitSchema(const std::string & processor)615 std::shared_ptr<SplitSchemer> GraphKernelSplitter::GetSplitSchema(const std::string &processor) {
616 return std::make_shared<CppCostModelSplitSchemer>(processor);
617 }
618
TrySplit(const CNodePtr & sub_root_cnode)619 bool GraphKernelSplitter::TrySplit(const CNodePtr &sub_root_cnode) {
620 MS_LOG(DEBUG) << "Split process node: " << sub_root_cnode->fullname_with_scope();
621 auto processor = Callback::Instance()->GetTargetFromContext();
622 auto schm = GetSplitSchema(processor);
623 MS_EXCEPTION_IF_NULL(schm);
624 auto splitter = Splitter::MakeSplitter(sub_root_cnode, schm);
625 MS_EXCEPTION_IF_NULL(splitter);
626 bool result = splitter->Split();
627 MS_LOG(DEBUG) << "Split node completed, result: " << result;
628 return result;
629 }
630
Run(const FuncGraphPtr & func_graph)631 bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
632 MS_EXCEPTION_IF_NULL(func_graph);
633 auto mng = func_graph->manager();
634 if (mng == nullptr) {
635 mng = Manage(func_graph, true);
636 func_graph->set_manager(mng);
637 }
638 auto todos = TopoSort(func_graph->get_return());
639
640 // Split subgraphs in reversed topo order,
641 // since the nodes behind the processing node may be modified when splitting.
642 bool changed = false;
643 for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
644 auto node = (*iter)->cast<CNodePtr>();
645 if (node != nullptr && AnfUtils::IsGraphKernel(node)) {
646 changed = TrySplit(node) || changed;
647 }
648 }
649 mng->RemoveRoots();
650 mng->KeepRoots({func_graph});
651 return changed;
652 }
653 } // namespace mindspore::graphkernel
654