• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
17 #include <algorithm>
18 #include <vector>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <queue>
23 #include <map>
24 #include <unordered_map>
25 #include "frontend/optimizer/irpass.h"
26 #include "pipeline/jit/parse/python_adapter.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
30 #include "debug/anf_ir_dump.h"
31 #include "utils/context/graph_kernel_flags.h"
32 
33 namespace mindspore {
34 namespace kernel {
35 namespace {
GetStitchInfo(const nlohmann::json & kernel_json)36 StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) {
37   StitchInfo info;
38   if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
39     nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
40     if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
41       std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
42       info.stitch_ops = stitch_ops;
43     }
44     if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
45       std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
46       info.stitch_atomic_ops = stitch_atomic_ops;
47     }
48   }
49   return info;
50 }
51 
GetRecomputeOps(const nlohmann::json & kernel_json)52 std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) {
53   if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
54     std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
55     return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
56   }
57   return std::set<std::string>();
58 }
59 
IsRecomputeOp(const nlohmann::json & op_desc,const std::set<std::string> & recompute_ops)60 bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) {
61   std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
62   if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
63     return false;
64   }
65   std::string tensor_name = output_descs[0][kJsonKeyTensorName];
66   if (recompute_ops.count(tensor_name)) {
67     return true;
68   }
69   return false;
70 }
71 
NewRecomputeNode(const AnfNodePtr & orig_node,std::map<AnfNodePtr,AnfNodePtr> * node_map)72 CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) {
73   auto func_graph = orig_node->func_graph();
74   MS_EXCEPTION_IF_NULL(func_graph);
75   auto cnode = orig_node->cast<CNodePtr>();
76   MS_EXCEPTION_IF_NULL(cnode);
77   TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
78   auto orig_inputs = cnode->inputs();
79   std::vector<AnfNodePtr> inputs;
80   for (auto inp : orig_inputs) {
81     if (node_map->find(inp) == node_map->end()) {
82       inputs.push_back(inp);
83       continue;
84     }
85     inputs.push_back((*node_map)[inp]);
86   }
87   CNodePtr cp_node = func_graph->NewCNode(inputs);
88   func_graph->AddNode(cp_node);
89   cp_node->set_abstract(cnode->abstract());
90   cp_node->set_forward(cnode->forward().first, cnode->forward().second);
91   cp_node->set_inputs_value(cnode->inputs_value());
92   ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
93   cp_node->set_scope(scope);
94   cp_node->set_kernel_info(cnode->kernel_info_ptr());
95   cp_node->set_primal_attrs(cnode->primal_attrs());
96   cp_node->set_primal_debug_infos(cnode->primal_debug_infos());
97   (*node_map)[orig_node] = cp_node;
98   return cp_node->cast<CNodePtr>();
99 }
100 
SetStitchAttr(const nlohmann::json & op_desc,const StitchInfo & info,const CNodePtr & node)101 void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
102   std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
103   if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
104   std::string tensor_name = output_descs[0][kJsonKeyTensorName];
105   if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
106     AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
107     MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
108   }
109   if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
110       info.stitch_atomic_ops.end()) {
111     AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
112     MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
113   }
114 }
115 
116 // replace original region root op by its copy in this res_graphs
ConnectRecomputeOps(AnfNodePtrList * res_graphs,const AnfNodePtr & orig_region_root,const AnfNodePtr & cp_region_root)117 void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
118                          const AnfNodePtr &cp_region_root) {
119   for (auto &node : *res_graphs) {
120     auto cnode = node->cast<CNodePtr>();
121     auto inputs = cnode->inputs();
122     for (size_t i = 1; i < inputs.size(); ++i) {
123       if (inputs[i] != orig_region_root) continue;
124       cnode->set_input(i, cp_region_root);
125     }
126   }
127 }
128 }  // namespace
129 
DecodeSplitNodes(const nlohmann::json & kernel_json,const std::map<std::string,AnfNodePtr> & address_node_map,AnfNodePtrList * res_graphs)130 bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
131                                          const std::map<std::string, AnfNodePtr> &address_node_map,
132                                          AnfNodePtrList *res_graphs) {
133   MS_EXCEPTION_IF_NULL(res_graphs);
134   MS_LOG(DEBUG) << "start decode, " << kernel_json;
135   // decode cnodes in graph.
136   std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
137   if (op_node_descs.empty()) {
138     MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
139     return false;
140   }
141   StitchInfo info = GetStitchInfo(kernel_json);
142   auto recompute_ops = GetRecomputeOps(kernel_json);
143   // key_value: original_copied
144   std::map<AnfNodePtr, AnfNodePtr> node_map;
145   // nodes would be copied
146   AnfNodePtrList orig_region_nodes;
147   // nodes would not be copied
148   AnfNodePtrList no_cp_nodes;
149   for (const auto &op_desc : op_node_descs) {
150     if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
151       MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
152       return false;
153     }
154 
155     std::string ptr_address = op_desc[kJsonKeyPtrAddress];
156     if (address_node_map.count(ptr_address) == 0) {
157       MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
158       return false;
159     }
160     auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
161     if (IsRecomputeOp(op_desc, recompute_ops)) {
162       auto cp_node = NewRecomputeNode(node, &node_map);
163       orig_region_nodes.push_back(node);
164       SetStitchAttr(op_desc, info, cp_node);
165       res_graphs->push_back(cp_node);
166       continue;
167     }
168     SetStitchAttr(op_desc, info, node);
169     res_graphs->push_back(node);
170     no_cp_nodes.push_back(node);
171   }
172   for (auto orig_node : orig_region_nodes) {
173     ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
174   }
175   MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
176   return true;
177 }
178 }  // namespace kernel
179 
180 namespace opt {
181 namespace {
TraverseFuncGraphFromCNode(const CNodePtr & cnode,const std::function<void (AnfNodePtr &)> & callback)182 void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
183   std::unordered_set<AnfNodePtr> visited;
184   std::queue<AnfNodePtr> que;
185   que.push(cnode);
186   visited.insert(cnode);
187   while (!que.empty()) {
188     auto ft_node = que.front();
189     que.pop();
190     callback(ft_node);
191     auto ft_cnode = ft_node->cast<CNodePtr>();
192     if (ft_cnode == nullptr) continue;
193     for (const auto &in_node : ft_cnode->inputs()) {
194       if (visited.count(in_node) == 0) {
195         que.push(in_node);
196         visited.insert(in_node);
197       }
198     }
199   }
200 }
201 
202 // Visited each AnfNode once, use callback to do the job on AnfNode
TraverseFuncGraph(const FuncGraphPtr & root,const std::function<void (AnfNodePtr &)> & callback)203 inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) {
204   TraverseFuncGraphFromCNode(root->get_return(), callback);
205 }
206 
207 class Area {
208  public:
Area(const AnfNodePtrList & anf_arr)209   explicit Area(const AnfNodePtrList &anf_arr) {
210     nodes_.insert(anf_arr.begin(), anf_arr.end());
211     for (auto &node : anf_arr) {
212       auto cnode = node->cast<CNodePtr>();
213       if (cnode == nullptr) continue;
214       const auto &inputs = cnode->inputs();
215       if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
216         spy_cnodes_.emplace_back(node);
217       }
218     }
219   }
220 
221   ~Area() = default;
222 
223   // Set the external inputs of spy as a Parameter.
CreateParameters(const FuncGraphPtr & func_graph,std::unordered_map<ParameterPtr,AnfNodePtr> * param_node_map)224   void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) {
225     std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map;
226     for (auto node : this->spy_cnodes_) {
227       auto cnode = node->cast<CNodePtr>();
228       MS_EXCEPTION_IF_NULL(cnode);
229       for (size_t i = 1; i < cnode->inputs().size(); ++i) {
230         AnfNodePtr in_node = cnode->input(i);
231         if (!IsExternalCNode(in_node)) continue;
232         auto it = node_param_map.find(in_node);
233         if (it == node_param_map.end()) {
234           auto new_param = std::make_shared<Parameter>(func_graph);
235           new_param->set_abstract(in_node->abstract());
236           func_graph->add_parameter(new_param);
237           node_param_map.insert(std::make_pair(in_node, new_param));
238           cnode->set_input(i, new_param);
239         } else {
240           cnode->set_input(i, it->second);
241         }
242       }
243     }
244     this->spy_cnodes_.clear();  // spy list is not useful anymore
245     for (auto &&elem : node_param_map) {
246       param_node_map->insert(std::make_pair(elem.second, elem.first));
247     }
248     return;
249   }
250 
251   // Make a return node for traitor nodes.
CreateReturnNode(const FuncGraphPtr & func_graph,std::unordered_map<AnfNodePtr,size_t> * tuple_node_index)252   void CreateReturnNode(const FuncGraphPtr &func_graph, std::unordered_map<AnfNodePtr, size_t> *tuple_node_index) {
253     // If there's no traitor in the area, it means that this area is the last part
254     // of the original FuncGraph, it already contains the original Return node.
255     if (traitor_nodes_.empty()) {
256       for (auto &node : nodes_) {
257         if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
258           func_graph->set_return(node->cast<CNodePtr>());
259           node->set_func_graph(func_graph);
260           return;
261         }
262       }
263       MS_LOG(ERROR) << "Cannot find the return node in " << func_graph->ToString();
264       return;
265     }
266     AnfNodePtrList return_inputs = {NewValueNode(prim::kPrimReturn)};
267     if (traitor_nodes_.size() > 1) {
268       // The area has multiple output, it's necessary to make a tuple for them.
269       AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
270       AbstractBasePtrList abstracts;
271       size_t i = 0;
272       for (auto &traitor : traitor_nodes_) {
273         tuple_node_index->insert(std::make_pair(traitor, i++));
274         maketuple_inputs.emplace_back(traitor);
275         abstracts.emplace_back(traitor->abstract());
276       }
277       auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
278       maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
279       nodes_.insert(maketuple_node);
280       return_inputs.emplace_back(maketuple_node);
281     } else {
282       return_inputs.emplace_back(traitor_nodes_[0]);
283     }
284     auto return_node = func_graph->NewCNode(return_inputs);
285     return_node->set_abstract(return_inputs.back()->abstract());
286     func_graph->set_return(return_node);
287     nodes_.insert(return_node);
288     traitor_nodes_.clear();  // traitor list is not useful anymore
289     return;
290   }
291 
AddTraitor(const AnfNodePtr & node)292   void AddTraitor(const AnfNodePtr &node) {
293     if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
294       traitor_nodes_.emplace_back(node);
295     }
296   }
297 
nodes() const298   const std::unordered_set<AnfNodePtr> &nodes() const { return nodes_; }
spy_cnodes() const299   const std::vector<AnfNodePtr> &spy_cnodes() const { return spy_cnodes_; }
300 
301  private:
302   // This is a CNode that does not belong to this area.
IsExternalCNode(const AnfNodePtr & node) const303   bool IsExternalCNode(const AnfNodePtr &node) const { return node->isa<CNode>() && this->nodes_.count(node) == 0; }
304 
305   // nodes in this area
306   std::unordered_set<AnfNodePtr> nodes_;
307   // if a node's output is used by other Area, it's a traitor
308   std::vector<AnfNodePtr> traitor_nodes_;
309   // if a node use other Area's output, it's a spy
310   std::vector<AnfNodePtr> spy_cnodes_;
311 };
312 
313 class AreaGraph {
314  public:
315   using AreaGraphPtr = std::shared_ptr<AreaGraph>;
316 
317   // Build an area graph to maintain the relation between areas.
318   // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
BuildAreaGraph(const std::vector<AnfNodePtrList> & node_groups)319   static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
320     auto area_graph = std::make_shared<AreaGraph>(node_groups);
321     if (area_graph == nullptr) return nullptr;
322     if (!area_graph->TopoSort()) {
323       MS_LOG(WARNING) << "The groups have a cycle.";
324       return nullptr;
325     }
326     return area_graph;
327   }
328 
329   // Split the graph to multiple areas, and reconnect the edges between the areas.
330   // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
331   // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
SplitGraph(const FuncGraphPtr & main_func_graph,std::vector<CNodePtr> * main_cnodes,std::vector<size_t> * cnode_group_id,const std::function<void (const Area &)> & expand_callback)332   void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
333                   std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) {
334     main_cnodes->clear();
335     main_cnodes->resize(areas_.size(), nullptr);
336 
337     for (auto &area : this->areas_) {
338       expand_callback(area);
339     }
340 
341     for (auto index : topo_order_) {
342       auto &current_area = areas_[index];
343       auto sub_func_graph = std::make_shared<FuncGraph>();
344       std::unordered_map<ParameterPtr, AnfNodePtr> param_node_map;
345 
346       current_area.CreateParameters(sub_func_graph, &param_node_map);
347       current_area.CreateReturnNode(sub_func_graph, &node_index_in_returned_tuple_);
348       auto new_main_cnode = this->CreateMainCNode(main_func_graph, sub_func_graph, *main_cnodes, param_node_map);
349       (*main_cnodes)[index] = new_main_cnode;
350     }
351 
352     SortCNodes(main_cnodes);
353     *cnode_group_id = std::move(topo_order_);  // The topo_order is not used anymore.
354     return;
355   }
356 
AreaGraph(const std::vector<AnfNodePtrList> & node_groups)357   explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
358     for (size_t i = 0; i < node_groups.size(); ++i) {
359       areas_.emplace_back(node_groups[i]);
360       for (const auto &node : node_groups[i]) {
361         node_area_map_[node] = i;
362       }
363     }
364     for (auto &area : areas_) {
365       for (auto &spy : area.spy_cnodes()) {
366         auto cnode = spy->cast<CNodePtr>();
367         MS_EXCEPTION_IF_NULL(cnode);
368         size_t v = node_area_map_[spy];
369         for (auto &in_node : cnode->inputs()) {
370           if (!in_node->isa<CNode>()) continue;
371           // area edge u -> v
372           size_t u = node_area_map_[in_node];
373           if (u == v) continue;
374           areas_[u].AddTraitor(in_node);
375           if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
376             edge_prev_[v].emplace_back(u);
377           }
378         }
379       }
380     }
381   }
382   ~AreaGraph() = default;
383 
384  private:
385   // Topological sort the areas.
TopoSort()386   bool TopoSort() {
387     std::vector<int> out_degree(edge_prev_.size(), 0);
388     std::queue<size_t> que;
389     for (auto &prev : edge_prev_) {
390       for (size_t i : prev) {
391         out_degree[i]++;
392       }
393     }
394     for (size_t i = 0; i < out_degree.size(); ++i) {
395       if (out_degree[i] == 0) que.push(i);
396     }
397     while (!que.empty()) {
398       size_t u = que.front();
399       que.pop();
400       topo_order_.emplace_back(u);
401       for (size_t i : edge_prev_[u]) {
402         if (--out_degree[i] == 0) que.push(i);
403       }
404     }
405     std::reverse(topo_order_.begin(), topo_order_.end());
406     return topo_order_.size() == areas_.size();
407   }
408 
409   // Make a CNode in main graph to hold the sub_func_graph.
CreateMainCNode(const FuncGraphPtr & main_func_graph,const FuncGraphPtr & sub_func_graph,const std::vector<CNodePtr> & main_cnodes,const std::unordered_map<ParameterPtr,AnfNodePtr> & param_node_map)410   CNodePtr CreateMainCNode(const FuncGraphPtr &main_func_graph, const FuncGraphPtr &sub_func_graph,
411                            const std::vector<CNodePtr> &main_cnodes,
412                            const std::unordered_map<ParameterPtr, AnfNodePtr> &param_node_map) {
413     TraceGuard guard(std::make_shared<TraceOpt>(sub_func_graph->debug_info()));
414     AnfNodePtrList main_cnode_inputs = {NewValueNode(sub_func_graph)};
415     for (const auto &param : sub_func_graph->parameters()) {
416       // assert the param exists.
417       const auto &input_node = param_node_map.find(param->cast<ParameterPtr>())->second;
418       size_t input_area = node_area_map_[input_node];
419       // if the input node is in a tuple, then we need to create a GetItem fot it.
420       if (node_index_in_returned_tuple_.count(input_node) != 0) {
421         auto idx_val = SizeToLong(node_index_in_returned_tuple_[input_node]);
422         auto idx = NewValueNode(idx_val);
423         idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
424         AnfNodePtrList getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), main_cnodes[input_area], idx};
425         TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info()));
426         auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
427         auto abs_tuple = dyn_cast<abstract::AbstractTuple>(main_cnodes[input_area]->abstract());
428         if (idx_val < SizeToLong(abs_tuple->size())) {
429           getitem_node->set_abstract(abs_tuple->elements()[LongToSize(idx_val)]);
430         } else {
431           getitem_node->set_abstract(main_cnodes[input_area]->abstract());
432         }
433         main_cnode_inputs.emplace_back(getitem_node);
434       } else {
435         main_cnode_inputs.emplace_back(main_cnodes[input_area]);
436       }
437     }
438     auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
439     new_main_cnode->set_abstract(sub_func_graph->output()->abstract());
440     return new_main_cnode;
441   }
442 
SortCNodes(std::vector<CNodePtr> * main_cnodes) const443   void SortCNodes(std::vector<CNodePtr> *main_cnodes) const {
444     std::vector<CNodePtr> main_cnodes_sorted;
445     std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
446                    [main_cnodes](size_t index) { return main_cnodes->at(index); });
447     *main_cnodes = std::move(main_cnodes_sorted);
448   }
449 
450   // Areas in this subgraph
451   std::vector<Area> areas_;
452   // Adjacency table of areas
453   std::vector<std::vector<size_t>> edge_prev_;
454   // Topological order of areas
455   std::vector<size_t> topo_order_;
456   // Map AnfNode to Area id
457   std::unordered_map<AnfNodePtr, size_t> node_area_map_;
458   // Map the nodes to their index if there are multiple value in an area
459   std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_;
460 };
461 
462 class SplitSchemer {
463  public:
464   SplitSchemer() = default;
465   virtual ~SplitSchemer() = default;
466   virtual bool Split(const FuncGraphPtr &func_graph) = 0;
NeedInline(size_t group_id) const467   virtual bool NeedInline(size_t group_id) const { return false; }
split_plan() const468   const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }
469 
470  protected:
471   std::vector<AnfNodePtrList> split_plan_;
472 };
473 
474 class Splitter {
475  public:
476   using SplitSchemerPtr = std::shared_ptr<SplitSchemer>;
477   using SplitterPtr = std::shared_ptr<Splitter>;
478 
Split()479   bool Split() {
480     GenParamMap();
481     auto ori_sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
482     if (!split_schemer_->Split(ori_sub_func_graph)) {
483       return false;
484     }
485 
486     auto area_graph = AreaGraph::BuildAreaGraph(split_schemer_->split_plan());
487     if (area_graph == nullptr) {
488       return false;
489     }
490 
491     // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
492     std::vector<size_t> cnodes_group_id;
493     area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id,
494                            [this](const Area &area) { this->AreaExpand(area); });
495 
496     RebuildGraph(cnodes_group_id);
497 
498     return true;
499   }
500 
MakeSplitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)501   static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) {
502     MS_EXCEPTION_IF_NULL(main_cnode);
503     MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
504     MS_EXCEPTION_IF_NULL(split_schemer);
505     return std::make_shared<Splitter>(main_cnode, split_schemer);
506   }
507 
Splitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)508   Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer)
509       : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
510   ~Splitter() = default;
511 
512  private:
ResetInlinedNodesKernelInfo() const513   void ResetInlinedNodesKernelInfo() const {
514     for (const auto &node : inlined_nodes_) {
515       ResetKernelInfo(node);
516     }
517   }
518 
519   // Maintain new subgraphs in main graph.
RebuildGraph(const std::vector<size_t> & cnodes_group_id)520   void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
521     BindFuncGraph();
522     RecoverParameter();
523     ConnectToMainGraph(cnodes_group_id);
524     UpdateSubGraphInfo();
525     ResetInlinedNodesKernelInfo();
526   }
527 
528   // Rebind nodes to its new sub_func_graph
BindFuncGraph() const529   void BindFuncGraph() const {
530     for (const auto &cnode : new_subgraph_cnodes_) {
531       auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
532       auto callback = [&sub_func_graph](const AnfNodePtr &node) {
533         if (!node->isa<ValueNode>()) {
534           node->set_func_graph(sub_func_graph);
535         }
536       };
537       TraverseFuncGraph(sub_func_graph, callback);
538     }
539   }
540 
541   // Recover the original subgraph's parameter if the new graph needs it
RecoverParameter()542   void RecoverParameter() {
543     for (const auto &cnode : new_subgraph_cnodes_) {
544       auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
545       auto callback = [&cnode, &sub_func_graph, this](const AnfNodePtr &node) {
546         auto param = node->cast<ParameterPtr>();
547         if (param == nullptr) return;
548         auto it = this->param_to_main_graph_node_map_.find(param);
549         if (it != this->param_to_main_graph_node_map_.end()) {
550           cnode->add_input(it->second);
551           sub_func_graph->add_parameter(param);
552           // Avoid repeating parameters.
553           this->param_to_main_graph_node_map_.erase(it);
554         }
555       };
556       TraverseFuncGraph(sub_func_graph, callback);
557     }
558   }
559 
InlineSubFuncGraph(const CNodePtr & main_node)560   CNodePtr InlineSubFuncGraph(const CNodePtr &main_node) {
561     auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(main_node);
562     const auto &inputs = main_node->inputs();
563     auto output = func_graph->output()->cast<CNodePtr>();
564     MS_EXCEPTION_IF_NULL(output);
565     const auto &parameters = func_graph->parameters();
566     std::unordered_map<AnfNodePtr, AnfNodePtr> param_input;
567     for (size_t i = 0; i < parameters.size(); ++i) {
568       param_input[parameters[i]] = inputs[i + 1];
569     }
570     auto sub_nodes = TopoSort(func_graph->get_return());
571     for (auto node : sub_nodes) {
572       if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
573         cnode->set_func_graph(main_func_graph_);
574         for (size_t i = 1; i < cnode->inputs().size(); ++i) {
575           auto iter = param_input.find(cnode->input(i));
576           if (iter != param_input.end()) {
577             cnode->set_input(i, iter->second);
578           }
579         }
580         if (AnfAlgo::IsRealKernel(node)) {
581           inlined_nodes_.emplace_back(node);
582         }
583       }
584     }
585     return output;
586   }
587 
588   // Set the new sub_func_graph node as input of nodes original main graph.
ConnectToMainGraph(const std::vector<size_t> & cnodes_group_id)589   void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
590     // For single output kernel, the last area contains the original output node (return node),
591     //  to replace old subgraph with new subgraphs, just replace the old CNode with new last CNode.
592     // For multiple output kernel, to avoid returning Parameter, the last MakeTuple was distribute to
593     //  a new FuncGraph, just inline the last MakeTuple node.
594     std::vector<CNodePtr> tmp_subgraph_cnodes;
595     std::unordered_map<AnfNodePtr, AnfNodePtr> replace_map;
596 
597     for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
598       if (split_schemer_->NeedInline(cnodes_group_id[i])) {
599         // Connect the sub_graph's inner node to main_graph
600         auto output = InlineSubFuncGraph(new_subgraph_cnodes_[i]);
601         if (i + 1 == new_subgraph_cnodes_.size()) {
602           replace_map[this->old_subgraph_cnode_] = output;
603         } else {
604           replace_map[new_subgraph_cnodes_[i]] = output;
605         }
606       } else {
607         if (i + 1 == new_subgraph_cnodes_.size()) {
608           replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
609         }
610         tmp_subgraph_cnodes.emplace_back(new_subgraph_cnodes_[i]);
611       }
612     }
613     new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes);
614 
615     TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
616       auto cnode = node->cast<CNodePtr>();
617       if (cnode == nullptr) return;
618       for (size_t i = 1; i < cnode->inputs().size(); ++i) {
619         auto input_node = cnode->input(i);
620         auto iter = replace_map.find(input_node);
621         if (iter != replace_map.end()) {
622           cnode->set_input(i, iter->second);
623         }
624       }
625     });
626   }
627 
UpdateSubGraphInfo() const628   void UpdateSubGraphInfo() const {
629     auto graph_manager = main_func_graph_->manager();
630     MS_EXCEPTION_IF_NULL(graph_manager);
631 
632     for (auto cnode : new_subgraph_cnodes_) {
633       auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
634       // add new sub_func_graph to manager
635       graph_manager->AddFuncGraph(sub_func_graph);
636 
637       // set GraphKernel attr
638       auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
639       sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
640 
641       // set kernel info
642       AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
643       AnfNodePtrList outputs;
644       kernel::GetFuncGraphOutputNodes(sub_func_graph, &outputs);
645       SetNewKernelInfo(cnode, sub_func_graph, inputs, outputs);
646     }
647   }
648 
649   // Copy all Parameter and ValueNode that the area used.
AreaExpand(const Area & area)650   void AreaExpand(const Area &area) {
651     std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
652     for (auto sub_node : area.nodes()) {
653       auto sub_cnode = sub_node->cast<CNodePtr>();
654       if (sub_cnode == nullptr) continue;
655       for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) {
656         auto in_node = sub_cnode->input(i);
657         if (in_node->isa<CNode>()) continue;
658         auto it = old_valuenode_and_param_map.find(in_node);
659         if (it != old_valuenode_and_param_map.end()) {
660           sub_cnode->set_input(i, it->second);
661         } else {
662           if (in_node->isa<Parameter>()) {
663             auto param = in_node->cast<ParameterPtr>();
664             auto cp_param = this->ParameterClone(param, in_node->func_graph());
665             old_valuenode_and_param_map[in_node] = cp_param->cast<AnfNodePtr>();
666             sub_cnode->set_input(i, cp_param);
667           }
668         }
669       }
670     }
671   }
672 
GenParamMap()673   void GenParamMap() {
674     auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
675     auto &param_arr = sub_func_graph->parameters();
676     for (size_t i = 0; i < param_arr.size(); ++i) {
677       auto param = param_arr[i]->cast<ParameterPtr>();
678       MS_EXCEPTION_IF_NULL(param);
679       param_to_main_graph_node_map_[param] = old_subgraph_cnode_->input(i + 1);
680     }
681   }
682 
ParameterClone(const ParameterPtr & param,const FuncGraphPtr & func)683   ParameterPtr ParameterClone(const ParameterPtr &param, const FuncGraphPtr &func) {
684     ParameterPtr param_c = std::make_shared<Parameter>(func);
685     param_c->set_name(param->name());
686     param_c->set_abstract(param->abstract());
687     param_to_main_graph_node_map_[param_c] = param_to_main_graph_node_map_[param];
688     return param_c;
689   }
690 
691   FuncGraphPtr main_func_graph_;
692   CNodePtr old_subgraph_cnode_;                // The cnode that holds the original sub_func_graph
693   std::vector<CNodePtr> new_subgraph_cnodes_;  // The cnode list that hold the new sub_func_graph
694   std::vector<AnfNodePtr> inlined_nodes_;
695   SplitSchemerPtr split_schemer_;
696   std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
697 };
698 
699 class CostModelSplitSchemer : public SplitSchemer {
700  public:
701   virtual ~CostModelSplitSchemer() = default;
Split(const FuncGraphPtr & func_graph)702   bool Split(const FuncGraphPtr &func_graph) override {
703     if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
704       MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node.";
705     }
706     func_graph_ = func_graph;
707     this->Run();
708     return !split_plan_.empty();
709   }
710 
NeedInline(size_t group_id) const711   bool NeedInline(size_t group_id) const override {
712     if (group_id >= need_inline_.size()) {
713       MS_LOG(EXCEPTION) << "The group_id " << group_id << " should be less than the group num " << need_inline_.size();
714     }
715     return need_inline_[group_id] != 0;
716   }
717 
718  protected:
SplitByCostModel()719   virtual bool SplitByCostModel() {
720     // Use an address map to record the anf node address when converting to json,
721     // it will recover the original node after split.
722     std::map<std::string, AnfNodePtr> address_node_map;
723 
724     // convert anf-ir to json
725     nlohmann::json json_desc;
726     DumpOption dump_option;
727     dump_option.is_before_select_kernel = false;
728     dump_option.save_ptr_address = true;
729     if (!AnfToJsonDesc(topo_valid_nodes_, dump_option, &json_desc, &address_node_map)) {
730       MS_LOG(ERROR) << "Collect json desc failed.";
731       return false;
732     }
733 
734     // call costmodel split function.
735     auto json_desc_str = json_desc.dump();
736     auto flags_str = CollectSplitFlags();
737     MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json: " << json_desc_str
738                   << ". flag: " << flags_str;
739     auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str, flags_str);
740     if (py::isinstance<py::none>(ret)) {
741       MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
742                     << json_desc_str << ". flag: " << flags_str;
743       return false;
744     }
745     std::string split_graphs_str = py::cast<std::string>(ret);
746     if (split_graphs_str.empty()) {
747       MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
748                     << json_desc_str << ". flag: " << flags_str;
749       return false;
750     }
751 
752     if (!DecodeJson(split_graphs_str, address_node_map)) {
753       MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str;
754       return false;
755     }
756     return true;
757   }
758 
DecodeJson(const std::string & json_desc,const std::map<std::string,AnfNodePtr> & address_node_map)759   virtual bool DecodeJson(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map) {
760     auto kernel_json = nlohmann::json::parse(json_desc);
761     std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
762     std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode];
763     if (graph_modes.size() != graph_descs.size()) {
764       MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size();
765       return false;
766     }
767 
768     // recover json to anfnode.
769     split_plan_.clear();
770     for (const auto &graph_desc : graph_descs) {
771       AnfNodePtrList res_graph;
772       if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
773         MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
774         return false;
775       }
776       split_plan_.emplace_back(std::move(res_graph));
777     }
778 
779     // ops to be inlined.
780     need_inline_.clear();
781     std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_),
782                    [](const std::string &mode) { return mode == "basic" ? 1 : 0; });
783     return true;
784   }
785 
Run()786   virtual void Run() {
787     auto mng = func_graph_->manager();
788     if (mng == nullptr) {
789       mng = Manage(func_graph_, true);
790       func_graph_->set_manager(mng);
791     }
792     GetValidKernelNodes();
793     // call CostModel to get a split plan.
794     if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) {
795       split_plan_.clear();
796       need_inline_.clear();
797       return;
798     } else if (split_plan_.size() == 1 && !NeedInline(0)) {
799       // In this case, the CostModel decided to keep the whole graph unchanged.
800       split_plan_.clear();
801       need_inline_.clear();
802       return;
803     } else {
804       MS_LOG(DEBUG) << "CostModel split succeeded. The kernel is split to " << split_plan_.size() << " parts.";
805     }
806     MapNodeGroup();
807     GroupReturnNode();
808     GroupVirtualNodes();
809   }
810 
IsValidKernelNode(const AnfNodePtr & node) const811   virtual bool IsValidKernelNode(const AnfNodePtr &node) const {
812     if (!node->isa<CNode>()) return false;
813     if (AnfAlgo::IsRealKernel(node)) return true;
814     return false;
815   }
816 
GetValidKernelNodes()817   virtual void GetValidKernelNodes() {
818     topo_all_nodes_ = TopoSort(func_graph_->get_return());
819     topo_valid_nodes_.clear();
820     std::copy_if(topo_all_nodes_.begin(), topo_all_nodes_.end(), std::back_inserter(topo_valid_nodes_),
821                  [this](const AnfNodePtr &node) { return IsValidKernelNode(node); });
822   }
823 
MapNodeGroup()824   void MapNodeGroup() {
825     node_group_.clear();
826     for (size_t i = 0; i < split_plan_.size(); ++i) {
827       for (const auto &node : split_plan_[i]) {
828         node_group_[node] = i;
829       }
830     }
831   }
832 
833   // group the return node and last MakeTuple node (if exists).
GroupReturnNode()834   virtual void GroupReturnNode() {
835     AnfNodePtrList outputs;
836     kernel::GetFuncGraphOutputNodes(func_graph_, &outputs);
837     auto ret_node = func_graph_->get_return();
838     auto output = func_graph_->output();
839     MS_EXCEPTION_IF_NULL(output);
840 
841     if (IsValidKernelNode(output)) {
842       auto group_id = node_group_[ret_node] = node_group_[output];
843       split_plan_[group_id].emplace_back(ret_node);
844       return;
845     }
846     // assign the make_tuple node to a new group.
847     if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
848       auto group_id = split_plan_.size();
849       split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
850       need_inline_.emplace_back(1);
851       node_group_[ret_node] = node_group_[output] = group_id;
852       return;
853     }
854   }
855 
856   // assign virtual node to the same group of its input.
GroupVirtualNodes()857   virtual void GroupVirtualNodes() {
858     for (const auto &node : topo_all_nodes_) {
859       if (node_group_.count(node)) continue;
860       auto cnode = node->cast<CNodePtr>();
861       if (cnode == nullptr) continue;
862       bool found = false;
863       for (const auto &input : cnode->inputs()) {
864         auto iter = node_group_.find(input);
865         if (iter != node_group_.end()) {
866           node_group_[node] = iter->second;
867           split_plan_[iter->second].emplace_back(node);
868           found = true;
869           break;
870         }
871       }
872       if (!found) {
873         MS_LOG(WARNING) << cnode->fullname_with_scope() << " is ungrouped.";
874       }
875     }
876   }
877 
CollectSplitFlags()878   virtual std::string CollectSplitFlags() {
879     const auto &flags = context::GraphKernelFlags::GetInstance();
880     nlohmann::json flag_json;
881     flag_json["dump_as_text"] = flags.dump_as_text;
882     flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
883     flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion;
884     return flag_json.dump();
885   }
886 
887   std::shared_ptr<FuncGraph> func_graph_;
888   AnfNodePtrList topo_all_nodes_;
889   AnfNodePtrList topo_valid_nodes_;
890   std::unordered_map<AnfNodePtr, size_t> node_group_;
891   std::vector<int> need_inline_;
892 };
893 
TrySplit(const CNodePtr & sub_root_cnode)894 bool TrySplit(const CNodePtr &sub_root_cnode) {
895   MS_LOG(DEBUG) << "Split process node: " << sub_root_cnode->fullname_with_scope();
896   auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared<CostModelSplitSchemer>());
897   MS_EXCEPTION_IF_NULL(splitter);
898   bool result = splitter->Split();
899   MS_LOG(DEBUG) << "Split node completed, result: " << result;
900   return result;
901 }
902 }  // namespace
903 
Run(const FuncGraphPtr & func_graph)904 bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
905   MS_EXCEPTION_IF_NULL(func_graph);
906   auto mng = func_graph->manager();
907   if (mng == nullptr) {
908     mng = Manage(func_graph, true);
909     func_graph->set_manager(mng);
910   }
911   auto todos = TopoSort(func_graph->get_return());
912 
913   // Split subgraphs in reversed topo order,
914   // since the nodes behind the processing node may be modified when splitting.
915   bool changed = false;
916   for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
917     auto node = (*iter)->cast<CNodePtr>();
918     if (node != nullptr && AnfAlgo::IsGraphKernel(node)) {
919       changed = TrySplit(node) || changed;
920     }
921   }
922   mng->RemoveRoots();
923   mng->KeepRoots({func_graph});
924   return changed;
925 }
926 }  // namespace opt
927 }  // namespace mindspore
928