• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/core/graph_kernel_splitter.h"
17 #include <algorithm>
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include <queue>
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "ir/anf.h"
25 #include "utils/anf_utils.h"
26 #include "utils/hash_map.h"
27 #include "utils/hash_set.h"
28 #include "backend/common/graph_kernel/core/graph_kernel_callback.h"
29 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
30 #include "backend/common/graph_kernel/split_model/split_model_factory.h"
31 
32 namespace mindspore::graphkernel {
33 namespace {
TraverseFuncGraphFromCNode(const CNodePtr & cnode,const std::function<void (AnfNodePtr &)> & callback)34 void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
35   mindspore::HashSet<AnfNodePtr> visited;
36   std::queue<AnfNodePtr> que;
37   que.push(cnode);
38   (void)visited.insert(cnode);
39   while (!que.empty()) {
40     auto ft_node = que.front();
41     que.pop();
42     callback(ft_node);
43     auto ft_cnode = ft_node->cast<CNodePtr>();
44     if (ft_cnode == nullptr) {
45       continue;
46     }
47     for (const auto &in_node : ft_cnode->inputs()) {
48       if (visited.count(in_node) == 0) {
49         que.push(in_node);
50         (void)visited.insert(in_node);
51       }
52     }
53   }
54 }
55 
56 // Visited each AnfNode once, use callback to do the job on AnfNode
TraverseFuncGraph(const FuncGraphPtr & root,const std::function<void (AnfNodePtr &)> & callback)57 inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) {
58   TraverseFuncGraphFromCNode(root->get_return(), callback);
59 }
60 
61 class Area {
62  public:
Area(const AnfNodePtrList & anf_arr)63   explicit Area(const AnfNodePtrList &anf_arr) {
64     nodes_.insert(anf_arr.cbegin(), anf_arr.cend());
65     for (auto &node : anf_arr) {
66       auto cnode = node->cast<CNodePtr>();
67       if (cnode == nullptr) {
68         continue;
69       }
70       const auto &inputs = cnode->inputs();
71       if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
72         (void)spy_cnodes_.emplace_back(node);
73       }
74     }
75   }
76 
77   ~Area() = default;
78 
79   // Set the external inputs of spy as a Parameter.
CreateParameters(const FuncGraphPtr & func_graph,mindspore::HashMap<ParameterPtr,AnfNodePtr> * param_node_map)80   void CreateParameters(const FuncGraphPtr &func_graph, mindspore::HashMap<ParameterPtr, AnfNodePtr> *param_node_map) {
81     mindspore::HashMap<AnfNodePtr, ParameterPtr> node_param_map;
82     for (auto node : this->spy_cnodes_) {
83       auto cnode = node->cast<CNodePtr>();
84       MS_EXCEPTION_IF_NULL(cnode);
85       for (size_t i = 1; i < cnode->size(); ++i) {
86         AnfNodePtr in_node = cnode->input(i);
87         if (!IsExternalCNode(in_node)) {
88           continue;
89         }
90         auto it = node_param_map.find(in_node);
91         if (it == node_param_map.end()) {
92           auto new_param = std::make_shared<Parameter>(func_graph);
93           new_param->set_abstract(in_node->abstract());
94           func_graph->add_parameter(new_param);
95           (void)node_param_map.emplace(in_node, new_param);
96           cnode->set_input(i, new_param);
97         } else {
98           cnode->set_input(i, it->second);
99         }
100       }
101     }
102     this->spy_cnodes_.clear();  // spy list is not useful anymore
103     for (auto &&elem : node_param_map) {
104       (void)param_node_map->emplace(elem.second, elem.first);
105     }
106     return;
107   }
108 
109   // Make a return node for traitor nodes.
CreateReturnNode(const FuncGraphPtr & func_graph,mindspore::HashMap<AnfNodePtr,size_t> * tuple_node_index)110   void CreateReturnNode(const FuncGraphPtr &func_graph, mindspore::HashMap<AnfNodePtr, size_t> *tuple_node_index) {
111     // If there's no traitor in the area, it means that this area is the last part
112     // of the original FuncGraph, it already contains the original Return node.
113     if (traitor_nodes_.empty()) {
114       for (auto &node : nodes_) {
115         if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
116           func_graph->set_return(node->cast<CNodePtr>());
117           node->set_func_graph(func_graph);
118           return;
119         }
120       }
121       MS_LOG(ERROR) << "Cannot find the return node in " << func_graph->ToString();
122       return;
123     }
124     AnfNodePtrList return_inputs = {NewValueNode(prim::kPrimReturn)};
125     if (traitor_nodes_.size() > 1) {
126       // The area has multiple output, it's necessary to make a tuple for them.
127       AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
128       AbstractBasePtrList abstracts;
129       size_t i = 0;
130       for (auto &traitor : traitor_nodes_) {
131         (void)tuple_node_index->emplace(traitor, i++);
132         (void)maketuple_inputs.emplace_back(traitor);
133         (void)abstracts.emplace_back(traitor->abstract());
134       }
135       auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
136       maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
137       (void)nodes_.insert(maketuple_node);
138       (void)return_inputs.emplace_back(maketuple_node);
139     } else {
140       (void)return_inputs.emplace_back(traitor_nodes_[0]);
141     }
142     auto return_node = func_graph->NewCNode(return_inputs);
143     return_node->set_abstract(return_inputs.back()->abstract());
144     func_graph->set_return(return_node);
145     (void)nodes_.insert(return_node);
146     traitor_nodes_.clear();  // traitor list is not useful anymore
147     return;
148   }
149 
AddTraitor(const AnfNodePtr & node)150   void AddTraitor(const AnfNodePtr &node) {
151     if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
152       (void)traitor_nodes_.emplace_back(node);
153     }
154   }
155 
nodes() const156   const mindspore::HashSet<AnfNodePtr> &nodes() const { return nodes_; }
spy_cnodes() const157   const std::vector<AnfNodePtr> &spy_cnodes() const { return spy_cnodes_; }
158 
159  private:
160   // This is a CNode that does not belong to this area.
IsExternalCNode(const AnfNodePtr & node) const161   bool IsExternalCNode(const AnfNodePtr &node) const { return node->isa<CNode>() && this->nodes_.count(node) == 0; }
162 
163   // nodes in this area
164   mindspore::HashSet<AnfNodePtr> nodes_;
165   // if a node's output is used by other Area, it's a traitor
166   std::vector<AnfNodePtr> traitor_nodes_;
167   // if a node use other Area's output, it's a spy
168   std::vector<AnfNodePtr> spy_cnodes_;
169 };
170 
171 class AreaGraph {
172  public:
173   using AreaGraphPtr = std::shared_ptr<AreaGraph>;
174 
175   // Build an area graph to maintain the relation between areas.
176   // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
BuildAreaGraph(const std::vector<AnfNodePtrList> & node_groups)177   static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
178     auto area_graph = std::make_shared<AreaGraph>(node_groups);
179     if (area_graph == nullptr) {
180       return nullptr;
181     }
182     if (!area_graph->TopoSort()) {
183       MS_LOG(WARNING) << "The groups have a cycle. The first node is " << node_groups[0][0]->fullname_with_scope();
184       return nullptr;
185     }
186     return area_graph;
187   }
188 
189   // Split the graph to multiple areas, and reconnect the edges between the areas.
190   // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
191   // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
SplitGraph(const FuncGraphPtr & main_func_graph,std::vector<CNodePtr> * main_cnodes,std::vector<size_t> * cnode_group_id,const std::function<void (const Area &)> & expand_callback)192   void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
193                   std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) {
194     main_cnodes->clear();
195     main_cnodes->resize(areas_.size(), nullptr);
196 
197     for (auto &area : this->areas_) {
198       expand_callback(area);
199     }
200 
201     for (auto index : topo_order_) {
202       auto &current_area = areas_[index];
203       auto sub_func_graph = std::make_shared<FuncGraph>();
204       mindspore::HashMap<ParameterPtr, AnfNodePtr> param_node_map;
205 
206       current_area.CreateParameters(sub_func_graph, &param_node_map);
207       current_area.CreateReturnNode(sub_func_graph, &node_index_in_returned_tuple_);
208       auto new_main_cnode = this->CreateMainCNode(main_func_graph, sub_func_graph, *main_cnodes, param_node_map);
209       (*main_cnodes)[index] = new_main_cnode;
210     }
211 
212     SortCNodes(main_cnodes);
213     *cnode_group_id = std::move(topo_order_);  // The topo_order is not used anymore.
214     return;
215   }
216 
AreaGraph(const std::vector<AnfNodePtrList> & node_groups)217   explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
218     for (size_t i = 0; i < node_groups.size(); ++i) {
219       (void)areas_.emplace_back(node_groups[i]);
220       for (const auto &node : node_groups[i]) {
221         node_area_map_[node] = i;
222       }
223     }
224     for (auto &area : areas_) {
225       for (auto &spy : area.spy_cnodes()) {
226         auto cnode = spy->cast<CNodePtr>();
227         MS_EXCEPTION_IF_NULL(cnode);
228         size_t v = node_area_map_[spy];
229         for (auto &in_node : cnode->inputs()) {
230           if (!in_node->isa<CNode>()) {
231             continue;
232           }
233           // area edge u -> v
234           size_t u = node_area_map_[in_node];
235           if (u == v) {
236             continue;
237           }
238           areas_[u].AddTraitor(in_node);
239           if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
240             (void)edge_prev_[v].emplace_back(u);
241           }
242         }
243       }
244     }
245   }
246   ~AreaGraph() = default;
247 
248  private:
249   // Topological sort the areas.
TopoSort()250   bool TopoSort() {
251     std::vector<int> out_degree(edge_prev_.size(), 0);
252     std::queue<size_t> que;
253     for (auto &prev : edge_prev_) {
254       for (size_t i : prev) {
255         out_degree[i]++;
256       }
257     }
258     for (size_t i = 0; i < out_degree.size(); ++i) {
259       if (out_degree[i] == 0) {
260         que.push(i);
261       }
262     }
263     while (!que.empty()) {
264       size_t u = que.front();
265       que.pop();
266       (void)topo_order_.emplace_back(u);
267       for (size_t i : edge_prev_[u]) {
268         if (--out_degree[i] == 0) {
269           que.push(i);
270         }
271       }
272     }
273     std::reverse(topo_order_.begin(), topo_order_.end());
274     return topo_order_.size() == areas_.size();
275   }
276 
277   // Make a CNode in main graph to hold the sub_func_graph.
CreateMainCNode(const FuncGraphPtr & main_func_graph,const FuncGraphPtr & sub_func_graph,const std::vector<CNodePtr> & main_cnodes,const mindspore::HashMap<ParameterPtr,AnfNodePtr> & param_node_map)278   CNodePtr CreateMainCNode(const FuncGraphPtr &main_func_graph, const FuncGraphPtr &sub_func_graph,
279                            const std::vector<CNodePtr> &main_cnodes,
280                            const mindspore::HashMap<ParameterPtr, AnfNodePtr> &param_node_map) {
281     TraceGuard guard(std::make_shared<TraceOpt>(sub_func_graph->debug_info()));
282     AnfNodePtrList main_cnode_inputs = {NewValueNode(sub_func_graph)};
283     for (const auto &param : sub_func_graph->parameters()) {
284       // assert the param exists.
285       const auto &input_node = param_node_map.find(param->cast<ParameterPtr>())->second;
286       size_t input_area = node_area_map_[input_node];
287       // if the input node is in a tuple, then we need to create a GetItem fot it.
288       if (node_index_in_returned_tuple_.count(input_node) != 0) {
289         auto idx_val = SizeToLong(node_index_in_returned_tuple_[input_node]);
290         auto idx = NewValueNode(idx_val);
291         idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
292         AnfNodePtrList getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), main_cnodes[input_area], idx};
293         TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info()));
294         auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
295         auto abs_tuple = dyn_cast<abstract::AbstractTuple>(main_cnodes[input_area]->abstract());
296         if (idx_val < SizeToLong(abs_tuple->size())) {
297           getitem_node->set_abstract(abs_tuple->elements()[LongToSize(idx_val)]);
298         } else {
299           getitem_node->set_abstract(main_cnodes[input_area]->abstract());
300         }
301         (void)main_cnode_inputs.emplace_back(getitem_node);
302       } else {
303         (void)main_cnode_inputs.emplace_back(main_cnodes[input_area]);
304       }
305     }
306     auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
307     new_main_cnode->set_abstract(sub_func_graph->output()->abstract());
308     return new_main_cnode;
309   }
310 
SortCNodes(std::vector<CNodePtr> * main_cnodes) const311   void SortCNodes(std::vector<CNodePtr> *main_cnodes) const {
312     std::vector<CNodePtr> main_cnodes_sorted;
313     (void)std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
314                          [main_cnodes](size_t index) { return main_cnodes->at(index); });
315     *main_cnodes = std::move(main_cnodes_sorted);
316   }
317 
318   // Areas in this subgraph
319   std::vector<Area> areas_;
320   // Adjacency table of areas
321   std::vector<std::vector<size_t>> edge_prev_;
322   // Topological order of areas
323   std::vector<size_t> topo_order_;
324   // Map AnfNode to Area id
325   mindspore::HashMap<AnfNodePtr, size_t> node_area_map_;
326   // Map the nodes to their index if there are multiple value in an area
327   mindspore::HashMap<AnfNodePtr, size_t> node_index_in_returned_tuple_;
328 };
329 
330 class Splitter {
331  public:
332   using SplitterPtr = std::shared_ptr<Splitter>;
333 
Split()334   bool Split() {
335     GenParamMap();
336     auto ori_sub_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
337     if (!split_schemer_->Split(ori_sub_func_graph)) {
338       return false;
339     }
340 
341     auto area_graph = AreaGraph::BuildAreaGraph(split_schemer_->split_plan());
342     if (area_graph == nullptr) {
343       return false;
344     }
345 
346     // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
347     std::vector<size_t> cnodes_group_id;
348     area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id,
349                            [this](const Area &area) { this->AreaExpand(area); });
350 
351     RebuildGraph(cnodes_group_id);
352 
353     return true;
354   }
355 
MakeSplitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)356   static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) {
357     MS_EXCEPTION_IF_NULL(main_cnode);
358     MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
359     MS_EXCEPTION_IF_NULL(split_schemer);
360     return std::make_shared<Splitter>(main_cnode, split_schemer);
361   }
362 
Splitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)363   Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer)
364       : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
365   ~Splitter() = default;
366 
367  private:
368   // Maintain new subgraphs in main graph.
RebuildGraph(const std::vector<size_t> & cnodes_group_id)369   void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
370     BindFuncGraph();
371     RecoverParameter();
372     SetSplitNodeName(cnodes_group_id);
373     ConnectToMainGraph(cnodes_group_id);
374     UpdateMainNodesKernelInfo();
375   }
376 
377   // Rebind nodes to its new sub_func_graph
BindFuncGraph() const378   void BindFuncGraph() const {
379     for (const auto &cnode : new_subgraph_cnodes_) {
380       auto sub_func_graph = GetCNodeFuncGraph(cnode);
381       auto callback = [&sub_func_graph](const AnfNodePtr &node) {
382         if (!node->isa<ValueNode>()) {
383           node->set_func_graph(sub_func_graph);
384         }
385       };
386       TraverseFuncGraph(sub_func_graph, callback);
387     }
388   }
389 
390   // Recover the original subgraph's parameter if the new graph needs it
RecoverParameter()391   void RecoverParameter() {
392     for (const auto &cnode : new_subgraph_cnodes_) {
393       auto sub_func_graph = GetCNodeFuncGraph(cnode);
394       auto callback = [&cnode, &sub_func_graph, this](const AnfNodePtr &node) {
395         auto param = node->cast<ParameterPtr>();
396         if (param == nullptr) {
397           return;
398         }
399         auto it = this->param_to_main_graph_node_map_.find(param);
400         if (it != this->param_to_main_graph_node_map_.end()) {
401           auto input = it->second;
402           cnode->add_input(input);
403           sub_func_graph->add_parameter(param);
404           // Avoid repeating parameters.
405           (void)this->param_to_main_graph_node_map_.erase(it);
406         }
407       };
408       TraverseFuncGraph(sub_func_graph, callback);
409     }
410   }
411 
InlineSubFuncGraph(const CNodePtr & main_node)412   CNodePtr InlineSubFuncGraph(const CNodePtr &main_node) {
413     auto func_graph = GetCNodeFuncGraph(main_node);
414     const auto &inputs = main_node->inputs();
415     auto output = func_graph->output()->cast<CNodePtr>();
416     MS_EXCEPTION_IF_NULL(output);
417     const auto &parameters = func_graph->parameters();
418     mindspore::HashMap<AnfNodePtr, AnfNodePtr> param_input;
419     for (size_t i = 0; i < parameters.size(); ++i) {
420       param_input[parameters[i]] = inputs[i + 1];
421     }
422     auto sub_nodes = TopoSort(func_graph->get_return());
423     for (auto node : sub_nodes) {
424       if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
425         cnode->set_func_graph(main_func_graph_);
426         for (size_t i = 1; i < cnode->size(); ++i) {
427           auto iter = param_input.find(cnode->input(i));
428           if (iter != param_input.end()) {
429             cnode->set_input(i, iter->second);
430           }
431         }
432         if (AnfUtils::IsRealKernel(node)) {
433           (void)maingraph_nodes_.emplace_back(node);
434         }
435       }
436     }
437     return output;
438   }
439 
SetSplitNodeName(const std::vector<size_t> & cnodes_group_id) const440   void SetSplitNodeName(const std::vector<size_t> &cnodes_group_id) const {
441     auto old_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
442     std::string ori_node_name;
443     if (old_func_graph->has_attr(kAttrNodeName)) {
444       ori_node_name = GetValue<std::string>(old_func_graph->get_attr(kAttrNodeName));
445     } else {
446       ori_node_name = GetValue<std::string>(old_func_graph->get_attr("graph_kernel"));
447     }
448     for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
449       auto group_id = cnodes_group_id[i];
450       if (!split_schemer_->NeedInline(group_id)) {
451         std::string node_name = ori_node_name + "_" + std::to_string(group_id);
452         AnfUtils::SetNodeAttr(kAttrNodeName, MakeValue(node_name), new_subgraph_cnodes_[i]);
453       }
454     }
455   }
456 
457   // Set the new sub_func_graph node as input of nodes original main graph.
ConnectToMainGraph(const std::vector<size_t> & cnodes_group_id)458   void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
459     // For single output kernel, the last area contains the original output node (return node),
460     //  to replace old subgraph with new subgraphs, just replace the old CNode with new last CNode.
461     // For multiple output kernel, to avoid returning Parameter, the last MakeTuple was distribute to
462     //  a new FuncGraph, just inline the last MakeTuple node.
463     mindspore::HashMap<AnfNodePtr, AnfNodePtr> replace_map;
464 
465     for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
466       if (split_schemer_->NeedInline(cnodes_group_id[i])) {
467         // Connect the sub_graph's inner node to main_graph
468         auto output = InlineSubFuncGraph(new_subgraph_cnodes_[i]);
469         if (i + 1 == new_subgraph_cnodes_.size()) {
470           replace_map[this->old_subgraph_cnode_] = output;
471         } else {
472           replace_map[new_subgraph_cnodes_[i]] = output;
473         }
474       } else {
475         if (i + 1 == new_subgraph_cnodes_.size()) {
476           replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
477         }
478         (void)maingraph_nodes_.emplace_back(new_subgraph_cnodes_[i]);
479       }
480     }
481 
482     TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
483       auto cnode = node->cast<CNodePtr>();
484       if (cnode == nullptr) {
485         return;
486       }
487       for (size_t i = 1; i < cnode->size(); ++i) {
488         auto input_node = cnode->input(i);
489         auto iter = replace_map.find(input_node);
490         if (iter != replace_map.end()) {
491           cnode->set_input(i, iter->second);
492         }
493       }
494     });
495   }
496 
UpdateMainNodesKernelInfo() const497   void UpdateMainNodesKernelInfo() const {
498     auto graph_manager = main_func_graph_->manager();
499     MS_EXCEPTION_IF_NULL(graph_manager);
500 
501     for (auto node : maingraph_nodes_) {
502       MS_LOG(DEBUG) << "Update kernel_info for " << node->DebugString() << " (" << node->fullname_with_scope() << ")";
503       auto sub_func_graph = GetCNodeFuncGraph(node);
504       if (sub_func_graph != nullptr) {
505         graph_manager->AddFuncGraph(sub_func_graph);
506         auto attr = GkUtils::ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
507         sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
508         Callback::Instance()->SetGraphKernelNodeKernelInfo(node);
509       } else {
510         Callback::Instance()->ResetKernelInfo(node);
511       }
512     }
513   }
514 
515   // Copy all Parameter and ValueNode that the area used.
AreaExpand(const Area & area)516   void AreaExpand(const Area &area) {
517     mindspore::HashMap<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
518     for (auto sub_node : area.nodes()) {
519       auto sub_cnode = sub_node->cast<CNodePtr>();
520       if (sub_cnode == nullptr) {
521         continue;
522       }
523       for (size_t i = 1; i < sub_cnode->size(); ++i) {
524         auto in_node = sub_cnode->input(i);
525         if (in_node->isa<CNode>()) {
526           continue;
527         }
528         auto it = old_valuenode_and_param_map.find(in_node);
529         if (it != old_valuenode_and_param_map.end()) {
530           sub_cnode->set_input(i, it->second);
531         } else {
532           if (in_node->isa<Parameter>()) {
533             auto param = in_node->cast<ParameterPtr>();
534             auto cp_param = this->ParameterClone(param, in_node->func_graph());
535             old_valuenode_and_param_map[in_node] = cp_param->cast<AnfNodePtr>();
536             sub_cnode->set_input(i, cp_param);
537           }
538         }
539       }
540     }
541   }
542 
GenParamMap()543   void GenParamMap() {
544     auto sub_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
545     auto &param_arr = sub_func_graph->parameters();
546     for (size_t i = 0; i < param_arr.size(); ++i) {
547       auto param = param_arr[i]->cast<ParameterPtr>();
548       MS_EXCEPTION_IF_NULL(param);
549       param_to_main_graph_node_map_[param] = old_subgraph_cnode_->input(i + 1);
550     }
551   }
552 
ParameterClone(const ParameterPtr & param,const FuncGraphPtr & func)553   ParameterPtr ParameterClone(const ParameterPtr &param, const FuncGraphPtr &func) {
554     ParameterPtr param_c = std::make_shared<Parameter>(func);
555     param_c->set_name(param->name());
556     param_c->set_abstract(param->abstract());
557     auto node = param_to_main_graph_node_map_[param];
558     param_to_main_graph_node_map_[param_c] = node;
559     return param_c;
560   }
561 
562   FuncGraphPtr main_func_graph_;
563   CNodePtr old_subgraph_cnode_;                // The cnode that holds the original sub_func_graph
564   std::vector<CNodePtr> new_subgraph_cnodes_;  // The cnode list that hold the new sub_func_graph
565   std::vector<AnfNodePtr> maingraph_nodes_;    // The nodes in main graph finally, include "call" and inlined node
566   SplitSchemerPtr split_schemer_;
567   mindspore::HashMap<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
568 };
569 
570 class CppCostModelSplitSchemer : public CommonSplitSchemer {
571  public:
CppCostModelSplitSchemer(const std::string & processor)572   explicit CppCostModelSplitSchemer(const std::string &processor) : processor_(processor) {}
573   ~CppCostModelSplitSchemer() = default;
Split(const FuncGraphPtr & func_graph)574   bool Split(const FuncGraphPtr &func_graph) override {
575     if (!SplitByCostModel(func_graph)) {
576       return false;
577     }
578     GroupReturnNode(func_graph);
579     return true;
580   }
581 
582  protected:
SplitByCostModel(const FuncGraphPtr & func_graph)583   bool SplitByCostModel(const FuncGraphPtr &func_graph) {
584     mindspore::HashMap<inner::NodePtr, AnfNodePtr> op_node_map;
585     auto lg = GkUtils::AnfGraph2LiteGraph(func_graph, &op_node_map);
586     MS_LOG(DEBUG) << "Litegraph: " << lg->ToString();
587     // use the original node index to sort the split_plan's nodes.
588     mindspore::HashMap<AnfNodePtr, size_t> node_idx_map;
589     for (size_t i = 0; i < lg->ops().size(); ++i) {
590       node_idx_map[op_node_map[lg->ops()[i]]] = i;
591     }
592     auto model = inner::SplitModelFactory::Instance().CreateSplitModel(processor_);
593     MS_EXCEPTION_IF_NULL(model);
594     model->Run(lg);
595     auto &areas = model->areas();
596     for (auto &area : areas) {
597       AnfNodePtrList nodes;
598       for (auto &op : area->ops()) {
599         (void)nodes.emplace_back(op_node_map[op]);
600         node_group_[nodes.back()] = split_plan_.size();
601       }
602       std::sort(nodes.begin(), nodes.end(), [&node_idx_map](const AnfNodePtr &a, const AnfNodePtr &b) {
603         return node_idx_map[a] < node_idx_map[b];
604       });
605       (void)split_plan_.emplace_back(std::move(nodes));
606       need_inline_.push_back((area->mode() == inner::AreaMode::BASIC ? 1 : 0));
607     }
608     return split_plan_.size() > 1 || (split_plan_.size() == 1 && NeedInline(0));
609   }
610 
611   std::string processor_;
612 };
613 }  // namespace
614 
GetSplitSchema(const std::string & processor)615 std::shared_ptr<SplitSchemer> GraphKernelSplitter::GetSplitSchema(const std::string &processor) {
616   return std::make_shared<CppCostModelSplitSchemer>(processor);
617 }
618 
TrySplit(const CNodePtr & sub_root_cnode)619 bool GraphKernelSplitter::TrySplit(const CNodePtr &sub_root_cnode) {
620   MS_LOG(DEBUG) << "Split process node: " << sub_root_cnode->fullname_with_scope();
621   auto processor = Callback::Instance()->GetTargetFromContext();
622   auto schm = GetSplitSchema(processor);
623   MS_EXCEPTION_IF_NULL(schm);
624   auto splitter = Splitter::MakeSplitter(sub_root_cnode, schm);
625   MS_EXCEPTION_IF_NULL(splitter);
626   bool result = splitter->Split();
627   MS_LOG(DEBUG) << "Split node completed, result: " << result;
628   return result;
629 }
630 
Run(const FuncGraphPtr & func_graph)631 bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
632   MS_EXCEPTION_IF_NULL(func_graph);
633   auto mng = func_graph->manager();
634   if (mng == nullptr) {
635     mng = Manage(func_graph, true);
636     func_graph->set_manager(mng);
637   }
638   auto todos = TopoSort(func_graph->get_return());
639 
640   // Split subgraphs in reversed topo order,
641   // since the nodes behind the processing node may be modified when splitting.
642   bool changed = false;
643   for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
644     auto node = (*iter)->cast<CNodePtr>();
645     if (node != nullptr && AnfUtils::IsGraphKernel(node)) {
646       changed = TrySplit(node) || changed;
647     }
648   }
649   mng->RemoveRoots();
650   mng->KeepRoots({func_graph});
651   return changed;
652 }
653 }  // namespace mindspore::graphkernel
654