• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/parallel_fusion.h"
18 
19 #include <algorithm>
20 #include <list>
21 #include <queue>
22 #include <unordered_map>
23 #include <utility>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "backend/common/graph_kernel/graph_kernel_flags.h"
27 #include "kernel/kernel.h"
28 #include "backend/common/graph_kernel/graph_kernel_helper.h"
29 #include "kernel/common_utils.h"
30 #include "frontend/operator/ops.h"
31 #include "ir/func_graph_cloner.h"
32 #include "backend/common/graph_kernel/core/update_state_formatter.h"
33 #include "backend/common/graph_kernel/core/graph_builder.h"
34 
35 namespace mindspore::graphkernel {
36 namespace {
37 // Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512.
38 constexpr size_t CUDA_PARA_LIMIT = 512;
39 
ProcessThroughPassCNode(OrderedMap<AnfNodePtr,NodeRelation> * node_rels)40 void ProcessThroughPassCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
41   std::set<AnfNodePtr> to_be_erased;
42   for (const auto &[node, node_rel] : (*node_rels)) {
43     if (!IsOneOfPrimitiveCNode(
44           node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem})) {
45       continue;
46     }
47     (void)to_be_erased.insert(node);
48     for (const auto &pre : node_rel.pres) {
49       auto &pre_nexts = (*node_rels)[pre].nexts;
50       (void)pre_nexts.erase(node);
51       for (const auto &next : node_rel.nexts) {
52         (void)pre_nexts.insert(next);
53         auto &next_pres = (*node_rels)[next].pres;
54         (void)next_pres.erase(node);
55         (void)next_pres.insert(pre);
56       }
57     }
58   }
59   for (const auto &node : to_be_erased) {
60     (void)node_rels->erase(node);
61   }
62 }
63 
ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr,NodeRelation> * node_rels)64 void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
65   AnfNodePtrList latter_to_be_erased;
66   for (auto &[node, node_rel] : (*node_rels)) {
67     if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
68       continue;
69     }
70 
71     AnfNodePtrList check_next_list;
72     check_next_list.push_back(node);
73 
74     bool disinterested = false;
75     for (auto &successor : node_rel.nexts) {
76       if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
77         disinterested = true;
78         break;
79       }
80       check_next_list.push_back(successor);
81     }
82     if (disinterested) {
83       continue;
84     }
85 
86     if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
87                      [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
88       continue;
89     }
90 
91     latter_to_be_erased.push_back(node);
92   }
93 
94   // Delete Tail MakeTuple(including its getitem nodes).
95   for (const auto &node : latter_to_be_erased) {
96     for (auto &pre : (*node_rels)[node].pres) {
97       (void)(*node_rels)[pre].nexts.erase(node);
98     }
99 
100     // Tail MakeTuple is just be consumed by nothing or invalid getitem node.
101     for (auto &getitem : (*node_rels)[node].nexts) {
102       (void)node_rels->erase(getitem);
103     }
104 
105     (void)node_rels->erase(node);
106   }
107 }
108 
IsSingleInputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)109 bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
110   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
111     return true;
112   }
113   return false;
114 }
115 
IsSingleOutputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)116 bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
117   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
118     return true;
119   }
120   return false;
121 }
122 
IsMultiInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)123 bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
124   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
125     return true;
126   }
127   return false;
128 }
129 
IsMultiOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)130 bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
131   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
132     return true;
133   }
134   return false;
135 }
136 
IsNoInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)137 bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
138   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
139     return true;
140   }
141   return false;
142 }
143 
IsNoOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)144 bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
145   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
146     return true;
147   }
148   return false;
149 }
150 
ProcessLocalStructure(OrderedMap<AnfNodePtr,NodeRelation> * node_rels,std::set<AnfNodePtr> * virtual_noout_nodes,std::set<AnfNodePtr> * ignore_noin_nodes)151 void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes,
152                            std::set<AnfNodePtr> *ignore_noin_nodes) {
153   // 1. Local relation
154   // Graph as following left part, relation D->B and D->E(D is a no input node)
155   // will make B and E to be multiply inputs node.
156   // But for parallel, this local relation can ignore for B and E, which make
157   // them be able to be paralleled.
158   //
159   // ************************************
160   // *                                  *
161   // * |                    |           *
162   // * A   D                A      D    *
163   // * |  /|                |     / \   *
164   // * | C |                |    C   F  *
165   // * |/  /                |    |   |  *
166   // * B  F      ====>      B    x   x  *
167   // * | /                  |           *
168   // * |/                   |           *
169   // * E                    E           *
170   // * |                    |           *
171   // *                                  *
172   // ************************************
173   AnfNodePtrList no_input_nodes;
174   for (const auto &node_rel : *node_rels) {
175     auto &node = node_rel.first;
176     if (IsNoInputsNode(*node_rels, node)) {
177       no_input_nodes.push_back(node);
178     }
179   }
180 
181   std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
182 
183   for (const auto &ninode : no_input_nodes) {
184     AnfNodePtrList cnexts((*node_rels)[ninode].nexts.cbegin(), (*node_rels)[ninode].nexts.cend());
185     for (const auto &n : cnexts) {
186       AnfNodePtr serial_tail = ninode;
187       AnfNodePtr cur_node = n;
188       while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
189         serial_tail = cur_node;
190         cur_node = *((*node_rels)[cur_node].nexts.begin());
191       }
192       (void)latter_delete.emplace_back(serial_tail, cur_node);
193     }
194   }
195 
196   // Delete relation.
197   for (const auto &[serial_tail, cur_node] : latter_delete) {
198     (void)virtual_noout_nodes->insert(serial_tail);
199     (void)ignore_noin_nodes->insert(cur_node);
200     (void)(*node_rels)[serial_tail].nexts.erase(cur_node);
201     (void)(*node_rels)[cur_node].pres.erase(serial_tail);
202     MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
203                  << cur_node->fullname_with_scope();
204   }
205 }
206 
GetInterestNodeIds(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const std::set<AnfNodePtr> & virtual_noout_nodes,const std::set<AnfNodePtr> & ignore_noin_nodes)207 std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
208   const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
209   const std::set<AnfNodePtr> &ignore_noin_nodes) {
210   AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
211   std::list<std::function<void(const AnfNodePtr &)>> func_list = {
212     [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
213       if (IsMultiInputsNode(node_rels, node)) {
214         multi_inputs_nodes.push_back(node);
215       }
216     },
217     [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
218       if (IsMultiOutputsNode(node_rels, node)) {
219         multi_outputs_nodes.push_back(node);
220       }
221     },
222     [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
223       if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
224         no_input_nodes.push_back(node);
225       }
226     },
227     [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
228       if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
229         no_output_nodes.push_back(node);
230       }
231     }};
232 
233   for (const auto &node_rel : node_rels) {
234     for (const auto &func : func_list) {
235       func(node_rel.first);
236     }
237   }
238 
239   return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
240 }
241 
WhiteOpsFilter(const AnfNodePtr & node)242 bool WhiteOpsFilter(const AnfNodePtr &node) { return common::AnfAlgo::IsGraphKernel(node); }
243 
244 // Parallel cannot work with stitching for now.
Parallelizable(const AnfNodePtr & node)245 bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !IsBufferStitchNode(node); }
246 
SearchFromNodes(const AnfNodePtrList & nodes,const std::function<bool (const AnfNodePtr &)> & filter_func,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::set<AnfNodePtr> * seen)247 std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
248                                             const std::function<bool(const AnfNodePtr &)> &filter_func,
249                                             const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
250                                             std::set<AnfNodePtr> *seen) {
251   // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
252   // For backward search, the other multi-inputs node can be contained in.
253   // For forward search, the other multi-outputs node can be contained in.
254   auto get_contain_node_set = [is_backward](const NodeRelation &info) { return is_backward ? info.pres : info.nexts; };
255   auto get_exclude_node_set = [is_backward](const NodeRelation &info) { return is_backward ? info.nexts : info.pres; };
256 
257   std::vector<AnfNodePtrList> group;
258   for (const auto &node : nodes) {
259     AnfNodePtrList stream;
260     AnfNodePtr n = node;
261     for (auto iter = node_rels.find(n);
262          seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
263          iter = node_rels.find(n)) {
264       if (filter_func(n)) {
265         stream.push_back(n);
266         (void)seen->insert(n);
267       }
268       if (get_contain_node_set(iter->second).size() != 1) {
269         break;
270       }
271       n = *(get_contain_node_set(iter->second).cbegin());
272     }
273     if (stream.size() > 0) {
274       group.push_back(stream);
275     }
276   }
277 
278   if (group.size() == 1) {
279     for (const auto &drop : group[0]) {
280       (void)seen->erase(drop);
281     }
282     group.clear();
283   }
284 
285   return group;
286 }
287 
SearchStreamFromMultiRelationNode(const AnfNodePtrList & multi_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)288 void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
289                                        const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
290                                        std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
291   auto get_related_nodes = [is_backward](const NodeRelation &info) { return is_backward ? info.pres : info.nexts; };
292   for (const auto &node : multi_nodes) {
293     if (auto iter = node_rels.find(node); iter != node_rels.end()) {
294       const auto &pre_nodes = get_related_nodes(iter->second);
295       AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
296       groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen));
297     }
298   }
299 
300   // Erase empty groups.
301   for (auto iter = groups->begin(); iter != groups->end();) {
302     if (iter->size() == 0) {
303       iter = groups->erase(iter);
304     } else {
305       ++iter;
306     }
307   }
308 }
309 
SearchStreamFromUnidirectionalNode(const AnfNodePtrList & ud_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)310 void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
311                                         const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
312                                         std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
313   groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen));
314 
315   // Erase empty groups.
316   for (auto iter = groups->begin(); iter != groups->end();) {
317     if (iter->size() == 0) {
318       iter = groups->erase(iter);
319     } else {
320       ++iter;
321     }
322   }
323 }
324 
DumpNode(const AnfNodePtr & node)325 std::string DumpNode(const AnfNodePtr &node) {
326   auto cnode = node->cast<CNodePtr>();
327   MS_EXCEPTION_IF_NULL(cnode);
328   std::stringstream buf;
329   buf << (common::AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
330       << cnode->ToString();
331   return buf.str();
332 }
333 
DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> & groups,const std::string & title="")334 void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups, const std::string &title = "") {
335   MS_LOG(INFO) << "[" << title << "]"
336                << "There are " << groups.size() << " parallel groups, their detail is: ";
337   int i = 0;
338   for (const auto &group : groups) {
339     std::stringstream buf;
340     buf << "[" << i << " group] " << group.size() << ":\n";
341     for (const auto &nodes : group) {
342       buf << "  " << nodes.size() << ": [<";
343       for (const auto &node : nodes) {
344         buf << "(" << DumpNode(node) << ") -> ";
345       }
346       buf << ">]\n";
347     }
348     i++;
349     MS_LOG(INFO) << buf.str();
350   }
351 }
352 
DumpParallelFusionDetail(const AnfNodePtrList & source,const AnfNodePtr & target)353 void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
354   std::stringstream buf;
355   buf << "Parallel fusion detail: ";
356   for (const auto &node : source) {
357     buf << "(" << DumpNode(node) << ") + ";
358   }
359   buf << "==>"
360       << "(" << DumpNode(target) << ")";
361   MS_LOG(INFO) << buf.str();
362 }
363 
ParameterLimit(const AnfNodePtrList & nodes)364 inline bool ParameterLimit(const AnfNodePtrList &nodes) {
365   if (nodes.empty()) {
366     MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition.";
367   }
368 
369   bool res = true;
370   auto processor_type = AnfAlgo::GetProcessor(nodes[0]);
371   if (processor_type == kernel::Processor::CUDA) {
372     // The number of inputs and outputs for a valid kernel should be less than cuda's limit.
373     size_t para_count = 0;
374     for (const auto &node : nodes) {
375       para_count += common::AnfAlgo::GetInputTensorNum(node);
376       para_count += AnfAlgo::GetOutputTensorNum(node);
377     }
378     res = para_count <= CUDA_PARA_LIMIT;
379   }
380 
381   return res;
382 }
383 
ExtraFusionCondition(const AnfNodePtrList & nodes)384 bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); }
385 }  // namespace
386 
GenAnalysisGraph(const AnfNodePtrList & nodes)387 OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
388   // Based on anf node input information, build a simple graph for latter analyzation.
389   OrderedMap<AnfNodePtr, NodeRelation> node_rels;
390   auto get_info = [&node_rels](const AnfNodePtr &node) {
391     if (node_rels.count(node) == 0) {
392       (void)node_rels.emplace(node, NodeRelation());
393     }
394     return &(node_rels[node]);
395   };
396 
397   for (const auto &node : nodes) {
398     if (!node->isa<CNode>()) {
399       continue;
400     }
401 
402     auto prior_node = get_info(node);
403     for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
404       if (!input->isa<CNode>()) {
405         continue;
406       }
407       auto behind_node = get_info(input);
408       (void)prior_node->pres.insert(input);
409       (void)behind_node->nexts.insert(node);
410     }
411   }
412 
413   ProcessThroughPassCNode(&node_rels);
414   ProcessTailMakeTupleCNode(&node_rels);
415   ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
416 
417   return node_rels;
418 }
419 
SearchParallelGroups(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels)420 std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
421   const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
422   // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
423   auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
424     GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
425 
426   // Get streams and group them
427   std::set<AnfNodePtr> seen;
428   std::vector<std::vector<AnfNodePtrList>> groups;
429 
430   SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
431   SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
432   SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
433   SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
434 
435   DumpParallelGroups(groups, "Dependency Analyze");
436   return groups;
437 }
438 
GetAvaliableNodesByOffset(int start,const std::vector<size_t> & offsets,const std::vector<bool> & used,const AnfNodePtrList & nodes,const std::set<int> & excludes) const439 std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
440   int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
441   const std::set<int> &excludes) const {
442   // Get unused nodes by offset index, the result will contain the node with start index.
443   int node_limit = static_cast<int>(nodes.size());
444   if (start >= node_limit) {
445     MS_LOG(EXCEPTION) << "Index offset should be less than the limit of given nodes " << node_limit << ", but got "
446                       << start;
447   }
448   AnfNodePtrList target_nodes = {nodes[IntToSize(start)]};
449   std::vector<int> valid_indices;
450   std::vector<size_t> unused;
451   for (size_t i = IntToSize(start); i < used.size(); ++i) {
452     if (!used[i] && excludes.count(i) == 0) {
453       unused.push_back(i);
454     }
455   }
456   size_t limit = unused.size();
457   for (auto offset : offsets) {
458     if (offset >= limit) {
459       MS_LOG(EXCEPTION) << "Index offset should be less than the limit of unused nodes " << limit << ", but got "
460                         << offset;
461     }
462     if (SizeToInt(unused[offset]) >= node_limit) {
463       MS_LOG(EXCEPTION) << "Index offset should be less than the limit of nodes " << node_limit << ", but got "
464                         << unused[offset];
465     }
466     valid_indices.push_back(unused[offset]);
467     target_nodes.push_back(nodes[unused[offset]]);
468   }
469 
470   return std::make_tuple(target_nodes, valid_indices);
471 }
472 
DoSearchInSortedCandidates(size_t origin_size,const AnfNodePtrList & candidates,std::map<AnfNodePtr,int> * origin_indices,std::map<AnfNodePtr,int> * sorted_indices)473 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
474   size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
475   std::map<AnfNodePtr, int> *sorted_indices) {
476   auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
477     MS_EXCEPTION_IF_NULL(node);
478     if (indices->find(node) == indices->end()) {
479       MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
480     }
481     return (*indices)[node];
482   };
483 
484   std::vector<ParallelInfo> parallel_infos;
485   std::vector<bool> origin_candidates_used(origin_size, false);
486   std::vector<bool> sorted_candidates_used(candidates.size(), false);
487   size_t offset;
488   for (size_t i = 0; i < candidates.size(); i += offset + 1) {
489     offset = 0;
490     if (sorted_candidates_used[i]) {
491       continue;
492     }
493 
494     int max_benefit = 0;
495     ParallelInfo best_parallel_info;
496     size_t unused_num = 0;
497     for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
498       unused_num += sorted_candidates_used[j] ? 0 : 1;
499     }
500     if (unused_num < 1) {
501       break;
502     }
503 
504     unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
505 
506     size_t begin = 1, end = unused_num;
507     while (begin <= end) {
508       size_t mid = (begin + end) / 2;
509       std::vector<size_t> tc(mid);
510       for (size_t idx = 0; idx < mid; idx++) {
511         tc[idx] = idx + 1;
512       }
513       AnfNodePtrList other_candidates;
514       std::tie(other_candidates, std::ignore) =
515         GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
516       if (ExtraFusionCondition(other_candidates)) {
517         int benefit;
518         std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
519         if (benefit > 0) {
520           begin = mid + 1;
521           continue;
522         }
523       }
524       end = mid - 1;
525     }
526 
527     if (begin > 1) {
528       std::vector<size_t> tc(begin - 1);
529       for (size_t idx = 0; idx < begin - 1; idx++) {
530         tc[idx] = idx + 1;
531       }
532       AnfNodePtrList other_candidates;
533       std::tie(other_candidates, std::ignore) =
534         GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
535       auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
536       if (benefit <= 0) {
537         MS_LOG(EXCEPTION) << "Internal error in candidate search! benefit should be greater than 0, but got "
538                           << benefit;
539       }
540       max_benefit = benefit;
541       best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
542       offset = begin - 1;
543     }
544 
545     if (max_benefit > 0) {
546       parallel_infos.push_back(best_parallel_info);
547       for (const auto &node : best_parallel_info.nodes()) {
548         sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true;
549         origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true;
550       }
551     }
552   }
553 
554   // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
555   if (parallel_infos.size() == 0) {
556     origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true;
557   }
558 
559   return std::make_tuple(origin_candidates_used, parallel_infos);
560 }
561 
SearchFuseNodesInCandidates(const AnfNodePtrList & cs)562 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
563   const AnfNodePtrList &cs) {
564   std::map<AnfNodePtr, int> origin_indices;
565   std::vector<size_t> indices;
566   for (size_t i = 0; i < cs.size(); ++i) {
567     if (cs[i]) {
568       origin_indices[cs[i]] = SizeToInt(i);
569       indices.push_back(i);
570     }
571   }
572 
573   // A calculated heavy node can cover more lighter nodes' cost, so sort them first.
574   std::map<size_t, int64_t> cal_amounts;
575   for (auto id : indices) {
576     cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
577   }
578   std::sort(indices.begin(), indices.end(),
579             [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
580 
581   AnfNodePtrList candidates;
582   for (size_t i = 0; i < indices.size(); ++i) {
583     candidates.push_back(cs[indices[i]]);
584   }
585 
586   std::map<AnfNodePtr, int> sorted_indices;
587   for (size_t i = 0; i < candidates.size(); ++i) {
588     sorted_indices[candidates[i]] = SizeToInt(i);
589   }
590 
591   return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
592 }
593 
SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> & group,std::vector<ParallelInfo> * parallel_infos)594 void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
595                                                       std::vector<ParallelInfo> *parallel_infos) {
596   std::vector<AnfNodePtrList::const_iterator> tails;
597   std::vector<AnfNodePtrList::const_iterator> ended;
598   for (const auto &node_list : group) {
599     tails.push_back(node_list.begin());
600     ended.push_back(node_list.end());
601   }
602   auto get_candidates = [&tails, &ended]() {
603     AnfNodePtrList candidates;
604     for (size_t id = 0; id < tails.size(); ++id) {
605       candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
606     }
607     return candidates;
608   };
609   auto update_tails = [&tails](const std::vector<bool> &used) {
610     if (used.size() != tails.size()) {
611       MS_LOG(EXCEPTION) << "Judged nodes size is different from left ones size: " << used.size() << " vs "
612                         << tails.size();
613     }
614     for (size_t id = 0; id < used.size(); ++id) {
615       if (used[id]) {
616         ++tails[id];
617       }
618     }
619   };
620   auto valid_candidate_num = [](const AnfNodePtrList &cs) {
621     return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
622   };
623 
624   auto candidates = get_candidates();
625   while (valid_candidate_num(candidates) > 1) {
626     auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
627     (void)std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
628                          [](const ParallelInfo &pi) { return pi; });
629     update_tails(used);
630     candidates = get_candidates();
631   }
632 }
633 
SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> & groups)634 std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
635   const std::vector<std::vector<AnfNodePtrList>> &groups) {
636   // Find core-fusable groups with cost model.
637   std::vector<ParallelInfo> parallel_infos;
638   for (const auto &group : groups) {
639     SearchFuseNodesInParallelGroup(group, &parallel_infos);
640   }
641 
642   return parallel_infos;
643 }
644 
SetFusedParallelOpAttrToReturnNode(const ParallelInfo & parallel_info)645 void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info) {
646   AnfNodePtr attach_node;
647   // Dim info should be attach to each segment's output.
648   for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
649     const auto &fuse_nodes = parallel_info.nodes();
650     std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
651     if (!common::AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
652       attach_node = fuse_nodes[i];
653       SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
654     } else {
655       auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
656       auto out_node = node_g->output();
657       if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
658         auto inputs = out_node->cast<CNodePtr>()->inputs();
659         for (size_t j = 1; j < inputs.size(); ++j) {
660           SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
661         }
662         attach_node = inputs[1];
663       } else {
664         attach_node = out_node;
665         SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
666       }
667     }
668   }
669 
670   // Fusion info is ok to attach to one of the segments.
671   SetFusionInfoAttrToNode(attach_node, parallel_info);
672 }
673 
SetFusionInfoAttrToNode(const AnfNodePtr & node,const ParallelInfo & parallel_info)674 void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info) {
675   auto fusion_type = parallel_info.fusion_info()->FusionType();
676   common::AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
677   if (parallel_info.fusion_info()->ExistTypeInfo()) {
678     if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
679       common::AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
680                                    MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
681     }
682   }
683 }
684 
CreateParallelOpSubGraphs(const std::vector<ParallelInfo> & parallel_infos,const std::shared_ptr<session::KernelGraph> & kernel_graph)685 bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
686                                                  const std::shared_ptr<session::KernelGraph> &kernel_graph) {
687   bool changed = false;
688 
689   for (size_t i = 0; i < parallel_infos.size(); ++i) {
690     const auto &fuse_nodes = parallel_infos[i].nodes();
691     if (fuse_nodes.size() <= 1) {
692       continue;
693     }
694     changed = true;
695     SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
696     auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel");
697     common::AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
698     DumpParallelFusionDetail(fuse_nodes, sg_node);
699   }
700 
701   return changed;
702 }
703 
CollectCapturedNodes(const std::vector<ParallelInfo> & infos)704 std::set<AnfNodePtr> CollectCapturedNodes(const std::vector<ParallelInfo> &infos) {
705   std::set<AnfNodePtr> captured;
706   (void)std::for_each(infos.cbegin(), infos.cend(), [&captured](const ParallelInfo &info) {
707     captured.insert(info.nodes().cbegin(), info.nodes().cend());
708   });
709   return captured;
710 }
711 
GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const std::set<AnfNodePtr> & exclude)712 std::vector<std::vector<AnfNodePtrList>> GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels,
713                                                                 const std::set<AnfNodePtr> &exclude) {
714   std::vector<std::vector<AnfNodePtrList>> groups;
715   // BFS
716   std::queue<AnfNodePtr> node_que;
717   std::unordered_map<AnfNodePtr, int> outdegrees;
718   for (const auto &[node, ref] : node_rels) {
719     outdegrees[node] = SizeToInt(ref.nexts.size());
720     if (outdegrees[node] == 0) {
721       node_que.push(node);
722     }
723   }
724 
725   int total_node_num = SizeToInt(node_rels.size());
726   while (!node_que.empty()) {
727     std::vector<AnfNodePtrList> group;
728     int node_size = SizeToInt(node_que.size());
729     while (node_size != 0) {
730       node_size--;
731       auto node = node_que.front();
732       node_que.pop();
733       if (exclude.count(node) == 0 && Parallelizable(node)) {
734         (void)group.emplace_back(AnfNodePtrList({node}));
735       }
736       --total_node_num;
737       auto iter = node_rels.find(node);
738       if (iter == node_rels.end()) {
739         MS_LOG(EXCEPTION) << "Internal error in node relationship!";
740       }
741       for (const auto &pre : iter->second.pres) {
742         if (--outdegrees[pre] == 0) {
743           node_que.push(pre);
744         }
745       }
746     }
747     if (!group.empty()) {
748       groups.push_back(group);
749     }
750   }
751 
752   if (total_node_num > 0) {
753     MS_LOG(EXCEPTION) << "There is circle in analyze graph!";
754   }
755   DumpParallelGroups(groups, "BFS");
756   return groups;
757 }
758 
Run(const FuncGraphPtr & graph)759 bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
760   MS_EXCEPTION_IF_NULL(graph);
761   parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level;
762 
763   (void)std::make_shared<ShrinkUpdateState>()->Run(graph);
764   auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
765   MS_EXCEPTION_IF_NULL(kernel_graph);
766 
767   cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
768   MS_EXCEPTION_IF_NULL(cost_model_ptr_);
769 
770   auto nodes = TopoSort(kernel_graph->get_return());
771   std::reverse(nodes.begin(), nodes.end());
772 
773   auto node_rels = GenAnalysisGraph(nodes);
774   auto groups = SearchParallelGroups(node_rels);
775   auto parallel_infos = SearchFusableParallelCNodes(groups);
776 
777   // Search in BFS for left nodes.
778   if (parallel_level_ > 0) {
779     auto exclued_nodes = CollectCapturedNodes(parallel_infos);
780     auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes);
781     auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs);
782 
783     (void)parallel_infos.insert(parallel_infos.cend(), bfs_parallel_infos.cbegin(), bfs_parallel_infos.cend());
784   }
785 
786   // Create core-fuse subgraph and change origin graph.
787   bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
788   (void)std::make_shared<SpreadUpdateState>()->Run(graph);
789   return changed;
790 }
791 }  // namespace mindspore::graphkernel
792