• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/parallel_fusion.h"
18 
19 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
20 #include "frontend/operator/ops.h"
21 #include "ir/func_graph_cloner.h"
22 #include "vm/segment_runner.h"
23 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
24 
25 namespace mindspore {
26 namespace opt {
27 namespace {
IsOneOf(const AnfNodePtr & node,const std::vector<PrimitivePtr> & ops_prim)28 bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) {
29   return std::any_of(ops_prim.cbegin(), ops_prim.cend(),
30                      [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
31 }
32 
ProcessThroughPassCNode(const std::function<bool (const AnfNodePtr &)> & pass_fn,OrderedMap<AnfNodePtr,NodeRelation> * node_rels)33 void ProcessThroughPassCNode(const std::function<bool(const AnfNodePtr &)> &pass_fn,
34                              OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
35   std::set<AnfNodePtr> latter_to_be_erased;
36   for (const auto &[node, node_rel] : (*node_rels)) {
37     if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) {
38       continue;
39     }
40 
41     auto nexts = node_rel.nexts;
42     std::vector<AnfNodePtr> pre_nodes;
43     std::queue<AnfNodePtr> node_que;
44     node_que.push(node);
45 
46     // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes.
47     while (!node_que.empty()) {
48       auto cur_node = node_que.front();
49       node_que.pop();
50 
51       if (!pass_fn(cur_node)) {
52         pre_nodes.push_back(cur_node);
53         continue;
54       }
55 
56       latter_to_be_erased.insert(cur_node);
57       auto predecessors = (*node_rels)[cur_node].pres;
58       if (predecessors.empty()) {
59         continue;
60       }
61 
62       for (const auto &pre_node : predecessors) {
63         (*node_rels)[cur_node].pres.erase(pre_node);
64         (*node_rels)[pre_node].nexts.erase(cur_node);
65         node_que.push(pre_node);
66       }
67     }
68 
69     // Modify the relation: delete node <-> next_node, add pre node <-> next_node.
70     for (const auto &next_node : nexts) {
71       (*node_rels)[next_node].pres.erase(node);
72       for (const auto &cur_node : pre_nodes) {
73         (*node_rels)[next_node].pres.insert(cur_node);
74         (*node_rels)[cur_node].nexts.insert(next_node);
75       }
76     }
77   }
78 
79   for (const auto &node : latter_to_be_erased) {
80     node_rels->erase(node);
81   }
82 }
83 
ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr,NodeRelation> * node_rels)84 void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
85   AnfNodePtrList latter_to_be_erased;
86   for (auto &[node, node_rel] : (*node_rels)) {
87     if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
88       continue;
89     }
90 
91     AnfNodePtrList check_next_list;
92     check_next_list.push_back(node);
93 
94     bool disinterested = false;
95     for (auto &successor : node_rel.nexts) {
96       if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
97         disinterested = true;
98         break;
99       }
100       check_next_list.push_back(successor);
101     }
102     if (disinterested) {
103       continue;
104     }
105 
106     if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
107                      [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
108       continue;
109     }
110 
111     latter_to_be_erased.push_back(node);
112   }
113 
114   // Delete Tail MakeTuple(including its getitem nodes).
115   for (const auto &node : latter_to_be_erased) {
116     for (auto &pre : (*node_rels)[node].pres) {
117       (*node_rels)[pre].nexts.erase(node);
118     }
119 
120     // Tail MakeTuple is just be consumed by nothing or invalid getitem node.
121     for (auto &getitem : (*node_rels)[node].nexts) {
122       node_rels->erase(getitem);
123     }
124 
125     node_rels->erase(node);
126   }
127 }
128 
IsSingleInputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)129 bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
130   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
131     return true;
132   }
133   return false;
134 }
135 
IsSingleOutputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)136 bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
137   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
138     return true;
139   }
140   return false;
141 }
142 
IsMultiInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)143 bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
144   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
145     return true;
146   }
147   return false;
148 }
149 
IsMultiOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)150 bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
151   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
152     return true;
153   }
154   return false;
155 }
156 
IsNoInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)157 bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
158   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
159     return true;
160   }
161   return false;
162 }
163 
IsNoOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)164 bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
165   if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
166     return true;
167   }
168   return false;
169 }
170 
ProcessLocalStructure(OrderedMap<AnfNodePtr,NodeRelation> * node_rels,std::set<AnfNodePtr> * virtual_noout_nodes,std::set<AnfNodePtr> * ignore_noin_nodes)171 void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes,
172                            std::set<AnfNodePtr> *ignore_noin_nodes) {
173   // 1. Local relation
174   // Graph as following left part, relation D->B and D->E(D is a no input node)
175   // will make B and E to be multiply inputs node.
176   // But for parallel, this local relation can ignore for B and E, which make
177   // them be able to be paralleled.
178   //
179   // ************************************
180   // *                                  *
181   // * |                    |           *
182   // * A   D                A      D    *
183   // * |  /|                |     / \   *
184   // * | C |                |    C   F  *
185   // * |/  /                |    |   |  *
186   // * B  F      ====>      B    x   x  *
187   // * | /                  |           *
188   // * |/                   |           *
189   // * E                    E           *
190   // * |                    |           *
191   // *                                  *
192   // ************************************
193   AnfNodePtrList no_input_nodes;
194   for (const auto &node_rel : *node_rels) {
195     auto &node = node_rel.first;
196     if (IsNoInputsNode(*node_rels, node)) {
197       no_input_nodes.push_back(node);
198     }
199   }
200 
201   std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
202 
203   for (const auto &ninode : no_input_nodes) {
204     AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end());
205     for (const auto &n : cnexts) {
206       AnfNodePtr serial_tail = ninode;
207       AnfNodePtr cur_node = n;
208       while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
209         serial_tail = cur_node;
210         cur_node = *((*node_rels)[cur_node].nexts.begin());
211       }
212       latter_delete.emplace_back(serial_tail, cur_node);
213     }
214   }
215 
216   // Delete relation.
217   for (const auto &[serial_tail, cur_node] : latter_delete) {
218     virtual_noout_nodes->insert(serial_tail);
219     ignore_noin_nodes->insert(cur_node);
220     (*node_rels)[serial_tail].nexts.erase(cur_node);
221     (*node_rels)[cur_node].pres.erase(serial_tail);
222     MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
223                  << cur_node->fullname_with_scope();
224   }
225 }
226 
GetInterestNodeIds(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const std::set<AnfNodePtr> & virtual_noout_nodes,const std::set<AnfNodePtr> & ignore_noin_nodes)227 std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
228   const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
229   const std::set<AnfNodePtr> &ignore_noin_nodes) {
230   AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
231   std::list<std::function<void(const AnfNodePtr &)>> func_list = {
232     [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
233       if (IsMultiInputsNode(node_rels, node)) {
234         multi_inputs_nodes.push_back(node);
235       }
236     },
237     [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
238       if (IsMultiOutputsNode(node_rels, node)) {
239         multi_outputs_nodes.push_back(node);
240       }
241     },
242     [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
243       if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
244         no_input_nodes.push_back(node);
245       }
246     },
247     [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
248       if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
249         no_output_nodes.push_back(node);
250       }
251     }};
252 
253   for (const auto &node_rel : node_rels) {
254     for (const auto &func : func_list) {
255       func(node_rel.first);
256     }
257   }
258 
259   return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
260 }
261 
WhiteOpsFilter(const AnfNodePtr & node)262 bool WhiteOpsFilter(const AnfNodePtr &node) {
263   std::vector<PrimitivePtr> whiteable_ops = {};  // Not special for now.
264   return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
265 }
266 
SearchFromNodes(const AnfNodePtrList & nodes,const std::function<bool (const AnfNodePtr &)> & filter_func,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::set<AnfNodePtr> * seen)267 std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
268                                             const std::function<bool(const AnfNodePtr &)> &filter_func,
269                                             const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
270                                             std::set<AnfNodePtr> *seen) {
271   // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
272   // For backward search, the other multi-inputs node can be contained in.
273   // For forward search, the other multi-outputs node can be contained in.
274   auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; }
275                                           : [](const NodeRelation &info) { return info.nexts; };
276   auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; }
277                                           : [](const NodeRelation &info) { return info.pres; };
278   std::vector<AnfNodePtrList> group;
279   for (const auto &node : nodes) {
280     AnfNodePtrList stream;
281     AnfNodePtr n = node;
282     for (auto iter = node_rels.find(n);
283          seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
284          iter = node_rels.find(n)) {
285       if (filter_func(n)) {
286         stream.push_back(n);
287         seen->insert(n);
288       }
289       if (get_contain_node_set(iter->second).size() != 1) {
290         break;
291       }
292       n = *(get_contain_node_set(iter->second).begin());
293     }
294     if (stream.size() > 0) {
295       group.push_back(stream);
296     }
297   }
298 
299   if (group.size() == 1) {
300     for (const auto &drop : group[0]) {
301       seen->erase(drop);
302     }
303     group.clear();
304   }
305 
306   return group;
307 }
308 
SearchStreamFromMultiRelationNode(const AnfNodePtrList & multi_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)309 void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
310                                        const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
311                                        std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
312   auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; }
313                                        : [](const NodeRelation &info) { return info.nexts; };
314   for (const auto &node : multi_nodes) {
315     if (auto iter = node_rels.find(node); iter != node_rels.end()) {
316       const auto &pre_nodes = get_related_nodes(iter->second);
317       AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
318       groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
319     }
320   }
321 
322   // Erase empty groups.
323   for (auto iter = groups->begin(); iter != groups->end();) {
324     if (iter->size() == 0) {
325       iter = groups->erase(iter);
326     } else {
327       ++iter;
328     }
329   }
330 }
331 
SearchStreamFromUnidirectionalNode(const AnfNodePtrList & ud_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)332 void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
333                                         const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
334                                         std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
335   groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
336 
337   // Erase empty groups.
338   for (auto iter = groups->begin(); iter != groups->end();) {
339     if (iter->size() == 0) {
340       iter = groups->erase(iter);
341     } else {
342       ++iter;
343     }
344   }
345 }
346 
DumpNode(const AnfNodePtr & node)347 std::string DumpNode(const AnfNodePtr &node) {
348   auto cnode = node->cast<CNodePtr>();
349   MS_EXCEPTION_IF_NULL(cnode);
350   std::stringstream buf;
351   buf << (AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
352       << cnode->ToString();
353   return buf.str();
354 }
355 
DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> & groups)356 void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups) {
357   MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: ";
358   int i = 0;
359   for (const auto group : groups) {
360     std::stringstream buf;
361     buf << "[" << i << " group] " << group.size() << ":\n";
362     for (const auto nodes : group) {
363       buf << "  " << nodes.size() << ": [<";
364       for (const auto node : nodes) {
365         buf << "(" << DumpNode(node) << ") -> ";
366       }
367       buf << ">]\n";
368     }
369     i++;
370     MS_LOG(INFO) << buf.str();
371   }
372 }
373 
DumpParallelFusionDetail(const AnfNodePtrList & source,const AnfNodePtr & target)374 void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
375   std::stringstream buf;
376   buf << "Parallel fusion detail: ";
377   for (const auto &node : source) {
378     buf << "(" << DumpNode(node) << ") + ";
379   }
380   buf << "==>"
381       << "(" << DumpNode(target) << ")";
382   MS_LOG(INFO) << buf.str();
383 }
384 }  // namespace
385 
GenAnalysisGraph(const AnfNodePtrList & nodes)386 OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
387   // Based on anf node input information, build a simple graph for latter analyzation.
388   OrderedMap<AnfNodePtr, NodeRelation> node_rels;
389   auto get_info = [&node_rels](const AnfNodePtr &node) {
390     if (node_rels.count(node) == 0) {
391       (void)node_rels.emplace(node, NodeRelation());
392     }
393     return &(node_rels[node]);
394   };
395 
396   for (const auto &node : nodes) {
397     if (!node->isa<CNode>()) {
398       continue;
399     }
400 
401     auto prior_node = get_info(node);
402     for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
403       if (!input->isa<CNode>()) {
404         continue;
405       }
406       auto behind_node = get_info(input);
407       prior_node->pres.insert(input);
408       behind_node->nexts.insert(node);
409     }
410   }
411 
412   ProcessThroughPassCNode(
413     [](const AnfNodePtr &node) {
414       return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
415     },
416     &node_rels);
417   ProcessTailMakeTupleCNode(&node_rels);
418   ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
419 
420   return node_rels;
421 }
422 
SearchParallelGroups(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels)423 std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
424   const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
425   // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
426   auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
427     GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
428 
429   // Get streams and group them
430   std::set<AnfNodePtr> seen;
431   std::vector<std::vector<AnfNodePtrList>> groups;
432 
433   SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
434   SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
435   SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
436   SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
437 
438   DumpParallelGroups(groups);
439   return groups;
440 }
441 
GetAvaliableNodesByOffset(int start,const std::vector<size_t> & offsets,const std::vector<bool> & used,const AnfNodePtrList & nodes,const std::set<int> & excludes)442 std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
443   int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
444   const std::set<int> &excludes) {
445   // Get unused nodes by offset index, the result will contain the node with start index.
446   int node_limit = static_cast<int>(nodes.size());
447   if (start >= node_limit) {
448     MS_LOG(EXCEPTION) << "Index offset is exceed the limit of given nodes.";
449   }
450   AnfNodePtrList target_nodes = {nodes[IntToSize(start)]};
451   std::vector<int> valid_indices;
452   std::vector<size_t> unused;
453   for (size_t i = IntToSize(start); i < used.size(); ++i) {
454     if (!used[i] && excludes.count(i) == 0) {
455       unused.push_back(i);
456     }
457   }
458   size_t limit = unused.size();
459   for (auto offset : offsets) {
460     if (offset >= limit) {
461       MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes.";
462     }
463     if (SizeToInt(unused[offset]) >= node_limit) {
464       MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes.";
465     }
466     valid_indices.push_back(unused[offset]);
467     target_nodes.push_back(nodes[unused[offset]]);
468   }
469 
470   return std::make_tuple(target_nodes, valid_indices);
471 }
472 
DoSearchInSortedCandidates(size_t origin_size,const AnfNodePtrList & candidates,std::map<AnfNodePtr,int> * origin_indices,std::map<AnfNodePtr,int> * sorted_indices)473 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
474   size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
475   std::map<AnfNodePtr, int> *sorted_indices) {
476   auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
477     MS_EXCEPTION_IF_NULL(node);
478     if (indices->find(node) == indices->end()) {
479       MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
480     }
481     return (*indices)[node];
482   };
483 
484   std::vector<ParallelInfo> parallel_infos;
485   std::vector<bool> origin_candidates_used(origin_size, false);
486   std::vector<bool> sorted_candidates_used(candidates.size(), false);
487 
488   for (size_t i = 0; i < candidates.size(); ++i) {
489     if (sorted_candidates_used[i]) {
490       continue;
491     }
492 
493     int max_benefit = 0;
494     ParallelInfo best_parallel_info;
495     size_t unused_num = 0;
496     for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
497       unused_num += sorted_candidates_used[j] ? 0 : 1;
498     }
499     if (unused_num < 1) {
500       break;
501     }
502 
503     unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
504 
505     size_t begin = 1, end = unused_num;
506     while (begin <= end) {
507       size_t mid = (begin + end) / 2;
508       std::vector<size_t> tc(mid);
509       std::iota(tc.begin(), tc.end(), 1);
510       AnfNodePtrList other_candidates;
511       std::tie(other_candidates, std::ignore) =
512         GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
513       int benefit;
514       std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
515       if (benefit > 0) {
516         begin = mid + 1;
517       } else {
518         end = mid - 1;
519       }
520     }
521 
522     if (begin > 1) {
523       std::vector<size_t> tc(begin - 1);
524       std::iota(tc.begin(), tc.end(), 1);
525       AnfNodePtrList other_candidates;
526       std::tie(other_candidates, std::ignore) =
527         GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
528       auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
529       if (benefit <= 0) {
530         MS_LOG(EXCEPTION) << "Internal error in candidate search!";
531       }
532       max_benefit = benefit;
533       best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
534       i += begin - 1;
535     }
536 
537     if (max_benefit > 0) {
538       parallel_infos.push_back(best_parallel_info);
539       for (const auto &node : best_parallel_info.nodes()) {
540         sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true;
541         origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true;
542       }
543     }
544   }
545 
546   // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
547   if (parallel_infos.size() == 0) {
548     origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true;
549   }
550 
551   return std::make_tuple(origin_candidates_used, parallel_infos);
552 }
553 
SearchFuseNodesInCandidates(const AnfNodePtrList & cs)554 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
555   const AnfNodePtrList &cs) {
556   std::map<AnfNodePtr, int> origin_indices;
557   std::vector<size_t> indices;
558   for (size_t i = 0; i < cs.size(); ++i) {
559     if (cs[i]) {
560       (void)origin_indices.emplace(cs[i], i);
561       indices.push_back(i);
562     }
563   }
564 
565   // A calculated heavy node can cover more lighter nodes' cost, so sort them first.
566   std::map<size_t, int> cal_amounts;
567   for (auto id : indices) {
568     cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
569   }
570   std::sort(indices.begin(), indices.end(),
571             [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
572 
573   AnfNodePtrList candidates;
574   for (size_t i = 0; i < indices.size(); ++i) {
575     candidates.push_back(cs[indices[i]]);
576   }
577 
578   std::map<AnfNodePtr, int> sorted_indices;
579   for (size_t i = 0; i < candidates.size(); ++i) {
580     (void)sorted_indices.emplace(candidates[i], i);
581   }
582 
583   return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
584 }
585 
SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> & group,std::vector<ParallelInfo> * parallel_infos)586 void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
587                                                       std::vector<ParallelInfo> *parallel_infos) {
588   std::vector<AnfNodePtrList::const_iterator> tails;
589   std::vector<AnfNodePtrList::const_iterator> ended;
590   for (const auto &node_list : group) {
591     tails.push_back(node_list.begin());
592     ended.push_back(node_list.end());
593   }
594   auto get_candidates = [&tails, &ended]() {
595     AnfNodePtrList candidates;
596     for (size_t id = 0; id < tails.size(); ++id) {
597       candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
598     }
599     return candidates;
600   };
601   auto update_tails = [&tails](const std::vector<bool> &used) {
602     if (used.size() != tails.size()) {
603       MS_LOG(EXCEPTION) << "Judged nodes size is not equal to left ones!";
604     }
605     for (size_t id = 0; id < used.size(); ++id) {
606       if (used[id]) {
607         ++tails[id];
608       }
609     }
610   };
611   auto valid_candidate_num = [](const AnfNodePtrList &cs) {
612     return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
613   };
614 
615   auto candidates = get_candidates();
616   while (valid_candidate_num(candidates) > 1) {
617     auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
618     std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
619                    [](const ParallelInfo &pi) { return pi; });
620     update_tails(used);
621     candidates = get_candidates();
622   }
623 }
624 
SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> & groups)625 std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
626   const std::vector<std::vector<AnfNodePtrList>> &groups) {
627   // Find core-fusable groups with cost model.
628   std::vector<ParallelInfo> parallel_infos;
629   for (const auto &group : groups) {
630     SearchFuseNodesInParallelGroup(group, &parallel_infos);
631   }
632 
633   return parallel_infos;
634 }
635 
SetFusedParallelOpAttrToReturnNode(const ParallelInfo & parallel_info)636 void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &parallel_info) {
637   AnfNodePtr attach_node;
638   // Dim info should be attach to each segment's output.
639   for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
640     const auto &fuse_nodes = parallel_info.nodes();
641     std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
642     if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
643       attach_node = fuse_nodes[i];
644       SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
645     } else {
646       auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
647       auto out_node = node_g->output();
648       if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
649         auto inputs = out_node->cast<CNodePtr>()->inputs();
650         for (size_t j = 1; j < inputs.size(); ++j) {
651           SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
652         }
653         attach_node = inputs[1];
654       } else {
655         attach_node = out_node;
656         SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
657       }
658     }
659   }
660 
661   // Fusion info is ok to attach to one of the segments.
662   SetFusionInfoAttrToNode(attach_node, parallel_info);
663 }
664 
SetFusionInfoAttrToNode(const AnfNodePtr & node,const ParallelInfo & parallel_info)665 void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info) {
666   auto fusion_type = parallel_info.fusion_info()->FusionType();
667   AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
668   if (parallel_info.fusion_info()->ExistTypeInfo()) {
669     if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
670       AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
671                            MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
672     }
673   }
674 }
675 
CreateParallelOpSubGraphs(const std::vector<ParallelInfo> & parallel_infos,const std::shared_ptr<session::KernelGraph> & kernel_graph)676 bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> &parallel_infos,
677                                                  const std::shared_ptr<session::KernelGraph> &kernel_graph) {
678   bool changed = false;
679 
680   for (size_t i = 0; i < parallel_infos.size(); ++i) {
681     const auto &fuse_nodes = parallel_infos[i].nodes();
682     if (fuse_nodes.size() <= 1) {
683       continue;
684     }
685     changed = true;
686     SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
687     AnfNodePtr sg_node;
688     std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
689     AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
690     DumpParallelFusionDetail(fuse_nodes, sg_node);
691   }
692 
693   return changed;
694 }
695 
Run(const FuncGraphPtr & graph)696 bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
697   MS_EXCEPTION_IF_NULL(graph);
698   (void)std::make_shared<ShrinkUpdateState>()->Run(graph);
699   auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
700   MS_EXCEPTION_IF_NULL(kernel_graph);
701 
702   cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
703   MS_EXCEPTION_IF_NULL(cost_model_ptr_);
704 
705   auto nodes = TopoSort(kernel_graph->get_return());
706   std::reverse(nodes.begin(), nodes.end());
707 
708   auto node_rels = GenAnalysisGraph(nodes);
709   auto groups = SearchParallelGroups(node_rels);
710   auto parallel_infos = SearchFusableParallelCNodes(groups);
711 
712   // Create core-fuse subgraph and change origin graph.
713   bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
714   (void)std::make_shared<SpreadUpdateState>()->Run(graph);
715   return changed;
716 }
717 }  // namespace opt
718 }  // namespace mindspore
719