• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/ascend/hal/hardware/ge_device_res_manager.h"
18 #ifndef _WIN32
19 #include <dlfcn.h>
20 #include <libgen.h>
21 #endif
22 #include "plugin/device/ascend/hal/hardware/ge_utils.h"
23 #include <utility>
24 #include "plugin/device/cpu/hal/device/cpu_memory_manager.h"
25 #include "plugin/device/ascend/hal/device/ascend_memory_manager.h"
26 #include "plugin/device/ascend/hal/device/ascend_device_address.h"
27 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
28 #include "plugin/device/ascend/hal/device/ascend_device_synchronizer.h"
29 #include "plugin/device/ascend/hal/device/ascend_event.h"
30 #include "plugin/device/ascend/hal/device/ascend_pin_mem_pool.h"
31 #include "plugin/device/cpu/hal/device/cpu_device_synchronizer.h"
32 #include "include/transform/graph_ir/utils.h"
33 #include "graph/types.h"
34 #include "transform/symbol/acl_rt_symbol.h"
35 #include "transform/symbol/symbol_utils.h"
36 #include "transform/acl_ir/op_api_util.h"
37 #include "include/backend/mem_reuse/mem_tracker.h"
38 #include "graph/def_types.h"
39 #include "runtime/device/move_to.h"
40 
41 namespace mindspore {
42 namespace device {
43 namespace ascend {
GetCurrentDir()44 std::string GetCurrentDir() {
45 #ifndef _WIN32
46   Dl_info dl_info;
47   if (dladdr(reinterpret_cast<void *>(GetCurrentDir), &dl_info) == 0) {
48     MS_LOG(WARNING) << "Get dladdr error";
49     return "";
50   }
51   std::string cur_so_path = dl_info.dli_fname;
52   return dirname(cur_so_path.data());
53 #else
54   return "";
55 #endif
56 }
57 
Malloc(size_t size)58 ::ge::MemBlock *GeAllocator::Malloc(size_t size) {
59   auto addr = res_manager_->AllocateMemory(size);
60   MS_LOG(DEBUG) << "GE Allocator malloc addr: " << addr << " size: " << size;
61   auto mem_block = new ::ge::MemBlock(*this, addr, size);
62   return mem_block;
63 }
64 
Free(::ge::MemBlock * block)65 void GeAllocator::Free(::ge::MemBlock *block) {
66   res_manager_->FreeMemory(block->GetAddr());
67   MS_LOG(DEBUG) << "GE Allocator free addr: " << block->GetAddr();
68   delete block;
69 }
70 
Initialize()71 void GeDeviceResManager::Initialize() {
72   auto ms_context = MsContext::GetInstance();
73   MS_EXCEPTION_IF_NULL(ms_context);
74   auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
75   runtime_instance_ = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
76   MS_EXCEPTION_IF_NULL(runtime_instance_);
77   if (!runtime_instance_->Init()) {
78     MS_LOG(EXCEPTION) << "Kernel runtime init error.";
79   }
80   mem_manager_ = runtime_instance_->GetMemoryManager();
81   MS_EXCEPTION_IF_NULL(mem_manager_);
82   if (ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD)) {
83     swap_manager_ = std::make_shared<SwapManager>(kDefaultStreamIndex, &AscendMemoryPool::GetInstance(),
84                                                   &AscendPinMemPool::GetInstance());
85   }
86 }
87 
SetCPUMemManager()88 void GeDeviceResManager::SetCPUMemManager() {
89   if (is_use_cpu_memory_) {
90     return;
91   }
92   if (mem_manager_ != nullptr) {
93     mem_manager_->Finalize();
94     mem_manager_ = nullptr;
95   }
96   runtime_instance_ = nullptr;
97   mem_manager_ = std::make_shared<cpu::CPUMemoryManager>();
98   MS_EXCEPTION_IF_NULL(mem_manager_);
99   is_use_cpu_memory_ = true;
100 }
101 
Destroy()102 void GeDeviceResManager::Destroy() {
103   (void)DestroyAllEvents();
104   // Release memory.
105   if (mem_manager_ != nullptr) {
106     mem_manager_->Finalize();
107     mem_manager_ = nullptr;
108   }
109 }
110 
AllocateMemory(DeviceAddress * const & address,uint32_t stream_id) const111 bool GeDeviceResManager::AllocateMemory(DeviceAddress *const &address, uint32_t stream_id) const {
112   MS_EXCEPTION_IF_NULL(address);
113   MS_EXCEPTION_IF_NULL(mem_manager_);
114   auto device_name_in_address = GetDeviceNameByType(static_cast<const DeviceType>(address->GetDeviceType()));
115   if (IsEnableRefMode() && device_name_in_address != device_context_->device_context_key().device_name_) {
116     MS_LOG(EXCEPTION) << "The device address type is wrong: type name in address:" << device_name_in_address
117                       << ", type name in context:" << device_context_->device_context_key().device_name_;
118   }
119 
120   if (address->GetPtr() != nullptr) {
121     MS_LOG(ERROR) << "Memory leak detected!";
122     return false;
123   }
124 
125   if (runtime_instance_ != nullptr) {
126     runtime_instance_->SetContext();
127   }
128   void *device_ptr = nullptr;
129 
130   if (stream_id == UINT32_MAX) {
131     stream_id = address->stream_id();
132   }
133 
134   if (swap_manager_ != nullptr) {
135     device_ptr = swap_manager_->AllocDeviceMemory(address->GetSize(), stream_id);
136   } else {
137     device_ptr = mem_manager_->MallocMemFromMemPool(address->GetSize(), address->from_persistent_mem(),
138                                                     address->need_recycle(), stream_id);
139   }
140 
141   if (!device_ptr) {
142     return false;
143   }
144 
145   address->set_ptr(device_ptr);
146   address->set_from_mem_pool(true);
147   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(BindDevicePtr, address, device_ptr);
148   return true;
149 }
150 
AllocateMemory(size_t size,uint32_t stream_id) const151 void *GeDeviceResManager::AllocateMemory(size_t size, uint32_t stream_id) const {
152   MS_EXCEPTION_IF_NULL(runtime_instance_);
153   runtime_instance_->SetContext();
154   MS_EXCEPTION_IF_NULL(mem_manager_);
155   if (swap_manager_ != nullptr) {
156     return swap_manager_->AllocDeviceMemory(size, stream_id);
157   }
158   return mem_manager_->MallocMemFromMemPool(size, false, false, stream_id);
159 }
160 
GetMaxUsedMemorySize() const161 size_t GeDeviceResManager::GetMaxUsedMemorySize() const {
162   MS_EXCEPTION_IF_NULL(mem_manager_);
163   return mem_manager_->GetMaxUsedMemorySize();
164 }
165 
FreeMemory(void * ptr) const166 void GeDeviceResManager::FreeMemory(void *ptr) const {
167   MS_EXCEPTION_IF_NULL(ptr);
168   MS_EXCEPTION_IF_NULL(mem_manager_);
169   mem_manager_->FreeMemFromMemPool(ptr);
170 }
171 
FreePartMemorys(const std::vector<void * > & free_addrs,const std::vector<void * > & keep_addrs,const std::vector<size_t> & keep_addr_sizes) const172 void GeDeviceResManager::FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs,
173                                          const std::vector<size_t> &keep_addr_sizes) const {
174   AscendMemoryPool::GetInstance().FreePartTensorMems(free_addrs, keep_addrs, keep_addr_sizes);
175 }
176 
DefragMemory()177 void GeDeviceResManager::DefragMemory() { AscendMemoryPool::GetInstance().DefragMemory(); }
178 
179 // Relevant function to manage memory statistics
GetTotalMemStatistics() const180 size_t GeDeviceResManager::GetTotalMemStatistics() const {
181   MS_EXCEPTION_IF_NULL(mem_manager_);
182   return mem_manager_->GetTotalMemStatistics();
183 }
184 
GetTotalUsedMemStatistics() const185 size_t GeDeviceResManager::GetTotalUsedMemStatistics() const {
186   MS_EXCEPTION_IF_NULL(mem_manager_);
187   return mem_manager_->GetTotalUsedMemStatistics();
188 }
189 
GetTotalIdleMemStatistics() const190 size_t GeDeviceResManager::GetTotalIdleMemStatistics() const {
191   MS_EXCEPTION_IF_NULL(mem_manager_);
192   return mem_manager_->GetTotalIdleMemStatistics();
193 }
194 
GetTotalEagerFreeMemStatistics() const195 size_t GeDeviceResManager::GetTotalEagerFreeMemStatistics() const {
196   MS_EXCEPTION_IF_NULL(mem_manager_);
197   return mem_manager_->GetTotalEagerFreeMemStatistics();
198 }
199 
GetUsedMemPeakStatistics() const200 size_t GeDeviceResManager::GetUsedMemPeakStatistics() const {
201   MS_EXCEPTION_IF_NULL(mem_manager_);
202   return mem_manager_->GetUsedMemPeakStatistics();
203 }
204 
GetReservedMemPeakStatistics() const205 size_t GeDeviceResManager::GetReservedMemPeakStatistics() const {
206   MS_EXCEPTION_IF_NULL(mem_manager_);
207   return mem_manager_->GetReservedMemPeakStatistics();
208 }
209 
GetBlockCountsStatistics() const210 std::unordered_map<std::string, std::size_t> GeDeviceResManager::GetBlockCountsStatistics() const {
211   MS_EXCEPTION_IF_NULL(mem_manager_);
212   return mem_manager_->GetBlockCountsStatistics();
213 }
214 
GetBlockUnitSizeStatistics() const215 std::unordered_map<std::string, std::size_t> GeDeviceResManager::GetBlockUnitSizeStatistics() const {
216   MS_EXCEPTION_IF_NULL(mem_manager_);
217   return mem_manager_->GetBlockUnitSizeStatistics();
218 }
219 
220 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetCommonMemBlocksInfoStatistics() const221 GeDeviceResManager::GetCommonMemBlocksInfoStatistics() const {
222   MS_EXCEPTION_IF_NULL(mem_manager_);
223   return mem_manager_->GetCommonMemBlocksInfoStatistics();
224 }
225 
226 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetPersistentMemBlocksInfoStatistics() const227 GeDeviceResManager::GetPersistentMemBlocksInfoStatistics() const {
228   MS_EXCEPTION_IF_NULL(mem_manager_);
229   return mem_manager_->GetPersistentMemBlocksInfoStatistics();
230 }
231 
ResetMaxMemoryReserved() const232 void GeDeviceResManager::ResetMaxMemoryReserved() const {
233   MS_EXCEPTION_IF_NULL(mem_manager_);
234   mem_manager_->ResetMaxMemoryReserved();
235 }
236 
ResetMaxMemoryAllocated() const237 void GeDeviceResManager::ResetMaxMemoryAllocated() const {
238   MS_EXCEPTION_IF_NULL(mem_manager_);
239   mem_manager_->ResetMaxMemoryAllocated();
240 }
241 
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)242 void GeDeviceResManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
243   (void)mem_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream);
244 }
245 
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)246 void GeDeviceResManager::SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
247   (void)mem_manager_->SwapOut(device_ptr, host_ptr, mem_size, stream);
248 }
249 
AllocateContinuousMemory(const std::vector<size_t> & size_list,uint32_t stream_id) const250 std::vector<void *> GeDeviceResManager::AllocateContinuousMemory(const std::vector<size_t> &size_list,
251                                                                  uint32_t stream_id) const {
252   MS_EXCEPTION_IF_NULL(runtime_instance_);
253   runtime_instance_->SetContext();
254   MS_EXCEPTION_IF_NULL(mem_manager_);
255   std::vector<size_t> aligned_size_list;
256   for (auto size : size_list) {
257     auto align_size = device::MemoryManager::GetCommonAlignSize(size);
258     aligned_size_list.emplace_back(align_size);
259   }
260   if (swap_manager_ != nullptr) {
261     return swap_manager_->AllocDeviceContinuousMem(aligned_size_list, stream_id);
262   }
263   return mem_manager_->MallocContinuousMemFromMemPool(aligned_size_list, stream_id);
264 }
265 
CreateDeviceAddress(const KernelTensorPtr & kernel_tensor) const266 DeviceAddressPtr GeDeviceResManager::CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const {
267   MS_EXCEPTION_IF_NULL(kernel_tensor);
268   if (!is_use_cpu_memory_) {
269     if (kernel_tensor->device_name().empty()) {
270       kernel_tensor->set_device_name(device_context_->device_context_key().device_name_);
271       kernel_tensor->set_device_id(device_context_->device_context_key().device_id_);
272     }
273     auto device_address = std::make_shared<AscendDeviceAddress>(kernel_tensor);
274     device_address->set_device_synchronizer(std::make_shared<AscendDeviceSynchronizer>());
275     return device_address;
276   } else {
277     if (kernel_tensor->device_name().empty()) {
278       kernel_tensor->set_device_name(kCPUDevice);
279       kernel_tensor->set_device_id(0);
280     }
281     auto device_address = std::make_shared<cpu::CPUDeviceAddress>(kernel_tensor);
282     device_address->set_device_synchronizer(std::make_shared<cpu::CPUDeviceSynchronizer>());
283     return device_address;
284   }
285 }
286 
CreateDeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id) const287 DeviceAddressPtr GeDeviceResManager::CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector,
288                                                          const Format &format, TypeId type_id,
289                                                          const std::string &device_name, uint32_t device_id,
290                                                          uint32_t stream_id) const {
291   if (!is_use_cpu_memory_) {
292     return std::make_shared<AscendDeviceAddress>(ptr, size, shape_vector, format, type_id, device_name, device_id,
293                                                  stream_id);
294   } else {
295     return std::make_shared<cpu::CPUDeviceAddress>(ptr, size, shape_vector, format, type_id, kCPUDevice, device_id,
296                                                    stream_id);
297   }
298 }
299 
GeSetContextOptions(const std::shared_ptr<MsContext> & ms_context_ptr,transform::SessionOptions * options)300 void GeDeviceResManager::GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr,
301                                              transform::SessionOptions *options) {
302   MS_EXCEPTION_IF_NULL(options);
303   if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
304     (*options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
305   }
306 
307   if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
308     (*options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
309   }
310 
311   auto atomic_clean_policy = ms_context_ptr->get_param<std::string>(MS_CTX_ATOMIC_CLEAN_POLICY);
312   if (atomic_clean_policy.empty()) {
313     atomic_clean_policy = "1";
314   }
315   (*options)["ge.exec.atomicCleanPolicy"] = atomic_clean_policy;
316   MS_LOG(INFO) << "Set GE atomic clean policy to " << atomic_clean_policy << ".";
317   (*options)["ge.graphRunMode"] = "1";
318 }
319 
CreateSessionAndGraphRunner()320 void GeDeviceResManager::CreateSessionAndGraphRunner() {
321   std::shared_ptr<::ge::Session> sess = transform::GetGeSession();
322   auto ms_context = MsContext::GetInstance();
323   MS_EXCEPTION_IF_NULL(ms_context);
324   if (sess == nullptr) {
325     transform::SessionOptions options;
326     options["ge.enablePrintOpPass"] = "0";
327     GeSetContextOptions(ms_context, &options);
328     options["ge.constLifecycle"] = "graph";
329 
330     options["ge.exec.formatMode"] = "0";
331     auto format_mode = common::GetEnv("MS_FORMAT_MODE");
332     if (format_mode == "1" || (format_mode.empty() && ms_context->ascend_soc_version() != "ascend910")) {
333       MS_LOG(INFO) << "Set GE option ge.exec.formatMode to 1.";
334       options["ge.exec.formatMode"] = "1";
335     }
336 
337     SetPassthroughGeOptions(false, &options);
338 
339     sess = transform::NewSession(options);
340     transform::SetGeSession(sess);
341   }
342 
343   transform::GraphRunnerOptions options;
344   options.sess_ptr = sess;
345   auto graph_runner = transform::NewGraphRunner(options);
346   transform::SetGraphRunner(graph_runner);
347 }
348 
LoadCollectiveCommLib()349 bool GeDeviceResManager::LoadCollectiveCommLib() {
350   // If this is simulation, load dummy collective communication library.
351   if (!common::GetEnv(kSimulationLevel).empty()) {
352     collective_comm_lib_ = &DummyAscendCollectiveCommLib::GetInstance();
353     return true;
354   }
355   // Ascend backend supports HCCL and LCCL collective communication libraries.
356   if (!common::GetEnv("MS_ENABLE_LCCL").empty()) {
357     std::string lowlatency_comm_lib_name = GetCurrentDir() + "/ascend/liblowlatency_collective.so";
358     auto loader = std::make_shared<CollectiveCommLibLoader>(lowlatency_comm_lib_name);
359     MS_EXCEPTION_IF_NULL(loader);
360     if (!loader->Initialize()) {
361       MS_LOG(EXCEPTION) << "Loading LCCL collective library failed.";
362       return false;
363     }
364     void *collective_comm_lib_handle = loader->collective_comm_lib_ptr();
365     MS_EXCEPTION_IF_NULL(collective_comm_lib_handle);
366 
367     auto instance_func = DlsymFuncObj(communication_lib_instance, collective_comm_lib_handle);
368     collective_comm_lib_ = instance_func();
369     MS_EXCEPTION_IF_NULL(collective_comm_lib_);
370     MS_LOG(WARNING) << "Loading LCCL because env MS_ENABLE_LCCL is set to 1. Pay attention that LCCL only supports "
371                        "single-node-multi-card mode in KernelByKernel for now.";
372   } else {
373     collective_comm_lib_ = &AscendCollectiveCommLib::GetInstance();
374   }
375   return true;
376 }
377 
BindDeviceToCurrentThread(bool force_bind) const378 bool GeDeviceResManager::BindDeviceToCurrentThread(bool force_bind) const {
379   static thread_local std::once_flag is_set;
380   std::call_once(is_set, []() {
381     auto ms_context = MsContext::GetInstance();
382     MS_EXCEPTION_IF_NULL(ms_context);
383     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
384     auto ret = CALL_ASCEND_API(aclrtSetDevice, static_cast<int32_t>(device_id));
385     if (ret != ACL_ERROR_NONE) {
386       MS_LOG(EXCEPTION) << "Device " << device_id << " call aclrtSetDevice failed, ret:" << static_cast<int>(ret);
387     }
388     transform::AclUtil::SetDeterministic();
389   });
390 
391   if (runtime_instance_ != nullptr) {
392     if (force_bind) {
393       runtime_instance_->SetContextForce();
394     } else {
395       runtime_instance_->SetContext();
396     }
397   }
398   return true;
399 }
400 
ResetStreamAndCtx()401 void GeDeviceResManager::ResetStreamAndCtx() {
402   if (runtime_instance_ != nullptr) {
403     runtime_instance_->ResetStreamAndCtx();
404   }
405 }
406 
CreateStream(size_t * stream_id) const407 bool GeDeviceResManager::CreateStream(size_t *stream_id) const {
408   if (!BindDeviceToCurrentThread(false)) {
409     MS_LOG(ERROR) << "Bind context to current thread failed";
410     return false;
411   }
412   AscendStreamMng::GetInstance().CreateStream(stream_id);
413   return true;
414 }
415 
CreateStreamWithPriority(size_t * stream_id,int32_t priority) const416 bool GeDeviceResManager::CreateStreamWithPriority(size_t *stream_id, int32_t priority) const {
417   if (!BindDeviceToCurrentThread(false)) {
418     MS_LOG(ERROR) << "Bind context to current thread failed";
419     return false;
420   }
421   AscendStreamMng::GetInstance().CreateStreamWithFlags(stream_id, ACL_STREAM_FAST_LAUNCH | ACL_STREAM_FAST_SYNC,
422                                                        IntToUint(priority));
423   return true;
424 }
425 
QueryStreamSize() const426 size_t GeDeviceResManager::QueryStreamSize() const { return AscendStreamMng::GetInstance().QueryStreamSize(); }
427 
GetStreamIds() const428 std::vector<uint32_t> GeDeviceResManager::GetStreamIds() const { return AscendStreamMng::GetInstance().GetStreamIds(); }
429 
single_op_multi_stream_enable() const430 bool GeDeviceResManager::single_op_multi_stream_enable() const {
431   return AscendStreamMng::GetInstance().single_op_multi_stream_enable();
432 }
433 
set_single_op_multi_stream_enable(bool single_op_multi_stream_enable)434 void GeDeviceResManager::set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) {
435   return AscendStreamMng::GetInstance().set_single_op_multi_stream_enable(single_op_multi_stream_enable);
436 }
437 
GetStream(size_t stream_id) const438 void *GeDeviceResManager::GetStream(size_t stream_id) const {
439   if (!BindDeviceToCurrentThread(false)) {
440     MS_LOG(ERROR) << "Bind context to current thread failed";
441     return nullptr;
442   }
443   return AscendStreamMng::GetInstance().GetStream(stream_id);
444 }
445 
SetCurrentStreamId(size_t stream_id)446 void GeDeviceResManager::SetCurrentStreamId(size_t stream_id) {
447   if (!BindDeviceToCurrentThread(false)) {
448     MS_LOG(ERROR) << "Bind context to current thread failed";
449     return;
450   }
451   AscendStreamMng::GetInstance().set_current_stream(stream_id);
452 }
453 
GetCurrentStreamId() const454 size_t GeDeviceResManager::GetCurrentStreamId() const {
455   if (!BindDeviceToCurrentThread(false)) {
456     MS_LOG(ERROR) << "Bind context to current thread failed";
457     return SIZE_MAX;
458   }
459   return AscendStreamMng::GetInstance().current_stream();
460 }
461 
QueryStream(size_t stream_id) const462 bool GeDeviceResManager::QueryStream(size_t stream_id) const {
463   if (!BindDeviceToCurrentThread(false)) {
464     MS_LOG(ERROR) << "Bind context to current thread failed";
465     return false;
466   }
467   return AscendStreamMng::GetInstance().QueryStream(stream_id);
468 }
469 
SyncStream(size_t stream_id) const470 bool GeDeviceResManager::SyncStream(size_t stream_id) const {
471   if (!BindDeviceToCurrentThread(false)) {
472     MS_LOG(ERROR) << "Bind context to current thread failed";
473     return false;
474   }
475   return AscendStreamMng::GetInstance().SyncStream(stream_id);
476 }
477 
SyncAllStreams() const478 bool GeDeviceResManager::SyncAllStreams() const {
479   if (runtime_instance_ == nullptr) {
480     return true;
481   }
482   runtime_instance_->SetContext();
483   return AscendStreamMng::GetInstance().SyncAllStreams();
484 }
485 
SyncNotDefaultStreams() const486 bool GeDeviceResManager::SyncNotDefaultStreams() const {
487   if (!BindDeviceToCurrentThread(false)) {
488     MS_LOG(ERROR) << "Bind context to current thread failed";
489     return false;
490   }
491   return AscendStreamMng::GetInstance().SyncNotDefaultStreams();
492 }
493 
DefaultStream() const494 size_t GeDeviceResManager::DefaultStream() const {
495   if (!BindDeviceToCurrentThread(false)) {
496     MS_LOG(ERROR) << "Bind context to current thread failed";
497     return SIZE_MAX;
498   }
499   return AscendStreamMng::GetInstance().default_stream_id();
500 }
501 
502 // ACL_EVENT_TIME_LINE: indicates that the number of created events is not limited, and the created events can be used
503 //  to compute the elapsed time between events, which may cause lost some performance.
504 // ACL_EVENT_SYNC: indicates that the number of created events is limited, and the created events can be used for
505 //  synchronization between multiple streams.
506 // ACL_EVENT_CAPTURE_STREAM_PROGRESS: indicates that the number of created events is not limited and high performance,
507 //  and the created events can not be used for timing and synchronization.
CreateRuntimeEvent(bool enable_blocking,bool enable_record_wait)508 DeviceEventPtr GeDeviceResManager::CreateRuntimeEvent(bool enable_blocking, bool enable_record_wait) {
509   if (!enable_blocking && !enable_record_wait) {
510     MS_LOG(INTERNAL_EXCEPTION) << "Bad parameters, enable_blocking is false and enable_record_wait is false.";
511   }
512 
513   uint32_t flag = 0;
514   if (enable_blocking) {
515     flag |= ACL_EVENT_SYNC;
516   }
517   if (enable_record_wait) {
518     flag |= ACL_EVENT_CAPTURE_STREAM_PROGRESS;
519   }
520   return std::make_shared<AscendEvent>(flag);
521 }
522 
CreateEventWithFlag(bool enable_timing,bool blocking)523 DeviceEventPtr GeDeviceResManager::CreateEventWithFlag(bool enable_timing, bool blocking) {
524   auto flag = enable_timing ? ACL_EVENT_TIME_LINE : ACL_EVENT_DEFAULT;
525   auto event = std::make_shared<AscendEvent>(flag);
526   MS_EXCEPTION_IF_NULL(event);
527   std::lock_guard<std::mutex> lock(device_events_mutex_);
528   device_events_.push_back(event);
529   return event;
530 }
531 
MoveTo(const tensor::TensorPtr & src_tensor,const tensor::TensorPtr & dst_tensor,const std::string & to,bool blocking,bool * return_self)532 void GeDeviceResManager::MoveTo(const tensor::TensorPtr &src_tensor, const tensor::TensorPtr &dst_tensor,
533                                 const std::string &to, bool blocking, bool *return_self) {
534   device::MoveTo(src_tensor, dst_tensor, to, blocking, return_self);
535 }
536 }  // namespace ascend
537 }  // namespace device
538 }  // namespace mindspore
539