1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/parallel_fusion.h"
18
19 #include <algorithm>
20 #include <list>
21 #include <queue>
22 #include <unordered_map>
23 #include <utility>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "backend/common/graph_kernel/graph_kernel_flags.h"
27 #include "kernel/kernel.h"
28 #include "backend/common/graph_kernel/graph_kernel_helper.h"
29 #include "kernel/common_utils.h"
30 #include "frontend/operator/ops.h"
31 #include "ir/func_graph_cloner.h"
32 #include "backend/common/graph_kernel/core/update_state_formatter.h"
33 #include "backend/common/graph_kernel/core/graph_builder.h"
34
35 namespace mindspore::graphkernel {
36 namespace {
37 // Cuda's parameter table can accept maximum 4KB, so the number of parameters should be less than 512.
38 constexpr size_t CUDA_PARA_LIMIT = 512;
39
ProcessThroughPassCNode(OrderedMap<AnfNodePtr,NodeRelation> * node_rels)40 void ProcessThroughPassCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
41 std::set<AnfNodePtr> to_be_erased;
42 for (const auto &[node, node_rel] : (*node_rels)) {
43 if (!IsOneOfPrimitiveCNode(
44 node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem})) {
45 continue;
46 }
47 (void)to_be_erased.insert(node);
48 for (const auto &pre : node_rel.pres) {
49 auto &pre_nexts = (*node_rels)[pre].nexts;
50 (void)pre_nexts.erase(node);
51 for (const auto &next : node_rel.nexts) {
52 (void)pre_nexts.insert(next);
53 auto &next_pres = (*node_rels)[next].pres;
54 (void)next_pres.erase(node);
55 (void)next_pres.insert(pre);
56 }
57 }
58 }
59 for (const auto &node : to_be_erased) {
60 (void)node_rels->erase(node);
61 }
62 }
63
ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr,NodeRelation> * node_rels)64 void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
65 AnfNodePtrList latter_to_be_erased;
66 for (auto &[node, node_rel] : (*node_rels)) {
67 if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
68 continue;
69 }
70
71 AnfNodePtrList check_next_list;
72 check_next_list.push_back(node);
73
74 bool disinterested = false;
75 for (auto &successor : node_rel.nexts) {
76 if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
77 disinterested = true;
78 break;
79 }
80 check_next_list.push_back(successor);
81 }
82 if (disinterested) {
83 continue;
84 }
85
86 if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
87 [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
88 continue;
89 }
90
91 latter_to_be_erased.push_back(node);
92 }
93
94 // Delete Tail MakeTuple(including its getitem nodes).
95 for (const auto &node : latter_to_be_erased) {
96 for (auto &pre : (*node_rels)[node].pres) {
97 (void)(*node_rels)[pre].nexts.erase(node);
98 }
99
100 // Tail MakeTuple is just be consumed by nothing or invalid getitem node.
101 for (auto &getitem : (*node_rels)[node].nexts) {
102 (void)node_rels->erase(getitem);
103 }
104
105 (void)node_rels->erase(node);
106 }
107 }
108
IsSingleInputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)109 bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
110 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
111 return true;
112 }
113 return false;
114 }
115
IsSingleOutputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)116 bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
117 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
118 return true;
119 }
120 return false;
121 }
122
IsMultiInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)123 bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
124 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
125 return true;
126 }
127 return false;
128 }
129
IsMultiOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)130 bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
131 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
132 return true;
133 }
134 return false;
135 }
136
IsNoInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)137 bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
138 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
139 return true;
140 }
141 return false;
142 }
143
IsNoOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)144 bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
145 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
146 return true;
147 }
148 return false;
149 }
150
ProcessLocalStructure(OrderedMap<AnfNodePtr,NodeRelation> * node_rels,std::set<AnfNodePtr> * virtual_noout_nodes,std::set<AnfNodePtr> * ignore_noin_nodes)151 void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes,
152 std::set<AnfNodePtr> *ignore_noin_nodes) {
153 // 1. Local relation
154 // Graph as following left part, relation D->B and D->E(D is a no input node)
155 // will make B and E to be multiply inputs node.
156 // But for parallel, this local relation can ignore for B and E, which make
157 // them be able to be paralleled.
158 //
159 // ************************************
160 // * *
161 // * | | *
162 // * A D A D *
163 // * | /| | / \ *
164 // * | C | | C F *
165 // * |/ / | | | *
166 // * B F ====> B x x *
167 // * | / | *
168 // * |/ | *
169 // * E E *
170 // * | | *
171 // * *
172 // ************************************
173 AnfNodePtrList no_input_nodes;
174 for (const auto &node_rel : *node_rels) {
175 auto &node = node_rel.first;
176 if (IsNoInputsNode(*node_rels, node)) {
177 no_input_nodes.push_back(node);
178 }
179 }
180
181 std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
182
183 for (const auto &ninode : no_input_nodes) {
184 AnfNodePtrList cnexts((*node_rels)[ninode].nexts.cbegin(), (*node_rels)[ninode].nexts.cend());
185 for (const auto &n : cnexts) {
186 AnfNodePtr serial_tail = ninode;
187 AnfNodePtr cur_node = n;
188 while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
189 serial_tail = cur_node;
190 cur_node = *((*node_rels)[cur_node].nexts.begin());
191 }
192 (void)latter_delete.emplace_back(serial_tail, cur_node);
193 }
194 }
195
196 // Delete relation.
197 for (const auto &[serial_tail, cur_node] : latter_delete) {
198 (void)virtual_noout_nodes->insert(serial_tail);
199 (void)ignore_noin_nodes->insert(cur_node);
200 (void)(*node_rels)[serial_tail].nexts.erase(cur_node);
201 (void)(*node_rels)[cur_node].pres.erase(serial_tail);
202 MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
203 << cur_node->fullname_with_scope();
204 }
205 }
206
GetInterestNodeIds(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const std::set<AnfNodePtr> & virtual_noout_nodes,const std::set<AnfNodePtr> & ignore_noin_nodes)207 std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
208 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
209 const std::set<AnfNodePtr> &ignore_noin_nodes) {
210 AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
211 std::list<std::function<void(const AnfNodePtr &)>> func_list = {
212 [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
213 if (IsMultiInputsNode(node_rels, node)) {
214 multi_inputs_nodes.push_back(node);
215 }
216 },
217 [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
218 if (IsMultiOutputsNode(node_rels, node)) {
219 multi_outputs_nodes.push_back(node);
220 }
221 },
222 [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
223 if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
224 no_input_nodes.push_back(node);
225 }
226 },
227 [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
228 if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
229 no_output_nodes.push_back(node);
230 }
231 }};
232
233 for (const auto &node_rel : node_rels) {
234 for (const auto &func : func_list) {
235 func(node_rel.first);
236 }
237 }
238
239 return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
240 }
241
WhiteOpsFilter(const AnfNodePtr & node)242 bool WhiteOpsFilter(const AnfNodePtr &node) { return common::AnfAlgo::IsGraphKernel(node); }
243
244 // Parallel cannot work with stitching for now.
Parallelizable(const AnfNodePtr & node)245 bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !IsBufferStitchNode(node); }
246
SearchFromNodes(const AnfNodePtrList & nodes,const std::function<bool (const AnfNodePtr &)> & filter_func,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::set<AnfNodePtr> * seen)247 std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
248 const std::function<bool(const AnfNodePtr &)> &filter_func,
249 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
250 std::set<AnfNodePtr> *seen) {
251 // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
252 // For backward search, the other multi-inputs node can be contained in.
253 // For forward search, the other multi-outputs node can be contained in.
254 auto get_contain_node_set = [is_backward](const NodeRelation &info) { return is_backward ? info.pres : info.nexts; };
255 auto get_exclude_node_set = [is_backward](const NodeRelation &info) { return is_backward ? info.nexts : info.pres; };
256
257 std::vector<AnfNodePtrList> group;
258 for (const auto &node : nodes) {
259 AnfNodePtrList stream;
260 AnfNodePtr n = node;
261 for (auto iter = node_rels.find(n);
262 seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
263 iter = node_rels.find(n)) {
264 if (filter_func(n)) {
265 stream.push_back(n);
266 (void)seen->insert(n);
267 }
268 if (get_contain_node_set(iter->second).size() != 1) {
269 break;
270 }
271 n = *(get_contain_node_set(iter->second).cbegin());
272 }
273 if (stream.size() > 0) {
274 group.push_back(stream);
275 }
276 }
277
278 if (group.size() == 1) {
279 for (const auto &drop : group[0]) {
280 (void)seen->erase(drop);
281 }
282 group.clear();
283 }
284
285 return group;
286 }
287
SearchStreamFromMultiRelationNode(const AnfNodePtrList & multi_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)288 void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
289 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
290 std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
291 auto get_related_nodes = [is_backward](const NodeRelation &info) { return is_backward ? info.pres : info.nexts; };
292 for (const auto &node : multi_nodes) {
293 if (auto iter = node_rels.find(node); iter != node_rels.end()) {
294 const auto &pre_nodes = get_related_nodes(iter->second);
295 AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
296 groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen));
297 }
298 }
299
300 // Erase empty groups.
301 for (auto iter = groups->begin(); iter != groups->end();) {
302 if (iter->size() == 0) {
303 iter = groups->erase(iter);
304 } else {
305 ++iter;
306 }
307 }
308 }
309
SearchStreamFromUnidirectionalNode(const AnfNodePtrList & ud_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)310 void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
311 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
312 std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
313 groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen));
314
315 // Erase empty groups.
316 for (auto iter = groups->begin(); iter != groups->end();) {
317 if (iter->size() == 0) {
318 iter = groups->erase(iter);
319 } else {
320 ++iter;
321 }
322 }
323 }
324
DumpNode(const AnfNodePtr & node)325 std::string DumpNode(const AnfNodePtr &node) {
326 auto cnode = node->cast<CNodePtr>();
327 MS_EXCEPTION_IF_NULL(cnode);
328 std::stringstream buf;
329 buf << (common::AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
330 << cnode->ToString();
331 return buf.str();
332 }
333
DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> & groups,const std::string & title="")334 void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups, const std::string &title = "") {
335 MS_LOG(INFO) << "[" << title << "]"
336 << "There are " << groups.size() << " parallel groups, their detail is: ";
337 int i = 0;
338 for (const auto &group : groups) {
339 std::stringstream buf;
340 buf << "[" << i << " group] " << group.size() << ":\n";
341 for (const auto &nodes : group) {
342 buf << " " << nodes.size() << ": [<";
343 for (const auto &node : nodes) {
344 buf << "(" << DumpNode(node) << ") -> ";
345 }
346 buf << ">]\n";
347 }
348 i++;
349 MS_LOG(INFO) << buf.str();
350 }
351 }
352
DumpParallelFusionDetail(const AnfNodePtrList & source,const AnfNodePtr & target)353 void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
354 std::stringstream buf;
355 buf << "Parallel fusion detail: ";
356 for (const auto &node : source) {
357 buf << "(" << DumpNode(node) << ") + ";
358 }
359 buf << "==>"
360 << "(" << DumpNode(target) << ")";
361 MS_LOG(INFO) << buf.str();
362 }
363
ParameterLimit(const AnfNodePtrList & nodes)364 inline bool ParameterLimit(const AnfNodePtrList &nodes) {
365 if (nodes.empty()) {
366 MS_LOG(EXCEPTION) << "Nodes is empty, can not check condition.";
367 }
368
369 bool res = true;
370 auto processor_type = AnfAlgo::GetProcessor(nodes[0]);
371 if (processor_type == kernel::Processor::CUDA) {
372 // The number of inputs and outputs for a valid kernel should be less than cuda's limit.
373 size_t para_count = 0;
374 for (const auto &node : nodes) {
375 para_count += common::AnfAlgo::GetInputTensorNum(node);
376 para_count += AnfAlgo::GetOutputTensorNum(node);
377 }
378 res = para_count <= CUDA_PARA_LIMIT;
379 }
380
381 return res;
382 }
383
ExtraFusionCondition(const AnfNodePtrList & nodes)384 bool ExtraFusionCondition(const AnfNodePtrList &nodes) { return ParameterLimit(nodes); }
385 } // namespace
386
GenAnalysisGraph(const AnfNodePtrList & nodes)387 OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
388 // Based on anf node input information, build a simple graph for latter analyzation.
389 OrderedMap<AnfNodePtr, NodeRelation> node_rels;
390 auto get_info = [&node_rels](const AnfNodePtr &node) {
391 if (node_rels.count(node) == 0) {
392 (void)node_rels.emplace(node, NodeRelation());
393 }
394 return &(node_rels[node]);
395 };
396
397 for (const auto &node : nodes) {
398 if (!node->isa<CNode>()) {
399 continue;
400 }
401
402 auto prior_node = get_info(node);
403 for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
404 if (!input->isa<CNode>()) {
405 continue;
406 }
407 auto behind_node = get_info(input);
408 (void)prior_node->pres.insert(input);
409 (void)behind_node->nexts.insert(node);
410 }
411 }
412
413 ProcessThroughPassCNode(&node_rels);
414 ProcessTailMakeTupleCNode(&node_rels);
415 ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
416
417 return node_rels;
418 }
419
SearchParallelGroups(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels)420 std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
421 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
422 // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
423 auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
424 GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
425
426 // Get streams and group them
427 std::set<AnfNodePtr> seen;
428 std::vector<std::vector<AnfNodePtrList>> groups;
429
430 SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
431 SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
432 SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
433 SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
434
435 DumpParallelGroups(groups, "Dependency Analyze");
436 return groups;
437 }
438
GetAvaliableNodesByOffset(int start,const std::vector<size_t> & offsets,const std::vector<bool> & used,const AnfNodePtrList & nodes,const std::set<int> & excludes) const439 std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
440 int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
441 const std::set<int> &excludes) const {
442 // Get unused nodes by offset index, the result will contain the node with start index.
443 int node_limit = static_cast<int>(nodes.size());
444 if (start >= node_limit) {
445 MS_LOG(EXCEPTION) << "Index offset should be less than the limit of given nodes " << node_limit << ", but got "
446 << start;
447 }
448 AnfNodePtrList target_nodes = {nodes[IntToSize(start)]};
449 std::vector<int> valid_indices;
450 std::vector<size_t> unused;
451 for (size_t i = IntToSize(start); i < used.size(); ++i) {
452 if (!used[i] && excludes.count(i) == 0) {
453 unused.push_back(i);
454 }
455 }
456 size_t limit = unused.size();
457 for (auto offset : offsets) {
458 if (offset >= limit) {
459 MS_LOG(EXCEPTION) << "Index offset should be less than the limit of unused nodes " << limit << ", but got "
460 << offset;
461 }
462 if (SizeToInt(unused[offset]) >= node_limit) {
463 MS_LOG(EXCEPTION) << "Index offset should be less than the limit of nodes " << node_limit << ", but got "
464 << unused[offset];
465 }
466 valid_indices.push_back(unused[offset]);
467 target_nodes.push_back(nodes[unused[offset]]);
468 }
469
470 return std::make_tuple(target_nodes, valid_indices);
471 }
472
DoSearchInSortedCandidates(size_t origin_size,const AnfNodePtrList & candidates,std::map<AnfNodePtr,int> * origin_indices,std::map<AnfNodePtr,int> * sorted_indices)473 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
474 size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
475 std::map<AnfNodePtr, int> *sorted_indices) {
476 auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
477 MS_EXCEPTION_IF_NULL(node);
478 if (indices->find(node) == indices->end()) {
479 MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
480 }
481 return (*indices)[node];
482 };
483
484 std::vector<ParallelInfo> parallel_infos;
485 std::vector<bool> origin_candidates_used(origin_size, false);
486 std::vector<bool> sorted_candidates_used(candidates.size(), false);
487 size_t offset;
488 for (size_t i = 0; i < candidates.size(); i += offset + 1) {
489 offset = 0;
490 if (sorted_candidates_used[i]) {
491 continue;
492 }
493
494 int max_benefit = 0;
495 ParallelInfo best_parallel_info;
496 size_t unused_num = 0;
497 for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
498 unused_num += sorted_candidates_used[j] ? 0 : 1;
499 }
500 if (unused_num < 1) {
501 break;
502 }
503
504 unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
505
506 size_t begin = 1, end = unused_num;
507 while (begin <= end) {
508 size_t mid = (begin + end) / 2;
509 std::vector<size_t> tc(mid);
510 for (size_t idx = 0; idx < mid; idx++) {
511 tc[idx] = idx + 1;
512 }
513 AnfNodePtrList other_candidates;
514 std::tie(other_candidates, std::ignore) =
515 GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
516 if (ExtraFusionCondition(other_candidates)) {
517 int benefit;
518 std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
519 if (benefit > 0) {
520 begin = mid + 1;
521 continue;
522 }
523 }
524 end = mid - 1;
525 }
526
527 if (begin > 1) {
528 std::vector<size_t> tc(begin - 1);
529 for (size_t idx = 0; idx < begin - 1; idx++) {
530 tc[idx] = idx + 1;
531 }
532 AnfNodePtrList other_candidates;
533 std::tie(other_candidates, std::ignore) =
534 GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
535 auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
536 if (benefit <= 0) {
537 MS_LOG(EXCEPTION) << "Internal error in candidate search! benefit should be greater than 0, but got "
538 << benefit;
539 }
540 max_benefit = benefit;
541 best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
542 offset = begin - 1;
543 }
544
545 if (max_benefit > 0) {
546 parallel_infos.push_back(best_parallel_info);
547 for (const auto &node : best_parallel_info.nodes()) {
548 sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true;
549 origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true;
550 }
551 }
552 }
553
554 // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
555 if (parallel_infos.size() == 0) {
556 origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true;
557 }
558
559 return std::make_tuple(origin_candidates_used, parallel_infos);
560 }
561
SearchFuseNodesInCandidates(const AnfNodePtrList & cs)562 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
563 const AnfNodePtrList &cs) {
564 std::map<AnfNodePtr, int> origin_indices;
565 std::vector<size_t> indices;
566 for (size_t i = 0; i < cs.size(); ++i) {
567 if (cs[i]) {
568 origin_indices[cs[i]] = SizeToInt(i);
569 indices.push_back(i);
570 }
571 }
572
573 // A calculated heavy node can cover more lighter nodes' cost, so sort them first.
574 std::map<size_t, int64_t> cal_amounts;
575 for (auto id : indices) {
576 cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
577 }
578 std::sort(indices.begin(), indices.end(),
579 [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
580
581 AnfNodePtrList candidates;
582 for (size_t i = 0; i < indices.size(); ++i) {
583 candidates.push_back(cs[indices[i]]);
584 }
585
586 std::map<AnfNodePtr, int> sorted_indices;
587 for (size_t i = 0; i < candidates.size(); ++i) {
588 sorted_indices[candidates[i]] = SizeToInt(i);
589 }
590
591 return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
592 }
593
SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> & group,std::vector<ParallelInfo> * parallel_infos)594 void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
595 std::vector<ParallelInfo> *parallel_infos) {
596 std::vector<AnfNodePtrList::const_iterator> tails;
597 std::vector<AnfNodePtrList::const_iterator> ended;
598 for (const auto &node_list : group) {
599 tails.push_back(node_list.begin());
600 ended.push_back(node_list.end());
601 }
602 auto get_candidates = [&tails, &ended]() {
603 AnfNodePtrList candidates;
604 for (size_t id = 0; id < tails.size(); ++id) {
605 candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
606 }
607 return candidates;
608 };
609 auto update_tails = [&tails](const std::vector<bool> &used) {
610 if (used.size() != tails.size()) {
611 MS_LOG(EXCEPTION) << "Judged nodes size is different from left ones size: " << used.size() << " vs "
612 << tails.size();
613 }
614 for (size_t id = 0; id < used.size(); ++id) {
615 if (used[id]) {
616 ++tails[id];
617 }
618 }
619 };
620 auto valid_candidate_num = [](const AnfNodePtrList &cs) {
621 return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
622 };
623
624 auto candidates = get_candidates();
625 while (valid_candidate_num(candidates) > 1) {
626 auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
627 (void)std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
628 [](const ParallelInfo &pi) { return pi; });
629 update_tails(used);
630 candidates = get_candidates();
631 }
632 }
633
SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> & groups)634 std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
635 const std::vector<std::vector<AnfNodePtrList>> &groups) {
636 // Find core-fusable groups with cost model.
637 std::vector<ParallelInfo> parallel_infos;
638 for (const auto &group : groups) {
639 SearchFuseNodesInParallelGroup(group, ¶llel_infos);
640 }
641
642 return parallel_infos;
643 }
644
SetFusedParallelOpAttrToReturnNode(const ParallelInfo & parallel_info)645 void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) {
646 AnfNodePtr attach_node;
647 // Dim info should be attach to each segment's output.
648 for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
649 const auto &fuse_nodes = parallel_info.nodes();
650 std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
651 if (!common::AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
652 attach_node = fuse_nodes[i];
653 SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
654 } else {
655 auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
656 auto out_node = node_g->output();
657 if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
658 auto inputs = out_node->cast<CNodePtr>()->inputs();
659 for (size_t j = 1; j < inputs.size(); ++j) {
660 SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
661 }
662 attach_node = inputs[1];
663 } else {
664 attach_node = out_node;
665 SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
666 }
667 }
668 }
669
670 // Fusion info is ok to attach to one of the segments.
671 SetFusionInfoAttrToNode(attach_node, parallel_info);
672 }
673
SetFusionInfoAttrToNode(const AnfNodePtr & node,const ParallelInfo & parallel_info)674 void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) {
675 auto fusion_type = parallel_info.fusion_info()->FusionType();
676 common::AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
677 if (parallel_info.fusion_info()->ExistTypeInfo()) {
678 if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
679 common::AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
680 MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
681 }
682 }
683 }
684
CreateParallelOpSubGraphs(const std::vector<ParallelInfo> & parallel_infos,const std::shared_ptr<session::KernelGraph> & kernel_graph)685 bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos,
686 const std::shared_ptr<session::KernelGraph> &kernel_graph) {
687 bool changed = false;
688
689 for (size_t i = 0; i < parallel_infos.size(); ++i) {
690 const auto &fuse_nodes = parallel_infos[i].nodes();
691 if (fuse_nodes.size() <= 1) {
692 continue;
693 }
694 changed = true;
695 SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
696 auto sg_node = ReplaceNodesWithGraphKernelNode(fuse_nodes, kernel_graph, "parallel");
697 common::AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
698 DumpParallelFusionDetail(fuse_nodes, sg_node);
699 }
700
701 return changed;
702 }
703
CollectCapturedNodes(const std::vector<ParallelInfo> & infos)704 std::set<AnfNodePtr> CollectCapturedNodes(const std::vector<ParallelInfo> &infos) {
705 std::set<AnfNodePtr> captured;
706 (void)std::for_each(infos.cbegin(), infos.cend(), [&captured](const ParallelInfo &info) {
707 captured.insert(info.nodes().cbegin(), info.nodes().cend());
708 });
709 return captured;
710 }
711
GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const std::set<AnfNodePtr> & exclude)712 std::vector<std::vector<AnfNodePtrList>> GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels,
713 const std::set<AnfNodePtr> &exclude) {
714 std::vector<std::vector<AnfNodePtrList>> groups;
715 // BFS
716 std::queue<AnfNodePtr> node_que;
717 std::unordered_map<AnfNodePtr, int> outdegrees;
718 for (const auto &[node, ref] : node_rels) {
719 outdegrees[node] = SizeToInt(ref.nexts.size());
720 if (outdegrees[node] == 0) {
721 node_que.push(node);
722 }
723 }
724
725 int total_node_num = SizeToInt(node_rels.size());
726 while (!node_que.empty()) {
727 std::vector<AnfNodePtrList> group;
728 int node_size = SizeToInt(node_que.size());
729 while (node_size != 0) {
730 node_size--;
731 auto node = node_que.front();
732 node_que.pop();
733 if (exclude.count(node) == 0 && Parallelizable(node)) {
734 (void)group.emplace_back(AnfNodePtrList({node}));
735 }
736 --total_node_num;
737 auto iter = node_rels.find(node);
738 if (iter == node_rels.end()) {
739 MS_LOG(EXCEPTION) << "Internal error in node relationship!";
740 }
741 for (const auto &pre : iter->second.pres) {
742 if (--outdegrees[pre] == 0) {
743 node_que.push(pre);
744 }
745 }
746 }
747 if (!group.empty()) {
748 groups.push_back(group);
749 }
750 }
751
752 if (total_node_num > 0) {
753 MS_LOG(EXCEPTION) << "There is circle in analyze graph!";
754 }
755 DumpParallelGroups(groups, "BFS");
756 return groups;
757 }
758
Run(const FuncGraphPtr & graph)759 bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
760 MS_EXCEPTION_IF_NULL(graph);
761 parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level;
762
763 (void)std::make_shared<ShrinkUpdateState>()->Run(graph);
764 auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
765 MS_EXCEPTION_IF_NULL(kernel_graph);
766
767 cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
768 MS_EXCEPTION_IF_NULL(cost_model_ptr_);
769
770 auto nodes = TopoSort(kernel_graph->get_return());
771 std::reverse(nodes.begin(), nodes.end());
772
773 auto node_rels = GenAnalysisGraph(nodes);
774 auto groups = SearchParallelGroups(node_rels);
775 auto parallel_infos = SearchFusableParallelCNodes(groups);
776
777 // Search in BFS for left nodes.
778 if (parallel_level_ > 0) {
779 auto exclued_nodes = CollectCapturedNodes(parallel_infos);
780 auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes);
781 auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs);
782
783 (void)parallel_infos.insert(parallel_infos.cend(), bfs_parallel_infos.cbegin(), bfs_parallel_infos.cend());
784 }
785
786 // Create core-fuse subgraph and change origin graph.
787 bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
788 (void)std::make_shared<SpreadUpdateState>()->Run(graph);
789 return changed;
790 }
791 } // namespace mindspore::graphkernel
792