1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "runtime/device/ascend/ascend_kernel_runtime.h"
17 #include <string>
18 #include <vector>
19 #include <memory>
20 #include <utility>
21 #include <algorithm>
22 #include <set>
23 #include "utils/signal_util.h"
24 #include "runtime/device/ascend/ascend_device_address.h"
25 #include "runtime/device/ascend/distribute/ascend_collective.h"
26 #include "utils/ms_context.h"
27 #include "utils/context/context_extends.h"
28 #include "utils/mpi/mpi_config.h"
29 #include "common/trans.h"
30 #include "runtime/rt.h"
31 #include "runtime/device/ascend/ascend_stream_assign.h"
32 #include "runtime/device/ascend/ge_runtime/model_runner.h"
33 #include "runtime/device/ascend/tasksink/task_generator.h"
34 #include "backend/session/anf_runtime_algorithm.h"
35 #include "backend/session/kernel_build_client.h"
36 #ifndef ENABLE_SECURITY
37 #include "runtime/device/ascend/profiling/profiling_manager.h"
38 #include "runtime/device/ascend/profiling/profiling_utils.h"
39 #endif
40 #include "runtime/device/ascend/ascend_memory_manager.h"
41 #include "runtime/device/ascend/ascend_event.h"
42 #ifndef ENABLE_SECURITY
43 #include "debug/data_dump/dump_json_parser.h"
44 #include "debug/data_dump/e2e_dump.h"
45 #endif
46 #include "toolchain/adx_datadump_server.h"
47 #include "utils/trace_base.h"
48 #include "graphengine/inc/external/acl/error_codes/rt_error_codes.h"
49 #include "common/util/error_manager/error_manager.h"
50 #include "debug/anf_ir_dump.h"
51 #include "frontend/parallel/context.h"
52 #include "utils/comm_manager.h"
53 #include "utils/runtime_error_codes.h"
54 #ifdef MEM_REUSE_DEBUG
55 #include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
56 #include "debug/env_config_parser.h"
57 #endif
58 #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
59 #include "utils/config_manager.h"
60 #include "runtime/device/ascend/profiling/reporter/op_name_task_stream_reporter.h"
61 #include "runtime/hccl_adapter/hccl_adapter.h"
62 #ifdef ENABLE_TDTQUE
63 #include "minddata/dataset/engine/tdt/tdt_handle.h"
64 using mindspore::dataset::TdtHandle;
65 #endif
66 #ifdef ENABLE_DUMP_IR
67 #include "debug/rdr/running_data_recorder.h"
68 #endif
69
70 #include "backend/session/pynative_task_manager.h"
71 #include "profiler/device/profiling.h"
72
73 #ifndef ENABLE_SECURITY
74 using mindspore::device::ascend::ProfilingManager;
75 using mindspore::device::ascend::ProfilingUtils;
76 #endif
77 using mindspore::device::ascend::tasksink::TaskGenerator;
78 using mindspore::ge::model_runner::ModelRunner;
79 using mindspore::kernel::tbe::TbeUtils;
80 using std::vector;
81
82 constexpr uint32_t kTupleTaskId = 0;
83 constexpr uint32_t kTupleStreamId = 1;
84 constexpr uint32_t kTupleArgs = 2;
85 constexpr uint32_t kProfilingMaxTaskIdInStream = 65531;
86 constexpr auto kModuleName = "MindSpore";
87 constexpr size_t kPathMax = 4096;
88
89 namespace mindspore::device::ascend {
90 static thread_local rtContext_t thread_local_rt_context{nullptr};
91 constexpr auto kUnknowErrorString = "Unknown error occurred";
92 namespace {
GetRankIdStr()93 std::string GetRankIdStr() {
94 auto context_ptr = MsContext::GetInstance();
95 MS_EXCEPTION_IF_NULL(context_ptr);
96 if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
97 MS_LOG(INFO) << "Get hccl rankid from mpi";
98 auto rank = HcclCollectiveGroup::instance().GetRankId();
99 return std::to_string(rank);
100 }
101 auto rank_id_str = common::GetEnv("RANK_ID");
102 if (rank_id_str.empty()) {
103 MS_LOG(EXCEPTION) << "Invalid environment variable 'RANK_ID', it should not be empty.";
104 }
105 return rank_id_str;
106 }
107
IntHandler(int,siginfo_t *,void *)108 void IntHandler(int, siginfo_t *, void *) {
109 mindspore::kernel::AscendKernelBuildClient::Instance().Close();
110 int this_pid = getpid();
111 MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
112 (void)kill(this_pid, SIGTERM);
113 }
114 } // namespace
115
116 std::vector<rtExceptionInfo> AscendKernelRuntime::task_fail_infoes_ = {};
117 const session::KernelGraph *current_graph_ = nullptr;
118 std::map<std::string, uint32_t> AscendKernelRuntime::overflow_tasks_;
~AscendKernelRuntime()119 AscendKernelRuntime::~AscendKernelRuntime() {
120 graph_model_map_.clear();
121 current_graph_ = nullptr;
122 rt_context_ = nullptr;
123 }
124
SetContext()125 void AscendKernelRuntime::SetContext() {
126 if (rt_context_ == nullptr) {
127 return;
128 }
129 if (thread_local_rt_context == rt_context_) {
130 return;
131 }
132 auto ret = rtCtxSetCurrent(rt_context_);
133 thread_local_rt_context = rt_context_;
134 if (ret != RT_ERROR_NONE) {
135 MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
136 }
137 }
138
SetCurrentContext()139 void AscendKernelRuntime::SetCurrentContext() {
140 if (rt_context_ == nullptr) {
141 return;
142 }
143 auto ret = rtCtxSetCurrent(rt_context_);
144 if (ret != RT_ERROR_NONE) {
145 MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
146 }
147 }
148
ClearGraphModelMap()149 void AscendKernelRuntime::ClearGraphModelMap() {
150 SetCurrentContext();
151 #ifndef ENABLE_SECURITY
152 for (auto &iter : graph_data_dumper_) {
153 MS_LOG(INFO) << "[DataDump] Unload data dumper:" << iter.first;
154 auto &data_dumper = iter.second;
155 MS_EXCEPTION_IF_NULL(data_dumper);
156 data_dumper->UnloadDumpInfo();
157 data_dumper->OpDebugUnregister();
158 }
159 graph_data_dumper_.clear();
160 // tell users which dump kernel name not used
161 DumpJsonParser::GetInstance().PrintUnusedKernel();
162 #endif
163
164 graph_dynamic_kernel_map_.clear();
165 graph_kernel_events_map_.clear();
166 for (auto &iter : graph_model_map_) {
167 MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
168 ModelRunner::Instance().UnloadModel(iter.first);
169 }
170 }
171
ClearGraphRuntimeResource(uint32_t graph_id)172 void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
173 SetCurrentContext();
174 auto mem_scheduler = mem_scheduler_manager_.GetMemScheduler(graph_id);
175 if (mem_scheduler != nullptr) {
176 mem_scheduler->Clear();
177 }
178 MS_LOG(DEBUG) << "Clear graph:" << graph_id << " data dumper";
179 #ifndef ENABLE_SECURITY
180 if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
181 MS_LOG(DEBUG) << "Unload dump info " << graph_id;
182 auto &data_dumper = dumper_iter->second;
183 MS_EXCEPTION_IF_NULL(data_dumper);
184 data_dumper->UnloadDumpInfo();
185 data_dumper->OpDebugUnregister();
186 graph_data_dumper_.erase(dumper_iter);
187 } else {
188 MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
189 }
190 #endif
191
192 MS_LOG(DEBUG) << "Clear graph:" << graph_id << " dynamic kernels";
193 if (auto dynamic_kernel_iter = graph_dynamic_kernel_map_.find(graph_id);
194 dynamic_kernel_iter != graph_dynamic_kernel_map_.end()) {
195 MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel";
196 graph_dynamic_kernel_map_.erase(dynamic_kernel_iter);
197 }
198 auto events_iter = graph_kernel_events_map_.find(graph_id);
199 if (events_iter != graph_kernel_events_map_.end()) {
200 graph_kernel_events_map_.erase(events_iter);
201 }
202 MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
203 if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
204 MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
205 ModelRunner::Instance().UnloadModel(graph_id);
206 graph_model_map_.erase(model_iter);
207 } else {
208 MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
209 }
210 }
211
GetModelStream(uint32_t graph_id) const212 void *AscendKernelRuntime::GetModelStream(uint32_t graph_id) const {
213 return ModelRunner::Instance().GetModelStream(graph_id);
214 }
215
ClearGlobalIdleMem()216 void AscendKernelRuntime::ClearGlobalIdleMem() {
217 if (mem_manager_ != nullptr) {
218 mem_manager_->ClearGlobalIdleMem();
219 }
220 }
221
NeedDestroyHccl()222 bool AscendKernelRuntime::NeedDestroyHccl() {
223 auto context_ptr = MsContext::GetInstance();
224 MS_EXCEPTION_IF_NULL(context_ptr);
225 if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
226 MS_LOG(INFO) << "Hccl is not enabled";
227 return false;
228 }
229 // Note: make sure hcom_connectivity_detection api never be used.
230 return true;
231 }
232
233 #ifndef ENABLE_SECURITY
AsyncDataDumpUninit()234 void AsyncDataDumpUninit() {
235 if (DumpJsonParser::GetInstance().async_dump_enabled()) {
236 if (AdxDataDumpServerUnInit() != 0) {
237 MS_LOG(ERROR) << "Adx data dump server uninit failed";
238 }
239 }
240 }
241
ReportProfilingData()242 void AscendKernelRuntime::ReportProfilingData() {
243 auto context = MsContext::GetInstance();
244 MS_EXCEPTION_IF_NULL(context);
245 if (ProfilingManager::GetInstance().IsProfiling() &&
246 context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
247 // Save Profiling Framework data
248 OpNameTaskStreamReporter reporter(device_id_, "nonsink", stream_id_task_id_op_name_map_);
249 reporter.ReportData();
250 }
251 }
252 #endif
253
ReleaseDeviceRes()254 void AscendKernelRuntime::ReleaseDeviceRes() {
255 MS_LOG(INFO) << "Ascend finalize start";
256 #ifdef ENABLE_DEBUGGER
257 if (debugger_ && debugger_->debugger_enabled()) {
258 debugger_->SetTrainingDone(true);
259 bool ret = debugger_->SendMetadata(false);
260 if (!ret) {
261 MS_LOG(ERROR) << "Failed to SendMetadata when finalize";
262 }
263 }
264 #endif
265 if (!initialized_) {
266 return;
267 }
268 SetCurrentContext();
269 #ifndef ENABLE_SECURITY
270 ReportProfilingData();
271 #endif
272 // release ge runtime
273 ClearGraphModelMap();
274
275 #ifndef ENABLE_SECURITY
276 AsyncDataDumpUninit();
277 #endif
278
279 auto context_ptr = MsContext::GetInstance();
280 MS_EXCEPTION_IF_NULL(context_ptr);
281 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
282
283 // DestroyHccl must be called before FreeDeviceMemory
284 (void)DestroyHccl();
285 if (mem_manager_ != nullptr) {
286 mem_manager_->FreeDeviceMemory();
287 }
288
289 auto rt_ret = rtRegTaskFailCallbackByModule(kModuleName, nullptr);
290 if (rt_ret != RT_ERROR_NONE) {
291 MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret;
292 }
293
294 (void)ResetDevice(device_id);
295 #ifndef ENABLE_SECURITY
296 (void)ProfilingManager::GetInstance().StopProfiling();
297 #endif
298 current_graph_ = nullptr;
299 if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
300 !context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
301 HcclCollectiveGroup::instance().FinalizeCollective();
302 }
303 MS_LOG(INFO) << "Ascend finalize end";
304 }
305
306 #ifndef ENABLE_SECURITY
PreInit()307 void AscendKernelRuntime::PreInit() {
308 const auto error_manager_ret = ErrorManager::GetInstance().Init();
309 if (error_manager_ret != 0) {
310 MS_LOG(WARNING) << "Init ErrorManager failed.";
311 }
312 auto ret = ProfilingManager::GetInstance().StartupProfiling(device_id_);
313 if (!ret) {
314 const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
315 if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
316 MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
317 }
318 MS_EXCEPTION(DeviceProcessError) << "StartupProfiling failed.";
319 }
320 }
321 #endif
322
Init()323 bool AscendKernelRuntime::Init() {
324 auto ms_context = MsContext::GetInstance();
325 MS_EXCEPTION_IF_NULL(ms_context);
326 auto execution_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
327
328 auto profiler_manager = profiler::ProfilerManager::GetInstance();
329 MS_EXCEPTION_IF_NULL(profiler_manager);
330 auto profiling_flag = profiler_manager->GetProfilingEnableFlag();
331 if (execution_mode == kPynativeMode && profiling_flag) {
332 pynative_mode_profiling_flag_ = true;
333 }
334 if (initialized_) {
335 SetCurrentContext();
336 return true;
337 }
338 const auto error_manager_ret = ErrorManager::GetInstance().Init();
339 if (error_manager_ret != 0) {
340 MS_LOG(WARNING) << "Init ErrorManager failed.";
341 }
342 try {
343 // Start up profiling before rtSetDevice
344 bool ret = InitDevice();
345 if (!ret) {
346 return ret;
347 }
348 #ifdef ENABLE_DEBUGGER
349 SetDebugger();
350 #endif
351 mem_manager_ = std::make_shared<AscendMemoryManager>();
352 MS_EXCEPTION_IF_NULL(mem_manager_);
353 mem_manager_->MallocDeviceMemory();
354
355 // Set callback func when exception error
356 auto rt_ret = rtRegTaskFailCallbackByModule(kModuleName, TaskFailCallback);
357 if (rt_ret != RT_ERROR_NONE) {
358 MS_LOG(EXCEPTION) << "Reg SetTaskFailCallback failed, error: " << rt_ret;
359 }
360 } catch (const std::exception &e) {
361 const string &error_message = ErrorManager::GetInstance().GetErrorMessage();
362 if (!error_message.empty() && error_message.find(kUnknowErrorString) == string::npos) {
363 MS_LOG(EXCEPTION) << "Ascend error occurred, error message: " << error_message
364 << "\nFirst error scene API: " << e.what();
365 }
366 throw;
367 }
368
369 initialized_ = true;
370 return true;
371 }
372
LoadData(const session::KernelGraph &)373 bool AscendKernelRuntime::LoadData(const session::KernelGraph & /* graph */) {
374 #ifdef ENABLE_DEBUGGER
375 MS_LOG(INFO) << "Start load step";
376 MS_EXCEPTION_IF_NULL(debugger_);
377 for (const auto &graph_ptr : debugger_->GetGraphPtrList()) {
378 debugger_->SetGraphPtr(graph_ptr);
379 // load output
380 debugger_->LoadGraphOutputs();
381 // load parameters
382 debugger_->LoadParametersAndConst();
383 }
384 #endif
385 return true;
386 }
387
KernelMemNotReuse(const AnfNodePtr & node)388 bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
389 MS_EXCEPTION_IF_NULL(node);
390 bool need_dump = false;
391 #ifndef ENABLE_SECURITY
392 auto &dump_json_parser = DumpJsonParser::GetInstance();
393 if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 1) {
394 auto op_name = node->fullname_with_scope();
395 if (dump_json_parser.NeedDump(op_name)) {
396 need_dump = true;
397 }
398 }
399 #endif
400 return need_dump;
401 }
402
CreateDeviceAddress(void * device_ptr,size_t device_size,const string & format,TypeId type_id) const403 DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
404 TypeId type_id) const {
405 return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
406 }
407
CreateDeviceAddress(void * device_ptr,size_t device_size,const string & format,TypeId type_id,const KernelWithIndex & node_index) const408 DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
409 TypeId type_id, const KernelWithIndex &node_index) const {
410 return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id, node_index);
411 }
412
Load(const session::KernelGraph & graph,bool is_task_sink)413 bool AscendKernelRuntime::Load(const session::KernelGraph &graph, bool is_task_sink) {
414 if (!is_task_sink) {
415 MS_LOG(INFO) << "Graph mode with not task sink";
416 GenKernelEvents(graph);
417 return true;
418 }
419
420 if (!GenTask(graph)) {
421 return false;
422 }
423 if (!LoadTask(graph)) {
424 return false;
425 }
426 return true;
427 }
428
GenDynamicKernel(const session::KernelGraph & graph)429 bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph &graph) {
430 MS_LOG(INFO) << "GenDynamicKernel start";
431 auto cnode_list = graph.execution_order();
432 std::vector<DynamicKernelPtr> dynamic_kernels;
433 for (const auto &cnode : cnode_list) {
434 MS_EXCEPTION_IF_NULL(cnode);
435 MS_LOG(INFO) << "Generate node:" << cnode->fullname_with_scope() << " dynamic kernel";
436 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
437 MS_EXCEPTION_IF_NULL(kernel_mod);
438 auto dynamic_kernel = kernel_mod->GenDynamicKernel(cnode, stream_);
439 if (dynamic_kernel == nullptr) {
440 MS_LOG(EXCEPTION) << "Dynamic shape is not supported with the operator [" << AnfAlgo::GetCNodeName(cnode) << "].";
441 }
442 dynamic_kernel->Initialize();
443 dynamic_kernels.emplace_back(dynamic_kernel);
444 }
445 graph_dynamic_kernel_map_[graph.graph_id()] = std::move(dynamic_kernels);
446 MS_LOG(INFO) << "GenDynamicKernel end";
447 return true;
448 }
449
GenTask(const session::KernelGraph & graph)450 bool AscendKernelRuntime::GenTask(const session::KernelGraph &graph) {
451 SetCurrentContext();
452 if (graph.is_dynamic_shape()) {
453 if (ConfigManager::GetInstance().dataset_mode() == DS_SINK_MODE && (ConfigManager::GetInstance().iter_num() > 1)) {
454 MS_LOG(EXCEPTION) << "Dynamic shape is not supported with dataset_sink_mode.";
455 }
456 #ifndef ENABLE_SECURITY
457 if (DumpJsonParser::GetInstance().async_dump_enabled()) {
458 MS_LOG(EXCEPTION) << "Dynamic shape is not supported with Asynchronous Dump. Please use Synchronous Dump.";
459 }
460 #endif
461 MS_LOG(INFO) << "Dynamic Shape Graph Generate Dynamic kernel";
462 return GenDynamicKernel(graph);
463 }
464 MS_LOG(INFO) << "GenTask start. GraphId:" << graph.graph_id();
465 #ifndef ENABLE_SECURITY
466 DumpJsonParser::GetInstance().UpdateNeedDumpKernels(graph);
467 #endif
468 #ifdef MEM_REUSE_DEBUG
469 if (!EnvConfigParser::GetInstance().GetSysMemreuse()) {
470 // Get normal graph ir for memreuse
471 mindspore::memreuse::MemReuseChecker::GetInstance().CheckNormalIR(&graph);
472 }
473 #endif
474 vector<std::shared_ptr<TaskInfo>> task_info_list;
475 auto anf_node_list = graph.execution_order();
476 auto task_generator = TaskGenerator();
477 if (!task_generator.GenTasks(anf_node_list, &task_info_list, graph.graph_id())) {
478 return false;
479 }
480 // Store the task_info_list
481 auto insert_ret = task_map_.insert(std::make_pair(graph.graph_id(), task_info_list));
482 if (!insert_ret.second) {
483 MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
484 }
485 // Graph may have no compute node, such TensorAddGrad.
486 if (task_info_list.empty()) {
487 MS_LOG(WARNING) << "Graph " << graph.graph_id() << " have no compute node";
488 return true;
489 }
490 AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
491 AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
492 // the streams' flag not HEAD_STREAM
493 std::vector<uint32_t> wait_active_stream_list;
494 assign_instance.GetWaitStreams(&wait_active_stream_list);
495 std::vector<uint32_t> force_copy_stream_list;
496 assign_instance.GetHcomStreams(&force_copy_stream_list);
497 MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num()
498 << ", total event num:" << resource_manager.get_cur_event_num()
499 << ", total label num:" << graph.label_num()
500 << ", wait_active_stream_list size:" << wait_active_stream_list.size()
501 << ", force_copy_stream_list size:" << force_copy_stream_list.size();
502 auto model = std::make_shared<ge::model_runner::DavinciModel>(
503 task_info_list, wait_active_stream_list, force_copy_stream_list, 0, 0, 0, 0, 0, 0,
504 resource_manager.get_cur_stream_num(), graph.label_num(), resource_manager.get_cur_event_num(), 0);
505 auto ret = graph_model_map_.insert(std::make_pair(graph.graph_id(), model));
506 if (!ret.second) {
507 MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
508 }
509 MS_LOG(INFO) << "TaskGenerator GetTaskInfo end...";
510 return true;
511 }
512
LoadTask(const session::KernelGraph & graph)513 bool AscendKernelRuntime::LoadTask(const session::KernelGraph &graph) {
514 SetCurrentContext();
515 if (graph.is_dynamic_shape()) {
516 MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
517 return true;
518 }
519
520 MS_LOG(INFO) << "LoadTask start. GraphId:" << graph.graph_id();
521 if (GraphWithEmptyTaskList(graph)) {
522 MS_LOG(WARNING) << "LoadTask end, task list is empty";
523 return true;
524 }
525
526 auto model_iter = graph_model_map_.find(graph.graph_id());
527 if (model_iter == graph_model_map_.end()) {
528 MS_LOG(ERROR) << "GraphId:" << graph.graph_id() << " Invalid! Graph LoadTask without GenTask.";
529 return false;
530 }
531
532 MS_LOG(INFO) << "LoadDavinciModel mode_id:" << model_iter->first;
533 ModelRunner::Instance().LoadDavinciModel(device_id_, 0, model_iter->first, model_iter->second);
534
535 #ifndef ENABLE_SECURITY
536 std::function<void *()> model_handle =
537 std::bind(&ModelRunner::GetModelHandle, &ModelRunner::Instance(), model_iter->first);
538 DistributeDebugTask(graph, NOT_NULL(model_handle));
539 #endif
540
541 try {
542 ModelRunner::Instance().DistributeTask(model_iter->first);
543 } catch (const std::exception &e) {
544 #ifdef ENABLE_DUMP_IR
545 mindspore::RDR::TriggerAll();
546 #endif
547 MS_LOG(EXCEPTION) << "Distribute Task Failed, error: " << e.what();
548 }
549
550 #ifndef ENABLE_SECURITY
551 if (ProfilingManager::GetInstance().IsProfiling()) {
552 auto task_ids = ModelRunner::Instance().GetTaskIdList(model_iter->first);
553 auto stream_ids = ModelRunner::Instance().GetStreamIdList(model_iter->first);
554 ProfilingUtils::ReportProfilingData(task_ids, stream_ids, graph);
555 }
556 LaunchDataDump(graph.graph_id());
557 #endif
558
559 ModelRunner::Instance().LoadModelComplete(model_iter->first);
560 return true;
561 }
562
563 #ifndef ENABLE_SECURITY
DistributeDebugTask(const session::KernelGraph & graph,const NotNull<std::function<void * ()>> & model_handle)564 void AscendKernelRuntime::DistributeDebugTask(const session::KernelGraph &graph,
565 const NotNull<std::function<void *()>> &model_handle) {
566 if (!DumpJsonParser::GetInstance().async_dump_enabled()) {
567 return;
568 }
569 MS_LOG(INFO) << "Start Distribute Debug Task";
570 auto data_dumper = std::make_shared<DataDumper>(&graph, model_handle);
571 MS_EXCEPTION_IF_NULL(data_dumper);
572 auto ret = graph_data_dumper_.try_emplace(graph.graph_id(), data_dumper);
573 data_dumper->OpDebugRegister();
574 if (!ret.second) {
575 MS_LOG(WARNING) << "[DataDump] Insert graphId:" << graph.graph_id() << " data dumper failed";
576 }
577 }
578
LaunchDataDump(GraphId graph_id)579 void AscendKernelRuntime::LaunchDataDump(GraphId graph_id) {
580 if (!DumpJsonParser::GetInstance().async_dump_enabled()) {
581 return;
582 }
583 MS_LOG(INFO) << "Start Launch Dump Data";
584 auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(graph_id);
585 if (auto dumper_iter = graph_data_dumper_.find(graph_id); dumper_iter != graph_data_dumper_.end()) {
586 auto &data_dumper = dumper_iter->second;
587 MS_EXCEPTION_IF_NULL(data_dumper);
588 data_dumper->set_runtime_info(runtime_info_map);
589 data_dumper->LoadDumpInfo();
590 } else {
591 MS_LOG(EXCEPTION) << "GraphId:" << graph_id << " not found";
592 }
593 }
594 #endif
595
TaskFailCallback(rtExceptionInfo * task_fail_info)596 void AscendKernelRuntime::TaskFailCallback(rtExceptionInfo *task_fail_info) {
597 if (task_fail_info == nullptr || current_graph_ == nullptr) {
598 MS_LOG(ERROR) << "Execute TaskFailCallback failed. task_fail_info or current_graph_ is nullptr";
599 return;
600 }
601
602 static std::mutex exception_mutex;
603 constexpr uint32_t kOverflowThreshold = 5;
604 std::lock_guard<std::mutex> lock(exception_mutex);
605 if (task_fail_info->retcode == ACL_ERROR_RT_AICORE_OVER_FLOW) {
606 auto node = AscendKernelRuntime::GetErrorNodeName(task_fail_info->streamid, task_fail_info->taskid);
607 if (!node) {
608 MS_LOG(WARNING) << "Node run task overflow, node name is unknown.";
609 } else {
610 auto key = std::to_string(task_fail_info->streamid) + std::to_string(task_fail_info->taskid) +
611 std::to_string(current_graph_->graph_id());
612 if (overflow_tasks_.find(key) == overflow_tasks_.end() || overflow_tasks_[key] == kOverflowThreshold) {
613 // print overflow info
614 MS_LOG(WARNING) << "Node run task overflow, node name: " << node->fullname_with_scope()
615 << "Task overflow infos task_id: " << task_fail_info->taskid
616 << ", stream_id: " << task_fail_info->streamid << ", tid: " << task_fail_info->tid
617 << ", device_id: " << task_fail_info->deviceid << ", retcode: " << task_fail_info->retcode
618 << " (" << GetErrorMsg(task_fail_info->retcode) << ")"
619 << ", trace: " << trace::DumpSourceLines(node);
620 overflow_tasks_[key] = 1;
621 } else {
622 overflow_tasks_[key]++;
623 }
624 }
625 } else {
626 task_fail_infoes_.push_back(*task_fail_info);
627 }
628 }
629
GetErrorNodeName(uint32_t streamid,uint32_t taskid)630 CNodePtr AscendKernelRuntime::GetErrorNodeName(uint32_t streamid, uint32_t taskid) {
631 if (current_graph_ == nullptr) {
632 return nullptr;
633 }
634 auto runtime_info_map = ModelRunner::Instance().GetRuntimeInfoMap(current_graph_->graph_id());
635 for (const auto &iter : runtime_info_map) {
636 MS_EXCEPTION_IF_NULL(iter.second);
637 auto task_id = std::get<kTupleTaskId>(*iter.second);
638 auto stream_id = std::get<kTupleStreamId>(*iter.second);
639 if (task_id == taskid && stream_id == streamid) {
640 auto &execute_node = current_graph_->execution_order();
641 auto node = std::find_if(execute_node.begin(), execute_node.end(), [&iter](const auto &node) {
642 MS_EXCEPTION_IF_NULL(node);
643 return node->UniqueName() == iter.first;
644 });
645 if (node != execute_node.end()) {
646 return *node;
647 }
648 }
649 }
650 return nullptr;
651 }
652
GetDumpPath()653 std::string AscendKernelRuntime::GetDumpPath() {
654 uint32_t rank_id = 0;
655 auto inst = parallel::ParallelContext::GetInstance();
656 MS_EXCEPTION_IF_NULL(inst);
657 if (inst->parallel_mode() != parallel::STAND_ALONE) {
658 if (!CommManager::GetInstance().GetRankID(kHcclWorldGroup, &rank_id)) {
659 MS_LOG(WARNING) << "Get rank id failed.";
660 }
661 }
662
663 auto ms_om_path = common::GetEnv("MS_OM_PATH");
664 std::string path;
665 const auto kSuffix = "/node_dump";
666 if (ms_om_path.empty()) {
667 MS_LOG(WARNING) << "MS_OM_PATH is null, so dump to process local path, as ./rank_id/node_dump/...";
668 path = "./rank_" + std::to_string(rank_id) + kSuffix;
669 } else {
670 path = ms_om_path + "/rank_" + std::to_string(rank_id) + kSuffix;
671 }
672 return path;
673 }
674
675 #ifndef ENABLE_SECURITY
DumpTaskExceptionInfo(const session::KernelGraph &)676 void AscendKernelRuntime::DumpTaskExceptionInfo(const session::KernelGraph & /* graph */) {
677 const std::string path = GetDumpPath();
678 if (access(path.c_str(), F_OK) == 0) {
679 if (!DeleteDumpDir(path)) {
680 MS_LOG(ERROR) << "Delete dump directory " << path << " failed";
681 }
682 }
683 for (const auto &task_fail_info : task_fail_infoes_) {
684 MS_LOG(ERROR) << "Task fail infos task_id: " << task_fail_info.taskid << ", stream_id: " << task_fail_info.streamid
685 << ", tid: " << task_fail_info.tid << ", device_id: " << task_fail_info.deviceid
686 << ", retcode: " << task_fail_info.retcode << " (" << GetErrorMsg(task_fail_info.retcode) << ")";
687 auto node = AscendKernelRuntime::GetErrorNodeName(task_fail_info.streamid, task_fail_info.taskid);
688 // Dump error data in local path
689 if (node == nullptr) {
690 continue;
691 }
692 auto full_scope_name = node->fullname_with_scope();
693 MS_LOG(ERROR) << "Dump node (" << full_scope_name << ") task error input/output data to: " << path
694 << " trace: " << trace::DumpSourceLines(node);
695 E2eDump::DumpInputImpl(node, false, path, &full_scope_name, nullptr);
696 E2eDump::DumpOutputImpl(node, false, path, &full_scope_name, nullptr);
697 }
698 }
699 #endif
700
Run(const session::KernelGraph & graph,bool is_task_sink)701 bool AscendKernelRuntime::Run(const session::KernelGraph &graph, bool is_task_sink) {
702 const uint64_t kUSecondInSecond = 1000000;
703 SignalGuard sg(IntHandler);
704 bool ret = false;
705
706 if (is_task_sink) {
707 #if defined(_WIN32) || defined(_WIN64)
708 auto start_time = std::chrono::steady_clock::now();
709 #else
710 struct timeval start_time {};
711 struct timeval end_time {};
712 (void)gettimeofday(&start_time, nullptr);
713 #endif
714 ret = RunTask(graph);
715 #if defined(_WIN32) || defined(_WIN64)
716 auto end_time = std::chrono::steady_clock::now();
717 std::chrono::duration<double, std::ratio<1, kUSecondInSecond>> cost = end_time - start_time;
718 MS_LOG(INFO) << "Call MS Run Success in " << cost.count() << " us";
719 #else
720 (void)gettimeofday(&end_time, nullptr);
721 uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
722 cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
723 MS_LOG(INFO) << "Call MS Run Success in " << cost << " us";
724 #endif
725 } else {
726 ret = LaunchKernels(graph);
727 }
728
729 return ret;
730 }
731
SetKernelModStream(const std::vector<CNodePtr> & kernels,std::vector<size_t> * last_stream_nodes)732 void AscendKernelRuntime::SetKernelModStream(const std::vector<CNodePtr> &kernels,
733 std::vector<size_t> *last_stream_nodes) {
734 auto context_ptr = MsContext::GetInstance();
735 MS_EXCEPTION_IF_NULL(context_ptr);
736 auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
737 std::map<void *, size_t> last_kernel;
738 for (size_t i = 0; i < kernels.size(); ++i) {
739 auto &node = kernels[i];
740 auto kernel_mod = AnfAlgo::GetKernelMod(node);
741 auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
742 MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
743 if (AnfAlgo::IsCommunicationOp(node)) {
744 auto group = AnfAlgo::GetNodeAttr<std::string>(node, kAttrGroup);
745 auto iter = group_stream_id_map_.find(group);
746 if (iter == group_stream_id_map_.end()) {
747 void *stream = nullptr;
748 auto ret = rtStreamCreate(&stream, 0);
749 if (ret != RT_ERROR_NONE) {
750 MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
751 }
752 auto id = SizeToUint(stream_id_map_.size());
753 group_stream_id_map_[group] = id;
754 stream_id_map_[id] = stream;
755 AnfAlgo::SetStreamId(id, node.get());
756 ascend_kernel_mod->SetStream(stream);
757 last_kernel[stream] = i;
758 } else {
759 auto id = iter->second;
760 AnfAlgo::SetStreamId(id, node.get());
761 ascend_kernel_mod->SetStream(stream_id_map_[id]);
762 last_kernel[stream_id_map_[id]] = i;
763 }
764 } else if (AnfAlgo::IsIndependentNode(node) && mode != kPynativeMode) {
765 AnfAlgo::SetStreamId(1, node.get());
766 ascend_kernel_mod->SetStream(independent_stream_);
767 last_kernel[independent_stream_] = i;
768 } else {
769 AnfAlgo::SetStreamId(0, node.get());
770 ascend_kernel_mod->SetStream(stream_);
771 }
772 }
773 for (size_t i = 1; i < kernels.size(); ++i) {
774 if (AnfAlgo::GetCNodeName(kernels[i - 1]) == kAtomicAddrCleanOpName) {
775 auto stream_id = AnfAlgo::GetStreamId(kernels[i]);
776 AnfAlgo::SetStreamId(stream_id, kernels[i - 1].get());
777 auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i - 1]);
778 auto ascend_kernel_mod = dynamic_cast<kernel::AscendKernelMod *>(kernel_mod);
779 MS_EXCEPTION_IF_NULL(ascend_kernel_mod);
780 ascend_kernel_mod->SetStream(stream_id_map_[stream_id]);
781 }
782 }
783 (void)std::transform(last_kernel.begin(), last_kernel.end(), std::back_inserter(*last_stream_nodes),
784 [](const std::pair<void *, size_t> &item) { return item.second; });
785 }
786
GenKernelEvents(const session::KernelGraph & graph)787 void AscendKernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
788 auto &kernels = graph.execution_order();
789 if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
790 return;
791 }
792 std::vector<size_t> last_stream_nodes;
793 SetKernelModStream(kernels, &last_stream_nodes);
794 auto kernel_events =
795 std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
796 auto &kernel_pre_run_events = kernel_events.first;
797 auto &kernel_post_run_events = kernel_events.second;
798 kernel_pre_run_events.resize(kernels.size());
799 kernel_post_run_events.resize(kernels.size());
800 for (size_t i = 0; i < kernels.size(); ++i) {
801 auto &kernel = kernels[i];
802 auto curr_stream_id = AnfAlgo::GetStreamId(kernel);
803 if (stream_id_map_.find(curr_stream_id) == stream_id_map_.end()) {
804 MS_LOG(EXCEPTION) << "Stream " << curr_stream_id << "has not been created";
805 }
806 auto wait_stream = stream_id_map_[curr_stream_id];
807 auto stream_num = stream_id_map_.size();
808 std::vector<bool> stream_hit(stream_num, false);
809 std::vector<AnfNodePtr> used_kernels;
810 std::set<AnfNodePtr> visited_kernels;
811 AnfAlgo::GetAllVisitedCNode(kernel, &used_kernels, &visited_kernels);
812 bool found_depend = false;
813 for (int k = SizeToInt(i) - 1; k >= 0; --k) {
814 auto pre_cnode = kernels[IntToSize(k)];
815 auto pre_cnode_stream_id = AnfAlgo::GetStreamId(pre_cnode);
816 if (pre_cnode_stream_id == curr_stream_id) {
817 found_depend = true;
818 continue;
819 }
820 for (auto &visited : used_kernels) {
821 if (visited == pre_cnode && !stream_hit[pre_cnode_stream_id]) {
822 stream_hit[pre_cnode_stream_id] = true;
823 found_depend = true;
824 auto record_stream = stream_id_map_[pre_cnode_stream_id];
825 auto event = CreateDeviceEvent();
826 event->set_wait_stream(wait_stream);
827 event->set_record_stream(record_stream);
828 kernel_post_run_events[IntToSize(k)].emplace_back([event]() { event->RecordEvent(); });
829 kernel_pre_run_events[i].emplace_back([event]() { event->WaitEvent(); });
830 }
831 }
832 }
833 if (!found_depend && wait_stream != stream_) {
834 auto pre_event = CreateDeviceEvent();
835 pre_event->set_wait_stream(wait_stream);
836 pre_event->set_record_stream(stream_);
837 kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->RecordEvent(); });
838 kernel_pre_run_events[i].emplace_back([pre_event]() { pre_event->WaitEvent(); });
839 }
840 }
841 ProcessBoundaryEvent(kernels, &kernel_post_run_events, last_stream_nodes);
842 graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
843 }
844
ProcessBoundaryEvent(const std::vector<CNodePtr> & kernels,std::vector<std::vector<std::function<void ()>>> * kernel_run_events,const std::vector<size_t> & last_stream_nodes)845 void AscendKernelRuntime::ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
846 std::vector<std::vector<std::function<void()>>> *kernel_run_events,
847 const std::vector<size_t> &last_stream_nodes) {
848 for (auto &i : last_stream_nodes) {
849 if (i >= kernels.size()) {
850 MS_LOG(ERROR) << "Node index exceed kernel size.";
851 continue;
852 }
853 auto &kernel = kernels[i];
854 MS_EXCEPTION_IF_NULL(kernel);
855 bool found_nearest_child = false;
856 for (size_t j = i + 1; j < kernels.size(); ++j) {
857 auto &child = kernels[j];
858 MS_EXCEPTION_IF_NULL(child);
859 auto input_size = child->inputs().size() - 1;
860 for (size_t k = 0; k < input_size; ++k) {
861 auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
862 if (kernel_index.first == kernel) {
863 found_nearest_child = true;
864 break;
865 }
866 }
867 if (found_nearest_child) {
868 break;
869 }
870 }
871 if (!found_nearest_child) {
872 auto post_event = CreateDeviceEvent();
873 MS_EXCEPTION_IF_NULL(post_event);
874 auto id = AnfAlgo::GetStreamId(kernel);
875 auto record_stream = stream_id_map_[id];
876 post_event->set_wait_stream(stream_);
877 post_event->set_record_stream(record_stream);
878 (*kernel_run_events)[i].emplace_back([post_event]() { post_event->RecordEvent(); });
879 (*kernel_run_events)[i].emplace_back([post_event]() { post_event->WaitEvent(); });
880 }
881 }
882 }
883
RunDynamicKernelAsync(const session::KernelGraph & graph)884 bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph &graph) {
885 MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph.graph_id();
886 auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
887 if (iter == graph_dynamic_kernel_map_.end()) {
888 MS_LOG(ERROR) << "GraphId:" << graph.graph_id() << " Not Found! Please generator executor first";
889 return false;
890 }
891
892 auto dynamic_kernels = iter->second;
893 for (const auto &dynamic_kernel : dynamic_kernels) {
894 MS_EXCEPTION_IF_NULL(dynamic_kernel);
895 if (dynamic_kernel->have_depends() || dynamic_kernel->GetKernelType() == KernelType::HCCL_KERNEL) {
896 MS_LOG(INFO) << "Match Dynamic Kernel, Start SyncStream";
897 if (!SyncStream()) {
898 MS_LOG(ERROR) << "SyncStream failed";
899 return false;
900 }
901 }
902
903 if (dynamic_kernel->is_dynamic_shape()) {
904 dynamic_kernel->InferShape();
905 dynamic_kernel->UpdateArgs();
906 }
907
908 dynamic_kernel->Execute();
909 dynamic_kernel->PostExecute();
910 }
911
912 if (!SyncStream()) {
913 MS_LOG(ERROR) << "SyncStream failed";
914 return false;
915 }
916
917 return true;
918 }
919
RunTask(const session::KernelGraph & graph)920 bool AscendKernelRuntime::RunTask(const session::KernelGraph &graph) {
921 current_graph_ = &graph;
922 SetCurrentContext();
923 if (graph.is_dynamic_shape()) {
924 MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
925 return RunDynamicKernelAsync(graph);
926 }
927
928 MS_LOG(INFO) << "RunTask start. GraphId:" << graph.graph_id();
929
930 auto context_ptr = MsContext::GetInstance();
931 MS_EXCEPTION_IF_NULL(context_ptr);
932 if (GraphWithEmptyTaskList(graph)) {
933 MS_LOG(WARNING) << "RunTask end, no task info found";
934 return true;
935 }
936
937 if (!CheckGraphIdValid(graph.graph_id())) {
938 MS_LOG(ERROR) << "GraphId:" << graph.graph_id() << " Invalid! Graph RunTask without GenTask.";
939 return false;
940 }
941
942 try {
943 ModelRunner::Instance().RunModel(graph.graph_id());
944 } catch (const std::exception &) {
945 #ifndef ENABLE_SECURITY
946 DumpTaskExceptionInfo(graph);
947 #endif
948 #ifdef ENABLE_TDTQUE
949 // Run task error, we should call TdtHostDestroy to release tdt to avoid DeviceQueueOp hostPush hung
950 // case1: cpu usage 100% cause thread/process exit, but some tdt thread remain in backend
951 if (!TdtHandle::DestroyHandle()) {
952 MS_LOG(WARNING) << "Destroy tdt channel failed.";
953 } else {
954 MS_LOG(INFO) << "Destroy tdt channel success.";
955 }
956 #endif
957 return false;
958 }
959 task_fail_infoes_.clear();
960 return true;
961 }
962
SyncStream()963 bool AscendKernelRuntime::SyncStream() {
964 SetCurrentContext();
965 session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
966 for (auto &iter : stream_id_map_) {
967 if (rtStreamSynchronize(iter.second) != RT_ERROR_NONE) { // o for switch stream
968 MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
969 return false;
970 }
971 }
972 return true;
973 }
974
MemcpyAsync(void * dst,const void * src,uint64_t size,int32_t kind)975 bool AscendKernelRuntime::MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) {
976 SetCurrentContext();
977 if (stream_ == nullptr) {
978 MS_LOG(ERROR) << "MemcpyAsync failed. stream_ is nullptr";
979 return false;
980 }
981
982 auto copy_kind = static_cast<rtMemcpyKind_t>(kind);
983 if (copy_kind != RT_MEMCPY_HOST_TO_DEVICE_EX && copy_kind != RT_MEMCPY_DEVICE_TO_DEVICE) {
984 MS_LOG(EXCEPTION) << "Memory copy async not support cache host buffer in kind: " << kind;
985 }
986 if (RT_ERROR_NONE != rtMemcpyAsync(dst, size, src, size, static_cast<rtMemcpyKind_t>(kind), stream_)) {
987 MS_LOG(ERROR) << "Call runtime rtMemcpyAsync error.";
988 return false;
989 }
990 return true;
991 }
992
CreateContext()993 void AscendKernelRuntime::CreateContext() {
994 if (rt_context_ == nullptr) {
995 auto ret = rtCtxCreate(&rt_context_, 0, UintToInt(device_id_));
996 if (ret != RT_ERROR_NONE) {
997 MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
998 }
999 }
1000 SetCurrentContext();
1001 }
1002
InitDevice()1003 bool AscendKernelRuntime::InitDevice() {
1004 int device_count = 0;
1005 auto ret = rtGetDeviceCount(&device_count);
1006 if (ret != RT_ERROR_NONE) {
1007 MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
1008 }
1009
1010 ret = rtSetDevice(UintToInt(device_id_));
1011 if (ret != RT_ERROR_NONE) {
1012 MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
1013 }
1014
1015 auto context_ptr = MsContext::GetInstance();
1016 MS_EXCEPTION_IF_NULL(context_ptr);
1017 if (context_ptr == nullptr) {
1018 MS_LOG(ERROR) << "Get MsContext instance failed";
1019 return false;
1020 }
1021 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
1022 if (!HcclInit()) {
1023 MS_LOG(ERROR) << "HcclInit init failed";
1024 return false;
1025 }
1026 }
1027
1028 // Context will be created by rtSetDevice
1029 ret = rtCtxGetCurrent(&rt_context_);
1030 if (ret != RT_ERROR_NONE || rt_context_ == nullptr) {
1031 MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]";
1032 return false;
1033 }
1034
1035 ret = rtStreamCreateWithFlags(&stream_, 0, RT_STREAM_HUGE);
1036 if (ret != RT_ERROR_NONE) {
1037 MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
1038 }
1039 ret = rtStreamCreateWithFlags(&independent_stream_, 0, RT_STREAM_HUGE);
1040 if (ret != RT_ERROR_NONE) {
1041 MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
1042 }
1043 ret = rtStreamCreate(&communication_stream_, 0);
1044 if (ret != RT_ERROR_NONE) {
1045 MS_LOG(EXCEPTION) << "create communication stream failed, ret:" << ret;
1046 }
1047 const int kCommunicationStreamID = 2;
1048 stream_id_map_[0] = stream_;
1049 stream_id_map_[1] = independent_stream_;
1050 stream_id_map_[kCommunicationStreamID] = communication_stream_;
1051 group_stream_id_map_[kHcclWorldGroup] = kCommunicationStreamID;
1052 return true;
1053 }
1054
ResetDevice(uint32_t device_id)1055 bool AscendKernelRuntime::ResetDevice(uint32_t device_id) {
1056 SetCurrentContext();
1057 int32_t ret;
1058 for (auto &iter : stream_id_map_) {
1059 ret = rtStreamDestroy(iter.second);
1060 if (ret != RT_ERROR_NONE) {
1061 MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
1062 }
1063 iter.second = nullptr;
1064 }
1065 ret = rtDeviceReset(UintToInt(device_id));
1066 if (ret != RT_ERROR_NONE) {
1067 MS_EXCEPTION(DeviceProcessError) << "Call rtDeviceReset, ret[" << ret << "]";
1068 }
1069 // set to nullptr as its not created, only bounded to existing context
1070 rt_context_ = nullptr;
1071 return true;
1072 }
1073
HcclInit()1074 bool AscendKernelRuntime::HcclInit() {
1075 auto context_ptr = MsContext::GetInstance();
1076 MS_EXCEPTION_IF_NULL(context_ptr);
1077 if (!context::IsTsdOpened(context_ptr)) {
1078 MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
1079 }
1080 MS_LOG(INFO) << "Do hcom init";
1081 bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
1082 auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
1083 if (!is_task_sink && mode == kGraphMode) {
1084 (void)hccl::HcclAdapter::GetInstance().InitHccl();
1085 auto rank_size = HcclCollectiveGroup::instance().GetRankSize();
1086 std::vector<unsigned int> ranks(rank_size);
1087 std::iota(std::begin(ranks), std::end(ranks), 0);
1088 HcclCollectiveGroup::instance().CreateCommGroup(kHcclWorldGroup, ranks);
1089 return true;
1090 }
1091
1092 auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
1093 if (config_path_str == nullptr) {
1094 config_path_str = std::getenv("RANK_TABLE_FILE");
1095 if (config_path_str == nullptr) {
1096 MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE";
1097 return false;
1098 }
1099 }
1100 if (strlen(config_path_str) >= kPathMax) {
1101 MS_LOG(ERROR) << "File path oversize";
1102 return false;
1103 }
1104 std::string rank_id_str = GetRankIdStr();
1105 auto full_path = realpath(config_path_str, nullptr);
1106 if (full_path == nullptr) {
1107 MS_LOG(ERROR) << "File path " << config_path_str << " does not exist";
1108 return false;
1109 }
1110 MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
1111 bool ret = hccl::HcclAdapter::GetInstance().InitHccl(context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID), rank_id_str,
1112 full_path, mode == kGraphMode);
1113 free(full_path);
1114 if (!ret) {
1115 MS_LOG(ERROR) << "Hcom init failed.";
1116 return false;
1117 }
1118 return true;
1119 }
1120
DestroyHccl()1121 bool AscendKernelRuntime::DestroyHccl() {
1122 if (!NeedDestroyHccl()) {
1123 MS_LOG(INFO) << "Hccl is not enable, no need to close.";
1124 return true;
1125 }
1126 bool res = hccl::HcclAdapter::GetInstance().FinalizeHccl();
1127 if (!res) {
1128 MS_LOG(ERROR) << "Hccl destroy failed";
1129 return false;
1130 }
1131 MS_LOG(INFO) << "Hccl destroy successful.";
1132 auto context_ptr = MsContext::GetInstance();
1133 MS_EXCEPTION_IF_NULL(context_ptr);
1134 context_ptr->set_param<bool>(MS_CTX_ENABLE_HCCL, false);
1135 return true;
1136 }
1137
GraphWithEmptyTaskList(const session::KernelGraph & graph) const1138 bool AscendKernelRuntime::GraphWithEmptyTaskList(const session::KernelGraph &graph) const {
1139 auto iter = task_map_.find(graph.graph_id());
1140 if (iter == task_map_.end()) {
1141 MS_LOG(EXCEPTION) << "Unknown graph ptr";
1142 }
1143 return iter->second.empty();
1144 }
1145
CheckGraphIdValid(GraphId graph_id) const1146 bool AscendKernelRuntime::CheckGraphIdValid(GraphId graph_id) const {
1147 return task_map_.find(graph_id) != task_map_.end() && graph_model_map_.find(graph_id) != graph_model_map_.end();
1148 }
1149
KernelLaunchProfiling(const std::string & kernel_name)1150 void AscendKernelRuntime::KernelLaunchProfiling(const std::string &kernel_name) {
1151 auto profiler_manager = profiler::ProfilerManager::GetInstance();
1152 MS_EXCEPTION_IF_NULL(profiler_manager);
1153 if (!profiler_manager->GetProfilingEnableFlag()) {
1154 return;
1155 }
1156
1157 // save task info
1158 uint32_t stream_id;
1159 uint32_t task_id;
1160 auto rt_ret = rtGetTaskIdAndStreamID(&task_id, &stream_id);
1161 if (rt_ret != RT_ERROR_NONE) {
1162 MS_LOG(EXCEPTION) << "Profiling get task_id stream_id failed";
1163 }
1164 std::pair<uint32_t, uint32_t> stream_task_pair = {stream_id, task_id};
1165 auto try_emplace_ret = stream_id_task_id_op_name_map_.try_emplace(stream_task_pair, kernel_name);
1166 if (!try_emplace_ret.second) {
1167 MS_LOG(WARNING) << "Profiling duplicate key, task_id:" << stream_task_pair.second
1168 << " stream_id:" << stream_task_pair.first << " name:" << kernel_name;
1169 }
1170 if (stream_id_task_id_op_name_map_.size() > kProfilingMaxTaskIdInStream) {
1171 MS_LOG(EXCEPTION) << "Too many profiling data";
1172 }
1173 }
1174
CreateDeviceEvent()1175 std::shared_ptr<DeviceEvent> AscendKernelRuntime::CreateDeviceEvent() {
1176 auto ascend_event = std::make_shared<AscendEvent>();
1177 MS_EXCEPTION_IF_NULL(ascend_event);
1178 return ascend_event;
1179 }
1180
CreateDeviceTimeEvent()1181 std::shared_ptr<DeviceEvent> AscendKernelRuntime::CreateDeviceTimeEvent() {
1182 auto ascend_time_event = std::make_shared<AscendTimeEvent>();
1183 MS_EXCEPTION_IF_NULL(ascend_time_event);
1184 return ascend_time_event;
1185 }
1186
GetAvailableMemMaxSize() const1187 uint64_t AscendKernelRuntime::GetAvailableMemMaxSize() const {
1188 auto ascend_mem_manager = std::dynamic_pointer_cast<AscendMemoryManager>(mem_manager_);
1189 MS_EXCEPTION_IF_NULL(ascend_mem_manager);
1190 return ascend_mem_manager->GetDeviceMemSize();
1191 }
1192
DeleteDumpDir(const std::string & path)1193 bool AscendKernelRuntime::DeleteDumpDir(const std::string &path) {
1194 string real_path = GetRealPath(path);
1195 if (DeleteDumpFile(real_path) == -1) {
1196 return false;
1197 }
1198 if (rmdir(real_path.c_str()) == -1) {
1199 MS_LOG(WARNING) << "Delete dir " << real_path << " failed!";
1200 }
1201 return true;
1202 }
1203
DeleteDumpFile(std::string path)1204 int AscendKernelRuntime::DeleteDumpFile(std::string path) {
1205 DIR *dir;
1206 struct dirent *dirinfo;
1207 struct stat statbuf;
1208 string filepath;
1209 int result = 0;
1210 lstat(path.c_str(), &statbuf);
1211
1212 if (S_ISREG(statbuf.st_mode)) {
1213 result = remove(path.c_str());
1214 } else if (S_ISDIR(statbuf.st_mode)) {
1215 if ((dir = opendir(path.c_str())) == nullptr) {
1216 return -1;
1217 }
1218
1219 while (!result && (dirinfo = readdir(dir))) {
1220 if (path[path.size() - 1] != '/') {
1221 path = path + "/";
1222 }
1223 MS_EXCEPTION_IF_NULL(dirinfo);
1224 filepath = path + dirinfo->d_name;
1225 if (strcmp(dirinfo->d_name, ".") == 0 || strcmp(dirinfo->d_name, "..") == 0) continue;
1226 result = DeleteDumpFile(filepath);
1227 if (!result) {
1228 if (rmdir(filepath.c_str()) == -1) {
1229 MS_LOG(WARNING) << "Delete dir " << filepath << " failed!";
1230 }
1231 }
1232 }
1233 if (closedir(dir) == -1) {
1234 MS_LOG(WARNING) << "Dump dir " << path << " close failed!";
1235 }
1236 }
1237 return result;
1238 }
1239
GetRealPath(const std::string & path)1240 std::string AscendKernelRuntime::GetRealPath(const std::string &path) {
1241 char real_path_mem[kPathMax] = {0};
1242 char *real_path_ret = realpath(path.c_str(), real_path_mem);
1243 if (real_path_ret == nullptr) {
1244 return "";
1245 }
1246 return std::string(real_path_mem);
1247 }
1248 } // namespace mindspore::device::ascend
1249