1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
17 #include <algorithm>
18 #include <vector>
19 #include <string>
20 #include <unordered_set>
21 #include <utility>
22 #include <queue>
23 #include <map>
24 #include <unordered_map>
25 #include "frontend/optimizer/irpass.h"
26 #include "pipeline/jit/parse/python_adapter.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
30 #include "debug/anf_ir_dump.h"
31 #include "utils/context/graph_kernel_flags.h"
32
33 namespace mindspore {
34 namespace kernel {
35 namespace {
GetStitchInfo(const nlohmann::json & kernel_json)36 StitchInfo GetStitchInfo(const nlohmann::json &kernel_json) {
37 StitchInfo info;
38 if (kernel_json.find(kJsonKeyBufferStitch) != kernel_json.end()) {
39 nlohmann::json buffer_stitch = kernel_json[kJsonKeyBufferStitch];
40 if (buffer_stitch.find(kJsonKeyStitchOp) != buffer_stitch.end()) {
41 std::vector<std::string> stitch_ops = buffer_stitch[kJsonKeyStitchOp];
42 info.stitch_ops = stitch_ops;
43 }
44 if (buffer_stitch.find(kJsonKeyStitchAtomicOp) != buffer_stitch.end()) {
45 std::vector<std::string> stitch_atomic_ops = buffer_stitch[kJsonKeyStitchAtomicOp];
46 info.stitch_atomic_ops = stitch_atomic_ops;
47 }
48 }
49 return info;
50 }
51
GetRecomputeOps(const nlohmann::json & kernel_json)52 std::set<std::string> GetRecomputeOps(const nlohmann::json &kernel_json) {
53 if (kernel_json.find(kJsonKeyRecomputeOps) != kernel_json.end()) {
54 std::vector<std::string> recompute_ops = kernel_json[kJsonKeyRecomputeOps];
55 return std::set<std::string>(recompute_ops.begin(), recompute_ops.end());
56 }
57 return std::set<std::string>();
58 }
59
IsRecomputeOp(const nlohmann::json & op_desc,const std::set<std::string> & recompute_ops)60 bool IsRecomputeOp(const nlohmann::json &op_desc, const std::set<std::string> &recompute_ops) {
61 std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
62 if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) {
63 return false;
64 }
65 std::string tensor_name = output_descs[0][kJsonKeyTensorName];
66 if (recompute_ops.count(tensor_name)) {
67 return true;
68 }
69 return false;
70 }
71
NewRecomputeNode(const AnfNodePtr & orig_node,std::map<AnfNodePtr,AnfNodePtr> * node_map)72 CNodePtr NewRecomputeNode(const AnfNodePtr &orig_node, std::map<AnfNodePtr, AnfNodePtr> *node_map) {
73 auto func_graph = orig_node->func_graph();
74 MS_EXCEPTION_IF_NULL(func_graph);
75 auto cnode = orig_node->cast<CNodePtr>();
76 MS_EXCEPTION_IF_NULL(cnode);
77 TraceGuard guard(std::make_shared<TraceOpt>(cnode->debug_info()));
78 auto orig_inputs = cnode->inputs();
79 std::vector<AnfNodePtr> inputs;
80 for (auto inp : orig_inputs) {
81 if (node_map->find(inp) == node_map->end()) {
82 inputs.push_back(inp);
83 continue;
84 }
85 inputs.push_back((*node_map)[inp]);
86 }
87 CNodePtr cp_node = func_graph->NewCNode(inputs);
88 func_graph->AddNode(cp_node);
89 cp_node->set_abstract(cnode->abstract());
90 cp_node->set_forward(cnode->forward().first, cnode->forward().second);
91 cp_node->set_inputs_value(cnode->inputs_value());
92 ScopePtr scope = (orig_node->scope() != kDefaultScope) ? orig_node->scope() : kDefaultScope;
93 cp_node->set_scope(scope);
94 cp_node->set_kernel_info(cnode->kernel_info_ptr());
95 cp_node->set_primal_attrs(cnode->primal_attrs());
96 cp_node->set_primal_debug_infos(cnode->primal_debug_infos());
97 (*node_map)[orig_node] = cp_node;
98 return cp_node->cast<CNodePtr>();
99 }
100
SetStitchAttr(const nlohmann::json & op_desc,const StitchInfo & info,const CNodePtr & node)101 void SetStitchAttr(const nlohmann::json &op_desc, const StitchInfo &info, const CNodePtr &node) {
102 std::vector<nlohmann::json> output_descs = op_desc[kJsonKeyOutputDesc];
103 if (output_descs.empty() || output_descs[0].find(kJsonKeyTensorName) == output_descs[0].end()) return;
104 std::string tensor_name = output_descs[0][kJsonKeyTensorName];
105 if (std::find(info.stitch_ops.begin(), info.stitch_ops.end(), tensor_name) != info.stitch_ops.end()) {
106 AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("common"), node);
107 MS_LOG(INFO) << "Enable common stitch fusion by " << node->fullname_with_scope();
108 }
109 if (std::find(info.stitch_atomic_ops.begin(), info.stitch_atomic_ops.end(), tensor_name) !=
110 info.stitch_atomic_ops.end()) {
111 AnfAlgo::SetNodeAttr(kAttrStitch, MakeValue("atomic"), node);
112 MS_LOG(INFO) << "Enable atomic add stitch fusion by " << node->fullname_with_scope();
113 }
114 }
115
116 // replace original region root op by its copy in this res_graphs
ConnectRecomputeOps(AnfNodePtrList * res_graphs,const AnfNodePtr & orig_region_root,const AnfNodePtr & cp_region_root)117 void ConnectRecomputeOps(AnfNodePtrList *res_graphs, const AnfNodePtr &orig_region_root,
118 const AnfNodePtr &cp_region_root) {
119 for (auto &node : *res_graphs) {
120 auto cnode = node->cast<CNodePtr>();
121 auto inputs = cnode->inputs();
122 for (size_t i = 1; i < inputs.size(); ++i) {
123 if (inputs[i] != orig_region_root) continue;
124 cnode->set_input(i, cp_region_root);
125 }
126 }
127 }
128 } // namespace
129
DecodeSplitNodes(const nlohmann::json & kernel_json,const std::map<std::string,AnfNodePtr> & address_node_map,AnfNodePtrList * res_graphs)130 bool SplitNodesDecoder::DecodeSplitNodes(const nlohmann::json &kernel_json,
131 const std::map<std::string, AnfNodePtr> &address_node_map,
132 AnfNodePtrList *res_graphs) {
133 MS_EXCEPTION_IF_NULL(res_graphs);
134 MS_LOG(DEBUG) << "start decode, " << kernel_json;
135 // decode cnodes in graph.
136 std::vector<nlohmann::json> op_node_descs = kernel_json[kJsonKeyOpDesc];
137 if (op_node_descs.empty()) {
138 MS_LOG(ERROR) << "Error decode, no cnodes for graph." << kernel_json;
139 return false;
140 }
141 StitchInfo info = GetStitchInfo(kernel_json);
142 auto recompute_ops = GetRecomputeOps(kernel_json);
143 // key_value: original_copied
144 std::map<AnfNodePtr, AnfNodePtr> node_map;
145 // nodes would be copied
146 AnfNodePtrList orig_region_nodes;
147 // nodes would not be copied
148 AnfNodePtrList no_cp_nodes;
149 for (const auto &op_desc : op_node_descs) {
150 if (op_desc.find(kJsonKeyPtrAddress) == op_desc.end() || op_desc[kJsonKeyPtrAddress].is_null()) {
151 MS_LOG(ERROR) << "Decode failed, key: " << kJsonKeyPtrAddress << " not found in: " << op_desc;
152 return false;
153 }
154
155 std::string ptr_address = op_desc[kJsonKeyPtrAddress];
156 if (address_node_map.count(ptr_address) == 0) {
157 MS_LOG(ERROR) << "Decode failed, ptr_address not found in map.";
158 return false;
159 }
160 auto node = address_node_map.at(ptr_address)->cast<CNodePtr>();
161 if (IsRecomputeOp(op_desc, recompute_ops)) {
162 auto cp_node = NewRecomputeNode(node, &node_map);
163 orig_region_nodes.push_back(node);
164 SetStitchAttr(op_desc, info, cp_node);
165 res_graphs->push_back(cp_node);
166 continue;
167 }
168 SetStitchAttr(op_desc, info, node);
169 res_graphs->push_back(node);
170 no_cp_nodes.push_back(node);
171 }
172 for (auto orig_node : orig_region_nodes) {
173 ConnectRecomputeOps(&no_cp_nodes, orig_node, node_map[orig_node]);
174 }
175 MS_LOG(DEBUG) << "decode cnodes success, size: " << res_graphs->size();
176 return true;
177 }
178 } // namespace kernel
179
180 namespace opt {
181 namespace {
TraverseFuncGraphFromCNode(const CNodePtr & cnode,const std::function<void (AnfNodePtr &)> & callback)182 void TraverseFuncGraphFromCNode(const CNodePtr &cnode, const std::function<void(AnfNodePtr &)> &callback) {
183 std::unordered_set<AnfNodePtr> visited;
184 std::queue<AnfNodePtr> que;
185 que.push(cnode);
186 visited.insert(cnode);
187 while (!que.empty()) {
188 auto ft_node = que.front();
189 que.pop();
190 callback(ft_node);
191 auto ft_cnode = ft_node->cast<CNodePtr>();
192 if (ft_cnode == nullptr) continue;
193 for (const auto &in_node : ft_cnode->inputs()) {
194 if (visited.count(in_node) == 0) {
195 que.push(in_node);
196 visited.insert(in_node);
197 }
198 }
199 }
200 }
201
202 // Visited each AnfNode once, use callback to do the job on AnfNode
TraverseFuncGraph(const FuncGraphPtr & root,const std::function<void (AnfNodePtr &)> & callback)203 inline void TraverseFuncGraph(const FuncGraphPtr &root, const std::function<void(AnfNodePtr &)> &callback) {
204 TraverseFuncGraphFromCNode(root->get_return(), callback);
205 }
206
207 class Area {
208 public:
Area(const AnfNodePtrList & anf_arr)209 explicit Area(const AnfNodePtrList &anf_arr) {
210 nodes_.insert(anf_arr.begin(), anf_arr.end());
211 for (auto &node : anf_arr) {
212 auto cnode = node->cast<CNodePtr>();
213 if (cnode == nullptr) continue;
214 const auto &inputs = cnode->inputs();
215 if (std::any_of(inputs.begin(), inputs.end(), [this](const AnfNodePtr &node) { return IsExternalCNode(node); })) {
216 spy_cnodes_.emplace_back(node);
217 }
218 }
219 }
220
221 ~Area() = default;
222
223 // Set the external inputs of spy as a Parameter.
CreateParameters(const FuncGraphPtr & func_graph,std::unordered_map<ParameterPtr,AnfNodePtr> * param_node_map)224 void CreateParameters(const FuncGraphPtr &func_graph, std::unordered_map<ParameterPtr, AnfNodePtr> *param_node_map) {
225 std::unordered_map<AnfNodePtr, ParameterPtr> node_param_map;
226 for (auto node : this->spy_cnodes_) {
227 auto cnode = node->cast<CNodePtr>();
228 MS_EXCEPTION_IF_NULL(cnode);
229 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
230 AnfNodePtr in_node = cnode->input(i);
231 if (!IsExternalCNode(in_node)) continue;
232 auto it = node_param_map.find(in_node);
233 if (it == node_param_map.end()) {
234 auto new_param = std::make_shared<Parameter>(func_graph);
235 new_param->set_abstract(in_node->abstract());
236 func_graph->add_parameter(new_param);
237 node_param_map.insert(std::make_pair(in_node, new_param));
238 cnode->set_input(i, new_param);
239 } else {
240 cnode->set_input(i, it->second);
241 }
242 }
243 }
244 this->spy_cnodes_.clear(); // spy list is not useful anymore
245 for (auto &&elem : node_param_map) {
246 param_node_map->insert(std::make_pair(elem.second, elem.first));
247 }
248 return;
249 }
250
251 // Make a return node for traitor nodes.
CreateReturnNode(const FuncGraphPtr & func_graph,std::unordered_map<AnfNodePtr,size_t> * tuple_node_index)252 void CreateReturnNode(const FuncGraphPtr &func_graph, std::unordered_map<AnfNodePtr, size_t> *tuple_node_index) {
253 // If there's no traitor in the area, it means that this area is the last part
254 // of the original FuncGraph, it already contains the original Return node.
255 if (traitor_nodes_.empty()) {
256 for (auto &node : nodes_) {
257 if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
258 func_graph->set_return(node->cast<CNodePtr>());
259 node->set_func_graph(func_graph);
260 return;
261 }
262 }
263 MS_LOG(ERROR) << "Cannot find the return node in " << func_graph->ToString();
264 return;
265 }
266 AnfNodePtrList return_inputs = {NewValueNode(prim::kPrimReturn)};
267 if (traitor_nodes_.size() > 1) {
268 // The area has multiple output, it's necessary to make a tuple for them.
269 AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
270 AbstractBasePtrList abstracts;
271 size_t i = 0;
272 for (auto &traitor : traitor_nodes_) {
273 tuple_node_index->insert(std::make_pair(traitor, i++));
274 maketuple_inputs.emplace_back(traitor);
275 abstracts.emplace_back(traitor->abstract());
276 }
277 auto maketuple_node = func_graph->NewCNode(maketuple_inputs);
278 maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstracts));
279 nodes_.insert(maketuple_node);
280 return_inputs.emplace_back(maketuple_node);
281 } else {
282 return_inputs.emplace_back(traitor_nodes_[0]);
283 }
284 auto return_node = func_graph->NewCNode(return_inputs);
285 return_node->set_abstract(return_inputs.back()->abstract());
286 func_graph->set_return(return_node);
287 nodes_.insert(return_node);
288 traitor_nodes_.clear(); // traitor list is not useful anymore
289 return;
290 }
291
AddTraitor(const AnfNodePtr & node)292 void AddTraitor(const AnfNodePtr &node) {
293 if (std::find(traitor_nodes_.begin(), traitor_nodes_.end(), node) == traitor_nodes_.end()) {
294 traitor_nodes_.emplace_back(node);
295 }
296 }
297
nodes() const298 const std::unordered_set<AnfNodePtr> &nodes() const { return nodes_; }
spy_cnodes() const299 const std::vector<AnfNodePtr> &spy_cnodes() const { return spy_cnodes_; }
300
301 private:
302 // This is a CNode that does not belong to this area.
IsExternalCNode(const AnfNodePtr & node) const303 bool IsExternalCNode(const AnfNodePtr &node) const { return node->isa<CNode>() && this->nodes_.count(node) == 0; }
304
305 // nodes in this area
306 std::unordered_set<AnfNodePtr> nodes_;
307 // if a node's output is used by other Area, it's a traitor
308 std::vector<AnfNodePtr> traitor_nodes_;
309 // if a node use other Area's output, it's a spy
310 std::vector<AnfNodePtr> spy_cnodes_;
311 };
312
313 class AreaGraph {
314 public:
315 using AreaGraphPtr = std::shared_ptr<AreaGraph>;
316
317 // Build an area graph to maintain the relation between areas.
318 // Input node_groups: A group list, each element is a AnfNode list representing the node set in this group.
BuildAreaGraph(const std::vector<AnfNodePtrList> & node_groups)319 static AreaGraphPtr BuildAreaGraph(const std::vector<AnfNodePtrList> &node_groups) {
320 auto area_graph = std::make_shared<AreaGraph>(node_groups);
321 if (area_graph == nullptr) return nullptr;
322 if (!area_graph->TopoSort()) {
323 MS_LOG(WARNING) << "The groups have a cycle.";
324 return nullptr;
325 }
326 return area_graph;
327 }
328
329 // Split the graph to multiple areas, and reconnect the edges between the areas.
330 // The output `main_cnodes` is a topo-sorted cnode list in main graph, holding the new sub_func_graphs.
331 // The output `cnode_group_id` represents the indices of main_cnodes before topo-sorting.
SplitGraph(const FuncGraphPtr & main_func_graph,std::vector<CNodePtr> * main_cnodes,std::vector<size_t> * cnode_group_id,const std::function<void (const Area &)> & expand_callback)332 void SplitGraph(const FuncGraphPtr &main_func_graph, std::vector<CNodePtr> *main_cnodes,
333 std::vector<size_t> *cnode_group_id, const std::function<void(const Area &)> &expand_callback) {
334 main_cnodes->clear();
335 main_cnodes->resize(areas_.size(), nullptr);
336
337 for (auto &area : this->areas_) {
338 expand_callback(area);
339 }
340
341 for (auto index : topo_order_) {
342 auto ¤t_area = areas_[index];
343 auto sub_func_graph = std::make_shared<FuncGraph>();
344 std::unordered_map<ParameterPtr, AnfNodePtr> param_node_map;
345
346 current_area.CreateParameters(sub_func_graph, ¶m_node_map);
347 current_area.CreateReturnNode(sub_func_graph, &node_index_in_returned_tuple_);
348 auto new_main_cnode = this->CreateMainCNode(main_func_graph, sub_func_graph, *main_cnodes, param_node_map);
349 (*main_cnodes)[index] = new_main_cnode;
350 }
351
352 SortCNodes(main_cnodes);
353 *cnode_group_id = std::move(topo_order_); // The topo_order is not used anymore.
354 return;
355 }
356
AreaGraph(const std::vector<AnfNodePtrList> & node_groups)357 explicit AreaGraph(const std::vector<AnfNodePtrList> &node_groups) : edge_prev_(node_groups.size()) {
358 for (size_t i = 0; i < node_groups.size(); ++i) {
359 areas_.emplace_back(node_groups[i]);
360 for (const auto &node : node_groups[i]) {
361 node_area_map_[node] = i;
362 }
363 }
364 for (auto &area : areas_) {
365 for (auto &spy : area.spy_cnodes()) {
366 auto cnode = spy->cast<CNodePtr>();
367 MS_EXCEPTION_IF_NULL(cnode);
368 size_t v = node_area_map_[spy];
369 for (auto &in_node : cnode->inputs()) {
370 if (!in_node->isa<CNode>()) continue;
371 // area edge u -> v
372 size_t u = node_area_map_[in_node];
373 if (u == v) continue;
374 areas_[u].AddTraitor(in_node);
375 if (std::find(edge_prev_[v].begin(), edge_prev_[v].end(), u) == edge_prev_[v].end()) {
376 edge_prev_[v].emplace_back(u);
377 }
378 }
379 }
380 }
381 }
382 ~AreaGraph() = default;
383
384 private:
385 // Topological sort the areas.
TopoSort()386 bool TopoSort() {
387 std::vector<int> out_degree(edge_prev_.size(), 0);
388 std::queue<size_t> que;
389 for (auto &prev : edge_prev_) {
390 for (size_t i : prev) {
391 out_degree[i]++;
392 }
393 }
394 for (size_t i = 0; i < out_degree.size(); ++i) {
395 if (out_degree[i] == 0) que.push(i);
396 }
397 while (!que.empty()) {
398 size_t u = que.front();
399 que.pop();
400 topo_order_.emplace_back(u);
401 for (size_t i : edge_prev_[u]) {
402 if (--out_degree[i] == 0) que.push(i);
403 }
404 }
405 std::reverse(topo_order_.begin(), topo_order_.end());
406 return topo_order_.size() == areas_.size();
407 }
408
409 // Make a CNode in main graph to hold the sub_func_graph.
CreateMainCNode(const FuncGraphPtr & main_func_graph,const FuncGraphPtr & sub_func_graph,const std::vector<CNodePtr> & main_cnodes,const std::unordered_map<ParameterPtr,AnfNodePtr> & param_node_map)410 CNodePtr CreateMainCNode(const FuncGraphPtr &main_func_graph, const FuncGraphPtr &sub_func_graph,
411 const std::vector<CNodePtr> &main_cnodes,
412 const std::unordered_map<ParameterPtr, AnfNodePtr> ¶m_node_map) {
413 TraceGuard guard(std::make_shared<TraceOpt>(sub_func_graph->debug_info()));
414 AnfNodePtrList main_cnode_inputs = {NewValueNode(sub_func_graph)};
415 for (const auto ¶m : sub_func_graph->parameters()) {
416 // assert the param exists.
417 const auto &input_node = param_node_map.find(param->cast<ParameterPtr>())->second;
418 size_t input_area = node_area_map_[input_node];
419 // if the input node is in a tuple, then we need to create a GetItem fot it.
420 if (node_index_in_returned_tuple_.count(input_node) != 0) {
421 auto idx_val = SizeToLong(node_index_in_returned_tuple_[input_node]);
422 auto idx = NewValueNode(idx_val);
423 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(idx_val));
424 AnfNodePtrList getitem_inputs = {NewValueNode(prim::kPrimTupleGetItem), main_cnodes[input_area], idx};
425 TraceGuard g_sub(std::make_shared<TraceOpt>(main_cnodes[input_area]->debug_info()));
426 auto getitem_node = main_func_graph->NewCNode(getitem_inputs);
427 auto abs_tuple = dyn_cast<abstract::AbstractTuple>(main_cnodes[input_area]->abstract());
428 if (idx_val < SizeToLong(abs_tuple->size())) {
429 getitem_node->set_abstract(abs_tuple->elements()[LongToSize(idx_val)]);
430 } else {
431 getitem_node->set_abstract(main_cnodes[input_area]->abstract());
432 }
433 main_cnode_inputs.emplace_back(getitem_node);
434 } else {
435 main_cnode_inputs.emplace_back(main_cnodes[input_area]);
436 }
437 }
438 auto new_main_cnode = main_func_graph->NewCNode(main_cnode_inputs);
439 new_main_cnode->set_abstract(sub_func_graph->output()->abstract());
440 return new_main_cnode;
441 }
442
SortCNodes(std::vector<CNodePtr> * main_cnodes) const443 void SortCNodes(std::vector<CNodePtr> *main_cnodes) const {
444 std::vector<CNodePtr> main_cnodes_sorted;
445 std::transform(topo_order_.begin(), topo_order_.end(), std::back_inserter(main_cnodes_sorted),
446 [main_cnodes](size_t index) { return main_cnodes->at(index); });
447 *main_cnodes = std::move(main_cnodes_sorted);
448 }
449
450 // Areas in this subgraph
451 std::vector<Area> areas_;
452 // Adjacency table of areas
453 std::vector<std::vector<size_t>> edge_prev_;
454 // Topological order of areas
455 std::vector<size_t> topo_order_;
456 // Map AnfNode to Area id
457 std::unordered_map<AnfNodePtr, size_t> node_area_map_;
458 // Map the nodes to their index if there are multiple value in an area
459 std::unordered_map<AnfNodePtr, size_t> node_index_in_returned_tuple_;
460 };
461
462 class SplitSchemer {
463 public:
464 SplitSchemer() = default;
465 virtual ~SplitSchemer() = default;
466 virtual bool Split(const FuncGraphPtr &func_graph) = 0;
NeedInline(size_t group_id) const467 virtual bool NeedInline(size_t group_id) const { return false; }
split_plan() const468 const std::vector<AnfNodePtrList> &split_plan() const { return split_plan_; }
469
470 protected:
471 std::vector<AnfNodePtrList> split_plan_;
472 };
473
474 class Splitter {
475 public:
476 using SplitSchemerPtr = std::shared_ptr<SplitSchemer>;
477 using SplitterPtr = std::shared_ptr<Splitter>;
478
Split()479 bool Split() {
480 GenParamMap();
481 auto ori_sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
482 if (!split_schemer_->Split(ori_sub_func_graph)) {
483 return false;
484 }
485
486 auto area_graph = AreaGraph::BuildAreaGraph(split_schemer_->split_plan());
487 if (area_graph == nullptr) {
488 return false;
489 }
490
491 // The output new_subgraph_cnodes are topo sorted, use a list to store its order in split_plan.
492 std::vector<size_t> cnodes_group_id;
493 area_graph->SplitGraph(main_func_graph_, &new_subgraph_cnodes_, &cnodes_group_id,
494 [this](const Area &area) { this->AreaExpand(area); });
495
496 RebuildGraph(cnodes_group_id);
497
498 return true;
499 }
500
MakeSplitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)501 static SplitterPtr MakeSplitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer) {
502 MS_EXCEPTION_IF_NULL(main_cnode);
503 MS_EXCEPTION_IF_NULL(main_cnode->func_graph());
504 MS_EXCEPTION_IF_NULL(split_schemer);
505 return std::make_shared<Splitter>(main_cnode, split_schemer);
506 }
507
Splitter(const CNodePtr & main_cnode,const SplitSchemerPtr & split_schemer)508 Splitter(const CNodePtr &main_cnode, const SplitSchemerPtr &split_schemer)
509 : main_func_graph_(main_cnode->func_graph()), old_subgraph_cnode_(main_cnode), split_schemer_(split_schemer) {}
510 ~Splitter() = default;
511
512 private:
ResetInlinedNodesKernelInfo() const513 void ResetInlinedNodesKernelInfo() const {
514 for (const auto &node : inlined_nodes_) {
515 ResetKernelInfo(node);
516 }
517 }
518
519 // Maintain new subgraphs in main graph.
RebuildGraph(const std::vector<size_t> & cnodes_group_id)520 void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
521 BindFuncGraph();
522 RecoverParameter();
523 ConnectToMainGraph(cnodes_group_id);
524 UpdateSubGraphInfo();
525 ResetInlinedNodesKernelInfo();
526 }
527
528 // Rebind nodes to its new sub_func_graph
BindFuncGraph() const529 void BindFuncGraph() const {
530 for (const auto &cnode : new_subgraph_cnodes_) {
531 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
532 auto callback = [&sub_func_graph](const AnfNodePtr &node) {
533 if (!node->isa<ValueNode>()) {
534 node->set_func_graph(sub_func_graph);
535 }
536 };
537 TraverseFuncGraph(sub_func_graph, callback);
538 }
539 }
540
541 // Recover the original subgraph's parameter if the new graph needs it
RecoverParameter()542 void RecoverParameter() {
543 for (const auto &cnode : new_subgraph_cnodes_) {
544 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
545 auto callback = [&cnode, &sub_func_graph, this](const AnfNodePtr &node) {
546 auto param = node->cast<ParameterPtr>();
547 if (param == nullptr) return;
548 auto it = this->param_to_main_graph_node_map_.find(param);
549 if (it != this->param_to_main_graph_node_map_.end()) {
550 cnode->add_input(it->second);
551 sub_func_graph->add_parameter(param);
552 // Avoid repeating parameters.
553 this->param_to_main_graph_node_map_.erase(it);
554 }
555 };
556 TraverseFuncGraph(sub_func_graph, callback);
557 }
558 }
559
InlineSubFuncGraph(const CNodePtr & main_node)560 CNodePtr InlineSubFuncGraph(const CNodePtr &main_node) {
561 auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(main_node);
562 const auto &inputs = main_node->inputs();
563 auto output = func_graph->output()->cast<CNodePtr>();
564 MS_EXCEPTION_IF_NULL(output);
565 const auto ¶meters = func_graph->parameters();
566 std::unordered_map<AnfNodePtr, AnfNodePtr> param_input;
567 for (size_t i = 0; i < parameters.size(); ++i) {
568 param_input[parameters[i]] = inputs[i + 1];
569 }
570 auto sub_nodes = TopoSort(func_graph->get_return());
571 for (auto node : sub_nodes) {
572 if (auto cnode = node->cast<CNodePtr>(); cnode != nullptr) {
573 cnode->set_func_graph(main_func_graph_);
574 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
575 auto iter = param_input.find(cnode->input(i));
576 if (iter != param_input.end()) {
577 cnode->set_input(i, iter->second);
578 }
579 }
580 if (AnfAlgo::IsRealKernel(node)) {
581 inlined_nodes_.emplace_back(node);
582 }
583 }
584 }
585 return output;
586 }
587
588 // Set the new sub_func_graph node as input of nodes original main graph.
ConnectToMainGraph(const std::vector<size_t> & cnodes_group_id)589 void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
590 // For single output kernel, the last area contains the original output node (return node),
591 // to replace old subgraph with new subgraphs, just replace the old CNode with new last CNode.
592 // For multiple output kernel, to avoid returning Parameter, the last MakeTuple was distribute to
593 // a new FuncGraph, just inline the last MakeTuple node.
594 std::vector<CNodePtr> tmp_subgraph_cnodes;
595 std::unordered_map<AnfNodePtr, AnfNodePtr> replace_map;
596
597 for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
598 if (split_schemer_->NeedInline(cnodes_group_id[i])) {
599 // Connect the sub_graph's inner node to main_graph
600 auto output = InlineSubFuncGraph(new_subgraph_cnodes_[i]);
601 if (i + 1 == new_subgraph_cnodes_.size()) {
602 replace_map[this->old_subgraph_cnode_] = output;
603 } else {
604 replace_map[new_subgraph_cnodes_[i]] = output;
605 }
606 } else {
607 if (i + 1 == new_subgraph_cnodes_.size()) {
608 replace_map[this->old_subgraph_cnode_] = new_subgraph_cnodes_.back();
609 }
610 tmp_subgraph_cnodes.emplace_back(new_subgraph_cnodes_[i]);
611 }
612 }
613 new_subgraph_cnodes_ = std::move(tmp_subgraph_cnodes);
614
615 TraverseFuncGraph(main_func_graph_, [&replace_map](const AnfNodePtr &node) {
616 auto cnode = node->cast<CNodePtr>();
617 if (cnode == nullptr) return;
618 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
619 auto input_node = cnode->input(i);
620 auto iter = replace_map.find(input_node);
621 if (iter != replace_map.end()) {
622 cnode->set_input(i, iter->second);
623 }
624 }
625 });
626 }
627
UpdateSubGraphInfo() const628 void UpdateSubGraphInfo() const {
629 auto graph_manager = main_func_graph_->manager();
630 MS_EXCEPTION_IF_NULL(graph_manager);
631
632 for (auto cnode : new_subgraph_cnodes_) {
633 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
634 // add new sub_func_graph to manager
635 graph_manager->AddFuncGraph(sub_func_graph);
636
637 // set GraphKernel attr
638 auto attr = ExtractGraphKernelName(TopoSort(sub_func_graph->get_return()), "", "split");
639 sub_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(attr));
640
641 // set kernel info
642 AnfNodePtrList inputs(cnode->inputs().begin() + 1, cnode->inputs().end());
643 AnfNodePtrList outputs;
644 kernel::GetFuncGraphOutputNodes(sub_func_graph, &outputs);
645 SetNewKernelInfo(cnode, sub_func_graph, inputs, outputs);
646 }
647 }
648
649 // Copy all Parameter and ValueNode that the area used.
AreaExpand(const Area & area)650 void AreaExpand(const Area &area) {
651 std::unordered_map<AnfNodePtr, AnfNodePtr> old_valuenode_and_param_map;
652 for (auto sub_node : area.nodes()) {
653 auto sub_cnode = sub_node->cast<CNodePtr>();
654 if (sub_cnode == nullptr) continue;
655 for (size_t i = 1; i < sub_cnode->inputs().size(); ++i) {
656 auto in_node = sub_cnode->input(i);
657 if (in_node->isa<CNode>()) continue;
658 auto it = old_valuenode_and_param_map.find(in_node);
659 if (it != old_valuenode_and_param_map.end()) {
660 sub_cnode->set_input(i, it->second);
661 } else {
662 if (in_node->isa<Parameter>()) {
663 auto param = in_node->cast<ParameterPtr>();
664 auto cp_param = this->ParameterClone(param, in_node->func_graph());
665 old_valuenode_and_param_map[in_node] = cp_param->cast<AnfNodePtr>();
666 sub_cnode->set_input(i, cp_param);
667 }
668 }
669 }
670 }
671 }
672
GenParamMap()673 void GenParamMap() {
674 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(old_subgraph_cnode_);
675 auto ¶m_arr = sub_func_graph->parameters();
676 for (size_t i = 0; i < param_arr.size(); ++i) {
677 auto param = param_arr[i]->cast<ParameterPtr>();
678 MS_EXCEPTION_IF_NULL(param);
679 param_to_main_graph_node_map_[param] = old_subgraph_cnode_->input(i + 1);
680 }
681 }
682
ParameterClone(const ParameterPtr & param,const FuncGraphPtr & func)683 ParameterPtr ParameterClone(const ParameterPtr ¶m, const FuncGraphPtr &func) {
684 ParameterPtr param_c = std::make_shared<Parameter>(func);
685 param_c->set_name(param->name());
686 param_c->set_abstract(param->abstract());
687 param_to_main_graph_node_map_[param_c] = param_to_main_graph_node_map_[param];
688 return param_c;
689 }
690
691 FuncGraphPtr main_func_graph_;
692 CNodePtr old_subgraph_cnode_; // The cnode that holds the original sub_func_graph
693 std::vector<CNodePtr> new_subgraph_cnodes_; // The cnode list that hold the new sub_func_graph
694 std::vector<AnfNodePtr> inlined_nodes_;
695 SplitSchemerPtr split_schemer_;
696 std::unordered_map<ParameterPtr, AnfNodePtr> param_to_main_graph_node_map_;
697 };
698
699 class CostModelSplitSchemer : public SplitSchemer {
700 public:
701 virtual ~CostModelSplitSchemer() = default;
Split(const FuncGraphPtr & func_graph)702 bool Split(const FuncGraphPtr &func_graph) override {
703 if (!func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
704 MS_EXCEPTION(NotSupportError) << "func_graph must be a GraphKernel node.";
705 }
706 func_graph_ = func_graph;
707 this->Run();
708 return !split_plan_.empty();
709 }
710
NeedInline(size_t group_id) const711 bool NeedInline(size_t group_id) const override {
712 if (group_id >= need_inline_.size()) {
713 MS_LOG(EXCEPTION) << "The group_id " << group_id << " should be less than the group num " << need_inline_.size();
714 }
715 return need_inline_[group_id] != 0;
716 }
717
718 protected:
SplitByCostModel()719 virtual bool SplitByCostModel() {
720 // Use an address map to record the anf node address when converting to json,
721 // it will recover the original node after split.
722 std::map<std::string, AnfNodePtr> address_node_map;
723
724 // convert anf-ir to json
725 nlohmann::json json_desc;
726 DumpOption dump_option;
727 dump_option.is_before_select_kernel = false;
728 dump_option.save_ptr_address = true;
729 if (!AnfToJsonDesc(topo_valid_nodes_, dump_option, &json_desc, &address_node_map)) {
730 MS_LOG(ERROR) << "Collect json desc failed.";
731 return false;
732 }
733
734 // call costmodel split function.
735 auto json_desc_str = json_desc.dump();
736 auto flags_str = CollectSplitFlags();
737 MS_LOG(DEBUG) << "CallPyFn: [" << kGraphKernelSplitFunc << "] with input json: " << json_desc_str
738 << ". flag: " << flags_str;
739 auto ret = parse::python_adapter::CallPyFn(kGraphKernelModule, kGraphKernelSplitFunc, json_desc_str, flags_str);
740 if (py::isinstance<py::none>(ret)) {
741 MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
742 << json_desc_str << ". flag: " << flags_str;
743 return false;
744 }
745 std::string split_graphs_str = py::cast<std::string>(ret);
746 if (split_graphs_str.empty()) {
747 MS_LOG(ERROR) << "CallPyFn: [" << kGraphKernelSplitFunc << "] return invalid result. input json:\n"
748 << json_desc_str << ". flag: " << flags_str;
749 return false;
750 }
751
752 if (!DecodeJson(split_graphs_str, address_node_map)) {
753 MS_LOG(ERROR) << "Failed to decode split graphs. input json:\n" << split_graphs_str;
754 return false;
755 }
756 return true;
757 }
758
DecodeJson(const std::string & json_desc,const std::map<std::string,AnfNodePtr> & address_node_map)759 virtual bool DecodeJson(const std::string &json_desc, const std::map<std::string, AnfNodePtr> &address_node_map) {
760 auto kernel_json = nlohmann::json::parse(json_desc);
761 std::vector<nlohmann::json> graph_descs = kernel_json[kJsonKeyGraphDesc];
762 std::vector<std::string> graph_modes = kernel_json[kJsonKeyGraphMode];
763 if (graph_modes.size() != graph_descs.size()) {
764 MS_LOG(ERROR) << "Size of graph_mode " << graph_modes.size() << " mismatch graph_desc " << graph_descs.size();
765 return false;
766 }
767
768 // recover json to anfnode.
769 split_plan_.clear();
770 for (const auto &graph_desc : graph_descs) {
771 AnfNodePtrList res_graph;
772 if (!kernel::SplitNodesDecoder::DecodeSplitNodes(graph_desc, address_node_map, &res_graph)) {
773 MS_LOG(ERROR) << "Failed decode sub graph, " << graph_desc;
774 return false;
775 }
776 split_plan_.emplace_back(std::move(res_graph));
777 }
778
779 // ops to be inlined.
780 need_inline_.clear();
781 std::transform(graph_modes.begin(), graph_modes.end(), std::back_inserter(need_inline_),
782 [](const std::string &mode) { return mode == "basic" ? 1 : 0; });
783 return true;
784 }
785
Run()786 virtual void Run() {
787 auto mng = func_graph_->manager();
788 if (mng == nullptr) {
789 mng = Manage(func_graph_, true);
790 func_graph_->set_manager(mng);
791 }
792 GetValidKernelNodes();
793 // call CostModel to get a split plan.
794 if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) {
795 split_plan_.clear();
796 need_inline_.clear();
797 return;
798 } else if (split_plan_.size() == 1 && !NeedInline(0)) {
799 // In this case, the CostModel decided to keep the whole graph unchanged.
800 split_plan_.clear();
801 need_inline_.clear();
802 return;
803 } else {
804 MS_LOG(DEBUG) << "CostModel split succeeded. The kernel is split to " << split_plan_.size() << " parts.";
805 }
806 MapNodeGroup();
807 GroupReturnNode();
808 GroupVirtualNodes();
809 }
810
IsValidKernelNode(const AnfNodePtr & node) const811 virtual bool IsValidKernelNode(const AnfNodePtr &node) const {
812 if (!node->isa<CNode>()) return false;
813 if (AnfAlgo::IsRealKernel(node)) return true;
814 return false;
815 }
816
GetValidKernelNodes()817 virtual void GetValidKernelNodes() {
818 topo_all_nodes_ = TopoSort(func_graph_->get_return());
819 topo_valid_nodes_.clear();
820 std::copy_if(topo_all_nodes_.begin(), topo_all_nodes_.end(), std::back_inserter(topo_valid_nodes_),
821 [this](const AnfNodePtr &node) { return IsValidKernelNode(node); });
822 }
823
MapNodeGroup()824 void MapNodeGroup() {
825 node_group_.clear();
826 for (size_t i = 0; i < split_plan_.size(); ++i) {
827 for (const auto &node : split_plan_[i]) {
828 node_group_[node] = i;
829 }
830 }
831 }
832
833 // group the return node and last MakeTuple node (if exists).
GroupReturnNode()834 virtual void GroupReturnNode() {
835 AnfNodePtrList outputs;
836 kernel::GetFuncGraphOutputNodes(func_graph_, &outputs);
837 auto ret_node = func_graph_->get_return();
838 auto output = func_graph_->output();
839 MS_EXCEPTION_IF_NULL(output);
840
841 if (IsValidKernelNode(output)) {
842 auto group_id = node_group_[ret_node] = node_group_[output];
843 split_plan_[group_id].emplace_back(ret_node);
844 return;
845 }
846 // assign the make_tuple node to a new group.
847 if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimMakeTuple)) {
848 auto group_id = split_plan_.size();
849 split_plan_.emplace_back(AnfNodePtrList{output, ret_node});
850 need_inline_.emplace_back(1);
851 node_group_[ret_node] = node_group_[output] = group_id;
852 return;
853 }
854 }
855
856 // assign virtual node to the same group of its input.
GroupVirtualNodes()857 virtual void GroupVirtualNodes() {
858 for (const auto &node : topo_all_nodes_) {
859 if (node_group_.count(node)) continue;
860 auto cnode = node->cast<CNodePtr>();
861 if (cnode == nullptr) continue;
862 bool found = false;
863 for (const auto &input : cnode->inputs()) {
864 auto iter = node_group_.find(input);
865 if (iter != node_group_.end()) {
866 node_group_[node] = iter->second;
867 split_plan_[iter->second].emplace_back(node);
868 found = true;
869 break;
870 }
871 }
872 if (!found) {
873 MS_LOG(WARNING) << cnode->fullname_with_scope() << " is ungrouped.";
874 }
875 }
876 }
877
CollectSplitFlags()878 virtual std::string CollectSplitFlags() {
879 const auto &flags = context::GraphKernelFlags::GetInstance();
880 nlohmann::json flag_json;
881 flag_json["dump_as_text"] = flags.dump_as_text;
882 flag_json["enable_stitch_fusion"] = flags.enable_stitch_fusion;
883 flag_json["enable_recompute_fusion"] = flags.enable_recompute_fusion;
884 return flag_json.dump();
885 }
886
887 std::shared_ptr<FuncGraph> func_graph_;
888 AnfNodePtrList topo_all_nodes_;
889 AnfNodePtrList topo_valid_nodes_;
890 std::unordered_map<AnfNodePtr, size_t> node_group_;
891 std::vector<int> need_inline_;
892 };
893
TrySplit(const CNodePtr & sub_root_cnode)894 bool TrySplit(const CNodePtr &sub_root_cnode) {
895 MS_LOG(DEBUG) << "Split process node: " << sub_root_cnode->fullname_with_scope();
896 auto splitter = Splitter::MakeSplitter(sub_root_cnode, std::make_shared<CostModelSplitSchemer>());
897 MS_EXCEPTION_IF_NULL(splitter);
898 bool result = splitter->Split();
899 MS_LOG(DEBUG) << "Split node completed, result: " << result;
900 return result;
901 }
902 } // namespace
903
Run(const FuncGraphPtr & func_graph)904 bool GraphKernelSplitter::Run(const FuncGraphPtr &func_graph) {
905 MS_EXCEPTION_IF_NULL(func_graph);
906 auto mng = func_graph->manager();
907 if (mng == nullptr) {
908 mng = Manage(func_graph, true);
909 func_graph->set_manager(mng);
910 }
911 auto todos = TopoSort(func_graph->get_return());
912
913 // Split subgraphs in reversed topo order,
914 // since the nodes behind the processing node may be modified when splitting.
915 bool changed = false;
916 for (auto iter = todos.crbegin(); iter != todos.crend(); ++iter) {
917 auto node = (*iter)->cast<CNodePtr>();
918 if (node != nullptr && AnfAlgo::IsGraphKernel(node)) {
919 changed = TrySplit(node) || changed;
920 }
921 }
922 mng->RemoveRoots();
923 mng->KeepRoots({func_graph});
924 return changed;
925 }
926 } // namespace opt
927 } // namespace mindspore
928