• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "runtime/device/ascend/ascend_kernel_runtime.h"
17 #include <string>
18 #include <vector>
19 #include <memory>
20 #include <utility>
21 #include <algorithm>
22 #include <set>
23 #include "utils/signal_util.h"
24 #include "runtime/device/ascend/ascend_device_address.h"
25 #include "runtime/device/ascend/distribute/ascend_collective.h"
26 #include "utils/ms_context.h"
27 #include "utils/context/context_extends.h"
28 #include "utils/mpi/mpi_config.h"
29 #include "common/trans.h"
30 #include "runtime/rt.h"
31 #include "runtime/device/ascend/ascend_stream_assign.h"
32 #include "runtime/device/ascend/ge_runtime/model_runner.h"
33 #include "runtime/device/ascend/tasksink/task_generator.h"
34 #include "backend/session/anf_runtime_algorithm.h"
35 #include "backend/session/kernel_build_client.h"
36 #ifndef ENABLE_SECURITY
37 #include "runtime/device/ascend/profiling/profiling_manager.h"
38 #include "runtime/device/ascend/profiling/profiling_utils.h"
39 #endif
40 #include "runtime/device/ascend/ascend_memory_manager.h"
41 #include "runtime/device/ascend/ascend_event.h"
42 #ifndef ENABLE_SECURITY
43 #include "debug/data_dump/dump_json_parser.h"
44 #include "debug/data_dump/e2e_dump.h"
45 #endif
46 #include "toolchain/adx_datadump_server.h"
47 #include "utils/trace_base.h"
48 #include "graphengine/inc/external/acl/error_codes/rt_error_codes.h"
49 #include "common/util/error_manager/error_manager.h"
50 #include "debug/anf_ir_dump.h"
51 #include "frontend/parallel/context.h"
52 #include "utils/comm_manager.h"
53 #include "utils/runtime_error_codes.h"
54 #ifdef MEM_REUSE_DEBUG
55 #include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
56 #include "debug/env_config_parser.h"
57 #endif
58 #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
59 #include "utils/config_manager.h"
60 #include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h"
61 #include "runtime/hccl_adapter/hccl_adapter.h"
62 #ifdef ENABLE_TDTQUE
63 #include "minddata/dataset/engine/tdt/tdt_handle.h"
64 using mindspore::dataset::TdtHandle;
65 #endif
66 #ifdef ENABLE_DUMP_IR
67 #include "debug/rdr/running_data_recorder.h"
68 #endif
69 
70 #include "backend/session/pynative_task_manager.h"
71 #include "profiler/device/profiling.h"
72 
73 #ifndef ENABLE_SECURITY
74 using mindspore::device::ascend::ProfilingManager;
75 using mindspore::device::ascend::ProfilingUtils;
76 #endif
77 using mindspore::device::ascend::tasksink::TaskGenerator;
78 using mindspore::ge::model_runner::ModelRunner;
79 using mindspore::kernel::tbe::TbeUtils;
80 using std::vector;
81 
82 constexpr uint32_t kTupleTaskId = 0;
83 constexpr uint32_t kTupleStreamId = 1;
84 constexpr uint32_t kTupleArgs = 2;
85 constexpr uint32_t kProfilingMaxTaskIdInStream = 65531;
86 constexpr auto kModuleName = "MindSpore";
87 constexpr size_t kPathMax = 4096;
88 
89 namespace mindspore::device::ascend {
90 static thread_local rtContext_t thread_local_rt_context{nullptr};
91 constexpr auto kUnknowErrorString = "Unknown error occurred";
92 namespace {
GetRankIdStr()93 std::string GetRankIdStr() {
94   auto context_ptr = MsContext::GetInstance();
95   MS_EXCEPTION_IF_NULL(context_ptr);
96   if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
97     MS_LOG(INFO) << "Get hccl rankid from mpi";
98     auto rank = HcclCollectiveGroup::instance().GetRankId();
99     return std::to_string(rank);
100   }
101   auto rank_id_str = common::GetEnv("RANK_ID");
102   if (rank_id_str.empty()) {
103     MS_LOG(EXCEPTION) << "Invalid environment variable 'RANK_ID', it should not be empty.";
104   }
105   return rank_id_str;
106 }
107 
IntHandler(int,siginfo_t *,void *)108 void IntHandler(int, siginfo_t *, void *) {
109   mindspore::kernel::AscendKernelBuildClient::Instance().Close();
110   int this_pid = getpid();
111   MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
112   (void)kill(this_pid, SIGTERM);
113 }
114 }  // namespace
115 
116 std::vector<rtExceptionInfo> AscendKernelRuntime::task_fail_infoes_ = {};
117 const session::KernelGraph *current_graph_ = nullptr;
118 std::map<std::string, uint32_t> AscendKernelRuntime::overflow_tasks_;
~AscendKernelRuntime()119 AscendKernelRuntime::~AscendKernelRuntime() {
120   graph_model_map_.clear();
121   current_graph_ = nullptr;
122   rt_context_ = nullptr;
123 }
124 
SetContext()125 void AscendKernelRuntime::SetContext() {
126   if (rt_context_ == nullptr) {
127     return;
128   }
129   if (thread_local_rt_context == rt_context_) {
130     return;
131   }
132   auto ret = rtCtxSetCurrent(rt_context_);
133   thread_local_rt_context = rt_context_;
134   if (ret != RT_ERROR_NONE) {
135     MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
136   }
137 }
138 
SetCurrentContext()139 void AscendKernelRuntime::SetCurrentContext() {
140   if (rt_context_ == nullptr) {
141     return;
142   }
143   auto ret = rtCtxSetCurrent(rt_context_);
144   if (ret != RT_ERROR_NONE) {
145     MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
146   }
147 }
148 
ClearGraphModelMap()149 void AscendKernelRuntime::ClearGraphModelMap() {
150   SetCurrentContext();
151 #ifndef ENABLE_SECURITY
152   for (auto &iter : graph_data_dumper_) {
153     MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first;
154     auto &data_dumper = iter.second;
155     MS_EXCEPTION_IF_NULL(data_dumper);
156     data_dumper->UnloadDumpInfo();
157     data_dumper->OpDebugUnregister();
158   }
159   graph_data_dumper_.clear();
160   // tell users which dump kernel name not used
161   DumpJsonParser::GetInstance().PrintUnusedKernel();
162 #endif
163 
164   graph_dynamic_kernel_map_.clear();
165   graph_kernel_events_map_.clear();
166   for (auto &iter : graph_model_map_) {
167     MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
168     ModelRunner::Instance().UnloadModel(iter.first);
169   }
170 }
171 
ClearGraphRuntimeResource(uint32_t graph_id)172 void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
173   SetCurrentContext();
174   auto mem_scheduler = mem_scheduler_manager_.GetMemScheduler(graph_id);
175   if (mem_scheduler != nullptr) {
176     mem_scheduler->Clear();
177   }
178   MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
179 #ifndef ENABLE_SECURITY
180   if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
181     MS_LOG(DEBUG) << "Unload dump info " << graph_id;
182     auto &data_dumper = dumper_iter->second;
183     MS_EXCEPTION_IF_NULL(data_dumper);
184     data_dumper->UnloadDumpInfo();
185     data_dumper->OpDebugUnregister();
186     graph_data_dumper_.erase(dumper_iter);
187   } else {
188     MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
189   }
190 #endif
191 
192   MS_LOG(DEBUG) << "Clear graph:" << graph_id << " dynamic kernels";
193   if (auto dynamic_kernel_iter = graph_dynamic_kernel_map_.find(graph_id);
194       dynamic_kernel_iter != graph_dynamic_kernel_map_.end()) {
195     MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel";
196     graph_dynamic_kernel_map_.erase(dynamic_kernel_iter);
197   }
198   auto events_iter = graph_kernel_events_map_.find(graph_id);
199   if (events_iter != graph_kernel_events_map_.end()) {
200     graph_kernel_events_map_.erase(events_iter);
201   }
202   MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
203   if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
204     MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
205     ModelRunner::Instance().UnloadModel(graph_id);
206     graph_model_map_.erase(model_iter);
207   } else {
208     MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
209   }
210 }
211 
GetModelStream(uint32_t graph_id) const212 void *AscendKernelRuntime::GetModelStream(uint32_t graph_id) const {
213   return ModelRunner::Instance().GetModelStream(graph_id);
214 }
215 
ClearGlobalIdleMem()216 void AscendKernelRuntime::ClearGlobalIdleMem() {
217   if (mem_manager_ != nullptr) {
218     mem_manager_->ClearGlobalIdleMem();
219   }
220 }
221 
NeedDestroyHccl()222 bool AscendKernelRuntime::NeedDestroyHccl() {
223   auto context_ptr = MsContext::GetInstance();
224   MS_EXCEPTION_IF_NULL(context_ptr);
225   if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
226     MS_LOG(INFO) << "Hccl is not enabled";
227     return false;
228   }
229   // Note: make sure hcom_connectivity_detection api never be used.
230   return true;
231 }
232 
233 #ifndef ENABLE_SECURITY
AsyncDataDumpUninit()234 void AsyncDataDumpUninit() {
235   if (DumpJsonParser::GetInstance().async_dump_enabled()) {
236     if (AdxDataDumpServerUnInit() != 0) {
237       MS_LOG(ERROR) << "Adx data dump server uninit failed";
238     }
239   }
240 }
241 
ReportProfilingData()242 void AscendKernelRuntime::ReportProfilingData() {
243   auto context = MsContext::GetInstance();
244   MS_EXCEPTION_IF_NULL(context);
245   if (ProfilingManager::GetInstance().IsProfiling() &&
246       context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
247     // Save Profiling Framework data
248     OpNameTaskStreamReporter reporter(device_id_, "nonsink", stream_id_task_id_op_name_map_);
249     reporter.ReportData();
250   }
251 }
252 #endif
253 
ReleaseDeviceRes()254 void AscendKernelRuntime::ReleaseDeviceRes() {
255   MS_LOG(INFO) << "Ascend finalize start";
256 #ifdef ENABLE_DEBUGGER
257   if (debugger_ && debugger_->debugger_enabled()) {
258     debugger_->SetTrainingDone(true);
259     bool ret = debugger_->SendMetadata(false);
260     if (!ret) {
261       MS_LOG(ERROR) << "Failed to SendMetadata when finalize";
262     }
263   }
264 #endif
265   if (!initialized_) {
266     return;
267   }
268   SetCurrentContext();
269 #ifndef ENABLE_SECURITY
270   ReportProfilingData();
271 #endif
272   // release ge runtime
273   ClearGraphModelMap();
274 
275 #ifndef ENABLE_SECURITY
276   AsyncDataDumpUninit();
277 #endif
278 
279   auto context_ptr = MsContext::GetInstance();
280   MS_EXCEPTION_IF_NULL(context_ptr);
281   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
282 
283   // DestroyHccl must be called before FreeDeviceMemory
284   (void)DestroyHccl();
285   if (mem_manager_ != nullptr) {
286     mem_manager_->FreeDeviceMemory();
287   }
288 
289   auto rt_ret = rtRegTaskFailCallbackByModule(kModuleName, nullptr);
290   if (rt_ret != RT_ERROR_NONE) {
291     MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret;
292   }
293 
294   (void)ResetDevice(device_id);
295 #ifndef ENABLE_SECURITY
296   (void)ProfilingManager::GetInstance().StopProfiling();
297 #endif
298   current_graph_ = nullptr;
299   if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
300       !context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
301     HcclCollectiveGroup::instance().FinalizeCollective();
302   }
303   MS_LOG(INFO) << "Ascend finalize end";
304 }
305 
306 #ifndef ENABLE_SECURITY
PreInit()307 void AscendKernelRuntime::PreInit() {
308   const auto error_manager_ret = ErrorManager::GetInstance().Init();
309   if (error_manager_ret != 0) {
310     MS_LOG(WARNING) << "Init ErrorManager failed.";
311   }
312   auto ret = ProfilingManager::GetInstance().StartupProfiling(device_id_);
313   if (!ret) {
314     const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
315     if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
316       MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
317     }
318     MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed.";
319   }
320 }
321 #endif
322 
Init()323 bool AscendKernelRuntime::Init() {
324   auto ms_context = MsContext::GetInstance();
325   MS_EXCEPTION_IF_NULL(ms_context);
326   auto execution_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
327 
328   auto profiler_manager = profiler::ProfilerManager::GetInstance();
329   MS_EXCEPTION_IF_NULL(profiler_manager);
330   auto profiling_flag = profiler_manager->GetProfilingEnableFlag();
331   if (execution_mode == kPynativeMode && profiling_flag) {
332     pynative_mode_profiling_flag_ = true;
333   }
334   if (initialized_) {
335     SetCurrentContext();
336     return true;
337   }
338   const auto error_manager_ret = ErrorManager::GetInstance().Init();
339   if (error_manager_ret != 0) {
340     MS_LOG(WARNING) << "Init ErrorManager failed.";
341   }
342   try {
343     // Start up profiling before rtSetDevice
344     bool ret = InitDevice();
345     if (!ret) {
346       return ret;
347     }
348 #ifdef ENABLE_DEBUGGER
349     SetDebugger();
350 #endif
351     mem_manager_ = std::make_shared<AscendMemoryManager>();
352     MS_EXCEPTION_IF_NULL(mem_manager_);
353     mem_manager_->MallocDeviceMemory();
354 
355     // Set callback func when exception error
356     auto rt_ret = rtRegTaskFailCallbackByModule(kModuleName, TaskFailCallback);
357     if (rt_ret != RT_ERROR_NONE) {
358       MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret;
359     }
360   } catch (const std::exception &e) {
361     const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
362     if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
363       MS_LOG(EXCEPTION) << "Ascend error occurred, error message: " << error_message
364                         << "\nFirst error scene API: " << e.what();
365     }
366     throw;
367   }
368 
369   initialized_ = true;
370   return true;
371 }
372 
LoadData(const session::KernelGraph &)373 bool AscendKernelRuntime::LoadData(const session::KernelGraph & /* graph */) {
374 #ifdef ENABLE_DEBUGGER
375   MS_LOG(INFO) << "Start load step";
376   MS_EXCEPTION_IF_NULL(debugger_);
377   for (const auto &graph_ptr : debugger_->GetGraphPtrList()) {
378     debugger_->SetGraphPtr(graph_ptr);
379     // load output
380     debugger_->LoadGraphOutputs();
381     // load parameters
382     debugger_->LoadParametersAndConst();
383   }
384 #endif
385   return true;
386 }
387 
KernelMemNotReuse(const AnfNodePtr & node)388 bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
389   MS_EXCEPTION_IF_NULL(node);
390   bool need_dump = false;
391 #ifndef ENABLE_SECURITY
392   auto &dump_json_parser = DumpJsonParser::GetInstance();
393   if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 1) {
394     auto op_name = node->fullname_with_scope();
395     if (dump_json_parser.NeedDump(op_name)) {
396       need_dump = true;
397     }
398   }
399 #endif
400   return need_dump;
401 }
402 
CreateDeviceAddress(void * device_ptr,size_t device_size,const string & format,TypeId type_id) const403 DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
404                                                           TypeId type_id) const {
405   return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
406 }
407 
CreateDeviceAddress(void * device_ptr,size_t device_size,const string & format,TypeId type_id,const KernelWithIndex & node_index) const408 DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
409                                                           TypeId type_id, const KernelWithIndex &node_index) const {
410   return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
411 }
412 
Load(const session::KernelGraph & graph,bool is_task_sink)413 bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) {
414   if (!is_task_sink) {
415     MS_LOG(INFO) << "Graph mode with not task sink";
416     GenKernelEvents(graph);
417     return true;
418   }
419 
420   if (!GenTask(graph)) {
421     return false;
422   }
423   if (!LoadTask(graph)) {
424     return false;
425   }
426   return true;
427 }
428 
GenDynamicKernel(const session::KernelGraph & graph)429 bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph &graph) {
430   MS_LOG(INFO) << "GenDynamicKernel start";
431   auto cnode_list = graph.execution_order();
432   std::vector<DynamicKernelPtr> dynamic_kernels;
433   for (const auto &cnode : cnode_list) {
434     MS_EXCEPTION_IF_NULL(cnode);
435     MS_LOG(INFO) << "Generate node:" << cnode->fullname_with_scope() << " dynamic kernel";
436     auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
437     MS_EXCEPTION_IF_NULL(kernel_mod);
438     auto dynamic_kernel = kernel_mod->GenDynamicKernel(cnode, stream_);
439     if (dynamic_kernel == nullptr) {
440       MS_LOG(EXCEPTION) << "Dynamic shape is not supported with the operator [" << AnfAlgo::GetCNodeName(cnode) << "].";
441     }
442     dynamic_kernel->Initialize();
443     dynamic_kernels.emplace_back(dynamic_kernel);
444   }
445   graph_dynamic_kernel_map_[graph.graph_id()] = std::move(dynamic_kernels);
446   MS_LOG(INFO) << "GenDynamicKernel end";
447   return true;
448 }
449 
GenTask(const session::KernelGraph & graph)450 bool AscendKernelRuntime::GenTask(const session::KernelGraph &graph) {
451   SetCurrentContext();
452   if (graph.is_dynamic_shape()) {
453     if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
454       MS_LOG(EXCEPTION) << "Dynamic shape is not supported with dataset_sink_mode.";
455     }
456 #ifndef ENABLE_SECURITY
457     if (DumpJsonParser::GetInstance().async_dump_enabled()) {
458       MS_LOG(EXCEPTION) << "Dynamic shape is not supported with Asynchronous Dump. Please use Synchronous Dump.";
459     }
460 #endif
461     MS_LOG(INFO) << "Dynamic Shape Graph Generate Dynamic kernel";
462     return GenDynamicKernel(graph);
463   }
464   MS_LOG(INFO) << "GenTask start. GraphId:" << graph.graph_id();
465 #ifndef ENABLE_SECURITY
466   DumpJsonParser::GetInstance().UpdateNeedDumpKernels(graph);
467 #endif
468 #ifdef MEM_REUSE_DEBUG
469   if (!EnvConfigParser::GetInstance().GetSysMemreuse()) {
470     // Get normal graph ir for memreuse
471     mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(&graph);
472   }
473 #endif
474   vector<std::shared_ptr<TaskInfo>> task_info_list;
475   auto anf_node_list = graph.execution_order();
476   auto task_generator = TaskGenerator();
477   if (!task_generator.GenTasks(anf_node_list, &task_info_list, graph.graph_id())) {
478     return false;
479   }
480   // Store the task_info_list
481   auto insert_ret = task_map_.insert(std::make_pair(graph.graph_id(), task_info_list));
482   if (!insert_ret.second) {
483     MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
484   }
485   // Graph may have no compute node, such TensorAddGrad.
486   if (task_info_list.empty()) {
487     MS_LOG(WARNING) << "Graph " << graph.graph_id() << " have no compute node";
488     return true;
489   }
490   AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
491   AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
492   // the streams' flag not HEAD_STREAM
493   std::vector<uint32_t> wait_active_stream_list;
494   assign_instance.GetWaitStreams(&wait_active_stream_list);
495   std::vector<uint32_t> force_copy_stream_list;
496   assign_instance.GetHcomStreams(&force_copy_stream_list);
497   MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num()
498                << ", total event num:" << resource_manager.get_cur_event_num()
499                << ", total label num:" << graph.label_num()
500                << ", wait_active_stream_list size:" << wait_active_stream_list.size()
501                << ", force_copy_stream_list size:" << force_copy_stream_list.size();
502   auto model = std::make_shared<ge::model_runner::DavinciModel>(
503     task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0,
504     resource_manager.get_cur_stream_num(), graph.label_num(), resource_manager.get_cur_event_num(), 0);
505   auto ret = graph_model_map_.insert(std::make_pair(graph.graph_id(), model));
506   if (!ret.second) {
507     MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
508   }
509   MS_LOG(INFO) << "TaskGenerator GetTaskInfo end...";
510   return true;
511 }
512 
LoadTask(const session::KernelGraph & graph)513 bool AscendKernelRuntime::LoadTask(const session::KernelGraph &graph) {
514   SetCurrentContext();
515   if (graph.is_dynamic_shape()) {
516     MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
517     return true;
518   }
519 
520   MS_LOG(INFO) << "LoadTask start. GraphId:" << graph.graph_id();
521   if (GraphWithEmptyTaskList(graph)) {
522     MS_LOG(WARNING) << "LoadTask end, task list is empty";
523     return true;
524   }
525 
526   auto model_iter = graph_model_map_.find(graph.graph_id());
527   if (model_iter == graph_model_map_.end()) {
528     MS_LOG(ERROR) << "GraphId:" << graph.graph_id() << " Invalid! Graph LoadTask without GenTask.";
529     return false;
530   }
531 
532   MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first;
533   ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second);
534 
535 #ifndef ENABLE_SECURITY
536   std::function<void *()> model_handle =
537     std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first);
538   DistributeDebugTask(graph, NOT_NULL(model_handle));
539 #endif
540 
541   try {
542     ModelRunner::Instance().DistributeTask(model_iter->first);
543   } catch (const std::exception &e) {
544 #ifdef ENABLE_DUMP_IR
545     mindspore::RDR::TriggerAll();
546 #endif
547     MS_LOG(EXCEPTION) << "Distribute Task Failed, error: " << e.what();
548   }
549 
550 #ifndef ENABLE_SECURITY
551   if (ProfilingManager::GetInstance().IsProfiling()) {
552     auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first);
553     auto stream_ids = ModelRunner::Instance().GetStreamIdList(model_iter->first);
554     ProfilingUtils::ReportProfilingData(task_ids, stream_ids, graph);
555   }
556   LaunchDataDump(graph.graph_id());
557 #endif
558 
559   ModelRunner::Instance().LoadModelComplete(model_iter->first);
560   return true;
561 }
562 
563 #ifndef ENABLE_SECURITY
DistributeDebugTask(const session::KernelGraph & graph,const NotNull<std::function<void * ()>> & model_handle)564 void AscendKernelRuntime::DistributeDebugTask(const session::KernelGraph &graph,
565                                               const NotNull<std::function<void *()>> &model_handle) {
566   if (!DumpJsonParser::GetInstance().async_dump_enabled()) {
567     return;
568   }
569   MS_LOG(INFO) << "Start Distribute Debug Task";
570   auto data_dumper = std::make_shared<DataDumper>(&graph, model_handle);
571   MS_EXCEPTION_IF_NULL(data_dumper);
572   auto ret = graph_data_dumper_.try_emplace(graph.graph_id(), data_dumper);
573   data_dumper->OpDebugRegister();
574   if (!ret.second) {
575     MS_LOG(WARNING) << "[DataDump] Insert graphId:" << graph.graph_id() << " data dumper failed";
576   }
577 }
578 
LaunchDataDump(GraphId graph_id)579 void AscendKernelRuntime::LaunchDataDump(GraphId graph_id) {
580   if (!DumpJsonParser::GetInstance().async_dump_enabled()) {
581     return;
582   }
583   MS_LOG(INFO) << "Start Launch Dump Data";
584   auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph_id);
585   if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
586     auto &data_dumper = dumper_iter->second;
587     MS_EXCEPTION_IF_NULL(data_dumper);
588     data_dumper->set_runtime_info(runtime_info_map);
589     data_dumper->LoadDumpInfo();
590   } else {
591     MS_LOG(EXCEPTION) << "GraphId:" << graph_id << " not found";
592   }
593 }
594 #endif
595 
TaskFailCallback(rtExceptionInfo * task_fail_info)596 void AscendKernelRuntime::TaskFailCallback(rtExceptionInfo *task_fail_info) {
597   if (task_fail_info == nullptr || current_graph_ == nullptr) {
598     MS_LOG(ERROR) << "Execute TaskFailCallback failed. task_fail_info or current_graph_ is nullptr";
599     return;
600   }
601 
602   static std::mutex exception_mutex;
603   constexpr uint32_t kOverflowThreshold = 5;
604   std::lock_guard<std::mutex> lock(exception_mutex);
605   if (task_fail_info->retcode == ACL_ERROR_RT_AICORE_OVER_FLOW) {
606     auto node = AscendKernelRuntime::GetErrorNodeName(task_fail_info->streamid, task_fail_info->taskid);
607     if (!node) {
608       MS_LOG(WARNING) << "Node run task overflow, node name is unknown.";
609     } else {
610       auto key = std::to_string(task_fail_info->streamid) + std::to_string(task_fail_info->taskid) +
611                  std::to_string(current_graph_->graph_id());
612       if (overflow_tasks_.find(key) == overflow_tasks_.end() || overflow_tasks_[key] == kOverflowThreshold) {
613         // print overflow info
614         MS_LOG(WARNING) << "Node run task overflow, node name: " << node->fullname_with_scope()
615                         << "Task overflow infos task_id: " << task_fail_info->taskid
616                         << ", stream_id: " << task_fail_info->streamid << ", tid: " << task_fail_info->tid
617                         << ", device_id: " << task_fail_info->deviceid << ", retcode: " << task_fail_info->retcode
618                         << " (" << GetErrorMsg(task_fail_info->retcode) << ")"
619                         << ", trace: " << trace::DumpSourceLines(node);
620         overflow_tasks_[key] = 1;
621       } else {
622         overflow_tasks_[key]++;
623       }
624     }
625   } else {
626     task_fail_infoes_.push_back(*task_fail_info);
627   }
628 }
629 
GetErrorNodeName(uint32_t streamid,uint32_t taskid)630 CNodePtr AscendKernelRuntime::GetErrorNodeName(uint32_t streamid, uint32_t taskid) {
631   if (current_graph_ == nullptr) {
632     return nullptr;
633   }
634   auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(current_graph_->graph_id());
635   for (const auto &iter : runtime_info_map) {
636     MS_EXCEPTION_IF_NULL(iter.second);
637     auto task_id = std::get<kTupleTaskId>(*iter.second);
638     auto stream_id = std::get<kTupleStreamId>(*iter.second);
639     if (task_id == taskid && stream_id == streamid) {
640       auto &execute_node = current_graph_->execution_order();
641       auto node = std::find_if(execute_node.begin(), execute_node.end(), [&iter](const auto &node) {
642         MS_EXCEPTION_IF_NULL(node);
643         return node->UniqueName() == iter.first;
644       });
645       if (node != execute_node.end()) {
646         return *node;
647       }
648     }
649   }
650   return nullptr;
651 }
652 
GetDumpPath()653 std::string AscendKernelRuntime::GetDumpPath() {
654   uint32_t rank_id = 0;
655   auto inst = parallel::ParallelContext::GetInstance();
656   MS_EXCEPTION_IF_NULL(inst);
657   if (inst->parallel_mode() != parallel::STAND_ALONE) {
658     if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
659       MS_LOG(WARNING) << "Get rank id failed.";
660     }
661   }
662 
663   auto ms_om_path = common::GetEnv("MS_OM_PATH");
664   std::string path;
665   const auto kSuffix = "/node_dump";
666   if (ms_om_path.empty()) {
667     MS_LOG(WARNING) << "MS_OM_PATH is null, so dump to process local path, as ./rank_id/node_dump/...";
668     path = "./rank_" + std::to_string(rank_id) + kSuffix;
669   } else {
670     path = ms_om_path + "/rank_" + std::to_string(rank_id) + kSuffix;
671   }
672   return path;
673 }
674 
675 #ifndef ENABLE_SECURITY
DumpTaskExceptionInfo(const session::KernelGraph &)676 void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph & /* graph */) {
677   const std::string path = GetDumpPath();
678   if (access(path.c_str(), F_OK) == 0) {
679     if (!DeleteDumpDir(path)) {
680       MS_LOG(ERROR) << "Delete dump directory " << path << " failed";
681     }
682   }
683   for (const auto &task_fail_info : task_fail_infoes_) {
684     MS_LOG(ERROR) << "Task fail infos task_id: " << task_fail_info.taskid << ", stream_id: " << task_fail_info.streamid
685                   << ", tid: " << task_fail_info.tid << ", device_id: " << task_fail_info.deviceid
686                   << ", retcode: " << task_fail_info.retcode << " (" << GetErrorMsg(task_fail_info.retcode) << ")";
687     auto node = AscendKernelRuntime::GetErrorNodeName(task_fail_info.streamid, task_fail_info.taskid);
688     // Dump error data in local path
689     if (node == nullptr) {
690       continue;
691     }
692     auto full_scope_name = node->fullname_with_scope();
693     MS_LOG(ERROR) << "Dump node (" << full_scope_name << ") task error input/output data to: " << path
694                   << " trace: " << trace::DumpSourceLines(node);
695     E2eDump::DumpInputImpl(node, false, path, &full_scope_name, nullptr);
696     E2eDump::DumpOutputImpl(node, false, path, &full_scope_name, nullptr);
697   }
698 }
699 #endif
700 
Run(const session::KernelGraph & graph,bool is_task_sink)701 bool AscendKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_sink) {
702   const uint64_t kUSecondInSecond = 1000000;
703   SignalGuard sg(IntHandler);
704   bool ret = false;
705 
706   if (is_task_sink) {
707 #if defined(_WIN32) || defined(_WIN64)
708     auto start_time = std::chrono::steady_clock::now();
709 #else
710     struct timeval start_time {};
711     struct timeval end_time {};
712     (void)gettimeofday(&start_time, nullptr);
713 #endif
714     ret = RunTask(graph);
715 #if defined(_WIN32) || defined(_WIN64)
716     auto end_time = std::chrono::steady_clock::now();
717     std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time;
718     MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
719 #else
720     (void)gettimeofday(&end_time, nullptr);
721     uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
722     cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
723     MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
724 #endif
725   } else {
726     ret = LaunchKernels(graph);
727   }
728 
729   return ret;
730 }
731 
SetKernelModStream(const std::vector<CNodePtr> & kernels,std::vector<size_t> * last_stream_nodes)732 void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
733                                              std::vector<size_t> *last_stream_nodes) {
734   auto context_ptr = MsContext::GetInstance();
735   MS_EXCEPTION_IF_NULL(context_ptr);
736   auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
737   std::map<void *, size_t> last_kernel;
738   for (size_t i = 0; i < kernels.size(); ++i) {
739     auto &node = kernels[i];
740     auto kernel_mod = AnfAlgo::GetKernelMod(node);
741     auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
742     MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
743     if (AnfAlgo::IsCommunicationOp(node)) {
744       auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
745       auto iter = group_stream_id_map_.find(group);
746       if (iter == group_stream_id_map_.end()) {
747         void *stream = nullptr;
748         auto ret = rtStreamCreate(&stream, 0);
749         if (ret != RT_ERROR_NONE) {
750           MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
751         }
752         auto id = SizeToUint(stream_id_map_.size());
753         group_stream_id_map_[group] = id;
754         stream_id_map_[id] = stream;
755         AnfAlgo::SetStreamId(id, node.get());
756         ascend_kernel_mod->SetStream(stream);
757         last_kernel[stream] = i;
758       } else {
759         auto id = iter->second;
760         AnfAlgo::SetStreamId(id, node.get());
761         ascend_kernel_mod->SetStream(stream_id_map_[id]);
762         last_kernel[stream_id_map_[id]] = i;
763       }
764     } else if (AnfAlgo::IsIndependentNode(node) && mode != kPynativeMode) {
765       AnfAlgo::SetStreamId(1, node.get());
766       ascend_kernel_mod->SetStream(independent_stream_);
767       last_kernel[independent_stream_] = i;
768     } else {
769       AnfAlgo::SetStreamId(0, node.get());
770       ascend_kernel_mod->SetStream(stream_);
771     }
772   }
773   for (size_t i = 1; i < kernels.size(); ++i) {
774     if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) {
775       auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
776       AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get());
777       auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i - 1]);
778       auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
779       MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
780       ascend_kernel_mod->SetStream(stream_id_map_[stream_id]);
781     }
782   }
783   (void)std::transform(last_kernel.begin(), last_kernel.end(), std::back_inserter(*last_stream_nodes),
784                        [](const std::pair<void *, size_t> &item) { return item.second; });
785 }
786 
GenKernelEvents(const session::KernelGraph & graph)787 void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
788   auto &kernels = graph.execution_order();
789   if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
790     return;
791   }
792   std::vector<size_t> last_stream_nodes;
793   SetKernelModStream(kernels, &last_stream_nodes);
794   auto kernel_events =
795     std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
796   auto &kernel_pre_run_events = kernel_events.first;
797   auto &kernel_post_run_events = kernel_events.second;
798   kernel_pre_run_events.resize(kernels.size());
799   kernel_post_run_events.resize(kernels.size());
800   for (size_t i = 0; i < kernels.size(); ++i) {
801     auto &kernel = kernels[i];
802     auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
803     if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
804       MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created";
805     }
806     auto wait_stream = stream_id_map_[curr_stream_id];
807     auto stream_num = stream_id_map_.size();
808     std::vector<bool> stream_hit(stream_num, false);
809     std::vector<AnfNodePtr> used_kernels;
810     std::set<AnfNodePtr> visited_kernels;
811     AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
812     bool found_depend = false;
813     for (int k = SizeToInt(i) - 1; k >= 0; --k) {
814       auto pre_cnode = kernels[IntToSize(k)];
815       auto pre_cnode_stream_id = AnfAlgo::GetStreamId(pre_cnode);
816       if (pre_cnode_stream_id == curr_stream_id) {
817         found_depend = true;
818         continue;
819       }
820       for (auto &visited : used_kernels) {
821         if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) {
822           stream_hit[pre_cnode_stream_id] = true;
823           found_depend = true;
824           auto record_stream = stream_id_map_[pre_cnode_stream_id];
825           auto event = CreateDeviceEvent();
826           event->set_wait_stream(wait_stream);
827           event->set_record_stream(record_stream);
828           kernel_post_run_events[IntToSize(k)].emplace_back([event]() { event->RecordEvent(); });
829           kernel_pre_run_events[i].emplace_back([event]() { event->WaitEvent(); });
830         }
831       }
832     }
833     if (!found_depend && wait_stream != stream_) {
834       auto pre_event = CreateDeviceEvent();
835       pre_event->set_wait_stream(wait_stream);
836       pre_event->set_record_stream(stream_);
837       kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->RecordEvent(); });
838       kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->WaitEvent(); });
839     }
840   }
841   ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes);
842   graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
843 }
844 
ProcessBoundaryEvent(const std::vector<CNodePtr> & kernels,std::vector<std::vector<std::function<void ()>>> * kernel_run_events,const std::vector<size_t> & last_stream_nodes)845 void AscendKernelRuntime::ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
846                                                std::vector<std::vector<std::function<void()>>> *kernel_run_events,
847                                                const std::vector<size_t> &last_stream_nodes) {
848   for (auto &i : last_stream_nodes) {
849     if (i >= kernels.size()) {
850       MS_LOG(ERROR) << "Node index exceed kernel size.";
851       continue;
852     }
853     auto &kernel = kernels[i];
854     MS_EXCEPTION_IF_NULL(kernel);
855     bool found_nearest_child = false;
856     for (size_t j = i + 1; j < kernels.size(); ++j) {
857       auto &child = kernels[j];
858       MS_EXCEPTION_IF_NULL(child);
859       auto input_size = child->inputs().size() - 1;
860       for (size_t k = 0; k < input_size; ++k) {
861         auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
862         if (kernel_index.first == kernel) {
863           found_nearest_child = true;
864           break;
865         }
866       }
867       if (found_nearest_child) {
868         break;
869       }
870     }
871     if (!found_nearest_child) {
872       auto post_event = CreateDeviceEvent();
873       MS_EXCEPTION_IF_NULL(post_event);
874       auto id = AnfAlgo::GetStreamId(kernel);
875       auto record_stream = stream_id_map_[id];
876       post_event->set_wait_stream(stream_);
877       post_event->set_record_stream(record_stream);
878       (*kernel_run_events)[i].emplace_back([post_event]() { post_event->RecordEvent(); });
879       (*kernel_run_events)[i].emplace_back([post_event]() { post_event->WaitEvent(); });
880     }
881   }
882 }
883 
RunDynamicKernelAsync(const session::KernelGraph & graph)884 bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph &graph) {
885   MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph.graph_id();
886   auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
887   if (iter == graph_dynamic_kernel_map_.end()) {
888     MS_LOG(ERROR) << "GraphId:" << graph.graph_id() << " Not Found! Please generator executor first";
889     return false;
890   }
891 
892   auto dynamic_kernels = iter->second;
893   for (const auto &dynamic_kernel : dynamic_kernels) {
894     MS_EXCEPTION_IF_NULL(dynamic_kernel);
895     if (dynamic_kernel->have_depends() || dynamic_kernel->GetKernelType() == KernelType::HCCL_KERNEL) {
896       MS_LOG(INFO) << "Match Dynamic Kernel, Start SyncStream";
897       if (!SyncStream()) {
898         MS_LOG(ERROR) << "SyncStream failed";
899         return false;
900       }
901     }
902 
903     if (dynamic_kernel->is_dynamic_shape()) {
904       dynamic_kernel->InferShape();
905       dynamic_kernel->UpdateArgs();
906     }
907 
908     dynamic_kernel->Execute();
909     dynamic_kernel->PostExecute();
910   }
911 
912   if (!SyncStream()) {
913     MS_LOG(ERROR) << "SyncStream failed";
914     return false;
915   }
916 
917   return true;
918 }
919 
RunTask(const session::KernelGraph & graph)920 bool AscendKernelRuntime::RunTask(const session::KernelGraph &graph) {
921   current_graph_ = &graph;
922   SetCurrentContext();
923   if (graph.is_dynamic_shape()) {
924     MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
925     return RunDynamicKernelAsync(graph);
926   }
927 
928   MS_LOG(INFO) << "RunTask start. GraphId:" << graph.graph_id();
929 
930   auto context_ptr = MsContext::GetInstance();
931   MS_EXCEPTION_IF_NULL(context_ptr);
932   if (GraphWithEmptyTaskList(graph)) {
933     MS_LOG(WARNING) << "RunTask end, no task info found";
934     return true;
935   }
936 
937   if (!CheckGraphIdValid(graph.graph_id())) {
938     MS_LOG(ERROR) << "GraphId:" << graph.graph_id() << " Invalid! Graph RunTask without GenTask.";
939     return false;
940   }
941 
942   try {
943     ModelRunner::Instance().RunModel(graph.graph_id());
944   } catch (const std::exception &) {
945 #ifndef ENABLE_SECURITY
946     DumpTaskExceptionInfo(graph);
947 #endif
948 #ifdef ENABLE_TDTQUE
949     // Run task error, we should call TdtHostDestroy to release tdt to avoid DeviceQueueOp hostPush hung
950     // case1: cpu usage 100% cause thread/process exit, but some tdt thread remain in backend
951     if (!TdtHandle::DestroyHandle()) {
952       MS_LOG(WARNING) << "Destroy tdt channel failed.";
953     } else {
954       MS_LOG(INFO) << "Destroy tdt channel success.";
955     }
956 #endif
957     return false;
958   }
959   task_fail_infoes_.clear();
960   return true;
961 }
962 
SyncStream()963 bool AscendKernelRuntime::SyncStream() {
964   SetCurrentContext();
965   session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
966   for (auto &iter : stream_id_map_) {
967     if (rtStreamSynchronize(iter.second) != RT_ERROR_NONE) {  // o for switch stream
968       MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
969       return false;
970     }
971   }
972   return true;
973 }
974 
MemcpyAsync(void * dst,const void * src,uint64_t size,int32_t kind)975 bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) {
976   SetCurrentContext();
977   if (stream_ == nullptr) {
978     MS_LOG(ERROR) << "MemcpyAsync failed. stream_ is nullptr";
979     return false;
980   }
981 
982   auto copy_kind = static_cast<rtMemcpyKind_t>(kind);
983   if (copy_kind != RT_MEMCPY_HOST_TO_DEVICE_EX && copy_kind != RT_MEMCPY_DEVICE_TO_DEVICE) {
984     MS_LOG(EXCEPTION) << "Memory copy async not support cache host buffer in kind: " << kind;
985   }
986   if (RT_ERROR_NONE != rtMemcpyAsync(dst, size, src, size, static_cast<rtMemcpyKind_t>(kind), stream_)) {
987     MS_LOG(ERROR) << "Call runtime rtMemcpyAsync error.";
988     return false;
989   }
990   return true;
991 }
992 
CreateContext()993 void AscendKernelRuntime::CreateContext() {
994   if (rt_context_ == nullptr) {
995     auto ret = rtCtxCreate(&rt_context_, 0, UintToInt(device_id_));
996     if (ret != RT_ERROR_NONE) {
997       MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
998     }
999   }
1000   SetCurrentContext();
1001 }
1002 
InitDevice()1003 bool AscendKernelRuntime::InitDevice() {
1004   int device_count = 0;
1005   auto ret = rtGetDeviceCount(&device_count);
1006   if (ret != RT_ERROR_NONE) {
1007     MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
1008   }
1009 
1010   ret = rtSetDevice(UintToInt(device_id_));
1011   if (ret != RT_ERROR_NONE) {
1012     MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
1013   }
1014 
1015   auto context_ptr = MsContext::GetInstance();
1016   MS_EXCEPTION_IF_NULL(context_ptr);
1017   if (context_ptr == nullptr) {
1018     MS_LOG(ERROR) << "Get MsContext instance failed";
1019     return false;
1020   }
1021   if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
1022     if (!HcclInit()) {
1023       MS_LOG(ERROR) << "HcclInit init failed";
1024       return false;
1025     }
1026   }
1027 
1028   // Context will be created by rtSetDevice
1029   ret = rtCtxGetCurrent(&rt_context_);
1030   if (ret != RT_ERROR_NONE || rt_context_ == nullptr) {
1031     MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]";
1032     return false;
1033   }
1034 
1035   ret = rtStreamCreateWithFlags(&stream_, 0, RT_STREAM_HUGE);
1036   if (ret != RT_ERROR_NONE) {
1037     MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
1038   }
1039   ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
1040   if (ret != RT_ERROR_NONE) {
1041     MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
1042   }
1043   ret = rtStreamCreate(&communication_stream_, 0);
1044   if (ret != RT_ERROR_NONE) {
1045     MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
1046   }
1047   const int kCommunicationStreamID = 2;
1048   stream_id_map_[0] = stream_;
1049   stream_id_map_[1] = independent_stream_;
1050   stream_id_map_[kCommunicationStreamID] = communication_stream_;
1051   group_stream_id_map_[kHcclWorldGroup] = kCommunicationStreamID;
1052   return true;
1053 }
1054 
ResetDevice(uint32_t device_id)1055 bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
1056   SetCurrentContext();
1057   int32_t ret;
1058   for (auto &iter : stream_id_map_) {
1059     ret = rtStreamDestroy(iter.second);
1060     if (ret != RT_ERROR_NONE) {
1061       MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
1062     }
1063     iter.second = nullptr;
1064   }
1065   ret = rtDeviceReset(UintToInt(device_id));
1066   if (ret != RT_ERROR_NONE) {
1067     MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
1068   }
1069   // set to nullptr as its not created, only bounded to existing context
1070   rt_context_ = nullptr;
1071   return true;
1072 }
1073 
HcclInit()1074 bool AscendKernelRuntime::HcclInit() {
1075   auto context_ptr = MsContext::GetInstance();
1076   MS_EXCEPTION_IF_NULL(context_ptr);
1077   if (!context::IsTsdOpened(context_ptr)) {
1078     MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
1079   }
1080   MS_LOG(INFO) << "Do hcom init";
1081   bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
1082   auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
1083   if (!is_task_sink && mode == kGraphMode) {
1084     (void)hccl::HcclAdapter::GetInstance().InitHccl();
1085     auto rank_size = HcclCollectiveGroup::instance().GetRankSize();
1086     std::vector<unsigned int> ranks(rank_size);
1087     std::iota(std::begin(ranks), std::end(ranks), 0);
1088     HcclCollectiveGroup::instance().CreateCommGroup(kHcclWorldGroup, ranks);
1089     return true;
1090   }
1091 
1092   auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
1093   if (config_path_str == nullptr) {
1094     config_path_str = std::getenv("RANK_TABLE_FILE");
1095     if (config_path_str == nullptr) {
1096       MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE";
1097       return false;
1098     }
1099   }
1100   if (strlen(config_path_str) >= kPathMax) {
1101     MS_LOG(ERROR) << "File path oversize";
1102     return false;
1103   }
1104   std::string rank_id_str = GetRankIdStr();
1105   auto full_path = realpath(config_path_str, nullptr);
1106   if (full_path == nullptr) {
1107     MS_LOG(ERROR) << "File path " << config_path_str << " does not exist";
1108     return false;
1109   }
1110   MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
1111   bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str,
1112                                                        full_path, mode == kGraphMode);
1113   free(full_path);
1114   if (!ret) {
1115     MS_LOG(ERROR) << "Hcom init failed.";
1116     return false;
1117   }
1118   return true;
1119 }
1120 
DestroyHccl()1121 bool AscendKernelRuntime::DestroyHccl() {
1122   if (!NeedDestroyHccl()) {
1123     MS_LOG(INFO) << "Hccl is not enable, no need to close.";
1124     return true;
1125   }
1126   bool res = hccl::HcclAdapter::GetInstance().FinalizeHccl();
1127   if (!res) {
1128     MS_LOG(ERROR) << "Hccl destroy failed";
1129     return false;
1130   }
1131   MS_LOG(INFO) << "Hccl destroy successful.";
1132   auto context_ptr = MsContext::GetInstance();
1133   MS_EXCEPTION_IF_NULL(context_ptr);
1134   context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
1135   return true;
1136 }
1137 
GraphWithEmptyTaskList(const session::KernelGraph & graph) const1138 bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph &graph) const {
1139   auto iter = task_map_.find(graph.graph_id());
1140   if (iter == task_map_.end()) {
1141     MS_LOG(EXCEPTION) << "Unknown graph ptr";
1142   }
1143   return iter->second.empty();
1144 }
1145 
CheckGraphIdValid(GraphId graph_id) const1146 bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const {
1147   return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end();
1148 }
1149 
KernelLaunchProfiling(const std::string & kernel_name)1150 void AscendKernelRuntime::KernelLaunchProfiling(const std::string &kernel_name) {
1151   auto profiler_manager = profiler::ProfilerManager::GetInstance();
1152   MS_EXCEPTION_IF_NULL(profiler_manager);
1153   if (!profiler_manager->GetProfilingEnableFlag()) {
1154     return;
1155   }
1156 
1157   // save task info
1158   uint32_t stream_id;
1159   uint32_t task_id;
1160   auto rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id);
1161   if (rt_ret != RT_ERROR_NONE) {
1162     MS_LOG(EXCEPTION) << "Profiling get task_id stream_id failed";
1163   }
1164   std::pair<uint32_t, uint32_t> stream_task_pair = {stream_id, task_id};
1165   auto try_emplace_ret = stream_id_task_id_op_name_map_.try_emplace(stream_task_pair, kernel_name);
1166   if (!try_emplace_ret.second) {
1167     MS_LOG(WARNING) << "Profiling duplicate key, task_id:" << stream_task_pair.second
1168                     << " stream_id:" << stream_task_pair.first << " name:" << kernel_name;
1169   }
1170   if (stream_id_task_id_op_name_map_.size() > kProfilingMaxTaskIdInStream) {
1171     MS_LOG(EXCEPTION) << "Too many profiling data";
1172   }
1173 }
1174 
CreateDeviceEvent()1175 std::shared_ptr<DeviceEvent> AscendKernelRuntime::CreateDeviceEvent() {
1176   auto ascend_event = std::make_shared<AscendEvent>();
1177   MS_EXCEPTION_IF_NULL(ascend_event);
1178   return ascend_event;
1179 }
1180 
CreateDeviceTimeEvent()1181 std::shared_ptr<DeviceEvent> AscendKernelRuntime::CreateDeviceTimeEvent() {
1182   auto ascend_time_event = std::make_shared<AscendTimeEvent>();
1183   MS_EXCEPTION_IF_NULL(ascend_time_event);
1184   return ascend_time_event;
1185 }
1186 
GetAvailableMemMaxSize() const1187 uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
1188   auto ascend_mem_manager = std::dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
1189   MS_EXCEPTION_IF_NULL(ascend_mem_manager);
1190   return ascend_mem_manager->GetDeviceMemSize();
1191 }
1192 
DeleteDumpDir(const std::string & path)1193 bool AscendKernelRuntime::DeleteDumpDir(const std::string &path) {
1194   string real_path = GetRealPath(path);
1195   if (DeleteDumpFile(real_path) == -1) {
1196     return false;
1197   }
1198   if (rmdir(real_path.c_str()) == -1) {
1199     MS_LOG(WARNING) << "Delete dir " << real_path << " failed!";
1200   }
1201   return true;
1202 }
1203 
DeleteDumpFile(std::string path)1204 int AscendKernelRuntime::DeleteDumpFile(std::string path) {
1205   DIR *dir;
1206   struct dirent *dirinfo;
1207   struct stat statbuf;
1208   string filepath;
1209   int result = 0;
1210   lstat(path.c_str(), &statbuf);
1211 
1212   if (S_ISREG(statbuf.st_mode)) {
1213     result = remove(path.c_str());
1214   } else if (S_ISDIR(statbuf.st_mode)) {
1215     if ((dir = opendir(path.c_str())) == nullptr) {
1216       return -1;
1217     }
1218 
1219     while (!result && (dirinfo = readdir(dir))) {
1220       if (path[path.size() - 1] != '/') {
1221         path = path + "/";
1222       }
1223       MS_EXCEPTION_IF_NULL(dirinfo);
1224       filepath = path + dirinfo->d_name;
1225       if (strcmp(dirinfo->d_name, ".") == 0 || strcmp(dirinfo->d_name, "..") == 0) continue;
1226       result = DeleteDumpFile(filepath);
1227       if (!result) {
1228         if (rmdir(filepath.c_str()) == -1) {
1229           MS_LOG(WARNING) << "Delete dir " << filepath << " failed!";
1230         }
1231       }
1232     }
1233     if (closedir(dir) == -1) {
1234       MS_LOG(WARNING) << "Dump dir " << path << " close failed!";
1235     }
1236   }
1237   return result;
1238 }
1239 
GetRealPath(const std::string & path)1240 std::string AscendKernelRuntime::GetRealPath(const std::string &path) {
1241   char real_path_mem[kPathMax] = {0};
1242   char *real_path_ret = realpath(path.c_str(), real_path_mem);
1243   if (real_path_ret == nullptr) {
1244     return "";
1245   }
1246   return std::string(real_path_mem);
1247 }
1248 }  // namespace mindspore::device::ascend
1249