1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/hal/hardware/ge_device_res_manager.h"
18 #ifndef _WIN32
19 #include <dlfcn.h>
20 #include <libgen.h>
21 #endif
22 #include "plugin/device/ascend/hal/hardware/ge_utils.h"
23 #include <utility>
24 #include "plugin/device/cpu/hal/device/cpu_memory_manager.h"
25 #include "plugin/device/ascend/hal/device/ascend_memory_manager.h"
26 #include "plugin/device/ascend/hal/device/ascend_device_address.h"
27 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
28 #include "plugin/device/ascend/hal/device/ascend_device_synchronizer.h"
29 #include "plugin/device/ascend/hal/device/ascend_event.h"
30 #include "plugin/device/ascend/hal/device/ascend_pin_mem_pool.h"
31 #include "plugin/device/cpu/hal/device/cpu_device_synchronizer.h"
32 #include "include/transform/graph_ir/utils.h"
33 #include "graph/types.h"
34 #include "transform/symbol/acl_rt_symbol.h"
35 #include "transform/symbol/symbol_utils.h"
36 #include "transform/acl_ir/op_api_util.h"
37 #include "include/backend/mem_reuse/mem_tracker.h"
38 #include "graph/def_types.h"
39 #include "runtime/device/move_to.h"
40
41 namespace mindspore {
42 namespace device {
43 namespace ascend {
GetCurrentDir()44 std::string GetCurrentDir() {
45 #ifndef _WIN32
46 Dl_info dl_info;
47 if (dladdr(reinterpret_cast<void *>(GetCurrentDir), &dl_info) == 0) {
48 MS_LOG(WARNING) << "Get dladdr error";
49 return "";
50 }
51 std::string cur_so_path = dl_info.dli_fname;
52 return dirname(cur_so_path.data());
53 #else
54 return "";
55 #endif
56 }
57
Malloc(size_t size)58 ::ge::MemBlock *GeAllocator::Malloc(size_t size) {
59 auto addr = res_manager_->AllocateMemory(size);
60 MS_LOG(DEBUG) << "GE Allocator malloc addr: " << addr << " size: " << size;
61 auto mem_block = new ::ge::MemBlock(*this, addr, size);
62 return mem_block;
63 }
64
Free(::ge::MemBlock * block)65 void GeAllocator::Free(::ge::MemBlock *block) {
66 res_manager_->FreeMemory(block->GetAddr());
67 MS_LOG(DEBUG) << "GE Allocator free addr: " << block->GetAddr();
68 delete block;
69 }
70
Initialize()71 void GeDeviceResManager::Initialize() {
72 auto ms_context = MsContext::GetInstance();
73 MS_EXCEPTION_IF_NULL(ms_context);
74 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
75 runtime_instance_ = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
76 MS_EXCEPTION_IF_NULL(runtime_instance_);
77 if (!runtime_instance_->Init()) {
78 MS_LOG(EXCEPTION) << "Kernel runtime init error.";
79 }
80 mem_manager_ = runtime_instance_->GetMemoryManager();
81 MS_EXCEPTION_IF_NULL(mem_manager_);
82 if (ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD)) {
83 swap_manager_ = std::make_shared<SwapManager>(kDefaultStreamIndex, &AscendMemoryPool::GetInstance(),
84 &AscendPinMemPool::GetInstance());
85 }
86 }
87
SetCPUMemManager()88 void GeDeviceResManager::SetCPUMemManager() {
89 if (is_use_cpu_memory_) {
90 return;
91 }
92 if (mem_manager_ != nullptr) {
93 mem_manager_->Finalize();
94 mem_manager_ = nullptr;
95 }
96 runtime_instance_ = nullptr;
97 mem_manager_ = std::make_shared<cpu::CPUMemoryManager>();
98 MS_EXCEPTION_IF_NULL(mem_manager_);
99 is_use_cpu_memory_ = true;
100 }
101
Destroy()102 void GeDeviceResManager::Destroy() {
103 (void)DestroyAllEvents();
104 // Release memory.
105 if (mem_manager_ != nullptr) {
106 mem_manager_->Finalize();
107 mem_manager_ = nullptr;
108 }
109 }
110
AllocateMemory(DeviceAddress * const & address,uint32_t stream_id) const111 bool GeDeviceResManager::AllocateMemory(DeviceAddress *const &address, uint32_t stream_id) const {
112 MS_EXCEPTION_IF_NULL(address);
113 MS_EXCEPTION_IF_NULL(mem_manager_);
114 auto device_name_in_address = GetDeviceNameByType(static_cast<const DeviceType>(address->GetDeviceType()));
115 if (IsEnableRefMode() && device_name_in_address != device_context_->device_context_key().device_name_) {
116 MS_LOG(EXCEPTION) << "The device address type is wrong: type name in address:" << device_name_in_address
117 << ", type name in context:" << device_context_->device_context_key().device_name_;
118 }
119
120 if (address->GetPtr() != nullptr) {
121 MS_LOG(ERROR) << "Memory leak detected!";
122 return false;
123 }
124
125 if (runtime_instance_ != nullptr) {
126 runtime_instance_->SetContext();
127 }
128 void *device_ptr = nullptr;
129
130 if (stream_id == UINT32_MAX) {
131 stream_id = address->stream_id();
132 }
133
134 if (swap_manager_ != nullptr) {
135 device_ptr = swap_manager_->AllocDeviceMemory(address->GetSize(), stream_id);
136 } else {
137 device_ptr = mem_manager_->MallocMemFromMemPool(address->GetSize(), address->from_persistent_mem(),
138 address->need_recycle(), stream_id);
139 }
140
141 if (!device_ptr) {
142 return false;
143 }
144
145 address->set_ptr(device_ptr);
146 address->set_from_mem_pool(true);
147 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(BindDevicePtr, address, device_ptr);
148 return true;
149 }
150
AllocateMemory(size_t size,uint32_t stream_id) const151 void *GeDeviceResManager::AllocateMemory(size_t size, uint32_t stream_id) const {
152 MS_EXCEPTION_IF_NULL(runtime_instance_);
153 runtime_instance_->SetContext();
154 MS_EXCEPTION_IF_NULL(mem_manager_);
155 if (swap_manager_ != nullptr) {
156 return swap_manager_->AllocDeviceMemory(size, stream_id);
157 }
158 return mem_manager_->MallocMemFromMemPool(size, false, false, stream_id);
159 }
160
GetMaxUsedMemorySize() const161 size_t GeDeviceResManager::GetMaxUsedMemorySize() const {
162 MS_EXCEPTION_IF_NULL(mem_manager_);
163 return mem_manager_->GetMaxUsedMemorySize();
164 }
165
FreeMemory(void * ptr) const166 void GeDeviceResManager::FreeMemory(void *ptr) const {
167 MS_EXCEPTION_IF_NULL(ptr);
168 MS_EXCEPTION_IF_NULL(mem_manager_);
169 mem_manager_->FreeMemFromMemPool(ptr);
170 }
171
FreePartMemorys(const std::vector<void * > & free_addrs,const std::vector<void * > & keep_addrs,const std::vector<size_t> & keep_addr_sizes) const172 void GeDeviceResManager::FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs,
173 const std::vector<size_t> &keep_addr_sizes) const {
174 AscendMemoryPool::GetInstance().FreePartTensorMems(free_addrs, keep_addrs, keep_addr_sizes);
175 }
176
DefragMemory()177 void GeDeviceResManager::DefragMemory() { AscendMemoryPool::GetInstance().DefragMemory(); }
178
179 // Relevant function to manage memory statistics
GetTotalMemStatistics() const180 size_t GeDeviceResManager::GetTotalMemStatistics() const {
181 MS_EXCEPTION_IF_NULL(mem_manager_);
182 return mem_manager_->GetTotalMemStatistics();
183 }
184
GetTotalUsedMemStatistics() const185 size_t GeDeviceResManager::GetTotalUsedMemStatistics() const {
186 MS_EXCEPTION_IF_NULL(mem_manager_);
187 return mem_manager_->GetTotalUsedMemStatistics();
188 }
189
GetTotalIdleMemStatistics() const190 size_t GeDeviceResManager::GetTotalIdleMemStatistics() const {
191 MS_EXCEPTION_IF_NULL(mem_manager_);
192 return mem_manager_->GetTotalIdleMemStatistics();
193 }
194
GetTotalEagerFreeMemStatistics() const195 size_t GeDeviceResManager::GetTotalEagerFreeMemStatistics() const {
196 MS_EXCEPTION_IF_NULL(mem_manager_);
197 return mem_manager_->GetTotalEagerFreeMemStatistics();
198 }
199
GetUsedMemPeakStatistics() const200 size_t GeDeviceResManager::GetUsedMemPeakStatistics() const {
201 MS_EXCEPTION_IF_NULL(mem_manager_);
202 return mem_manager_->GetUsedMemPeakStatistics();
203 }
204
GetReservedMemPeakStatistics() const205 size_t GeDeviceResManager::GetReservedMemPeakStatistics() const {
206 MS_EXCEPTION_IF_NULL(mem_manager_);
207 return mem_manager_->GetReservedMemPeakStatistics();
208 }
209
GetBlockCountsStatistics() const210 std::unordered_map<std::string, std::size_t> GeDeviceResManager::GetBlockCountsStatistics() const {
211 MS_EXCEPTION_IF_NULL(mem_manager_);
212 return mem_manager_->GetBlockCountsStatistics();
213 }
214
GetBlockUnitSizeStatistics() const215 std::unordered_map<std::string, std::size_t> GeDeviceResManager::GetBlockUnitSizeStatistics() const {
216 MS_EXCEPTION_IF_NULL(mem_manager_);
217 return mem_manager_->GetBlockUnitSizeStatistics();
218 }
219
220 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetCommonMemBlocksInfoStatistics() const221 GeDeviceResManager::GetCommonMemBlocksInfoStatistics() const {
222 MS_EXCEPTION_IF_NULL(mem_manager_);
223 return mem_manager_->GetCommonMemBlocksInfoStatistics();
224 }
225
226 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetPersistentMemBlocksInfoStatistics() const227 GeDeviceResManager::GetPersistentMemBlocksInfoStatistics() const {
228 MS_EXCEPTION_IF_NULL(mem_manager_);
229 return mem_manager_->GetPersistentMemBlocksInfoStatistics();
230 }
231
ResetMaxMemoryReserved() const232 void GeDeviceResManager::ResetMaxMemoryReserved() const {
233 MS_EXCEPTION_IF_NULL(mem_manager_);
234 mem_manager_->ResetMaxMemoryReserved();
235 }
236
ResetMaxMemoryAllocated() const237 void GeDeviceResManager::ResetMaxMemoryAllocated() const {
238 MS_EXCEPTION_IF_NULL(mem_manager_);
239 mem_manager_->ResetMaxMemoryAllocated();
240 }
241
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)242 void GeDeviceResManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
243 (void)mem_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream);
244 }
245
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)246 void GeDeviceResManager::SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
247 (void)mem_manager_->SwapOut(device_ptr, host_ptr, mem_size, stream);
248 }
249
AllocateContinuousMemory(const std::vector<size_t> & size_list,uint32_t stream_id) const250 std::vector<void *> GeDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list,
251 uint32_t stream_id) const {
252 MS_EXCEPTION_IF_NULL(runtime_instance_);
253 runtime_instance_->SetContext();
254 MS_EXCEPTION_IF_NULL(mem_manager_);
255 std::vector<size_t> aligned_size_list;
256 for (auto size : size_list) {
257 auto align_size = device::MemoryManager::GetCommonAlignSize(size);
258 aligned_size_list.emplace_back(align_size);
259 }
260 if (swap_manager_ != nullptr) {
261 return swap_manager_->AllocDeviceContinuousMem(aligned_size_list, stream_id);
262 }
263 return mem_manager_->MallocContinuousMemFromMemPool(aligned_size_list, stream_id);
264 }
265
CreateDeviceAddress(const KernelTensorPtr & kernel_tensor) const266 DeviceAddressPtr GeDeviceResManager::CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const {
267 MS_EXCEPTION_IF_NULL(kernel_tensor);
268 if (!is_use_cpu_memory_) {
269 if (kernel_tensor->device_name().empty()) {
270 kernel_tensor->set_device_name(device_context_->device_context_key().device_name_);
271 kernel_tensor->set_device_id(device_context_->device_context_key().device_id_);
272 }
273 auto device_address = std::make_shared<AscendDeviceAddress>(kernel_tensor);
274 device_address->set_device_synchronizer(std::make_shared<AscendDeviceSynchronizer>());
275 return device_address;
276 } else {
277 if (kernel_tensor->device_name().empty()) {
278 kernel_tensor->set_device_name(kCPUDevice);
279 kernel_tensor->set_device_id(0);
280 }
281 auto device_address = std::make_shared<cpu::CPUDeviceAddress>(kernel_tensor);
282 device_address->set_device_synchronizer(std::make_shared<cpu::CPUDeviceSynchronizer>());
283 return device_address;
284 }
285 }
286
CreateDeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id) const287 DeviceAddressPtr GeDeviceResManager::CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector,
288 const Format &format, TypeId type_id,
289 const std::string &device_name, uint32_t device_id,
290 uint32_t stream_id) const {
291 if (!is_use_cpu_memory_) {
292 return std::make_shared<AscendDeviceAddress>(ptr, size, shape_vector, format, type_id, device_name, device_id,
293 stream_id);
294 } else {
295 return std::make_shared<cpu::CPUDeviceAddress>(ptr, size, shape_vector, format, type_id, kCPUDevice, device_id,
296 stream_id);
297 }
298 }
299
GeSetContextOptions(const std::shared_ptr<MsContext> & ms_context_ptr,transform::SessionOptions * options)300 void GeDeviceResManager::GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
301 transform::SessionOptions *options) {
302 MS_EXCEPTION_IF_NULL(options);
303 if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
304 (*options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
305 }
306
307 if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
308 (*options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
309 }
310
311 auto atomic_clean_policy = ms_context_ptr->get_param<std::string>(MS_CTX_ATOMIC_CLEAN_POLICY);
312 if (atomic_clean_policy.empty()) {
313 atomic_clean_policy = "1";
314 }
315 (*options)["ge.exec.atomicCleanPolicy"] = atomic_clean_policy;
316 MS_LOG(INFO) << "Set GE atomic clean policy to " << atomic_clean_policy << ".";
317 (*options)["ge.graphRunMode"] = "1";
318 }
319
CreateSessionAndGraphRunner()320 void GeDeviceResManager::CreateSessionAndGraphRunner() {
321 std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
322 auto ms_context = MsContext::GetInstance();
323 MS_EXCEPTION_IF_NULL(ms_context);
324 if (sess == nullptr) {
325 transform::SessionOptions options;
326 options["ge.enablePrintOpPass"] = "0";
327 GeSetContextOptions(ms_context, &options);
328 options["ge.constLifecycle"] = "graph";
329
330 options["ge.exec.formatMode"] = "0";
331 auto format_mode = common::GetEnv("MS_FORMAT_MODE");
332 if (format_mode == "1" || (format_mode.empty() && ms_context->ascend_soc_version() != "ascend910")) {
333 MS_LOG(INFO) << "Set GE option ge.exec.formatMode to 1.";
334 options["ge.exec.formatMode"] = "1";
335 }
336
337 SetPassthroughGeOptions(false, &options);
338
339 sess = transform::NewSession(options);
340 transform::SetGeSession(sess);
341 }
342
343 transform::GraphRunnerOptions options;
344 options.sess_ptr = sess;
345 auto graph_runner = transform::NewGraphRunner(options);
346 transform::SetGraphRunner(graph_runner);
347 }
348
LoadCollectiveCommLib()349 bool GeDeviceResManager::LoadCollectiveCommLib() {
350 // If this is simulation, load dummy collective communication library.
351 if (!common::GetEnv(kSimulationLevel).empty()) {
352 collective_comm_lib_ = &DummyAscendCollectiveCommLib::GetInstance();
353 return true;
354 }
355 // Ascend backend supports HCCL and LCCL collective communication libraries.
356 if (!common::GetEnv("MS_ENABLE_LCCL").empty()) {
357 std::string lowlatency_comm_lib_name = GetCurrentDir() + "/ascend/liblowlatency_collective.so";
358 auto loader = std::make_shared<CollectiveCommLibLoader>(lowlatency_comm_lib_name);
359 MS_EXCEPTION_IF_NULL(loader);
360 if (!loader->Initialize()) {
361 MS_LOG(EXCEPTION) << "Loading LCCL collective library failed.";
362 return false;
363 }
364 void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
365 MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
366
367 auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
368 collective_comm_lib_ = instance_func();
369 MS_EXCEPTION_IF_NULL(collective_comm_lib_);
370 MS_LOG(WARNING) << "Loading LCCL because env MS_ENABLE_LCCL is set to 1. Pay attention that LCCL only supports "
371 "single-node-multi-card mode in KernelByKernel for now.";
372 } else {
373 collective_comm_lib_ = &AscendCollectiveCommLib::GetInstance();
374 }
375 return true;
376 }
377
BindDeviceToCurrentThread(bool force_bind) const378 bool GeDeviceResManager::BindDeviceToCurrentThread(bool force_bind) const {
379 static thread_local std::once_flag is_set;
380 std::call_once(is_set, []() {
381 auto ms_context = MsContext::GetInstance();
382 MS_EXCEPTION_IF_NULL(ms_context);
383 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
384 auto ret = CALL_ASCEND_API(aclrtSetDevice, static_cast<int32_t>(device_id));
385 if (ret != ACL_ERROR_NONE) {
386 MS_LOG(EXCEPTION) << "Device " << device_id << " call aclrtSetDevice failed, ret:" << static_cast<int>(ret);
387 }
388 transform::AclUtil::SetDeterministic();
389 });
390
391 if (runtime_instance_ != nullptr) {
392 if (force_bind) {
393 runtime_instance_->SetContextForce();
394 } else {
395 runtime_instance_->SetContext();
396 }
397 }
398 return true;
399 }
400
ResetStreamAndCtx()401 void GeDeviceResManager::ResetStreamAndCtx() {
402 if (runtime_instance_ != nullptr) {
403 runtime_instance_->ResetStreamAndCtx();
404 }
405 }
406
CreateStream(size_t * stream_id) const407 bool GeDeviceResManager::CreateStream(size_t *stream_id) const {
408 if (!BindDeviceToCurrentThread(false)) {
409 MS_LOG(ERROR) << "Bind context to current thread failed";
410 return false;
411 }
412 AscendStreamMng::GetInstance().CreateStream(stream_id);
413 return true;
414 }
415
CreateStreamWithPriority(size_t * stream_id,int32_t priority) const416 bool GeDeviceResManager::CreateStreamWithPriority(size_t *stream_id, int32_t priority) const {
417 if (!BindDeviceToCurrentThread(false)) {
418 MS_LOG(ERROR) << "Bind context to current thread failed";
419 return false;
420 }
421 AscendStreamMng::GetInstance().CreateStreamWithFlags(stream_id, ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC,
422 IntToUint(priority));
423 return true;
424 }
425
QueryStreamSize() const426 size_t GeDeviceResManager::QueryStreamSize() const { return AscendStreamMng::GetInstance().QueryStreamSize(); }
427
GetStreamIds() const428 std::vector<uint32_t> GeDeviceResManager::GetStreamIds() const { return AscendStreamMng::GetInstance().GetStreamIds(); }
429
single_op_multi_stream_enable() const430 bool GeDeviceResManager::single_op_multi_stream_enable() const {
431 return AscendStreamMng::GetInstance().single_op_multi_stream_enable();
432 }
433
set_single_op_multi_stream_enable(bool single_op_multi_stream_enable)434 void GeDeviceResManager::set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) {
435 return AscendStreamMng::GetInstance().set_single_op_multi_stream_enable(single_op_multi_stream_enable);
436 }
437
GetStream(size_t stream_id) const438 void *GeDeviceResManager::GetStream(size_t stream_id) const {
439 if (!BindDeviceToCurrentThread(false)) {
440 MS_LOG(ERROR) << "Bind context to current thread failed";
441 return nullptr;
442 }
443 return AscendStreamMng::GetInstance().GetStream(stream_id);
444 }
445
SetCurrentStreamId(size_t stream_id)446 void GeDeviceResManager::SetCurrentStreamId(size_t stream_id) {
447 if (!BindDeviceToCurrentThread(false)) {
448 MS_LOG(ERROR) << "Bind context to current thread failed";
449 return;
450 }
451 AscendStreamMng::GetInstance().set_current_stream(stream_id);
452 }
453
GetCurrentStreamId() const454 size_t GeDeviceResManager::GetCurrentStreamId() const {
455 if (!BindDeviceToCurrentThread(false)) {
456 MS_LOG(ERROR) << "Bind context to current thread failed";
457 return SIZE_MAX;
458 }
459 return AscendStreamMng::GetInstance().current_stream();
460 }
461
QueryStream(size_t stream_id) const462 bool GeDeviceResManager::QueryStream(size_t stream_id) const {
463 if (!BindDeviceToCurrentThread(false)) {
464 MS_LOG(ERROR) << "Bind context to current thread failed";
465 return false;
466 }
467 return AscendStreamMng::GetInstance().QueryStream(stream_id);
468 }
469
SyncStream(size_t stream_id) const470 bool GeDeviceResManager::SyncStream(size_t stream_id) const {
471 if (!BindDeviceToCurrentThread(false)) {
472 MS_LOG(ERROR) << "Bind context to current thread failed";
473 return false;
474 }
475 return AscendStreamMng::GetInstance().SyncStream(stream_id);
476 }
477
SyncAllStreams() const478 bool GeDeviceResManager::SyncAllStreams() const {
479 if (runtime_instance_ == nullptr) {
480 return true;
481 }
482 runtime_instance_->SetContext();
483 return AscendStreamMng::GetInstance().SyncAllStreams();
484 }
485
SyncNotDefaultStreams() const486 bool GeDeviceResManager::SyncNotDefaultStreams() const {
487 if (!BindDeviceToCurrentThread(false)) {
488 MS_LOG(ERROR) << "Bind context to current thread failed";
489 return false;
490 }
491 return AscendStreamMng::GetInstance().SyncNotDefaultStreams();
492 }
493
DefaultStream() const494 size_t GeDeviceResManager::DefaultStream() const {
495 if (!BindDeviceToCurrentThread(false)) {
496 MS_LOG(ERROR) << "Bind context to current thread failed";
497 return SIZE_MAX;
498 }
499 return AscendStreamMng::GetInstance().default_stream_id();
500 }
501
502 // ACL_EVENT_TIME_LINE: indicates that the number of created events is not limited, and the created events can be used
503 // to compute the elapsed time between events, which may cause lost some performance.
504 // ACL_EVENT_SYNC: indicates that the number of created events is limited, and the created events can be used for
505 // synchronization between multiple streams.
506 // ACL_EVENT_CAPTURE_STREAM_PROGRESS: indicates that the number of created events is not limited and high performance,
507 // and the created events can not be used for timing and synchronization.
CreateRuntimeEvent(bool enable_blocking,bool enable_record_wait)508 DeviceEventPtr GeDeviceResManager::CreateRuntimeEvent(bool enable_blocking, bool enable_record_wait) {
509 if (!enable_blocking && !enable_record_wait) {
510 MS_LOG(INTERNAL_EXCEPTION) << "Bad parameters, enable_blocking is false and enable_record_wait is false.";
511 }
512
513 uint32_t flag = 0;
514 if (enable_blocking) {
515 flag |= ACL_EVENT_SYNC;
516 }
517 if (enable_record_wait) {
518 flag |= ACL_EVENT_CAPTURE_STREAM_PROGRESS;
519 }
520 return std::make_shared<AscendEvent>(flag);
521 }
522
CreateEventWithFlag(bool enable_timing,bool blocking)523 DeviceEventPtr GeDeviceResManager::CreateEventWithFlag(bool enable_timing, bool blocking) {
524 auto flag = enable_timing ? ACL_EVENT_TIME_LINE : ACL_EVENT_DEFAULT;
525 auto event = std::make_shared<AscendEvent>(flag);
526 MS_EXCEPTION_IF_NULL(event);
527 std::lock_guard<std::mutex> lock(device_events_mutex_);
528 device_events_.push_back(event);
529 return event;
530 }
531
MoveTo(const tensor::TensorPtr & src_tensor,const tensor::TensorPtr & dst_tensor,const std::string & to,bool blocking,bool * return_self)532 void GeDeviceResManager::MoveTo(const tensor::TensorPtr &src_tensor, const tensor::TensorPtr &dst_tensor,
533 const std::string &to, bool blocking, bool *return_self) {
534 device::MoveTo(src_tensor, dst_tensor, to, blocking, return_self);
535 }
536 } // namespace ascend
537 } // namespace device
538 } // namespace mindspore
539