1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/ascend_stream_assign.h"
18
19 #include <algorithm>
20 #include <utility>
21
22 #include "ir/manager.h"
23 #include "utils/ms_context.h"
24 #include "utils/ms_utils.h"
25 #include "frontend/parallel/context.h"
26 #include "frontend/parallel/device_manager.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 #include "runtime/device/kernel_adjust.h"
29 #include "backend/optimizer/common/helper.h"
30 #include "backend/kernel_compiler/oplib/oplib.h"
31 #include "utils/utils.h"
32
33 #ifdef ENABLE_DUMP_IR
34 #include "debug/rdr/running_data_recorder.h"
35 #endif
36
37 namespace mindspore {
38 namespace device {
39 namespace ascend {
40 namespace {
41 constexpr uint32_t kDeviceNumOfServer = 8;
42 constexpr uint32_t kDeviceNumThreshold = 1024;
43 const char kDefaultGroup[] = "__default_group";
44 constexpr auto kAttrStreamID = "stream_id";
45
46 constexpr uint32_t kMaxStreamNum = 1024;
47 constexpr uint32_t kHcomSecondaryStreamNum = 3;
48
49 constexpr uint32_t kMaxTaskNumPerStream = 1010;
50 constexpr uint32_t kMaxCommonNodeNumPerStream = 350;
51
52 constexpr uint32_t kTaskNumPerHcomNode = 200;
53 constexpr uint32_t kTaskNumPerWorldHcomNode = 250;
54 constexpr uint32_t kTaskNumPerSameServerHcomNode = 125;
55 constexpr uint32_t kTaskNumPerHcomSendRecvNode = 15;
56
57 constexpr size_t kHcomNum = 2;
58 constexpr size_t kLastGradHcomOffset = 2;
59 constexpr size_t kLastGradAndStatusNum = 2;
60
IsSameServer(const std::vector<uint32_t> & rank_ids)61 bool IsSameServer(const std::vector<uint32_t> &rank_ids) {
62 auto min_iter = min_element(rank_ids.begin(), rank_ids.end());
63 uint32_t min = (min_iter != rank_ids.end()) ? *min_iter : 0;
64 auto max_iter = max_element(rank_ids.begin(), rank_ids.end());
65 uint32_t max = (max_iter != rank_ids.end()) ? *max_iter : 0;
66 return ((max - min < kDeviceNumOfServer) && (min / kDeviceNumOfServer == max / kDeviceNumOfServer));
67 }
68
DoGetHcomGroup(const string & original_group)69 string DoGetHcomGroup(const string &original_group) {
70 string communi_parallel_mode = parallel::ParallelContext::GetInstance()->communi_parallel_mode();
71 if (communi_parallel_mode == parallel::ALL_GROUP_PARALLEL) {
72 return original_group;
73 }
74
75 if (communi_parallel_mode == parallel::NO_GROUP_PARALLEL) {
76 return kDefaultGroup;
77 }
78
79 MS_EXCEPTION_IF_NULL(parallel::g_device_manager);
80 auto group_info = parallel::g_device_manager->group_info();
81 for (const auto &info : group_info) {
82 if (info.first != original_group) {
83 continue;
84 }
85
86 const auto &rank_ids = info.second;
87 if (IsSameServer(rank_ids)) {
88 return original_group;
89 } else {
90 return kDefaultGroup;
91 }
92 }
93
94 // world group is not in group_info.
95 return kDefaultGroup;
96 }
97
GetHcomGroup(const CNodePtr & cnode)98 string GetHcomGroup(const CNodePtr &cnode) {
99 MS_EXCEPTION_IF_NULL(cnode);
100 if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
101 MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
102 }
103
104 auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
105 auto new_group = DoGetHcomGroup(group_name);
106 MS_LOG_INFO << "hcom node: " << cnode->fullname_with_scope() << ", old group: " << group_name
107 << ", new group: " << new_group;
108
109 return new_group;
110 }
111
GetHcomTaskNum(const CNodePtr & cnode)112 uint32_t GetHcomTaskNum(const CNodePtr &cnode) {
113 MS_EXCEPTION_IF_NULL(cnode);
114 if (!AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
115 MS_LOG_EXCEPTION << "Hcom node " << cnode->fullname_with_scope() << " has no group attribute.";
116 }
117
118 if (parallel::g_device_manager == nullptr) {
119 MS_LOG(INFO) << "Device manager is nullptr.";
120 return kTaskNumPerHcomNode;
121 }
122
123 auto node_name = AnfAlgo::GetCNodeName(cnode);
124 if (node_name == kHcomSendOpName || node_name == kReceiveOpName) {
125 return kTaskNumPerHcomSendRecvNode;
126 }
127
128 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
129 auto device_num = parallel::ParallelContext::GetInstance()->device_num();
130 auto group_name = AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
131 auto group_info = parallel::g_device_manager->group_info();
132 for (const auto &info : group_info) {
133 if (info.first != group_name) {
134 continue;
135 }
136 const auto &rank_ids = info.second;
137 if (IsSameServer(rank_ids)) {
138 return kTaskNumPerSameServerHcomNode;
139 } else if (rank_ids.size() == static_cast<size_t>(device_num) && device_num >= kDeviceNumThreshold) {
140 return kTaskNumPerWorldHcomNode;
141 } else {
142 return kTaskNumPerHcomNode;
143 }
144 }
145
146 // world group is not in group_info.
147 if (device_num >= kDeviceNumThreshold) {
148 return kTaskNumPerWorldHcomNode;
149 } else {
150 return kTaskNumPerHcomNode;
151 }
152 }
153
GetHcomAndOverflowMarker(const NotNull<KernelGraphPtr> & graph_ptr,vector<CNodePtr> * hcom_nodes)154 CNodePtr GetHcomAndOverflowMarker(const NotNull<KernelGraphPtr> &graph_ptr, vector<CNodePtr> *hcom_nodes) {
155 MS_EXCEPTION_IF_NULL(hcom_nodes);
156 auto cnode_ptr_list = graph_ptr->execution_order();
157 CNodePtr overflow_marker = nullptr;
158 std::string kNPUGetFloatStatusOpName = "NPUGetFloatStatus";
159 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
160 auto cur_cnode_ptr = cnode_ptr_list[i];
161 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
162 if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kNPUGetFloatStatusOpName) {
163 overflow_marker = cur_cnode_ptr;
164 } else if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL) {
165 hcom_nodes->emplace_back(cur_cnode_ptr);
166 } else if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) {
167 auto graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
168 AnfAlgo::SetGraphId(graph_id, cnode_ptr_list[i - 1].get());
169 }
170 }
171 return overflow_marker;
172 }
173
HasRefNodes(const vector<CNodePtr> & moved_backward_cnodes)174 bool HasRefNodes(const vector<CNodePtr> &moved_backward_cnodes) {
175 for (auto &cnode : moved_backward_cnodes) {
176 std::string op_name = AnfAlgo::GetCNodeName(cnode);
177 auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
178 if (op_info != nullptr && op_info->is_ref()) {
179 MS_LOG(INFO) << "Find RefNode: " << op_name << ", full name: " << cnode->fullname_with_scope();
180 return true;
181 }
182 }
183 return false;
184 }
185
GetStreamKind(uint32_t cur_stream_id,uint32_t pre_stream_id,uint32_t next_stream_id)186 StreamActiveKind GetStreamKind(uint32_t cur_stream_id, uint32_t pre_stream_id, uint32_t next_stream_id) {
187 // pre_stream_id equal to UINT32_MAX means no node active current StreamActive
188 // next_stream_id equal to UINT32_MAX means current StreamActive active no node
189 if (pre_stream_id == UINT32_MAX || next_stream_id == UINT32_MAX) {
190 return kInvalid;
191 }
192
193 if (cur_stream_id == pre_stream_id && cur_stream_id == next_stream_id) {
194 return kMiddle;
195 }
196
197 if (cur_stream_id == pre_stream_id) {
198 return kTail;
199 }
200
201 if (cur_stream_id == next_stream_id) {
202 return kHead;
203 }
204
205 return kInvalid;
206 }
SetNodeStreamIDAttr(const NotNull<KernelGraphPtr> & graph_ptr)207 void SetNodeStreamIDAttr(const NotNull<KernelGraphPtr> &graph_ptr) {
208 auto exec_orders = graph_ptr->execution_order();
209 for (auto node : exec_orders) {
210 AnfAlgo::SetNodeAttr(kAttrStreamID, MakeValue<uint32_t>(AnfAlgo::GetStreamId(node)), node);
211 }
212 }
213 } // namespace
214
AssignStream(const NotNull<KernelGraphPtr> & graph_ptr)215 void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
216 if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
217 MS_LOG(INFO) << "Communication parallel mode: " << parallel::ParallelContext::GetInstance()->communi_parallel_mode()
218 << ".";
219
220 Reset();
221 SetLoopSink();
222 ReorderIndependentOrders(graph_ptr);
223 TrailingTimeOptimizationByReorder(graph_ptr);
224
225 AssignAllNodesStream(graph_ptr);
226 UpdateAtomicAddrCleanStreamId(graph_ptr);
227 InsertStreamActive(graph_ptr);
228 InsertEventForHcomParallel(graph_ptr);
229 InsertEventForIndependentParallel(graph_ptr);
230 GetIndependentMaxTarget(graph_ptr);
231 InsertCtrlForIndependentParallel(graph_ptr);
232 AdjustAtomicAddrCleanOrder(graph_ptr);
233
234 GetNeedActiveStreams(graph_ptr);
235
236 MS_LOG(INFO) << "Before check resource assign";
237 graph_ptr->PrintGraphExecuteOrder();
238
239 CheckResourceAssign(graph_ptr);
240 MS_LOG(INFO) << "After finish stream assign";
241 #ifdef ENABLE_DUMP_IR
242 SubModuleId module = SubModuleId::SM_SESSION;
243 std::string name = "assign_stream." + std::to_string(graph_ptr->graph_id());
244 const std::vector<CNodePtr> &exec_order = graph_ptr->execution_order();
245 (void)mindspore::RDR::RecordStreamExecOrder(module, name, exec_order);
246 #endif
247 graph_ptr->PrintGraphExecuteOrder();
248 SetNodeStreamIDAttr(graph_ptr);
249 FindStreamRelations(graph_ptr);
250 PrintStreamRelations();
251 GetStreamRelations();
252 PrintStreamGroups();
253 FindEventRelations(graph_ptr);
254 }
255 }
256
SetLoopSink()257 void AscendStreamAssign::SetLoopSink() {
258 if (KernelAdjust::NeedInsertSwitch()) {
259 loop_sink_ = true;
260 } else {
261 loop_sink_ = false;
262 }
263 }
264
265 // section 1
ReorderIndependentOrders(const NotNull<KernelGraphPtr> & graph_ptr)266 void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr) {
267 std::vector<CNodePtr> exe_orders;
268 std::vector<CNodePtr> independents;
269 std::vector<CNodePtr> others;
270
271 auto cnode_ptr_list = graph_ptr->execution_order();
272 MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size();
273 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
274 auto cur_cnode_ptr = cnode_ptr_list[i];
275 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
276 if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
277 independents.emplace_back(cur_cnode_ptr);
278 } else {
279 others.emplace_back(cur_cnode_ptr);
280 }
281 }
282
283 if (others.empty() || independents.empty()) {
284 MS_LOG(INFO) << "Independent or others is empty, no need reorder";
285 return;
286 }
287
288 std::set<CNode *> processed;
289 for (size_t i = 0; i < others.size(); i++) {
290 auto begin = others.begin() + i;
291 auto end = begin + 1;
292 bool flag = false;
293 for (size_t j = 0; j < independents.size(); j++) {
294 auto cur_independent = independents[j];
295 auto it = std::find(processed.begin(), processed.end(), cur_independent.get());
296 if (it != processed.end()) {
297 continue;
298 }
299
300 auto res = FindTargetOp(begin, end, cur_independent, false);
301 if (res != end) {
302 flag = true;
303 exe_orders.emplace_back(cur_independent);
304 exe_orders.emplace_back(*begin);
305 processed.emplace(cur_independent.get());
306 break;
307 }
308 }
309
310 if (!flag) {
311 exe_orders.emplace_back(*begin);
312 }
313 }
314
315 MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size();
316 if (processed.size() != independents.size()) {
317 MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size";
318 return;
319 }
320
321 graph_ptr->set_execution_order(exe_orders);
322 }
323
CheckScenario(const NotNull<KernelGraphPtr> & graph_ptr,vector<CNodePtr> * last_grad_and_status)324 void AscendStreamAssign::CheckScenario(const NotNull<KernelGraphPtr> &graph_ptr,
325 vector<CNodePtr> *last_grad_and_status) {
326 MS_EXCEPTION_IF_NULL(last_grad_and_status);
327 auto cnode_ptr_list = graph_ptr->execution_order();
328 vector<CNodePtr> hcom_nodes;
329 auto overflow_marker = GetHcomAndOverflowMarker(graph_ptr, &hcom_nodes);
330 if (hcom_nodes.size() < kHcomNum || overflow_marker == nullptr) {
331 MS_LOG(INFO) << "Current model isn't in distribute or mix-precision mode, no optimization needed";
332 last_grad_and_status->clear();
333 return;
334 }
335
336 auto overflow_marker_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), overflow_marker);
337 auto last_hcom_ptr = hcom_nodes[hcom_nodes.size() - 1];
338 auto last_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_hcom_ptr);
339 auto last_grad_hcom_ptr = hcom_nodes[hcom_nodes.size() - kLastGradHcomOffset];
340 auto last_grad_hcom_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_hcom_ptr);
341 if (last_grad_hcom_pos > overflow_marker_pos || last_hcom_pos < overflow_marker_pos) {
342 MS_LOG(INFO) << "Grads average done after overflow judgement or status aren't allgathered, no optimization needed";
343 last_grad_and_status->clear();
344 return;
345 }
346
347 auto last_inputs = GetLastInputCnode(graph_ptr, last_grad_hcom_ptr);
348 if (last_inputs.empty() || last_inputs.size() > 1 || IsHcom(last_inputs[0])) {
349 MS_LOG(INFO) << "Inputs of last gradients allreduce is empty or include other allreduce, no optimization needed";
350 last_grad_and_status->clear();
351 return;
352 }
353 auto last_grad_ptr = last_inputs[0];
354 MS_LOG(DEBUG) << "Last Hcom: " << last_grad_hcom_ptr->fullname_with_scope()
355 << "; last input: " << last_grad_ptr->fullname_with_scope();
356 auto last_grad_hcom_graph_id = AnfAlgo::GetGraphId(last_grad_hcom_ptr.get());
357 auto last_grad_graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get());
358 auto overflow_marker_graph_id = AnfAlgo::GetGraphId(overflow_marker.get());
359 if (last_grad_graph_id != last_grad_hcom_graph_id || last_grad_graph_id != overflow_marker_graph_id) {
360 MS_LOG(INFO) << "The grads and grad_hcom or overflow marker were not on the same subgraph, no optimization needed";
361 last_grad_and_status->clear();
362 return;
363 }
364
365 auto label_switch_pos = find_if(last_grad_hcom_pos, cnode_ptr_list.end(),
366 [](CNodePtr &node) -> bool { return AnfAlgo::GetCNodeName(node) == "LabelSwitch"; });
367 if (label_switch_pos == cnode_ptr_list.end()) {
368 MS_LOG(INFO) << "No branches after getting overflow status, no optimization needed";
369 last_grad_and_status->clear();
370 return;
371 }
372 last_grad_and_status->emplace_back(last_grad_ptr);
373 last_grad_and_status->emplace_back(overflow_marker);
374 return;
375 }
376
GetCNodesNeededMoved(vector<CNodePtr> * moved_backward_cnodes,vector<CNodePtr> * moved_forward_cnodes,const vector<CNodePtr> & last_grad_and_status,const NotNull<KernelGraphPtr> & graph_ptr)377 CNodePtr AscendStreamAssign::GetCNodesNeededMoved(vector<CNodePtr> *moved_backward_cnodes,
378 vector<CNodePtr> *moved_forward_cnodes,
379 const vector<CNodePtr> &last_grad_and_status,
380 const NotNull<KernelGraphPtr> &graph_ptr) {
381 MS_EXCEPTION_IF_NULL(moved_backward_cnodes);
382 MS_EXCEPTION_IF_NULL(moved_forward_cnodes);
383 auto cnode_ptr_list = graph_ptr->execution_order();
384 if (last_grad_and_status.size() != kLastGradAndStatusNum) {
385 return nullptr;
386 }
387 auto last_grad_ptr = last_grad_and_status[0];
388 auto float_status_ptr = last_grad_and_status[1];
389 auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr);
390 auto float_status_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), float_status_ptr);
391 if (last_grad_pos == cnode_ptr_list.end() || float_status_pos == cnode_ptr_list.end()) {
392 return nullptr;
393 }
394 auto graph_id = AnfAlgo::GetGraphId(last_grad_ptr.get());
395 moved_backward_cnodes->insert(moved_backward_cnodes->end(), last_grad_pos + 1, float_status_pos);
396
397 auto it = float_status_pos;
398 while (AnfAlgo::GetGraphId((*it).get()) == graph_id && it < cnode_ptr_list.end()) {
399 if (AnfAlgo::GetCNodeName(*it) == kAtomicAddrCleanOpName) {
400 it++;
401 continue;
402 }
403 auto inputs = GetInputKernels(*it);
404 bool is_independent = true;
405 for (auto &input : inputs) {
406 if (find(moved_backward_cnodes->begin(), moved_backward_cnodes->end(), input) != moved_backward_cnodes->end()) {
407 is_independent = false;
408 break;
409 }
410 }
411 if (is_independent) {
412 if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) {
413 moved_forward_cnodes->emplace_back(*(it - 1));
414 }
415 moved_forward_cnodes->emplace_back(*it);
416 } else {
417 if (AnfAlgo::GetCNodeName(*(it - 1)) == kAtomicAddrCleanOpName) {
418 moved_backward_cnodes->emplace_back(*(it - 1));
419 }
420 moved_backward_cnodes->emplace_back(*it);
421 }
422 it++;
423 }
424
425 size_t total_moved_size = LongToSize(it - last_grad_pos - 1);
426 if (HasRefNodes(*moved_backward_cnodes) ||
427 moved_backward_cnodes->size() + moved_forward_cnodes->size() != total_moved_size) {
428 MS_LOG(INFO) << "Ref node was found or invalid number of moved nodes, give up optimization";
429 return nullptr;
430 }
431 return GetTargetOutputNode(*moved_backward_cnodes, *it, graph_ptr);
432 }
433
GetTargetOutputNode(const vector<CNodePtr> & moved_backward_cnodes,const CNodePtr first_node,const NotNull<KernelGraphPtr> & graph_ptr)434 CNodePtr AscendStreamAssign::GetTargetOutputNode(const vector<CNodePtr> &moved_backward_cnodes,
435 const CNodePtr first_node, const NotNull<KernelGraphPtr> &graph_ptr) {
436 auto cnode_ptr_list = graph_ptr->execution_order();
437 if (moved_backward_cnodes.empty() || !first_node) {
438 return nullptr;
439 }
440 uint32_t subgraph_id = 0;
441 bool get_subgraph_id = false;
442 auto it = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), first_node);
443 CNodePtr first_output_node_ptr = nullptr;
444 while (!get_subgraph_id && it < cnode_ptr_list.end()) {
445 auto inputs = GetInputKernels(*it);
446 for (auto &input : inputs) {
447 if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end()) {
448 get_subgraph_id = true;
449 subgraph_id = AnfAlgo::GetGraphId((*it).get());
450 first_output_node_ptr = *it;
451 break;
452 }
453 }
454 it++;
455 }
456 if (subgraph_id == 0) {
457 MS_LOG(INFO) << "The nodes moved backward were not used by any other nodes, no need moved";
458 return nullptr;
459 }
460
461 for (; it < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*it).get()) != subgraph_id; it++) {
462 auto inputs = GetInputKernels(*it);
463 for (auto &input : inputs) {
464 if (find(moved_backward_cnodes.begin(), moved_backward_cnodes.end(), input) != moved_backward_cnodes.end()) {
465 MS_LOG(INFO) << "The nodes moved backward were used by nodes on different subgraphs, no need moved";
466 return nullptr;
467 }
468 }
469 }
470 return first_output_node_ptr;
471 }
472
FinetuneSubgraphExecOrder(vector<CNodePtr> * cnodes)473 bool AscendStreamAssign::FinetuneSubgraphExecOrder(vector<CNodePtr> *cnodes) {
474 MS_EXCEPTION_IF_NULL(cnodes);
475 auto hcom_pos = find_if(cnodes->begin(), cnodes->end(),
476 [](CNodePtr &node_ptr) -> bool { return AnfAlgo::GetCNodeName(node_ptr) == "AllReduce"; });
477 if (hcom_pos == cnodes->end()) {
478 return false;
479 }
480 CNodePtr hcom_ptr = *hcom_pos;
481
482 vector<CNodePtr> ori_cnodes(cnodes->begin(), cnodes->end());
483 cnodes->clear();
484 vector<CNodePtr> atomic_addr_clean;
485 for (auto iter = ori_cnodes.begin(); iter < ori_cnodes.end(); ++iter) {
486 if (AnfAlgo::GetCNodeName(*iter) == kAtomicAddrCleanOpName) {
487 atomic_addr_clean.emplace_back(*iter);
488 continue;
489 }
490 auto last_input_pos = cnodes->end();
491 for (auto &input : GetInputKernels(*iter)) {
492 auto pos = find(cnodes->begin(), cnodes->end(), input);
493 if (pos != cnodes->end()) {
494 last_input_pos = (last_input_pos == cnodes->end() || last_input_pos < pos) ? pos : last_input_pos;
495 }
496 }
497 if (last_input_pos == cnodes->end()) {
498 auto hcom_it = find(cnodes->begin(), cnodes->end(), hcom_ptr);
499 if (hcom_it == cnodes->end() || AnfAlgo::GetCNodeName(*iter) == kLabelGotoOpName ||
500 AnfAlgo::GetCNodeName(*iter) == kLabelSetOpName || AnfAlgo::GetCNodeName(*iter) == kLabelSwitchOpName) {
501 cnodes->emplace_back(*iter);
502 } else {
503 cnodes->insert(hcom_it, *iter);
504 }
505 } else {
506 cnodes->insert(last_input_pos + 1, *iter);
507 }
508 }
509
510 for (auto &node : atomic_addr_clean) {
511 auto first_input_pos = cnodes->end();
512 for (auto &input : GetInputKernels(node)) {
513 auto pos = find(cnodes->begin(), cnodes->end(), input);
514 first_input_pos = (first_input_pos == cnodes->end() || first_input_pos > pos) ? pos : first_input_pos;
515 }
516 if (first_input_pos == cnodes->end()) {
517 return false;
518 } else {
519 cnodes->insert(first_input_pos, node);
520 }
521 }
522 return cnodes->size() == ori_cnodes.size();
523 }
524
525 // performance optimization for trailing time in distribute mode
526 // allreduce of the last batch of gradients and the optimizer can be done parallel
TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> & graph_ptr)527 void AscendStreamAssign::TrailingTimeOptimizationByReorder(const NotNull<KernelGraphPtr> &graph_ptr) {
528 vector<CNodePtr> last_grad_and_status;
529 CheckScenario(graph_ptr, &last_grad_and_status);
530 if (last_grad_and_status.empty()) {
531 MS_LOG(INFO) << "Unsuitable scenario, no optimization needed";
532 return;
533 }
534
535 auto cnode_ptr_list = graph_ptr->execution_order();
536 vector<CNodePtr> moved_forward_cnodes;
537 vector<CNodePtr> moved_backward_cnodes;
538 CNodePtr first_output_ptr =
539 GetCNodesNeededMoved(&moved_backward_cnodes, &moved_forward_cnodes, last_grad_and_status, graph_ptr);
540 if (moved_backward_cnodes.empty() || first_output_ptr == nullptr) {
541 MS_LOG(INFO) << "Unsuitable scenario, no optimization needed";
542 return;
543 }
544
545 uint32_t subgraph_id = AnfAlgo::GetGraphId(first_output_ptr.get());
546 auto last_grad_ptr = last_grad_and_status[0];
547 auto last_grad_pos = find(cnode_ptr_list.begin(), cnode_ptr_list.end(), last_grad_ptr);
548 vector<CNodePtr> cnodes(cnode_ptr_list.begin(), last_grad_pos + 1);
549 cnodes.insert(cnodes.end(), moved_forward_cnodes.begin(), moved_forward_cnodes.end());
550 auto pos = last_grad_pos + moved_forward_cnodes.size() + moved_backward_cnodes.size() + 1;
551 while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) != subgraph_id) {
552 cnodes.emplace_back(*pos);
553 ++pos;
554 }
555
556 vector<CNodePtr> subgraph_cnodes;
557 while (pos < cnode_ptr_list.end() && AnfAlgo::GetGraphId((*pos).get()) == subgraph_id) {
558 if (AnfAlgo::GetCNodeName(*pos) == kLabelGotoOpName) {
559 break;
560 }
561 if (*pos != first_output_ptr) {
562 subgraph_cnodes.emplace_back(*pos);
563 } else {
564 subgraph_cnodes.insert(subgraph_cnodes.end(), moved_backward_cnodes.begin(), moved_backward_cnodes.end());
565 subgraph_cnodes.emplace_back(*pos);
566 }
567 ++pos;
568 }
569
570 if (!FinetuneSubgraphExecOrder(&subgraph_cnodes) || subgraph_cnodes.empty()) {
571 MS_LOG(INFO) << "Finetune subgraph execute order failed, no optimization needed";
572 return;
573 }
574
575 cnodes.insert(cnodes.end(), subgraph_cnodes.begin(), subgraph_cnodes.end());
576 cnodes.insert(cnodes.end(), pos, cnode_ptr_list.end());
577 if (cnodes.size() != cnode_ptr_list.size()) {
578 return;
579 }
580 for (auto &node : subgraph_cnodes) {
581 AnfAlgo::SetGraphId(subgraph_id, node.get());
582 }
583
584 graph_ptr->set_execution_order(cnodes);
585 }
586
587 // section 2
AssignAllNodesStream(const NotNull<KernelGraphPtr> & graph_ptr)588 void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr) {
589 auto cnode_ptr_list = graph_ptr->execution_order();
590 bool exit_independent = false;
591 bool exit_hcom = false;
592 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
593 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
594 CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
595 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
596 // node has been assigned stream before
597 if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
598 continue;
599 }
600
601 if (IsHcom(cur_cnode_ptr)) {
602 exit_hcom = true;
603 continue;
604 }
605
606 if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
607 exit_independent = true;
608 continue;
609 }
610
611 AssignCommonStreamId(cur_cnode_ptr);
612 }
613
614 auto common_stream_num = resource_manager.get_cur_stream_num();
615
616 if (exit_hcom) {
617 AssignHcom(graph_ptr);
618 }
619 auto hcom_stream_num = resource_manager.get_cur_stream_num() - common_stream_num;
620
621 if (exit_independent) {
622 AssignIndependent(graph_ptr);
623 }
624 auto independent_stream_num = resource_manager.get_cur_stream_num() - common_stream_num - hcom_stream_num;
625 auto total_stream_num =
626 resource_manager.get_cur_stream_num() + Uint32tMulWithOverflowCheck(hcom_stream_num, kHcomSecondaryStreamNum);
627 MS_LOG(INFO) << "Total stream number: " << total_stream_num << ", common stream number: " << common_stream_num
628 << ", hcom stream number: " << hcom_stream_num << "*" << (kHcomSecondaryStreamNum + 1)
629 << ", independent stream number: " << independent_stream_num << ".";
630
631 if (total_stream_num > kMaxStreamNum) {
632 MS_LOG(EXCEPTION) << "Total stream number " << total_stream_num << " exceeds the limit of " << kMaxStreamNum
633 << ", search details information in mindspore's FAQ.";
634 }
635
636 MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
637 }
638
AssignCommonStreamId(const CNodePtr & cur_cnode_ptr)639 void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
640 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
641 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
642 uint32_t cur_common_stream_id = 0;
643 uint32_t cur_stream_num = resource_manager.get_cur_stream_num();
644 if (cur_stream_num == 0) {
645 cur_common_stream_id = resource_manager.ApplyNewStream();
646 } else {
647 cur_common_stream_id = resource_manager.GetCurAllocStreamId();
648 }
649
650 auto it = common_stream_map_.find(cur_common_stream_id);
651 if (it == common_stream_map_.end()) {
652 AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
653 common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
654 } else {
655 if (it->second < kMaxCommonNodeNumPerStream) {
656 AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
657 it->second++;
658 } else {
659 cur_common_stream_id = resource_manager.ApplyNewStream();
660 AnfAlgo::SetStreamId(cur_common_stream_id, cur_cnode_ptr.get());
661 common_stream_map_.insert(std::make_pair(cur_common_stream_id, 1));
662 }
663 }
664 }
665
AssignHcom(const NotNull<KernelGraphPtr> & graph_ptr)666 void AscendStreamAssign::AssignHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
667 auto cnode_ptr_list = graph_ptr->execution_order();
668 std::map<std::string, std::map<uint32_t, std::vector<CNodePtr>>> group_graph_nodes_map;
669 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
670 CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
671 // node has been assigned stream before
672 if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
673 continue;
674 }
675
676 if (IsHcom(cur_cnode_ptr)) {
677 auto group_name = GetHcomGroup(cur_cnode_ptr);
678 auto hcom_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
679 auto iter = group_graph_nodes_map.find(group_name);
680 if (iter == group_graph_nodes_map.end()) {
681 std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
682 graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
683 group_graph_nodes_map[group_name] = graph_nodes_map;
684 } else {
685 auto &graph_nodes_map = iter->second;
686 auto it = graph_nodes_map.find(hcom_graph_id);
687 if (it == graph_nodes_map.end()) {
688 graph_nodes_map[hcom_graph_id] = {cur_cnode_ptr};
689 } else {
690 it->second.emplace_back(cur_cnode_ptr);
691 }
692 }
693 }
694 }
695
696 MS_LOG(INFO) << "hcom diff group size:" << group_graph_nodes_map.size();
697 for (const auto &item : group_graph_nodes_map) {
698 MS_LOG_INFO << "group id:" << item.first << "; diff graph id size:" << item.second.size();
699 }
700
701 for (const auto &diff_group : group_graph_nodes_map) {
702 // group id:
703 std::map<uint32_t, std::set<uint32_t>> hcom_graph_map;
704 for (const auto &item : diff_group.second) {
705 bool new_graph = true;
706 auto graph_id = item.first;
707 hcom_graph_map[graph_id] = {};
708 for (const auto &hcom_node_ptr : item.second) {
709 auto assigned_stream_id = AssignHcomStreamId(hcom_node_ptr, new_graph);
710 hcom_graph_map[graph_id].emplace(assigned_stream_id);
711 new_graph = false;
712 }
713 }
714 group_hcom_graph_map_[diff_group.first] = hcom_graph_map;
715 }
716 }
717
AssignHcomStreamId(const CNodePtr & cur_cnode_ptr,bool new_graph)718 uint32_t AscendStreamAssign::AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
719 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
720 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
721 auto task_num = GetHcomTaskNum(cur_cnode_ptr);
722
723 uint32_t cur_hcom_stream_id;
724 if (new_graph) {
725 cur_hcom_stream_id = resource_manager.ApplyNewStream();
726 } else {
727 cur_hcom_stream_id = resource_manager.GetCurAllocStreamId();
728 }
729 auto it = hcom_stream_map_.find(cur_hcom_stream_id);
730 if (it == hcom_stream_map_.end()) {
731 AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
732 hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
733 } else {
734 if (it->second <= kMaxTaskNumPerStream - task_num) {
735 AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
736 it->second = Uint32tAddWithOverflowCheck(it->second, task_num);
737 } else {
738 cur_hcom_stream_id = resource_manager.ApplyNewStream();
739 AnfAlgo::SetStreamId(cur_hcom_stream_id, cur_cnode_ptr.get());
740 hcom_stream_map_.emplace(cur_hcom_stream_id, task_num);
741 }
742 }
743 return cur_hcom_stream_id;
744 }
745
AssignIndependent(const NotNull<KernelGraphPtr> & graph_ptr)746 void AscendStreamAssign::AssignIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
747 auto cnode_ptr_list = graph_ptr->execution_order();
748 std::map<uint32_t, std::vector<CNodePtr>> graph_nodes_map;
749 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
750 CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
751 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
752 if (AnfAlgo::GetStreamId(cur_cnode_ptr) != kInvalidStreamId) {
753 continue;
754 }
755 if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
756 auto independent_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
757 auto it = graph_nodes_map.find(independent_graph_id);
758 if (it == graph_nodes_map.end()) {
759 graph_nodes_map[independent_graph_id] = {cur_cnode_ptr};
760 } else {
761 it->second.emplace_back(cur_cnode_ptr);
762 }
763 }
764 }
765
766 MS_LOG(INFO) << "independent diff graph id size:" << graph_nodes_map.size();
767 for (const auto &item : graph_nodes_map) {
768 bool new_graph = true;
769 auto graph_id = item.first;
770 independent_graph_map_[graph_id] = {};
771 for (const auto &independent_node_ptr : item.second) {
772 auto assigned_stream_id = AssignIndependentStreamId(independent_node_ptr, new_graph);
773 independent_graph_map_[graph_id].emplace(assigned_stream_id);
774 new_graph = false;
775 }
776 }
777 MS_LOG(INFO) << "stream nums:" << independent_stream_map_.size();
778 }
779
AssignIndependentStreamId(const CNodePtr & cur_cnode_ptr,bool new_graph)780 uint32_t AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph) {
781 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
782 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
783 uint32_t cur_independent_stream_id;
784 if (new_graph) {
785 cur_independent_stream_id = resource_manager.ApplyNewStream();
786 } else {
787 cur_independent_stream_id = resource_manager.GetCurAllocStreamId();
788 }
789 auto it = independent_stream_map_.find(cur_independent_stream_id);
790 if (it == independent_stream_map_.end()) {
791 AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
792 independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
793 } else {
794 if (it->second < kMaxCommonNodeNumPerStream) {
795 AnfAlgo::SetStreamId(it->first, cur_cnode_ptr.get());
796 it->second++;
797 } else {
798 cur_independent_stream_id = resource_manager.ApplyNewStream();
799 AnfAlgo::SetStreamId(cur_independent_stream_id, cur_cnode_ptr.get());
800 independent_stream_map_.insert(std::make_pair(cur_independent_stream_id, 1));
801 }
802 }
803
804 return cur_independent_stream_id;
805 }
806
807 // section 3
UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> & graph_ptr)808 void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
809 MS_LOG(INFO) << "Start";
810 auto cnode_ptr_list = graph_ptr->execution_order();
811 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
812 CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
813 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
814 // update AtomicAddrClean stream same with the next node
815 if (i > 0 && AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]) == kAtomicAddrCleanOpName) {
816 AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get());
817 }
818 }
819 MS_LOG(INFO) << "End";
820 }
821
822 // section 4
InsertStreamActive(const NotNull<KernelGraphPtr> & graph_ptr)823 void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr) {
824 InsertStreamActiveForCommon(graph_ptr);
825 InsertStreamActiveForIndependent(graph_ptr);
826 InsertStreamActiveForParallel(graph_ptr);
827 }
828
InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> & graph_ptr)829 void AscendStreamAssign::InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
830 if (group_hcom_graph_map_.empty() && independent_graph_map_.empty()) {
831 MS_LOG(INFO) << "Hcom and independent is empty";
832 return;
833 }
834 auto root_graph_id = graph_ptr->graph_id();
835 if (root_graph_id == kInvalidGraphId) {
836 MS_LOG(INFO) << "Root graph id is invalid";
837 return;
838 }
839
840 std::map<uint32_t, std::set<uint32_t>> other_graph;
841 std::set<uint32_t> hcom_streams;
842 for (const auto &graph_nodes : group_hcom_graph_map_) {
843 for (const auto &item : graph_nodes.second) {
844 MS_LOG(INFO) << "Graph id:" << item.first;
845 if (item.first == root_graph_id) {
846 if (loop_sink_) {
847 hcom_streams.insert(item.second.begin(), item.second.end());
848 }
849 } else {
850 auto it = other_graph.find(item.first);
851 if (it == other_graph.end()) {
852 other_graph[item.first] = item.second;
853 } else {
854 for (const auto &stream : item.second) {
855 it->second.emplace(stream);
856 }
857 }
858 }
859 }
860 }
861
862 if (!hcom_streams.empty()) {
863 ActiveRootGraphHcom(graph_ptr, hcom_streams);
864 }
865
866 MS_LOG(INFO) << "Independent graph map size:" << independent_graph_map_.size();
867 for (const auto &item : independent_graph_map_) {
868 MS_LOG(DEBUG) << "Graph id:" << item.first;
869 if (item.first == root_graph_id) {
870 if (loop_sink_) {
871 ActiveRootGraphIndependent(graph_ptr, item.second);
872 }
873 } else {
874 auto it = other_graph.find(item.first);
875 if (it == other_graph.end()) {
876 other_graph[item.first] = item.second;
877 } else {
878 for (const auto &stream : item.second) {
879 it->second.emplace(stream);
880 }
881 }
882 }
883 }
884
885 ActiveOtherGraphParallel(graph_ptr, other_graph);
886 }
887
ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> & graph_ptr,std::map<uint32_t,std::set<uint32_t>> other_graph)888 void AscendStreamAssign::ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
889 std::map<uint32_t, std::set<uint32_t>> other_graph) {
890 MS_LOG(INFO) << "Other graph size:" << other_graph.size();
891 if (other_graph.empty()) {
892 return;
893 }
894
895 auto root_graph_id = graph_ptr->graph_id();
896
897 std::vector<CNodePtr> update_stream_list;
898 auto exe_order = graph_ptr->execution_order();
899 for (size_t i = 0; i < exe_order.size(); i++) {
900 auto cur_cnode_ptr = exe_order[i];
901 auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode_ptr.get());
902 if (cur_graph_id == root_graph_id) {
903 update_stream_list.emplace_back(cur_cnode_ptr);
904 continue;
905 }
906
907 auto it = other_graph.find(cur_graph_id);
908 if (it == other_graph.end()) {
909 update_stream_list.emplace_back(cur_cnode_ptr);
910 continue;
911 }
912
913 auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
914 CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
915 // 1.set stream id
916 AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
917 // 2.set active stream ids
918 std::vector<uint32_t> active_index_list;
919 std::copy(it->second.begin(), it->second.end(), std::back_inserter(active_index_list));
920 AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
921
922 // find position for insert streamactive
923 if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kLabelSetOpName) {
924 update_stream_list.emplace_back(cur_cnode_ptr);
925 update_stream_list.emplace_back(active_ptr);
926 } else {
927 update_stream_list.emplace_back(active_ptr);
928 update_stream_list.emplace_back(cur_cnode_ptr);
929 }
930 other_graph.erase(it);
931 }
932 graph_ptr->set_execution_order(update_stream_list);
933 }
934
ActiveRootGraphHcom(const NotNull<KernelGraphPtr> & graph_ptr,const std::set<uint32_t> & hcom_streams)935 void AscendStreamAssign::ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr,
936 const std::set<uint32_t> &hcom_streams) {
937 MS_LOG(INFO) << "Active root graph hcom start";
938 std::vector<CNodePtr> update_cnode_list;
939 auto exe_orders = graph_ptr->execution_order();
940 for (size_t i = 0; i < exe_orders.size(); i++) {
941 CNodePtr cur_cnode_ptr = exe_orders[i];
942 if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) {
943 update_cnode_list.emplace_back(cur_cnode_ptr);
944 continue;
945 }
946
947 if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) {
948 update_cnode_list.emplace_back(cur_cnode_ptr);
949 continue;
950 }
951
952 auto kind = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrStreamSwitchKind);
953 if (kind != kFpBpStreamSwitch) {
954 update_cnode_list.emplace_back(cur_cnode_ptr);
955 continue;
956 }
957
958 auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
959 MS_LOG(INFO) << "FpBpStreamswtich stream id:" << AnfAlgo::GetStreamId(cur_cnode_ptr)
960 << "; true branch stream id:" << true_stream_id;
961 CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
962 AnfAlgo::SetStreamId(true_stream_id, active_ptr.get());
963 vector<uint32_t> active_ids;
964 // active hcom stream
965 std::copy(hcom_streams.begin(), hcom_streams.end(), std::back_inserter(active_ids));
966 AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_ids), active_ptr);
967 update_cnode_list.emplace_back(cur_cnode_ptr);
968 update_cnode_list.emplace_back(active_ptr);
969 std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
970 break;
971 }
972
973 hcom_stream_activated_ = true;
974 graph_ptr->set_execution_order(update_cnode_list);
975 }
976
ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> & graph_ptr,const std::set<uint32_t> & independent_streams)977 void AscendStreamAssign::ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr,
978 const std::set<uint32_t> &independent_streams) {
979 MS_LOG(DEBUG) << "Start active root graph independent";
980 std::vector<CNodePtr> update_cnode_list;
981 auto exe_orders = graph_ptr->execution_order();
982 for (size_t i = 0; i < exe_orders.size(); i++) {
983 CNodePtr cur_cnode_ptr = exe_orders[i];
984 if (AnfAlgo::GetCNodeName(cur_cnode_ptr) != kStreamSwitchOpName) {
985 update_cnode_list.emplace_back(cur_cnode_ptr);
986 continue;
987 }
988
989 if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, cur_cnode_ptr)) {
990 update_cnode_list.emplace_back(cur_cnode_ptr);
991 continue;
992 }
993
994 auto kind = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrStreamSwitchKind);
995 if (kind != kIndependentStreamSwitch) {
996 update_cnode_list.emplace_back(cur_cnode_ptr);
997 continue;
998 }
999
1000 // first independetn stream id is minimum and order by std map;
1001 auto first_independent_stream = *(independent_streams.begin());
1002 AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_independent_stream), cur_cnode_ptr);
1003 update_cnode_list.emplace_back(cur_cnode_ptr);
1004 std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
1005 break;
1006 }
1007
1008 independent_stream_activated_ = true;
1009 graph_ptr->set_execution_order(update_cnode_list);
1010 }
InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> & graph_ptr)1011 void AscendStreamAssign::InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
1012 MS_LOG(INFO) << "Start";
1013 GetProcessedStream(graph_ptr);
1014 std::vector<CNodePtr> update_cnode_list;
1015 CNodePtr cur_cnode_ptr = nullptr;
1016 CNodePtr pre_cnode_ptr = nullptr;
1017 uint32_t pre_stream_id = UINT32_MAX;
1018
1019 auto cnode_ptr_list = graph_ptr->execution_order();
1020 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1021 cur_cnode_ptr = cnode_ptr_list[i];
1022 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1023 if (AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
1024 update_cnode_list.emplace_back(cur_cnode_ptr);
1025 continue;
1026 }
1027
1028 if (IsHcom(cur_cnode_ptr)) {
1029 update_cnode_list.emplace_back(cur_cnode_ptr);
1030 continue;
1031 }
1032 uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1033 bool processed = IsProcessedStream(cur_stream_id);
1034 // 1)inner stream assign, need insert active op
1035 if (!processed) {
1036 MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id;
1037 CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
1038 // 1.set stream id
1039 AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get());
1040 // 2.set active stream ids
1041 std::vector<uint32_t> active_index_list{cur_stream_id};
1042 AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
1043 if (i > 0) {
1044 auto pre_node = AnfAlgo::GetCNodeName(cnode_ptr_list[i - 1]);
1045 if (pre_node == kLabelSwitchOpName || pre_node == kLabelGotoOpName) {
1046 update_cnode_list.insert(update_cnode_list.end() - 1, active_ptr);
1047 AnfAlgo::SetStreamId(cur_stream_id, cnode_ptr_list[i - 1].get());
1048 } else {
1049 update_cnode_list.emplace_back(active_ptr);
1050 }
1051 }
1052 }
1053 if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
1054 MS_LOG(INFO) << "Insert StreamActive op after FP StreamSwitch for stream parallel";
1055 update_cnode_list.emplace_back(cur_cnode_ptr);
1056 } else {
1057 update_cnode_list.emplace_back(cur_cnode_ptr);
1058 }
1059
1060 processed_streams_.emplace(cur_stream_id);
1061 pre_stream_id = cur_stream_id;
1062 pre_cnode_ptr = cur_cnode_ptr;
1063 }
1064 graph_ptr->set_execution_order(update_cnode_list);
1065 }
1066
InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> & graph_ptr)1067 void AscendStreamAssign::InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr) {
1068 auto root_graph_id = graph_ptr->graph_id();
1069 if (root_graph_id == kInvalidGraphId) {
1070 return;
1071 }
1072 std::set<uint32_t> independent_streams;
1073 for (const auto &item : independent_graph_map_) {
1074 if (item.first == root_graph_id) {
1075 independent_streams = item.second;
1076 }
1077 }
1078
1079 // Root graph independent stream size is not more than one, no need insert active
1080 if (independent_streams.size() <= 1) {
1081 return;
1082 }
1083 std::vector<CNodePtr> update_cnode_list;
1084 auto exe_orders = graph_ptr->execution_order();
1085
1086 // first independent is been activated, active other independent stream
1087 std::vector<uint32_t> streams;
1088 std::copy(independent_streams.begin(), independent_streams.end(), std::back_inserter(streams));
1089 std::sort(streams.begin(), streams.end());
1090 uint32_t node_num = 0;
1091 for (size_t i = 0; i < exe_orders.size(); i++) {
1092 auto cur_cnode_ptr = exe_orders[i];
1093 update_cnode_list.emplace_back(cur_cnode_ptr);
1094 if (!AnfAlgo::IsIndependentNode(cur_cnode_ptr)) {
1095 continue;
1096 }
1097
1098 if (AnfAlgo::GetGraphId(cur_cnode_ptr.get()) != root_graph_id) {
1099 continue;
1100 }
1101
1102 node_num++;
1103 auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1104 auto it = std::find(streams.begin(), streams.end(), cur_stream_id);
1105 if (it == streams.end()) {
1106 MS_LOG(EXCEPTION) << "Can't find independent stream id:" << cur_stream_id;
1107 } else if (it == streams.end() - 1) {
1108 std::copy(exe_orders.begin() + i + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
1109 break;
1110 } else {
1111 if (node_num == kMaxCommonNodeNumPerStream) {
1112 CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
1113 // 1.set stream id
1114 AnfAlgo::SetStreamId(cur_stream_id, active_ptr.get());
1115 // 2.set active stream ids
1116 std::vector<uint32_t> active_index_list{*(it + 1)};
1117 AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
1118 update_cnode_list.emplace_back(active_ptr);
1119 node_num = 0;
1120 }
1121 }
1122 }
1123 graph_ptr->set_execution_order(update_cnode_list);
1124 }
1125
GetProcessedStream(const NotNull<KernelGraphPtr> & graph_ptr)1126 void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr) {
1127 // 0 stream is activated at first
1128 processed_streams_.emplace(0);
1129 auto cnode_ptr_list = graph_ptr->execution_order();
1130 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1131 auto cur_cnode_ptr = cnode_ptr_list[i];
1132 uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1133
1134 if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
1135 if (AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, cur_cnode_ptr)) {
1136 auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrTrueBranchStream);
1137 processed_streams_.emplace(true_stream_id);
1138 }
1139
1140 if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
1141 continue;
1142 }
1143 auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
1144 if (need_active) {
1145 processed_streams_.emplace(cur_stream_id);
1146 }
1147 }
1148 }
1149 for (const auto &item : processed_streams_) {
1150 MS_LOG(INFO) << "Before active:" << item << " is been processed";
1151 }
1152 }
1153
CheckStreamSwitch(const CNodePtr & switch_ptr)1154 bool AscendStreamAssign::CheckStreamSwitch(const CNodePtr &switch_ptr) {
1155 if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, switch_ptr)) {
1156 return false;
1157 }
1158
1159 auto need_active = AnfAlgo::GetNodeAttr<bool>(switch_ptr, kStreamNeedActivedFirst);
1160 if (!need_active) {
1161 return false;
1162 }
1163
1164 if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, switch_ptr)) {
1165 return false;
1166 }
1167
1168 auto kind = AnfAlgo::GetNodeAttr<uint32_t>(switch_ptr, kAttrStreamSwitchKind);
1169 if (kind == kEosStreamSwitch || kind == kGetNextStreamSwitch) {
1170 return false;
1171 }
1172
1173 return true;
1174 }
1175
IsProcessedStream(uint32_t stream_id)1176 bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
1177 auto it = std::find(processed_streams_.begin(), processed_streams_.end(), stream_id);
1178 if (it != processed_streams_.end()) {
1179 return true;
1180 }
1181 return false;
1182 }
1183
IsAllOutGraphOut(const KernelGraphPtr & graph,const CNodePtr & cnode)1184 bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNodePtr &cnode) {
1185 MS_EXCEPTION_IF_NULL(graph);
1186 MS_EXCEPTION_IF_NULL(cnode);
1187 auto cnode_out_num = AnfAlgo::GetOutputTensorNum(cnode);
1188 auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
1189 std::set<int> output_index_set;
1190 // Assign Communicate Op Memory firstly.
1191 for (const auto &node : nodes) {
1192 auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
1193 MS_EXCEPTION_IF_NULL(item_with_index.first);
1194 if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
1195 continue;
1196 }
1197 if (item_with_index.first == cnode) {
1198 output_index_set.insert(item_with_index.second);
1199 }
1200 }
1201
1202 MS_LOG(INFO) << "Node " << cnode->fullname_with_scope() << " has " << cnode_out_num
1203 << " outputs, in graph output num:" << output_index_set.size();
1204 return cnode_out_num == output_index_set.size();
1205 }
1206
FindGraphEnd(vector<CNodePtr>::iterator begin,vector<CNodePtr>::iterator end)1207 vector<CNodePtr>::iterator AscendStreamAssign::FindGraphEnd(vector<CNodePtr>::iterator begin,
1208 vector<CNodePtr>::iterator end) {
1209 while (begin != end) {
1210 if (AnfAlgo::HasNodeAttr(kAttrFpBpEnd, *begin)) {
1211 MS_LOG(INFO) << "FpBp end op is " << (*begin)->fullname_with_scope();
1212 return begin;
1213 }
1214 ++begin;
1215 }
1216 return end;
1217 }
1218
1219 // section5
InsertEventForHcomParallel(const NotNull<KernelGraphPtr> & graph_ptr)1220 void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
1221 MS_LOG(INFO) << "Start";
1222 InsertEventCommonDependHcom(graph_ptr);
1223 InsertEventHcomDependCommonBak(graph_ptr);
1224 InsertEventHcomDependHcom(graph_ptr);
1225 MS_LOG(INFO) << "End";
1226 }
1227
InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> & graph_ptr)1228 void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
1229 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1230 auto cnode_ptr_list = graph_ptr->execution_order();
1231 vector<CNodePtr> cnodes = cnode_ptr_list;
1232 uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1233 auto it = cnodes.begin();
1234 while (it != cnodes.end()) {
1235 MS_EXCEPTION_IF_NULL(*it);
1236 if (IsHcom(*it)) {
1237 auto cur_hcom_node = *it;
1238 CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
1239 it = cnodes.insert(it + 1, send_cnode_ptr);
1240
1241 auto target = FindTargetOp(it, cnodes.end(), cur_hcom_node, true);
1242 if (target == cnodes.end()) {
1243 if (IsAllOutGraphOut(graph_ptr, cur_hcom_node)) {
1244 // if hcom's all output is graph output, we need to insert send/recv to fpbp end in data sink mode
1245 target = FindGraphEnd(it, cnodes.end());
1246 }
1247
1248 if (target == cnodes.end()) {
1249 MS_EXCEPTION_IF_NULL(*(it - 1));
1250 MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
1251 << ", can't find target for insert recv op, no insert send/recv";
1252 it = cnodes.erase(it);
1253 continue;
1254 }
1255 }
1256
1257 // deal recv op
1258 uint32_t stream_id = AnfAlgo::GetStreamId(*target);
1259 CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id);
1260 (void)cnodes.insert(target, recv_cnode_ptr);
1261 cur_event_id = resource_manager.ApplyNewEvent();
1262 }
1263 ++it;
1264 }
1265 // one event allocated additional, should delete
1266 resource_manager.DeleteEvent();
1267 graph_ptr->set_execution_order(cnodes);
1268 MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
1269 }
1270
1271 // after memory reuse is correct, use this function
InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> & graph_ptr)1272 void AscendStreamAssign::InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr) {
1273 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1274 auto cnode_ptr_list = graph_ptr->execution_order();
1275 vector<CNodePtr> cnodes;
1276 CNodePtr cur_cnode_ptr = nullptr;
1277 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1278 cur_cnode_ptr = cnode_ptr_list[i];
1279 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1280 if (i == 0) {
1281 cnodes.emplace_back(cur_cnode_ptr);
1282 continue;
1283 }
1284
1285 if (!IsHcom(cur_cnode_ptr)) {
1286 cnodes.emplace_back(cur_cnode_ptr);
1287 continue;
1288 }
1289
1290 // get the input which located in the last exe orders
1291 vector<CNodePtr> inputs_cnode = GetLastInputCnode(graph_ptr, cur_cnode_ptr);
1292 if (inputs_cnode.empty()) {
1293 cnodes.emplace_back(cur_cnode_ptr);
1294 MS_LOG(WARNING) << "Hcom op:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << " can't find inputs nodes";
1295 continue;
1296 }
1297
1298 MS_LOG(INFO) << "Current hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
1299 << "; inputs cnode size:" << inputs_cnode.size();
1300
1301 for (size_t j = 0; j < inputs_cnode.size(); j++) {
1302 auto &cur_input = inputs_cnode.at(j);
1303 MS_LOG(INFO) << "The index:" << j << " input, name:" << AnfAlgo::GetCNodeName(cur_input);
1304 uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1305 auto pre_stream_id = AnfAlgo::GetStreamId(cur_input);
1306 auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id);
1307 auto it = std::find(cnodes.begin(), cnodes.end(), cur_input);
1308 if (it == cnodes.end()) {
1309 MS_LOG_EXCEPTION << "Hcom:" << AnfAlgo::GetCNodeName(cur_cnode_ptr)
1310 << " can't find input node:" << AnfAlgo::GetCNodeName(cur_input);
1311 }
1312 cnodes.insert(it + 1, send);
1313 uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1314 auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
1315 cnodes.emplace_back(recv);
1316 cnodes.emplace_back(cur_cnode_ptr);
1317 }
1318 }
1319
1320 graph_ptr->set_execution_order(cnodes);
1321 MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
1322 }
1323
GetLastInputCnode(const NotNull<KernelGraphPtr> & graph_ptr,const CNodePtr & cur_cnode_ptr)1324 vector<CNodePtr> AscendStreamAssign::GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr,
1325 const CNodePtr &cur_cnode_ptr) {
1326 auto group_name = GetHcomGroup(cur_cnode_ptr);
1327 auto input_cnodes = GetInputKernels(cur_cnode_ptr);
1328 if (input_cnodes.empty()) {
1329 return {};
1330 }
1331 // record max index node for each stream
1332 std::map<uint32_t, std::pair<CNodePtr, uint32_t>> result;
1333 for (size_t i = 0; i < input_cnodes.size(); i++) {
1334 auto &cur_input = input_cnodes.at(i);
1335 auto stream_id = AnfAlgo::GetStreamId(cur_input);
1336 auto cur_index = GetIndexByKey(graph_ptr, cur_input.get());
1337 if (cur_index == UINT32_MAX) {
1338 MS_LOG_EXCEPTION << "The input node:" << AnfAlgo::GetCNodeName(cur_input) << " is not found in graph";
1339 }
1340 auto it = result.find(stream_id);
1341 if (it == result.end()) {
1342 result[stream_id] = std::make_pair(cur_input, cur_index);
1343 } else {
1344 auto max_index = it->second.second;
1345 if (cur_index > max_index) {
1346 result[stream_id] = std::make_pair(cur_input, cur_index);
1347 }
1348 }
1349 }
1350
1351 vector<CNodePtr> final_inputs;
1352 CNodePtr max_common_cnode = nullptr;
1353 for (const auto &item : result) {
1354 if (IsHcom(item.second.first)) {
1355 auto cur_group = GetHcomGroup(item.second.first);
1356 if (cur_group == group_name) {
1357 continue;
1358 } else {
1359 final_inputs.emplace_back(item.second.first);
1360 }
1361 } else {
1362 max_common_cnode = item.second.first;
1363 }
1364 }
1365
1366 if (max_common_cnode != nullptr) {
1367 final_inputs.emplace_back(max_common_cnode);
1368 }
1369 return final_inputs;
1370 }
1371
GetInputKernels(const CNodePtr & cnode)1372 vector<CNodePtr> AscendStreamAssign::GetInputKernels(const CNodePtr &cnode) {
1373 MS_EXCEPTION_IF_NULL(cnode);
1374 vector<CNodePtr> input_cnodes;
1375 queue<CNodePtr> nop_nodes;
1376 auto inputs = cnode->inputs();
1377 for (size_t i = 1; i < inputs.size(); i++) {
1378 auto real_input = AnfAlgo::VisitKernel(inputs[i], 0);
1379 auto node = real_input.first;
1380 MS_EXCEPTION_IF_NULL(node);
1381 if (opt::IsNopNode(node)) {
1382 nop_nodes.push(node->cast<CNodePtr>());
1383 while (!nop_nodes.empty()) {
1384 auto cur_node = nop_nodes.front();
1385 nop_nodes.pop();
1386 auto new_inputs = cur_node->inputs();
1387 for (size_t j = 1; j < new_inputs.size(); j++) {
1388 auto new_real_input = AnfAlgo::VisitKernel(new_inputs[j], 0);
1389 auto new_node = new_real_input.first;
1390 MS_EXCEPTION_IF_NULL(new_node);
1391 if (opt::IsNopNode(new_node)) {
1392 nop_nodes.push(new_node->cast<CNodePtr>());
1393 } else if (new_node->isa<CNode>()) {
1394 input_cnodes.emplace_back(new_node->cast<CNodePtr>());
1395 }
1396 }
1397 }
1398 } else if (node->isa<CNode>()) {
1399 input_cnodes.emplace_back(node->cast<CNodePtr>());
1400 }
1401 }
1402 return input_cnodes;
1403 }
1404
InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> & graph_ptr)1405 void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
1406 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1407 auto cnode_ptr_list = graph_ptr->execution_order();
1408 vector<CNodePtr> cnodes;
1409 CNodePtr cur_cnode_ptr = nullptr;
1410 uint32_t pre_stream_id = UINT32_MAX;
1411 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1412 cur_cnode_ptr = cnode_ptr_list[i];
1413 uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1414 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1415 if (i == 0) {
1416 cnodes.emplace_back(cur_cnode_ptr);
1417 pre_stream_id = cur_stream_id;
1418 continue;
1419 }
1420
1421 if (!IsHcom(cur_cnode_ptr)) {
1422 cnodes.emplace_back(cur_cnode_ptr);
1423 pre_stream_id = cur_stream_id;
1424 continue;
1425 }
1426
1427 if (cur_stream_id == pre_stream_id) {
1428 cnodes.emplace_back(cur_cnode_ptr);
1429 pre_stream_id = cur_stream_id;
1430 continue;
1431 }
1432
1433 if (!IsHcom(cnode_ptr_list[i - 1])) {
1434 uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1435 auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, pre_stream_id);
1436 cnodes.emplace_back(send);
1437 auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_stream_id);
1438 cnodes.emplace_back(recv);
1439 cnodes.emplace_back(cur_cnode_ptr);
1440 } else {
1441 cnodes.emplace_back(cur_cnode_ptr);
1442 }
1443 pre_stream_id = cur_stream_id;
1444 }
1445
1446 graph_ptr->set_execution_order(cnodes);
1447 MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
1448 }
1449
GetStreamIDHcomMap(const std::vector<CNodePtr> & cnode_ptr_list,const std::string & group,size_t graph_id)1450 std::vector<std::pair<uint32_t, vector<size_t>>> AscendStreamAssign::GetStreamIDHcomMap(
1451 const std::vector<CNodePtr> &cnode_ptr_list, const std::string &group, size_t graph_id) {
1452 std::vector<std::pair<uint32_t, vector<size_t>>> stream_indices;
1453 for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
1454 auto cur_cnode = cnode_ptr_list[i];
1455 if (!IsHcom(cur_cnode)) {
1456 continue;
1457 }
1458
1459 uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
1460 auto group_name = GetHcomGroup(cur_cnode);
1461 auto cur_graph_id = AnfAlgo::GetGraphId(cur_cnode.get());
1462 MS_LOG(INFO) << "Hcom node name:" << AnfAlgo::GetCNodeName(cur_cnode) << "; group:" << group_name
1463 << "; stream id:" << cur_stream_id;
1464 if (group_name != group || cur_graph_id != graph_id) {
1465 continue;
1466 }
1467
1468 bool exit = false;
1469 for (auto &item : stream_indices) {
1470 if (item.first == cur_stream_id) {
1471 item.second.emplace_back(i);
1472 exit = true;
1473 break;
1474 }
1475 }
1476 if (!exit) {
1477 stream_indices.emplace_back(std::make_pair(cur_stream_id, std::vector<size_t>{i}));
1478 }
1479 }
1480 return stream_indices;
1481 }
1482
InsertEventHcomDependHcomAtSameGroup(const NotNull<KernelGraphPtr> & graph_ptr,std::pair<std::string,std::map<uint32_t,std::set<uint32_t>>> group_item)1483 void AscendStreamAssign::InsertEventHcomDependHcomAtSameGroup(
1484 const NotNull<KernelGraphPtr> &graph_ptr, std::pair<std::string, std::map<uint32_t, std::set<uint32_t>>> group_item) {
1485 for (const auto &graph_item : group_item.second) {
1486 auto stream_indices = GetStreamIDHcomMap(graph_ptr->execution_order(), group_item.first, graph_item.first);
1487 constexpr size_t kStreamMax = 2;
1488 if (stream_indices.size() < kStreamMax) {
1489 MS_LOG(INFO) << "Group:" << group_item.first << ", Graph: " << graph_item.first
1490 << " different stream hcom size is less than 2, no need insert event between them";
1491 continue;
1492 }
1493 InsertEventBetweenHcom(graph_ptr, stream_indices);
1494 }
1495 }
1496
InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> & graph_ptr)1497 void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
1498 if (group_hcom_graph_map_.empty()) {
1499 return;
1500 }
1501 for (const auto &group_item : group_hcom_graph_map_) {
1502 InsertEventHcomDependHcomAtSameGroup(graph_ptr, group_item);
1503 }
1504 }
1505
InsertEventBetweenHcom(const NotNull<KernelGraphPtr> & graph_ptr,const std::vector<std::pair<uint32_t,vector<size_t>>> & hcom_index)1506 void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
1507 const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index) {
1508 vector<CNodePtr> orders;
1509 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1510 auto cnode_ptr_list = graph_ptr->execution_order();
1511 uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1512 if (hcom_index.empty()) {
1513 MS_LOG(EXCEPTION) << "Hcom stream number is empty";
1514 }
1515 size_t first_stream_last_index = hcom_index[0].second.back();
1516 size_t last_stream_first_index = hcom_index.back().second.front();
1517 MS_LOG(INFO) << "First stream last index:" << first_stream_last_index
1518 << "; last stream first index:" << last_stream_first_index;
1519 std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_stream_last_index, std::back_inserter(orders));
1520 for (size_t i = first_stream_last_index; i <= last_stream_first_index; i++) {
1521 auto cur_cnode = cnode_ptr_list[i];
1522 if (!IsSatisfiedHcom(hcom_index, cur_cnode, i)) {
1523 orders.emplace_back(cur_cnode);
1524 continue;
1525 }
1526 auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
1527 if (i == first_stream_last_index) {
1528 // first fusion hcom
1529 orders.emplace_back(cur_cnode);
1530 auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1531 orders.emplace_back(send);
1532 } else if (i == last_stream_first_index) {
1533 // last fusion hcom
1534 auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1535 orders.emplace_back(recv);
1536 orders.emplace_back(cur_cnode);
1537 } else {
1538 size_t cur_stream_hcom_size = UINT32_MAX;
1539 size_t first_index = UINT32_MAX;
1540 size_t last_index = UINT32_MAX;
1541 for (const auto &item : hcom_index) {
1542 if (item.first == cur_hcom_stream_id) {
1543 cur_stream_hcom_size = item.second.size();
1544 first_index = item.second.front();
1545 last_index = item.second.back();
1546 }
1547 }
1548
1549 if (cur_stream_hcom_size == 1) {
1550 auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1551 orders.emplace_back(recv);
1552 cur_event_id = resource_manager.ApplyNewEvent();
1553 orders.emplace_back(cur_cnode);
1554 auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1555 orders.emplace_back(send);
1556 } else {
1557 // current stream, first hcom:add recv op
1558 if (i == first_index) {
1559 auto recv = CreateRecvApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1560 orders.emplace_back(recv);
1561 cur_event_id = resource_manager.ApplyNewEvent();
1562 orders.emplace_back(cur_cnode);
1563 } else if (i == last_index) {
1564 // current stream, last hcom:add send op
1565 orders.emplace_back(cur_cnode);
1566 auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
1567 orders.emplace_back(send);
1568 } else {
1569 // current stream, not first and last op
1570 orders.emplace_back(cur_cnode);
1571 }
1572 }
1573 }
1574 }
1575 std::copy(cnode_ptr_list.begin() + last_stream_first_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
1576 graph_ptr->set_execution_order(orders);
1577 }
1578
IsSatisfiedHcom(const std::vector<std::pair<uint32_t,vector<size_t>>> & hcom_index,const CNodePtr & node_ptr,size_t index)1579 bool AscendStreamAssign::IsSatisfiedHcom(const std::vector<std::pair<uint32_t, vector<size_t>>> &hcom_index,
1580 const CNodePtr &node_ptr, size_t index) {
1581 MS_EXCEPTION_IF_NULL(node_ptr);
1582 auto cur_hcom_stream_id = AnfAlgo::GetStreamId(node_ptr);
1583 for (const auto &item : hcom_index) {
1584 if (item.first == cur_hcom_stream_id) {
1585 auto it = std::find(item.second.begin(), item.second.end(), index);
1586 if (it != item.second.end()) {
1587 return true;
1588 }
1589 }
1590 }
1591 return false;
1592 }
1593
1594 // section6
InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> & graph_ptr)1595 void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
1596 MS_LOG(INFO) << "Start";
1597 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1598 auto cnode_ptr_list = graph_ptr->execution_order();
1599 vector<CNodePtr> cnodes = cnode_ptr_list;
1600 uint32_t cur_event_id = resource_manager.ApplyNewEvent();
1601 std::map<CNodePtr, CNodePtr> cnode_send_map;
1602 std::map<CNodePtr, std::vector<CNodePtr>> cnode_recv_map;
1603 auto it = cnodes.begin();
1604 while (it != cnodes.end()) {
1605 MS_EXCEPTION_IF_NULL(*it);
1606 if (AnfAlgo::IsIndependentNode(*it)) {
1607 MS_LOG(DEBUG) << "Deal independent op[" << (*it)->DebugString() << "]";
1608 CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
1609
1610 auto target = FindTargetOp(it + 1, cnodes.end(), *it, false);
1611 if (target == cnodes.end()) {
1612 MS_LOG(DEBUG) << "Independent node[" << (*it)->fullname_with_scope()
1613 << "] can't find target for insert recv op, no insert send/recv";
1614 it++;
1615 continue;
1616 }
1617
1618 // deal recv op
1619 uint32_t stream_id = AnfAlgo::GetStreamId(*target);
1620 CNodePtr recv_cnode_ptr = CreateRecvApplyKernel(graph_ptr, cur_event_id, stream_id);
1621
1622 cnode_send_map.insert(std::make_pair(*it, send_cnode_ptr));
1623 auto result = cnode_recv_map.find(*target);
1624 if (result == cnode_recv_map.end()) {
1625 std::vector<CNodePtr> recv_cnodes = {recv_cnode_ptr};
1626 cnode_recv_map.insert(std::make_pair(*target, recv_cnodes));
1627 } else {
1628 result->second.push_back(recv_cnode_ptr);
1629 }
1630 cur_event_id = resource_manager.ApplyNewEvent();
1631 }
1632 ++it;
1633 }
1634 // one event allocated additional, should delete
1635 resource_manager.DeleteEvent();
1636
1637 std::vector<CNodePtr> new_cnodes;
1638 for (const auto &cnode : cnodes) {
1639 auto result_recv = cnode_recv_map.find(cnode);
1640 if (result_recv != cnode_recv_map.end()) {
1641 for (const auto &recv : result_recv->second) {
1642 new_cnodes.push_back(recv);
1643 }
1644 }
1645 new_cnodes.push_back(cnode);
1646 auto result_send = cnode_send_map.find(cnode);
1647 if (result_send != cnode_send_map.end()) {
1648 new_cnodes.push_back(result_send->second);
1649 }
1650 }
1651
1652 graph_ptr->set_execution_order(new_cnodes);
1653 MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num();
1654 MS_LOG(INFO) << "End";
1655 }
1656
GetIndependentMaxTarget(const NotNull<KernelGraphPtr> & graph_ptr)1657 void AscendStreamAssign::GetIndependentMaxTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
1658 MS_LOG(INFO) << "Start";
1659 auto cnode_ptr_list = graph_ptr->execution_order();
1660 for (size_t i = 0; i < cnode_ptr_list.size(); i++) {
1661 auto cur_node = cnode_ptr_list[i];
1662 auto key = cur_node.get();
1663 if (!AnfAlgo::IsIndependentNode(cur_node)) {
1664 continue;
1665 }
1666
1667 bool flag = false;
1668 for (size_t j = cnode_ptr_list.size() - 1; j > i; j--) {
1669 auto target_node = cnode_ptr_list[j];
1670 auto inputs = target_node->inputs();
1671 for (size_t m = 1; m < inputs.size(); m++) {
1672 auto input = inputs[m];
1673 MS_EXCEPTION_IF_NULL(input);
1674 if (opt::IsNopNode(input)) {
1675 auto cnode = input->cast<CNodePtr>();
1676 auto new_inputs = cnode->inputs();
1677 for (size_t k = 1; k < new_inputs.size(); k++) {
1678 auto new_real_input = AnfAlgo::VisitKernel(new_inputs[k], 0);
1679 if (key == new_real_input.first.get()) {
1680 MS_LOG(DEBUG) << "Nop node find max target op:" << AnfAlgo::GetCNodeName(cur_node);
1681 independent_targets_.emplace(target_node.get());
1682 flag = true;
1683 break;
1684 }
1685 }
1686 } else {
1687 auto real_input = AnfAlgo::VisitKernel(input, 0);
1688 if (key == real_input.first.get()) {
1689 MS_LOG(DEBUG) << "Find max target op:" << AnfAlgo::GetCNodeName(cur_node);
1690 independent_targets_.emplace(target_node.get());
1691 flag = true;
1692 }
1693 }
1694 if (flag) {
1695 break;
1696 }
1697 }
1698 }
1699 }
1700
1701 MS_LOG(INFO) << "End";
1702 }
1703
GetIndexByKey(const NotNull<KernelGraphPtr> & graph_ptr,const CNodeKey & key)1704 uint32_t AscendStreamAssign::GetIndexByKey(const NotNull<KernelGraphPtr> &graph_ptr, const CNodeKey &key) {
1705 auto &exe_orders = graph_ptr->execution_order();
1706 for (uint32_t i = 0; i < exe_orders.size(); i++) {
1707 CNodeKey node_key = exe_orders[i].get();
1708 if (node_key == key) {
1709 return i;
1710 }
1711 }
1712
1713 return UINT32_MAX;
1714 }
1715
GetMaxIndexTarget(const NotNull<KernelGraphPtr> & graph_ptr)1716 uint32_t AscendStreamAssign::GetMaxIndexTarget(const NotNull<KernelGraphPtr> &graph_ptr) {
1717 if (independent_targets_.empty()) {
1718 return UINT32_MAX;
1719 }
1720
1721 std::set<uint32_t> indices;
1722 for (const auto &key : independent_targets_) {
1723 auto index = GetIndexByKey(graph_ptr, key);
1724 if (index == UINT32_MAX) {
1725 MS_LOG(EXCEPTION) << "graph has no correspond key";
1726 }
1727 indices.emplace(index);
1728 }
1729
1730 return *(std::max_element(indices.begin(), indices.end()));
1731 }
1732
GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> & graph_ptr)1733 uint32_t AscendStreamAssign::GetIndependentStreamSwitchStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
1734 auto &exe_orders = graph_ptr->execution_order();
1735 for (const auto &item : exe_orders) {
1736 if (AnfAlgo::GetCNodeName(item) == kStreamSwitchOpName) {
1737 if (!AnfAlgo::HasNodeAttr(kAttrStreamSwitchKind, item)) {
1738 continue;
1739 }
1740 auto kind = AnfAlgo::GetNodeAttr<uint32_t>(item, kAttrStreamSwitchKind);
1741 if (kind == kIndependentStreamSwitch) {
1742 return AnfAlgo::GetStreamId(item);
1743 }
1744 }
1745 }
1746 return kInvalidStreamId;
1747 }
1748
InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> & graph_ptr)1749 void AscendStreamAssign::InsertCtrlForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
1750 if (independent_targets_.empty()) {
1751 return;
1752 }
1753
1754 uint32_t independent_switch_stream = GetIndependentStreamSwitchStreamId(graph_ptr);
1755 if (independent_switch_stream == kInvalidStreamId) {
1756 return;
1757 }
1758
1759 auto max_index = GetMaxIndexTarget(graph_ptr);
1760 auto &exe_orders = graph_ptr->execution_order();
1761 if (max_index >= exe_orders.size()) {
1762 MS_LOG(EXCEPTION) << "Max target index:" << max_index << " is greater than graph orders size:" << exe_orders.size();
1763 }
1764
1765 auto max_node_stream = AnfAlgo::GetStreamId(exe_orders[max_index]);
1766
1767 CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
1768 // 1.set stream id
1769 AnfAlgo::SetStreamId(max_node_stream, active_ptr.get());
1770 // 2.set active stream ids
1771 std::vector<uint32_t> active_index_list{independent_switch_stream};
1772 AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
1773
1774 std::vector<CNodePtr> update_cnode_list;
1775 std::copy(exe_orders.begin(), exe_orders.begin() + max_index + 1, std::back_inserter(update_cnode_list));
1776 update_cnode_list.emplace_back(active_ptr);
1777 std::copy(exe_orders.begin() + max_index + 1, exe_orders.end(), std::back_inserter(update_cnode_list));
1778 graph_ptr->set_execution_order(update_cnode_list);
1779 }
1780
1781 // section7
GetNeedActiveStreams(const NotNull<KernelGraphPtr> & graph_ptr)1782 void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr) {
1783 CNodePtr cur_cnode_ptr = nullptr;
1784 auto cnode_ptr_list = graph_ptr->execution_order();
1785
1786 // 1)stream witch kStreamNeedActivedFirst attr should be activated;
1787 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1788 cur_cnode_ptr = cnode_ptr_list[i];
1789 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1790 if (!AnfAlgo::HasNodeAttr(kStreamNeedActivedFirst, cur_cnode_ptr)) {
1791 continue;
1792 }
1793
1794 auto need_active = AnfAlgo::GetNodeAttr<bool>(cur_cnode_ptr, kStreamNeedActivedFirst);
1795 if (need_active) {
1796 auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1797 MS_LOG(INFO) << "Stream id:" << stream_id << " is need activated at first";
1798 need_first_active_streams_.push_back(stream_id);
1799 }
1800 }
1801
1802 // 2)independent stream:if has not been activate, push to need active vector
1803 auto root_graph_id = graph_ptr->graph_id();
1804 if (!independent_stream_activated_) {
1805 auto it = independent_graph_map_.find(root_graph_id);
1806 if (it != independent_graph_map_.end()) {
1807 need_first_active_streams_.push_back(*(it->second.begin()));
1808 }
1809 }
1810
1811 // 3)hcom stream:if has not been activate, push to need active vector
1812 if (!hcom_stream_activated_) {
1813 for (const auto &item : group_hcom_graph_map_) {
1814 auto &hcom_graph_map = item.second;
1815 auto it = hcom_graph_map.find(root_graph_id);
1816 if (it != hcom_graph_map.end()) {
1817 std::copy(it->second.begin(), it->second.end(), std::back_inserter(need_first_active_streams_));
1818 }
1819 }
1820 }
1821
1822 // 4)first stream 0 should be activated first;
1823 auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), 0);
1824 if (it == need_first_active_streams_.end()) {
1825 need_first_active_streams_.emplace_back(0);
1826 }
1827 MS_LOG(INFO) << "Finally, need active first stream include:";
1828 for (const auto &item : need_first_active_streams_) {
1829 MS_LOG(INFO) << "stream id:" << item;
1830 }
1831 }
1832
1833 // section8
CheckResourceAssign(const NotNull<KernelGraphPtr> & graph_ptr)1834 void AscendStreamAssign::CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
1835 CheckStreamAssign(graph_ptr);
1836 CheckEventAssign(graph_ptr);
1837 }
1838
CheckStreamAssign(const NotNull<KernelGraphPtr> & graph_ptr)1839 void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
1840 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1841 std::set<uint32_t> streams;
1842 uint32_t max_stream = 0;
1843 uint32_t min_stream = kInvalidStreamId;
1844 auto cnode_ptr_list = graph_ptr->execution_order();
1845 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1846 CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
1847 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1848 uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
1849 if (stream_id == kInvalidStreamId) {
1850 MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream";
1851 }
1852
1853 (void)streams.emplace(stream_id);
1854 if (stream_id > max_stream) {
1855 max_stream = stream_id;
1856 }
1857 if (stream_id < min_stream) {
1858 min_stream = stream_id;
1859 }
1860 }
1861
1862 // check stream assign
1863 if (!streams.empty()) {
1864 if (min_stream != 0) {
1865 MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream;
1866 }
1867 uint32_t assigned_stream_num = resource_manager.get_cur_stream_num();
1868 if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) {
1869 MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream
1870 << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size();
1871 }
1872 }
1873 }
1874
CheckEventAssign(const NotNull<KernelGraphPtr> & graph_ptr)1875 void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr) {
1876 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
1877 std::map<uint32_t, std::vector<CNodePtr>> event_map;
1878 uint32_t max_event_id = 0;
1879 uint32_t min_event_id = kInvalidEventId;
1880 auto cnode_ptr_list = graph_ptr->execution_order();
1881 for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
1882 CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
1883 MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
1884 auto name = AnfAlgo::GetCNodeName(cur_cnode_ptr);
1885 if (name == kSendOpName || name == kRecvOpName) {
1886 uint32_t event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId);
1887 if (event_id > max_event_id) {
1888 max_event_id = event_id;
1889 }
1890
1891 if (event_id < min_event_id) {
1892 min_event_id = event_id;
1893 }
1894 auto it = event_map.find(event_id);
1895 if (it == event_map.end()) {
1896 event_map[event_id] = {cur_cnode_ptr};
1897 } else {
1898 event_map[event_id].emplace_back(cur_cnode_ptr);
1899 }
1900 }
1901 }
1902 // check event assign
1903 if (!event_map.empty()) {
1904 if (min_event_id != 0) {
1905 MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id;
1906 }
1907 uint32_t assigned_event_num = resource_manager.get_cur_event_num();
1908 if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) {
1909 MS_LOG(EXCEPTION) << "Event should be consecutive, however, assigned event num is: " << assigned_event_num
1910 << ", max event id:" << max_event_id << ", event map is:" << event_map;
1911 }
1912 for (const auto &item : event_map) {
1913 if (item.second.size() != 2) {
1914 MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id, invalid event id is:" << item.first
1915 << ", event size is:" << item.second.size();
1916 }
1917 auto first_name = AnfAlgo::GetCNodeName(item.second[0]);
1918 auto second_name = AnfAlgo::GetCNodeName(item.second[1]);
1919 if (!(first_name == kSendOpName && second_name == kRecvOpName)) {
1920 MS_LOG(EXCEPTION) << "Send should be before recv, invalid event id is:" << item.first;
1921 }
1922 }
1923 }
1924 }
1925
1926 // section9
CreateSendApplyKernel(const NotNull<KernelGraphPtr> & graph_ptr,uint32_t event_id,uint32_t stream_id)1927 CNodePtr AscendStreamAssign::CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id,
1928 uint32_t stream_id) {
1929 auto send_op = std::make_shared<Primitive>(kSendOpName);
1930 MS_EXCEPTION_IF_NULL(send_op);
1931 auto send_apply = std::make_shared<ValueNode>(send_op);
1932 MS_EXCEPTION_IF_NULL(send_apply);
1933 std::vector<AnfNodePtr> send_input_list = {send_apply};
1934 CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list);
1935 MS_EXCEPTION_IF_NULL(send_node_ptr);
1936 kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
1937 selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
1938 AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
1939 AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
1940 auto abstract_none = std::make_shared<abstract::AbstractNone>();
1941 MS_EXCEPTION_IF_NULL(abstract_none);
1942 send_node_ptr->set_abstract(abstract_none);
1943 AnfAlgo::SetStreamId(stream_id, send_node_ptr.get());
1944 return send_node_ptr;
1945 }
1946
CreateRecvApplyKernel(const NotNull<KernelGraphPtr> & graph_ptr,uint32_t event_id,uint32_t stream_id)1947 CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id,
1948 uint32_t stream_id) {
1949 auto recv_op = std::make_shared<Primitive>(kRecvOpName);
1950 MS_EXCEPTION_IF_NULL(recv_op);
1951 auto recv_apply = std::make_shared<ValueNode>(recv_op);
1952 MS_EXCEPTION_IF_NULL(recv_apply);
1953 std::vector<AnfNodePtr> recv_input_list = {recv_apply};
1954 CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list);
1955 MS_EXCEPTION_IF_NULL(recv_node_ptr);
1956 kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
1957 selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
1958 AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
1959 AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
1960 AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get());
1961 auto abstract_none = std::make_shared<abstract::AbstractNone>();
1962 MS_EXCEPTION_IF_NULL(abstract_none);
1963 recv_node_ptr->set_abstract(abstract_none);
1964 return recv_node_ptr;
1965 }
1966
IsNopNodeTarget(const AnfNodePtr & nop_node,const CNodePtr & target_node,const CNodePtr & cur_node,bool exclude_hcom)1967 bool AscendStreamAssign::IsNopNodeTarget(const AnfNodePtr &nop_node, const CNodePtr &target_node,
1968 const CNodePtr &cur_node, bool exclude_hcom) {
1969 MS_EXCEPTION_IF_NULL(nop_node);
1970 auto cnode = nop_node->cast<CNodePtr>();
1971 auto new_inputs = cnode->inputs();
1972 for (size_t i = 1; i < new_inputs.size(); i++) {
1973 if (opt::IsNopNode(new_inputs[i])) {
1974 if (IsNopNodeTarget(new_inputs[i], target_node, cur_node, exclude_hcom)) {
1975 return true;
1976 }
1977 } else {
1978 auto new_real_input = AnfAlgo::VisitKernel(new_inputs[i], 0);
1979 if (target_node == new_real_input.first) {
1980 if (!(exclude_hcom && IsHcom(cur_node))) {
1981 return true;
1982 }
1983 }
1984 }
1985 }
1986 return false;
1987 }
1988
FindTargetOp(vector<CNodePtr>::iterator begin,vector<CNodePtr>::iterator end,const CNodePtr & node,bool exclude_hcom)1989 vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::iterator begin,
1990 vector<CNodePtr>::iterator end, const CNodePtr &node,
1991 bool exclude_hcom) {
1992 while (begin != end) {
1993 auto inputs = (*begin)->inputs();
1994 for (size_t i = 1; i < inputs.size(); i++) {
1995 auto input = inputs[i];
1996 MS_EXCEPTION_IF_NULL(input);
1997 if (opt::IsNopNode(input)) {
1998 if (IsNopNodeTarget(input, node, *begin, exclude_hcom)) {
1999 return begin;
2000 }
2001 } else {
2002 auto real_input = AnfAlgo::VisitKernel(input, 0);
2003 if (node == real_input.first) {
2004 if (!(exclude_hcom && IsHcom(*begin))) {
2005 MS_LOG(DEBUG) << "Nop node find target op[" << (*begin)->DebugString() << "]";
2006 return begin;
2007 }
2008 }
2009 }
2010 }
2011 ++begin;
2012 }
2013 return end;
2014 }
2015
IsTaskSink()2016 bool AscendStreamAssign::IsTaskSink() {
2017 auto ms_context = MsContext::GetInstance();
2018 MS_EXCEPTION_IF_NULL(ms_context);
2019 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
2020 MS_LOG(INFO) << "Task sink mode is not enable";
2021 return false;
2022 } else {
2023 MS_LOG(INFO) << "Task sink mode is enable";
2024 return true;
2025 }
2026 }
2027
GetWaitStreams(vector<uint32_t> * wait_active_stream_list)2028 void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) {
2029 MS_EXCEPTION_IF_NULL(wait_active_stream_list);
2030 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
2031 uint32_t total_stream_num = resource_manager.get_cur_stream_num();
2032 if (total_stream_num == 0) {
2033 MS_LOG(INFO) << "The total_common_stream_num is zero";
2034 return;
2035 }
2036
2037 // common stream:active first common stream
2038 for (uint32_t i = 0; i < total_stream_num; i++) {
2039 auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i);
2040 if (it == need_first_active_streams_.end()) {
2041 MS_LOG(INFO) << "Wait common stream id = " << i;
2042 wait_active_stream_list->push_back(i);
2043 }
2044 }
2045 }
2046
IsHcom(const CNodePtr & apply_kernel)2047 bool AscendStreamAssign::IsHcom(const CNodePtr &apply_kernel) {
2048 MS_EXCEPTION_IF_NULL(apply_kernel);
2049 return AnfAlgo::GetKernelType(apply_kernel) == HCCL_KERNEL;
2050 }
2051
GetHcomStreams(std::vector<uint32_t> * streams)2052 void AscendStreamAssign::GetHcomStreams(std::vector<uint32_t> *streams) {
2053 MS_EXCEPTION_IF_NULL(streams);
2054 for (const auto &item : hcom_stream_map_) {
2055 streams->emplace_back(item.first);
2056 }
2057 }
2058
Reset()2059 void AscendStreamAssign::Reset() {
2060 independent_stream_activated_ = false;
2061 hcom_stream_activated_ = false;
2062 loop_sink_ = false;
2063 independent_stream_map_.clear();
2064 hcom_stream_map_.clear();
2065 common_stream_map_.clear();
2066 processed_streams_.clear();
2067 need_first_active_streams_.clear();
2068 stream_groups_.clear();
2069 stream_relations_.clear();
2070 event_map_.clear();
2071 independent_targets_.clear();
2072 independent_graph_map_.clear();
2073 group_hcom_graph_map_.clear();
2074 middle_active_streams_.clear();
2075 }
2076
2077 // section 10
IsVecExist(const std::vector<uint32_t> & group)2078 bool AscendStreamAssign::IsVecExist(const std::vector<uint32_t> &group) {
2079 auto group_size = group.size();
2080 if (group_size == 0) {
2081 return false;
2082 }
2083 for (const auto &item : stream_groups_) {
2084 if (item.size() < group.size()) {
2085 continue;
2086 }
2087
2088 bool flag = true;
2089 for (size_t i = 0; i < group_size; i++) {
2090 if (item[i] != group.at(i)) {
2091 flag = false;
2092 break;
2093 }
2094 }
2095
2096 if (flag) {
2097 return true;
2098 } else {
2099 continue;
2100 }
2101 }
2102
2103 return false;
2104 }
2105
DFS(uint32_t start,std::vector<uint32_t> * group)2106 void AscendStreamAssign::DFS(uint32_t start, std::vector<uint32_t> *group) {
2107 MS_EXCEPTION_IF_NULL(group);
2108 auto it = stream_relations_.find(start);
2109 if (it == stream_relations_.end()) {
2110 if (!IsVecExist(*group)) {
2111 stream_groups_.emplace_back(*group);
2112 } else {
2113 MS_LOG(WARNING) << "DFS find same stream group, Not expected";
2114 }
2115 return;
2116 }
2117
2118 vector<uint32_t> active_streams = stream_relations_[start];
2119
2120 for (const auto &item : active_streams) {
2121 group->emplace_back(item);
2122 DFS(item, group);
2123 group->pop_back();
2124 }
2125 }
2126
GetStreamRelations()2127 void AscendStreamAssign::GetStreamRelations() {
2128 auto starts = middle_active_streams_;
2129 for (const auto &stream : need_first_active_streams_) {
2130 starts.emplace(stream);
2131 }
2132
2133 for (const auto &start : starts) {
2134 vector<uint32_t> group{start};
2135 DFS(start, &group);
2136 }
2137 }
2138
FindStreamRelations(const NotNull<KernelGraphPtr> & graph_ptr)2139 void AscendStreamAssign::FindStreamRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
2140 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
2141 auto stream_num = resource_manager.get_cur_stream_num();
2142 if (stream_num <= 1) {
2143 return;
2144 }
2145
2146 auto exe_orders = graph_ptr->execution_order();
2147 for (size_t i = 0; i < exe_orders.size(); i++) {
2148 auto cur_cnode = exe_orders[i];
2149 auto name = AnfAlgo::GetCNodeName(cur_cnode);
2150 if (name != kStreamSwitchOpName && name != kStreamActiveOpName) {
2151 continue;
2152 }
2153
2154 // support:streamswitch is begin of the stream
2155 if (name == kStreamSwitchOpName) {
2156 GetStreamSwitchStreamRelation(cur_cnode);
2157 }
2158
2159 if (name == kStreamActiveOpName) {
2160 GetStreamActiveStreamRelation(graph_ptr, i);
2161 }
2162 }
2163 }
2164
GetStreamSwitchStreamRelation(const CNodePtr & node_ptr)2165 void AscendStreamAssign::GetStreamSwitchStreamRelation(const CNodePtr &node_ptr) {
2166 MS_EXCEPTION_IF_NULL(node_ptr);
2167 auto cur_stream_id = AnfAlgo::GetStreamId(node_ptr);
2168 auto true_stream_id = AnfAlgo::GetNodeAttr<uint32_t>(node_ptr, kAttrTrueBranchStream);
2169 if (true_stream_id <= cur_stream_id) {
2170 MS_LOG(ERROR) << "StreamSwitch self stream id " << cur_stream_id
2171 << " is greater than true branch stream id:" << true_stream_id;
2172 }
2173 auto it = stream_relations_.find(cur_stream_id);
2174 if (it == stream_relations_.end()) {
2175 stream_relations_[cur_stream_id] = {true_stream_id};
2176 } else {
2177 auto iter =
2178 std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), true_stream_id);
2179 if (iter == stream_relations_[cur_stream_id].end()) {
2180 stream_relations_[cur_stream_id].emplace_back(true_stream_id);
2181 }
2182 }
2183 }
2184
GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> & graph_ptr,size_t index)2185 void AscendStreamAssign::GetStreamActiveStreamRelation(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
2186 StreamActiveKind kind = GetStreamActiveKind(graph_ptr, index);
2187 if (kind == kInvalid) {
2188 MS_LOG(INFO) << "Invalid streamActive kind";
2189 return;
2190 }
2191
2192 auto orders = graph_ptr->execution_order();
2193 if (index >= orders.size()) {
2194 MS_LOG(EXCEPTION) << "Invalid index.";
2195 }
2196 auto cur_cnode = orders[index];
2197 auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
2198 auto active_list = AnfAlgo::GetNodeAttr<vector<uint32_t>>(cur_cnode, kAttrActiveStreamList);
2199 if (kind == kHead) {
2200 uint32_t active_current_stream_id = GetStreamByActivedStream(cur_stream_id);
2201 if (active_current_stream_id == kInvalidStreamId) {
2202 MS_LOG(EXCEPTION) << "No stream to active streamactive stream: " << cur_stream_id;
2203 }
2204
2205 for (const auto &item : active_list) {
2206 if (item <= active_current_stream_id) {
2207 MS_LOG(WARNING) << "Activated stream is less than activing stream";
2208 continue;
2209 }
2210 auto it = std::find(stream_relations_[active_current_stream_id].begin(),
2211 stream_relations_[active_current_stream_id].end(), item);
2212 if (it == stream_relations_[active_current_stream_id].end()) {
2213 stream_relations_[active_current_stream_id].emplace_back(item);
2214 }
2215 }
2216 }
2217
2218 if (kind == kMiddle) {
2219 for (const auto &stream : active_list) {
2220 if (stream <= cur_stream_id) {
2221 MS_LOG(INFO) << "MIDDLE StreamActive active stream is less than self stream, no need deal";
2222 } else {
2223 MS_LOG(INFO) << "MIDDLE StreamActive :" << cur_stream_id << ", active target stream:" << stream;
2224 middle_active_streams_.emplace(stream);
2225 }
2226 }
2227 }
2228
2229 if (kind == kTail) {
2230 auto it = stream_relations_.find(cur_stream_id);
2231 if (it == stream_relations_.end()) {
2232 stream_relations_[cur_stream_id] = active_list;
2233 } else {
2234 for (const auto &stream : active_list) {
2235 if (stream <= cur_stream_id) {
2236 MS_LOG(WARNING) << "Activated stream is less than activing stream";
2237 continue;
2238 }
2239 auto iter = std::find(stream_relations_[cur_stream_id].begin(), stream_relations_[cur_stream_id].end(), stream);
2240 if (iter == stream_relations_[cur_stream_id].end()) {
2241 stream_relations_[cur_stream_id].emplace_back(stream);
2242 }
2243 }
2244 }
2245 }
2246 }
2247
GetStreamActiveKind(const NotNull<KernelGraphPtr> & graph_ptr,size_t index)2248 StreamActiveKind AscendStreamAssign::GetStreamActiveKind(const NotNull<KernelGraphPtr> &graph_ptr, size_t index) {
2249 auto exe_orders = graph_ptr->execution_order();
2250 if (index >= exe_orders.size()) {
2251 MS_LOG(EXCEPTION) << "Invalid op index:" << index;
2252 }
2253
2254 auto cur_cnode = exe_orders[index];
2255 auto cur_stream_id = AnfAlgo::GetStreamId(cur_cnode);
2256 if (AnfAlgo::GetCNodeName(cur_cnode) != kStreamActiveOpName) {
2257 MS_LOG(EXCEPTION) << "Current node name [" << AnfAlgo::GetCNodeName(cur_cnode) << "] is not StreamActive.";
2258 }
2259
2260 if (index == 0) {
2261 return kInvalid;
2262 }
2263
2264 if (index == exe_orders.size() - 1) {
2265 return kInvalid;
2266 }
2267
2268 uint32_t pre_stream_id = UINT32_MAX;
2269 uint32_t next_stream_id = UINT32_MAX;
2270 int32_t start = SizeToInt(index) - 1;
2271 for (int32_t i = start; i >= 0; i--) {
2272 auto cnode = exe_orders[IntToSize(i)];
2273 auto name = AnfAlgo::GetCNodeName(cnode);
2274 if (name == kSendOpName || name == kRecvOpName) {
2275 continue;
2276 }
2277 auto stream = AnfAlgo::GetStreamId(cnode);
2278 auto it = hcom_stream_map_.find(stream);
2279 if (it != hcom_stream_map_.end()) {
2280 continue;
2281 }
2282
2283 it = independent_stream_map_.find(stream);
2284 if (it != independent_stream_map_.end()) {
2285 continue;
2286 }
2287
2288 pre_stream_id = stream;
2289 break;
2290 }
2291
2292 for (size_t i = index + 1; i < exe_orders.size(); i++) {
2293 auto cnode = exe_orders[i];
2294 if (AnfAlgo::GetCNodeName(cnode) == kSendOpName || AnfAlgo::GetCNodeName(cnode) == kRecvOpName) {
2295 continue;
2296 }
2297
2298 auto stream = AnfAlgo::GetStreamId(cnode);
2299 auto it = hcom_stream_map_.find(stream);
2300 if (it != hcom_stream_map_.end()) {
2301 continue;
2302 }
2303
2304 it = independent_stream_map_.find(stream);
2305 if (it != independent_stream_map_.end()) {
2306 continue;
2307 }
2308
2309 next_stream_id = stream;
2310 break;
2311 }
2312
2313 return GetStreamKind(cur_stream_id, pre_stream_id, next_stream_id);
2314 }
2315
GetStreamByActivedStream(uint32_t actived_stream_id)2316 uint32_t AscendStreamAssign::GetStreamByActivedStream(uint32_t actived_stream_id) {
2317 if (stream_relations_.empty()) {
2318 return kInvalidStreamId;
2319 }
2320
2321 for (const auto &item : stream_relations_) {
2322 auto it = std::find(item.second.begin(), item.second.end(), actived_stream_id);
2323 if (it != item.second.end()) {
2324 return item.first;
2325 }
2326 }
2327
2328 return kInvalidStreamId;
2329 }
2330
PrintStreamRelations()2331 void AscendStreamAssign::PrintStreamRelations() {
2332 MS_LOG(INFO) << "Stream relations size:" << stream_relations_.size();
2333 for (const auto &item : stream_relations_) {
2334 MS_LOG(INFO) << "Stream:" << item.first;
2335 for (const auto &stream : item.second) {
2336 MS_LOG(INFO) << "--activated stream id:" << stream;
2337 }
2338 }
2339 }
2340
PrintStreamGroups()2341 void AscendStreamAssign::PrintStreamGroups() {
2342 MS_LOG(INFO) << "Stream group size:" << stream_groups_.size();
2343 for (const auto &item : stream_groups_) {
2344 MS_LOG(INFO) << "Group:";
2345 for (const auto &stream : item) {
2346 MS_LOG(INFO) << "Stream id:" << stream;
2347 }
2348 }
2349 }
2350
2351 // section 11
IsSatisfiedEvent(uint32_t send_stream_id,uint32_t recv_stream_id) const2352 bool AscendStreamAssign::IsSatisfiedEvent(uint32_t send_stream_id, uint32_t recv_stream_id) const {
2353 size_t send_group = 0;
2354 size_t recv_group = 0;
2355 bool send_flag = true;
2356 bool recv_flag = true;
2357 for (size_t i = 0; i < stream_groups_.size(); i++) {
2358 auto group = stream_groups_[i];
2359 if (send_flag) {
2360 auto it = std::find(group.begin(), group.end(), send_stream_id);
2361 if (it != group.end()) {
2362 send_group = i;
2363 send_flag = false;
2364 }
2365 }
2366
2367 if (recv_flag) {
2368 auto it = std::find(group.begin(), group.end(), recv_stream_id);
2369 if (it != group.end()) {
2370 recv_group = i;
2371 recv_flag = false;
2372 }
2373 }
2374 }
2375
2376 if (!(send_flag || recv_flag)) {
2377 return (send_group != recv_group);
2378 }
2379
2380 return false;
2381 }
2382
FindEventRelations(const NotNull<KernelGraphPtr> & graph_ptr)2383 void AscendStreamAssign::FindEventRelations(const NotNull<KernelGraphPtr> &graph_ptr) {
2384 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
2385 auto event_nums = resource_manager.get_cur_event_num();
2386 if (event_nums == 0) {
2387 return;
2388 }
2389 auto exe_orders = graph_ptr->execution_order();
2390 // find all event info
2391 for (size_t i = 0; i < exe_orders.size(); i++) {
2392 auto cur_cnode = exe_orders[i];
2393 auto name = AnfAlgo::GetCNodeName(cur_cnode);
2394 if (name == kSendOpName) {
2395 event_map_[cur_cnode] = {};
2396 }
2397
2398 if (name == kRecvOpName) {
2399 auto recv_event_id = AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode, kAttrEventId);
2400 for (auto &item : event_map_) {
2401 auto send_event_id = AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
2402 if (recv_event_id == send_event_id) {
2403 item.second = cur_cnode;
2404 break;
2405 }
2406 }
2407 }
2408 }
2409
2410 // delete useless event info
2411 auto begin = event_map_.begin();
2412 while (begin != event_map_.end()) {
2413 auto send_stream_id = AnfAlgo::GetStreamId(begin->first);
2414 auto recv_stream_id = AnfAlgo::GetStreamId(begin->second);
2415 bool flag = IsSatisfiedEvent(send_stream_id, recv_stream_id);
2416 if (!flag) {
2417 begin = event_map_.erase(begin);
2418 } else {
2419 ++begin;
2420 }
2421 }
2422
2423 MS_LOG(INFO) << "Satisfied event info";
2424 for (const auto &item : event_map_) {
2425 MS_LOG(INFO) << "Event_id:" << AnfAlgo::GetNodeAttr<uint32_t>(item.first, kAttrEventId);
2426 }
2427 }
2428
2429 // section12
AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> & graph_ptr)2430 void AscendStreamAssign::AdjustAtomicAddrCleanOrder(const NotNull<KernelGraphPtr> &graph_ptr) {
2431 // Eg:[atomic, recv, memcpy] should be [recv, atomic, memcpy]
2432 std::vector<CNodePtr> update_orders;
2433 auto &exe_orders = graph_ptr->execution_order();
2434 size_t i = 0;
2435 while (i < exe_orders.size()) {
2436 auto cur_cnode = exe_orders.at(i);
2437 if (AnfAlgo::GetCNodeName(cur_cnode) != kAtomicAddrCleanOpName) {
2438 update_orders.emplace_back(cur_cnode);
2439 i++;
2440 continue;
2441 }
2442 while (i < exe_orders.size() - 1) {
2443 i++;
2444 auto next_cnode = exe_orders.at(i);
2445 auto next_cnode_name = AnfAlgo::GetCNodeName(next_cnode);
2446 if (next_cnode_name == kSendOpName || next_cnode_name == kRecvOpName) {
2447 update_orders.emplace_back(next_cnode);
2448 } else {
2449 update_orders.emplace_back(cur_cnode);
2450 break;
2451 }
2452 }
2453 }
2454 graph_ptr->set_execution_order(update_orders);
2455 }
2456 } // namespace ascend
2457 } // namespace device
2458 } // namespace mindspore
2459