1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
18 #include <algorithm>
19 #include <numeric>
20 #include <ostream>
21 #include <utility>
22 #include <string>
23 #include "include/backend/mem_reuse/mem_tracker.h"
24 #include "include/common/utils/utils.h"
25 #include "utils/log_adapter.h"
26 #include "utils/ms_context.h"
27 #include "utils/convert_utils_base.h"
28 #include "utils/ms_utils.h"
29 #ifdef ENABLE_DEBUGGER
30 #include "plugin/device/cpu/hal/profiler/cpu_profiling.h"
31 #endif
32
33 namespace mindspore {
34 namespace device {
35 static const char kPersistentParamMem[] = "Persistent mem";
36 static const char kCommonMem[] = "Common mem";
37 constexpr size_t kGBToByte = 1024 << 20;
38 // The smallest memory request size, if it is smaller than this size, the device memory request may fail
39 // Set experience value to 10M
40 const size_t kMinimumAllocMem = 10 << 20;
41
42 thread_local AllocatorDebugInfo DynamicMemAllocatorDebugInfo::debug_info_;
43
44 const char kBlockMemorySize[] = "block_memory_size";
45 const char kBlockStreamId[] = "block_stream_id";
46 const char kCommonMemPoolType[] = "common_mem_pool";
47 const char kPersistentMemPoolType[] = "persistent_mem_pool";
48
49 static const std::map<DynamicMemBufStatus, std::string> kBufStatusString = {
50 {DynamicMemBufStatus::kMemBufIdle, "idle"},
51 {DynamicMemBufStatus::kMemBufUsed, "used"},
52 {DynamicMemBufStatus::kMemBufEagerFree, "eager_free"},
53 {DynamicMemBufStatus::kMemBufUsedByEvent, "used_by_event"}};
54
55 static const std::map<AllocatorType, std::string> kAllocatorTypeString = {
56 {AllocatorType::kWeight, "weight"},
57 {AllocatorType::kConstantValue, "constant value"},
58 {AllocatorType::kKernelOutput, "kernel output"},
59 {AllocatorType::kGraphOutput, "graph output"},
60 {AllocatorType::kWorkspace, "workspace"},
61 {AllocatorType::kOther, "other"},
62 };
63
~DynamicMemPoolBestFit()64 DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
65 persistent_mem_->Clear();
66 common_mem_->Clear();
67 stream_pair_addresses_.clear();
68 }
69
update_border_addr(DeviceMemPtr left_addr,DeviceMemPtr right_addr)70 void DynamicMemBlock::update_border_addr(DeviceMemPtr left_addr, DeviceMemPtr right_addr) {
71 if (min_addr_ == nullptr) {
72 min_addr_ = left_addr;
73 } else {
74 min_addr_ = std::min(min_addr_, left_addr);
75 }
76 if (max_addr_ == nullptr) {
77 max_addr_ = right_addr;
78 } else {
79 max_addr_ = std::max(max_addr_, right_addr);
80 }
81 }
82
get_actual_peak()83 size_t DynamicMemBlock::get_actual_peak() {
84 if (min_addr_ == nullptr || max_addr_ == nullptr) {
85 return 0;
86 }
87 int64_t actual_memory = reinterpret_cast<uint8_t *>(max_addr_) - reinterpret_cast<uint8_t *>(min_addr_);
88 return actual_memory;
89 }
90
AllocTensorMem(size_t size,bool from_persistent_mem,bool need_recycle,uint32_t stream_id)91 DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size, bool from_persistent_mem, bool need_recycle,
92 uint32_t stream_id) {
93 if (stream_id == UINT32_MAX) {
94 MS_LOG(DEBUG) << "Rewrite stream id from INT32 MAX to 0.";
95 stream_id = kDefaultStreamIndex;
96 }
97 size_t align_size = AlignMemorySize(size);
98 #ifdef __APPLE__
99 std::lock_guard<SpinLock> spin_lock(spin_lock_);
100 #else
101 std::lock_guard<std::mutex> locker(mutex_);
102 #endif
103 // Find the memory buf by tensor size, if not find, then add new memory block and memory buf.
104 DeviceMemPtr device_addr = FindAvailableMemBuf(align_size, from_persistent_mem, stream_id);
105 static bool init_recycle_memory = false;
106 if (need_recycle && !init_recycle_memory) {
107 // Force persist memory to be reserved when recycle memory is allocated for the first time
108 init_recycle_memory = true;
109 MS_LOG(INFO) << "Init Recycle Memory";
110 device_addr = nullptr;
111 }
112 if (device_addr == nullptr) {
113 device_addr = AddMemBlockAndMemBuf(align_size, from_persistent_mem, need_recycle, stream_id);
114
115 if (device_addr == nullptr) {
116 MS_LOG(INFO) << "Alloc tensor mem failed and try to sync all events to release memory.";
117 SyncAllEventsInner();
118 device_addr = FindAvailableMemBuf(align_size, from_persistent_mem, stream_id);
119 }
120
121 // Alloc memory failed and dump the info.
122 if (!device_addr) {
123 DumpDynamicMemPoolStateInfo();
124 }
125 }
126
127 // report memory data to profiler
128 #ifdef ENABLE_DEBUGGER
129 static auto profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
130 MS_EXCEPTION_IF_NULL(profiler_inst);
131 if (profiler_inst->GetEnableFlag() && profiler_inst->GetProfileMemoryFlag()) {
132 profiler_inst->RecordMemoryPoolInfo(TotalUsedMemStatistics(), TotalMemStatistics(),
133 TotalUsedByEventMemStatistics());
134 }
135 #endif
136
137 if (common::IsNeedProfileMemory()) {
138 MS_LOG(WARNING) << "Need Profile Memory, Memory pool alloc, total mem: " << TotalMemStatistics()
139 << ", peak mem: " << UsedMemPeakStatistics() << ", in use mem: " << TotalUsedMemStatistics()
140 << ", used by event mem: " << TotalUsedByEventMemStatistics()
141 << ", device address addr: " << device_addr << ", size: " << size
142 << ", from persistent mem: " << from_persistent_mem << ", need recycle: " << need_recycle;
143 }
144 if (device_addr != nullptr) {
145 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
146 device::tracker::CALL_MEMORY_TRACKER(AllocMemBlock, device_addr, align_size, GetMemoryPoolType(),
147 ActualPeakStatistics(), TotalUsedMemStatistics(), TotalMemStatistics(),
148 stream_id);
149 }
150 if (IsMemoryPoolRecycle()) {
151 (void)mem_bufs_.insert(device_addr);
152 }
153 }
154 MS_LOG(DEBUG) << "Alloc memory details, name:" << DynamicMemAllocatorDebugInfo::GetDebugInfo().name_
155 << ", persistent_mem:" << from_persistent_mem << ", stream id: " << stream_id
156 << ", address:" << device_addr << ", size:" << size << "B, total allocated mem:" << TotalMemStatistics()
157 << "B, peak used mem:" << UsedMemPeakStatistics() << "B, in used mem:" << TotalUsedMemStatistics()
158 << "B, used by event mem:" << TotalUsedByEventMemStatistics()
159 << "B, actual peak used mem:" << ActualPeakStatistics()
160 << "B, total idle mem:" << TotalIdleMemStatistics() << "B.";
161 return device_addr;
162 }
163
AllocContinuousTensorMem(const std::vector<size_t> & size_list,uint32_t stream_id)164 std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(const std::vector<size_t> &size_list,
165 uint32_t stream_id) {
166 std::vector<DeviceMemPtr> device_addr_list;
167 size_t total_size = std::accumulate(size_list.begin(), size_list.end(), IntToSize(0));
168 // Pre-alloc the one whole piece memory.
169 auto device_addr = AllocTensorMem(total_size, false, false, stream_id);
170 if (!device_addr) {
171 return device_addr_list;
172 }
173 #ifdef __APPLE__
174 std::lock_guard<SpinLock> spin_lock(spin_lock_);
175 #else
176 std::lock_guard<std::mutex> locker(mutex_);
177 #endif
178 // Remove the pre-alloc memory.
179 auto mem_block = FindMemBlock(device_addr, common_mem_);
180 if (mem_block == nullptr) {
181 mem_block = FindMemBlock(device_addr, persistent_mem_);
182 }
183 MS_EXCEPTION_IF_NULL(mem_block);
184 const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
185 if (iter == mem_block->block_all_mem_buf_map_.end()) {
186 DumpDynamicMemPoolDebugInfo();
187 MS_LOG(INTERNAL_EXCEPTION) << "Can't find the device address[" << device_addr << "].";
188 }
189 auto mem_buf = iter->second;
190 MS_EXCEPTION_IF_NULL(mem_buf);
191 if (mem_buf->size_ < total_size) {
192 DumpDynamicMemPoolDebugInfo();
193 MS_LOG(EXCEPTION) << "The size of membuf is less than total_size.";
194 }
195 auto rest_size = mem_buf->size_ - total_size;
196 (void)mem_block->block_all_mem_buf_map_.erase(iter);
197 // Split the pre-alloc memory into continuous memory by the size list.
198 DynamicMemBufPtr continuous_mem_buf;
199 auto buf_addr = device_addr;
200 for (size_t i : size_list) {
201 continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, DynamicMemBufStatus::kMemBufUsed, i, stream_id,
202 DynamicMemAllocatorDebugInfo::GetDebugInfo().name_,
203 DynamicMemAllocatorDebugInfo::GetDebugInfo().type_);
204 MS_EXCEPTION_IF_NULL(continuous_mem_buf);
205 (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
206 mem_block->update_border_addr(mem_buf->device_addr_, AddressOffset(mem_buf->device_addr_, mem_buf->size_));
207 device_addr_list.emplace_back(buf_addr);
208 buf_addr = AddressOffset(buf_addr, i);
209 }
210 // Update the size of the last memory buf.
211 continuous_mem_buf->size_ += rest_size;
212 return device_addr_list;
213 }
214
AlignMemorySize(size_t size) const215 size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const {
216 if (size == 0) {
217 return kDynamicMemAlignSize;
218 }
219 return ((size + kDynamicMemAlignSize - 1) / kDynamicMemAlignSize) * kDynamicMemAlignSize;
220 }
221
FindAvailableMemBuf(size_t size,bool from_persistent_mem,uint32_t stream_id)222 DeviceMemPtr DynamicMemPoolBestFit::FindAvailableMemBuf(size_t size, bool from_persistent_mem, uint32_t stream_id) {
223 auto addr = FindMemBufByStatus(size, from_persistent_mem, DynamicMemBufStatus::kMemBufIdle, stream_id);
224 if (addr == nullptr && is_trigger_eager_free_) {
225 MS_LOG(DEBUG) << "Find idle mem buf failed and eager free is enabled, try to search in eager free bufs.";
226 // Check total used max memory limits, since real occupy memory size equals to used mem size plus idle mem size.
227 // Eager free mem may occupy some memory, so total_mem_size need multiply by a factor.
228 float threshold_factor = 0.8f;
229 size_t threshold = IsEnableVmm() ? total_mem_size() : static_cast<size_t>(total_mem_size() * threshold_factor);
230 if (TotalUsedMemStatistics() + TotalUsedByEventMemStatistics() + TotalIdleMemStatistics() + size <= threshold) {
231 addr = FindMemBufByStatus(size, from_persistent_mem, DynamicMemBufStatus::kMemBufEagerFree, stream_id);
232 }
233 }
234 return addr;
235 }
236
FindMemBufByStatus(size_t size,bool from_persistent_mem,DynamicMemBufStatus target_status,uint32_t stream_id)237 DeviceMemPtr DynamicMemPoolBestFit::FindMemBufByStatus(size_t size, bool from_persistent_mem,
238 DynamicMemBufStatus target_status, uint32_t stream_id) {
239 auto addr = FindMemBufInSpecifiedMng(size, from_persistent_mem, target_status, stream_id);
240 if (addr == nullptr && !IsEnableVmm()) {
241 if (from_persistent_mem && !persistent_mem_->mem_block_list_.empty()) {
242 MS_LOG(DEBUG) << "Find mem buf in current pool failed, try to find in another one.";
243 addr = FindMemBufInSpecifiedMng(size, !from_persistent_mem, target_status, stream_id);
244 }
245 }
246 return addr;
247 }
248
FindMemBufInSpecifiedMng(size_t size,bool from_persistent_mem,DynamicMemBufStatus target_status,uint32_t stream_id)249 DeviceMemPtr DynamicMemPoolBestFit::FindMemBufInSpecifiedMng(size_t size, bool from_persistent_mem,
250 DynamicMemBufStatus target_status, uint32_t stream_id) {
251 auto &mem_mng = from_persistent_mem ? persistent_mem_ : common_mem_;
252 auto &mem_buf_map = mem_mng->GetOrCreateMemBufMap(stream_id, target_status);
253 auto iter = mem_buf_map.lower_bound(size);
254 if (iter != mem_buf_map.end()) {
255 if (IsMemoryPoolRecycle()) {
256 // Ensure that the addresses corresponding to the same Tensor for each step are consistent, making the memory pool
257 // recycling function more stable.
258 auto find_size = iter->first;
259 // Can be optimized in the future.
260 auto [lb, ub] = mem_buf_map.equal_range(find_size);
261 for (auto i = lb; i != ub; ++i) {
262 if (i->second->device_addr_ > iter->second->device_addr_) {
263 iter = i;
264 }
265 }
266 }
267 auto mem_buf = iter->second;
268 MS_EXCEPTION_IF_NULL(mem_buf);
269 if (mem_buf->status_ != target_status) {
270 DumpDynamicMemPoolDebugInfo();
271 MS_LOG(EXCEPTION) << "Mem_buf is not " << target_status << ", alloc_size[" << size << "] mem_buf_size["
272 << mem_buf->size_ << "] mem_buf_address[" << mem_buf->device_addr_ << "].";
273 }
274 mem_buf->allocator_name_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().name_;
275 mem_buf->allocator_type_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().type_;
276 if (mem_buf->status_ == DynamicMemBufStatus::kMemBufEagerFree && IsEnableVmm()) {
277 MS_LOG(DEBUG) << "Find eager free memory, mem_buf_size[" << mem_buf->size_ << "] mem_buf_address["
278 << mem_buf->device_addr_ << "], need size: " << size;
279 auto ret = MmapDeviceMem(size, mem_buf->device_addr_);
280 if (ret != size) {
281 return nullptr;
282 }
283 }
284 // Remove map of old idle memory buf
285 (void)mem_buf_map.erase(iter);
286 // Divide memory buf
287 if (IsSplit(size, mem_buf->size_)) {
288 SplitMemBuf(size, mem_buf, mem_mng, stream_id);
289 }
290 auto mem_block = FindMemBlock(mem_buf->device_addr_, mem_mng);
291 MS_EXCEPTION_IF_NULL(mem_block);
292 mem_block->update_border_addr(mem_buf->device_addr_, AddressOffset(mem_buf->device_addr_, mem_buf->size_));
293 mem_buf->status_ = DynamicMemBufStatus::kMemBufUsed;
294 // Memory statistics
295 mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
296 mem_mng->mps_.UpdatePeakSize();
297 if (target_status == DynamicMemBufStatus::kMemBufIdle) {
298 mem_mng->mps_.total_idle_mem_size_ -= mem_buf->size_;
299 } else if (target_status == DynamicMemBufStatus::kMemBufEagerFree) {
300 mem_mng->mps_.total_eager_free_mem_size_ -= mem_buf->size_;
301 }
302 return mem_buf->device_addr_;
303 }
304 return nullptr;
305 }
306
MemAllocUnitSize(bool from_persistent_mem) const307 size_t DynamicMemPoolBestFit::MemAllocUnitSize(bool from_persistent_mem) const {
308 return from_persistent_mem ? persistent_mem_->unit_size_ : common_mem_->unit_size_;
309 }
310
SetMemAllocUintSize(size_t common_size,size_t persist_size)311 void DynamicMemPoolBestFit::SetMemAllocUintSize(size_t common_size, size_t persist_size) {
312 persistent_mem_->unit_size_ = persist_size;
313 common_mem_->unit_size_ = common_size;
314 config_unit_size_ = common_size;
315 MS_LOG(DEBUG) << "Set mem alloc unit size, common " << common_size << " persistent " << persist_size;
316 }
317
GetMinUsingMemoryAddr() const318 void *DynamicMemPoolBestFit::GetMinUsingMemoryAddr() const {
319 if (mem_bufs_.empty()) {
320 return nullptr;
321 }
322 return *(mem_bufs_.begin());
323 }
324
SetMemPoolBlockSize(size_t available_device_mem_size)325 void DynamicMemPoolBestFit::SetMemPoolBlockSize(size_t available_device_mem_size) {
326 auto ms_context = MsContext::GetInstance();
327 MS_EXCEPTION_IF_NULL(ms_context);
328 float mem_block_size = ms_context->get_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE);
329 if (mem_block_size == kDefaultMempoolBlockSize) {
330 return;
331 }
332
333 size_t config_size = FloatToSize(mem_block_size * kGBToByte);
334 if (config_size > available_device_mem_size) {
335 MS_LOG(WARNING) << "Memory pool block size " << config_size << " is bigger than currently available maximum memory "
336 << available_device_mem_size << ", and the actual effective value will be "
337 << available_device_mem_size;
338 }
339 // Reserve 1G for persistent_mem
340 if (available_device_mem_size > kGBToByte) {
341 available_device_mem_size -= kGBToByte;
342 }
343 size_t real_block_size = std::min(config_size, available_device_mem_size);
344 SetMemAllocUintSize(real_block_size);
345 }
346
AddMemBlockAndMemBuf(size_t size,bool from_persistent_mem,bool need_recycle,uint32_t stream_id)347 DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size, bool from_persistent_mem, bool need_recycle,
348 uint32_t stream_id) {
349 if (from_persistent_mem && !need_recycle && !persistent_mem_->Empty()) {
350 from_persistent_mem = false;
351 }
352
353 // Try eager free routine.
354 if (IsEnableVmm() || is_trigger_eager_free_) {
355 is_trigger_eager_free_ = true;
356 return AddMemBlockAndMemBufByEagerFree(size, from_persistent_mem, stream_id);
357 }
358
359 size_t alloc_mem_size = CalMemBlockAllocSize(size, from_persistent_mem, need_recycle);
360 MS_LOG(DEBUG) << "CalMemBlockAllocSize return : " << size << ", alloc_mem_size : " << alloc_mem_size;
361 if (alloc_mem_size == 0) {
362 if (auto device_addr = FindAvailableMemBuf(size, !from_persistent_mem, stream_id)) {
363 return device_addr;
364 }
365 if (IsEnableEagerFree()) {
366 is_trigger_eager_free_ = true;
367 return AddMemBlockAndMemBufByEagerFree(size, from_persistent_mem, stream_id);
368 }
369 return nullptr;
370 }
371
372 // Add new memory block
373 DeviceMemPtr device_addr = nullptr;
374 auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr);
375 if (real_alloc_size < size) {
376 MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
377 << "].";
378 return nullptr;
379 }
380 // If unit_size is changed by other function(not context), change unit_size back
381 MS_EXCEPTION_IF_NULL(common_mem_);
382 common_mem_->unit_size_ = config_unit_size_;
383
384 return CreateMemBlockAndMemBuf(size, from_persistent_mem, device_addr, real_alloc_size,
385 DynamicMemBufStatus::kMemBufIdle, stream_id);
386 }
387
AddMemBlockAndMemBufByEagerFree(size_t size,bool from_persistent_mem,uint32_t stream_id)388 DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBufByEagerFree(size_t size, bool from_persistent_mem,
389 uint32_t stream_id) {
390 // Check used max memory limits.
391 if (TotalUsedMemStatistics() + TotalUsedByEventMemStatistics() + size > total_mem_size()) {
392 MS_LOG(ERROR) << "TotalUsedMemStatistics : " << TotalUsedMemStatistics()
393 << " plus TotalUsedByEventMemStatistics : " << TotalUsedByEventMemStatistics()
394 << " and plus alloc size : " << size << " is more than total mem size : " << total_mem_size() << ".";
395 return nullptr;
396 }
397
398 MS_LOG(DEBUG) << "Try to eager free memory.";
399 if (!SyncAllStreams()) {
400 MS_LOG(INTERNAL_EXCEPTION) << "Sync all streams failed.";
401 }
402 FreeIdleMemsByEagerFree();
403 auto mem_addr = FindMemBufByStatus(size, from_persistent_mem, DynamicMemBufStatus::kMemBufEagerFree, stream_id);
404 if (mem_addr != nullptr) {
405 MS_LOG(DEBUG) << "Find eager free memory success, mem_addr : " << mem_addr << ".";
406 return mem_addr;
407 }
408
409 auto alloc_size = std::max(size, static_cast<size_t>(total_mem_size()));
410 MS_LOG(INFO) << "Try to alloc eager free mem block, size : " << size << ", alloc_size : " << alloc_size << ".";
411 DeviceMemPtr device_addr = nullptr;
412 auto real_alloc_size = AllocDeviceMemByEagerFree(alloc_size, &device_addr);
413 if (real_alloc_size < alloc_size) {
414 MS_LOG(ERROR) << "AllocDeviceMemByEagerFree failed, alloc_size : " << real_alloc_size << ".";
415 return nullptr;
416 }
417 return CreateMemBlockAndMemBuf(size, from_persistent_mem, device_addr, real_alloc_size,
418 DynamicMemBufStatus::kMemBufEagerFree, stream_id);
419 }
420
CreateMemBlockAndMemBuf(size_t size,bool from_persistent_mem,DeviceMemPtr source_addr,size_t source_size,DynamicMemBufStatus mem_buf_status,uint32_t stream_id)421 DeviceMemPtr DynamicMemPoolBestFit::CreateMemBlockAndMemBuf(size_t size, bool from_persistent_mem,
422 DeviceMemPtr source_addr, size_t source_size,
423 DynamicMemBufStatus mem_buf_status, uint32_t stream_id) {
424 auto mem_block = std::make_shared<DynamicMemBlock>(source_addr, source_size, stream_id);
425 auto mem_mng = from_persistent_mem ? persistent_mem_ : common_mem_;
426 mem_mng->AddMemBlock(mem_block, stream_id);
427 // Add new memory buf.
428 auto mem_buf = std::make_shared<DynamicMemBuf>(mem_block->device_addr(), mem_buf_status, mem_block->size(), stream_id,
429 DynamicMemAllocatorDebugInfo::GetDebugInfo().name_,
430 DynamicMemAllocatorDebugInfo::GetDebugInfo().type_);
431 if (mem_buf->status_ == DynamicMemBufStatus::kMemBufEagerFree && IsEnableVmm()) {
432 MS_LOG(DEBUG) << "Find eager free memory, mem_buf_size[" << mem_buf->size_ << "] mem_buf_address["
433 << mem_buf->device_addr_ << "], need size: " << size;
434 auto ret = MmapDeviceMem(size, mem_buf->device_addr_);
435 if (ret != size) {
436 return nullptr;
437 }
438 }
439 // Add map of new memory buf in the block
440 (void)mem_block->block_all_mem_buf_map_.emplace(mem_block->device_addr(), mem_buf);
441 // Split memory buf
442 if (IsSplit(size, mem_buf->size_)) {
443 SplitMemBuf(size, mem_buf, mem_mng, stream_id);
444 }
445 mem_block->update_border_addr(mem_buf->device_addr_, AddressOffset(mem_buf->device_addr_, mem_buf->size_));
446 mem_buf->status_ = DynamicMemBufStatus::kMemBufUsed;
447 // Memory statistics
448 mem_mng->mps_.total_mem_size_ += mem_block->size();
449 mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
450 mem_mng->mps_.UpdatePeakSize();
451 if (mem_buf_status == DynamicMemBufStatus::kMemBufIdle) {
452 mem_mng->mps_.total_idle_mem_size_ += source_size - mem_buf->size_;
453 } else if (mem_buf_status == DynamicMemBufStatus::kMemBufEagerFree) {
454 mem_mng->mps_.total_eager_free_mem_size_ += source_size - mem_buf->size_;
455 } else {
456 MS_LOG(INTERNAL_EXCEPTION) << "Unsupported mem_buf_status : " << mem_buf_status << ".";
457 }
458 MS_LOG(DEBUG) << "Usage: used size : " << TotalUsedMemStatistics()
459 << ", used by event size : " << TotalUsedByEventMemStatistics()
460 << ", idle size : " << TotalIdleMemStatistics()
461 << ", eager free size : " << TotalEagerFreeMemStatistics() << ".";
462 return mem_buf->device_addr_;
463 }
464
CalMemBlockAllocSize(size_t size,bool from_persistent_mem,bool)465 size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size, bool from_persistent_mem, bool) {
466 auto device_free_mem_size = free_mem_size();
467 if (device_free_mem_size < size && common::IsNeedProfileMemory()) {
468 device_free_mem_size = size;
469 }
470 if (device_free_mem_size < size) {
471 MS_LOG(INFO) << "Memory not enough: current free memory size[" << device_free_mem_size
472 << "] is smaller than required size[" << size << "].";
473 return 0;
474 }
475 // The memory of the device is too small, which may cause the new application to fail.
476 if (device_free_mem_size < kMinimumAllocMem) {
477 MS_LOG(INFO) << "Device memory size [" << device_free_mem_size << "] is smaller than minimum alloc size ["
478 << kMinimumAllocMem << "].";
479 return 0;
480 }
481 auto alloc_mem_size = MemAllocUnitSize(from_persistent_mem);
482 // Growing at twice of alloc size
483 constexpr size_t kDouble = 2;
484 while (alloc_mem_size < size) {
485 alloc_mem_size = alloc_mem_size * kDouble;
486 }
487 alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size);
488 return alloc_mem_size;
489 }
490
FreeIdleMemsByEagerFree()491 const size_t DynamicMemPoolBestFit::FreeIdleMemsByEagerFree() {
492 eager_free_count_++;
493
494 auto eager_free_mem_func = [&](MemStatusManagerPtr &mem_mng) {
495 const auto &stream_ids = mem_mng->GetStreamIds();
496 for (const auto &stream_id : stream_ids) {
497 auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufIdle);
498 auto &&iter = mem_mng->mem_bufs_.find(key);
499 if (iter == mem_mng->mem_bufs_.end()) {
500 continue;
501 }
502 auto &mem_buf_map = iter->second;
503 for (auto &size_mem_buf : mem_buf_map) {
504 auto &mem_buf = size_mem_buf.second;
505 auto [mem_block, iter, mem_mng] = FindByStrictAddr(mem_buf->device_addr_);
506 if (PreCombineMemBuf(mem_buf, mem_mng)) {
507 CombineMemBuf(mem_block, iter, mem_mng, DynamicMemBufStatus::kMemBufIdle,
508 DynamicMemBufStatus::kMemBufEagerFree);
509 }
510 }
511 mem_mng->mem_bufs_.erase(iter);
512 }
513 // After memory free idle, do eager free.
514 size_t free_size = 0;
515 size_t real_free_size = 0;
516 for (const auto &stream_id : stream_ids) {
517 auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufEagerFree);
518 auto &&iter = mem_mng->mem_bufs_.find(key);
519 if (iter == mem_mng->mem_bufs_.end()) {
520 continue;
521 }
522 auto &mem_buf_map = iter->second;
523 for (auto &size_mem_buf : mem_buf_map) {
524 auto &mem_buf = size_mem_buf.second;
525 free_size += mem_buf->size_;
526 MS_LOG(DEBUG) << "Eager free address : " << mem_buf->device_addr_ << ".";
527 real_free_size += FreeDeviceMemByEagerFree(mem_buf->device_addr_, mem_buf->size_);
528 }
529 }
530
531 return std::make_pair(free_size, real_free_size);
532 };
533
534 const auto [persistent_free_size, persistent_real_free_size] = eager_free_mem_func(persistent_mem_);
535 const auto [common_free_size, common_real_free_size] = eager_free_mem_func(common_mem_);
536 auto free_size = persistent_free_size + common_free_size;
537 auto real_free_size = persistent_real_free_size + common_real_free_size;
538 static bool is_enable_memory_statistics = common::IsEnableRuntimeConfig(common::kRuntimeMemoryStat);
539 if (is_enable_memory_statistics) {
540 std::cout << "Total eager free memory : " << free_size << ", real free : " << real_free_size
541 << ", not free size: " << (free_size - real_free_size) << "." << std::endl;
542 }
543 MS_LOG(INFO) << "Eager free count : " << eager_free_count_ << ", free memory : " << free_size
544 << ", real free : " << real_free_size << ", not free size: " << (free_size - real_free_size) << ".";
545 return real_free_size;
546 }
547
DefragMemory()548 void DynamicMemPoolBestFit::DefragMemory() {
549 MS_LOG(DEBUG) << "Start defrag memory.";
550 #ifdef __APPLE__
551 std::lock_guard<SpinLock> spin_lock(spin_lock_);
552 #else
553 std::lock_guard<std::mutex> locker(mutex_);
554 #endif
555
556 // eager free count initialize with 0, and increase by initializing persistent pool and common pool.
557 if (eager_free_count_ <= 2L) {
558 MS_LOG(DEBUG) << "Exit defrag memory since eager free count is 0.";
559 return;
560 }
561 if (last_eager_free_count_ == eager_free_count_) {
562 MS_LOG(DEBUG) << "Exit defrag memory since last eager free count equals to eager free count : "
563 << last_eager_free_count_ << ".";
564 return;
565 }
566
567 MS_LOG(INFO) << "Try to defrag memory.";
568 if (!SyncAllStreams()) {
569 MS_LOG(INTERNAL_EXCEPTION) << "Sync all streams failed.";
570 }
571 FreeIdleMemsByEagerFree();
572 last_eager_free_count_ = eager_free_count_;
573 }
574
IsSplit(size_t tensor_size,size_t mem_buf_size) const575 bool DynamicMemPoolBestFit::IsSplit(size_t tensor_size, size_t mem_buf_size) const {
576 return mem_buf_size - tensor_size >= kDynamicMemAlignSize;
577 }
578
SplitMemBuf(size_t size,const DynamicMemBufPtr & mem_buf,const MemStatusManagerPtr & mem_mng,uint32_t stream_id)579 void DynamicMemPoolBestFit::SplitMemBuf(size_t size, const DynamicMemBufPtr &mem_buf,
580 const MemStatusManagerPtr &mem_mng, uint32_t stream_id) {
581 MS_EXCEPTION_IF_NULL(mem_buf);
582 MS_EXCEPTION_IF_NULL(mem_mng);
583 const auto &mem_block = FindMemBlock(mem_buf->device_addr_, mem_mng);
584 MS_EXCEPTION_IF_NULL(mem_block);
585 // Divide new memory buf
586 if (mem_buf->size_ < size) {
587 DumpDynamicMemPoolDebugInfo();
588 MS_LOG(EXCEPTION) << "The size of membuf is less than size.";
589 }
590 size_t newbuf_size = mem_buf->size_ - size;
591 mem_buf->size_ = size;
592 DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size);
593 auto new_mem_buf = std::make_shared<DynamicMemBuf>(newbuf_addr, mem_buf->status_, newbuf_size, stream_id);
594 // Add map of new memory buf in the block
595 (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
596 mem_mng->AddMemBuf(new_mem_buf);
597 }
598
CmpMemBlock(const DeviceMemPtr & device_addr,const DynamicMemBlockPtr & mem_block)599 bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
600 MS_EXCEPTION_IF_NULL(device_addr);
601 MS_EXCEPTION_IF_NULL(mem_block);
602 return device_addr < mem_block->device_addr();
603 }
604
FindMemBlock(const DeviceMemPtr & device_addr,const MemStatusManagerPtr & mem_mng) const605 DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr,
606 const MemStatusManagerPtr &mem_mng) const {
607 MS_EXCEPTION_IF_NULL(device_addr);
608 MS_EXCEPTION_IF_NULL(mem_mng);
609 auto &&iter =
610 std::upper_bound(mem_mng->mem_block_list_.begin(), mem_mng->mem_block_list_.end(), device_addr, CmpMemBlock);
611 if (iter != mem_mng->mem_block_list_.begin()) {
612 return *(--iter);
613 }
614 return nullptr;
615 }
616
FreeTensorMem(const DeviceMemPtr & device_addr)617 void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
618 #ifdef __APPLE__
619 std::lock_guard<SpinLock> spin_lock(spin_lock_);
620 #else
621 std::lock_guard<std::mutex> locker(mutex_);
622 #endif
623 FreeTensorMemInner(device_addr);
624 }
625
FreeTensorMemInner(const DeviceMemPtr & device_addr)626 void DynamicMemPoolBestFit::FreeTensorMemInner(const DeviceMemPtr &device_addr) {
627 auto [mem_block, iter, mem_mng] = FindByStrictAddr(device_addr);
628 if (mem_block == nullptr) {
629 // Maybe destroy the memory pool first, then destroy the address, so this is normal case.
630 MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
631 return;
632 }
633 auto mem_buf = iter->second;
634 MS_EXCEPTION_IF_NULL(mem_buf);
635 if (PreCombineMemBuf(mem_buf, mem_mng)) {
636 CombineMemBuf(mem_block, iter, mem_mng, mem_buf->status_, DynamicMemBufStatus::kMemBufIdle);
637 if (IsMemoryPoolRecycle()) {
638 (void)mem_bufs_.erase(device_addr);
639 }
640 MS_LOG(DEBUG) << "Free memory details, name:" << DynamicMemAllocatorDebugInfo::GetDebugInfo().name_
641 << ", address:" << device_addr << ", total allocated mem:" << TotalMemStatistics()
642 << "B, peak used mem:" << UsedMemPeakStatistics() << "B, in used mem:" << TotalUsedMemStatistics()
643 << "B, used by event mem:" << TotalUsedByEventMemStatistics()
644 << "B, actual peak used mem:" << ActualPeakStatistics()
645 << "B, total idle mem:" << TotalIdleMemStatistics() << "B.";
646 }
647 }
648
649 // PreCombineMemBuf judge status for mem buf can be combined or not.
650 // If there are no events recorded on mem buf, return true to release mem buf.
651 // If there are events recorded on mem buf, change status of mem buf to kMemBufUsedByEvent, and return false.
652 // Note: Before release mem buf by event, must make share that the status of mem buf is kMemBufUsedByEvent,
653 // or wait event may release mem buf incorrectly.
PreCombineMemBuf(const DynamicMemBufPtr & mem_buf,const MemStatusManagerPtr & mem_mng)654 bool DynamicMemPoolBestFit::PreCombineMemBuf(const DynamicMemBufPtr &mem_buf, const MemStatusManagerPtr &mem_mng) {
655 auto device_addr = mem_buf->device_addr_;
656 if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsed && !mem_buf->IsEventNotUsed()) {
657 mem_buf->status_ = DynamicMemBufStatus::kMemBufUsedByEvent;
658 mem_mng->mps_.total_used_mem_size_ -= mem_buf->size_;
659 mem_mng->mps_.total_used_by_event_mem_size_ += mem_buf->size_;
660 MS_LOG(DEBUG) << "Combine mem buf exit since mem buf is used by event, device_addr : " << device_addr
661 << ", used by event mem size : " << mem_mng->mps_.total_used_by_event_mem_size_ << ".";
662 return false;
663 }
664
665 if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsedByEvent && !mem_buf->IsEventNotUsed()) {
666 MS_LOG(INTERNAL_EXCEPTION) << "Combine mem buf failed as mem buf can not be freed, device_addr : " << device_addr
667 << ".";
668 }
669
670 MS_LOG(DEBUG) << "Pre combine mem buf address : " << mem_buf->device_addr_ << " success.";
671 return true;
672 }
673
CombineMemBuf(const DynamicMemBlockPtr & mem_block,const DeviceAddrMapMemBuf::iterator & iter,const MemStatusManagerPtr & mem_mng,DynamicMemBufStatus origin_status,DynamicMemBufStatus target_status)674 void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block,
675 const DeviceAddrMapMemBuf::iterator &iter, const MemStatusManagerPtr &mem_mng,
676 DynamicMemBufStatus origin_status, DynamicMemBufStatus target_status) {
677 const auto &mem_buf = iter->second;
678 MS_LOG(DEBUG) << "Combine mem buf release mem buf, device_addr : " << mem_buf->device_addr_ << ".";
679
680 // report memory data to profiler
681 #ifdef ENABLE_DEBUGGER
682 static auto profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
683 MS_EXCEPTION_IF_NULL(profiler_inst);
684 if (profiler_inst->GetEnableFlag() && profiler_inst->GetProfileMemoryFlag()) {
685 profiler_inst->RecordMemoryPoolInfo(TotalUsedMemStatistics(), TotalMemStatistics(),
686 TotalUsedByEventMemStatistics());
687 }
688 #endif
689
690 if (common::IsNeedProfileMemory()) {
691 MS_LOG(WARNING) << "Need Profile Memory, Memory pool free, total mem: " << TotalMemStatistics()
692 << ", peak mem: " << UsedMemPeakStatistics() << ", in use mem: " << TotalUsedMemStatistics()
693 << ", used by event mem: " << TotalUsedByEventMemStatistics()
694 << ", device address addr: " << mem_buf->device_addr_ << ", size: " << mem_buf->size_;
695 }
696 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled() &&
697 target_status == DynamicMemBufStatus::kMemBufIdle) {
698 device::tracker::CALL_MEMORY_TRACKER(FreeMemBlock, mem_buf->device_addr_, TotalUsedMemStatistics(),
699 TotalMemStatistics());
700 }
701
702 if (mem_buf->status_ != origin_status) {
703 DumpDynamicMemPoolDebugInfo();
704 MS_LOG(EXCEPTION) << "Find the mem_buf status : " << mem_buf->status_
705 << " is not equal to origin status : " << origin_status << ", mem_buf_address["
706 << mem_buf->device_addr_ << "].";
707 }
708 mem_buf->status_ = target_status;
709 if (origin_status == DynamicMemBufStatus::kMemBufUsed) {
710 if (mem_mng->mps_.total_used_mem_size_ < mem_buf->size_) {
711 DumpDynamicMemPoolDebugInfo();
712 MS_LOG(EXCEPTION) << "The total used mem size : " << mem_mng->mps_.total_used_mem_size_
713 << " is less than the size of membuf : " << mem_buf->size_ << ".";
714 }
715 mem_mng->mps_.total_used_mem_size_ -= mem_buf->size_;
716 } else if (origin_status == DynamicMemBufStatus::kMemBufUsedByEvent) {
717 if (mem_mng->mps_.total_used_by_event_mem_size_ < mem_buf->size_) {
718 DumpDynamicMemPoolDebugInfo();
719 MS_LOG(EXCEPTION) << "The total used by event mem size : " << mem_mng->mps_.total_used_by_event_mem_size_
720 << " is less than the size of membuf : " << mem_buf->size_ << ".";
721 }
722 mem_mng->mps_.total_used_by_event_mem_size_ -= mem_buf->size_;
723 MS_LOG(DEBUG) << "Combime mem buf for addr : " << mem_buf->device_addr_
724 << ", used by event mem size : " << mem_mng->mps_.total_used_by_event_mem_size_ << ".";
725 } else if (origin_status == DynamicMemBufStatus::kMemBufIdle) {
726 if (mem_mng->mps_.total_idle_mem_size_ < mem_buf->size_) {
727 DumpDynamicMemPoolDebugInfo();
728 MS_LOG(EXCEPTION) << "The total idle mem size : " << mem_mng->mps_.total_idle_mem_size_
729 << " is less than the size of membuf : " << mem_buf->size_ << ".";
730 }
731 mem_mng->mps_.total_idle_mem_size_ -= mem_buf->size_;
732 } else {
733 MS_LOG(INTERNAL_EXCEPTION) << "Unsupported origin status : " << origin_status << ".";
734 }
735 if (target_status == DynamicMemBufStatus::kMemBufIdle) {
736 mem_mng->mps_.total_idle_mem_size_ += mem_buf->size_;
737 } else if (target_status == DynamicMemBufStatus::kMemBufEagerFree) {
738 mem_mng->mps_.total_eager_free_mem_size_ += mem_buf->size_;
739 } else {
740 MS_LOG(INTERNAL_EXCEPTION) << "Unsupported target status : " << target_status << ".";
741 }
742 // Combine backward(combine the next_mem_buf to mem_buf)
743 auto next_iter = iter;
744 (void)next_iter++;
745 if (next_iter != mem_block->block_all_mem_buf_map_.end()) {
746 auto next_mem_buf = next_iter->second;
747 MS_EXCEPTION_IF_NULL(next_mem_buf);
748 if (next_mem_buf->status_ == target_status) {
749 mem_buf->size_ += next_mem_buf->size_;
750 mem_mng->RemoveMemBuf(next_mem_buf);
751
752 (void)mem_block->block_all_mem_buf_map_.erase(next_iter);
753 }
754 }
755 // Combine forward(combine the mem_buf to prev_mem_buf)
756 bool forward_combine = false;
757 DynamicMemBufPtr prev_mem_buf;
758 if (iter != mem_block->block_all_mem_buf_map_.begin()) {
759 auto prev_iter = iter;
760 (void)prev_iter--;
761 prev_mem_buf = prev_iter->second;
762 MS_EXCEPTION_IF_NULL(prev_mem_buf);
763 if (prev_mem_buf->status_ == target_status) {
764 mem_mng->RemoveMemBuf(prev_mem_buf);
765 prev_mem_buf->size_ += mem_buf->size_;
766 (void)mem_block->block_all_mem_buf_map_.erase(iter);
767 forward_combine = true;
768 }
769 }
770
771 if (forward_combine) {
772 mem_mng->AddMemBuf(prev_mem_buf);
773 } else {
774 mem_mng->AddMemBuf(mem_buf);
775 }
776 }
777
778 std::tuple<DynamicMemBlockPtr, DeviceAddrMapMemBuf::iterator, MemStatusManagerPtr>
FindByStrictAddr(const DeviceMemPtr & device_addr) const779 DynamicMemPoolBestFit::FindByStrictAddr(const DeviceMemPtr &device_addr) const {
780 MS_EXCEPTION_IF_NULL(device_addr);
781 // Find in the common pool.
782 auto mem_block = FindMemBlock(device_addr, common_mem_);
783 if (mem_block != nullptr) {
784 const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
785 if (iter != mem_block->block_all_mem_buf_map_.end()) {
786 return std::make_tuple(mem_block, iter, common_mem_);
787 }
788 }
789
790 // Find in the persistent pool.
791 mem_block = FindMemBlock(device_addr, persistent_mem_);
792 if (mem_block != nullptr) {
793 const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
794 if (iter != mem_block->block_all_mem_buf_map_.end()) {
795 return std::make_tuple(mem_block, iter, persistent_mem_);
796 }
797 }
798
799 DeviceAddrMapMemBuf empty_map;
800 return std::make_tuple(nullptr, empty_map.end(), common_mem_);
801 }
802
FreePartTensorMems(const std::vector<DeviceMemPtr> & free_addrs,const std::vector<DeviceMemPtr> & keep_addrs,const std::vector<size_t> & keep_addr_sizes)803 void DynamicMemPoolBestFit::FreePartTensorMems(const std::vector<DeviceMemPtr> &free_addrs,
804 const std::vector<DeviceMemPtr> &keep_addrs,
805 const std::vector<size_t> &keep_addr_sizes) {
806 #ifdef __APPLE__
807 std::lock_guard<SpinLock> spin_lock(spin_lock_);
808 #else
809 std::lock_guard<std::mutex> locker(mutex_);
810 #endif
811
812 for (auto &free_addr : free_addrs) {
813 FreeTensorMemInner(free_addr);
814 }
815
816 MS_EXCEPTION_IF_CHECK_FAIL((keep_addrs.size() == keep_addr_sizes.size()), "The keep addrs size is wrong.");
817 for (size_t i = 0; i < keep_addrs.size(); ++i) {
818 KeepTensorMemByAddr(keep_addrs[i], keep_addr_sizes[i]);
819 }
820 }
821
KeepTensorMemByAddr(const DeviceMemPtr & device_addr,size_t size)822 void DynamicMemPoolBestFit::KeepTensorMemByAddr(const DeviceMemPtr &device_addr, size_t size) {
823 MS_EXCEPTION_IF_NULL(device_addr);
824 // Fetch the memblock and membuf by the device address.
825 auto [mem_block, mem_buf, mem_mng] = FindByKeepAddr(device_addr);
826 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
827 device::tracker::CALL_MEMORY_TRACKER(AllocMemBlock, device_addr, size, GetMemoryPoolType(), ActualPeakStatistics(),
828 TotalUsedMemStatistics(), TotalMemStatistics(), mem_block->stream_id_);
829 }
830 MS_EXCEPTION_IF_NULL(mem_block);
831 MS_EXCEPTION_IF_NULL(mem_buf);
832 MS_EXCEPTION_IF_NULL(mem_mng);
833 if (mem_buf->status_ != DynamicMemBufStatus::kMemBufIdle) {
834 DumpDynamicMemPoolDebugInfo();
835 MS_LOG(EXCEPTION) << "The membuf status isn't idle for addr:" << device_addr << ", size:" << size
836 << ", find the mem buf addr:" << mem_buf->device_addr_ << ", size:" << mem_buf->size_;
837 }
838
839 // Calculate the size of left and right split membuf.
840 size_t split_left_size = CalAddressOffset(device_addr, mem_buf->device_addr_);
841 MS_EXCEPTION_IF_CHECK_FAIL((mem_buf->size_ >= (split_left_size + size)), "The split size is wrong.");
842 size_t split_right_szie = mem_buf->size_ - split_left_size - size;
843
844 // Split the left membuf.
845 mem_mng->RemoveMemBuf(mem_buf);
846 if (split_left_size == 0) {
847 mem_buf->status_ = DynamicMemBufStatus::kMemBufUsed;
848 mem_buf->size_ = size;
849 mem_buf->allocator_name_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().name_;
850 mem_buf->allocator_type_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().type_;
851 } else {
852 mem_buf->size_ = split_left_size;
853 mem_mng->AddMemBuf(mem_buf);
854
855 auto used_mem_buf = std::make_shared<DynamicMemBuf>(
856 device_addr, DynamicMemBufStatus::kMemBufUsed, size, mem_block->stream_id_,
857 DynamicMemAllocatorDebugInfo::GetDebugInfo().name_, DynamicMemAllocatorDebugInfo::GetDebugInfo().type_);
858 (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, used_mem_buf);
859 }
860
861 // Split the right membuf.
862 if (split_right_szie > 0) {
863 DeviceMemPtr right_buf_addr = AddressOffset(device_addr, size);
864 auto right_mem_buf = std::make_shared<DynamicMemBuf>(right_buf_addr, DynamicMemBufStatus::kMemBufIdle,
865 split_right_szie, mem_block->stream_id_);
866 (void)mem_block->block_all_mem_buf_map_.emplace(right_buf_addr, right_mem_buf);
867 mem_mng->AddMemBuf(right_mem_buf);
868 }
869
870 // Memory statistics.
871 mem_mng->mps_.total_used_mem_size_ += size;
872 mem_mng->mps_.UpdatePeakSize();
873 mem_mng->mps_.total_idle_mem_size_ -= size;
874 MS_LOG(DEBUG) << "Keep memory details, name:" << DynamicMemAllocatorDebugInfo::GetDebugInfo().name_
875 << ", address:" << device_addr << ", size:" << size << "B, total allocated mem:" << TotalMemStatistics()
876 << "B, peak used mem:" << UsedMemPeakStatistics() << "B, in used mem:" << TotalUsedMemStatistics()
877 << "B, used by event mem:" << TotalUsedByEventMemStatistics()
878 << "B, actual peak used mem:" << ActualPeakStatistics()
879 << "B, total idle mem:" << TotalIdleMemStatistics() << "B.";
880 }
881
FindMemBufByKeepAddr(const DeviceMemPtr & device_addr,const DynamicMemBlockPtr & mem_block) const882 DynamicMemBufPtr DynamicMemPoolBestFit::FindMemBufByKeepAddr(const DeviceMemPtr &device_addr,
883 const DynamicMemBlockPtr &mem_block) const {
884 MS_EXCEPTION_IF_NULL(device_addr);
885 MS_EXCEPTION_IF_NULL(mem_block);
886 auto &&iter = mem_block->block_all_mem_buf_map_.upper_bound(device_addr);
887 if (iter != mem_block->block_all_mem_buf_map_.begin()) {
888 return (--iter)->second;
889 }
890 return nullptr;
891 }
892
FindByKeepAddr(const DeviceMemPtr & device_addr) const893 std::tuple<DynamicMemBlockPtr, DynamicMemBufPtr, MemStatusManagerPtr> DynamicMemPoolBestFit::FindByKeepAddr(
894 const DeviceMemPtr &device_addr) const {
895 MS_EXCEPTION_IF_NULL(device_addr);
896 auto is_addr_in_membuf = [](const DeviceMemPtr &device_addr, const DynamicMemBufPtr &mem_buf) {
897 return (mem_buf != nullptr) && (device_addr >= mem_buf->device_addr_) &&
898 (mem_buf->size_ >= CalAddressOffset(device_addr, mem_buf->device_addr_));
899 };
900
901 // Find in the common pool.
902 auto mem_block = FindMemBlock(device_addr, common_mem_);
903 if (mem_block != nullptr) {
904 auto mem_buf = FindMemBufByKeepAddr(device_addr, mem_block);
905 if (is_addr_in_membuf(device_addr, mem_buf)) {
906 return std::make_tuple(mem_block, mem_buf, common_mem_);
907 }
908 }
909
910 // Find in the persistent pool.
911 mem_block = FindMemBlock(device_addr, persistent_mem_);
912 if (mem_block != nullptr) {
913 auto mem_buf = FindMemBufByKeepAddr(device_addr, mem_block);
914 if (is_addr_in_membuf(device_addr, mem_buf)) {
915 return std::make_tuple(mem_block, mem_buf, persistent_mem_);
916 }
917 }
918
919 return std::make_tuple(nullptr, nullptr, common_mem_);
920 }
921
ReleaseDeviceRes()922 void DynamicMemPoolBestFit::ReleaseDeviceRes() {
923 #ifdef __APPLE__
924 std::lock_guard<SpinLock> spin_lock(spin_lock_);
925 #else
926 std::lock_guard<std::mutex> locker(mutex_);
927 #endif
928 DumpDynamicMemPoolStateInfo();
929
930 auto fn = [this](const MemStatusManagerPtr &mem_mng) {
931 MS_EXCEPTION_IF_NULL(mem_mng);
932 for (auto &iter : mem_mng->mem_block_list_) {
933 MS_EXCEPTION_IF_NULL(iter);
934 auto &device_addr = iter->device_addr_base_;
935 if (device_addr != nullptr) {
936 if (!FreeDeviceMem(device_addr)) {
937 MS_LOG(ERROR) << "Free device memory[" << device_addr << "] error.";
938 }
939 device_addr = nullptr;
940 }
941 }
942 mem_mng->Clear();
943 };
944 fn(common_mem_);
945 fn(persistent_mem_);
946
947 tracker::MemTrackerManager::GetInstance().Dump();
948 }
949
DumpDynamicMemPoolStateInfo()950 void DynamicMemPoolBestFit::DumpDynamicMemPoolStateInfo() {
951 size_t total_used_size_list[kAllocatorTypeNum] = {0};
952 static bool is_enable_memory_statistics = common::IsEnableRuntimeConfig(common::kRuntimeMemoryStat) ||
953 common::IsEnableRuntimeConfig(common::kRuntimeMemoryTrack);
954 auto fn = [&](const MemStatusManagerPtr &mem_mng, const std::string &mem_type) {
955 MS_EXCEPTION_IF_NULL(mem_mng);
956 if (mem_mng->Empty()) {
957 return;
958 }
959
960 std::ostringstream buf;
961 for (size_t i = 0; i < mem_mng->mem_block_list_.size(); ++i) {
962 size_t mem_block_used_size = 0;
963 MS_EXCEPTION_IF_NULL(mem_mng->mem_block_list_[i]);
964 for (auto mb = mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.begin();
965 mb != mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.end(); ++mb) {
966 if (mb->second->status_ == DynamicMemBufStatus::kMemBufUsed) {
967 mem_block_used_size += mb->second->size_;
968 MS_EXCEPTION_IF_CHECK_FAIL((static_cast<int>(mb->second->allocator_type_) < kAllocatorTypeNum),
969 "Allocator type is out of range.");
970 total_used_size_list[static_cast<int>(mb->second->allocator_type_)] += mb->second->size_;
971 }
972 }
973 buf << ", block[" << i << "] stream id:" << mem_mng->mem_block_list_[i]->stream_id_
974 << " block size:" << mem_mng->mem_block_list_[i]->mem_block_size_ / kMBToByte
975 << "M idle size:" << (mem_mng->mem_block_list_[i]->mem_block_size_ - mem_block_used_size) / kMBToByte
976 << "M actual size: " << (mem_mng->mem_block_list_[i]->get_actual_peak()) / kMBToByte << "M.";
977 }
978 std::ostringstream oss_buf;
979 // Dump all the memory buf info
980 oss_buf << mem_type << " pool info: Total allocated mem:" << mem_mng->mps_.total_mem_size_ / kMBToByte
981 << "M, peak used mem:" << mem_mng->mps_.used_mem_peak_size_ / kMBToByte
982 << "M, in used mem:" << mem_mng->mps_.total_used_mem_size_ / kMBToByte
983 << "M, total use by event mem:" << mem_mng->mps_.total_used_by_event_mem_size_ / kMBToByte
984 << "M, total idle mem:" << mem_mng->mps_.total_idle_mem_size_ / kMBToByte
985 << "M. Block unit size:" << mem_mng->unit_size_ / kMBToByte
986 << "M, block counts:" << mem_mng->mem_block_list_.size() << buf.str();
987 if (is_enable_memory_statistics) {
988 std::cout << "[MS_RUNTIME_PROF]" << oss_buf.str() << std::endl;
989 }
990 MS_LOG(INFO) << oss_buf.str();
991 };
992
993 fn(common_mem_, std::string(kCommonMem));
994 fn(persistent_mem_, std::string(kPersistentParamMem));
995 std::ostringstream oss_mem;
996 oss_mem << "The dynamic memory pool total allocated mem:" << TotalMemStatistics() / kMBToByte
997 << "M, min addr :" << GetMinUsingMemoryAddr()
998 << ", max addr: " << (mem_bufs_.empty() ? nullptr : *(--mem_bufs_.end()))
999 << ", peak used mem:" << UsedMemPeakStatistics() / kMBToByte
1000 << "M, actual peak used mem:" << ActualPeakStatistics() / kMBToByte
1001 << "M, in used mem:" << TotalUsedMemStatistics() / kMBToByte
1002 << "M, total used by event mem:" << TotalUsedByEventMemStatistics() / kMBToByte
1003 << "M, total idle mem:" << TotalIdleMemStatistics() / kMBToByte
1004 << "M, total eager free mem:" << TotalEagerFreeMemStatistics() / kMBToByte
1005 << "M. Weight used size:" << total_used_size_list[static_cast<int>(AllocatorType::kWeight)] / kMBToByte
1006 << "M, constant value used size:"
1007 << total_used_size_list[static_cast<int>(AllocatorType::kConstantValue)] / kMBToByte
1008 << "M, kernel output used size:"
1009 << total_used_size_list[static_cast<int>(AllocatorType::kKernelOutput)] / kMBToByte
1010 << "M, other used size:" << total_used_size_list[static_cast<int>(AllocatorType::kOther)] / kMBToByte << "M.";
1011 if (is_enable_memory_statistics) {
1012 std::cout << "[MS_RUNTIME_PROF]" << oss_mem.str() << std::endl;
1013 }
1014 MS_LOG(INFO) << oss_mem.str();
1015 }
1016
DumpDynamicMemPoolDebugInfo()1017 void DynamicMemPoolBestFit::DumpDynamicMemPoolDebugInfo() {
1018 auto fn = [](const MemStatusManagerPtr &mem_mng, const std::string &mem_type) {
1019 const auto &device_state = mem_mng->DumpMemBlockDebugInfo(mem_type);
1020 const auto &stream_ids = mem_mng->GetStreamIds();
1021 // Dump all the idle memory buf info.
1022 size_t total_idle_mem_in_mem_mng = 0;
1023 MS_LOG(WARNING) << mem_type << " all idle_mem_bufs info: counts[" << stream_ids.size() << "].";
1024 for (const auto &stream_id : stream_ids) {
1025 auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufIdle);
1026 const auto &&iter = mem_mng->mem_bufs_.find(key);
1027 if (iter == mem_mng->mem_bufs_.end()) {
1028 continue;
1029 }
1030 const auto &mem_buf_map = iter->second;
1031 MS_LOG(WARNING) << " stream id : " << stream_id << ", idle mem buf info : count[]" << mem_buf_map.size() << "].";
1032 for (auto &&idle_iter = mem_buf_map.begin(); idle_iter != mem_buf_map.end(); idle_iter++) {
1033 auto &mem_buf = idle_iter->second;
1034 MS_EXCEPTION_IF_NULL(mem_buf);
1035 total_idle_mem_in_mem_mng += mem_buf->size_;
1036 MS_LOG(INFO) << " Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_
1037 << "] status[" << kBufStatusString.at(mem_buf->status_) << "] stream id[" << mem_buf->stream_id_
1038 << "].";
1039 }
1040 }
1041 // Dump all the eager free memory buf info.
1042 size_t total_eager_free_mem_in_mem_mng = 0;
1043 MS_LOG(WARNING) << mem_type << " all eager free mem_buf info: counts[" << stream_ids.size() << "].";
1044 for (const auto &stream_id : stream_ids) {
1045 auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufEagerFree);
1046 const auto &&iter = mem_mng->mem_bufs_.find(key);
1047 if (iter == mem_mng->mem_bufs_.end()) {
1048 continue;
1049 }
1050 const auto &mem_buf_map = iter->second;
1051 MS_LOG(WARNING) << " stream id : " << stream_id << ", eager free mem buf info : count[]" << mem_buf_map.size()
1052 << "].";
1053 for (auto &&idle_iter = mem_buf_map.begin(); idle_iter != mem_buf_map.end(); idle_iter++) {
1054 auto &mem_buf = idle_iter->second;
1055 MS_EXCEPTION_IF_NULL(mem_buf);
1056 total_eager_free_mem_in_mem_mng += mem_buf->size_;
1057 MS_LOG(INFO) << " Eager free mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_
1058 << "] status[" << kBufStatusString.at(mem_buf->status_) << "] stream id[" << mem_buf->stream_id_
1059 << "].";
1060 }
1061 }
1062 // Dump the memory statistical info.
1063 MS_LOG(WARNING) << mem_type << " total allocated memory[" << device_state.total_mem_size_ << "], used memory["
1064 << device_state.total_used_mem_size_ << "], used by event memory["
1065 << device_state.total_used_by_event_mem_size_ << "], idle memory["
1066 << device_state.total_idle_mem_size_ << "].";
1067 if (device_state.total_idle_mem_size_ != total_idle_mem_in_mem_mng) {
1068 MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory.";
1069 }
1070 if (device_state.total_used_by_event_mem_size_ != mem_mng->mps_.total_used_by_event_mem_size_) {
1071 MS_LOG(ERROR) << "Check error: the used by event memory in the mem_block is not equal the global idle memory.";
1072 }
1073 if (device_state.total_eager_free_mem_size_ != total_eager_free_mem_in_mem_mng) {
1074 MS_LOG(ERROR) << "Check error: the eager free memory in the mem_block is not equal the global eager free memory.";
1075 }
1076 if (device_state.total_mem_size_ != device_state.total_used_mem_size_ + device_state.total_used_by_event_mem_size_ +
1077 device_state.total_idle_mem_size_ + device_state.total_eager_free_mem_size_) {
1078 MS_LOG(ERROR) << "Check error: the the total memory : " << device_state.total_mem_size_
1079 << " is not equal the sum of used memory : " << device_state.total_used_mem_size_
1080 << ", use by event memory : " << device_state.total_used_by_event_mem_size_
1081 << ", idle memory : " << device_state.total_idle_mem_size_
1082 << " and eager free memory : " << device_state.total_eager_free_mem_size_ << ".";
1083 }
1084 };
1085
1086 MS_LOG(WARNING) << "Start dump dynamic memory pool debug info.";
1087 fn(common_mem_, std::string(kCommonMem));
1088 fn(persistent_mem_, std::string(kPersistentParamMem));
1089 MS_LOG(WARNING) << "Finish dump dynamic memory pool debug info.";
1090 }
1091
1092 // Element in vector : memory_stream_id, address
RecordEvent(int64_t task_id_on_stream,uint32_t user_stream_id,const std::vector<std::pair<uint32_t,DeviceMemPtr>> & memory_stream_addresses,const DeviceEventPtr & event)1093 bool DynamicMemPoolBestFit::RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id,
1094 const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses,
1095 const DeviceEventPtr &event) {
1096 MS_LOG(DEBUG) << "Record event for, task_id_on_stream : " << task_id_on_stream
1097 << ", user_stream_id : " << user_stream_id
1098 << ", memory_stream_addresses size : " << memory_stream_addresses.size() << ", event : " << event.get()
1099 << ".";
1100 #ifdef __APPLE__
1101 std::lock_guard<SpinLock> spin_lock(spin_lock_);
1102 #else
1103 std::lock_guard<std::mutex> locker(mutex_);
1104 #endif
1105 for (auto &[memory_stream_id, address] : memory_stream_addresses) {
1106 auto &&mem_buf_tuple = FindByStrictAddr(address);
1107 auto mem_block = std::get<0>(mem_buf_tuple);
1108 // Output of somas sub graph may be used by somas sub graph inner node, address may not be kept in mem pool.
1109 if (mem_block == nullptr) {
1110 MS_LOG(DEBUG) << "Can't find memblock by address in memory pool.";
1111 continue;
1112 }
1113 auto mem_buf = (std::get<1>(mem_buf_tuple))->second;
1114 (void)mem_buf->RecordEvent(task_id_on_stream, user_stream_id, event);
1115 (void)stream_pair_addresses_[std::make_pair(user_stream_id, memory_stream_id)].emplace(mem_buf);
1116 }
1117 return true;
1118 }
1119
WaitEvent(int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)1120 bool DynamicMemPoolBestFit::WaitEvent(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id) {
1121 #ifdef __APPLE__
1122 std::lock_guard<SpinLock> spin_lock(spin_lock_);
1123 #else
1124 std::lock_guard<std::mutex> locker(mutex_);
1125 #endif
1126 auto key = std::make_pair(user_stream_id, memory_stream_id);
1127 auto iter = stream_pair_addresses_.find(key);
1128 if (iter == stream_pair_addresses_.end()) {
1129 return false;
1130 }
1131
1132 auto addresses = iter->second;
1133 for (const auto &address : addresses) {
1134 address->WaitEvent(task_id_on_stream, user_stream_id);
1135 // Remove event and try to free memory.
1136 if (address->IsEventNotUsed()) {
1137 iter->second.erase(address);
1138 if (address->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1139 FreeTensorMemInner(address->device_addr_);
1140 }
1141 }
1142 }
1143 MS_LOG(DEBUG) << "After release, bounded addresses size : " << iter->second.size()
1144 << ", used by event size : " << TotalUsedByEventMemStatistics() << ".";
1145 return true;
1146 }
1147
1148 // WaitEvent is called before sync stream, so performance may not be the issue.
WaitEvent(int64_t task_id_on_stream,uint32_t memory_stream_id)1149 bool DynamicMemPoolBestFit::WaitEvent(int64_t task_id_on_stream, uint32_t memory_stream_id) {
1150 #ifdef __APPLE__
1151 std::lock_guard<SpinLock> spin_lock(spin_lock_);
1152 #else
1153 std::lock_guard<std::mutex> locker(mutex_);
1154 #endif
1155 for (auto &stream_pair_addresses : stream_pair_addresses_) {
1156 const auto &[user_stream, memory_stream] = stream_pair_addresses.first;
1157 if (memory_stream != memory_stream_id) {
1158 continue;
1159 }
1160 auto addresses = stream_pair_addresses.second;
1161 for (const auto &address : addresses) {
1162 address->WaitEvent(task_id_on_stream, user_stream);
1163 // Remove event and try to free memory.
1164 if (address->IsEventNotUsed()) {
1165 stream_pair_addresses.second.erase(address);
1166 if (address->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1167 FreeTensorMemInner(address->device_addr_);
1168 }
1169 }
1170 }
1171 }
1172 MS_LOG(DEBUG) << "After release events, task_id_on_stream : " << task_id_on_stream
1173 << ", memory_stream_id : " << memory_stream_id
1174 << ", used by event size : " << TotalUsedByEventMemStatistics() << ".";
1175 return true;
1176 }
1177
SyncAllEvents()1178 bool DynamicMemPoolBestFit::SyncAllEvents() {
1179 #ifdef __APPLE__
1180 std::lock_guard<SpinLock> spin_lock(spin_lock_);
1181 #else
1182 std::lock_guard<std::mutex> locker(mutex_);
1183 #endif
1184 return SyncAllEventsInner();
1185 }
1186
SyncAllEventsInner()1187 bool DynamicMemPoolBestFit::SyncAllEventsInner() {
1188 MS_LOG(DEBUG) << "Sync all events, stream_pair_addresses_ size : " << stream_pair_addresses_.size() << ".";
1189 if (stream_pair_addresses_.empty()) {
1190 return false;
1191 }
1192
1193 std::set<DynamicMemBufPtr> carry_event_addresses;
1194 for (const auto &stream_pair_address : stream_pair_addresses_) {
1195 for (const auto &address : stream_pair_address.second) {
1196 (void)carry_event_addresses.emplace(address);
1197 }
1198 }
1199 for (auto &address : carry_event_addresses) {
1200 if (address->SyncAllEvents() && address->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1201 FreeTensorMemInner(address->device_addr_);
1202 }
1203 }
1204
1205 stream_pair_addresses_.clear();
1206 return true;
1207 }
1208
1209 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
ExtractBlocksListInfo(const MemStatusManagerPtr & mem_mng) const1210 DynamicMemPoolBestFit::ExtractBlocksListInfo(const MemStatusManagerPtr &mem_mng) const {
1211 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> blocks_list_info;
1212 for (auto iter = mem_mng->mem_block_list_.begin(); iter != mem_mng->mem_block_list_.end(); ++iter) {
1213 std::unordered_map<std::string, size_t> block_info;
1214 block_info[kBlockMemorySize] = (*iter)->size();
1215 block_info[kBlockStreamId] = (*iter)->stream_id_;
1216 blocks_list_info[(std::string *)(*iter)->device_addr()] = block_info;
1217 }
1218 return blocks_list_info;
1219 }
1220
1221 // The statistics information.
TotalMemStatistics() const1222 size_t DynamicMemPoolBestFit::TotalMemStatistics() const {
1223 return common_mem_->mps_.total_mem_size_ + persistent_mem_->mps_.total_mem_size_;
1224 }
TotalUsedMemStatistics() const1225 size_t DynamicMemPoolBestFit::TotalUsedMemStatistics() const {
1226 return common_mem_->mps_.total_used_mem_size_ + persistent_mem_->mps_.total_used_mem_size_;
1227 }
TotalUsedByEventMemStatistics() const1228 size_t DynamicMemPoolBestFit::TotalUsedByEventMemStatistics() const {
1229 return common_mem_->mps_.total_used_by_event_mem_size_ + persistent_mem_->mps_.total_used_by_event_mem_size_;
1230 }
TotalIdleMemStatistics() const1231 size_t DynamicMemPoolBestFit::TotalIdleMemStatistics() const {
1232 return common_mem_->mps_.total_idle_mem_size_ + persistent_mem_->mps_.total_idle_mem_size_;
1233 }
TotalEagerFreeMemStatistics() const1234 size_t DynamicMemPoolBestFit::TotalEagerFreeMemStatistics() const {
1235 return common_mem_->mps_.total_eager_free_mem_size_ + persistent_mem_->mps_.total_eager_free_mem_size_;
1236 }
UsedMemPeakStatistics() const1237 size_t DynamicMemPoolBestFit::UsedMemPeakStatistics() const {
1238 return common_mem_->mps_.used_mem_peak_size_ + persistent_mem_->mps_.used_mem_peak_size_;
1239 }
MaxMemAllocatedStatistics() const1240 size_t DynamicMemPoolBestFit::MaxMemAllocatedStatistics() const {
1241 return common_mem_->mps_.temp_used_mem_peak_size_ + persistent_mem_->mps_.temp_used_mem_peak_size_;
1242 }
MaxMemReservedStatistics() const1243 size_t DynamicMemPoolBestFit::MaxMemReservedStatistics() const {
1244 return common_mem_->mps_.total_mem_size_ + persistent_mem_->mps_.total_mem_size_ -
1245 common_mem_->mps_.temp_total_mem_size_ - persistent_mem_->mps_.temp_total_mem_size_;
1246 }
ActualPeakStatistics() const1247 size_t DynamicMemPoolBestFit::ActualPeakStatistics() const {
1248 return common_mem_->CalActualPeak() + persistent_mem_->CalActualPeak();
1249 }
BlockCountsStatistics() const1250 std::unordered_map<std::string, std::size_t> DynamicMemPoolBestFit::BlockCountsStatistics() const {
1251 size_t common_mem_block_counts = common_mem_->mem_block_list_.size();
1252 size_t persistent_mem_block_counts = persistent_mem_->mem_block_list_.size();
1253 std::unordered_map<std::string, std::size_t> block_count_stats;
1254 block_count_stats[kCommonMemPoolType] = common_mem_block_counts;
1255 block_count_stats[kPersistentMemPoolType] = persistent_mem_block_counts;
1256 return block_count_stats;
1257 }
BlockUnitSizeStatistics() const1258 std::unordered_map<std::string, std::size_t> DynamicMemPoolBestFit::BlockUnitSizeStatistics() const {
1259 size_t common_mem_block_unit_size = common_mem_->unit_size_;
1260 size_t persistent_mem_block_unit_size = persistent_mem_->unit_size_;
1261 std::unordered_map<std::string, std::size_t> block_unit_size_stats;
1262 block_unit_size_stats[kCommonMemPoolType] = common_mem_block_unit_size;
1263 block_unit_size_stats[kPersistentMemPoolType] = persistent_mem_block_unit_size;
1264 return block_unit_size_stats;
1265 }
1266 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
CommonMemBlocksInfoStatistics() const1267 DynamicMemPoolBestFit::CommonMemBlocksInfoStatistics() const {
1268 return ExtractBlocksListInfo(common_mem_);
1269 }
1270 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
PersistentMemBlocksInfoStatistics() const1271 DynamicMemPoolBestFit::PersistentMemBlocksInfoStatistics() const {
1272 return ExtractBlocksListInfo(persistent_mem_);
1273 }
ResetMaxMemReserved() const1274 void DynamicMemPoolBestFit::ResetMaxMemReserved() const {
1275 common_mem_->mps_.temp_total_mem_size_ = common_mem_->mps_.total_mem_size_;
1276 persistent_mem_->mps_.temp_total_mem_size_ = persistent_mem_->mps_.total_mem_size_;
1277 }
ResetMaxMemAllocated() const1278 void DynamicMemPoolBestFit::ResetMaxMemAllocated() const {
1279 common_mem_->mps_.temp_total_used_mem_size_ = common_mem_->mps_.total_used_mem_size_;
1280 persistent_mem_->mps_.temp_total_used_mem_size_ = persistent_mem_->mps_.total_used_mem_size_;
1281 common_mem_->mps_.temp_total_used_by_event_mem_size_ = common_mem_->mps_.total_used_by_event_mem_size_;
1282 persistent_mem_->mps_.temp_total_used_by_event_mem_size_ = persistent_mem_->mps_.total_used_by_event_mem_size_;
1283 common_mem_->mps_.temp_used_mem_peak_size_ = 0;
1284 persistent_mem_->mps_.temp_used_mem_peak_size_ = 0;
1285 }
1286
CalActualPeak()1287 size_t MemStatusManager::CalActualPeak() {
1288 if (mem_block_insertion_order_.empty()) {
1289 return 0;
1290 }
1291 size_t actual_peak = total_block_size_;
1292 const auto &end_block = mem_block_insertion_order_.back();
1293 MS_EXCEPTION_IF_NULL(end_block);
1294 actual_peak -= end_block->size();
1295 actual_peak += end_block->get_actual_peak();
1296 return actual_peak;
1297 }
1298
RecordEvent(int64_t task_id_on_stream,uint32_t user_stream_id,const DeviceEventPtr & event)1299 bool DynamicMemBuf::RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id, const DeviceEventPtr &event) {
1300 MS_EXCEPTION_IF_NULL(event);
1301 if (events_ == nullptr) {
1302 events_ = std::make_shared<std::unordered_map<uint32_t, std::shared_ptr<std::list<TaskIdOnStreamEvent>>>>();
1303 }
1304 std::shared_ptr<std::list<TaskIdOnStreamEvent>> event_list = nullptr;
1305 auto iter = events_->find(user_stream_id);
1306 if (iter == events_->end()) {
1307 event_list = std::make_shared<std::list<TaskIdOnStreamEvent>>();
1308 (void)events_->emplace(user_stream_id, event_list);
1309 } else {
1310 event_list = iter->second;
1311 MS_EXCEPTION_IF_NULL(event_list);
1312 }
1313 (void)event_list->emplace_back(task_id_on_stream, event);
1314 return true;
1315 }
1316
WaitEvent(uint32_t task_id_on_stream,uint32_t user_stream_id)1317 bool DynamicMemBuf::WaitEvent(uint32_t task_id_on_stream, uint32_t user_stream_id) {
1318 if (events_ == nullptr) {
1319 return false;
1320 }
1321 auto iter = events_->find(user_stream_id);
1322 if (iter == events_->end()) {
1323 return false;
1324 }
1325 auto &event_list = iter->second;
1326 MS_EXCEPTION_IF_NULL(event_list);
1327 // Pop all element in list that not bigger than task_id_on_stream.
1328 while (!event_list->empty() && event_list->front().first <= task_id_on_stream) {
1329 event_list->pop_front();
1330 }
1331 // Remove list if event list is empty.
1332 if (event_list->empty()) {
1333 events_->erase(iter);
1334 }
1335 return true;
1336 }
1337
IsEventNotUsed()1338 bool DynamicMemBuf::IsEventNotUsed() { return events_ == nullptr ? true : events_->empty(); }
1339
SyncAllEvents()1340 bool DynamicMemBuf::SyncAllEvents() {
1341 if (IsEventNotUsed()) {
1342 return false;
1343 }
1344
1345 for (auto iter = events_->begin(); iter != events_->end();) {
1346 auto &event_list = iter->second;
1347 MS_EXCEPTION_IF_NULL(event_list);
1348 for (auto list_iter = event_list->begin(); list_iter != event_list->end();) {
1349 auto &event = list_iter->second;
1350 // Sync event if event is not arrived.
1351 if (!event->QueryEvent()) {
1352 event->SyncEvent();
1353 }
1354 list_iter = event_list->erase(list_iter);
1355 }
1356 if (event_list->empty()) {
1357 // list is empty, erase list in map.
1358 iter = events_->erase(iter);
1359 } else {
1360 MS_LOG(INTERNAL_EXCEPTION) << "Event list is not empty.";
1361 }
1362 }
1363 return events_->empty();
1364 }
1365
AddMemBlock(const DynamicMemBlockPtr & mem_block,uint32_t stream_id)1366 void MemStatusManager::AddMemBlock(const DynamicMemBlockPtr &mem_block, uint32_t stream_id) {
1367 auto iter = mem_blocks_.find(stream_id);
1368 if (iter != mem_blocks_.end()) {
1369 DoAddMemBlock(mem_block, &iter->second);
1370 } else {
1371 (void)mem_blocks_.emplace(stream_id, std::vector<DynamicMemBlockPtr>{mem_block});
1372 }
1373
1374 DoAddMemBlock(mem_block, &mem_block_list_);
1375 mem_block_insertion_order_.emplace_back(mem_block);
1376 total_block_size_ += mem_block->size();
1377 }
1378
DoAddMemBlock(const DynamicMemBlockPtr & mem_block,std::vector<DynamicMemBlockPtr> * mem_block_list)1379 void MemStatusManager::DoAddMemBlock(const DynamicMemBlockPtr &mem_block,
1380 std::vector<DynamicMemBlockPtr> *mem_block_list) {
1381 auto iter = std::upper_bound(mem_block_list->begin(), mem_block_list->end(), mem_block->device_addr(),
1382 [](const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
1383 return device_addr < mem_block->device_addr();
1384 });
1385 (void)mem_block_list->insert(iter, mem_block);
1386 }
1387
GetOrCreateMemBufMap(uint32_t stream_id,DynamicMemBufStatus status)1388 SizeMapMemBuf &MemStatusManager::GetOrCreateMemBufMap(uint32_t stream_id, DynamicMemBufStatus status) {
1389 return mem_bufs_[std::make_pair(stream_id, status)];
1390 }
1391
AddMemBuf(const DynamicMemBufPtr & mem_buf)1392 void MemStatusManager::AddMemBuf(const DynamicMemBufPtr &mem_buf) {
1393 auto key = std::make_pair(mem_buf->stream_id_, mem_buf->status_);
1394 auto &mem_buf_map = mem_bufs_[key];
1395 (void)mem_buf_map.emplace(mem_buf->size_, mem_buf);
1396 }
1397
RemoveMemBuf(const DynamicMemBufPtr & mem_buf)1398 void MemStatusManager::RemoveMemBuf(const DynamicMemBufPtr &mem_buf) {
1399 auto key = std::make_pair(mem_buf->stream_id_, mem_buf->status_);
1400 auto &mem_buf_map = mem_bufs_[key];
1401 auto &&iter = mem_buf_map.equal_range(mem_buf->size_);
1402 while (iter.first != iter.second) {
1403 if (iter.first->second->device_addr_ == mem_buf->device_addr_) {
1404 (void)mem_buf_map.erase(iter.first);
1405 return;
1406 }
1407 (void)iter.first++;
1408 }
1409 MS_LOG(INTERNAL_EXCEPTION) << "Remove mem buf failed, address : " << mem_buf->device_addr_ << ".";
1410 }
1411
Clear()1412 void MemStatusManager::Clear() noexcept {
1413 mem_blocks_.clear();
1414 mem_block_list_.clear();
1415 mem_bufs_.clear();
1416 }
1417
DumpMemBlockDebugInfo(const std::string & mem_type)1418 const DeviceState MemStatusManager::DumpMemBlockDebugInfo(const std::string &mem_type) {
1419 DeviceState device_state;
1420 // Dump the memory block info and memory buf info.
1421 MS_LOG(WARNING) << mem_type << " all mem_block info: counts[" << mem_block_list_.size() << "].";
1422 for (auto iter = mem_block_list_.begin(); iter != mem_block_list_.end(); ++iter) {
1423 device_state.total_mem_size_ += (*iter)->size();
1424 auto mem_buf_map = (*iter)->block_all_mem_buf_map_;
1425 MS_LOG(WARNING) << " MemBlock info: number[" << iter - mem_block_list_.begin() << "] mem_buf_counts["
1426 << mem_buf_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size["
1427 << (*iter)->size() << "] stream id[" << (*iter)->stream_id_ << "].";
1428 for (auto iter_mem_buf = mem_buf_map.begin(); iter_mem_buf != mem_buf_map.end(); ++iter_mem_buf) {
1429 auto mem_buf = iter_mem_buf->second;
1430 MS_EXCEPTION_IF_NULL(mem_buf);
1431 if (mem_buf->status_ == DynamicMemBufStatus::kMemBufIdle) {
1432 device_state.total_idle_mem_size_ += mem_buf->size_;
1433 } else if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsed) {
1434 device_state.total_used_mem_size_ += mem_buf->size_;
1435 } else if (mem_buf->status_ == DynamicMemBufStatus::kMemBufEagerFree) {
1436 device_state.total_eager_free_mem_size_ += mem_buf->size_;
1437 } else if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1438 device_state.total_used_by_event_mem_size_ += mem_buf->size_;
1439 } else {
1440 MS_LOG(INTERNAL_EXCEPTION) << "Unknown mem buf status : " << mem_buf->status_ << ".";
1441 }
1442 MS_LOG(INFO) << " MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status["
1443 << kBufStatusString.at(mem_buf->status_) << "] name["
1444 << (mem_buf->allocator_name_.empty() ? "Unknown" : mem_buf->allocator_name_) << "] type["
1445 << kAllocatorTypeString.at(mem_buf->allocator_type_) << "] stream id[" << mem_buf->stream_id_
1446 << "].";
1447 }
1448 }
1449 return device_state;
1450 }
1451 } // namespace device
1452 } // namespace mindspore
1453