1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/parallel_fusion.h"
18
19 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
20 #include "frontend/operator/ops.h"
21 #include "ir/func_graph_cloner.h"
22 #include "vm/segment_runner.h"
23 #include "backend/optimizer/graph_kernel/update_state_formatter.h"
24
25 namespace mindspore {
26 namespace opt {
27 namespace {
IsOneOf(const AnfNodePtr & node,const std::vector<PrimitivePtr> & ops_prim)28 bool IsOneOf(const AnfNodePtr &node, const std::vector<PrimitivePtr> &ops_prim) {
29 return std::any_of(ops_prim.cbegin(), ops_prim.cend(),
30 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
31 }
32
ProcessThroughPassCNode(const std::function<bool (const AnfNodePtr &)> & pass_fn,OrderedMap<AnfNodePtr,NodeRelation> * node_rels)33 void ProcessThroughPassCNode(const std::function<bool(const AnfNodePtr &)> &pass_fn,
34 OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
35 std::set<AnfNodePtr> latter_to_be_erased;
36 for (const auto &[node, node_rel] : (*node_rels)) {
37 if (!pass_fn(node) || latter_to_be_erased.count(node) != 0) {
38 continue;
39 }
40
41 auto nexts = node_rel.nexts;
42 std::vector<AnfNodePtr> pre_nodes;
43 std::queue<AnfNodePtr> node_que;
44 node_que.push(node);
45
46 // Find until all pre nodes get false from pass_fn, and collect all these predecessor nodes.
47 while (!node_que.empty()) {
48 auto cur_node = node_que.front();
49 node_que.pop();
50
51 if (!pass_fn(cur_node)) {
52 pre_nodes.push_back(cur_node);
53 continue;
54 }
55
56 latter_to_be_erased.insert(cur_node);
57 auto predecessors = (*node_rels)[cur_node].pres;
58 if (predecessors.empty()) {
59 continue;
60 }
61
62 for (const auto &pre_node : predecessors) {
63 (*node_rels)[cur_node].pres.erase(pre_node);
64 (*node_rels)[pre_node].nexts.erase(cur_node);
65 node_que.push(pre_node);
66 }
67 }
68
69 // Modify the relation: delete node <-> next_node, add pre node <-> next_node.
70 for (const auto &next_node : nexts) {
71 (*node_rels)[next_node].pres.erase(node);
72 for (const auto &cur_node : pre_nodes) {
73 (*node_rels)[next_node].pres.insert(cur_node);
74 (*node_rels)[cur_node].nexts.insert(next_node);
75 }
76 }
77 }
78
79 for (const auto &node : latter_to_be_erased) {
80 node_rels->erase(node);
81 }
82 }
83
ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr,NodeRelation> * node_rels)84 void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
85 AnfNodePtrList latter_to_be_erased;
86 for (auto &[node, node_rel] : (*node_rels)) {
87 if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
88 continue;
89 }
90
91 AnfNodePtrList check_next_list;
92 check_next_list.push_back(node);
93
94 bool disinterested = false;
95 for (auto &successor : node_rel.nexts) {
96 if (!IsPrimitiveCNode(successor, prim::kPrimTupleGetItem)) {
97 disinterested = true;
98 break;
99 }
100 check_next_list.push_back(successor);
101 }
102 if (disinterested) {
103 continue;
104 }
105
106 if (!std::all_of(check_next_list.cbegin(), check_next_list.cend(),
107 [&node_rels](const AnfNodePtr &n) -> bool { return (*node_rels)[n].nexts.empty(); })) {
108 continue;
109 }
110
111 latter_to_be_erased.push_back(node);
112 }
113
114 // Delete Tail MakeTuple(including its getitem nodes).
115 for (const auto &node : latter_to_be_erased) {
116 for (auto &pre : (*node_rels)[node].pres) {
117 (*node_rels)[pre].nexts.erase(node);
118 }
119
120 // Tail MakeTuple is just be consumed by nothing or invalid getitem node.
121 for (auto &getitem : (*node_rels)[node].nexts) {
122 node_rels->erase(getitem);
123 }
124
125 node_rels->erase(node);
126 }
127 }
128
IsSingleInputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)129 bool IsSingleInputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
130 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 1) {
131 return true;
132 }
133 return false;
134 }
135
IsSingleOutputNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)136 bool IsSingleOutputNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
137 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 1) {
138 return true;
139 }
140 return false;
141 }
142
IsMultiInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)143 bool IsMultiInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
144 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() > 1) {
145 return true;
146 }
147 return false;
148 }
149
IsMultiOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)150 bool IsMultiOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
151 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() > 1) {
152 return true;
153 }
154 return false;
155 }
156
IsNoInputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)157 bool IsNoInputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
158 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.pres.size() == 0) {
159 return true;
160 }
161 return false;
162 }
163
IsNoOutputsNode(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const AnfNodePtr & node)164 bool IsNoOutputsNode(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const AnfNodePtr &node) {
165 if (auto iter = node_rels.find(node); iter != node_rels.end() && iter->second.nexts.size() == 0) {
166 return true;
167 }
168 return false;
169 }
170
ProcessLocalStructure(OrderedMap<AnfNodePtr,NodeRelation> * node_rels,std::set<AnfNodePtr> * virtual_noout_nodes,std::set<AnfNodePtr> * ignore_noin_nodes)171 void ProcessLocalStructure(OrderedMap<AnfNodePtr, NodeRelation> *node_rels, std::set<AnfNodePtr> *virtual_noout_nodes,
172 std::set<AnfNodePtr> *ignore_noin_nodes) {
173 // 1. Local relation
174 // Graph as following left part, relation D->B and D->E(D is a no input node)
175 // will make B and E to be multiply inputs node.
176 // But for parallel, this local relation can ignore for B and E, which make
177 // them be able to be paralleled.
178 //
179 // ************************************
180 // * *
181 // * | | *
182 // * A D A D *
183 // * | /| | / \ *
184 // * | C | | C F *
185 // * |/ / | | | *
186 // * B F ====> B x x *
187 // * | / | *
188 // * |/ | *
189 // * E E *
190 // * | | *
191 // * *
192 // ************************************
193 AnfNodePtrList no_input_nodes;
194 for (const auto &node_rel : *node_rels) {
195 auto &node = node_rel.first;
196 if (IsNoInputsNode(*node_rels, node)) {
197 no_input_nodes.push_back(node);
198 }
199 }
200
201 std::vector<std::pair<AnfNodePtr, AnfNodePtr>> latter_delete;
202
203 for (const auto &ninode : no_input_nodes) {
204 AnfNodePtrList cnexts((*node_rels)[ninode].nexts.begin(), (*node_rels)[ninode].nexts.end());
205 for (const auto &n : cnexts) {
206 AnfNodePtr serial_tail = ninode;
207 AnfNodePtr cur_node = n;
208 while (IsSingleInputNode(*node_rels, cur_node) && IsSingleOutputNode(*node_rels, cur_node)) {
209 serial_tail = cur_node;
210 cur_node = *((*node_rels)[cur_node].nexts.begin());
211 }
212 latter_delete.emplace_back(serial_tail, cur_node);
213 }
214 }
215
216 // Delete relation.
217 for (const auto &[serial_tail, cur_node] : latter_delete) {
218 virtual_noout_nodes->insert(serial_tail);
219 ignore_noin_nodes->insert(cur_node);
220 (*node_rels)[serial_tail].nexts.erase(cur_node);
221 (*node_rels)[cur_node].pres.erase(serial_tail);
222 MS_LOG(INFO) << "Process local relation delete relation: " << serial_tail->fullname_with_scope() << " -> "
223 << cur_node->fullname_with_scope();
224 }
225 }
226
GetInterestNodeIds(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,const std::set<AnfNodePtr> & virtual_noout_nodes,const std::set<AnfNodePtr> & ignore_noin_nodes)227 std::tuple<AnfNodePtrList, AnfNodePtrList, AnfNodePtrList, AnfNodePtrList> GetInterestNodeIds(
228 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, const std::set<AnfNodePtr> &virtual_noout_nodes,
229 const std::set<AnfNodePtr> &ignore_noin_nodes) {
230 AnfNodePtrList multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes;
231 std::list<std::function<void(const AnfNodePtr &)>> func_list = {
232 [&node_rels, &multi_inputs_nodes](const AnfNodePtr &node) {
233 if (IsMultiInputsNode(node_rels, node)) {
234 multi_inputs_nodes.push_back(node);
235 }
236 },
237 [&node_rels, &multi_outputs_nodes](const AnfNodePtr &node) {
238 if (IsMultiOutputsNode(node_rels, node)) {
239 multi_outputs_nodes.push_back(node);
240 }
241 },
242 [&node_rels, &no_input_nodes, &ignore_noin_nodes](const AnfNodePtr &node) {
243 if (IsNoInputsNode(node_rels, node) && ignore_noin_nodes.count(node) == 0) {
244 no_input_nodes.push_back(node);
245 }
246 },
247 [&node_rels, &no_output_nodes, &virtual_noout_nodes](const AnfNodePtr &node) {
248 if (IsNoOutputsNode(node_rels, node) && virtual_noout_nodes.count(node) == 0) {
249 no_output_nodes.push_back(node);
250 }
251 }};
252
253 for (const auto &node_rel : node_rels) {
254 for (const auto &func : func_list) {
255 func(node_rel.first);
256 }
257 }
258
259 return std::make_tuple(multi_inputs_nodes, multi_outputs_nodes, no_input_nodes, no_output_nodes);
260 }
261
WhiteOpsFilter(const AnfNodePtr & node)262 bool WhiteOpsFilter(const AnfNodePtr &node) {
263 std::vector<PrimitivePtr> whiteable_ops = {}; // Not special for now.
264 return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
265 }
266
SearchFromNodes(const AnfNodePtrList & nodes,const std::function<bool (const AnfNodePtr &)> & filter_func,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::set<AnfNodePtr> * seen)267 std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
268 const std::function<bool(const AnfNodePtr &)> &filter_func,
269 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
270 std::set<AnfNodePtr> *seen) {
271 // Start from multi-inputs node, stop on seen node or multi-inputs or multi-outputs nodes.
272 // For backward search, the other multi-inputs node can be contained in.
273 // For forward search, the other multi-outputs node can be contained in.
274 auto get_contain_node_set = is_backward ? [](const NodeRelation &info) { return info.pres; }
275 : [](const NodeRelation &info) { return info.nexts; };
276 auto get_exclude_node_set = is_backward ? [](const NodeRelation &info) { return info.nexts; }
277 : [](const NodeRelation &info) { return info.pres; };
278 std::vector<AnfNodePtrList> group;
279 for (const auto &node : nodes) {
280 AnfNodePtrList stream;
281 AnfNodePtr n = node;
282 for (auto iter = node_rels.find(n);
283 seen->count(n) == 0 && iter != node_rels.end() && get_exclude_node_set(iter->second).size() <= 1;
284 iter = node_rels.find(n)) {
285 if (filter_func(n)) {
286 stream.push_back(n);
287 seen->insert(n);
288 }
289 if (get_contain_node_set(iter->second).size() != 1) {
290 break;
291 }
292 n = *(get_contain_node_set(iter->second).begin());
293 }
294 if (stream.size() > 0) {
295 group.push_back(stream);
296 }
297 }
298
299 if (group.size() == 1) {
300 for (const auto &drop : group[0]) {
301 seen->erase(drop);
302 }
303 group.clear();
304 }
305
306 return group;
307 }
308
SearchStreamFromMultiRelationNode(const AnfNodePtrList & multi_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)309 void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
310 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
311 std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
312 auto get_related_nodes = is_backward ? [](const NodeRelation &info) { return info.pres; }
313 : [](const NodeRelation &info) { return info.nexts; };
314 for (const auto &node : multi_nodes) {
315 if (auto iter = node_rels.find(node); iter != node_rels.end()) {
316 const auto &pre_nodes = get_related_nodes(iter->second);
317 AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
318 groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
319 }
320 }
321
322 // Erase empty groups.
323 for (auto iter = groups->begin(); iter != groups->end();) {
324 if (iter->size() == 0) {
325 iter = groups->erase(iter);
326 } else {
327 ++iter;
328 }
329 }
330 }
331
SearchStreamFromUnidirectionalNode(const AnfNodePtrList & ud_nodes,const OrderedMap<AnfNodePtr,NodeRelation> & node_rels,bool is_backward,std::vector<std::vector<AnfNodePtrList>> * groups,std::set<AnfNodePtr> * seen)332 void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
333 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
334 std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
335 groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
336
337 // Erase empty groups.
338 for (auto iter = groups->begin(); iter != groups->end();) {
339 if (iter->size() == 0) {
340 iter = groups->erase(iter);
341 } else {
342 ++iter;
343 }
344 }
345 }
346
DumpNode(const AnfNodePtr & node)347 std::string DumpNode(const AnfNodePtr &node) {
348 auto cnode = node->cast<CNodePtr>();
349 MS_EXCEPTION_IF_NULL(cnode);
350 std::stringstream buf;
351 buf << (AnfAlgo::IsGraphKernel(cnode) ? "[graph]" : "[primitive]") << cnode->fullname_with_scope() << "|"
352 << cnode->ToString();
353 return buf.str();
354 }
355
DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> & groups)356 void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups) {
357 MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: ";
358 int i = 0;
359 for (const auto group : groups) {
360 std::stringstream buf;
361 buf << "[" << i << " group] " << group.size() << ":\n";
362 for (const auto nodes : group) {
363 buf << " " << nodes.size() << ": [<";
364 for (const auto node : nodes) {
365 buf << "(" << DumpNode(node) << ") -> ";
366 }
367 buf << ">]\n";
368 }
369 i++;
370 MS_LOG(INFO) << buf.str();
371 }
372 }
373
DumpParallelFusionDetail(const AnfNodePtrList & source,const AnfNodePtr & target)374 void DumpParallelFusionDetail(const AnfNodePtrList &source, const AnfNodePtr &target) {
375 std::stringstream buf;
376 buf << "Parallel fusion detail: ";
377 for (const auto &node : source) {
378 buf << "(" << DumpNode(node) << ") + ";
379 }
380 buf << "==>"
381 << "(" << DumpNode(target) << ")";
382 MS_LOG(INFO) << buf.str();
383 }
384 } // namespace
385
GenAnalysisGraph(const AnfNodePtrList & nodes)386 OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const AnfNodePtrList &nodes) {
387 // Based on anf node input information, build a simple graph for latter analyzation.
388 OrderedMap<AnfNodePtr, NodeRelation> node_rels;
389 auto get_info = [&node_rels](const AnfNodePtr &node) {
390 if (node_rels.count(node) == 0) {
391 (void)node_rels.emplace(node, NodeRelation());
392 }
393 return &(node_rels[node]);
394 };
395
396 for (const auto &node : nodes) {
397 if (!node->isa<CNode>()) {
398 continue;
399 }
400
401 auto prior_node = get_info(node);
402 for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
403 if (!input->isa<CNode>()) {
404 continue;
405 }
406 auto behind_node = get_info(input);
407 prior_node->pres.insert(input);
408 behind_node->nexts.insert(node);
409 }
410 }
411
412 ProcessThroughPassCNode(
413 [](const AnfNodePtr &node) {
414 return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
415 },
416 &node_rels);
417 ProcessTailMakeTupleCNode(&node_rels);
418 ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
419
420 return node_rels;
421 }
422
SearchParallelGroups(const OrderedMap<AnfNodePtr,NodeRelation> & node_rels)423 std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
424 const OrderedMap<AnfNodePtr, NodeRelation> &node_rels) {
425 // Get interesting nodes: multi-inputs nodes, multi-outputs nodes, no input nodes and no output nodes.
426 auto [mul_ins_nodes, mul_outs_nodes, no_in_nodes, no_out_nodes] =
427 GetInterestNodeIds(node_rels, virtual_noout_nodes_, ignore_noin_nodes_);
428
429 // Get streams and group them
430 std::set<AnfNodePtr> seen;
431 std::vector<std::vector<AnfNodePtrList>> groups;
432
433 SearchStreamFromMultiRelationNode(mul_ins_nodes, node_rels, true, &groups, &seen);
434 SearchStreamFromUnidirectionalNode(no_out_nodes, node_rels, true, &groups, &seen);
435 SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
436 SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
437
438 DumpParallelGroups(groups);
439 return groups;
440 }
441
GetAvaliableNodesByOffset(int start,const std::vector<size_t> & offsets,const std::vector<bool> & used,const AnfNodePtrList & nodes,const std::set<int> & excludes)442 std::tuple<AnfNodePtrList, std::vector<int>> ParallelOpFusion::GetAvaliableNodesByOffset(
443 int start, const std::vector<size_t> &offsets, const std::vector<bool> &used, const AnfNodePtrList &nodes,
444 const std::set<int> &excludes) {
445 // Get unused nodes by offset index, the result will contain the node with start index.
446 int node_limit = static_cast<int>(nodes.size());
447 if (start >= node_limit) {
448 MS_LOG(EXCEPTION) << "Index offset is exceed the limit of given nodes.";
449 }
450 AnfNodePtrList target_nodes = {nodes[IntToSize(start)]};
451 std::vector<int> valid_indices;
452 std::vector<size_t> unused;
453 for (size_t i = IntToSize(start); i < used.size(); ++i) {
454 if (!used[i] && excludes.count(i) == 0) {
455 unused.push_back(i);
456 }
457 }
458 size_t limit = unused.size();
459 for (auto offset : offsets) {
460 if (offset >= limit) {
461 MS_LOG(EXCEPTION) << "Index offset is exceed the limit of unused nodes.";
462 }
463 if (SizeToInt(unused[offset]) >= node_limit) {
464 MS_LOG(EXCEPTION) << "Index offset is exceed the limit of nodes.";
465 }
466 valid_indices.push_back(unused[offset]);
467 target_nodes.push_back(nodes[unused[offset]]);
468 }
469
470 return std::make_tuple(target_nodes, valid_indices);
471 }
472
DoSearchInSortedCandidates(size_t origin_size,const AnfNodePtrList & candidates,std::map<AnfNodePtr,int> * origin_indices,std::map<AnfNodePtr,int> * sorted_indices)473 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::DoSearchInSortedCandidates(
474 size_t origin_size, const AnfNodePtrList &candidates, std::map<AnfNodePtr, int> *origin_indices,
475 std::map<AnfNodePtr, int> *sorted_indices) {
476 auto get_index = [](std::map<AnfNodePtr, int> *indices, const AnfNodePtr &node) -> int {
477 MS_EXCEPTION_IF_NULL(node);
478 if (indices->find(node) == indices->end()) {
479 MS_LOG(EXCEPTION) << "There is no index record for node " << node->ToString();
480 }
481 return (*indices)[node];
482 };
483
484 std::vector<ParallelInfo> parallel_infos;
485 std::vector<bool> origin_candidates_used(origin_size, false);
486 std::vector<bool> sorted_candidates_used(candidates.size(), false);
487
488 for (size_t i = 0; i < candidates.size(); ++i) {
489 if (sorted_candidates_used[i]) {
490 continue;
491 }
492
493 int max_benefit = 0;
494 ParallelInfo best_parallel_info;
495 size_t unused_num = 0;
496 for (size_t j = i + 1; j < sorted_candidates_used.size(); ++j) {
497 unused_num += sorted_candidates_used[j] ? 0 : 1;
498 }
499 if (unused_num < 1) {
500 break;
501 }
502
503 unused_num = std::min(unused_num, config_.max_num_for_fuse() - 1);
504
505 size_t begin = 1, end = unused_num;
506 while (begin <= end) {
507 size_t mid = (begin + end) / 2;
508 std::vector<size_t> tc(mid);
509 std::iota(tc.begin(), tc.end(), 1);
510 AnfNodePtrList other_candidates;
511 std::tie(other_candidates, std::ignore) =
512 GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
513 int benefit;
514 std::tie(std::ignore, benefit, std::ignore) = cost_model_ptr_->CalFuseInfo(other_candidates);
515 if (benefit > 0) {
516 begin = mid + 1;
517 } else {
518 end = mid - 1;
519 }
520 }
521
522 if (begin > 1) {
523 std::vector<size_t> tc(begin - 1);
524 std::iota(tc.begin(), tc.end(), 1);
525 AnfNodePtrList other_candidates;
526 std::tie(other_candidates, std::ignore) =
527 GetAvaliableNodesByOffset(SizeToInt(i), tc, sorted_candidates_used, candidates, std::set<int>());
528 auto [dim_infos, benefit, fusion_info] = cost_model_ptr_->CalFuseInfo(other_candidates);
529 if (benefit <= 0) {
530 MS_LOG(EXCEPTION) << "Internal error in candidate search!";
531 }
532 max_benefit = benefit;
533 best_parallel_info = ParallelInfo(other_candidates, dim_infos, fusion_info);
534 i += begin - 1;
535 }
536
537 if (max_benefit > 0) {
538 parallel_infos.push_back(best_parallel_info);
539 for (const auto &node : best_parallel_info.nodes()) {
540 sorted_candidates_used[IntToSize(get_index(sorted_indices, node))] = true;
541 origin_candidates_used[IntToSize(get_index(origin_indices, node))] = true;
542 }
543 }
544 }
545
546 // Current nodes is not suitable to fuse, so pop first node to try other fusion possibility.
547 if (parallel_infos.size() == 0) {
548 origin_candidates_used[IntToSize(get_index(origin_indices, candidates[parallel_infos.size()]))] = true;
549 }
550
551 return std::make_tuple(origin_candidates_used, parallel_infos);
552 }
553
SearchFuseNodesInCandidates(const AnfNodePtrList & cs)554 std::tuple<std::vector<bool>, std::vector<ParallelInfo>> ParallelOpFusion::SearchFuseNodesInCandidates(
555 const AnfNodePtrList &cs) {
556 std::map<AnfNodePtr, int> origin_indices;
557 std::vector<size_t> indices;
558 for (size_t i = 0; i < cs.size(); ++i) {
559 if (cs[i]) {
560 (void)origin_indices.emplace(cs[i], i);
561 indices.push_back(i);
562 }
563 }
564
565 // A calculated heavy node can cover more lighter nodes' cost, so sort them first.
566 std::map<size_t, int> cal_amounts;
567 for (auto id : indices) {
568 cal_amounts[id] = cost_model_ptr_->GetNodeCalAmount(cs[id]);
569 }
570 std::sort(indices.begin(), indices.end(),
571 [&cal_amounts](size_t a, size_t b) { return cal_amounts[a] > cal_amounts[b]; });
572
573 AnfNodePtrList candidates;
574 for (size_t i = 0; i < indices.size(); ++i) {
575 candidates.push_back(cs[indices[i]]);
576 }
577
578 std::map<AnfNodePtr, int> sorted_indices;
579 for (size_t i = 0; i < candidates.size(); ++i) {
580 (void)sorted_indices.emplace(candidates[i], i);
581 }
582
583 return DoSearchInSortedCandidates(cs.size(), candidates, &origin_indices, &sorted_indices);
584 }
585
SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> & group,std::vector<ParallelInfo> * parallel_infos)586 void ParallelOpFusion::SearchFuseNodesInParallelGroup(const std::vector<AnfNodePtrList> &group,
587 std::vector<ParallelInfo> *parallel_infos) {
588 std::vector<AnfNodePtrList::const_iterator> tails;
589 std::vector<AnfNodePtrList::const_iterator> ended;
590 for (const auto &node_list : group) {
591 tails.push_back(node_list.begin());
592 ended.push_back(node_list.end());
593 }
594 auto get_candidates = [&tails, &ended]() {
595 AnfNodePtrList candidates;
596 for (size_t id = 0; id < tails.size(); ++id) {
597 candidates.push_back(tails[id] != ended[id] ? *tails[id] : AnfNodePtr());
598 }
599 return candidates;
600 };
601 auto update_tails = [&tails](const std::vector<bool> &used) {
602 if (used.size() != tails.size()) {
603 MS_LOG(EXCEPTION) << "Judged nodes size is not equal to left ones!";
604 }
605 for (size_t id = 0; id < used.size(); ++id) {
606 if (used[id]) {
607 ++tails[id];
608 }
609 }
610 };
611 auto valid_candidate_num = [](const AnfNodePtrList &cs) {
612 return std::count_if(cs.begin(), cs.end(), [](const AnfNodePtr &n) { return n != nullptr; });
613 };
614
615 auto candidates = get_candidates();
616 while (valid_candidate_num(candidates) > 1) {
617 auto [used, fnds] = SearchFuseNodesInCandidates(candidates);
618 std::transform(fnds.cbegin(), fnds.cend(), std::back_insert_iterator(*parallel_infos),
619 [](const ParallelInfo &pi) { return pi; });
620 update_tails(used);
621 candidates = get_candidates();
622 }
623 }
624
SearchFusableParallelCNodes(const std::vector<std::vector<AnfNodePtrList>> & groups)625 std::vector<ParallelInfo> ParallelOpFusion::SearchFusableParallelCNodes(
626 const std::vector<std::vector<AnfNodePtrList>> &groups) {
627 // Find core-fusable groups with cost model.
628 std::vector<ParallelInfo> parallel_infos;
629 for (const auto &group : groups) {
630 SearchFuseNodesInParallelGroup(group, ¶llel_infos);
631 }
632
633 return parallel_infos;
634 }
635
SetFusedParallelOpAttrToReturnNode(const ParallelInfo & parallel_info)636 void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo ¶llel_info) {
637 AnfNodePtr attach_node;
638 // Dim info should be attach to each segment's output.
639 for (size_t i = 0; i < parallel_info.GetSize(); ++i) {
640 const auto &fuse_nodes = parallel_info.nodes();
641 std::vector<size_t> info = {i, std::dynamic_pointer_cast<CommonDimInfo>(parallel_info.dims()[i])->dim_info()};
642 if (!AnfAlgo::IsGraphKernel(fuse_nodes[i])) {
643 attach_node = fuse_nodes[i];
644 SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), fuse_nodes[i]);
645 } else {
646 auto node_g = GetValueNode<FuncGraphPtr>((fuse_nodes[i]->cast<CNodePtr>())->input(0));
647 auto out_node = node_g->output();
648 if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
649 auto inputs = out_node->cast<CNodePtr>()->inputs();
650 for (size_t j = 1; j < inputs.size(); ++j) {
651 SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), inputs[j]);
652 }
653 attach_node = inputs[1];
654 } else {
655 attach_node = out_node;
656 SetNodeAttrSafely(kAttrParallelDimInfo, MakeValue<std::vector<size_t>>(info), out_node);
657 }
658 }
659 }
660
661 // Fusion info is ok to attach to one of the segments.
662 SetFusionInfoAttrToNode(attach_node, parallel_info);
663 }
664
SetFusionInfoAttrToNode(const AnfNodePtr & node,const ParallelInfo & parallel_info)665 void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) {
666 auto fusion_type = parallel_info.fusion_info()->FusionType();
667 AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
668 if (parallel_info.fusion_info()->ExistTypeInfo()) {
669 if (auto pipeline_fusion = std::dynamic_pointer_cast<BlockPipelineFusionInfo>(parallel_info.fusion_info())) {
670 AnfAlgo::SetNodeAttr(kAttrParallelTypeInfo,
671 MakeValue<std::vector<std::vector<int>>>(pipeline_fusion->PipelineIds()), node);
672 }
673 }
674 }
675
CreateParallelOpSubGraphs(const std::vector<ParallelInfo> & parallel_infos,const std::shared_ptr<session::KernelGraph> & kernel_graph)676 bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo> ¶llel_infos,
677 const std::shared_ptr<session::KernelGraph> &kernel_graph) {
678 bool changed = false;
679
680 for (size_t i = 0; i < parallel_infos.size(); ++i) {
681 const auto &fuse_nodes = parallel_infos[i].nodes();
682 if (fuse_nodes.size() <= 1) {
683 continue;
684 }
685 changed = true;
686 SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
687 AnfNodePtr sg_node;
688 std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
689 AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
690 DumpParallelFusionDetail(fuse_nodes, sg_node);
691 }
692
693 return changed;
694 }
695
Run(const FuncGraphPtr & graph)696 bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
697 MS_EXCEPTION_IF_NULL(graph);
698 (void)std::make_shared<ShrinkUpdateState>()->Run(graph);
699 auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
700 MS_EXCEPTION_IF_NULL(kernel_graph);
701
702 cost_model_ptr_ = ParellelCostModelWarehouse::Instance().GetParallelCostModel(target_);
703 MS_EXCEPTION_IF_NULL(cost_model_ptr_);
704
705 auto nodes = TopoSort(kernel_graph->get_return());
706 std::reverse(nodes.begin(), nodes.end());
707
708 auto node_rels = GenAnalysisGraph(nodes);
709 auto groups = SearchParallelGroups(node_rels);
710 auto parallel_infos = SearchFusableParallelCNodes(groups);
711
712 // Create core-fuse subgraph and change origin graph.
713 bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
714 (void)std::make_shared<SpreadUpdateState>()->Run(graph);
715 return changed;
716 }
717 } // namespace opt
718 } // namespace mindspore
719