1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/hal/device/ascend_memory_pool.h"
18 #include <algorithm>
19 #include <utility>
20 #include "plugin/device/ascend/hal/device/ascend_memory_adapter.h"
21 #include "plugin/device/ascend/hal/device/ascend_gmem_adapter.h"
22 #include "plugin/device/ascend/hal/device/ascend_vmm_adapter.h"
23 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
24 #include "utils/log_adapter.h"
25 #include "utils/convert_utils_base.h"
26 #include "transform/symbol/acl_rt_symbol.h"
27 #include "transform/symbol/symbol_utils.h"
28
29 namespace mindspore {
30 namespace device {
31 namespace ascend {
32 // The minimum unit size (8MB) of memory block used for dynamic extend in graph run mode.
33 static const size_t ASCEND_COMMON_POOL_ALLOC_UNIT_SIZE_FOR_GRAPH_RUN_MODE = 8 << 20;
34 constexpr char kGlobalOverflowWorkspace[] = "GLOBAL_OVERFLOW_WORKSPACE";
35
AscendMemoryPool()36 AscendMemoryPool::AscendMemoryPool() { SetEnableVmm(AscendVmmAdapter::GetInstance().IsEnabled()); }
37
SetMemPoolBlockSize(size_t available_device_mem_size)38 void AscendMemoryPool::SetMemPoolBlockSize(size_t available_device_mem_size) {
39 auto ms_context = MsContext::GetInstance();
40 MS_EXCEPTION_IF_NULL(ms_context);
41 float mem_block_size = ms_context->get_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE);
42 // set from context configuration
43 if (!common::IsFloatEqual(mem_block_size, kDefaultMempoolBlockSize)) {
44 size_t config_size = FloatToSize(mem_block_size * kGBToByte);
45 if (config_size > available_device_mem_size) {
46 MS_LOG(WARNING) << "Memory pool block size " << config_size
47 << " is bigger than currently available maximum memory " << available_device_mem_size
48 << ", and the actual effective value will be " << available_device_mem_size;
49 }
50 // Reserve 1G for persistent_mem
51 if (available_device_mem_size > kDynamicMemAllocUnitSize) {
52 available_device_mem_size -= kDynamicMemAllocUnitSize;
53 }
54 size_t real_block_size = std::min(config_size, available_device_mem_size);
55 SetMemAllocUintSize(real_block_size, kDynamicMemAllocUnitSize);
56 return;
57 }
58
59 // set by default configuration
60 const auto graph_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
61 const bool is_graph_run_mode = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
62 if (graph_mode && is_graph_run_mode) {
63 SetMemAllocUintSize(ASCEND_COMMON_POOL_ALLOC_UNIT_SIZE_FOR_GRAPH_RUN_MODE,
64 ASCEND_COMMON_POOL_ALLOC_UNIT_SIZE_FOR_GRAPH_RUN_MODE);
65 } else {
66 SetMemAllocUintSize(kDynamicMemAllocUnitSize, kDynamicMemAllocUnitSize);
67 }
68 }
69
70 namespace {
NoAdditionalMemory()71 bool NoAdditionalMemory() {
72 auto context = MsContext::GetInstance();
73 MS_EXCEPTION_IF_NULL(context);
74 const auto is_cell_reuse = context->CellReuseLevel() != CellReuseLevel::kNoCellReuse;
75 const auto is_multi_graph_sink = context->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK);
76 const auto is_task_sink = context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
77 return (is_cell_reuse || is_multi_graph_sink) && is_task_sink;
78 }
79 } // namespace
80
CalMemBlockAllocSize(size_t size,bool from_persistent_mem,bool need_recycle)81 size_t AscendMemoryPool::CalMemBlockAllocSize(size_t size, bool from_persistent_mem, bool need_recycle) {
82 auto device_free_mem_size = free_mem_size();
83 if (device_free_mem_size < size && common::IsNeedProfileMemory()) {
84 device_free_mem_size = size;
85 }
86 if (device_free_mem_size < size) {
87 MS_LOG(INFO) << "The device memory is not enough, the free memory size is " << device_free_mem_size
88 << ", but the alloc size is " << size;
89 MS_LOG(INFO) << "The dynamic memory pool total size is "
90 << device::ascend::AscendMemoryPool::GetInstance().TotalMemStatistics() / kMBToByte
91 << "M, total used size is "
92 << device::ascend::AscendMemoryPool::GetInstance().TotalUsedMemStatistics() / kMBToByte
93 << "M, used peak size is "
94 << device::ascend::AscendMemoryPool::GetInstance().UsedMemPeakStatistics() / kMBToByte << "M.";
95 MS_LOG(INFO) << "Memory Statistics:" << AscendMemAdapter::GetInstance().DevMemStatistics();
96 return 0;
97 }
98
99 size_t alloc_mem_size;
100 SetMemPoolBlockSize(device_free_mem_size);
101 auto alloc_mem_unit_size = MemAllocUnitSize(from_persistent_mem);
102 if (need_recycle) {
103 alloc_mem_unit_size = kDynamicMemAllocUnitSize;
104 }
105 MS_LOG(DEBUG) << "Get unit block size " << alloc_mem_unit_size;
106 alloc_mem_size = alloc_mem_unit_size;
107
108 auto ms_context = MsContext::GetInstance();
109 MS_EXCEPTION_IF_NULL(ms_context);
110 const bool is_graph_run_mode = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
111 if (is_graph_run_mode) {
112 // Growing at adding alloc unit size
113 while (alloc_mem_size < size) {
114 alloc_mem_size = alloc_mem_size + alloc_mem_unit_size;
115 }
116 } else {
117 // Growing at twice of alloc unit size
118 constexpr size_t kDouble = 2;
119 while (alloc_mem_size < size) {
120 alloc_mem_size = alloc_mem_size * kDouble;
121 }
122 }
123
124 alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size);
125 if (NoAdditionalMemory() && !need_recycle) {
126 alloc_mem_size = std::min(alloc_mem_size, size);
127 }
128 return alloc_mem_size;
129 }
130
AllocDeviceMem(size_t size,DeviceMemPtr * addr)131 size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) {
132 MS_LOG(INFO) << "Malloc Memory for Pool, size: " << size;
133 if (size == 0) {
134 MS_LOG(EXCEPTION) << "Failed to alloc memory pool resource, the size is zero!";
135 }
136 *addr = AscendMemAdapter::GetInstance().MallocStaticDevMem(size);
137 if (*addr == nullptr) {
138 MS_LOG(EXCEPTION) << "Alloc device memory pool address is nullptr, failed to alloc memory pool resource!";
139 }
140 return size;
141 }
142
AllocOverflowTensorMem(size_t size,bool from_persistent_mem)143 DeviceMemPtr AscendMemoryPool::AllocOverflowTensorMem(size_t size, bool from_persistent_mem) {
144 size_t align_size = AlignMemorySize(size);
145 std::lock_guard<std::mutex> locker(mutex_);
146 auto iter = overflow_memory_info_map_.find(kGlobalOverflowWorkspace);
147 if (iter != overflow_memory_info_map_.cend()) {
148 return iter->second;
149 }
150 DeviceMemPtr overflow_memory_ptr = AllocTensorMem(align_size, from_persistent_mem);
151 MS_EXCEPTION_IF_NULL(overflow_memory_ptr);
152 auto acl_ret = CALL_ASCEND_API(aclrtMemset, overflow_memory_ptr, align_size, 0, align_size);
153 if (acl_ret != ACL_RT_SUCCESS) {
154 MS_LOG(EXCEPTION) << "Clear overflow memory failed, aclrtMemset size = " << align_size << ", ret = " << acl_ret;
155 }
156 (void)overflow_memory_info_map_.emplace(kGlobalOverflowWorkspace, overflow_memory_ptr);
157 return overflow_memory_ptr;
158 }
159
GetMaxUsedMemSize() const160 size_t AscendMemoryPool::GetMaxUsedMemSize() const {
161 void *min_used_addr = GetMinUsingMemoryAddr();
162 if (min_used_addr == nullptr) {
163 return 0;
164 }
165 auto max_used_hbm = AscendMemAdapter::GetInstance().GetMsUsedHbmSize();
166 size_t static_offset = reinterpret_cast<uint8_t *>(min_used_addr) - AscendMemAdapter::GetInstance().GetBaseAddr();
167 return LongToSize(max_used_hbm) - static_offset;
168 }
169
IsEnableEagerFree() const170 const bool AscendMemoryPool::IsEnableEagerFree() const {
171 return AscendGmemAdapter::GetInstance().is_eager_free_enabled();
172 }
173
SyncAllStreams()174 const bool AscendMemoryPool::SyncAllStreams() { return AscendStreamMng::GetInstance().SyncAllStreams(); }
175
AllocDeviceMemByEagerFree(size_t size,DeviceMemPtr * addr)176 size_t AscendMemoryPool::AllocDeviceMemByEagerFree(size_t size, DeviceMemPtr *addr) {
177 if (IsEnableVmm()) {
178 return AscendVmmAdapter::GetInstance().AllocDeviceMem(size, addr);
179 } else if (IsEnableEagerFree()) {
180 return AscendGmemAdapter::GetInstance().AllocDeviceMem(size, addr);
181 } else {
182 MS_LOG(EXCEPTION) << "Eager free and VMM are both disabled.";
183 }
184 }
185
FreeDeviceMemByEagerFree(const DeviceMemPtr addr,const size_t size)186 size_t AscendMemoryPool::FreeDeviceMemByEagerFree(const DeviceMemPtr addr, const size_t size) {
187 if (IsEnableVmm()) {
188 return AscendVmmAdapter::GetInstance().EagerFreeDeviceMem(addr, size);
189 } else if (IsEnableEagerFree()) {
190 return AscendGmemAdapter::GetInstance().EagerFreeDeviceMem(addr, size);
191 } else {
192 MS_LOG(EXCEPTION) << "Eager free and VMM are both disabled.";
193 }
194 }
195
MmapDeviceMem(const size_t size,const DeviceMemPtr addr)196 size_t AscendMemoryPool::MmapDeviceMem(const size_t size, const DeviceMemPtr addr) {
197 return AscendVmmAdapter::GetInstance().MmapDeviceMem(size, addr, total_mem_size());
198 }
199
FreeDeviceMem(const DeviceMemPtr & addr)200 bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) {
201 MS_EXCEPTION_IF_NULL(addr);
202 int64_t max_actual = ActualPeakStatistics();
203 MS_LOG(INFO) << "Max actual used memory size is " << max_actual;
204 AscendMemAdapter::GetInstance().UpdateActualPeakMemory(max_actual);
205 int64_t max_peak = UsedMemPeakStatistics();
206 MS_LOG(INFO) << "Max peak used memory size is " << max_peak;
207 AscendMemAdapter::GetInstance().UpdateUsedPeakMemory(max_peak);
208 return AscendMemAdapter::GetInstance().FreeStaticDevMem(addr);
209 }
210
ResetIdleMemBuf() const211 void AscendMemoryPool::ResetIdleMemBuf() const {
212 auto fn = [this](const MemStatusManagerPtr &mem_mng) {
213 MS_EXCEPTION_IF_NULL(mem_mng);
214 if (mem_mng->mem_block_list_.empty()) {
215 return;
216 }
217 const auto &stream_ids = mem_mng->GetStreamIds();
218 for (const auto stream_id : stream_ids) {
219 auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufIdle);
220 const auto &&iter = mem_mng->mem_bufs_.find(key);
221 if (iter == mem_mng->mem_bufs_.end()) {
222 continue;
223 }
224 const auto &mem_buf_map = iter->second;
225 for (auto &&idle_iter = mem_buf_map.begin(); idle_iter != mem_buf_map.end(); idle_iter++) {
226 auto &mem_buf = idle_iter->second;
227 MS_EXCEPTION_IF_NULL(mem_buf);
228 (void)CALL_ASCEND_API(aclrtMemset, mem_buf->device_addr_, mem_buf->size_, 0, mem_buf->size_);
229 }
230 }
231 };
232 fn(persistent_mem());
233 fn(common_mem());
234 }
235
free_mem_size()236 size_t AscendMemoryPool::free_mem_size() { return AscendMemAdapter::GetInstance().FreeDevMemSize(); }
237
total_mem_size() const238 uint64_t AscendMemoryPool::total_mem_size() const {
239 static constexpr uint64_t kMaxHbmSize = 1LL << 40;
240 if (common::IsNeedProfileMemory()) {
241 return kMaxHbmSize;
242 } else {
243 return AscendMemAdapter::GetInstance().MaxHbmSizeForMs();
244 }
245 }
246 } // namespace ascend
247 } // namespace device
248 } // namespace mindspore
249