• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/kernel_adjust.h"
18 
19 #include <map>
20 #include <algorithm>
21 #include <string>
22 #include <vector>
23 #include <utility>
24 
25 #include "backend/session/anf_runtime_algorithm.h"
26 #include "utils/ms_context.h"
27 #include "common/trans.h"
28 #include "utils/config_manager.h"
29 #include "utils/ms_utils.h"
30 #include "backend/kernel_compiler/kernel_build_info.h"
31 #include "utils/utils.h"
32 #include "runtime/device/ascend/profiling/profiling_manager.h"
33 #include "runtime/base.h"
34 #include "runtime/device/ascend/ascend_stream_assign.h"
35 #include "utils/shape_utils.h"
36 
37 namespace {
38 constexpr auto kGradients = "Gradients";
39 constexpr auto kSpecifyParameter = "accu_status";
40 size_t kNPUShape = 8;
41 constexpr size_t kLastHandleDiff = 2;
42 }  // namespace
43 namespace mindspore {
44 namespace device {
45 #ifndef ENABLE_SECURITY
46 using device::ascend::ProfilingUtils;
47 #endif
ReorderGetNext(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)48 void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
49   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
50   const std::vector<CNodePtr> &origin_cnode_list = kernel_graph_ptr->execution_order();
51   std::vector<CNodePtr> getnext_list;
52   std::vector<CNodePtr> other_list;
53   for (const auto &cnode : origin_cnode_list) {
54     if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
55       getnext_list.emplace_back(cnode);
56     } else {
57       other_list.emplace_back(cnode);
58     }
59   }
60   std::vector<CNodePtr> new_order_list;
61   new_order_list.insert(new_order_list.end(), getnext_list.begin(), getnext_list.end());
62   new_order_list.insert(new_order_list.end(), other_list.begin(), other_list.end());
63   kernel_graph_ptr->set_execution_order(new_order_list);
64 }
65 
NeedInsertSwitch()66 bool KernelAdjust::NeedInsertSwitch() {
67   auto context_ptr = MsContext::GetInstance();
68   MS_EXCEPTION_IF_NULL(context_ptr);
69   return (context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK) &&
70           context_ptr->get_param<bool>(MS_CTX_ENABLE_LOOP_SINK) && ConfigManager::GetInstance().iter_num() > 1);
71 }
72 
CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> & graph_ptr,uint32_t event_id)73 CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
74                                              uint32_t event_id) {
75   MS_EXCEPTION_IF_NULL(graph_ptr);
76   auto send_op = std::make_shared<Primitive>(kSendOpName);
77   MS_EXCEPTION_IF_NULL(send_op);
78   auto send_apply = std::make_shared<ValueNode>(send_op);
79   MS_EXCEPTION_IF_NULL(send_apply);
80   std::vector<AnfNodePtr> send_input_list = {send_apply};
81   CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list);
82   MS_EXCEPTION_IF_NULL(send_node_ptr);
83   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
84   selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
85   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
86   AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
87   auto abstract_none = std::make_shared<abstract::AbstractNone>();
88   MS_EXCEPTION_IF_NULL(abstract_none);
89   send_node_ptr->set_abstract(abstract_none);
90   return send_node_ptr;
91 }
92 
CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> & graph_ptr,uint32_t event_id)93 CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
94                                              uint32_t event_id) {
95   MS_EXCEPTION_IF_NULL(graph_ptr);
96   auto recv_op = std::make_shared<Primitive>(kRecvOpName);
97   MS_EXCEPTION_IF_NULL(recv_op);
98   auto recv_apply = std::make_shared<ValueNode>(recv_op);
99   MS_EXCEPTION_IF_NULL(recv_apply);
100   std::vector<AnfNodePtr> recv_input_list = {recv_apply};
101   CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list);
102   MS_EXCEPTION_IF_NULL(recv_node_ptr);
103   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
104   selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
105   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
106   AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
107   auto abstract_none = std::make_shared<abstract::AbstractNone>();
108   MS_EXCEPTION_IF_NULL(abstract_none);
109   recv_node_ptr->set_abstract(abstract_none);
110   return recv_node_ptr;
111 }
112 
ExistGetNext(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)113 bool KernelAdjust::ExistGetNext(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
114   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
115   const std::vector<CNodePtr> &cnode_list = kernel_graph_ptr->execution_order();
116   for (const auto &cnode : cnode_list) {
117     if (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName) {
118       return true;
119     }
120   }
121   return false;
122 }
123 
ExistIndependent(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)124 bool KernelAdjust::ExistIndependent(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
125   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
126   const auto &exe_orders = kernel_graph_ptr->execution_order();
127   for (const auto &node : exe_orders) {
128     if (AnfAlgo::IsIndependentNode(node) && AnfAlgo::GetGraphId(node.get()) == kernel_graph_ptr->graph_id()) {
129       MS_LOG(INFO) << "graph exit independent node";
130       return true;
131     }
132   }
133 
134   return false;
135 }
136 
InsertIndepentParallel(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input,std::vector<CNodePtr> * exec_order)137 void KernelAdjust::InsertIndepentParallel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
138                                           const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
139                                           std::vector<CNodePtr> *exec_order) {
140   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
141   MS_EXCEPTION_IF_NULL(exec_order);
142   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
143   CNodePtr independent_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kIndependentStreamSwitch);
144   MS_EXCEPTION_IF_NULL(independent_switch_app);
145   uint32_t independent_switch_stream_id = resource_manager.ApplyNewStream();
146   AnfAlgo::SetStreamId(independent_switch_stream_id, independent_switch_app.get());
147   AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), independent_switch_app);
148   AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kIndependentStreamSwitch), independent_switch_app);
149   (*exec_order).push_back(independent_switch_app);
150   MS_LOG(INFO) << "Independent op loop insert Stream Switch " << independent_switch_app->fullname_with_scope();
151 }
152 
InsertFpBpLoopStreamSwitch(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input,std::vector<CNodePtr> * exec_order,uint32_t * fpbp_stream_id,uint32_t * fpbp_switch_stream_id)153 void KernelAdjust::InsertFpBpLoopStreamSwitch(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
154                                               const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
155                                               std::vector<CNodePtr> *exec_order, uint32_t *fpbp_stream_id,
156                                               uint32_t *fpbp_switch_stream_id) {
157   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
158   MS_EXCEPTION_IF_NULL(exec_order);
159   MS_EXCEPTION_IF_NULL(fpbp_stream_id);
160   MS_EXCEPTION_IF_NULL(fpbp_switch_stream_id);
161   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
162   *fpbp_switch_stream_id = resource_manager.ApplyNewStream();
163   *fpbp_stream_id = resource_manager.ApplyNewStream();
164   CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kFpBpStreamSwitch);
165   MS_EXCEPTION_IF_NULL(fpbp_switch_app);
166   AnfAlgo::SetStreamId(*fpbp_switch_stream_id, fpbp_switch_app.get());
167   AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
168   // update fpbp loop stream switch true_branch_stream attr
169   AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(*fpbp_stream_id), fpbp_switch_app);
170   AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kFpBpStreamSwitch), fpbp_switch_app);
171   (*exec_order).push_back(fpbp_switch_app);
172   MS_LOG(INFO) << "FpBp loop insert Stream Switch " << fpbp_switch_app->fullname_with_scope();
173 }
174 
CopyMemcpyList(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::vector<CNodePtr> & orders,size_t order_index,std::vector<CNodePtr> * memcpy_list,std::vector<CNodePtr> * other_list)175 void KernelAdjust::CopyMemcpyList(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
176                                   const std::vector<CNodePtr> &orders, size_t order_index,
177                                   std::vector<CNodePtr> *memcpy_list, std::vector<CNodePtr> *other_list) {
178   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
179   MS_EXCEPTION_IF_NULL(memcpy_list);
180   MS_EXCEPTION_IF_NULL(other_list);
181   CNodePtr cur_cnode = nullptr;
182   for (size_t idx = order_index + 1; idx < orders.size(); idx++) {
183     cur_cnode = orders[idx];
184     if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) {
185       auto pre_node = orders[idx - 1];
186       auto pre_kernel_name = AnfAlgo::GetCNodeName(pre_node);
187       if (pre_kernel_name == kAtomicAddrCleanOpName) {
188         (*other_list).pop_back();
189         (*memcpy_list).push_back(pre_node);
190       }
191       (*memcpy_list).emplace_back(cur_cnode);
192     } else {
193       (*other_list).emplace_back(cur_cnode);
194     }
195   }
196 }
197 
InsertEosDoneRecv(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t eos_done_event_id,uint32_t fpbp_stream_id)198 void KernelAdjust::InsertEosDoneRecv(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
199                                      std::vector<CNodePtr> *exec_order, uint32_t eos_done_event_id,
200                                      uint32_t fpbp_stream_id) {
201   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
202   MS_EXCEPTION_IF_NULL(exec_order);
203   CNodePtr eos_done_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_done_event_id);
204   AnfAlgo::SetStreamId(fpbp_stream_id, eos_done_recv.get());
205   (*exec_order).push_back(eos_done_recv);
206   MS_LOG(INFO) << "FpBp loop insert EoS done Recv " << eos_done_recv->fullname_with_scope();
207 }
208 
InsertGetNextLoopStreamActive(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,const std::vector<uint32_t> & getnext_active_streams)209 void KernelAdjust::InsertGetNextLoopStreamActive(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
210                                                  std::vector<CNodePtr> *exec_order,
211                                                  const std::vector<uint32_t> &getnext_active_streams) {
212   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
213   MS_EXCEPTION_IF_NULL(exec_order);
214   CNodePtr getnext_active_app = CreateStreamActiveOp(kernel_graph_ptr);
215   MS_EXCEPTION_IF_NULL(getnext_active_app);
216   AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(getnext_active_streams),
217                        getnext_active_app);
218   (*exec_order).push_back(getnext_active_app);
219   MS_LOG(INFO) << "FpBp loop insert GetNext loop Stream Active " << getnext_active_app->fullname_with_scope();
220 }
221 
InsertFpBpStartRecv(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t fpbp_start_event_id,uint32_t fpbp_stream_id)222 void KernelAdjust::InsertFpBpStartRecv(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
223                                        std::vector<CNodePtr> *exec_order, uint32_t fpbp_start_event_id,
224                                        uint32_t fpbp_stream_id) {
225   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
226   MS_EXCEPTION_IF_NULL(exec_order);
227   CNodePtr fpbp_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, fpbp_start_event_id);
228   AnfAlgo::SetStreamId(fpbp_stream_id, fpbp_start_recv.get());
229   (*exec_order).push_back(fpbp_start_recv);
230   MS_LOG(INFO) << "FpBp loop insert FpBp start Recv " << fpbp_start_recv->fullname_with_scope();
231 }
232 
InsertNextLoopAssignAdd(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input,uint32_t fpbp_stream_id)233 void KernelAdjust::InsertNextLoopAssignAdd(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
234                                            std::vector<CNodePtr> *exec_order,
235                                            const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
236                                            uint32_t fpbp_stream_id) {
237   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
238   MS_EXCEPTION_IF_NULL(exec_order);
239   CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, false);
240   MS_EXCEPTION_IF_NULL(assign_add_one);
241   AnfAlgo::SetStreamId(fpbp_stream_id, assign_add_one.get());
242   (*exec_order).push_back(assign_add_one);
243   MS_LOG(INFO) << "FpBp loop insert next loop AssignAdd " << assign_add_one->fullname_with_scope();
244 }
245 
InsertCurrentLoopAssignAdd(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input)246 void KernelAdjust::InsertCurrentLoopAssignAdd(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
247                                               std::vector<CNodePtr> *exec_order,
248                                               const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
249   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
250   MS_EXCEPTION_IF_NULL(exec_order);
251   CNodePtr cur_assign_add = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input, true);
252   MS_EXCEPTION_IF_NULL(cur_assign_add);
253   AnfAlgo::SetNodeAttr(kAttrFpBpEnd, MakeValue<bool>(true), cur_assign_add);
254   (*exec_order).push_back(cur_assign_add);
255   MS_LOG(INFO) << "FpBp loop insert current loop AssignAdd " << cur_assign_add->fullname_with_scope();
256 }
257 
InsertFpBpAndEosLoopStreamActive(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,const std::vector<uint32_t> & fpbp_active_streams)258 void KernelAdjust::InsertFpBpAndEosLoopStreamActive(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
259                                                     std::vector<CNodePtr> *exec_order,
260                                                     const std::vector<uint32_t> &fpbp_active_streams) {
261   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
262   MS_EXCEPTION_IF_NULL(exec_order);
263   CNodePtr fpbp_active_app = CreateStreamActiveOp(kernel_graph_ptr);
264   MS_EXCEPTION_IF_NULL(fpbp_active_app);
265   AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(fpbp_active_streams), fpbp_active_app);
266   (*exec_order).push_back(fpbp_active_app);
267   MS_LOG(INFO) << "FpBp loop insert FpBp loop and Eos loop Stream Active " << fpbp_active_app->fullname_with_scope();
268 }
269 
InsertSwitchLoopInput(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input)270 void KernelAdjust::InsertSwitchLoopInput(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
271                                          const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
272   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
273   std::vector<AnfNodePtr> *mute_inputs = kernel_graph_ptr->MutableInputs();
274   MS_EXCEPTION_IF_NULL(mute_inputs);
275   mute_inputs->push_back(switch_loop_input.at(kCurLoopCountParamName));
276   mute_inputs->push_back(switch_loop_input.at(kNextLoopCountParamName));
277   mute_inputs->push_back(switch_loop_input.at(kEpochParamName));
278   mute_inputs->push_back(switch_loop_input.at(kIterLoopParamName));
279   mute_inputs->push_back(switch_loop_input.at(kOneParamName));
280   for (const auto &input : kernel_graph_ptr->inputs()) {
281     MS_EXCEPTION_IF_NULL(input);
282     if (input->isa<Parameter>()) {
283       ParameterPtr param_ptr = input->cast<ParameterPtr>();
284       if (param_ptr == nullptr) {
285         MS_EXCEPTION(NotSupportError) << "Cast to parameter point failed !";
286       }
287     }
288   }
289 }
290 
InsertGetNextLoopStreamSwitch(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t * getnext_switch_stream_id,uint32_t * getnext_stream_id,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input)291 void KernelAdjust::InsertGetNextLoopStreamSwitch(
292   const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, std::vector<CNodePtr> *exec_order,
293   uint32_t *getnext_switch_stream_id, uint32_t *getnext_stream_id,
294   const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input) {
295   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
296   MS_EXCEPTION_IF_NULL(exec_order);
297   MS_EXCEPTION_IF_NULL(getnext_switch_stream_id);
298   MS_EXCEPTION_IF_NULL(getnext_stream_id);
299   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
300   *getnext_switch_stream_id = resource_manager.ApplyNewStream();
301   *getnext_stream_id = resource_manager.ApplyNewStream();
302   CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kGetNextStreamSwitch);
303   MS_EXCEPTION_IF_NULL(getnext_switch_app);
304   AnfAlgo::SetStreamId(*getnext_switch_stream_id, getnext_switch_app.get());
305   // update getnext loop stream switch true_branch_stream attr
306   AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), getnext_switch_app);
307   AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(*getnext_stream_id), getnext_switch_app);
308   AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kGetNextStreamSwitch), getnext_switch_app);
309   (*exec_order).push_back(getnext_switch_app);
310   MS_LOG(INFO) << "GetNext loop insert Stream Switch " << getnext_switch_app->fullname_with_scope();
311 }
312 
SetBeforeGetNextStreamID(std::vector<CNodePtr> * exec_order,const std::vector<CNodePtr> & orders,size_t * order_index,CNodePtr getnext_cnode,uint32_t getnext_stream_id)313 void KernelAdjust::SetBeforeGetNextStreamID(std::vector<CNodePtr> *exec_order, const std::vector<CNodePtr> &orders,
314                                             size_t *order_index, CNodePtr getnext_cnode, uint32_t getnext_stream_id) {
315   MS_EXCEPTION_IF_NULL(exec_order);
316   MS_EXCEPTION_IF_NULL(order_index);
317   for (; *order_index < orders.size(); (*order_index)++) {
318     auto node = orders[*order_index];
319     (*exec_order).push_back(node);
320     AnfAlgo::SetStreamId(getnext_stream_id, (*exec_order)[(*exec_order).size() - 1].get());
321     if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
322       getnext_cnode = node;
323       break;
324     }
325   }
326 }
327 
InsertGetNextLoopFpBpStartSend(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t * fpbp_start_event_id,uint32_t getnext_stream_id)328 void KernelAdjust::InsertGetNextLoopFpBpStartSend(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
329                                                   std::vector<CNodePtr> *exec_order, uint32_t *fpbp_start_event_id,
330                                                   uint32_t getnext_stream_id) {
331   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
332   MS_EXCEPTION_IF_NULL(exec_order);
333   MS_EXCEPTION_IF_NULL(fpbp_start_event_id);
334   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
335   *fpbp_start_event_id = resource_manager.ApplyNewEvent();
336   CNodePtr fpbp_start_send = CreateSendApplyKernel(kernel_graph_ptr, *fpbp_start_event_id);
337   AnfAlgo::SetStreamId(getnext_stream_id, fpbp_start_send.get());
338   (*exec_order).push_back(fpbp_start_send);
339   MS_LOG(INFO) << "GetNext loop insert FpBp start Send " << fpbp_start_send->fullname_with_scope();
340 }
341 
InsertGetNextLoopEosStartSend(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t * eos_start_event_id,uint32_t getnext_stream_id)342 void KernelAdjust::InsertGetNextLoopEosStartSend(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
343                                                  std::vector<CNodePtr> *exec_order, uint32_t *eos_start_event_id,
344                                                  uint32_t getnext_stream_id) {
345   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
346   MS_EXCEPTION_IF_NULL(exec_order);
347   MS_EXCEPTION_IF_NULL(eos_start_event_id);
348   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
349   *eos_start_event_id = resource_manager.ApplyNewEvent();
350   CNodePtr eos_start_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_start_event_id);
351   AnfAlgo::SetStreamId(getnext_stream_id, eos_start_send.get());
352   (*exec_order).push_back(eos_start_send);
353   MS_LOG(INFO) << "GetNext loop insert EoS start Send " << eos_start_send->fullname_with_scope();
354 }
355 
InsertEosStreamSwitch(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input,std::vector<CNodePtr> * exec_order,uint32_t * eos_switch_stream_id,uint32_t * eos_stream_id)356 void KernelAdjust::InsertEosStreamSwitch(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
357                                          const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
358                                          std::vector<CNodePtr> *exec_order, uint32_t *eos_switch_stream_id,
359                                          uint32_t *eos_stream_id) {
360   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
361   MS_EXCEPTION_IF_NULL(exec_order);
362   MS_EXCEPTION_IF_NULL(eos_switch_stream_id);
363   MS_EXCEPTION_IF_NULL(eos_stream_id);
364   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
365   *eos_switch_stream_id = resource_manager.ApplyNewStream();
366   *eos_stream_id = resource_manager.ApplyNewStream();
367   CNodePtr eos_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input, kEosStreamSwitch);
368   MS_EXCEPTION_IF_NULL(eos_switch_app);
369   AnfAlgo::SetStreamId(*eos_switch_stream_id, eos_switch_app.get());
370   AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), eos_switch_app);
371   // update eos loop stream switch true_branch_stream attr
372   AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(*eos_stream_id), eos_switch_app);
373   AnfAlgo::SetNodeAttr(kAttrStreamSwitchKind, MakeValue<uint32_t>(kEosStreamSwitch), eos_switch_app);
374   (*exec_order).push_back(eos_switch_app);
375   MS_LOG(INFO) << "EoS loop insert Stream Switch " << eos_switch_app->fullname_with_scope();
376 }
377 
InsertGetNextLoopEosStartRecv(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t eos_start_event_id,uint32_t eos_stream_id)378 void KernelAdjust::InsertGetNextLoopEosStartRecv(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
379                                                  std::vector<CNodePtr> *exec_order, uint32_t eos_start_event_id,
380                                                  uint32_t eos_stream_id) {
381   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
382   MS_EXCEPTION_IF_NULL(exec_order);
383   CNodePtr eos_start_recv = CreateRecvApplyKernel(kernel_graph_ptr, eos_start_event_id);
384   AnfAlgo::SetStreamId(eos_stream_id, eos_start_recv.get());
385   (*exec_order).push_back(eos_start_recv);
386   MS_LOG(INFO) << "EoS loop insert EoS Recv " << eos_start_recv->fullname_with_scope();
387 }
388 
InsertEosOp(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,const CNodePtr & getnext_cnode,uint32_t eos_stream_id)389 void KernelAdjust::InsertEosOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
390                                std::vector<CNodePtr> *exec_order, const CNodePtr &getnext_cnode,
391                                uint32_t eos_stream_id) {
392   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
393   MS_EXCEPTION_IF_NULL(exec_order);
394   MS_EXCEPTION_IF_NULL(getnext_cnode);
395   CNodePtr end_of_sequence_op = CreateEndOfSequenceOP(kernel_graph_ptr, getnext_cnode);
396   MS_EXCEPTION_IF_NULL(end_of_sequence_op);
397   AnfAlgo::SetStreamId(eos_stream_id, end_of_sequence_op.get());
398   (*exec_order).push_back(end_of_sequence_op);
399   MS_LOG(INFO) << "EoS loop insert Eos Op " << end_of_sequence_op->fullname_with_scope();
400 }
401 
InsertEosDoneSend(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::vector<CNodePtr> * exec_order,uint32_t * eos_done_event_id,uint32_t eos_stream_id)402 void KernelAdjust::InsertEosDoneSend(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
403                                      std::vector<CNodePtr> *exec_order, uint32_t *eos_done_event_id,
404                                      uint32_t eos_stream_id) {
405   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
406   MS_EXCEPTION_IF_NULL(exec_order);
407   MS_EXCEPTION_IF_NULL(eos_done_event_id);
408   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
409   *eos_done_event_id = resource_manager.ApplyNewEvent();
410   CNodePtr eos_done_send = CreateSendApplyKernel(kernel_graph_ptr, *eos_done_event_id);
411   AnfAlgo::SetStreamId(eos_stream_id, eos_done_send.get());
412   (*exec_order).push_back(eos_done_send);
413   MS_LOG(INFO) << "EoS loop insert EoS done Send " << eos_done_send->fullname_with_scope();
414 }
415 
InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)416 void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
417   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
418   device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
419   resource_manager.ResetResource();
420   if (!NeedInsertSwitch()) {
421     return;
422   }
423   if (kernel_graph_ptr->is_dynamic_shape()) {
424     MS_LOG(INFO) << "KernelGraph:" << kernel_graph_ptr->graph_id() << " is dynamic shape, skip InsertSwitchLoop";
425     return;
426   }
427   bool exist_getnext = ExistGetNext(kernel_graph_ptr);
428   bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX && exist_getnext;
429   MS_LOG(INFO) << "GetNext exist:" << exist_getnext << " End of Sequence mode:" << eos_mode
430                << " iter num:" << ConfigManager::GetInstance().iter_num();
431   if (exist_getnext) {
432     ReorderGetNext(kernel_graph_ptr);
433   }
434   std::map<std::string, mindspore::ParameterPtr> switch_loop_input;
435   CreateSwitchOpParameters(kernel_graph_ptr, &switch_loop_input);
436   InsertSwitchLoopInput(kernel_graph_ptr, switch_loop_input);
437 
438   const std::vector<CNodePtr> &orders = kernel_graph_ptr->execution_order();
439   if (orders.empty()) {
440     MS_LOG(EXCEPTION) << "graph " << kernel_graph_ptr->graph_id() << " execution order is empty";
441   }
442 
443   std::vector<CNodePtr> exec_order;
444   CNodePtr getnext_cnode;
445   uint32_t getnext_switch_stream_id = UINT32_MAX;
446   uint32_t fpbp_start_event_id = UINT32_MAX;
447   uint32_t eos_start_event_id = UINT32_MAX;
448   uint32_t getnext_stream_id = UINT32_MAX;
449   size_t order_index = 0;
450 
451   if (exist_getnext) {
452     InsertGetNextLoopStreamSwitch(kernel_graph_ptr, &exec_order, &getnext_switch_stream_id, &getnext_stream_id,
453                                   switch_loop_input);
454     SetBeforeGetNextStreamID(&exec_order, orders, &order_index, getnext_cnode, getnext_stream_id);
455     InsertGetNextLoopFpBpStartSend(kernel_graph_ptr, &exec_order, &fpbp_start_event_id, getnext_stream_id);
456     if (eos_mode) {
457       InsertGetNextLoopEosStartSend(kernel_graph_ptr, &exec_order, &eos_start_event_id, getnext_stream_id);
458     }
459   }
460 
461   uint32_t eos_switch_stream_id = UINT32_MAX;
462   uint32_t eos_stream_id = UINT32_MAX;
463   uint32_t eos_done_event_id = UINT32_MAX;
464   std::vector<uint32_t> fpbp_active_streams;
465   if (eos_mode) {
466     InsertEosStreamSwitch(kernel_graph_ptr, switch_loop_input, &exec_order, &eos_switch_stream_id, &eos_stream_id);
467     InsertGetNextLoopEosStartRecv(kernel_graph_ptr, &exec_order, eos_start_event_id, eos_stream_id);
468     InsertEosOp(kernel_graph_ptr, &exec_order, getnext_cnode, eos_stream_id);
469     InsertEosDoneSend(kernel_graph_ptr, &exec_order, &eos_done_event_id, eos_stream_id);
470     fpbp_active_streams.push_back(eos_switch_stream_id);
471   }
472 
473   bool exist_independent = ExistIndependent(kernel_graph_ptr);
474   if (exist_independent) {
475     InsertIndepentParallel(kernel_graph_ptr, switch_loop_input, &exec_order);
476   }
477 
478   uint32_t fpbp_stream_id = UINT32_MAX;
479   uint32_t fpbp_switch_stream_id = UINT32_MAX;
480   InsertFpBpLoopStreamSwitch(kernel_graph_ptr, switch_loop_input, &exec_order, &fpbp_stream_id, &fpbp_switch_stream_id);
481 
482   if (exist_getnext) {
483     InsertFpBpStartRecv(kernel_graph_ptr, &exec_order, fpbp_start_event_id, fpbp_stream_id);
484   }
485   InsertNextLoopAssignAdd(kernel_graph_ptr, &exec_order, switch_loop_input, fpbp_stream_id);
486 
487   std::vector<CNodePtr> memcpy_list;
488   std::vector<CNodePtr> other_list;
489   if (exist_getnext) {
490     CopyMemcpyList(kernel_graph_ptr, orders, order_index, &memcpy_list, &other_list);
491     (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
492   } else {
493     other_list = orders;
494   }
495 
496   if (eos_mode) {
497     InsertEosDoneRecv(kernel_graph_ptr, &exec_order, eos_done_event_id, fpbp_stream_id);
498   }
499   std::vector<uint32_t> getnext_active_streams;
500   if (exist_getnext) {
501     // small loop active
502     getnext_active_streams.push_back(getnext_switch_stream_id);
503     InsertGetNextLoopStreamActive(kernel_graph_ptr, &exec_order, getnext_active_streams);
504   }
505 
506   (void)std::copy(other_list.begin(), other_list.end(), std::back_inserter(exec_order));
507   InsertCurrentLoopAssignAdd(kernel_graph_ptr, &exec_order, switch_loop_input);
508   // big loop active
509   fpbp_active_streams.push_back(fpbp_switch_stream_id);
510   InsertFpBpAndEosLoopStreamActive(kernel_graph_ptr, &exec_order, fpbp_active_streams);
511   kernel_graph_ptr->set_execution_order(exec_order);
512 }
513 
CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,std::map<std::string,mindspore::ParameterPtr> * switch_loop_input)514 void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
515                                             std::map<std::string, mindspore::ParameterPtr> *switch_loop_input) {
516   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
517   MS_EXCEPTION_IF_NULL(switch_loop_input);
518   ShapeVector shp = {1};
519   tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
520   MS_EXCEPTION_IF_NULL(tensor_ptr);
521   mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract();
522   if (paremeter_abstract_ptr == nullptr) {
523     MS_LOG(EXCEPTION) << "create abstract before insert switch op failed!";
524   }
525 
526   ParameterPtr cur_loop_count = std::make_shared<Parameter>(kernel_graph_ptr);
527   MS_EXCEPTION_IF_NULL(cur_loop_count);
528   cur_loop_count->set_name(kCurLoopCountParamName);
529   cur_loop_count->set_abstract(paremeter_abstract_ptr);
530   ParameterPtr loop_count_cur = kernel_graph_ptr->NewParameter(cur_loop_count);
531   (*switch_loop_input)[kCurLoopCountParamName] = loop_count_cur;
532 
533   ParameterPtr next_loop_count = std::make_shared<Parameter>(kernel_graph_ptr);
534   MS_EXCEPTION_IF_NULL(next_loop_count);
535   next_loop_count->set_name(kNextLoopCountParamName);
536   next_loop_count->set_abstract(paremeter_abstract_ptr);
537   ParameterPtr loop_count_next = kernel_graph_ptr->NewParameter(next_loop_count);
538   (*switch_loop_input)[kNextLoopCountParamName] = loop_count_next;
539 
540   ParameterPtr iter_loop = std::make_shared<Parameter>(kernel_graph_ptr);
541   iter_loop->set_name(kIterLoopParamName);
542   iter_loop->set_abstract(paremeter_abstract_ptr);
543   ParameterPtr iter_loop_new = kernel_graph_ptr->NewParameter(iter_loop);
544   (*switch_loop_input)[kIterLoopParamName] = iter_loop_new;
545 
546   ParameterPtr one = std::make_shared<Parameter>(kernel_graph_ptr);
547   one->set_name(kOneParamName);
548   one->set_abstract(paremeter_abstract_ptr);
549   ParameterPtr one_new = kernel_graph_ptr->NewParameter(one);
550   (*switch_loop_input)[kOneParamName] = one_new;
551 
552   ParameterPtr epoch = std::make_shared<Parameter>(kernel_graph_ptr);
553   MS_EXCEPTION_IF_NULL(epoch);
554   epoch->set_name(kEpochParamName);
555   epoch->set_abstract(paremeter_abstract_ptr);
556   ParameterPtr epoch_new = kernel_graph_ptr->NewParameter(epoch);
557   (*switch_loop_input)[kEpochParamName] = epoch_new;
558 }
559 
CreateMngKernelBuilder(const std::vector<std::string> & formats,const std::vector<TypeId> & type_ids)560 kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBuilder(
561   const std::vector<std::string> &formats, const std::vector<TypeId> &type_ids) {
562   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
563   selected_kernel_builder.SetInputsFormat(formats);
564   selected_kernel_builder.SetInputsDeviceType(type_ids);
565 
566   selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
567   selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
568   selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
569   return selected_kernel_builder;
570 }
571 
CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input,StreamSwitchKind kind)572 CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
573                                             const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
574                                             StreamSwitchKind kind) {
575   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
576     {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
577   auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
578   auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
579   std::vector<AnfNodePtr> inputs;
580   inputs.push_back(NewValueNode(stream_switch));
581   if (kind == kFpBpStreamSwitch || kind == kEosStreamSwitch) {
582     inputs.push_back(switch_loop_input.at(kNextLoopCountParamName));
583   } else if (kind == kGetNextStreamSwitch || kind == kIndependentStreamSwitch) {
584     inputs.push_back(switch_loop_input.at(kNextLoopCountParamName));
585   } else {
586     MS_LOG(ERROR) << "unknown stream switch kind: " << kind;
587   }
588 
589   inputs.push_back(switch_loop_input.at(kIterLoopParamName));
590   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
591   CNodePtr stream_switch_app = kernel_graph_ptr->NewCNode(inputs);
592   MS_EXCEPTION_IF_NULL(stream_switch_app);
593   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_switch_app.get());
594   stream_switch_app->set_abstract(typeNone_abstract);
595   // set attr: cond_ RT_LESS
596   int condition = static_cast<int>(RT_LESS);
597   ValuePtr cond = MakeValue(condition);
598   AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app);
599   // set attr:data_type
600   int data_type = static_cast<int>(RT_SWITCH_INT64);
601   ValuePtr dt = MakeValue(data_type);
602   AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app);
603   // set distinction label and graph id
604   return stream_switch_app;
605 }
606 
CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)607 CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
608   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
609     {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
610   abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>();
611   auto stream_active_others = std::make_shared<Primitive>(kStreamActiveOpName);
612   std::vector<AnfNodePtr> inputs;
613   inputs.push_back(NewValueNode(stream_active_others));
614   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
615   CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs);
616   MS_EXCEPTION_IF_NULL(stream_active_others_app);
617   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get());
618   stream_active_others_app->set_abstract(typeNone_abstract);
619   return stream_active_others_app;
620 }
621 
CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const CNodePtr & node,size_t output_idx)622 CNodePtr KernelAdjust::CreatTupleGetItemNode(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
623                                              const CNodePtr &node, size_t output_idx) {
624   auto idx = NewValueNode(SizeToLong(output_idx));
625   MS_EXCEPTION_IF_NULL(idx);
626   auto imm = std::make_shared<Int64Imm>(SizeToInt(output_idx));
627   auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
628   idx->set_abstract(abstract_scalar);
629   CNodePtr tuple_getitem = kernel_graph_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx});
630   MS_EXCEPTION_IF_NULL(tuple_getitem);
631   tuple_getitem->set_scope(node->scope());
632   std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
633   TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
634   AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
635   return tuple_getitem;
636 }
637 
CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const CNodePtr & getnext_cnode)638 CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
639                                              const CNodePtr &getnext_cnode) {
640   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
641   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
642   selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
643   selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8});
644 
645   selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
646   selected_kernel_builder.SetProcessor(kernel::Processor::AICPU);
647   selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL);
648 
649   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
650   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeUInt8});
651   // EndOfSequence
652   auto end_of_sequence = std::make_shared<Primitive>(kEndOfSequence);
653   std::vector<AnfNodePtr> inputs;
654   inputs.push_back(NewValueNode(end_of_sequence));
655   // GetNext output 0 is EndOfSequence's input
656   auto tuple_get_item = CreatTupleGetItemNode(kernel_graph_ptr, getnext_cnode, 0);
657   inputs.push_back(tuple_get_item);
658   CNodePtr end_of_sequence_node = kernel_graph_ptr->NewCNode(inputs);
659   MS_EXCEPTION_IF_NULL(end_of_sequence_node);
660   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), end_of_sequence_node.get());
661   std::vector<std::string> input_names = {"x"};
662   ValuePtr input_names_v = MakeValue(input_names);
663   AnfAlgo::SetNodeAttr("input_names", input_names_v, end_of_sequence_node);
664   std::vector<std::string> output_names = {"y"};
665   ValuePtr output_names_v = MakeValue(output_names);
666   AnfAlgo::SetNodeAttr("output_names", output_names_v, end_of_sequence_node);
667   end_of_sequence_node->set_abstract(tuple_get_item->abstract());
668   return end_of_sequence_node;
669 }
670 
CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const std::map<std::string,mindspore::ParameterPtr> & switch_loop_input,bool cur_loop)671 CNodePtr KernelAdjust::CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
672                                                 const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input,
673                                                 bool cur_loop) {
674   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
675   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
676     {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
677   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
678   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32});
679   // AssignAdd
680   auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
681   std::vector<AnfNodePtr> inputs;
682   inputs.push_back(NewValueNode(assign_add));
683   if (cur_loop) {
684     inputs.push_back(switch_loop_input.at(kCurLoopCountParamName));
685   } else {
686     inputs.push_back(switch_loop_input.at(kNextLoopCountParamName));
687   }
688 
689   inputs.push_back(switch_loop_input.at(kOneParamName));
690   CNodePtr assign_add_one = kernel_graph_ptr->NewCNode(inputs);
691   MS_EXCEPTION_IF_NULL(assign_add_one);
692   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_one.get());
693   std::vector<std::string> input_names = {"ref", "value"};
694   std::vector<std::string> output_names = {"output"};
695   ValuePtr input_names_v = MakeValue(input_names);
696   ValuePtr output_names_v = MakeValue(output_names);
697   AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_one);
698   AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_one);
699   selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
700   MS_EXCEPTION_IF_NULL(switch_loop_input.at(kCurLoopCountParamName));
701   assign_add_one->set_abstract(switch_loop_input.at(kCurLoopCountParamName)->abstract());
702   // add AssignAdd op to kernel ref node map
703   session::AnfWithOutIndex final_pair = std::make_pair(assign_add_one, 0);
704   session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(assign_add_one, 0), 0);
705   kernel_graph_ptr->AddRefCorrespondPairs(final_pair, kernel_with_index);
706   return assign_add_one;
707 }
708 
StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)709 bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
710   auto &dump_json_parser = DumpJsonParser::GetInstance();
711   bool sink_mode = (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE || kernel_graph_ptr->IsDatasetGraph());
712   if (!sink_mode && dump_json_parser.async_dump_enabled()) {
713     InitCtrlInputs(kernel_graph_ptr);
714     return true;
715   }
716   if (!NeedInsertSwitch()) {
717     return true;
718   }
719   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
720   if (kernel_graph_ptr->is_dynamic_shape()) {
721     MS_LOG(INFO) << "Skip StepLoadCtrlInputs";
722     return true;
723   }
724   auto input_nodes = kernel_graph_ptr->inputs();
725   std::vector<tensor::TensorPtr> inputs;
726   LoadSwitchInputs(&inputs);
727   std::shared_ptr<std::vector<tensor::TensorPtr>> inputsPtr = std::make_shared<std::vector<tensor::TensorPtr>>(inputs);
728   kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr);
729   size_t input_ctrl_size = inputs.size();
730   // inputs_node:include four ctrl nodes in the back. such as:conv,loop_cnt, ites_loop, zero, one.
731   // deal four ctrl nodes.
732   for (size_t i = 0; i < inputs.size(); ++i) {
733     auto tensor = inputs[i];
734     MS_EXCEPTION_IF_NULL(tensor);
735     size_t deal_index = input_nodes.size() - input_ctrl_size + i;
736     if (deal_index >= input_nodes.size()) {
737       MS_LOG(EXCEPTION) << "deal_index[" << deal_index << "] out of range";
738     }
739     auto input_node = input_nodes[deal_index];
740     bool need_sync = false;
741     MS_EXCEPTION_IF_NULL(input_node);
742     if (input_node->isa<Parameter>()) {
743       auto pk_node = input_node->cast<ParameterPtr>();
744       MS_EXCEPTION_IF_NULL(pk_node);
745       if (tensor->NeedSyncHostToDevice() || !pk_node->has_default()) {
746         need_sync = true;
747       }
748     }
749     if (need_sync) {
750       auto pk_node = input_node->cast<ParameterPtr>();
751       MS_EXCEPTION_IF_NULL(pk_node);
752       auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
753       MS_EXCEPTION_IF_NULL(device_address);
754       tensor->set_device_address(device_address);
755       if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
756                                             LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
757                                             tensor->device_info().host_format_)) {
758         MS_LOG(INFO) << "SyncHostToDevice failed.";
759         return false;
760       }
761     }
762     tensor->set_sync_status(kNoNeedSync);
763   }
764   return true;
765 }
766 
LoadSwitchInputs(std::vector<tensor::TensorPtr> * inputs)767 void KernelAdjust::LoadSwitchInputs(std::vector<tensor::TensorPtr> *inputs) {
768   MS_LOG(INFO) << "---------------- LoadSwitchInputs---";
769   MS_EXCEPTION_IF_NULL(inputs);
770   // current loop count
771   ShapeVector shp = {1};
772   tensor::TensorPtr cur_loop_count = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
773   MS_EXCEPTION_IF_NULL(cur_loop_count);
774   int32_t *val = nullptr;
775   val = static_cast<int32_t *>(cur_loop_count->data_c());
776   MS_EXCEPTION_IF_NULL(val);
777   *val = 0;
778   inputs->push_back(cur_loop_count);
779 
780   // next loop count
781   tensor::TensorPtr next_loop_count = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
782   MS_EXCEPTION_IF_NULL(next_loop_count);
783   val = static_cast<int32_t *>(next_loop_count->data_c());
784   MS_EXCEPTION_IF_NULL(val);
785   *val = 0;
786   inputs->push_back(next_loop_count);
787 
788   // Epoch in device
789   tensor::TensorPtr epoch_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
790   MS_EXCEPTION_IF_NULL(epoch_tensor);
791   val = static_cast<int32_t *>(epoch_tensor->data_c());
792   MS_EXCEPTION_IF_NULL(val);
793   *val = 0;
794   inputs->push_back(epoch_tensor);
795 
796   // total loop count per iter
797   tensor::TensorPtr iter_loop_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
798   MS_EXCEPTION_IF_NULL(iter_loop_tensor);
799   val = static_cast<int32_t *>(iter_loop_tensor->data_c());
800   MS_EXCEPTION_IF_NULL(val);
801   if (ConfigManager::GetInstance().dataset_mode() == DS_NORMAL_MODE) {
802     MS_LOG(INFO) << "iter_loop_tensor not used in dataset_mode DS_NORMAL_MODE";
803     *val = 0;
804   } else {
805     *val = SizeToInt(LongToSize(ConfigManager::GetInstance().iter_num()));
806   }
807   MS_LOG(INFO) << "iter_loop_tensor = " << *val;
808   inputs->push_back(iter_loop_tensor);
809 
810   tensor::TensorPtr one_tensor = std::make_shared<tensor::Tensor>(kInt32->type_id(), shp);
811   MS_EXCEPTION_IF_NULL(one_tensor);
812   val = static_cast<int32_t *>(one_tensor->data_c());
813   MS_EXCEPTION_IF_NULL(val);
814   *val = 1;
815   inputs->push_back(one_tensor);
816 
817   MS_LOG(INFO) << "---------------- LoadSwitchInputs End--";
818 }
819 
InitCtrlInputs(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)820 void KernelAdjust::InitCtrlInputs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
821   MS_LOG(INFO) << " -------------------------- InitCtrlInputs Start-- ";
822   std::vector<tensor::TensorPtr> inputs;
823   // prepare default values for CtrlInputs
824   LoadSwitchInputs(&inputs);
825   std::shared_ptr<std::vector<tensor::TensorPtr>> inputsPtr = std::make_shared<std::vector<tensor::TensorPtr>>(inputs);
826   kernel_graph_ptr->set_input_ctrl_tensors(inputsPtr);
827   for (size_t i = 0; i < inputs.size(); ++i) {
828     auto tensor = inputs[i];
829     MS_EXCEPTION_IF_NULL(tensor);
830     device::DeviceAddressPtr device_address = std::make_shared<device::ascend::AscendDeviceAddress>(
831       nullptr, LongToSize(tensor->data().nbytes()), tensor->device_info().host_format_, tensor->data_type());
832     auto ms_context = MsContext::GetInstance();
833     MS_EXCEPTION_IF_NULL(ms_context);
834     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
835     auto runtime_instance = KernelRuntimeManager::Instance().GetSingleKernelRuntime(kAscendDevice, device_id);
836     if (runtime_instance->MallocMem(kStaticMem, LongToSize(tensor->data().nbytes()), device_address) == nullptr) {
837       MS_LOG(EXCEPTION) << "Cannot alloc address when flag is : " << kStaticMem
838                         << " , tensor size is : " << tensor->data().nbytes();
839     }
840     MS_EXCEPTION_IF_NULL(device_address);
841     tensor->set_device_address(device_address);
842     if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
843                                           tensor->data_c(), tensor->device_info().host_format_)) {
844       MS_LOG(EXCEPTION) << "SyncHostToDevice failed for InitCtrlInputs.";
845     }
846   }
847   MS_LOG(INFO) << " ------------------------- InitCtrlInputs End--";
848 }
849 
850 #ifndef ENABLE_SECURITY
Profiling(NotNull<session::KernelGraph * > kernel_graph_ptr)851 void KernelAdjust::Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr) {
852   if (!ascend::ProfilingManager::GetInstance().IsProfiling()) {
853     MS_LOG(INFO) << "No need to profiling";
854     return;
855   }
856   ProfilingTraceInfo profiling_trace_info = ProfilingUtils::GenerateProfilingTrace(*kernel_graph_ptr);
857   if (!profiling_trace_info.IsValid()) {
858     MS_LOG(INFO) << "[profiling] no profiling node found!";
859     return;
860   }
861   InsertProfilingKernel(profiling_trace_info, kernel_graph_ptr);
862 }
863 
InsertProfilingKernel(const ProfilingTraceInfo & profiling_trace_info,NotNull<session::KernelGraph * > kernel_graph_ptr)864 void KernelAdjust::InsertProfilingKernel(const ProfilingTraceInfo &profiling_trace_info,
865                                          NotNull<session::KernelGraph *> kernel_graph_ptr) {
866   MS_LOG(INFO) << "[profiling] Insert profiling kernel start";
867   if (!profiling_trace_info.IsValid()) {
868     MS_LOG(WARNING) << "Profiling trace point not found";
869     return;
870   }
871   std::vector<CNodePtr> new_cnode_list;
872   std::vector<CNodePtr> cnode_ptr_list = kernel_graph_ptr->execution_order();
873   if (cnode_ptr_list.empty()) {
874     MS_LOG(ERROR) << "No CNode in graph " << kernel_graph_ptr->graph_id();
875     return;
876   }
877   for (const auto &cnode_ptr : cnode_ptr_list) {
878     ProfilingUtils::InsertProfilingTraceFp(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
879                                            NOT_NULL(&new_cnode_list));
880     new_cnode_list.emplace_back(cnode_ptr);
881     ProfilingUtils::InsertProfilingCustomOp(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
882                                             NOT_NULL(&new_cnode_list));
883     ProfilingUtils::InsertProfilingTraceBpEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
884                                               NOT_NULL(&new_cnode_list));
885     ProfilingUtils::InsertProfilingTraceIterEnd(cnode_ptr, profiling_trace_info, kernel_graph_ptr,
886                                                 NOT_NULL(&new_cnode_list));
887   }
888   kernel_graph_ptr->set_execution_order(new_cnode_list);
889 }
890 #endif
891 
CreateNPUGetFloatStatus(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const CNodePtr & npu_alloc_cnode)892 CNodePtr KernelAdjust::CreateNPUGetFloatStatus(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
893                                                const CNodePtr &npu_alloc_cnode) {
894   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
895   MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
896   auto npu_get_primitive = std::make_shared<Primitive>(kNPUGetFloatStatusOpName);
897   std::vector<AnfNodePtr> npu_get_inputs = {NewValueNode(npu_get_primitive), npu_alloc_cnode};
898   auto npu_get_cnode = kernel_graph_ptr->NewCNode(npu_get_inputs);
899   MS_EXCEPTION_IF_NULL(npu_get_cnode);
900   npu_alloc_cnode->set_scope(kDefaultScope);
901   npu_get_cnode->set_abstract(npu_alloc_cnode->abstract());
902 
903   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
904   selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
905   selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
906   selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
907   selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
908   selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
909   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
910   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
911   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), npu_get_cnode.get());
912   return npu_get_cnode;
913 }
914 
CreateNPUClearStatus(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const CNodePtr & npu_alloc_cnode)915 CNodePtr KernelAdjust::CreateNPUClearStatus(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
916                                             const CNodePtr &npu_alloc_cnode) {
917   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
918   MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
919   auto npu_clear_primitive = std::make_shared<Primitive>(kNPUClearFloatStatusOpName);
920   std::vector<AnfNodePtr> npu_clear_inputs = {NewValueNode(npu_clear_primitive), npu_alloc_cnode};
921   auto npu_clear_cnode = kernel_graph_ptr->NewCNode(npu_clear_inputs);
922   MS_EXCEPTION_IF_NULL(npu_clear_cnode);
923   npu_alloc_cnode->set_scope(kDefaultScope);
924   npu_clear_cnode->set_abstract(npu_alloc_cnode->abstract());
925 
926   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
927   selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
928   selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
929   selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
930   selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
931   selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
932   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
933   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
934   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), npu_clear_cnode.get());
935 
936   return npu_clear_cnode;
937 }
938 
CreateNPUAllocStatus(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)939 CNodePtr KernelAdjust::CreateNPUAllocStatus(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
940   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
941   // create npu_alloc_cnode
942   auto npu_alloc_primitive = std::make_shared<Primitive>(kNPUAllocFloatStatusOpName);
943   std::vector<AnfNodePtr> npu_alloc_inputs = {NewValueNode(npu_alloc_primitive)};
944   auto npu_alloc_cnode = kernel_graph_ptr->NewCNode(npu_alloc_inputs);
945   MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
946   npu_alloc_cnode->set_scope(kDefaultScope);
947   std::vector<size_t> npu_output_shape = {kNPUShape};
948   AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {npu_output_shape}, npu_alloc_cnode.get());
949 
950   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
951   selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
952   selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
953   selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
954   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
955   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
956   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), npu_alloc_cnode.get());
957   return npu_alloc_cnode;
958 }
959 
CreateAssignAdd(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const CNodePtr & npu_alloc_cnode,const AnfNodePtr & specify_para)960 CNodePtr KernelAdjust::CreateAssignAdd(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
961                                        const CNodePtr &npu_alloc_cnode, const AnfNodePtr &specify_para) {
962   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
963   MS_EXCEPTION_IF_NULL(npu_alloc_cnode);
964   MS_EXCEPTION_IF_NULL(specify_para);
965   auto assign_add_primitive = std::make_shared<Primitive>(kAssignAddOpName);
966   std::vector<AnfNodePtr> assign_add_inputs = {NewValueNode(assign_add_primitive), specify_para, npu_alloc_cnode};
967   auto assign_add_cnode = kernel_graph_ptr->NewCNode(assign_add_inputs);
968   MS_EXCEPTION_IF_NULL(assign_add_cnode);
969   assign_add_cnode->set_scope(kDefaultScope);
970   assign_add_cnode->set_abstract(specify_para->abstract());
971 
972   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
973     {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat32});
974   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
975   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
976 
977   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_add_cnode.get());
978   std::vector<std::string> input_names = {"ref", "value"};
979   std::vector<std::string> output_names = {"output"};
980   ValuePtr input_names_v = MakeValue(input_names);
981   ValuePtr output_names_v = MakeValue(output_names);
982   AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_add_cnode);
983   AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_add_cnode);
984   selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
985 
986   session::AnfWithOutIndex final_pair = std::make_pair(assign_add_cnode, 0);
987   session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(assign_add_cnode, 0), 0);
988   kernel_graph_ptr->AddRefCorrespondPairs(final_pair, kernel_with_index);
989   return assign_add_cnode;
990 }
991 
CreateAssign(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr,const AnfNodePtr & specify_para)992 CNodePtr KernelAdjust::CreateAssign(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
993                                     const AnfNodePtr &specify_para) {
994   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
995   MS_EXCEPTION_IF_NULL(specify_para);
996 
997   std::vector<float> reset(kNPUShape, 0.0);
998   ShapeVector reset_shape({static_cast<int64_t>(kNPUShape)});
999   auto shp_buf_size = sizeof(float) * reset.size();
1000   auto reset_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, reset_shape, reset.data(), shp_buf_size);
1001   auto reset_value_node = std::make_shared<ValueNode>(reset_tensor);
1002   MS_EXCEPTION_IF_NULL(reset_value_node);
1003   reset_value_node->set_abstract(specify_para->abstract());
1004   kernel_graph_ptr->AddValueNodeToGraph(reset_value_node);
1005 
1006   auto kernel_info = std::make_shared<device::KernelInfo>();
1007   MS_EXCEPTION_IF_NULL(kernel_info);
1008   reset_value_node->set_kernel_info(kernel_info);
1009   kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
1010   builder1.SetOutputsFormat({kOpFormat_DEFAULT});
1011   builder1.SetOutputsDeviceType({kNumberTypeFloat32});
1012   AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), reset_value_node.get());
1013 
1014   auto assign_primitive = std::make_shared<Primitive>(kAssignOpName);
1015   std::vector<AnfNodePtr> assign_inputs = {NewValueNode(assign_primitive), specify_para, reset_value_node};
1016   auto assign_cnode = kernel_graph_ptr->NewCNode(assign_inputs);
1017   MS_EXCEPTION_IF_NULL(assign_cnode);
1018   assign_cnode->set_scope(kDefaultScope);
1019   assign_cnode->set_abstract(specify_para->abstract());
1020 
1021   kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
1022     {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat32});
1023   selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
1024   selected_kernel_builder.SetOutputsDeviceType({kNumberTypeFloat32});
1025 
1026   AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), assign_cnode.get());
1027   std::vector<std::string> input_names = {"ref", "value"};
1028   std::vector<std::string> output_names = {"output"};
1029   ValuePtr input_names_v = MakeValue(input_names);
1030   ValuePtr output_names_v = MakeValue(output_names);
1031   AnfAlgo::SetNodeAttr("input_names", input_names_v, assign_cnode);
1032   AnfAlgo::SetNodeAttr("output_names", output_names_v, assign_cnode);
1033   selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
1034 
1035   session::AnfWithOutIndex final_pair = std::make_pair(assign_cnode, 0);
1036   session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(assign_cnode, 0), 0);
1037   kernel_graph_ptr->AddRefCorrespondPairs(final_pair, kernel_with_index);
1038   return assign_cnode;
1039 }
1040 
InsertOverflowCheckOperations(const std::shared_ptr<session::KernelGraph> & kernel_graph_ptr)1041 void KernelAdjust::InsertOverflowCheckOperations(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
1042   MS_LOG(INFO) << "Start Insert Overflow Check Operations.";
1043 
1044   MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
1045   auto parameters = kernel_graph_ptr->parameters();
1046   AnfNodePtr specify_para;
1047   bool not_find = true;
1048   for (size_t i = 0; i < parameters.size(); i++) {
1049     auto para_fullname = parameters[i]->fullname_with_scope();
1050     if (para_fullname.find(kSpecifyParameter) != std::string::npos) {
1051       not_find = false;
1052       specify_para = parameters[i];
1053       break;
1054     }
1055   }
1056 
1057   if (not_find) {
1058     MS_LOG(INFO) << "Not find parameter named " << kSpecifyParameter;
1059     return;
1060   }
1061 
1062   bool first_grad_op = true;
1063   CNodePtr npu_alloc_cnode;
1064   std::vector<CNodePtr> new_execution_order;
1065   auto execution_order = kernel_graph_ptr->execution_order();
1066   for (size_t i = 0; i < execution_order.size() - 1; i++) {
1067     new_execution_order.push_back(execution_order[i]);
1068     auto cur_full_name = execution_order[i]->fullname_with_scope();
1069     auto next_full_name = execution_order[i + 1]->fullname_with_scope();
1070     auto cur_stream_id = AnfAlgo::GetStreamId(execution_order[i]);
1071     auto next_stream_id = AnfAlgo::GetStreamId(execution_order[i + 1]);
1072 
1073     if (cur_full_name.find(kGradients) == std::string::npos && next_full_name.find(kGradients) != std::string::npos) {
1074       if (first_grad_op) {
1075         npu_alloc_cnode = CreateNPUAllocStatus(kernel_graph_ptr);
1076         auto npu_clear_cnode = CreateNPUClearStatus(kernel_graph_ptr, npu_alloc_cnode);
1077         auto assign_cnode = CreateAssign(kernel_graph_ptr, specify_para);
1078         AnfAlgo::SetStreamId(next_stream_id, npu_alloc_cnode.get());
1079         AnfAlgo::SetStreamId(next_stream_id, npu_clear_cnode.get());
1080         AnfAlgo::SetStreamId(next_stream_id, assign_cnode.get());
1081         new_execution_order.push_back(npu_alloc_cnode);
1082         new_execution_order.push_back(npu_clear_cnode);
1083         new_execution_order.push_back(assign_cnode);
1084         first_grad_op = false;
1085       } else {
1086         auto npu_clear_cnode = CreateNPUClearStatus(kernel_graph_ptr, npu_alloc_cnode);
1087         AnfAlgo::SetStreamId(next_stream_id, npu_clear_cnode.get());
1088         new_execution_order.push_back(npu_clear_cnode);
1089       }
1090     }
1091     if (cur_full_name.find(kGradients) != std::string::npos && next_full_name.find(kGradients) == std::string::npos) {
1092       auto npu_get_cnode = CreateNPUGetFloatStatus(kernel_graph_ptr, npu_alloc_cnode);
1093       auto assign_add_cnode = CreateAssignAdd(kernel_graph_ptr, npu_alloc_cnode, specify_para);
1094       AnfAlgo::SetStreamId(cur_stream_id, npu_get_cnode.get());
1095       AnfAlgo::SetStreamId(cur_stream_id, assign_add_cnode.get());
1096       new_execution_order.push_back(npu_get_cnode);
1097       new_execution_order.push_back(assign_add_cnode);
1098     }
1099     if (i == execution_order.size() - kLastHandleDiff) {
1100       new_execution_order.push_back(execution_order[i + 1]);
1101       if (next_full_name.find(kGradients) != std::string::npos) {
1102         auto npu_get_cnode = CreateNPUGetFloatStatus(kernel_graph_ptr, npu_alloc_cnode);
1103         auto assign_add_cnode = CreateAssignAdd(kernel_graph_ptr, npu_alloc_cnode, specify_para);
1104         AnfAlgo::SetStreamId(cur_stream_id, npu_get_cnode.get());
1105         AnfAlgo::SetStreamId(cur_stream_id, assign_add_cnode.get());
1106         new_execution_order.push_back(npu_get_cnode);
1107         new_execution_order.push_back(assign_add_cnode);
1108       }
1109     }
1110   }
1111 
1112   kernel_graph_ptr->set_execution_order(new_execution_order);
1113 }
1114 }  // namespace device
1115 }  // namespace mindspore
1116