1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "transform/acl_ir/acl_allocator.h"
17 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
18 #include "transform/symbol/acl_rt_allocator_symbol.h"
19 #include "transform/symbol/symbol_utils.h"
20 #include "include/backend/mem_reuse/mem_tracker.h"
21
22 namespace mindspore {
23 namespace transform {
AllocFunc(void * obj,size_t size)24 void *AclAllocator::AllocFunc(void *obj, size_t size) {
25 MS_EXCEPTION_IF_NULL(obj);
26 auto allocator = static_cast<AclAllocator *>(obj);
27 MS_EXCEPTION_IF_NULL(allocator);
28 auto stream_ptr = allocator->stream();
29 auto stream_id = device::ascend::AscendStreamMng::GetInstance().GetStreamId(stream_ptr);
30 MS_EXCEPTION_IF_NULL(allocator->device_context_->device_res_manager_);
31 auto block = allocator->device_context_->device_res_manager_->AllocateMemory(size, stream_id);
32 if (block == nullptr) {
33 MS_LOG(EXCEPTION) << "Malloc Mem From Mem Pool failed, size:" << size;
34 }
35 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "AclWorkspace", "AclWorkspace", "");
36 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddCompileTimeMemInfo, "AclWorkspace", size, block,
37 device::tracker::MemType::kWorkSpace);
38 return block;
39 }
40
AllocAdviseFunc(void * obj,size_t size,void * addr)41 void *AclAllocator::AllocAdviseFunc(void *obj, size_t size, void *addr) {
42 MS_EXCEPTION_IF_NULL(obj);
43 MS_EXCEPTION_IF_NULL(addr);
44 addr = AclAllocator::AllocFunc(obj, size);
45 return addr;
46 }
47
FreeFunc(void * obj,void * block)48 void AclAllocator::FreeFunc(void *obj, void *block) {
49 MS_EXCEPTION_IF_NULL(obj);
50 auto allocator = static_cast<AclAllocator *>(obj);
51 MS_EXCEPTION_IF_NULL(allocator);
52 MS_EXCEPTION_IF_NULL(allocator->device_context_->device_res_manager_);
53 allocator->device_context_->device_res_manager_->FreeMemory(block);
54 }
55
GetAddrFromBlock(void * block)56 void *AclAllocator::GetAddrFromBlock(void *block) {
57 MS_EXCEPTION_IF_NULL(block);
58 return block;
59 }
60
NewAclAllocator(void * stream)61 AclAllocatorPtr AclAllocatorRegister::NewAclAllocator(void *stream) {
62 auto allocator_obj = std::make_shared<AclAllocator>(stream);
63 MS_EXCEPTION_IF_NULL(allocator_obj);
64
65 auto allocator_desc = CALL_ASCEND_API(aclrtAllocatorCreateDesc);
66 MS_EXCEPTION_IF_NULL(allocator_desc);
67 allocator_obj->set_allocator_desc(allocator_desc);
68 (void)CALL_ASCEND_API(aclrtAllocatorSetObjToDesc, allocator_desc, allocator_obj.get());
69 (void)CALL_ASCEND_API(aclrtAllocatorSetAllocFuncToDesc, allocator_desc, AclAllocator::AllocFunc);
70 (void)CALL_ASCEND_API(aclrtAllocatorSetFreeFuncToDesc, allocator_desc, AclAllocator::FreeFunc);
71 (void)CALL_ASCEND_API(aclrtAllocatorSetAllocAdviseFuncToDesc, allocator_desc, AclAllocator::AllocAdviseFunc);
72 (void)CALL_ASCEND_API(aclrtAllocatorSetGetAddrFromBlockFuncToDesc, allocator_desc, AclAllocator::GetAddrFromBlock);
73 return allocator_obj;
74 }
75
FreeAclAllocatorRes(const AclAllocatorPtr & allocator_obj)76 void AclAllocatorRegister::FreeAclAllocatorRes(const AclAllocatorPtr &allocator_obj) {
77 (void)CALL_ASCEND_API(aclrtAllocatorDestroyDesc, allocator_obj->allocator_desc());
78 (void)CALL_ASCEND_API(aclrtAllocatorUnregister, allocator_obj->stream());
79 }
80
~AclAllocatorRegister()81 AclAllocatorRegister::~AclAllocatorRegister() {
82 for (const auto &allocator_iter : allocator_map_) {
83 FreeAclAllocatorRes(allocator_iter.second);
84 }
85 }
86
Instance()87 AclAllocatorRegister &AclAllocatorRegister::Instance() {
88 static AclAllocatorRegister instance;
89 return instance;
90 }
91
RegisterAllocator(void * stream)92 void AclAllocatorRegister::RegisterAllocator(void *stream) {
93 if (allocator_map_.find(stream) == allocator_map_.end()) {
94 const auto &allocator_obj = NewAclAllocator(stream);
95 (void)CALL_ASCEND_API(aclrtAllocatorRegister, stream, allocator_obj->allocator_desc());
96 allocator_map_[stream] = allocator_obj;
97 }
98 }
99 } // namespace transform
100 } // namespace mindspore
101