• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "transform/acl_ir/acl_allocator.h"
17 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
18 #include "transform/symbol/acl_rt_allocator_symbol.h"
19 #include "transform/symbol/symbol_utils.h"
20 #include "include/backend/mem_reuse/mem_tracker.h"
21 
22 namespace mindspore {
23 namespace transform {
AllocFunc(void * obj,size_t size)24 void *AclAllocator::AllocFunc(void *obj, size_t size) {
25   MS_EXCEPTION_IF_NULL(obj);
26   auto allocator = static_cast<AclAllocator *>(obj);
27   MS_EXCEPTION_IF_NULL(allocator);
28   auto stream_ptr = allocator->stream();
29   auto stream_id = device::ascend::AscendStreamMng::GetInstance().GetStreamId(stream_ptr);
30   MS_EXCEPTION_IF_NULL(allocator->device_context_->device_res_manager_);
31   auto block = allocator->device_context_->device_res_manager_->AllocateMemory(size, stream_id);
32   if (block == nullptr) {
33     MS_LOG(EXCEPTION) << "Malloc Mem From Mem Pool failed, size:" << size;
34   }
35   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "AclWorkspace", "AclWorkspace", "");
36   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddCompileTimeMemInfo, "AclWorkspace", size, block,
37                                                  device::tracker::MemType::kWorkSpace);
38   return block;
39 }
40 
AllocAdviseFunc(void * obj,size_t size,void * addr)41 void *AclAllocator::AllocAdviseFunc(void *obj, size_t size, void *addr) {
42   MS_EXCEPTION_IF_NULL(obj);
43   MS_EXCEPTION_IF_NULL(addr);
44   addr = AclAllocator::AllocFunc(obj, size);
45   return addr;
46 }
47 
FreeFunc(void * obj,void * block)48 void AclAllocator::FreeFunc(void *obj, void *block) {
49   MS_EXCEPTION_IF_NULL(obj);
50   auto allocator = static_cast<AclAllocator *>(obj);
51   MS_EXCEPTION_IF_NULL(allocator);
52   MS_EXCEPTION_IF_NULL(allocator->device_context_->device_res_manager_);
53   allocator->device_context_->device_res_manager_->FreeMemory(block);
54 }
55 
GetAddrFromBlock(void * block)56 void *AclAllocator::GetAddrFromBlock(void *block) {
57   MS_EXCEPTION_IF_NULL(block);
58   return block;
59 }
60 
NewAclAllocator(void * stream)61 AclAllocatorPtr AclAllocatorRegister::NewAclAllocator(void *stream) {
62   auto allocator_obj = std::make_shared<AclAllocator>(stream);
63   MS_EXCEPTION_IF_NULL(allocator_obj);
64 
65   auto allocator_desc = CALL_ASCEND_API(aclrtAllocatorCreateDesc);
66   MS_EXCEPTION_IF_NULL(allocator_desc);
67   allocator_obj->set_allocator_desc(allocator_desc);
68   (void)CALL_ASCEND_API(aclrtAllocatorSetObjToDesc, allocator_desc, allocator_obj.get());
69   (void)CALL_ASCEND_API(aclrtAllocatorSetAllocFuncToDesc, allocator_desc, AclAllocator::AllocFunc);
70   (void)CALL_ASCEND_API(aclrtAllocatorSetFreeFuncToDesc, allocator_desc, AclAllocator::FreeFunc);
71   (void)CALL_ASCEND_API(aclrtAllocatorSetAllocAdviseFuncToDesc, allocator_desc, AclAllocator::AllocAdviseFunc);
72   (void)CALL_ASCEND_API(aclrtAllocatorSetGetAddrFromBlockFuncToDesc, allocator_desc, AclAllocator::GetAddrFromBlock);
73   return allocator_obj;
74 }
75 
FreeAclAllocatorRes(const AclAllocatorPtr & allocator_obj)76 void AclAllocatorRegister::FreeAclAllocatorRes(const AclAllocatorPtr &allocator_obj) {
77   (void)CALL_ASCEND_API(aclrtAllocatorDestroyDesc, allocator_obj->allocator_desc());
78   (void)CALL_ASCEND_API(aclrtAllocatorUnregister, allocator_obj->stream());
79 }
80 
~AclAllocatorRegister()81 AclAllocatorRegister::~AclAllocatorRegister() {
82   for (const auto &allocator_iter : allocator_map_) {
83     FreeAclAllocatorRes(allocator_iter.second);
84   }
85 }
86 
Instance()87 AclAllocatorRegister &AclAllocatorRegister::Instance() {
88   static AclAllocatorRegister instance;
89   return instance;
90 }
91 
RegisterAllocator(void * stream)92 void AclAllocatorRegister::RegisterAllocator(void *stream) {
93   if (allocator_map_.find(stream) == allocator_map_.end()) {
94     const auto &allocator_obj = NewAclAllocator(stream);
95     (void)CALL_ASCEND_API(aclrtAllocatorRegister, stream, allocator_obj->allocator_desc());
96     allocator_map_[stream] = allocator_obj;
97   }
98 }
99 }  // namespace transform
100 }  // namespace mindspore
101