• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/ascend_stream_assign.h"
18 
19 #include <algorithm>
20 #include <utility>
21 
22 #include "ir/manager.h"
23 #include "utils/ms_context.h"
24 #include "utils/ms_utils.h"
25 #include "frontend/parallel/context.h"
26 #include "frontend/parallel/device_manager.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "runtime/device/kernel_adjust.h"
29 #include "backend/optimizer/common/helper.h"
30 #include "backend/kernel_compiler/oplib/oplib.h"
31 #include "utils/utils.h"
32 
33 #ifdef ENABLE_DUMP_IR
34 #include "debug/rdr/running_data_recorder.h"
35 #endif
36 
37 namespace mindspore {
38 namespace device {
39 namespace ascend {
40 namespace {
41 constexpr uint32_t kDeviceNumOfServer = 8;
42 constexpr uint32_t kDeviceNumThreshold = 1024;
43 const char kDefaultGroup[] = "__default_group";
44 constexpr auto kAttrStreamID = "stream_id";
45 
46 constexpr uint32_t kMaxStreamNum = 1024;
47 constexpr uint32_t kHcomSecondaryStreamNum = 3;
48 
49 constexpr uint32_t kMaxTaskNumPerStream = 1010;
50 constexpr uint32_t kMaxCommonNodeNumPerStream = 350;
51 
52 constexpr uint32_t kTaskNumPerHcomNode = 200;
53 constexpr uint32_t kTaskNumPerWorldHcomNode = 250;
54 constexpr uint32_t kTaskNumPerSameServerHcomNode = 125;
55 constexpr uint32_t kTaskNumPerHcomSendRecvNode = 15;
56 
57 constexpr size_t kHcomNum = 2;
58 constexpr size_t kLastGradHcomOffset = 2;
59 constexpr size_t kLastGradAndStatusNum = 2;
60 
IsSameServer(const std::vector<uint32_t> & rank_ids)61 bool IsSameServer(const std::vector<uint32_t> &rank_ids) {
62   auto min_iter = min_element(rank_ids.begin(), rank_ids.end());
63   uint32_t min = (min_iter != rank_ids.end()) ? *min_iter : 0;
64   auto max_iter = max_element(rank_ids.begin(), rank_ids.end());
65   uint32_t max = (max_iter != rank_ids.end()) ? *max_iter : 0;
66   return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
67 }
68 
DoGetHcomGroup(const string & original_group)69 string DoGetHcomGroup(const string &original_group) {
70   string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode();
71   if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) {
72     return original_group;
73   }
74 
75   if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) {
76     return kDefaultGroup;
77   }
78 
79   MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
80   auto group_info = parallel::g_device_manager->group_info();
81   for (const auto &info : group_info) {
82     if (info.first != original_group) {
83       continue;
84     }
85 
86     const auto &rank_ids = info.second;
87     if (IsSameServer(rank_ids)) {
88       return original_group;
89     } else {
90       return kDefaultGroup;
91     }
92   }
93 
94   // world group is not in group_info.
95   return kDefaultGroup;
96 }
97 
GetHcomGroup(const CNodePtr & cnode)98 string GetHcomGroup(const CNodePtr &cnode) {
99   MS_EXCEPTION_IF_NULL(cnode);
100   if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
101     MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
102   }
103 
104   auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
105   auto new_group = DoGetHcomGroup(group_name);
106   MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name
107               << ", new group: " << new_group;
108 
109   return new_group;
110 }
111 
GetHcomTaskNum(const CNodePtr & cnode)112 uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
113   MS_EXCEPTION_IF_NULL(cnode);
114   if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
115     MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
116   }
117 
118   if (parallel::g_device_manager == nullptr) {
119     MS_LOG(INFO) << "Device manager is nullptr.";
120     return kTaskNumPerHcomNode;
121   }
122 
123   auto node_name = AnfAlgo::GetCNodeName(cnode);
124   if (node_name == kHcomSendOpName || node_name == kReceiveOpName) {
125     return kTaskNumPerHcomSendRecvNode;
126   }
127 
128   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
129   auto device_num = parallel::ParallelContext::GetInstance()->device_num();
130   auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
131   auto group_info = parallel::g_device_manager->group_info();
132   for (const auto &info : group_info) {
133     if (info.first != group_name) {
134       continue;
135     }
136     const auto &rank_ids = info.second;
137     if (IsSameServer(rank_ids)) {
138       return kTaskNumPerSameServerHcomNode;
139     } else if (rank_ids.size() == static_cast<size_t>(device_num) && device_num >= kDeviceNumThreshold) {
140       return kTaskNumPerWorldHcomNode;
141     } else {
142       return kTaskNumPerHcomNode;
143     }
144   }
145 
146   // world group is not in group_info.
147   if (device_num >= kDeviceNumThreshold) {
148     return kTaskNumPerWorldHcomNode;
149   } else {
150     return kTaskNumPerHcomNode;
151   }
152 }
153 
GetHcomAndOverflowMarker(const NotNull<KernelGraphPtr> & graph_ptr,vector<CNodePtr> * hcom_nodes)154 CNodePtr GetHcomAndOverflowMarker(const NotNull<KernelGraphPtr> &graph_ptr, vector<CNodePtr> *hcom_nodes) {
155   MS_EXCEPTION_IF_NULL(hcom_nodes);
156   auto cnode_ptr_list = graph_ptr->execution_order();
157   CNodePtr overflow_marker = nullptr;
158   std::string kNPUGetFloatStatusOpName = "NPUGetFloatStatus";
159   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
160     auto cur_cnode_ptr = cnode_ptr_list[i];
161     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
162     if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kNPUGetFloatStatusOpName) {
163       overflow_marker = cur_cnode_ptr;
164     } else if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL) {
165       hcom_nodes->emplace_back(cur_cnode_ptr);
166     } else if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) {
167       auto graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
168       AnfAlgo::SetGraphId(graph_id, cnode_ptr_list[i - 1].get());
169     }
170   }
171   return overflow_marker;
172 }
173 
HasRefNodes(const vector<CNodePtr> & moved_backward_cnodes)174 bool HasRefNodes(const vector<CNodePtr> &moved_backward_cnodes) {
175   for (auto &cnode : moved_backward_cnodes) {
176     std::string op_name = AnfAlgo::GetCNodeName(cnode);
177     auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
178     if (op_info != nullptr && op_info->is_ref()) {
179       MS_LOG(INFO) << "Find RefNode: " << op_name << ", full name: " << cnode->fullname_with_scope();
180       return true;
181     }
182   }
183   return false;
184 }
185 
GetStreamKind(uint32_t cur_stream_id,uint32_t pre_stream_id,uint32_t next_stream_id)186 StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, uint32_t next_stream_id) {
187   // pre_stream_id equal to UINT32_MAX means no node active current StreamActive
188   // next_stream_id equal to UINT32_MAX means current StreamActive active no node
189   if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
190     return kInvalid;
191   }
192 
193   if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
194     return kMiddle;
195   }
196 
197   if (cur_stream_id == pre_stream_id) {
198     return kTail;
199   }
200 
201   if (cur_stream_id == next_stream_id) {
202     return kHead;
203   }
204 
205   return kInvalid;
206 }
SetNodeStreamIDAttr(const NotNull<KernelGraphPtr> & graph_ptr)207 void SetNodeStreamIDAttr(const NotNull<KernelGraphPtr> &graph_ptr) {
208   auto exec_orders = graph_ptr->execution_order();
209   for (auto node : exec_orders) {
210     AnfAlgo::SetNodeAttr(kAttrStreamID, MakeValue<uint32_t>(AnfAlgo::GetStreamId(node)), node);
211   }
212 }
213 }  // namespace
214 
AssignStream(const NotNull<KernelGraphPtr> & graph_ptr)215 void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
216   if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
217     MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
218                  << ".";
219 
220     Reset();
221     SetLoopSink();
222     ReorderIndependentOrders(graph_ptr);
223     TrailingTimeOptimizationByReorder(graph_ptr);
224 
225     AssignAllNodesStream(graph_ptr);
226     UpdateAtomicAddrCleanStreamId(graph_ptr);
227     InsertStreamActive(graph_ptr);
228     InsertEventForHcomParallel(graph_ptr);
229     InsertEventForIndependentParallel(graph_ptr);
230     GetIndependentMaxTarget(graph_ptr);
231     InsertCtrlForIndependentParallel(graph_ptr);
232     AdjustAtomicAddrCleanOrder(graph_ptr);
233 
234     GetNeedActiveStreams(graph_ptr);
235 
236     MS_LOG(INFO) << "Before check resource assign";
237     graph_ptr->PrintGraphExecuteOrder();
238 
239     CheckResourceAssign(graph_ptr);
240     MS_LOG(INFO) << "After finish stream assign";
241 #ifdef ENABLE_DUMP_IR
242     SubModuleId module = SubModuleId::SM_SESSION;
243     std::string name = "assign_stream." + std::to_string(graph_ptr->graph_id());
244     const std::vector<CNodePtr> &exec_order = graph_ptr->execution_order();
245     (void)mindspore::RDR::RecordStreamExecOrder(module, name, exec_order);
246 #endif
247     graph_ptr->PrintGraphExecuteOrder();
248     SetNodeStreamIDAttr(graph_ptr);
249     FindStreamRelations(graph_ptr);
250     PrintStreamRelations();
251     GetStreamRelations();
252     PrintStreamGroups();
253     FindEventRelations(graph_ptr);
254   }
255 }
256 
SetLoopSink()257 void AscendStreamAssign::SetLoopSink() {
258   if (KernelAdjust::NeedInsertSwitch()) {
259     loop_sink_ = true;
260   } else {
261     loop_sink_ = false;
262   }
263 }
264 
265 // section 1
ReorderIndependentOrders(const NotNull<KernelGraphPtr> & graph_ptr)266 void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr) {
267   std::vector<CNodePtr> exe_orders;
268   std::vector<CNodePtr> independents;
269   std::vector<CNodePtr> others;
270 
271   auto cnode_ptr_list = graph_ptr->execution_order();
272   MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size();
273   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
274     auto cur_cnode_ptr = cnode_ptr_list[i];
275     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
276     if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
277       independents.emplace_back(cur_cnode_ptr);
278     } else {
279       others.emplace_back(cur_cnode_ptr);
280     }
281   }
282 
283   if (others.empty() || independents.empty()) {
284     MS_LOG(INFO) << "Independent or others is empty, no need reorder";
285     return;
286   }
287 
288   std::set<CNode *> processed;
289   for (size_t i = 0; i < others.size(); i++) {
290     auto begin = others.begin() + i;
291     auto end = begin + 1;
292     bool flag = false;
293     for (size_t j = 0; j < independents.size(); j++) {
294       auto cur_independent = independents[j];
295       auto it = std::find(processed.begin(), processed.end(), cur_independent.get());
296       if (it != processed.end()) {
297         continue;
298       }
299 
300       auto res = FindTargetOp(begin, end, cur_independent, false);
301       if (res != end) {
302         flag = true;
303         exe_orders.emplace_back(cur_independent);
304         exe_orders.emplace_back(*begin);
305         processed.emplace(cur_independent.get());
306         break;
307       }
308     }
309 
310     if (!flag) {
311       exe_orders.emplace_back(*begin);
312     }
313   }
314 
315   MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size();
316   if (processed.size() != independents.size()) {
317     MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size";
318     return;
319   }
320 
321   graph_ptr->set_execution_order(exe_orders);
322 }
323 
CheckScenario(const NotNull<KernelGraphPtr> & graph_ptr,vector<CNodePtr> * last_grad_and_status)324 void AscendStreamAssign::CheckScenario(const NotNull<KernelGraphPtr> &graph_ptr,
325                                        vector<CNodePtr> *last_grad_and_status) {
326   MS_EXCEPTION_IF_NULL(last_grad_and_status);
327   auto cnode_ptr_list = graph_ptr->execution_order();
328   vector<CNodePtr> hcom_nodes;
329   auto overflow_marker = GetHcomAndOverflowMarker(graph_ptr, &hcom_nodes);
330   if (hcom_nodes.size() < kHcomNum || overflow_marker == nullptr) {
331     MS_LOG(INFO) << "Current model isn't in distribute or mix-precision mode, no optimization needed";
332     last_grad_and_status->clear();
333     return;
334   }
335 
336   auto overflow_marker_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), overflow_marker);
337   auto last_hcom_ptr = hcom_nodes[hcom_nodes.size() - 1];
338   auto last_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_hcom_ptr);
339   auto last_grad_hcom_ptr = hcom_nodes[hcom_nodes.size() - kLastGradHcomOffset];
340   auto last_grad_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_hcom_ptr);
341   if (last_grad_hcom_pos > overflow_marker_pos || last_hcom_pos < overflow_marker_pos) {
342     MS_LOG(INFO) << "Grads average done after overflow judgement or status aren't allgathered, no optimization needed";
343     last_grad_and_status->clear();
344     return;
345   }
346 
347   auto last_inputs = GetLastInputCnode(graph_ptr, last_grad_hcom_ptr);
348   if (last_inputs.empty() || last_inputs.size() > 1 || IsHcom(last_inputs[0])) {
349     MS_LOG(INFO) << "Inputs of last gradients allreduce is empty or include other allreduce, no optimization needed";
350     last_grad_and_status->clear();
351     return;
352   }
353   auto last_grad_ptr = last_inputs[0];
354   MS_LOG(DEBUG) << "Last Hcom: " << last_grad_hcom_ptr->fullname_with_scope()
355                 << "; last input: " << last_grad_ptr->fullname_with_scope();
356   auto last_grad_hcom_graph_id = AnfAlgo::GetGraphId(last_grad_hcom_ptr.get());
357   auto last_grad_graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get());
358   auto overflow_marker_graph_id = AnfAlgo::GetGraphId(overflow_marker.get());
359   if (last_grad_graph_id != last_grad_hcom_graph_id || last_grad_graph_id != overflow_marker_graph_id) {
360     MS_LOG(INFO) << "The grads and grad_hcom or overflow marker were not on the same subgraph, no optimization needed";
361     last_grad_and_status->clear();
362     return;
363   }
364 
365   auto label_switch_pos = find_if(last_grad_hcom_pos, cnode_ptr_list.end(),
366                                   [](CNodePtr &node) -> bool { return AnfAlgo::GetCNodeName(node) == "LabelSwitch"; });
367   if (label_switch_pos == cnode_ptr_list.end()) {
368     MS_LOG(INFO) << "No branches after getting overflow status, no optimization needed";
369     last_grad_and_status->clear();
370     return;
371   }
372   last_grad_and_status->emplace_back(last_grad_ptr);
373   last_grad_and_status->emplace_back(overflow_marker);
374   return;
375 }
376 
GetCNodesNeededMoved(vector<CNodePtr> * moved_backward_cnodes,vector<CNodePtr> * moved_forward_cnodes,const vector<CNodePtr> & last_grad_and_status,const NotNull<KernelGraphPtr> & graph_ptr)377 CNodePtr AscendStreamAssign::GetCNodesNeededMoved(vector<CNodePtr> *moved_backward_cnodes,
378                                                   vector<CNodePtr> *moved_forward_cnodes,
379                                                   const vector<CNodePtr> &last_grad_and_status,
380                                                   const NotNull<KernelGraphPtr> &graph_ptr) {
381   MS_EXCEPTION_IF_NULL(moved_backward_cnodes);
382   MS_EXCEPTION_IF_NULL(moved_forward_cnodes);
383   auto cnode_ptr_list = graph_ptr->execution_order();
384   if (last_grad_and_status.size() != kLastGradAndStatusNum) {
385     return nullptr;
386   }
387   auto last_grad_ptr = last_grad_and_status[0];
388   auto float_status_ptr = last_grad_and_status[1];
389   auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr);
390   auto float_status_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), float_status_ptr);
391   if (last_grad_pos == cnode_ptr_list.end() || float_status_pos == cnode_ptr_list.end()) {
392     return nullptr;
393   }
394   auto graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get());
395   moved_backward_cnodes->insert(moved_backward_cnodes->end(), last_grad_pos + 1, float_status_pos);
396 
397   auto it = float_status_pos;
398   while (AnfAlgo::GetGraphId((*it).get()) == graph_id && it < cnode_ptr_list.end()) {
399     if (AnfAlgo::GetCNodeName(*it) == kAtomicAddrCleanOpName) {
400       it++;
401       continue;
402     }
403     auto inputs = GetInputKernels(*it);
404     bool is_independent = true;
405     for (auto &input : inputs) {
406       if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) {
407         is_independent = false;
408         break;
409       }
410     }
411     if (is_independent) {
412       if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) {
413         moved_forward_cnodes->emplace_back(*(it - 1));
414       }
415       moved_forward_cnodes->emplace_back(*it);
416     } else {
417       if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) {
418         moved_backward_cnodes->emplace_back(*(it - 1));
419       }
420       moved_backward_cnodes->emplace_back(*it);
421     }
422     it++;
423   }
424 
425   size_t total_moved_size = LongToSize(it - last_grad_pos - 1);
426   if (HasRefNodes(*moved_backward_cnodes) ||
427       moved_backward_cnodes->size() + moved_forward_cnodes->size() != total_moved_size) {
428     MS_LOG(INFO) << "Ref node was found or invalid number of moved nodes, give up optimization";
429     return nullptr;
430   }
431   return GetTargetOutputNode(*moved_backward_cnodes, *it, graph_ptr);
432 }
433 
GetTargetOutputNode(const vector<CNodePtr> & moved_backward_cnodes,const CNodePtr first_node,const NotNull<KernelGraphPtr> & graph_ptr)434 CNodePtr AscendStreamAssign::GetTargetOutputNode(const vector<CNodePtr> &moved_backward_cnodes,
435                                                  const CNodePtr first_node, const NotNull<KernelGraphPtr> &graph_ptr) {
436   auto cnode_ptr_list = graph_ptr->execution_order();
437   if (moved_backward_cnodes.empty() || !first_node) {
438     return nullptr;
439   }
440   uint32_t subgraph_id = 0;
441   bool get_subgraph_id = false;
442   auto it = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), first_node);
443   CNodePtr first_output_node_ptr = nullptr;
444   while (!get_subgraph_id && it < cnode_ptr_list.end()) {
445     auto inputs = GetInputKernels(*it);
446     for (auto &input : inputs) {
447       if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end()) {
448         get_subgraph_id = true;
449         subgraph_id = AnfAlgo::GetGraphId((*it).get());
450         first_output_node_ptr = *it;
451         break;
452       }
453     }
454     it++;
455   }
456   if (subgraph_id == 0) {
457     MS_LOG(INFO) << "The nodes moved backward were not used by any other nodes, no need moved";
458     return nullptr;
459   }
460 
461   for (; it < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*it).get()) != subgraph_id; it++) {
462     auto inputs = GetInputKernels(*it);
463     for (auto &input : inputs) {
464       if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end()) {
465         MS_LOG(INFO) << "The nodes moved backward were used by nodes on different subgraphs, no need moved";
466         return nullptr;
467       }
468     }
469   }
470   return first_output_node_ptr;
471 }
472 
FinetuneSubgraphExecOrder(vector<CNodePtr> * cnodes)473 bool AscendStreamAssign::FinetuneSubgraphExecOrder(vector<CNodePtr> *cnodes) {
474   MS_EXCEPTION_IF_NULL(cnodes);
475   auto hcom_pos = find_if(cnodes->begin(), cnodes->end(),
476                           [](CNodePtr &node_ptr) -> bool { return AnfAlgo::GetCNodeName(node_ptr) == "AllReduce"; });
477   if (hcom_pos == cnodes->end()) {
478     return false;
479   }
480   CNodePtr hcom_ptr = *hcom_pos;
481 
482   vector<CNodePtr> ori_cnodes(cnodes->begin(), cnodes->end());
483   cnodes->clear();
484   vector<CNodePtr> atomic_addr_clean;
485   for (auto iter = ori_cnodes.begin(); iter < ori_cnodes.end(); ++iter) {
486     if (AnfAlgo::GetCNodeName(*iter) == kAtomicAddrCleanOpName) {
487       atomic_addr_clean.emplace_back(*iter);
488       continue;
489     }
490     auto last_input_pos = cnodes->end();
491     for (auto &input : GetInputKernels(*iter)) {
492       auto pos = find(cnodes->begin(), cnodes->end(), input);
493       if (pos != cnodes->end()) {
494         last_input_pos = (last_input_pos == cnodes->end() || last_input_pos < pos) ? pos : last_input_pos;
495       }
496     }
497     if (last_input_pos == cnodes->end()) {
498       auto hcom_it = find(cnodes->begin(), cnodes->end(), hcom_ptr);
499       if (hcom_it == cnodes->end() || AnfAlgo::GetCNodeName(*iter) == kLabelGotoOpName ||
500           AnfAlgo::GetCNodeName(*iter) == kLabelSetOpName || AnfAlgo::GetCNodeName(*iter) == kLabelSwitchOpName) {
501         cnodes->emplace_back(*iter);
502       } else {
503         cnodes->insert(hcom_it, *iter);
504       }
505     } else {
506       cnodes->insert(last_input_pos + 1, *iter);
507     }
508   }
509 
510   for (auto &node : atomic_addr_clean) {
511     auto first_input_pos = cnodes->end();
512     for (auto &input : GetInputKernels(node)) {
513       auto pos = find(cnodes->begin(), cnodes->end(), input);
514       first_input_pos = (first_input_pos == cnodes->end() || first_input_pos > pos) ? pos : first_input_pos;
515     }
516     if (first_input_pos == cnodes->end()) {
517       return false;
518     } else {
519       cnodes->insert(first_input_pos, node);
520     }
521   }
522   return cnodes->size() == ori_cnodes.size();
523 }
524 
525 // performance optimization for trailing time in distribute mode
526 // allreduce of the last batch of gradients and the optimizer can be done parallel
TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> & graph_ptr)527 void AscendStreamAssign::TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> &graph_ptr) {
528   vector<CNodePtr> last_grad_and_status;
529   CheckScenario(graph_ptr, &last_grad_and_status);
530   if (last_grad_and_status.empty()) {
531     MS_LOG(INFO) << "Unsuitable scenario, no optimization needed";
532     return;
533   }
534 
535   auto cnode_ptr_list = graph_ptr->execution_order();
536   vector<CNodePtr> moved_forward_cnodes;
537   vector<CNodePtr> moved_backward_cnodes;
538   CNodePtr first_output_ptr =
539     GetCNodesNeededMoved(&moved_backward_cnodes, &moved_forward_cnodes, last_grad_and_status, graph_ptr);
540   if (moved_backward_cnodes.empty() || first_output_ptr == nullptr) {
541     MS_LOG(INFO) << "Unsuitable scenario, no optimization needed";
542     return;
543   }
544 
545   uint32_t subgraph_id = AnfAlgo::GetGraphId(first_output_ptr.get());
546   auto last_grad_ptr = last_grad_and_status[0];
547   auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr);
548   vector<CNodePtr> cnodes(cnode_ptr_list.begin(), last_grad_pos + 1);
549   cnodes.insert(cnodes.end(), moved_forward_cnodes.begin(), moved_forward_cnodes.end());
550   auto pos = last_grad_pos + moved_forward_cnodes.size() + moved_backward_cnodes.size() + 1;
551   while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) != subgraph_id) {
552     cnodes.emplace_back(*pos);
553     ++pos;
554   }
555 
556   vector<CNodePtr> subgraph_cnodes;
557   while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) == subgraph_id) {
558     if (AnfAlgo::GetCNodeName(*pos) == kLabelGotoOpName) {
559       break;
560     }
561     if (*pos != first_output_ptr) {
562       subgraph_cnodes.emplace_back(*pos);
563     } else {
564       subgraph_cnodes.insert(subgraph_cnodes.end(), moved_backward_cnodes.begin(), moved_backward_cnodes.end());
565       subgraph_cnodes.emplace_back(*pos);
566     }
567     ++pos;
568   }
569 
570   if (!FinetuneSubgraphExecOrder(&subgraph_cnodes) || subgraph_cnodes.empty()) {
571     MS_LOG(INFO) << "Finetune subgraph execute order failed, no optimization needed";
572     return;
573   }
574 
575   cnodes.insert(cnodes.end(), subgraph_cnodes.begin(), subgraph_cnodes.end());
576   cnodes.insert(cnodes.end(), pos, cnode_ptr_list.end());
577   if (cnodes.size() != cnode_ptr_list.size()) {
578     return;
579   }
580   for (auto &node : subgraph_cnodes) {
581     AnfAlgo::SetGraphId(subgraph_id, node.get());
582   }
583 
584   graph_ptr->set_execution_order(cnodes);
585 }
586 
587 // section 2
AssignAllNodesStream(const NotNull<KernelGraphPtr> & graph_ptr)588 void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr) {
589   auto cnode_ptr_list = graph_ptr->execution_order();
590   bool exit_independent = false;
591   bool exit_hcom = false;
592   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
593   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
594     CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
595     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
596     // node has been assigned stream before
597     if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
598       continue;
599     }
600 
601     if (IsHcom(cur_cnode_ptr)) {
602       exit_hcom = true;
603       continue;
604     }
605 
606     if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
607       exit_independent = true;
608       continue;
609     }
610 
611     AssignCommonStreamId(cur_cnode_ptr);
612   }
613 
614   auto common_stream_num = resource_manager.get_cur_stream_num();
615 
616   if (exit_hcom) {
617     AssignHcom(graph_ptr);
618   }
619   auto hcom_stream_num = resource_manager.get_cur_stream_num() - common_stream_num;
620 
621   if (exit_independent) {
622     AssignIndependent(graph_ptr);
623   }
624   auto independent_stream_num = resource_manager.get_cur_stream_num() - common_stream_num - hcom_stream_num;
625   auto total_stream_num =
626     resource_manager.get_cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum);
627   MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num
628                << ", hcom stream number: " << hcom_stream_num << "*" << (kHcomSecondaryStreamNum + 1)
629                << ", independent stream number: " << independent_stream_num << ".";
630 
631   if (total_stream_num > kMaxStreamNum) {
632     MS_LOG(EXCEPTION) << "Total stream number " << total_stream_num << " exceeds the limit of " << kMaxStreamNum
633                       << ", search details information in mindspore's FAQ.";
634   }
635 
636   MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
637 }
638 
AssignCommonStreamId(const CNodePtr & cur_cnode_ptr)639 void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
640   MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
641   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
642   uint32_t cur_common_stream_id = 0;
643   uint32_t cur_stream_num = resource_manager.get_cur_stream_num();
644   if (cur_stream_num == 0) {
645     cur_common_stream_id = resource_manager.ApplyNewStream();
646   } else {
647     cur_common_stream_id = resource_manager.GetCurAllocStreamId();
648   }
649 
650   auto it = common_stream_map_.find(cur_common_stream_id);
651   if (it == common_stream_map_.end()) {
652     AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
653     common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
654   } else {
655     if (it->second < kMaxCommonNodeNumPerStream) {
656       AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
657       it->second++;
658     } else {
659       cur_common_stream_id = resource_manager.ApplyNewStream();
660       AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
661       common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
662     }
663   }
664 }
665 
AssignHcom(const NotNull<KernelGraphPtr> & graph_ptr)666 void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
667   auto cnode_ptr_list = graph_ptr->execution_order();
668   std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>> group_graph_nodes_map;
669   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
670     CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
671     // node has been assigned stream before
672     if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
673       continue;
674     }
675 
676     if (IsHcom(cur_cnode_ptr)) {
677       auto group_name = GetHcomGroup(cur_cnode_ptr);
678       auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
679       auto iter = group_graph_nodes_map.find(group_name);
680       if (iter == group_graph_nodes_map.end()) {
681         std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
682         graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
683         group_graph_nodes_map[group_name] = graph_nodes_map;
684       } else {
685         auto &graph_nodes_map = iter->second;
686         auto it = graph_nodes_map.find(hcom_graph_id);
687         if (it == graph_nodes_map.end()) {
688           graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
689         } else {
690           it->second.emplace_back(cur_cnode_ptr);
691         }
692       }
693     }
694   }
695 
696   MS_LOG(INFO) << "hcom diff group size:" << group_graph_nodes_map.size();
697   for (const auto &item : group_graph_nodes_map) {
698     MS_LOG_INFO << "group id:" << item.first << "; diff graph id size:" << item.second.size();
699   }
700 
701   for (const auto &diff_group : group_graph_nodes_map) {
702     // group id:
703     std::map<uint32_t, std::set<uint32_t>> hcom_graph_map;
704     for (const auto &item : diff_group.second) {
705       bool new_graph = true;
706       auto graph_id = item.first;
707       hcom_graph_map[graph_id] = {};
708       for (const auto &hcom_node_ptr : item.second) {
709         auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph);
710         hcom_graph_map[graph_id].emplace(assigned_stream_id);
711         new_graph = false;
712       }
713     }
714     group_hcom_graph_map_[diff_group.first] = hcom_graph_map;
715   }
716 }
717 
AssignHcomStreamId(const CNodePtr & cur_cnode_ptr,bool new_graph)718 uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
719   MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
720   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
721   auto task_num = GetHcomTaskNum(cur_cnode_ptr);
722 
723   uint32_t cur_hcom_stream_id;
724   if (new_graph) {
725     cur_hcom_stream_id = resource_manager.ApplyNewStream();
726   } else {
727     cur_hcom_stream_id = resource_manager.GetCurAllocStreamId();
728   }
729   auto it = hcom_stream_map_.find(cur_hcom_stream_id);
730   if (it == hcom_stream_map_.end()) {
731     AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
732     hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
733   } else {
734     if (it->second <= kMaxTaskNumPerStream - task_num) {
735       AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
736       it->second = Uint32tAddWithOverflowCheck(it->second, task_num);
737     } else {
738       cur_hcom_stream_id = resource_manager.ApplyNewStream();
739       AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
740       hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
741     }
742   }
743   return cur_hcom_stream_id;
744 }
745 
AssignIndependent(const NotNull<KernelGraphPtr> & graph_ptr)746 void AscendStreamAssign::AssignIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
747   auto cnode_ptr_list = graph_ptr->execution_order();
748   std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
749   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
750     CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
751     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
752     if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
753       continue;
754     }
755     if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
756       auto independent_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
757       auto it = graph_nodes_map.find(independent_graph_id);
758       if (it == graph_nodes_map.end()) {
759         graph_nodes_map[independent_graph_id] = {cur_cnode_ptr};
760       } else {
761         it->second.emplace_back(cur_cnode_ptr);
762       }
763     }
764   }
765 
766   MS_LOG(INFO) << "independent diff graph id size:" << graph_nodes_map.size();
767   for (const auto &item : graph_nodes_map) {
768     bool new_graph = true;
769     auto graph_id = item.first;
770     independent_graph_map_[graph_id] = {};
771     for (const auto &independent_node_ptr : item.second) {
772       auto assigned_stream_id = AssignIndependentStreamId(independent_node_ptr, new_graph);
773       independent_graph_map_[graph_id].emplace(assigned_stream_id);
774       new_graph = false;
775     }
776   }
777   MS_LOG(INFO) << "stream nums:" << independent_stream_map_.size();
778 }
779 
AssignIndependentStreamId(const CNodePtr & cur_cnode_ptr,bool new_graph)780 uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
781   MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
782   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
783   uint32_t cur_independent_stream_id;
784   if (new_graph) {
785     cur_independent_stream_id = resource_manager.ApplyNewStream();
786   } else {
787     cur_independent_stream_id = resource_manager.GetCurAllocStreamId();
788   }
789   auto it = independent_stream_map_.find(cur_independent_stream_id);
790   if (it == independent_stream_map_.end()) {
791     AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
792     independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
793   } else {
794     if (it->second < kMaxCommonNodeNumPerStream) {
795       AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
796       it->second++;
797     } else {
798       cur_independent_stream_id = resource_manager.ApplyNewStream();
799       AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
800       independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
801     }
802   }
803 
804   return cur_independent_stream_id;
805 }
806 
807 // section 3
UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> & graph_ptr)808 void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
809   MS_LOG(INFO) << "Start";
810   auto cnode_ptr_list = graph_ptr->execution_order();
811   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
812     CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
813     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
814     // update AtomicAddrClean stream same with the next node
815     if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) {
816       AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get());
817     }
818   }
819   MS_LOG(INFO) << "End";
820 }
821 
822 // section 4
InsertStreamActive(const NotNull<KernelGraphPtr> & graph_ptr)823 void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr) {
824   InsertStreamActiveForCommon(graph_ptr);
825   InsertStreamActiveForIndependent(graph_ptr);
826   InsertStreamActiveForParallel(graph_ptr);
827 }
828 
InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> & graph_ptr)829 void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
830   if (group_hcom_graph_map_.empty() && independent_graph_map_.empty()) {
831     MS_LOG(INFO) << "Hcom and independent is empty";
832     return;
833   }
834   auto root_graph_id = graph_ptr->graph_id();
835   if (root_graph_id == kInvalidGraphId) {
836     MS_LOG(INFO) << "Root graph id is invalid";
837     return;
838   }
839 
840   std::map<uint32_t, std::set<uint32_t>> other_graph;
841   std::set<uint32_t> hcom_streams;
842   for (const auto &graph_nodes : group_hcom_graph_map_) {
843     for (const auto &item : graph_nodes.second) {
844       MS_LOG(INFO) << "Graph id:" << item.first;
845       if (item.first == root_graph_id) {
846         if (loop_sink_) {
847           hcom_streams.insert(item.second.begin(), item.second.end());
848         }
849       } else {
850         auto it = other_graph.find(item.first);
851         if (it == other_graph.end()) {
852           other_graph[item.first] = item.second;
853         } else {
854           for (const auto &stream : item.second) {
855             it->second.emplace(stream);
856           }
857         }
858       }
859     }
860   }
861 
862   if (!hcom_streams.empty()) {
863     ActiveRootGraphHcom(graph_ptr, hcom_streams);
864   }
865 
866   MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size();
867   for (const auto &item : independent_graph_map_) {
868     MS_LOG(DEBUG) << "Graph id:" << item.first;
869     if (item.first == root_graph_id) {
870       if (loop_sink_) {
871         ActiveRootGraphIndependent(graph_ptr, item.second);
872       }
873     } else {
874       auto it = other_graph.find(item.first);
875       if (it == other_graph.end()) {
876         other_graph[item.first] = item.second;
877       } else {
878         for (const auto &stream : item.second) {
879           it->second.emplace(stream);
880         }
881       }
882     }
883   }
884 
885   ActiveOtherGraphParallel(graph_ptr, other_graph);
886 }
887 
ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> & graph_ptr,std::map<uint32_t,std::set<uint32_t>> other_graph)888 void AscendStreamAssign::ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
889                                                   std::map<uint32_t, std::set<uint32_t>> other_graph) {
890   MS_LOG(INFO) << "Other graph size:" << other_graph.size();
891   if (other_graph.empty()) {
892     return;
893   }
894 
895   auto root_graph_id = graph_ptr->graph_id();
896 
897   std::vector<CNodePtr> update_stream_list;
898   auto exe_order = graph_ptr->execution_order();
899   for (size_t i = 0; i < exe_order.size(); i++) {
900     auto cur_cnode_ptr = exe_order[i];
901     auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
902     if (cur_graph_id == root_graph_id) {
903       update_stream_list.emplace_back(cur_cnode_ptr);
904       continue;
905     }
906 
907     auto it = other_graph.find(cur_graph_id);
908     if (it == other_graph.end()) {
909       update_stream_list.emplace_back(cur_cnode_ptr);
910       continue;
911     }
912 
913     auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
914     CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
915     // 1.set stream id
916     AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
917     // 2.set active stream ids
918     std::vector<uint32_t> active_index_list;
919     std::copy(it->second.begin(), it->second.end(), std::back_inserter(active_index_list));
920     AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
921 
922     // find position for insert streamactive
923     if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kLabelSetOpName) {
924       update_stream_list.emplace_back(cur_cnode_ptr);
925       update_stream_list.emplace_back(active_ptr);
926     } else {
927       update_stream_list.emplace_back(active_ptr);
928       update_stream_list.emplace_back(cur_cnode_ptr);
929     }
930     other_graph.erase(it);
931   }
932   graph_ptr->set_execution_order(update_stream_list);
933 }
934 
ActiveRootGraphHcom(const NotNull<KernelGraphPtr> & graph_ptr,const std::set<uint32_t> & hcom_streams)935 void AscendStreamAssign::ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr,
936                                              const std::set<uint32_t> &hcom_streams) {
937   MS_LOG(INFO) << "Active root graph hcom start";
938   std::vector<CNodePtr> update_cnode_list;
939   auto exe_orders = graph_ptr->execution_order();
940   for (size_t i = 0; i < exe_orders.size(); i++) {
941     CNodePtr cur_cnode_ptr = exe_orders[i];
942     if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) {
943       update_cnode_list.emplace_back(cur_cnode_ptr);
944       continue;
945     }
946 
947     if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) {
948       update_cnode_list.emplace_back(cur_cnode_ptr);
949       continue;
950     }
951 
952     auto kind = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrStreamSwitchKind);
953     if (kind != kFpBpStreamSwitch) {
954       update_cnode_list.emplace_back(cur_cnode_ptr);
955       continue;
956     }
957 
958     auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
959     MS_LOG(INFO) << "FpBpStreamswtich stream id:" << AnfAlgo::GetStreamId(cur_cnode_ptr)
960                  << "; true branch stream id:" << true_stream_id;
961     CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
962     AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
963     vector<uint32_t> active_ids;
964     // active hcom stream
965     std::copy(hcom_streams.begin(), hcom_streams.end(), std::back_inserter(active_ids));
966     AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr);
967     update_cnode_list.emplace_back(cur_cnode_ptr);
968     update_cnode_list.emplace_back(active_ptr);
969     std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
970     break;
971   }
972 
973   hcom_stream_activated_ = true;
974   graph_ptr->set_execution_order(update_cnode_list);
975 }
976 
ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> & graph_ptr,const std::set<uint32_t> & independent_streams)977 void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr,
978                                                     const std::set<uint32_t> &independent_streams) {
979   MS_LOG(DEBUG) << "Start active root graph independent";
980   std::vector<CNodePtr> update_cnode_list;
981   auto exe_orders = graph_ptr->execution_order();
982   for (size_t i = 0; i < exe_orders.size(); i++) {
983     CNodePtr cur_cnode_ptr = exe_orders[i];
984     if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) {
985       update_cnode_list.emplace_back(cur_cnode_ptr);
986       continue;
987     }
988 
989     if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) {
990       update_cnode_list.emplace_back(cur_cnode_ptr);
991       continue;
992     }
993 
994     auto kind = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrStreamSwitchKind);
995     if (kind != kIndependentStreamSwitch) {
996       update_cnode_list.emplace_back(cur_cnode_ptr);
997       continue;
998     }
999 
1000     // first independetn stream id is minimum and order by std map;
1001     auto first_independent_stream = *(independent_streams.begin());
1002     AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), cur_cnode_ptr);
1003     update_cnode_list.emplace_back(cur_cnode_ptr);
1004     std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
1005     break;
1006   }
1007 
1008   independent_stream_activated_ = true;
1009   graph_ptr->set_execution_order(update_cnode_list);
1010 }
InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> & graph_ptr)1011 void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
1012   MS_LOG(INFO) << "Start";
1013   GetProcessedStream(graph_ptr);
1014   std::vector<CNodePtr> update_cnode_list;
1015   CNodePtr cur_cnode_ptr = nullptr;
1016   CNodePtr pre_cnode_ptr = nullptr;
1017   uint32_t pre_stream_id = UINT32_MAX;
1018 
1019   auto cnode_ptr_list = graph_ptr->execution_order();
1020   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1021     cur_cnode_ptr = cnode_ptr_list[i];
1022     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1023     if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
1024       update_cnode_list.emplace_back(cur_cnode_ptr);
1025       continue;
1026     }
1027 
1028     if (IsHcom(cur_cnode_ptr)) {
1029       update_cnode_list.emplace_back(cur_cnode_ptr);
1030       continue;
1031     }
1032     uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1033     bool processed = IsProcessedStream(cur_stream_id);
1034     // 1)inner stream assign, need insert active op
1035     if (!processed) {
1036       MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id;
1037       CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
1038       // 1.set stream id
1039       AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get());
1040       // 2.set active stream ids
1041       std::vector<uint32_t> active_index_list{cur_stream_id};
1042       AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
1043       if (i > 0) {
1044         auto pre_node = AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]);
1045         if (pre_node == kLabelSwitchOpName || pre_node == kLabelGotoOpName) {
1046           update_cnode_list.insert(update_cnode_list.end() - 1, active_ptr);
1047           AnfAlgo::SetStreamId(cur_stream_id, cnode_ptr_list[i - 1].get());
1048         } else {
1049           update_cnode_list.emplace_back(active_ptr);
1050         }
1051       }
1052     }
1053     if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
1054       MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
1055       update_cnode_list.emplace_back(cur_cnode_ptr);
1056     } else {
1057       update_cnode_list.emplace_back(cur_cnode_ptr);
1058     }
1059 
1060     processed_streams_.emplace(cur_stream_id);
1061     pre_stream_id = cur_stream_id;
1062     pre_cnode_ptr = cur_cnode_ptr;
1063   }
1064   graph_ptr->set_execution_order(update_cnode_list);
1065 }
1066 
InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> & graph_ptr)1067 void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
1068   auto root_graph_id = graph_ptr->graph_id();
1069   if (root_graph_id == kInvalidGraphId) {
1070     return;
1071   }
1072   std::set<uint32_t> independent_streams;
1073   for (const auto &item : independent_graph_map_) {
1074     if (item.first == root_graph_id) {
1075       independent_streams = item.second;
1076     }
1077   }
1078 
1079   // Root graph independent stream size is not more than one, no need insert active
1080   if (independent_streams.size() <= 1) {
1081     return;
1082   }
1083   std::vector<CNodePtr> update_cnode_list;
1084   auto exe_orders = graph_ptr->execution_order();
1085 
1086   // first independent is been activated, active other independent stream
1087   std::vector<uint32_t> streams;
1088   std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams));
1089   std::sort(streams.begin(), streams.end());
1090   uint32_t node_num = 0;
1091   for (size_t i = 0; i < exe_orders.size(); i++) {
1092     auto cur_cnode_ptr = exe_orders[i];
1093     update_cnode_list.emplace_back(cur_cnode_ptr);
1094     if (!AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
1095       continue;
1096     }
1097 
1098     if (AnfAlgo::GetGraphId(cur_cnode_ptr.get()) != root_graph_id) {
1099       continue;
1100     }
1101 
1102     node_num++;
1103     auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1104     auto it = std::find(streams.begin(), streams.end(), cur_stream_id);
1105     if (it == streams.end()) {
1106       MS_LOG(EXCEPTION) << "Can't find independent stream id:" << cur_stream_id;
1107     } else if (it == streams.end() - 1) {
1108       std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
1109       break;
1110     } else {
1111       if (node_num == kMaxCommonNodeNumPerStream) {
1112         CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
1113         // 1.set stream id
1114         AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
1115         // 2.set active stream ids
1116         std::vector<uint32_t> active_index_list{*(it + 1)};
1117         AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
1118         update_cnode_list.emplace_back(active_ptr);
1119         node_num = 0;
1120       }
1121     }
1122   }
1123   graph_ptr->set_execution_order(update_cnode_list);
1124 }
1125 
GetProcessedStream(const NotNull<KernelGraphPtr> & graph_ptr)1126 void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr) {
1127   // 0 stream is activated at first
1128   processed_streams_.emplace(0);
1129   auto cnode_ptr_list = graph_ptr->execution_order();
1130   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1131     auto cur_cnode_ptr = cnode_ptr_list[i];
1132     uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1133 
1134     if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
1135       if (AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, cur_cnode_ptr)) {
1136         auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
1137         processed_streams_.emplace(true_stream_id);
1138       }
1139 
1140       if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
1141         continue;
1142       }
1143       auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
1144       if (need_active) {
1145         processed_streams_.emplace(cur_stream_id);
1146       }
1147     }
1148   }
1149   for (const auto &item : processed_streams_) {
1150     MS_LOG(INFO) << "Before active:" << item << " is been processed";
1151   }
1152 }
1153 
CheckStreamSwitch(const CNodePtr & switch_ptr)1154 bool AscendStreamAssign::CheckStreamSwitch(const CNodePtr &switch_ptr) {
1155   if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) {
1156     return false;
1157   }
1158 
1159   auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst);
1160   if (!need_active) {
1161     return false;
1162   }
1163 
1164   if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, switch_ptr)) {
1165     return false;
1166   }
1167 
1168   auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
1169   if (kind == kEosStreamSwitch || kind == kGetNextStreamSwitch) {
1170     return false;
1171   }
1172 
1173   return true;
1174 }
1175 
IsProcessedStream(uint32_t stream_id)1176 bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
1177   auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id);
1178   if (it != processed_streams_.end()) {
1179     return true;
1180   }
1181   return false;
1182 }
1183 
IsAllOutGraphOut(const KernelGraphPtr & graph,const CNodePtr & cnode)1184 bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) {
1185   MS_EXCEPTION_IF_NULL(graph);
1186   MS_EXCEPTION_IF_NULL(cnode);
1187   auto cnode_out_num = AnfAlgo::GetOutputTensorNum(cnode);
1188   auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
1189   std::set<int> output_index_set;
1190   // Assign Communicate Op Memory firstly.
1191   for (const auto &node : nodes) {
1192     auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
1193     MS_EXCEPTION_IF_NULL(item_with_index.first);
1194     if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
1195       continue;
1196     }
1197     if (item_with_index.first == cnode) {
1198       output_index_set.insert(item_with_index.second);
1199     }
1200   }
1201 
1202   MS_LOG(INFO) << "Node " << cnode->fullname_with_scope() << " has " << cnode_out_num
1203                << " outputs, in graph output num:" << output_index_set.size();
1204   return cnode_out_num == output_index_set.size();
1205 }
1206 
FindGraphEnd(vector<CNodePtr>::iterator begin,vector<CNodePtr>::iterator end)1207 vector<CNodePtr>::iterator AscendStreamAssign::FindGraphEnd(vector<CNodePtr>::iterator begin,
1208                                                             vector<CNodePtr>::iterator end) {
1209   while (begin != end) {
1210     if (AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) {
1211       MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope();
1212       return begin;
1213     }
1214     ++begin;
1215   }
1216   return end;
1217 }
1218 
1219 // section5
InsertEventForHcomParallel(const NotNull<KernelGraphPtr> & graph_ptr)1220 void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
1221   MS_LOG(INFO) << "Start";
1222   InsertEventCommonDependHcom(graph_ptr);
1223   InsertEventHcomDependCommonBak(graph_ptr);
1224   InsertEventHcomDependHcom(graph_ptr);
1225   MS_LOG(INFO) << "End";
1226 }
1227 
InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> & graph_ptr)1228 void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
1229   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1230   auto cnode_ptr_list = graph_ptr->execution_order();
1231   vector<CNodePtr> cnodes = cnode_ptr_list;
1232   uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1233   auto it = cnodes.begin();
1234   while (it != cnodes.end()) {
1235     MS_EXCEPTION_IF_NULL(*it);
1236     if (IsHcom(*it)) {
1237       auto cur_hcom_node = *it;
1238       CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
1239       it = cnodes.insert(it + 1, send_cnode_ptr);
1240 
1241       auto target = FindTargetOp(it, cnodes.end(), cur_hcom_node, true);
1242       if (target == cnodes.end()) {
1243         if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) {
1244           // if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode
1245           target = FindGraphEnd(it, cnodes.end());
1246         }
1247 
1248         if (target == cnodes.end()) {
1249           MS_EXCEPTION_IF_NULL(*(it - 1));
1250           MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
1251                           << ", can't find target for insert recv op, no insert send/recv";
1252           it = cnodes.erase(it);
1253           continue;
1254         }
1255       }
1256 
1257       // deal recv op
1258       uint32_t stream_id = AnfAlgo::GetStreamId(*target);
1259       CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id);
1260       (void)cnodes.insert(target, recv_cnode_ptr);
1261       cur_event_id = resource_manager.ApplyNewEvent();
1262     }
1263     ++it;
1264   }
1265   // one event allocated additional, should delete
1266   resource_manager.DeleteEvent();
1267   graph_ptr->set_execution_order(cnodes);
1268   MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
1269 }
1270 
1271 // after memory reuse is correct, use this function
InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> & graph_ptr)1272 void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr) {
1273   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1274   auto cnode_ptr_list = graph_ptr->execution_order();
1275   vector<CNodePtr> cnodes;
1276   CNodePtr cur_cnode_ptr = nullptr;
1277   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1278     cur_cnode_ptr = cnode_ptr_list[i];
1279     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1280     if (i == 0) {
1281       cnodes.emplace_back(cur_cnode_ptr);
1282       continue;
1283     }
1284 
1285     if (!IsHcom(cur_cnode_ptr)) {
1286       cnodes.emplace_back(cur_cnode_ptr);
1287       continue;
1288     }
1289 
1290     // get the input which located in the last exe orders
1291     vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
1292     if (inputs_cnode.empty()) {
1293       cnodes.emplace_back(cur_cnode_ptr);
1294       MS_LOG(WARNING) << "Hcom op:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << " can't find inputs nodes";
1295       continue;
1296     }
1297 
1298     MS_LOG(INFO) << "Current hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
1299                  << "; inputs cnode size:" << inputs_cnode.size();
1300 
1301     for (size_t j = 0; j < inputs_cnode.size(); j++) {
1302       auto &cur_input = inputs_cnode.at(j);
1303       MS_LOG(INFO) << "The index:" << j << " input, name:" << AnfAlgo::GetCNodeName(cur_input);
1304       uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1305       auto pre_stream_id = AnfAlgo::GetStreamId(cur_input);
1306       auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id);
1307       auto it = std::find(cnodes.begin(), cnodes.end(), cur_input);
1308       if (it == cnodes.end()) {
1309         MS_LOG_EXCEPTION << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
1310                          << " can't find input node:" << AnfAlgo::GetCNodeName(cur_input);
1311       }
1312       cnodes.insert(it + 1, send);
1313       uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1314       auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
1315       cnodes.emplace_back(recv);
1316       cnodes.emplace_back(cur_cnode_ptr);
1317     }
1318   }
1319 
1320   graph_ptr->set_execution_order(cnodes);
1321   MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
1322 }
1323 
GetLastInputCnode(const NotNull<KernelGraphPtr> & graph_ptr,const CNodePtr & cur_cnode_ptr)1324 vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
1325                                                        const CNodePtr &cur_cnode_ptr) {
1326   auto group_name = GetHcomGroup(cur_cnode_ptr);
1327   auto input_cnodes = GetInputKernels(cur_cnode_ptr);
1328   if (input_cnodes.empty()) {
1329     return {};
1330   }
1331   // record max index node for each stream
1332   std::map<uint32_t, std::pair<CNodePtr, uint32_t>> result;
1333   for (size_t i = 0; i < input_cnodes.size(); i++) {
1334     auto &cur_input = input_cnodes.at(i);
1335     auto stream_id = AnfAlgo::GetStreamId(cur_input);
1336     auto cur_index = GetIndexByKey(graph_ptr, cur_input.get());
1337     if (cur_index == UINT32_MAX) {
1338       MS_LOG_EXCEPTION << "The input node:" << AnfAlgo::GetCNodeName(cur_input) << " is not found in graph";
1339     }
1340     auto it = result.find(stream_id);
1341     if (it == result.end()) {
1342       result[stream_id] = std::make_pair(cur_input, cur_index);
1343     } else {
1344       auto max_index = it->second.second;
1345       if (cur_index > max_index) {
1346         result[stream_id] = std::make_pair(cur_input, cur_index);
1347       }
1348     }
1349   }
1350 
1351   vector<CNodePtr> final_inputs;
1352   CNodePtr max_common_cnode = nullptr;
1353   for (const auto &item : result) {
1354     if (IsHcom(item.second.first)) {
1355       auto cur_group = GetHcomGroup(item.second.first);
1356       if (cur_group == group_name) {
1357         continue;
1358       } else {
1359         final_inputs.emplace_back(item.second.first);
1360       }
1361     } else {
1362       max_common_cnode = item.second.first;
1363     }
1364   }
1365 
1366   if (max_common_cnode != nullptr) {
1367     final_inputs.emplace_back(max_common_cnode);
1368   }
1369   return final_inputs;
1370 }
1371 
GetInputKernels(const CNodePtr & cnode)1372 vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &cnode) {
1373   MS_EXCEPTION_IF_NULL(cnode);
1374   vector<CNodePtr> input_cnodes;
1375   queue<CNodePtr> nop_nodes;
1376   auto inputs = cnode->inputs();
1377   for (size_t i = 1; i < inputs.size(); i++) {
1378     auto real_input = AnfAlgo::VisitKernel(inputs[i], 0);
1379     auto node = real_input.first;
1380     MS_EXCEPTION_IF_NULL(node);
1381     if (opt::IsNopNode(node)) {
1382       nop_nodes.push(node->cast<CNodePtr>());
1383       while (!nop_nodes.empty()) {
1384         auto cur_node = nop_nodes.front();
1385         nop_nodes.pop();
1386         auto new_inputs = cur_node->inputs();
1387         for (size_t j = 1; j < new_inputs.size(); j++) {
1388           auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
1389           auto new_node = new_real_input.first;
1390           MS_EXCEPTION_IF_NULL(new_node);
1391           if (opt::IsNopNode(new_node)) {
1392             nop_nodes.push(new_node->cast<CNodePtr>());
1393           } else if (new_node->isa<CNode>()) {
1394             input_cnodes.emplace_back(new_node->cast<CNodePtr>());
1395           }
1396         }
1397       }
1398     } else if (node->isa<CNode>()) {
1399       input_cnodes.emplace_back(node->cast<CNodePtr>());
1400     }
1401   }
1402   return input_cnodes;
1403 }
1404 
InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> & graph_ptr)1405 void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
1406   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1407   auto cnode_ptr_list = graph_ptr->execution_order();
1408   vector<CNodePtr> cnodes;
1409   CNodePtr cur_cnode_ptr = nullptr;
1410   uint32_t pre_stream_id = UINT32_MAX;
1411   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1412     cur_cnode_ptr = cnode_ptr_list[i];
1413     uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1414     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1415     if (i == 0) {
1416       cnodes.emplace_back(cur_cnode_ptr);
1417       pre_stream_id = cur_stream_id;
1418       continue;
1419     }
1420 
1421     if (!IsHcom(cur_cnode_ptr)) {
1422       cnodes.emplace_back(cur_cnode_ptr);
1423       pre_stream_id = cur_stream_id;
1424       continue;
1425     }
1426 
1427     if (cur_stream_id == pre_stream_id) {
1428       cnodes.emplace_back(cur_cnode_ptr);
1429       pre_stream_id = cur_stream_id;
1430       continue;
1431     }
1432 
1433     if (!IsHcom(cnode_ptr_list[i - 1])) {
1434       uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1435       auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id);
1436       cnodes.emplace_back(send);
1437       auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
1438       cnodes.emplace_back(recv);
1439       cnodes.emplace_back(cur_cnode_ptr);
1440     } else {
1441       cnodes.emplace_back(cur_cnode_ptr);
1442     }
1443     pre_stream_id = cur_stream_id;
1444   }
1445 
1446   graph_ptr->set_execution_order(cnodes);
1447   MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
1448 }
1449 
GetStreamIDHcomMap(const std::vector<CNodePtr> & cnode_ptr_list,const std::string & group,size_t graph_id)1450 std::vector<std::pair<uint32_t, vector<size_t>>> AscendStreamAssign::GetStreamIDHcomMap(
1451   const std::vector<CNodePtr> &cnode_ptr_list, const std::string &group, size_t graph_id) {
1452   std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
1453   for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
1454     auto cur_cnode = cnode_ptr_list[i];
1455     if (!IsHcom(cur_cnode)) {
1456       continue;
1457     }
1458 
1459     uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
1460     auto group_name = GetHcomGroup(cur_cnode);
1461     auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get());
1462     MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
1463                  << "; stream id:" << cur_stream_id;
1464     if (group_name != group || cur_graph_id != graph_id) {
1465       continue;
1466     }
1467 
1468     bool exit = false;
1469     for (auto &item : stream_indices) {
1470       if (item.first == cur_stream_id) {
1471         item.second.emplace_back(i);
1472         exit = true;
1473         break;
1474       }
1475     }
1476     if (!exit) {
1477       stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
1478     }
1479   }
1480   return stream_indices;
1481 }
1482 
InsertEventHcomDependHcomAtSameGroup(const NotNull<KernelGraphPtr> & graph_ptr,std::pair<std::string,std::map<uint32_t,std::set<uint32_t>>> group_item)1483 void AscendStreamAssign::InsertEventHcomDependHcomAtSameGroup(
1484   const NotNull<KernelGraphPtr> &graph_ptr, std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item) {
1485   for (const auto &graph_item : group_item.second) {
1486     auto stream_indices = GetStreamIDHcomMap(graph_ptr->execution_order(), group_item.first, graph_item.first);
1487     constexpr size_t kStreamMax = 2;
1488     if (stream_indices.size() < kStreamMax) {
1489       MS_LOG(INFO) << "Group:" << group_item.first << ", Graph: " << graph_item.first
1490                    << " different stream hcom size is less than 2, no need insert event between them";
1491       continue;
1492     }
1493     InsertEventBetweenHcom(graph_ptr, stream_indices);
1494   }
1495 }
1496 
InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> & graph_ptr)1497 void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
1498   if (group_hcom_graph_map_.empty()) {
1499     return;
1500   }
1501   for (const auto &group_item : group_hcom_graph_map_) {
1502     InsertEventHcomDependHcomAtSameGroup(graph_ptr, group_item);
1503   }
1504 }
1505 
InsertEventBetweenHcom(const NotNull<KernelGraphPtr> & graph_ptr,const std::vector<std::pair<uint32_t,vector<size_t>>> & hcom_index)1506 void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
1507                                                 const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index) {
1508   vector<CNodePtr> orders;
1509   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1510   auto cnode_ptr_list = graph_ptr->execution_order();
1511   uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1512   if (hcom_index.empty()) {
1513     MS_LOG(EXCEPTION) << "Hcom stream number is empty";
1514   }
1515   size_t first_stream_last_index = hcom_index[0].second.back();
1516   size_t last_stream_first_index = hcom_index.back().second.front();
1517   MS_LOG(INFO) << "First stream last index:" << first_stream_last_index
1518                << "; last stream first index:" << last_stream_first_index;
1519   std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders));
1520   for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) {
1521     auto cur_cnode = cnode_ptr_list[i];
1522     if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) {
1523       orders.emplace_back(cur_cnode);
1524       continue;
1525     }
1526     auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
1527     if (i == first_stream_last_index) {
1528       // first fusion hcom
1529       orders.emplace_back(cur_cnode);
1530       auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1531       orders.emplace_back(send);
1532     } else if (i == last_stream_first_index) {
1533       // last fusion hcom
1534       auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1535       orders.emplace_back(recv);
1536       orders.emplace_back(cur_cnode);
1537     } else {
1538       size_t cur_stream_hcom_size = UINT32_MAX;
1539       size_t first_index = UINT32_MAX;
1540       size_t last_index = UINT32_MAX;
1541       for (const auto &item : hcom_index) {
1542         if (item.first == cur_hcom_stream_id) {
1543           cur_stream_hcom_size = item.second.size();
1544           first_index = item.second.front();
1545           last_index = item.second.back();
1546         }
1547       }
1548 
1549       if (cur_stream_hcom_size == 1) {
1550         auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1551         orders.emplace_back(recv);
1552         cur_event_id = resource_manager.ApplyNewEvent();
1553         orders.emplace_back(cur_cnode);
1554         auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1555         orders.emplace_back(send);
1556       } else {
1557         // current stream, first hcom:add recv op
1558         if (i == first_index) {
1559           auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1560           orders.emplace_back(recv);
1561           cur_event_id = resource_manager.ApplyNewEvent();
1562           orders.emplace_back(cur_cnode);
1563         } else if (i == last_index) {
1564           // current stream, last hcom:add send op
1565           orders.emplace_back(cur_cnode);
1566           auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1567           orders.emplace_back(send);
1568         } else {
1569           // current stream, not first and last op
1570           orders.emplace_back(cur_cnode);
1571         }
1572       }
1573     }
1574   }
1575   std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
1576   graph_ptr->set_execution_order(orders);
1577 }
1578 
IsSatisfiedHcom(const std::vector<std::pair<uint32_t,vector<size_t>>> & hcom_index,const CNodePtr & node_ptr,size_t index)1579 bool AscendStreamAssign::IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index,
1580                                          const CNodePtr &node_ptr, size_t index) {
1581   MS_EXCEPTION_IF_NULL(node_ptr);
1582   auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr);
1583   for (const auto &item : hcom_index) {
1584     if (item.first == cur_hcom_stream_id) {
1585       auto it = std::find(item.second.begin(), item.second.end(), index);
1586       if (it != item.second.end()) {
1587         return true;
1588       }
1589     }
1590   }
1591   return false;
1592 }
1593 
1594 // section6
InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> & graph_ptr)1595 void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
1596   MS_LOG(INFO) << "Start";
1597   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1598   auto cnode_ptr_list = graph_ptr->execution_order();
1599   vector<CNodePtr> cnodes = cnode_ptr_list;
1600   uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1601   std::map<CNodePtr, CNodePtr> cnode_send_map;
1602   std::map<CNodePtr, std::vector<CNodePtr>> cnode_recv_map;
1603   auto it = cnodes.begin();
1604   while (it != cnodes.end()) {
1605     MS_EXCEPTION_IF_NULL(*it);
1606     if (AnfAlgo::IsIndependentNode(*it)) {
1607       MS_LOG(DEBUG) << "Deal independent op[" << (*it)->DebugString() << "]";
1608       CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
1609 
1610       auto target = FindTargetOp(it + 1, cnodes.end(), *it, false);
1611       if (target == cnodes.end()) {
1612         MS_LOG(DEBUG) << "Independent node[" << (*it)->fullname_with_scope()
1613                       << "] can't find target for insert recv op, no insert send/recv";
1614         it++;
1615         continue;
1616       }
1617 
1618       // deal recv op
1619       uint32_t stream_id = AnfAlgo::GetStreamId(*target);
1620       CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id);
1621 
1622       cnode_send_map.insert(std::make_pair(*it, send_cnode_ptr));
1623       auto result = cnode_recv_map.find(*target);
1624       if (result == cnode_recv_map.end()) {
1625         std::vector<CNodePtr> recv_cnodes = {recv_cnode_ptr};
1626         cnode_recv_map.insert(std::make_pair(*target, recv_cnodes));
1627       } else {
1628         result->second.push_back(recv_cnode_ptr);
1629       }
1630       cur_event_id = resource_manager.ApplyNewEvent();
1631     }
1632     ++it;
1633   }
1634   // one event allocated additional, should delete
1635   resource_manager.DeleteEvent();
1636 
1637   std::vector<CNodePtr> new_cnodes;
1638   for (const auto &cnode : cnodes) {
1639     auto result_recv = cnode_recv_map.find(cnode);
1640     if (result_recv != cnode_recv_map.end()) {
1641       for (const auto &recv : result_recv->second) {
1642         new_cnodes.push_back(recv);
1643       }
1644     }
1645     new_cnodes.push_back(cnode);
1646     auto result_send = cnode_send_map.find(cnode);
1647     if (result_send != cnode_send_map.end()) {
1648       new_cnodes.push_back(result_send->second);
1649     }
1650   }
1651 
1652   graph_ptr->set_execution_order(new_cnodes);
1653   MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num();
1654   MS_LOG(INFO) << "End";
1655 }
1656 
GetIndependentMaxTarget(const NotNull<KernelGraphPtr> & graph_ptr)1657 void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
1658   MS_LOG(INFO) << "Start";
1659   auto cnode_ptr_list = graph_ptr->execution_order();
1660   for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
1661     auto cur_node = cnode_ptr_list[i];
1662     auto key = cur_node.get();
1663     if (!AnfAlgo::IsIndependentNode(cur_node)) {
1664       continue;
1665     }
1666 
1667     bool flag = false;
1668     for (size_t j = cnode_ptr_list.size() - 1; j > i; j--) {
1669       auto target_node = cnode_ptr_list[j];
1670       auto inputs = target_node->inputs();
1671       for (size_t m = 1; m < inputs.size(); m++) {
1672         auto input = inputs[m];
1673         MS_EXCEPTION_IF_NULL(input);
1674         if (opt::IsNopNode(input)) {
1675           auto cnode = input->cast<CNodePtr>();
1676           auto new_inputs = cnode->inputs();
1677           for (size_t k = 1; k < new_inputs.size(); k++) {
1678             auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0);
1679             if (key == new_real_input.first.get()) {
1680               MS_LOG(DEBUG) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node);
1681               independent_targets_.emplace(target_node.get());
1682               flag = true;
1683               break;
1684             }
1685           }
1686         } else {
1687           auto real_input = AnfAlgo::VisitKernel(input, 0);
1688           if (key == real_input.first.get()) {
1689             MS_LOG(DEBUG) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node);
1690             independent_targets_.emplace(target_node.get());
1691             flag = true;
1692           }
1693         }
1694         if (flag) {
1695           break;
1696         }
1697       }
1698     }
1699   }
1700 
1701   MS_LOG(INFO) << "End";
1702 }
1703 
GetIndexByKey(const NotNull<KernelGraphPtr> & graph_ptr,const CNodeKey & key)1704 uint32_t AscendStreamAssign::GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key) {
1705   auto &exe_orders = graph_ptr->execution_order();
1706   for (uint32_t i = 0; i < exe_orders.size(); i++) {
1707     CNodeKey node_key = exe_orders[i].get();
1708     if (node_key == key) {
1709       return i;
1710     }
1711   }
1712 
1713   return UINT32_MAX;
1714 }
1715 
GetMaxIndexTarget(const NotNull<KernelGraphPtr> & graph_ptr)1716 uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
1717   if (independent_targets_.empty()) {
1718     return UINT32_MAX;
1719   }
1720 
1721   std::set<uint32_t> indices;
1722   for (const auto &key : independent_targets_) {
1723     auto index = GetIndexByKey(graph_ptr, key);
1724     if (index == UINT32_MAX) {
1725       MS_LOG(EXCEPTION) << "graph has no correspond key";
1726     }
1727     indices.emplace(index);
1728   }
1729 
1730   return *(std::max_element(indices.begin(), indices.end()));
1731 }
1732 
GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> & graph_ptr)1733 uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
1734   auto &exe_orders = graph_ptr->execution_order();
1735   for (const auto &item : exe_orders) {
1736     if (AnfAlgo::GetCNodeName(item) == kStreamSwitchOpName) {
1737       if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, item)) {
1738         continue;
1739       }
1740       auto kind = AnfAlgo::GetNodeAttr<uint32_t>(item, kAttrStreamSwitchKind);
1741       if (kind == kIndependentStreamSwitch) {
1742         return AnfAlgo::GetStreamId(item);
1743       }
1744     }
1745   }
1746   return kInvalidStreamId;
1747 }
1748 
InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> & graph_ptr)1749 void AscendStreamAssign::InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
1750   if (independent_targets_.empty()) {
1751     return;
1752   }
1753 
1754   uint32_t independent_switch_stream = GetIndependentStreamSwitchStreamId(graph_ptr);
1755   if (independent_switch_stream == kInvalidStreamId) {
1756     return;
1757   }
1758 
1759   auto max_index = GetMaxIndexTarget(graph_ptr);
1760   auto &exe_orders = graph_ptr->execution_order();
1761   if (max_index >= exe_orders.size()) {
1762     MS_LOG(EXCEPTION) << "Max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size();
1763   }
1764 
1765   auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]);
1766 
1767   CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
1768   // 1.set stream id
1769   AnfAlgo::SetStreamId(max_node_stream, active_ptr.get());
1770   // 2.set active stream ids
1771   std::vector<uint32_t> active_index_list{independent_switch_stream};
1772   AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
1773 
1774   std::vector<CNodePtr> update_cnode_list;
1775   std::copy(exe_orders.begin(), exe_orders.begin() + max_index + 1, std::back_inserter(update_cnode_list));
1776   update_cnode_list.emplace_back(active_ptr);
1777   std::copy(exe_orders.begin() + max_index + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
1778   graph_ptr->set_execution_order(update_cnode_list);
1779 }
1780 
1781 // section7
GetNeedActiveStreams(const NotNull<KernelGraphPtr> & graph_ptr)1782 void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr) {
1783   CNodePtr cur_cnode_ptr = nullptr;
1784   auto cnode_ptr_list = graph_ptr->execution_order();
1785 
1786   // 1)stream witch kStreamNeedActivedFirst attr should be activated;
1787   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1788     cur_cnode_ptr = cnode_ptr_list[i];
1789     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1790     if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
1791       continue;
1792     }
1793 
1794     auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
1795     if (need_active) {
1796       auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1797       MS_LOG(INFO) << "Stream id:" << stream_id << " is need activated at first";
1798       need_first_active_streams_.push_back(stream_id);
1799     }
1800   }
1801 
1802   // 2)independent stream:if has not been activate, push to need active vector
1803   auto root_graph_id = graph_ptr->graph_id();
1804   if (!independent_stream_activated_) {
1805     auto it = independent_graph_map_.find(root_graph_id);
1806     if (it != independent_graph_map_.end()) {
1807       need_first_active_streams_.push_back(*(it->second.begin()));
1808     }
1809   }
1810 
1811   // 3)hcom stream:if has not been activate, push to need active vector
1812   if (!hcom_stream_activated_) {
1813     for (const auto &item : group_hcom_graph_map_) {
1814       auto &hcom_graph_map = item.second;
1815       auto it = hcom_graph_map.find(root_graph_id);
1816       if (it != hcom_graph_map.end()) {
1817         std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_));
1818       }
1819     }
1820   }
1821 
1822   // 4)first stream 0 should be activated first;
1823   auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), 0);
1824   if (it == need_first_active_streams_.end()) {
1825     need_first_active_streams_.emplace_back(0);
1826   }
1827   MS_LOG(INFO) << "Finally, need active first stream include:";
1828   for (const auto &item : need_first_active_streams_) {
1829     MS_LOG(INFO) << "stream id:" << item;
1830   }
1831 }
1832 
1833 // section8
CheckResourceAssign(const NotNull<KernelGraphPtr> & graph_ptr)1834 void AscendStreamAssign::CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
1835   CheckStreamAssign(graph_ptr);
1836   CheckEventAssign(graph_ptr);
1837 }
1838 
CheckStreamAssign(const NotNull<KernelGraphPtr> & graph_ptr)1839 void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
1840   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1841   std::set<uint32_t> streams;
1842   uint32_t max_stream = 0;
1843   uint32_t min_stream = kInvalidStreamId;
1844   auto cnode_ptr_list = graph_ptr->execution_order();
1845   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1846     CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
1847     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1848     uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1849     if (stream_id == kInvalidStreamId) {
1850       MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream";
1851     }
1852 
1853     (void)streams.emplace(stream_id);
1854     if (stream_id > max_stream) {
1855       max_stream = stream_id;
1856     }
1857     if (stream_id < min_stream) {
1858       min_stream = stream_id;
1859     }
1860   }
1861 
1862   // check stream assign
1863   if (!streams.empty()) {
1864     if (min_stream != 0) {
1865       MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream;
1866     }
1867     uint32_t assigned_stream_num = resource_manager.get_cur_stream_num();
1868     if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) {
1869       MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream
1870                         << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size();
1871     }
1872   }
1873 }
1874 
CheckEventAssign(const NotNull<KernelGraphPtr> & graph_ptr)1875 void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
1876   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1877   std::map<uint32_t, std::vector<CNodePtr>> event_map;
1878   uint32_t max_event_id = 0;
1879   uint32_t min_event_id = kInvalidEventId;
1880   auto cnode_ptr_list = graph_ptr->execution_order();
1881   for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1882     CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
1883     MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1884     auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr);
1885     if (name == kSendOpName || name == kRecvOpName) {
1886       uint32_t event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId);
1887       if (event_id > max_event_id) {
1888         max_event_id = event_id;
1889       }
1890 
1891       if (event_id < min_event_id) {
1892         min_event_id = event_id;
1893       }
1894       auto it = event_map.find(event_id);
1895       if (it == event_map.end()) {
1896         event_map[event_id] = {cur_cnode_ptr};
1897       } else {
1898         event_map[event_id].emplace_back(cur_cnode_ptr);
1899       }
1900     }
1901   }
1902   // check event assign
1903   if (!event_map.empty()) {
1904     if (min_event_id != 0) {
1905       MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id;
1906     }
1907     uint32_t assigned_event_num = resource_manager.get_cur_event_num();
1908     if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) {
1909       MS_LOG(EXCEPTION) << "Event should be consecutive, however, assigned event num is: " << assigned_event_num
1910                         << ", max event id:" << max_event_id << ", event map is:" << event_map;
1911     }
1912     for (const auto &item : event_map) {
1913       if (item.second.size() != 2) {
1914         MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id, invalid event id is:" << item.first
1915                           << ", event size is:" << item.second.size();
1916       }
1917       auto first_name = AnfAlgo::GetCNodeName(item.second[0]);
1918       auto second_name = AnfAlgo::GetCNodeName(item.second[1]);
1919       if (!(first_name == kSendOpName && second_name == kRecvOpName)) {
1920         MS_LOG(EXCEPTION) << "Send should be before recv, invalid event id is:" << item.first;
1921       }
1922     }
1923   }
1924 }
1925 
1926 // section9
CreateSendApplyKernel(const NotNull<KernelGraphPtr> & graph_ptr,uint32_t event_id,uint32_t stream_id)1927 CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id,
1928                                                    uint32_t stream_id) {
1929   auto send_op = std::make_shared<Primitive>(kSendOpName);
1930   MS_EXCEPTION_IF_NULL(send_op);
1931   auto send_apply = std::make_shared<ValueNode>(send_op);
1932   MS_EXCEPTION_IF_NULL(send_apply);
1933   std::vector<AnfNodePtr> send_input_list = {send_apply};
1934   CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list);
1935   MS_EXCEPTION_IF_NULL(send_node_ptr);
1936   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
1937   selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
1938   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
1939   AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
1940   auto abstract_none = std::make_shared<abstract::AbstractNone>();
1941   MS_EXCEPTION_IF_NULL(abstract_none);
1942   send_node_ptr->set_abstract(abstract_none);
1943   AnfAlgo::SetStreamId(stream_id, send_node_ptr.get());
1944   return send_node_ptr;
1945 }
1946 
CreateRecvApplyKernel(const NotNull<KernelGraphPtr> & graph_ptr,uint32_t event_id,uint32_t stream_id)1947 CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id,
1948                                                    uint32_t stream_id) {
1949   auto recv_op = std::make_shared<Primitive>(kRecvOpName);
1950   MS_EXCEPTION_IF_NULL(recv_op);
1951   auto recv_apply = std::make_shared<ValueNode>(recv_op);
1952   MS_EXCEPTION_IF_NULL(recv_apply);
1953   std::vector<AnfNodePtr> recv_input_list = {recv_apply};
1954   CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list);
1955   MS_EXCEPTION_IF_NULL(recv_node_ptr);
1956   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
1957   selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
1958   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
1959   AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
1960   AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get());
1961   auto abstract_none = std::make_shared<abstract::AbstractNone>();
1962   MS_EXCEPTION_IF_NULL(abstract_none);
1963   recv_node_ptr->set_abstract(abstract_none);
1964   return recv_node_ptr;
1965 }
1966 
IsNopNodeTarget(const AnfNodePtr & nop_node,const CNodePtr & target_node,const CNodePtr & cur_node,bool exclude_hcom)1967 bool AscendStreamAssign::IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node,
1968                                          const CNodePtr &cur_node, bool exclude_hcom) {
1969   MS_EXCEPTION_IF_NULL(nop_node);
1970   auto cnode = nop_node->cast<CNodePtr>();
1971   auto new_inputs = cnode->inputs();
1972   for (size_t i = 1; i < new_inputs.size(); i++) {
1973     if (opt::IsNopNode(new_inputs[i])) {
1974       if (IsNopNodeTarget(new_inputs[i], target_node, cur_node, exclude_hcom)) {
1975         return true;
1976       }
1977     } else {
1978       auto new_real_input = AnfAlgo::VisitKernel(new_inputs[i], 0);
1979       if (target_node == new_real_input.first) {
1980         if (!(exclude_hcom && IsHcom(cur_node))) {
1981           return true;
1982         }
1983       }
1984     }
1985   }
1986   return false;
1987 }
1988 
FindTargetOp(vector<CNodePtr>::iterator begin,vector<CNodePtr>::iterator end,const CNodePtr & node,bool exclude_hcom)1989 vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin,
1990                                                             vector<CNodePtr>::iterator end, const CNodePtr &node,
1991                                                             bool exclude_hcom) {
1992   while (begin != end) {
1993     auto inputs = (*begin)->inputs();
1994     for (size_t i = 1; i < inputs.size(); i++) {
1995       auto input = inputs[i];
1996       MS_EXCEPTION_IF_NULL(input);
1997       if (opt::IsNopNode(input)) {
1998         if (IsNopNodeTarget(input, node, *begin, exclude_hcom)) {
1999           return begin;
2000         }
2001       } else {
2002         auto real_input = AnfAlgo::VisitKernel(input, 0);
2003         if (node == real_input.first) {
2004           if (!(exclude_hcom && IsHcom(*begin))) {
2005             MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
2006             return begin;
2007           }
2008         }
2009       }
2010     }
2011     ++begin;
2012   }
2013   return end;
2014 }
2015 
IsTaskSink()2016 bool AscendStreamAssign::IsTaskSink() {
2017   auto ms_context = MsContext::GetInstance();
2018   MS_EXCEPTION_IF_NULL(ms_context);
2019   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
2020     MS_LOG(INFO) << "Task sink mode is not enable";
2021     return false;
2022   } else {
2023     MS_LOG(INFO) << "Task sink mode is enable";
2024     return true;
2025   }
2026 }
2027 
GetWaitStreams(vector<uint32_t> * wait_active_stream_list)2028 void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) {
2029   MS_EXCEPTION_IF_NULL(wait_active_stream_list);
2030   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
2031   uint32_t total_stream_num = resource_manager.get_cur_stream_num();
2032   if (total_stream_num == 0) {
2033     MS_LOG(INFO) << "The total_common_stream_num is zero";
2034     return;
2035   }
2036 
2037   // common stream:active first common stream
2038   for (uint32_t i = 0; i < total_stream_num; i++) {
2039     auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i);
2040     if (it == need_first_active_streams_.end()) {
2041       MS_LOG(INFO) << "Wait common stream id = " << i;
2042       wait_active_stream_list->push_back(i);
2043     }
2044   }
2045 }
2046 
IsHcom(const CNodePtr & apply_kernel)2047 bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) {
2048   MS_EXCEPTION_IF_NULL(apply_kernel);
2049   return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL;
2050 }
2051 
GetHcomStreams(std::vector<uint32_t> * streams)2052 void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) {
2053   MS_EXCEPTION_IF_NULL(streams);
2054   for (const auto &item : hcom_stream_map_) {
2055     streams->emplace_back(item.first);
2056   }
2057 }
2058 
Reset()2059 void AscendStreamAssign::Reset() {
2060   independent_stream_activated_ = false;
2061   hcom_stream_activated_ = false;
2062   loop_sink_ = false;
2063   independent_stream_map_.clear();
2064   hcom_stream_map_.clear();
2065   common_stream_map_.clear();
2066   processed_streams_.clear();
2067   need_first_active_streams_.clear();
2068   stream_groups_.clear();
2069   stream_relations_.clear();
2070   event_map_.clear();
2071   independent_targets_.clear();
2072   independent_graph_map_.clear();
2073   group_hcom_graph_map_.clear();
2074   middle_active_streams_.clear();
2075 }
2076 
2077 // section 10
IsVecExist(const std::vector<uint32_t> & group)2078 bool AscendStreamAssign::IsVecExist(const std::vector<uint32_t> &group) {
2079   auto group_size = group.size();
2080   if (group_size == 0) {
2081     return false;
2082   }
2083   for (const auto &item : stream_groups_) {
2084     if (item.size() < group.size()) {
2085       continue;
2086     }
2087 
2088     bool flag = true;
2089     for (size_t i = 0; i < group_size; i++) {
2090       if (item[i] != group.at(i)) {
2091         flag = false;
2092         break;
2093       }
2094     }
2095 
2096     if (flag) {
2097       return true;
2098     } else {
2099       continue;
2100     }
2101   }
2102 
2103   return false;
2104 }
2105 
DFS(uint32_t start,std::vector<uint32_t> * group)2106 void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) {
2107   MS_EXCEPTION_IF_NULL(group);
2108   auto it = stream_relations_.find(start);
2109   if (it == stream_relations_.end()) {
2110     if (!IsVecExist(*group)) {
2111       stream_groups_.emplace_back(*group);
2112     } else {
2113       MS_LOG(WARNING) << "DFS find same stream group, Not expected";
2114     }
2115     return;
2116   }
2117 
2118   vector<uint32_t> active_streams = stream_relations_[start];
2119 
2120   for (const auto &item : active_streams) {
2121     group->emplace_back(item);
2122     DFS(item, group);
2123     group->pop_back();
2124   }
2125 }
2126 
GetStreamRelations()2127 void AscendStreamAssign::GetStreamRelations() {
2128   auto starts = middle_active_streams_;
2129   for (const auto &stream : need_first_active_streams_) {
2130     starts.emplace(stream);
2131   }
2132 
2133   for (const auto &start : starts) {
2134     vector<uint32_t> group{start};
2135     DFS(start, &group);
2136   }
2137 }
2138 
FindStreamRelations(const NotNull<KernelGraphPtr> & graph_ptr)2139 void AscendStreamAssign::FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
2140   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
2141   auto stream_num = resource_manager.get_cur_stream_num();
2142   if (stream_num <= 1) {
2143     return;
2144   }
2145 
2146   auto exe_orders = graph_ptr->execution_order();
2147   for (size_t i = 0; i < exe_orders.size(); i++) {
2148     auto cur_cnode = exe_orders[i];
2149     auto name = AnfAlgo::GetCNodeName(cur_cnode);
2150     if (name != kStreamSwitchOpName && name != kStreamActiveOpName) {
2151       continue;
2152     }
2153 
2154     // support:streamswitch is begin of the stream
2155     if (name == kStreamSwitchOpName) {
2156       GetStreamSwitchStreamRelation(cur_cnode);
2157     }
2158 
2159     if (name == kStreamActiveOpName) {
2160       GetStreamActiveStreamRelation(graph_ptr, i);
2161     }
2162   }
2163 }
2164 
GetStreamSwitchStreamRelation(const CNodePtr & node_ptr)2165 void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) {
2166   MS_EXCEPTION_IF_NULL(node_ptr);
2167   auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr);
2168   auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(node_ptr, kAttrTrueBranchStream);
2169   if (true_stream_id <= cur_stream_id) {
2170     MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id
2171                   << " is greater than true branch stream id:" << true_stream_id;
2172   }
2173   auto it = stream_relations_.find(cur_stream_id);
2174   if (it == stream_relations_.end()) {
2175     stream_relations_[cur_stream_id] = {true_stream_id};
2176   } else {
2177     auto iter =
2178       std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id);
2179     if (iter == stream_relations_[cur_stream_id].end()) {
2180       stream_relations_[cur_stream_id].emplace_back(true_stream_id);
2181     }
2182   }
2183 }
2184 
GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> & graph_ptr,size_t index)2185 void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
2186   StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index);
2187   if (kind == kInvalid) {
2188     MS_LOG(INFO) << "Invalid streamActive kind";
2189     return;
2190   }
2191 
2192   auto orders = graph_ptr->execution_order();
2193   if (index >= orders.size()) {
2194     MS_LOG(EXCEPTION) << "Invalid index.";
2195   }
2196   auto cur_cnode = orders[index];
2197   auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
2198   auto active_list = AnfAlgo::GetNodeAttr<vector<uint32_t>>(cur_cnode, kAttrActiveStreamList);
2199   if (kind == kHead) {
2200     uint32_t active_current_stream_id = GetStreamByActivedStream(cur_stream_id);
2201     if (active_current_stream_id == kInvalidStreamId) {
2202       MS_LOG(EXCEPTION) << "No stream to active streamactive stream: " << cur_stream_id;
2203     }
2204 
2205     for (const auto &item : active_list) {
2206       if (item <= active_current_stream_id) {
2207         MS_LOG(WARNING) << "Activated stream is less than activing stream";
2208         continue;
2209       }
2210       auto it = std::find(stream_relations_[active_current_stream_id].begin(),
2211                           stream_relations_[active_current_stream_id].end(), item);
2212       if (it == stream_relations_[active_current_stream_id].end()) {
2213         stream_relations_[active_current_stream_id].emplace_back(item);
2214       }
2215     }
2216   }
2217 
2218   if (kind == kMiddle) {
2219     for (const auto &stream : active_list) {
2220       if (stream <= cur_stream_id) {
2221         MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal";
2222       } else {
2223         MS_LOG(INFO) << "MIDDLE StreamActive :" << cur_stream_id << ", active target stream:" << stream;
2224         middle_active_streams_.emplace(stream);
2225       }
2226     }
2227   }
2228 
2229   if (kind == kTail) {
2230     auto it = stream_relations_.find(cur_stream_id);
2231     if (it == stream_relations_.end()) {
2232       stream_relations_[cur_stream_id] = active_list;
2233     } else {
2234       for (const auto &stream : active_list) {
2235         if (stream <= cur_stream_id) {
2236           MS_LOG(WARNING) << "Activated stream is less than activing stream";
2237           continue;
2238         }
2239         auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
2240         if (iter == stream_relations_[cur_stream_id].end()) {
2241           stream_relations_[cur_stream_id].emplace_back(stream);
2242         }
2243       }
2244     }
2245   }
2246 }
2247 
GetStreamActiveKind(const NotNull<KernelGraphPtr> & graph_ptr,size_t index)2248 StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
2249   auto exe_orders = graph_ptr->execution_order();
2250   if (index >= exe_orders.size()) {
2251     MS_LOG(EXCEPTION) << "Invalid op index:" << index;
2252   }
2253 
2254   auto cur_cnode = exe_orders[index];
2255   auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
2256   if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) {
2257     MS_LOG(EXCEPTION) << "Current node name [" << AnfAlgo::GetCNodeName(cur_cnode) << "] is not StreamActive.";
2258   }
2259 
2260   if (index == 0) {
2261     return kInvalid;
2262   }
2263 
2264   if (index == exe_orders.size() - 1) {
2265     return kInvalid;
2266   }
2267 
2268   uint32_t pre_stream_id = UINT32_MAX;
2269   uint32_t next_stream_id = UINT32_MAX;
2270   int32_t start = SizeToInt(index) - 1;
2271   for (int32_t i = start; i >= 0; i--) {
2272     auto cnode = exe_orders[IntToSize(i)];
2273     auto name = AnfAlgo::GetCNodeName(cnode);
2274     if (name == kSendOpName || name == kRecvOpName) {
2275       continue;
2276     }
2277     auto stream = AnfAlgo::GetStreamId(cnode);
2278     auto it = hcom_stream_map_.find(stream);
2279     if (it != hcom_stream_map_.end()) {
2280       continue;
2281     }
2282 
2283     it = independent_stream_map_.find(stream);
2284     if (it != independent_stream_map_.end()) {
2285       continue;
2286     }
2287 
2288     pre_stream_id = stream;
2289     break;
2290   }
2291 
2292   for (size_t i = index + 1; i < exe_orders.size(); i++) {
2293     auto cnode = exe_orders[i];
2294     if (AnfAlgo::GetCNodeName(cnode) == kSendOpName || AnfAlgo::GetCNodeName(cnode) == kRecvOpName) {
2295       continue;
2296     }
2297 
2298     auto stream = AnfAlgo::GetStreamId(cnode);
2299     auto it = hcom_stream_map_.find(stream);
2300     if (it != hcom_stream_map_.end()) {
2301       continue;
2302     }
2303 
2304     it = independent_stream_map_.find(stream);
2305     if (it != independent_stream_map_.end()) {
2306       continue;
2307     }
2308 
2309     next_stream_id = stream;
2310     break;
2311   }
2312 
2313   return GetStreamKind(cur_stream_id, pre_stream_id, next_stream_id);
2314 }
2315 
GetStreamByActivedStream(uint32_t actived_stream_id)2316 uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
2317   if (stream_relations_.empty()) {
2318     return kInvalidStreamId;
2319   }
2320 
2321   for (const auto &item : stream_relations_) {
2322     auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id);
2323     if (it != item.second.end()) {
2324       return item.first;
2325     }
2326   }
2327 
2328   return kInvalidStreamId;
2329 }
2330 
PrintStreamRelations()2331 void AscendStreamAssign::PrintStreamRelations() {
2332   MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size();
2333   for (const auto &item : stream_relations_) {
2334     MS_LOG(INFO) << "Stream:" << item.first;
2335     for (const auto &stream : item.second) {
2336       MS_LOG(INFO) << "--activated stream id:" << stream;
2337     }
2338   }
2339 }
2340 
PrintStreamGroups()2341 void AscendStreamAssign::PrintStreamGroups() {
2342   MS_LOG(INFO) << "Stream group size:" << stream_groups_.size();
2343   for (const auto &item : stream_groups_) {
2344     MS_LOG(INFO) << "Group:";
2345     for (const auto &stream : item) {
2346       MS_LOG(INFO) << "Stream id:" << stream;
2347     }
2348   }
2349 }
2350 
2351 // section 11
IsSatisfiedEvent(uint32_t send_stream_id,uint32_t recv_stream_id) const2352 bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const {
2353   size_t send_group = 0;
2354   size_t recv_group = 0;
2355   bool send_flag = true;
2356   bool recv_flag = true;
2357   for (size_t i = 0; i < stream_groups_.size(); i++) {
2358     auto group = stream_groups_[i];
2359     if (send_flag) {
2360       auto it = std::find(group.begin(), group.end(), send_stream_id);
2361       if (it != group.end()) {
2362         send_group = i;
2363         send_flag = false;
2364       }
2365     }
2366 
2367     if (recv_flag) {
2368       auto it = std::find(group.begin(), group.end(), recv_stream_id);
2369       if (it != group.end()) {
2370         recv_group = i;
2371         recv_flag = false;
2372       }
2373     }
2374   }
2375 
2376   if (!(send_flag || recv_flag)) {
2377     return (send_group != recv_group);
2378   }
2379 
2380   return false;
2381 }
2382 
FindEventRelations(const NotNull<KernelGraphPtr> & graph_ptr)2383 void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
2384   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
2385   auto event_nums = resource_manager.get_cur_event_num();
2386   if (event_nums == 0) {
2387     return;
2388   }
2389   auto exe_orders = graph_ptr->execution_order();
2390   // find all event info
2391   for (size_t i = 0; i < exe_orders.size(); i++) {
2392     auto cur_cnode = exe_orders[i];
2393     auto name = AnfAlgo::GetCNodeName(cur_cnode);
2394     if (name == kSendOpName) {
2395       event_map_[cur_cnode] = {};
2396     }
2397 
2398     if (name == kRecvOpName) {
2399       auto recv_event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode, kAttrEventId);
2400       for (auto &item : event_map_) {
2401         auto send_event_id = AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
2402         if (recv_event_id == send_event_id) {
2403           item.second = cur_cnode;
2404           break;
2405         }
2406       }
2407     }
2408   }
2409 
2410   // delete useless event info
2411   auto begin = event_map_.begin();
2412   while (begin != event_map_.end()) {
2413     auto send_stream_id = AnfAlgo::GetStreamId(begin->first);
2414     auto recv_stream_id = AnfAlgo::GetStreamId(begin->second);
2415     bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id);
2416     if (!flag) {
2417       begin = event_map_.erase(begin);
2418     } else {
2419       ++begin;
2420     }
2421   }
2422 
2423   MS_LOG(INFO) << "Satisfied event info";
2424   for (const auto &item : event_map_) {
2425     MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
2426   }
2427 }
2428 
2429 // section12
AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> & graph_ptr)2430 void AscendStreamAssign::AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr) {
2431   // Eg:[atomic, recv, memcpy] should be [recv, atomic, memcpy]
2432   std::vector<CNodePtr> update_orders;
2433   auto &exe_orders = graph_ptr->execution_order();
2434   size_t i = 0;
2435   while (i < exe_orders.size()) {
2436     auto cur_cnode = exe_orders.at(i);
2437     if (AnfAlgo::GetCNodeName(cur_cnode) != kAtomicAddrCleanOpName) {
2438       update_orders.emplace_back(cur_cnode);
2439       i++;
2440       continue;
2441     }
2442     while (i < exe_orders.size() - 1) {
2443       i++;
2444       auto next_cnode = exe_orders.at(i);
2445       auto next_cnode_name = AnfAlgo::GetCNodeName(next_cnode);
2446       if (next_cnode_name == kSendOpName || next_cnode_name == kRecvOpName) {
2447         update_orders.emplace_back(next_cnode);
2448       } else {
2449         update_orders.emplace_back(cur_cnode);
2450         break;
2451       }
2452     }
2453   }
2454   graph_ptr->set_execution_order(update_orders);
2455 }
2456 }  // namespace ascend
2457 }  // namespace device
2458 }  // namespace mindspore
2459