• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
18 #include <algorithm>
19 #include <numeric>
20 #include <ostream>
21 #include <utility>
22 #include <string>
23 #include "include/backend/mem_reuse/mem_tracker.h"
24 #include "include/common/utils/utils.h"
25 #include "utils/log_adapter.h"
26 #include "utils/ms_context.h"
27 #include "utils/convert_utils_base.h"
28 #include "utils/ms_utils.h"
29 #ifdef ENABLE_DEBUGGER
30 #include "plugin/device/cpu/hal/profiler/cpu_profiling.h"
31 #endif
32 
33 namespace mindspore {
34 namespace device {
35 static const char kPersistentParamMem[] = "Persistent mem";
36 static const char kCommonMem[] = "Common mem";
37 constexpr size_t kGBToByte = 1024 << 20;
38 // The smallest memory request size, if it is smaller than this size, the device memory request may fail
39 // Set experience value to 10M
40 const size_t kMinimumAllocMem = 10 << 20;
41 
42 thread_local AllocatorDebugInfo DynamicMemAllocatorDebugInfo::debug_info_;
43 
44 const char kBlockMemorySize[] = "block_memory_size";
45 const char kBlockStreamId[] = "block_stream_id";
46 const char kCommonMemPoolType[] = "common_mem_pool";
47 const char kPersistentMemPoolType[] = "persistent_mem_pool";
48 
49 static const std::map<DynamicMemBufStatus, std::string> kBufStatusString = {
50   {DynamicMemBufStatus::kMemBufIdle, "idle"},
51   {DynamicMemBufStatus::kMemBufUsed, "used"},
52   {DynamicMemBufStatus::kMemBufEagerFree, "eager_free"},
53   {DynamicMemBufStatus::kMemBufUsedByEvent, "used_by_event"}};
54 
55 static const std::map<AllocatorType, std::string> kAllocatorTypeString = {
56   {AllocatorType::kWeight, "weight"},
57   {AllocatorType::kConstantValue, "constant value"},
58   {AllocatorType::kKernelOutput, "kernel output"},
59   {AllocatorType::kGraphOutput, "graph output"},
60   {AllocatorType::kWorkspace, "workspace"},
61   {AllocatorType::kOther, "other"},
62 };
63 
~DynamicMemPoolBestFit()64 DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
65   persistent_mem_->Clear();
66   common_mem_->Clear();
67   stream_pair_addresses_.clear();
68 }
69 
update_border_addr(DeviceMemPtr left_addr,DeviceMemPtr right_addr)70 void DynamicMemBlock::update_border_addr(DeviceMemPtr left_addr, DeviceMemPtr right_addr) {
71   if (min_addr_ == nullptr) {
72     min_addr_ = left_addr;
73   } else {
74     min_addr_ = std::min(min_addr_, left_addr);
75   }
76   if (max_addr_ == nullptr) {
77     max_addr_ = right_addr;
78   } else {
79     max_addr_ = std::max(max_addr_, right_addr);
80   }
81 }
82 
get_actual_peak()83 size_t DynamicMemBlock::get_actual_peak() {
84   if (min_addr_ == nullptr || max_addr_ == nullptr) {
85     return 0;
86   }
87   int64_t actual_memory = reinterpret_cast<uint8_t *>(max_addr_) - reinterpret_cast<uint8_t *>(min_addr_);
88   return actual_memory;
89 }
90 
AllocTensorMem(size_t size,bool from_persistent_mem,bool need_recycle,uint32_t stream_id)91 DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size, bool from_persistent_mem, bool need_recycle,
92                                                    uint32_t stream_id) {
93   if (stream_id == UINT32_MAX) {
94     MS_LOG(DEBUG) << "Rewrite stream id from INT32 MAX to 0.";
95     stream_id = kDefaultStreamIndex;
96   }
97   size_t align_size = AlignMemorySize(size);
98 #ifdef __APPLE__
99   std::lock_guard<SpinLock> spin_lock(spin_lock_);
100 #else
101   std::lock_guard<std::mutex> locker(mutex_);
102 #endif
103   // Find the memory buf by tensor size, if not find, then add new memory block and memory buf.
104   DeviceMemPtr device_addr = FindAvailableMemBuf(align_size, from_persistent_mem, stream_id);
105   static bool init_recycle_memory = false;
106   if (need_recycle && !init_recycle_memory) {
107     // Force persist memory to be reserved when recycle memory is allocated for the first time
108     init_recycle_memory = true;
109     MS_LOG(INFO) << "Init Recycle Memory";
110     device_addr = nullptr;
111   }
112   if (device_addr == nullptr) {
113     device_addr = AddMemBlockAndMemBuf(align_size, from_persistent_mem, need_recycle, stream_id);
114 
115     if (device_addr == nullptr) {
116       MS_LOG(INFO) << "Alloc tensor mem failed and try to sync all events to release memory.";
117       SyncAllEventsInner();
118       device_addr = FindAvailableMemBuf(align_size, from_persistent_mem, stream_id);
119     }
120 
121     // Alloc memory failed and dump the info.
122     if (!device_addr) {
123       DumpDynamicMemPoolStateInfo();
124     }
125   }
126 
127 // report memory data to profiler
128 #ifdef ENABLE_DEBUGGER
129   static auto profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
130   MS_EXCEPTION_IF_NULL(profiler_inst);
131   if (profiler_inst->GetEnableFlag() && profiler_inst->GetProfileMemoryFlag()) {
132     profiler_inst->RecordMemoryPoolInfo(TotalUsedMemStatistics(), TotalMemStatistics(),
133                                         TotalUsedByEventMemStatistics());
134   }
135 #endif
136 
137   if (common::IsNeedProfileMemory()) {
138     MS_LOG(WARNING) << "Need Profile Memory, Memory pool alloc, total mem: " << TotalMemStatistics()
139                     << ", peak mem: " << UsedMemPeakStatistics() << ", in use mem: " << TotalUsedMemStatistics()
140                     << ", used by event mem: " << TotalUsedByEventMemStatistics()
141                     << ", device address addr: " << device_addr << ", size: " << size
142                     << ", from persistent mem: " << from_persistent_mem << ", need recycle: " << need_recycle;
143   }
144   if (device_addr != nullptr) {
145     if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
146       device::tracker::CALL_MEMORY_TRACKER(AllocMemBlock, device_addr, align_size, GetMemoryPoolType(),
147                                            ActualPeakStatistics(), TotalUsedMemStatistics(), TotalMemStatistics(),
148                                            stream_id);
149     }
150     if (IsMemoryPoolRecycle()) {
151       (void)mem_bufs_.insert(device_addr);
152     }
153   }
154   MS_LOG(DEBUG) << "Alloc memory details, name:" << DynamicMemAllocatorDebugInfo::GetDebugInfo().name_
155                 << ", persistent_mem:" << from_persistent_mem << ", stream id: " << stream_id
156                 << ", address:" << device_addr << ", size:" << size << "B, total allocated mem:" << TotalMemStatistics()
157                 << "B, peak used mem:" << UsedMemPeakStatistics() << "B, in used mem:" << TotalUsedMemStatistics()
158                 << "B, used by event mem:" << TotalUsedByEventMemStatistics()
159                 << "B, actual peak used mem:" << ActualPeakStatistics()
160                 << "B, total idle mem:" << TotalIdleMemStatistics() << "B.";
161   return device_addr;
162 }
163 
AllocContinuousTensorMem(const std::vector<size_t> & size_list,uint32_t stream_id)164 std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(const std::vector<size_t> &size_list,
165                                                                           uint32_t stream_id) {
166   std::vector<DeviceMemPtr> device_addr_list;
167   size_t total_size = std::accumulate(size_list.begin(), size_list.end(), IntToSize(0));
168   // Pre-alloc the one whole piece memory.
169   auto device_addr = AllocTensorMem(total_size, false, false, stream_id);
170   if (!device_addr) {
171     return device_addr_list;
172   }
173 #ifdef __APPLE__
174   std::lock_guard<SpinLock> spin_lock(spin_lock_);
175 #else
176   std::lock_guard<std::mutex> locker(mutex_);
177 #endif
178   // Remove the pre-alloc memory.
179   auto mem_block = FindMemBlock(device_addr, common_mem_);
180   if (mem_block == nullptr) {
181     mem_block = FindMemBlock(device_addr, persistent_mem_);
182   }
183   MS_EXCEPTION_IF_NULL(mem_block);
184   const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
185   if (iter == mem_block->block_all_mem_buf_map_.end()) {
186     DumpDynamicMemPoolDebugInfo();
187     MS_LOG(INTERNAL_EXCEPTION) << "Can't find the device address[" << device_addr << "].";
188   }
189   auto mem_buf = iter->second;
190   MS_EXCEPTION_IF_NULL(mem_buf);
191   if (mem_buf->size_ < total_size) {
192     DumpDynamicMemPoolDebugInfo();
193     MS_LOG(EXCEPTION) << "The size of membuf is less than total_size.";
194   }
195   auto rest_size = mem_buf->size_ - total_size;
196   (void)mem_block->block_all_mem_buf_map_.erase(iter);
197   // Split the pre-alloc memory into continuous memory by the size list.
198   DynamicMemBufPtr continuous_mem_buf;
199   auto buf_addr = device_addr;
200   for (size_t i : size_list) {
201     continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, DynamicMemBufStatus::kMemBufUsed, i, stream_id,
202                                                          DynamicMemAllocatorDebugInfo::GetDebugInfo().name_,
203                                                          DynamicMemAllocatorDebugInfo::GetDebugInfo().type_);
204     MS_EXCEPTION_IF_NULL(continuous_mem_buf);
205     (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
206     mem_block->update_border_addr(mem_buf->device_addr_, AddressOffset(mem_buf->device_addr_, mem_buf->size_));
207     device_addr_list.emplace_back(buf_addr);
208     buf_addr = AddressOffset(buf_addr, i);
209   }
210   // Update the size of the last memory buf.
211   continuous_mem_buf->size_ += rest_size;
212   return device_addr_list;
213 }
214 
AlignMemorySize(size_t size) const215 size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const {
216   if (size == 0) {
217     return kDynamicMemAlignSize;
218   }
219   return ((size + kDynamicMemAlignSize - 1) / kDynamicMemAlignSize) * kDynamicMemAlignSize;
220 }
221 
FindAvailableMemBuf(size_t size,bool from_persistent_mem,uint32_t stream_id)222 DeviceMemPtr DynamicMemPoolBestFit::FindAvailableMemBuf(size_t size, bool from_persistent_mem, uint32_t stream_id) {
223   auto addr = FindMemBufByStatus(size, from_persistent_mem, DynamicMemBufStatus::kMemBufIdle, stream_id);
224   if (addr == nullptr && is_trigger_eager_free_) {
225     MS_LOG(DEBUG) << "Find idle mem buf failed and eager free is enabled, try to search in eager free bufs.";
226     // Check total used max memory limits, since real occupy memory size equals to used mem size plus idle mem size.
227     // Eager free mem may occupy some memory, so total_mem_size need multiply by a factor.
228     float threshold_factor = 0.8f;
229     size_t threshold = IsEnableVmm() ? total_mem_size() : static_cast<size_t>(total_mem_size() * threshold_factor);
230     if (TotalUsedMemStatistics() + TotalUsedByEventMemStatistics() + TotalIdleMemStatistics() + size <= threshold) {
231       addr = FindMemBufByStatus(size, from_persistent_mem, DynamicMemBufStatus::kMemBufEagerFree, stream_id);
232     }
233   }
234   return addr;
235 }
236 
FindMemBufByStatus(size_t size,bool from_persistent_mem,DynamicMemBufStatus target_status,uint32_t stream_id)237 DeviceMemPtr DynamicMemPoolBestFit::FindMemBufByStatus(size_t size, bool from_persistent_mem,
238                                                        DynamicMemBufStatus target_status, uint32_t stream_id) {
239   auto addr = FindMemBufInSpecifiedMng(size, from_persistent_mem, target_status, stream_id);
240   if (addr == nullptr && !IsEnableVmm()) {
241     if (from_persistent_mem && !persistent_mem_->mem_block_list_.empty()) {
242       MS_LOG(DEBUG) << "Find mem buf in current pool failed, try to find in another one.";
243       addr = FindMemBufInSpecifiedMng(size, !from_persistent_mem, target_status, stream_id);
244     }
245   }
246   return addr;
247 }
248 
FindMemBufInSpecifiedMng(size_t size,bool from_persistent_mem,DynamicMemBufStatus target_status,uint32_t stream_id)249 DeviceMemPtr DynamicMemPoolBestFit::FindMemBufInSpecifiedMng(size_t size, bool from_persistent_mem,
250                                                              DynamicMemBufStatus target_status, uint32_t stream_id) {
251   auto &mem_mng = from_persistent_mem ? persistent_mem_ : common_mem_;
252   auto &mem_buf_map = mem_mng->GetOrCreateMemBufMap(stream_id, target_status);
253   auto iter = mem_buf_map.lower_bound(size);
254   if (iter != mem_buf_map.end()) {
255     if (IsMemoryPoolRecycle()) {
256       // Ensure that the addresses corresponding to the same Tensor for each step are consistent, making the memory pool
257       // recycling function more stable.
258       auto find_size = iter->first;
259       // Can be optimized in the future.
260       auto [lb, ub] = mem_buf_map.equal_range(find_size);
261       for (auto i = lb; i != ub; ++i) {
262         if (i->second->device_addr_ > iter->second->device_addr_) {
263           iter = i;
264         }
265       }
266     }
267     auto mem_buf = iter->second;
268     MS_EXCEPTION_IF_NULL(mem_buf);
269     if (mem_buf->status_ != target_status) {
270       DumpDynamicMemPoolDebugInfo();
271       MS_LOG(EXCEPTION) << "Mem_buf is not " << target_status << ", alloc_size[" << size << "] mem_buf_size["
272                         << mem_buf->size_ << "] mem_buf_address[" << mem_buf->device_addr_ << "].";
273     }
274     mem_buf->allocator_name_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().name_;
275     mem_buf->allocator_type_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().type_;
276     if (mem_buf->status_ == DynamicMemBufStatus::kMemBufEagerFree && IsEnableVmm()) {
277       MS_LOG(DEBUG) << "Find eager free memory, mem_buf_size[" << mem_buf->size_ << "] mem_buf_address["
278                     << mem_buf->device_addr_ << "], need size: " << size;
279       auto ret = MmapDeviceMem(size, mem_buf->device_addr_);
280       if (ret != size) {
281         return nullptr;
282       }
283     }
284     // Remove map of old idle memory buf
285     (void)mem_buf_map.erase(iter);
286     // Divide memory buf
287     if (IsSplit(size, mem_buf->size_)) {
288       SplitMemBuf(size, mem_buf, mem_mng, stream_id);
289     }
290     auto mem_block = FindMemBlock(mem_buf->device_addr_, mem_mng);
291     MS_EXCEPTION_IF_NULL(mem_block);
292     mem_block->update_border_addr(mem_buf->device_addr_, AddressOffset(mem_buf->device_addr_, mem_buf->size_));
293     mem_buf->status_ = DynamicMemBufStatus::kMemBufUsed;
294     // Memory statistics
295     mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
296     mem_mng->mps_.UpdatePeakSize();
297     if (target_status == DynamicMemBufStatus::kMemBufIdle) {
298       mem_mng->mps_.total_idle_mem_size_ -= mem_buf->size_;
299     } else if (target_status == DynamicMemBufStatus::kMemBufEagerFree) {
300       mem_mng->mps_.total_eager_free_mem_size_ -= mem_buf->size_;
301     }
302     return mem_buf->device_addr_;
303   }
304   return nullptr;
305 }
306 
MemAllocUnitSize(bool from_persistent_mem) const307 size_t DynamicMemPoolBestFit::MemAllocUnitSize(bool from_persistent_mem) const {
308   return from_persistent_mem ? persistent_mem_->unit_size_ : common_mem_->unit_size_;
309 }
310 
SetMemAllocUintSize(size_t common_size,size_t persist_size)311 void DynamicMemPoolBestFit::SetMemAllocUintSize(size_t common_size, size_t persist_size) {
312   persistent_mem_->unit_size_ = persist_size;
313   common_mem_->unit_size_ = common_size;
314   config_unit_size_ = common_size;
315   MS_LOG(DEBUG) << "Set mem alloc unit size, common " << common_size << " persistent " << persist_size;
316 }
317 
GetMinUsingMemoryAddr() const318 void *DynamicMemPoolBestFit::GetMinUsingMemoryAddr() const {
319   if (mem_bufs_.empty()) {
320     return nullptr;
321   }
322   return *(mem_bufs_.begin());
323 }
324 
SetMemPoolBlockSize(size_t available_device_mem_size)325 void DynamicMemPoolBestFit::SetMemPoolBlockSize(size_t available_device_mem_size) {
326   auto ms_context = MsContext::GetInstance();
327   MS_EXCEPTION_IF_NULL(ms_context);
328   float mem_block_size = ms_context->get_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE);
329   if (mem_block_size == kDefaultMempoolBlockSize) {
330     return;
331   }
332 
333   size_t config_size = FloatToSize(mem_block_size * kGBToByte);
334   if (config_size > available_device_mem_size) {
335     MS_LOG(WARNING) << "Memory pool block size " << config_size << " is bigger than currently available maximum memory "
336                     << available_device_mem_size << ", and the actual effective value will be "
337                     << available_device_mem_size;
338   }
339   // Reserve 1G for persistent_mem
340   if (available_device_mem_size > kGBToByte) {
341     available_device_mem_size -= kGBToByte;
342   }
343   size_t real_block_size = std::min(config_size, available_device_mem_size);
344   SetMemAllocUintSize(real_block_size);
345 }
346 
AddMemBlockAndMemBuf(size_t size,bool from_persistent_mem,bool need_recycle,uint32_t stream_id)347 DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size, bool from_persistent_mem, bool need_recycle,
348                                                          uint32_t stream_id) {
349   if (from_persistent_mem && !need_recycle && !persistent_mem_->Empty()) {
350     from_persistent_mem = false;
351   }
352 
353   // Try eager free routine.
354   if (IsEnableVmm() || is_trigger_eager_free_) {
355     is_trigger_eager_free_ = true;
356     return AddMemBlockAndMemBufByEagerFree(size, from_persistent_mem, stream_id);
357   }
358 
359   size_t alloc_mem_size = CalMemBlockAllocSize(size, from_persistent_mem, need_recycle);
360   MS_LOG(DEBUG) << "CalMemBlockAllocSize return : " << size << ", alloc_mem_size : " << alloc_mem_size;
361   if (alloc_mem_size == 0) {
362     if (auto device_addr = FindAvailableMemBuf(size, !from_persistent_mem, stream_id)) {
363       return device_addr;
364     }
365     if (IsEnableEagerFree()) {
366       is_trigger_eager_free_ = true;
367       return AddMemBlockAndMemBufByEagerFree(size, from_persistent_mem, stream_id);
368     }
369     return nullptr;
370   }
371 
372   // Add new memory block
373   DeviceMemPtr device_addr = nullptr;
374   auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr);
375   if (real_alloc_size < size) {
376     MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
377                     << "].";
378     return nullptr;
379   }
380   // If unit_size is changed by other function(not context), change unit_size back
381   MS_EXCEPTION_IF_NULL(common_mem_);
382   common_mem_->unit_size_ = config_unit_size_;
383 
384   return CreateMemBlockAndMemBuf(size, from_persistent_mem, device_addr, real_alloc_size,
385                                  DynamicMemBufStatus::kMemBufIdle, stream_id);
386 }
387 
AddMemBlockAndMemBufByEagerFree(size_t size,bool from_persistent_mem,uint32_t stream_id)388 DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBufByEagerFree(size_t size, bool from_persistent_mem,
389                                                                     uint32_t stream_id) {
390   // Check used max memory limits.
391   if (TotalUsedMemStatistics() + TotalUsedByEventMemStatistics() + size > total_mem_size()) {
392     MS_LOG(ERROR) << "TotalUsedMemStatistics : " << TotalUsedMemStatistics()
393                   << " plus TotalUsedByEventMemStatistics : " << TotalUsedByEventMemStatistics()
394                   << " and plus alloc size : " << size << " is more than total mem size : " << total_mem_size() << ".";
395     return nullptr;
396   }
397 
398   MS_LOG(DEBUG) << "Try to eager free memory.";
399   if (!SyncAllStreams()) {
400     MS_LOG(INTERNAL_EXCEPTION) << "Sync all streams failed.";
401   }
402   FreeIdleMemsByEagerFree();
403   auto mem_addr = FindMemBufByStatus(size, from_persistent_mem, DynamicMemBufStatus::kMemBufEagerFree, stream_id);
404   if (mem_addr != nullptr) {
405     MS_LOG(DEBUG) << "Find eager free memory success, mem_addr : " << mem_addr << ".";
406     return mem_addr;
407   }
408 
409   auto alloc_size = std::max(size, static_cast<size_t>(total_mem_size()));
410   MS_LOG(INFO) << "Try to alloc eager free mem block, size : " << size << ", alloc_size : " << alloc_size << ".";
411   DeviceMemPtr device_addr = nullptr;
412   auto real_alloc_size = AllocDeviceMemByEagerFree(alloc_size, &device_addr);
413   if (real_alloc_size < alloc_size) {
414     MS_LOG(ERROR) << "AllocDeviceMemByEagerFree failed, alloc_size : " << real_alloc_size << ".";
415     return nullptr;
416   }
417   return CreateMemBlockAndMemBuf(size, from_persistent_mem, device_addr, real_alloc_size,
418                                  DynamicMemBufStatus::kMemBufEagerFree, stream_id);
419 }
420 
CreateMemBlockAndMemBuf(size_t size,bool from_persistent_mem,DeviceMemPtr source_addr,size_t source_size,DynamicMemBufStatus mem_buf_status,uint32_t stream_id)421 DeviceMemPtr DynamicMemPoolBestFit::CreateMemBlockAndMemBuf(size_t size, bool from_persistent_mem,
422                                                             DeviceMemPtr source_addr, size_t source_size,
423                                                             DynamicMemBufStatus mem_buf_status, uint32_t stream_id) {
424   auto mem_block = std::make_shared<DynamicMemBlock>(source_addr, source_size, stream_id);
425   auto mem_mng = from_persistent_mem ? persistent_mem_ : common_mem_;
426   mem_mng->AddMemBlock(mem_block, stream_id);
427   // Add new memory buf.
428   auto mem_buf = std::make_shared<DynamicMemBuf>(mem_block->device_addr(), mem_buf_status, mem_block->size(), stream_id,
429                                                  DynamicMemAllocatorDebugInfo::GetDebugInfo().name_,
430                                                  DynamicMemAllocatorDebugInfo::GetDebugInfo().type_);
431   if (mem_buf->status_ == DynamicMemBufStatus::kMemBufEagerFree && IsEnableVmm()) {
432     MS_LOG(DEBUG) << "Find eager free memory, mem_buf_size[" << mem_buf->size_ << "] mem_buf_address["
433                   << mem_buf->device_addr_ << "], need size: " << size;
434     auto ret = MmapDeviceMem(size, mem_buf->device_addr_);
435     if (ret != size) {
436       return nullptr;
437     }
438   }
439   // Add map of new memory buf in the block
440   (void)mem_block->block_all_mem_buf_map_.emplace(mem_block->device_addr(), mem_buf);
441   // Split memory buf
442   if (IsSplit(size, mem_buf->size_)) {
443     SplitMemBuf(size, mem_buf, mem_mng, stream_id);
444   }
445   mem_block->update_border_addr(mem_buf->device_addr_, AddressOffset(mem_buf->device_addr_, mem_buf->size_));
446   mem_buf->status_ = DynamicMemBufStatus::kMemBufUsed;
447   // Memory statistics
448   mem_mng->mps_.total_mem_size_ += mem_block->size();
449   mem_mng->mps_.total_used_mem_size_ += mem_buf->size_;
450   mem_mng->mps_.UpdatePeakSize();
451   if (mem_buf_status == DynamicMemBufStatus::kMemBufIdle) {
452     mem_mng->mps_.total_idle_mem_size_ += source_size - mem_buf->size_;
453   } else if (mem_buf_status == DynamicMemBufStatus::kMemBufEagerFree) {
454     mem_mng->mps_.total_eager_free_mem_size_ += source_size - mem_buf->size_;
455   } else {
456     MS_LOG(INTERNAL_EXCEPTION) << "Unsupported mem_buf_status : " << mem_buf_status << ".";
457   }
458   MS_LOG(DEBUG) << "Usage: used size : " << TotalUsedMemStatistics()
459                 << ", used by event size : " << TotalUsedByEventMemStatistics()
460                 << ", idle size : " << TotalIdleMemStatistics()
461                 << ", eager free size : " << TotalEagerFreeMemStatistics() << ".";
462   return mem_buf->device_addr_;
463 }
464 
CalMemBlockAllocSize(size_t size,bool from_persistent_mem,bool)465 size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size, bool from_persistent_mem, bool) {
466   auto device_free_mem_size = free_mem_size();
467   if (device_free_mem_size < size && common::IsNeedProfileMemory()) {
468     device_free_mem_size = size;
469   }
470   if (device_free_mem_size < size) {
471     MS_LOG(INFO) << "Memory not enough: current free memory size[" << device_free_mem_size
472                  << "] is smaller than required size[" << size << "].";
473     return 0;
474   }
475   // The memory of the device is too small, which may cause the new application to fail.
476   if (device_free_mem_size < kMinimumAllocMem) {
477     MS_LOG(INFO) << "Device memory size [" << device_free_mem_size << "] is smaller than minimum alloc size ["
478                  << kMinimumAllocMem << "].";
479     return 0;
480   }
481   auto alloc_mem_size = MemAllocUnitSize(from_persistent_mem);
482   // Growing at twice of alloc size
483   constexpr size_t kDouble = 2;
484   while (alloc_mem_size < size) {
485     alloc_mem_size = alloc_mem_size * kDouble;
486   }
487   alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size);
488   return alloc_mem_size;
489 }
490 
FreeIdleMemsByEagerFree()491 const size_t DynamicMemPoolBestFit::FreeIdleMemsByEagerFree() {
492   eager_free_count_++;
493 
494   auto eager_free_mem_func = [&](MemStatusManagerPtr &mem_mng) {
495     const auto &stream_ids = mem_mng->GetStreamIds();
496     for (const auto &stream_id : stream_ids) {
497       auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufIdle);
498       auto &&iter = mem_mng->mem_bufs_.find(key);
499       if (iter == mem_mng->mem_bufs_.end()) {
500         continue;
501       }
502       auto &mem_buf_map = iter->second;
503       for (auto &size_mem_buf : mem_buf_map) {
504         auto &mem_buf = size_mem_buf.second;
505         auto [mem_block, iter, mem_mng] = FindByStrictAddr(mem_buf->device_addr_);
506         if (PreCombineMemBuf(mem_buf, mem_mng)) {
507           CombineMemBuf(mem_block, iter, mem_mng, DynamicMemBufStatus::kMemBufIdle,
508                         DynamicMemBufStatus::kMemBufEagerFree);
509         }
510       }
511       mem_mng->mem_bufs_.erase(iter);
512     }
513     // After memory free idle, do eager free.
514     size_t free_size = 0;
515     size_t real_free_size = 0;
516     for (const auto &stream_id : stream_ids) {
517       auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufEagerFree);
518       auto &&iter = mem_mng->mem_bufs_.find(key);
519       if (iter == mem_mng->mem_bufs_.end()) {
520         continue;
521       }
522       auto &mem_buf_map = iter->second;
523       for (auto &size_mem_buf : mem_buf_map) {
524         auto &mem_buf = size_mem_buf.second;
525         free_size += mem_buf->size_;
526         MS_LOG(DEBUG) << "Eager free address : " << mem_buf->device_addr_ << ".";
527         real_free_size += FreeDeviceMemByEagerFree(mem_buf->device_addr_, mem_buf->size_);
528       }
529     }
530 
531     return std::make_pair(free_size, real_free_size);
532   };
533 
534   const auto [persistent_free_size, persistent_real_free_size] = eager_free_mem_func(persistent_mem_);
535   const auto [common_free_size, common_real_free_size] = eager_free_mem_func(common_mem_);
536   auto free_size = persistent_free_size + common_free_size;
537   auto real_free_size = persistent_real_free_size + common_real_free_size;
538   static bool is_enable_memory_statistics = common::IsEnableRuntimeConfig(common::kRuntimeMemoryStat);
539   if (is_enable_memory_statistics) {
540     std::cout << "Total eager free memory : " << free_size << ", real free : " << real_free_size
541               << ", not free size: " << (free_size - real_free_size) << "." << std::endl;
542   }
543   MS_LOG(INFO) << "Eager free count : " << eager_free_count_ << ", free memory : " << free_size
544                << ", real free : " << real_free_size << ", not free size: " << (free_size - real_free_size) << ".";
545   return real_free_size;
546 }
547 
DefragMemory()548 void DynamicMemPoolBestFit::DefragMemory() {
549   MS_LOG(DEBUG) << "Start defrag memory.";
550 #ifdef __APPLE__
551   std::lock_guard<SpinLock> spin_lock(spin_lock_);
552 #else
553   std::lock_guard<std::mutex> locker(mutex_);
554 #endif
555 
556   // eager free count initialize with 0, and increase by initializing persistent pool and common pool.
557   if (eager_free_count_ <= 2L) {
558     MS_LOG(DEBUG) << "Exit defrag memory since eager free count is 0.";
559     return;
560   }
561   if (last_eager_free_count_ == eager_free_count_) {
562     MS_LOG(DEBUG) << "Exit defrag memory since last eager free count equals to eager free count : "
563                   << last_eager_free_count_ << ".";
564     return;
565   }
566 
567   MS_LOG(INFO) << "Try to defrag memory.";
568   if (!SyncAllStreams()) {
569     MS_LOG(INTERNAL_EXCEPTION) << "Sync all streams failed.";
570   }
571   FreeIdleMemsByEagerFree();
572   last_eager_free_count_ = eager_free_count_;
573 }
574 
IsSplit(size_t tensor_size,size_t mem_buf_size) const575 bool DynamicMemPoolBestFit::IsSplit(size_t tensor_size, size_t mem_buf_size) const {
576   return mem_buf_size - tensor_size >= kDynamicMemAlignSize;
577 }
578 
SplitMemBuf(size_t size,const DynamicMemBufPtr & mem_buf,const MemStatusManagerPtr & mem_mng,uint32_t stream_id)579 void DynamicMemPoolBestFit::SplitMemBuf(size_t size, const DynamicMemBufPtr &mem_buf,
580                                         const MemStatusManagerPtr &mem_mng, uint32_t stream_id) {
581   MS_EXCEPTION_IF_NULL(mem_buf);
582   MS_EXCEPTION_IF_NULL(mem_mng);
583   const auto &mem_block = FindMemBlock(mem_buf->device_addr_, mem_mng);
584   MS_EXCEPTION_IF_NULL(mem_block);
585   // Divide new memory buf
586   if (mem_buf->size_ < size) {
587     DumpDynamicMemPoolDebugInfo();
588     MS_LOG(EXCEPTION) << "The size of membuf is less than size.";
589   }
590   size_t newbuf_size = mem_buf->size_ - size;
591   mem_buf->size_ = size;
592   DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size);
593   auto new_mem_buf = std::make_shared<DynamicMemBuf>(newbuf_addr, mem_buf->status_, newbuf_size, stream_id);
594   // Add map of new memory buf in the block
595   (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
596   mem_mng->AddMemBuf(new_mem_buf);
597 }
598 
CmpMemBlock(const DeviceMemPtr & device_addr,const DynamicMemBlockPtr & mem_block)599 bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
600   MS_EXCEPTION_IF_NULL(device_addr);
601   MS_EXCEPTION_IF_NULL(mem_block);
602   return device_addr < mem_block->device_addr();
603 }
604 
FindMemBlock(const DeviceMemPtr & device_addr,const MemStatusManagerPtr & mem_mng) const605 DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr,
606                                                        const MemStatusManagerPtr &mem_mng) const {
607   MS_EXCEPTION_IF_NULL(device_addr);
608   MS_EXCEPTION_IF_NULL(mem_mng);
609   auto &&iter =
610     std::upper_bound(mem_mng->mem_block_list_.begin(), mem_mng->mem_block_list_.end(), device_addr, CmpMemBlock);
611   if (iter != mem_mng->mem_block_list_.begin()) {
612     return *(--iter);
613   }
614   return nullptr;
615 }
616 
FreeTensorMem(const DeviceMemPtr & device_addr)617 void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
618 #ifdef __APPLE__
619   std::lock_guard<SpinLock> spin_lock(spin_lock_);
620 #else
621   std::lock_guard<std::mutex> locker(mutex_);
622 #endif
623   FreeTensorMemInner(device_addr);
624 }
625 
FreeTensorMemInner(const DeviceMemPtr & device_addr)626 void DynamicMemPoolBestFit::FreeTensorMemInner(const DeviceMemPtr &device_addr) {
627   auto [mem_block, iter, mem_mng] = FindByStrictAddr(device_addr);
628   if (mem_block == nullptr) {
629     // Maybe destroy the memory pool first, then destroy the address, so this is normal case.
630     MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
631     return;
632   }
633   auto mem_buf = iter->second;
634   MS_EXCEPTION_IF_NULL(mem_buf);
635   if (PreCombineMemBuf(mem_buf, mem_mng)) {
636     CombineMemBuf(mem_block, iter, mem_mng, mem_buf->status_, DynamicMemBufStatus::kMemBufIdle);
637     if (IsMemoryPoolRecycle()) {
638       (void)mem_bufs_.erase(device_addr);
639     }
640     MS_LOG(DEBUG) << "Free memory details, name:" << DynamicMemAllocatorDebugInfo::GetDebugInfo().name_
641                   << ", address:" << device_addr << ", total allocated mem:" << TotalMemStatistics()
642                   << "B, peak used mem:" << UsedMemPeakStatistics() << "B, in used mem:" << TotalUsedMemStatistics()
643                   << "B, used by event mem:" << TotalUsedByEventMemStatistics()
644                   << "B, actual peak used mem:" << ActualPeakStatistics()
645                   << "B, total idle mem:" << TotalIdleMemStatistics() << "B.";
646   }
647 }
648 
649 // PreCombineMemBuf judge status for mem buf can be combined or not.
650 // If there are no events recorded on mem buf, return true to release mem buf.
651 // If there are events recorded on mem buf, change status of mem buf to kMemBufUsedByEvent, and return false.
652 // Note: Before release mem buf by event, must make share that the status of mem buf is kMemBufUsedByEvent,
653 // or wait event may release mem buf incorrectly.
PreCombineMemBuf(const DynamicMemBufPtr & mem_buf,const MemStatusManagerPtr & mem_mng)654 bool DynamicMemPoolBestFit::PreCombineMemBuf(const DynamicMemBufPtr &mem_buf, const MemStatusManagerPtr &mem_mng) {
655   auto device_addr = mem_buf->device_addr_;
656   if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsed && !mem_buf->IsEventNotUsed()) {
657     mem_buf->status_ = DynamicMemBufStatus::kMemBufUsedByEvent;
658     mem_mng->mps_.total_used_mem_size_ -= mem_buf->size_;
659     mem_mng->mps_.total_used_by_event_mem_size_ += mem_buf->size_;
660     MS_LOG(DEBUG) << "Combine mem buf exit since mem buf is used by event, device_addr : " << device_addr
661                   << ", used by event mem size : " << mem_mng->mps_.total_used_by_event_mem_size_ << ".";
662     return false;
663   }
664 
665   if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsedByEvent && !mem_buf->IsEventNotUsed()) {
666     MS_LOG(INTERNAL_EXCEPTION) << "Combine mem buf failed as mem buf can not be freed, device_addr : " << device_addr
667                                << ".";
668   }
669 
670   MS_LOG(DEBUG) << "Pre combine mem buf address : " << mem_buf->device_addr_ << " success.";
671   return true;
672 }
673 
CombineMemBuf(const DynamicMemBlockPtr & mem_block,const DeviceAddrMapMemBuf::iterator & iter,const MemStatusManagerPtr & mem_mng,DynamicMemBufStatus origin_status,DynamicMemBufStatus target_status)674 void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block,
675                                           const DeviceAddrMapMemBuf::iterator &iter, const MemStatusManagerPtr &mem_mng,
676                                           DynamicMemBufStatus origin_status, DynamicMemBufStatus target_status) {
677   const auto &mem_buf = iter->second;
678   MS_LOG(DEBUG) << "Combine mem buf release mem buf, device_addr : " << mem_buf->device_addr_ << ".";
679 
680 // report memory data to profiler
681 #ifdef ENABLE_DEBUGGER
682   static auto profiler_inst = profiler::cpu::CPUProfiler::GetInstance();
683   MS_EXCEPTION_IF_NULL(profiler_inst);
684   if (profiler_inst->GetEnableFlag() && profiler_inst->GetProfileMemoryFlag()) {
685     profiler_inst->RecordMemoryPoolInfo(TotalUsedMemStatistics(), TotalMemStatistics(),
686                                         TotalUsedByEventMemStatistics());
687   }
688 #endif
689 
690   if (common::IsNeedProfileMemory()) {
691     MS_LOG(WARNING) << "Need Profile Memory, Memory pool free, total mem: " << TotalMemStatistics()
692                     << ", peak mem: " << UsedMemPeakStatistics() << ", in use mem: " << TotalUsedMemStatistics()
693                     << ", used by event mem: " << TotalUsedByEventMemStatistics()
694                     << ", device address addr: " << mem_buf->device_addr_ << ", size: " << mem_buf->size_;
695   }
696   if (device::tracker::MemTrackerManager::GetInstance().IsEnabled() &&
697       target_status == DynamicMemBufStatus::kMemBufIdle) {
698     device::tracker::CALL_MEMORY_TRACKER(FreeMemBlock, mem_buf->device_addr_, TotalUsedMemStatistics(),
699                                          TotalMemStatistics());
700   }
701 
702   if (mem_buf->status_ != origin_status) {
703     DumpDynamicMemPoolDebugInfo();
704     MS_LOG(EXCEPTION) << "Find the mem_buf status : " << mem_buf->status_
705                       << " is not equal to origin status : " << origin_status << ", mem_buf_address["
706                       << mem_buf->device_addr_ << "].";
707   }
708   mem_buf->status_ = target_status;
709   if (origin_status == DynamicMemBufStatus::kMemBufUsed) {
710     if (mem_mng->mps_.total_used_mem_size_ < mem_buf->size_) {
711       DumpDynamicMemPoolDebugInfo();
712       MS_LOG(EXCEPTION) << "The total used mem size : " << mem_mng->mps_.total_used_mem_size_
713                         << " is less than the size of membuf : " << mem_buf->size_ << ".";
714     }
715     mem_mng->mps_.total_used_mem_size_ -= mem_buf->size_;
716   } else if (origin_status == DynamicMemBufStatus::kMemBufUsedByEvent) {
717     if (mem_mng->mps_.total_used_by_event_mem_size_ < mem_buf->size_) {
718       DumpDynamicMemPoolDebugInfo();
719       MS_LOG(EXCEPTION) << "The total used by event mem size : " << mem_mng->mps_.total_used_by_event_mem_size_
720                         << " is less than the size of membuf : " << mem_buf->size_ << ".";
721     }
722     mem_mng->mps_.total_used_by_event_mem_size_ -= mem_buf->size_;
723     MS_LOG(DEBUG) << "Combime mem buf for addr : " << mem_buf->device_addr_
724                   << ", used by event mem size : " << mem_mng->mps_.total_used_by_event_mem_size_ << ".";
725   } else if (origin_status == DynamicMemBufStatus::kMemBufIdle) {
726     if (mem_mng->mps_.total_idle_mem_size_ < mem_buf->size_) {
727       DumpDynamicMemPoolDebugInfo();
728       MS_LOG(EXCEPTION) << "The total idle mem size : " << mem_mng->mps_.total_idle_mem_size_
729                         << " is less than the size of membuf : " << mem_buf->size_ << ".";
730     }
731     mem_mng->mps_.total_idle_mem_size_ -= mem_buf->size_;
732   } else {
733     MS_LOG(INTERNAL_EXCEPTION) << "Unsupported origin status : " << origin_status << ".";
734   }
735   if (target_status == DynamicMemBufStatus::kMemBufIdle) {
736     mem_mng->mps_.total_idle_mem_size_ += mem_buf->size_;
737   } else if (target_status == DynamicMemBufStatus::kMemBufEagerFree) {
738     mem_mng->mps_.total_eager_free_mem_size_ += mem_buf->size_;
739   } else {
740     MS_LOG(INTERNAL_EXCEPTION) << "Unsupported target status : " << target_status << ".";
741   }
742   // Combine backward(combine the next_mem_buf to mem_buf)
743   auto next_iter = iter;
744   (void)next_iter++;
745   if (next_iter != mem_block->block_all_mem_buf_map_.end()) {
746     auto next_mem_buf = next_iter->second;
747     MS_EXCEPTION_IF_NULL(next_mem_buf);
748     if (next_mem_buf->status_ == target_status) {
749       mem_buf->size_ += next_mem_buf->size_;
750       mem_mng->RemoveMemBuf(next_mem_buf);
751 
752       (void)mem_block->block_all_mem_buf_map_.erase(next_iter);
753     }
754   }
755   // Combine forward(combine the mem_buf to prev_mem_buf)
756   bool forward_combine = false;
757   DynamicMemBufPtr prev_mem_buf;
758   if (iter != mem_block->block_all_mem_buf_map_.begin()) {
759     auto prev_iter = iter;
760     (void)prev_iter--;
761     prev_mem_buf = prev_iter->second;
762     MS_EXCEPTION_IF_NULL(prev_mem_buf);
763     if (prev_mem_buf->status_ == target_status) {
764       mem_mng->RemoveMemBuf(prev_mem_buf);
765       prev_mem_buf->size_ += mem_buf->size_;
766       (void)mem_block->block_all_mem_buf_map_.erase(iter);
767       forward_combine = true;
768     }
769   }
770 
771   if (forward_combine) {
772     mem_mng->AddMemBuf(prev_mem_buf);
773   } else {
774     mem_mng->AddMemBuf(mem_buf);
775   }
776 }
777 
778 std::tuple<DynamicMemBlockPtr, DeviceAddrMapMemBuf::iterator, MemStatusManagerPtr>
FindByStrictAddr(const DeviceMemPtr & device_addr) const779 DynamicMemPoolBestFit::FindByStrictAddr(const DeviceMemPtr &device_addr) const {
780   MS_EXCEPTION_IF_NULL(device_addr);
781   // Find in the common pool.
782   auto mem_block = FindMemBlock(device_addr, common_mem_);
783   if (mem_block != nullptr) {
784     const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
785     if (iter != mem_block->block_all_mem_buf_map_.end()) {
786       return std::make_tuple(mem_block, iter, common_mem_);
787     }
788   }
789 
790   // Find in the persistent pool.
791   mem_block = FindMemBlock(device_addr, persistent_mem_);
792   if (mem_block != nullptr) {
793     const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
794     if (iter != mem_block->block_all_mem_buf_map_.end()) {
795       return std::make_tuple(mem_block, iter, persistent_mem_);
796     }
797   }
798 
799   DeviceAddrMapMemBuf empty_map;
800   return std::make_tuple(nullptr, empty_map.end(), common_mem_);
801 }
802 
FreePartTensorMems(const std::vector<DeviceMemPtr> & free_addrs,const std::vector<DeviceMemPtr> & keep_addrs,const std::vector<size_t> & keep_addr_sizes)803 void DynamicMemPoolBestFit::FreePartTensorMems(const std::vector<DeviceMemPtr> &free_addrs,
804                                                const std::vector<DeviceMemPtr> &keep_addrs,
805                                                const std::vector<size_t> &keep_addr_sizes) {
806 #ifdef __APPLE__
807   std::lock_guard<SpinLock> spin_lock(spin_lock_);
808 #else
809   std::lock_guard<std::mutex> locker(mutex_);
810 #endif
811 
812   for (auto &free_addr : free_addrs) {
813     FreeTensorMemInner(free_addr);
814   }
815 
816   MS_EXCEPTION_IF_CHECK_FAIL((keep_addrs.size() == keep_addr_sizes.size()), "The keep addrs size is wrong.");
817   for (size_t i = 0; i < keep_addrs.size(); ++i) {
818     KeepTensorMemByAddr(keep_addrs[i], keep_addr_sizes[i]);
819   }
820 }
821 
KeepTensorMemByAddr(const DeviceMemPtr & device_addr,size_t size)822 void DynamicMemPoolBestFit::KeepTensorMemByAddr(const DeviceMemPtr &device_addr, size_t size) {
823   MS_EXCEPTION_IF_NULL(device_addr);
824   // Fetch the memblock and membuf by the device address.
825   auto [mem_block, mem_buf, mem_mng] = FindByKeepAddr(device_addr);
826   if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
827     device::tracker::CALL_MEMORY_TRACKER(AllocMemBlock, device_addr, size, GetMemoryPoolType(), ActualPeakStatistics(),
828                                          TotalUsedMemStatistics(), TotalMemStatistics(), mem_block->stream_id_);
829   }
830   MS_EXCEPTION_IF_NULL(mem_block);
831   MS_EXCEPTION_IF_NULL(mem_buf);
832   MS_EXCEPTION_IF_NULL(mem_mng);
833   if (mem_buf->status_ != DynamicMemBufStatus::kMemBufIdle) {
834     DumpDynamicMemPoolDebugInfo();
835     MS_LOG(EXCEPTION) << "The membuf status isn't idle for addr:" << device_addr << ", size:" << size
836                       << ", find the mem buf addr:" << mem_buf->device_addr_ << ", size:" << mem_buf->size_;
837   }
838 
839   // Calculate the size of left and right split membuf.
840   size_t split_left_size = CalAddressOffset(device_addr, mem_buf->device_addr_);
841   MS_EXCEPTION_IF_CHECK_FAIL((mem_buf->size_ >= (split_left_size + size)), "The split size is wrong.");
842   size_t split_right_szie = mem_buf->size_ - split_left_size - size;
843 
844   // Split the left membuf.
845   mem_mng->RemoveMemBuf(mem_buf);
846   if (split_left_size == 0) {
847     mem_buf->status_ = DynamicMemBufStatus::kMemBufUsed;
848     mem_buf->size_ = size;
849     mem_buf->allocator_name_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().name_;
850     mem_buf->allocator_type_ = DynamicMemAllocatorDebugInfo::GetDebugInfo().type_;
851   } else {
852     mem_buf->size_ = split_left_size;
853     mem_mng->AddMemBuf(mem_buf);
854 
855     auto used_mem_buf = std::make_shared<DynamicMemBuf>(
856       device_addr, DynamicMemBufStatus::kMemBufUsed, size, mem_block->stream_id_,
857       DynamicMemAllocatorDebugInfo::GetDebugInfo().name_, DynamicMemAllocatorDebugInfo::GetDebugInfo().type_);
858     (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, used_mem_buf);
859   }
860 
861   // Split the right membuf.
862   if (split_right_szie > 0) {
863     DeviceMemPtr right_buf_addr = AddressOffset(device_addr, size);
864     auto right_mem_buf = std::make_shared<DynamicMemBuf>(right_buf_addr, DynamicMemBufStatus::kMemBufIdle,
865                                                          split_right_szie, mem_block->stream_id_);
866     (void)mem_block->block_all_mem_buf_map_.emplace(right_buf_addr, right_mem_buf);
867     mem_mng->AddMemBuf(right_mem_buf);
868   }
869 
870   // Memory statistics.
871   mem_mng->mps_.total_used_mem_size_ += size;
872   mem_mng->mps_.UpdatePeakSize();
873   mem_mng->mps_.total_idle_mem_size_ -= size;
874   MS_LOG(DEBUG) << "Keep memory details, name:" << DynamicMemAllocatorDebugInfo::GetDebugInfo().name_
875                 << ", address:" << device_addr << ", size:" << size << "B, total allocated mem:" << TotalMemStatistics()
876                 << "B, peak used mem:" << UsedMemPeakStatistics() << "B, in used mem:" << TotalUsedMemStatistics()
877                 << "B, used by event mem:" << TotalUsedByEventMemStatistics()
878                 << "B, actual peak used mem:" << ActualPeakStatistics()
879                 << "B, total idle mem:" << TotalIdleMemStatistics() << "B.";
880 }
881 
FindMemBufByKeepAddr(const DeviceMemPtr & device_addr,const DynamicMemBlockPtr & mem_block) const882 DynamicMemBufPtr DynamicMemPoolBestFit::FindMemBufByKeepAddr(const DeviceMemPtr &device_addr,
883                                                              const DynamicMemBlockPtr &mem_block) const {
884   MS_EXCEPTION_IF_NULL(device_addr);
885   MS_EXCEPTION_IF_NULL(mem_block);
886   auto &&iter = mem_block->block_all_mem_buf_map_.upper_bound(device_addr);
887   if (iter != mem_block->block_all_mem_buf_map_.begin()) {
888     return (--iter)->second;
889   }
890   return nullptr;
891 }
892 
FindByKeepAddr(const DeviceMemPtr & device_addr) const893 std::tuple<DynamicMemBlockPtr, DynamicMemBufPtr, MemStatusManagerPtr> DynamicMemPoolBestFit::FindByKeepAddr(
894   const DeviceMemPtr &device_addr) const {
895   MS_EXCEPTION_IF_NULL(device_addr);
896   auto is_addr_in_membuf = [](const DeviceMemPtr &device_addr, const DynamicMemBufPtr &mem_buf) {
897     return (mem_buf != nullptr) && (device_addr >= mem_buf->device_addr_) &&
898            (mem_buf->size_ >= CalAddressOffset(device_addr, mem_buf->device_addr_));
899   };
900 
901   // Find in the common pool.
902   auto mem_block = FindMemBlock(device_addr, common_mem_);
903   if (mem_block != nullptr) {
904     auto mem_buf = FindMemBufByKeepAddr(device_addr, mem_block);
905     if (is_addr_in_membuf(device_addr, mem_buf)) {
906       return std::make_tuple(mem_block, mem_buf, common_mem_);
907     }
908   }
909 
910   // Find in the persistent pool.
911   mem_block = FindMemBlock(device_addr, persistent_mem_);
912   if (mem_block != nullptr) {
913     auto mem_buf = FindMemBufByKeepAddr(device_addr, mem_block);
914     if (is_addr_in_membuf(device_addr, mem_buf)) {
915       return std::make_tuple(mem_block, mem_buf, persistent_mem_);
916     }
917   }
918 
919   return std::make_tuple(nullptr, nullptr, common_mem_);
920 }
921 
ReleaseDeviceRes()922 void DynamicMemPoolBestFit::ReleaseDeviceRes() {
923 #ifdef __APPLE__
924   std::lock_guard<SpinLock> spin_lock(spin_lock_);
925 #else
926   std::lock_guard<std::mutex> locker(mutex_);
927 #endif
928   DumpDynamicMemPoolStateInfo();
929 
930   auto fn = [this](const MemStatusManagerPtr &mem_mng) {
931     MS_EXCEPTION_IF_NULL(mem_mng);
932     for (auto &iter : mem_mng->mem_block_list_) {
933       MS_EXCEPTION_IF_NULL(iter);
934       auto &device_addr = iter->device_addr_base_;
935       if (device_addr != nullptr) {
936         if (!FreeDeviceMem(device_addr)) {
937           MS_LOG(ERROR) << "Free device memory[" << device_addr << "] error.";
938         }
939         device_addr = nullptr;
940       }
941     }
942     mem_mng->Clear();
943   };
944   fn(common_mem_);
945   fn(persistent_mem_);
946 
947   tracker::MemTrackerManager::GetInstance().Dump();
948 }
949 
DumpDynamicMemPoolStateInfo()950 void DynamicMemPoolBestFit::DumpDynamicMemPoolStateInfo() {
951   size_t total_used_size_list[kAllocatorTypeNum] = {0};
952   static bool is_enable_memory_statistics = common::IsEnableRuntimeConfig(common::kRuntimeMemoryStat) ||
953                                             common::IsEnableRuntimeConfig(common::kRuntimeMemoryTrack);
954   auto fn = [&](const MemStatusManagerPtr &mem_mng, const std::string &mem_type) {
955     MS_EXCEPTION_IF_NULL(mem_mng);
956     if (mem_mng->Empty()) {
957       return;
958     }
959 
960     std::ostringstream buf;
961     for (size_t i = 0; i < mem_mng->mem_block_list_.size(); ++i) {
962       size_t mem_block_used_size = 0;
963       MS_EXCEPTION_IF_NULL(mem_mng->mem_block_list_[i]);
964       for (auto mb = mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.begin();
965            mb != mem_mng->mem_block_list_[i]->block_all_mem_buf_map_.end(); ++mb) {
966         if (mb->second->status_ == DynamicMemBufStatus::kMemBufUsed) {
967           mem_block_used_size += mb->second->size_;
968           MS_EXCEPTION_IF_CHECK_FAIL((static_cast<int>(mb->second->allocator_type_) < kAllocatorTypeNum),
969                                      "Allocator type is out of range.");
970           total_used_size_list[static_cast<int>(mb->second->allocator_type_)] += mb->second->size_;
971         }
972       }
973       buf << ", block[" << i << "] stream id:" << mem_mng->mem_block_list_[i]->stream_id_
974           << " block size:" << mem_mng->mem_block_list_[i]->mem_block_size_ / kMBToByte
975           << "M idle size:" << (mem_mng->mem_block_list_[i]->mem_block_size_ - mem_block_used_size) / kMBToByte
976           << "M actual size: " << (mem_mng->mem_block_list_[i]->get_actual_peak()) / kMBToByte << "M.";
977     }
978     std::ostringstream oss_buf;
979     // Dump all the memory buf info
980     oss_buf << mem_type << " pool info: Total allocated mem:" << mem_mng->mps_.total_mem_size_ / kMBToByte
981             << "M, peak used mem:" << mem_mng->mps_.used_mem_peak_size_ / kMBToByte
982             << "M, in used mem:" << mem_mng->mps_.total_used_mem_size_ / kMBToByte
983             << "M, total use by event mem:" << mem_mng->mps_.total_used_by_event_mem_size_ / kMBToByte
984             << "M, total idle mem:" << mem_mng->mps_.total_idle_mem_size_ / kMBToByte
985             << "M. Block unit size:" << mem_mng->unit_size_ / kMBToByte
986             << "M, block counts:" << mem_mng->mem_block_list_.size() << buf.str();
987     if (is_enable_memory_statistics) {
988       std::cout << "[MS_RUNTIME_PROF]" << oss_buf.str() << std::endl;
989     }
990     MS_LOG(INFO) << oss_buf.str();
991   };
992 
993   fn(common_mem_, std::string(kCommonMem));
994   fn(persistent_mem_, std::string(kPersistentParamMem));
995   std::ostringstream oss_mem;
996   oss_mem << "The dynamic memory pool total allocated mem:" << TotalMemStatistics() / kMBToByte
997           << "M, min addr :" << GetMinUsingMemoryAddr()
998           << ", max addr: " << (mem_bufs_.empty() ? nullptr : *(--mem_bufs_.end()))
999           << ", peak used mem:" << UsedMemPeakStatistics() / kMBToByte
1000           << "M, actual peak used mem:" << ActualPeakStatistics() / kMBToByte
1001           << "M, in used mem:" << TotalUsedMemStatistics() / kMBToByte
1002           << "M, total used by event mem:" << TotalUsedByEventMemStatistics() / kMBToByte
1003           << "M, total idle mem:" << TotalIdleMemStatistics() / kMBToByte
1004           << "M, total eager free mem:" << TotalEagerFreeMemStatistics() / kMBToByte
1005           << "M. Weight used size:" << total_used_size_list[static_cast<int>(AllocatorType::kWeight)] / kMBToByte
1006           << "M, constant value used size:"
1007           << total_used_size_list[static_cast<int>(AllocatorType::kConstantValue)] / kMBToByte
1008           << "M, kernel output used size:"
1009           << total_used_size_list[static_cast<int>(AllocatorType::kKernelOutput)] / kMBToByte
1010           << "M, other used size:" << total_used_size_list[static_cast<int>(AllocatorType::kOther)] / kMBToByte << "M.";
1011   if (is_enable_memory_statistics) {
1012     std::cout << "[MS_RUNTIME_PROF]" << oss_mem.str() << std::endl;
1013   }
1014   MS_LOG(INFO) << oss_mem.str();
1015 }
1016 
DumpDynamicMemPoolDebugInfo()1017 void DynamicMemPoolBestFit::DumpDynamicMemPoolDebugInfo() {
1018   auto fn = [](const MemStatusManagerPtr &mem_mng, const std::string &mem_type) {
1019     const auto &device_state = mem_mng->DumpMemBlockDebugInfo(mem_type);
1020     const auto &stream_ids = mem_mng->GetStreamIds();
1021     // Dump all the idle memory buf info.
1022     size_t total_idle_mem_in_mem_mng = 0;
1023     MS_LOG(WARNING) << mem_type << " all idle_mem_bufs info: counts[" << stream_ids.size() << "].";
1024     for (const auto &stream_id : stream_ids) {
1025       auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufIdle);
1026       const auto &&iter = mem_mng->mem_bufs_.find(key);
1027       if (iter == mem_mng->mem_bufs_.end()) {
1028         continue;
1029       }
1030       const auto &mem_buf_map = iter->second;
1031       MS_LOG(WARNING) << "  stream id : " << stream_id << ", idle mem buf info : count[]" << mem_buf_map.size() << "].";
1032       for (auto &&idle_iter = mem_buf_map.begin(); idle_iter != mem_buf_map.end(); idle_iter++) {
1033         auto &mem_buf = idle_iter->second;
1034         MS_EXCEPTION_IF_NULL(mem_buf);
1035         total_idle_mem_in_mem_mng += mem_buf->size_;
1036         MS_LOG(INFO) << " Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_
1037                      << "] status[" << kBufStatusString.at(mem_buf->status_) << "] stream id[" << mem_buf->stream_id_
1038                      << "].";
1039       }
1040     }
1041     // Dump all the eager free memory buf info.
1042     size_t total_eager_free_mem_in_mem_mng = 0;
1043     MS_LOG(WARNING) << mem_type << " all eager free mem_buf info: counts[" << stream_ids.size() << "].";
1044     for (const auto &stream_id : stream_ids) {
1045       auto key = std::make_pair(stream_id, DynamicMemBufStatus::kMemBufEagerFree);
1046       const auto &&iter = mem_mng->mem_bufs_.find(key);
1047       if (iter == mem_mng->mem_bufs_.end()) {
1048         continue;
1049       }
1050       const auto &mem_buf_map = iter->second;
1051       MS_LOG(WARNING) << "  stream id : " << stream_id << ", eager free mem buf info : count[]" << mem_buf_map.size()
1052                       << "].";
1053       for (auto &&idle_iter = mem_buf_map.begin(); idle_iter != mem_buf_map.end(); idle_iter++) {
1054         auto &mem_buf = idle_iter->second;
1055         MS_EXCEPTION_IF_NULL(mem_buf);
1056         total_eager_free_mem_in_mem_mng += mem_buf->size_;
1057         MS_LOG(INFO) << " Eager free mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_
1058                      << "] status[" << kBufStatusString.at(mem_buf->status_) << "] stream id[" << mem_buf->stream_id_
1059                      << "].";
1060       }
1061     }
1062     // Dump the memory statistical info.
1063     MS_LOG(WARNING) << mem_type << " total allocated memory[" << device_state.total_mem_size_ << "], used memory["
1064                     << device_state.total_used_mem_size_ << "], used by event memory["
1065                     << device_state.total_used_by_event_mem_size_ << "], idle memory["
1066                     << device_state.total_idle_mem_size_ << "].";
1067     if (device_state.total_idle_mem_size_ != total_idle_mem_in_mem_mng) {
1068       MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory.";
1069     }
1070     if (device_state.total_used_by_event_mem_size_ != mem_mng->mps_.total_used_by_event_mem_size_) {
1071       MS_LOG(ERROR) << "Check error: the used by event memory in the mem_block is not equal the global idle memory.";
1072     }
1073     if (device_state.total_eager_free_mem_size_ != total_eager_free_mem_in_mem_mng) {
1074       MS_LOG(ERROR) << "Check error: the eager free memory in the mem_block is not equal the global eager free memory.";
1075     }
1076     if (device_state.total_mem_size_ != device_state.total_used_mem_size_ + device_state.total_used_by_event_mem_size_ +
1077                                           device_state.total_idle_mem_size_ + device_state.total_eager_free_mem_size_) {
1078       MS_LOG(ERROR) << "Check error: the the total memory : " << device_state.total_mem_size_
1079                     << " is not equal the sum of used memory : " << device_state.total_used_mem_size_
1080                     << ", use by event memory : " << device_state.total_used_by_event_mem_size_
1081                     << ", idle memory : " << device_state.total_idle_mem_size_
1082                     << " and eager free memory : " << device_state.total_eager_free_mem_size_ << ".";
1083     }
1084   };
1085 
1086   MS_LOG(WARNING) << "Start dump dynamic memory pool debug info.";
1087   fn(common_mem_, std::string(kCommonMem));
1088   fn(persistent_mem_, std::string(kPersistentParamMem));
1089   MS_LOG(WARNING) << "Finish dump dynamic memory pool debug info.";
1090 }
1091 
1092 // Element in vector : memory_stream_id, address
RecordEvent(int64_t task_id_on_stream,uint32_t user_stream_id,const std::vector<std::pair<uint32_t,DeviceMemPtr>> & memory_stream_addresses,const DeviceEventPtr & event)1093 bool DynamicMemPoolBestFit::RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id,
1094                                         const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses,
1095                                         const DeviceEventPtr &event) {
1096   MS_LOG(DEBUG) << "Record event for, task_id_on_stream : " << task_id_on_stream
1097                 << ", user_stream_id : " << user_stream_id
1098                 << ", memory_stream_addresses size : " << memory_stream_addresses.size() << ", event : " << event.get()
1099                 << ".";
1100 #ifdef __APPLE__
1101   std::lock_guard<SpinLock> spin_lock(spin_lock_);
1102 #else
1103   std::lock_guard<std::mutex> locker(mutex_);
1104 #endif
1105   for (auto &[memory_stream_id, address] : memory_stream_addresses) {
1106     auto &&mem_buf_tuple = FindByStrictAddr(address);
1107     auto mem_block = std::get<0>(mem_buf_tuple);
1108     // Output of somas sub graph may be used by somas sub graph inner node, address may not be kept in mem pool.
1109     if (mem_block == nullptr) {
1110       MS_LOG(DEBUG) << "Can't find memblock by address in memory pool.";
1111       continue;
1112     }
1113     auto mem_buf = (std::get<1>(mem_buf_tuple))->second;
1114     (void)mem_buf->RecordEvent(task_id_on_stream, user_stream_id, event);
1115     (void)stream_pair_addresses_[std::make_pair(user_stream_id, memory_stream_id)].emplace(mem_buf);
1116   }
1117   return true;
1118 }
1119 
WaitEvent(int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)1120 bool DynamicMemPoolBestFit::WaitEvent(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id) {
1121 #ifdef __APPLE__
1122   std::lock_guard<SpinLock> spin_lock(spin_lock_);
1123 #else
1124   std::lock_guard<std::mutex> locker(mutex_);
1125 #endif
1126   auto key = std::make_pair(user_stream_id, memory_stream_id);
1127   auto iter = stream_pair_addresses_.find(key);
1128   if (iter == stream_pair_addresses_.end()) {
1129     return false;
1130   }
1131 
1132   auto addresses = iter->second;
1133   for (const auto &address : addresses) {
1134     address->WaitEvent(task_id_on_stream, user_stream_id);
1135     // Remove event and try to free memory.
1136     if (address->IsEventNotUsed()) {
1137       iter->second.erase(address);
1138       if (address->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1139         FreeTensorMemInner(address->device_addr_);
1140       }
1141     }
1142   }
1143   MS_LOG(DEBUG) << "After release, bounded addresses size : " << iter->second.size()
1144                 << ", used by event size : " << TotalUsedByEventMemStatistics() << ".";
1145   return true;
1146 }
1147 
1148 // WaitEvent is called before sync stream, so performance may not be the issue.
WaitEvent(int64_t task_id_on_stream,uint32_t memory_stream_id)1149 bool DynamicMemPoolBestFit::WaitEvent(int64_t task_id_on_stream, uint32_t memory_stream_id) {
1150 #ifdef __APPLE__
1151   std::lock_guard<SpinLock> spin_lock(spin_lock_);
1152 #else
1153   std::lock_guard<std::mutex> locker(mutex_);
1154 #endif
1155   for (auto &stream_pair_addresses : stream_pair_addresses_) {
1156     const auto &[user_stream, memory_stream] = stream_pair_addresses.first;
1157     if (memory_stream != memory_stream_id) {
1158       continue;
1159     }
1160     auto addresses = stream_pair_addresses.second;
1161     for (const auto &address : addresses) {
1162       address->WaitEvent(task_id_on_stream, user_stream);
1163       // Remove event and try to free memory.
1164       if (address->IsEventNotUsed()) {
1165         stream_pair_addresses.second.erase(address);
1166         if (address->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1167           FreeTensorMemInner(address->device_addr_);
1168         }
1169       }
1170     }
1171   }
1172   MS_LOG(DEBUG) << "After release events, task_id_on_stream : " << task_id_on_stream
1173                 << ", memory_stream_id : " << memory_stream_id
1174                 << ", used by event size : " << TotalUsedByEventMemStatistics() << ".";
1175   return true;
1176 }
1177 
SyncAllEvents()1178 bool DynamicMemPoolBestFit::SyncAllEvents() {
1179 #ifdef __APPLE__
1180   std::lock_guard<SpinLock> spin_lock(spin_lock_);
1181 #else
1182   std::lock_guard<std::mutex> locker(mutex_);
1183 #endif
1184   return SyncAllEventsInner();
1185 }
1186 
SyncAllEventsInner()1187 bool DynamicMemPoolBestFit::SyncAllEventsInner() {
1188   MS_LOG(DEBUG) << "Sync all events, stream_pair_addresses_ size : " << stream_pair_addresses_.size() << ".";
1189   if (stream_pair_addresses_.empty()) {
1190     return false;
1191   }
1192 
1193   std::set<DynamicMemBufPtr> carry_event_addresses;
1194   for (const auto &stream_pair_address : stream_pair_addresses_) {
1195     for (const auto &address : stream_pair_address.second) {
1196       (void)carry_event_addresses.emplace(address);
1197     }
1198   }
1199   for (auto &address : carry_event_addresses) {
1200     if (address->SyncAllEvents() && address->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1201       FreeTensorMemInner(address->device_addr_);
1202     }
1203   }
1204 
1205   stream_pair_addresses_.clear();
1206   return true;
1207 }
1208 
1209 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
ExtractBlocksListInfo(const MemStatusManagerPtr & mem_mng) const1210 DynamicMemPoolBestFit::ExtractBlocksListInfo(const MemStatusManagerPtr &mem_mng) const {
1211   std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> blocks_list_info;
1212   for (auto iter = mem_mng->mem_block_list_.begin(); iter != mem_mng->mem_block_list_.end(); ++iter) {
1213     std::unordered_map<std::string, size_t> block_info;
1214     block_info[kBlockMemorySize] = (*iter)->size();
1215     block_info[kBlockStreamId] = (*iter)->stream_id_;
1216     blocks_list_info[(std::string *)(*iter)->device_addr()] = block_info;
1217   }
1218   return blocks_list_info;
1219 }
1220 
1221 // The statistics information.
TotalMemStatistics() const1222 size_t DynamicMemPoolBestFit::TotalMemStatistics() const {
1223   return common_mem_->mps_.total_mem_size_ + persistent_mem_->mps_.total_mem_size_;
1224 }
TotalUsedMemStatistics() const1225 size_t DynamicMemPoolBestFit::TotalUsedMemStatistics() const {
1226   return common_mem_->mps_.total_used_mem_size_ + persistent_mem_->mps_.total_used_mem_size_;
1227 }
TotalUsedByEventMemStatistics() const1228 size_t DynamicMemPoolBestFit::TotalUsedByEventMemStatistics() const {
1229   return common_mem_->mps_.total_used_by_event_mem_size_ + persistent_mem_->mps_.total_used_by_event_mem_size_;
1230 }
TotalIdleMemStatistics() const1231 size_t DynamicMemPoolBestFit::TotalIdleMemStatistics() const {
1232   return common_mem_->mps_.total_idle_mem_size_ + persistent_mem_->mps_.total_idle_mem_size_;
1233 }
TotalEagerFreeMemStatistics() const1234 size_t DynamicMemPoolBestFit::TotalEagerFreeMemStatistics() const {
1235   return common_mem_->mps_.total_eager_free_mem_size_ + persistent_mem_->mps_.total_eager_free_mem_size_;
1236 }
UsedMemPeakStatistics() const1237 size_t DynamicMemPoolBestFit::UsedMemPeakStatistics() const {
1238   return common_mem_->mps_.used_mem_peak_size_ + persistent_mem_->mps_.used_mem_peak_size_;
1239 }
MaxMemAllocatedStatistics() const1240 size_t DynamicMemPoolBestFit::MaxMemAllocatedStatistics() const {
1241   return common_mem_->mps_.temp_used_mem_peak_size_ + persistent_mem_->mps_.temp_used_mem_peak_size_;
1242 }
MaxMemReservedStatistics() const1243 size_t DynamicMemPoolBestFit::MaxMemReservedStatistics() const {
1244   return common_mem_->mps_.total_mem_size_ + persistent_mem_->mps_.total_mem_size_ -
1245          common_mem_->mps_.temp_total_mem_size_ - persistent_mem_->mps_.temp_total_mem_size_;
1246 }
ActualPeakStatistics() const1247 size_t DynamicMemPoolBestFit::ActualPeakStatistics() const {
1248   return common_mem_->CalActualPeak() + persistent_mem_->CalActualPeak();
1249 }
BlockCountsStatistics() const1250 std::unordered_map<std::string, std::size_t> DynamicMemPoolBestFit::BlockCountsStatistics() const {
1251   size_t common_mem_block_counts = common_mem_->mem_block_list_.size();
1252   size_t persistent_mem_block_counts = persistent_mem_->mem_block_list_.size();
1253   std::unordered_map<std::string, std::size_t> block_count_stats;
1254   block_count_stats[kCommonMemPoolType] = common_mem_block_counts;
1255   block_count_stats[kPersistentMemPoolType] = persistent_mem_block_counts;
1256   return block_count_stats;
1257 }
BlockUnitSizeStatistics() const1258 std::unordered_map<std::string, std::size_t> DynamicMemPoolBestFit::BlockUnitSizeStatistics() const {
1259   size_t common_mem_block_unit_size = common_mem_->unit_size_;
1260   size_t persistent_mem_block_unit_size = persistent_mem_->unit_size_;
1261   std::unordered_map<std::string, std::size_t> block_unit_size_stats;
1262   block_unit_size_stats[kCommonMemPoolType] = common_mem_block_unit_size;
1263   block_unit_size_stats[kPersistentMemPoolType] = persistent_mem_block_unit_size;
1264   return block_unit_size_stats;
1265 }
1266 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
CommonMemBlocksInfoStatistics() const1267 DynamicMemPoolBestFit::CommonMemBlocksInfoStatistics() const {
1268   return ExtractBlocksListInfo(common_mem_);
1269 }
1270 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
PersistentMemBlocksInfoStatistics() const1271 DynamicMemPoolBestFit::PersistentMemBlocksInfoStatistics() const {
1272   return ExtractBlocksListInfo(persistent_mem_);
1273 }
ResetMaxMemReserved() const1274 void DynamicMemPoolBestFit::ResetMaxMemReserved() const {
1275   common_mem_->mps_.temp_total_mem_size_ = common_mem_->mps_.total_mem_size_;
1276   persistent_mem_->mps_.temp_total_mem_size_ = persistent_mem_->mps_.total_mem_size_;
1277 }
ResetMaxMemAllocated() const1278 void DynamicMemPoolBestFit::ResetMaxMemAllocated() const {
1279   common_mem_->mps_.temp_total_used_mem_size_ = common_mem_->mps_.total_used_mem_size_;
1280   persistent_mem_->mps_.temp_total_used_mem_size_ = persistent_mem_->mps_.total_used_mem_size_;
1281   common_mem_->mps_.temp_total_used_by_event_mem_size_ = common_mem_->mps_.total_used_by_event_mem_size_;
1282   persistent_mem_->mps_.temp_total_used_by_event_mem_size_ = persistent_mem_->mps_.total_used_by_event_mem_size_;
1283   common_mem_->mps_.temp_used_mem_peak_size_ = 0;
1284   persistent_mem_->mps_.temp_used_mem_peak_size_ = 0;
1285 }
1286 
CalActualPeak()1287 size_t MemStatusManager::CalActualPeak() {
1288   if (mem_block_insertion_order_.empty()) {
1289     return 0;
1290   }
1291   size_t actual_peak = total_block_size_;
1292   const auto &end_block = mem_block_insertion_order_.back();
1293   MS_EXCEPTION_IF_NULL(end_block);
1294   actual_peak -= end_block->size();
1295   actual_peak += end_block->get_actual_peak();
1296   return actual_peak;
1297 }
1298 
RecordEvent(int64_t task_id_on_stream,uint32_t user_stream_id,const DeviceEventPtr & event)1299 bool DynamicMemBuf::RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id, const DeviceEventPtr &event) {
1300   MS_EXCEPTION_IF_NULL(event);
1301   if (events_ == nullptr) {
1302     events_ = std::make_shared<std::unordered_map<uint32_t, std::shared_ptr<std::list<TaskIdOnStreamEvent>>>>();
1303   }
1304   std::shared_ptr<std::list<TaskIdOnStreamEvent>> event_list = nullptr;
1305   auto iter = events_->find(user_stream_id);
1306   if (iter == events_->end()) {
1307     event_list = std::make_shared<std::list<TaskIdOnStreamEvent>>();
1308     (void)events_->emplace(user_stream_id, event_list);
1309   } else {
1310     event_list = iter->second;
1311     MS_EXCEPTION_IF_NULL(event_list);
1312   }
1313   (void)event_list->emplace_back(task_id_on_stream, event);
1314   return true;
1315 }
1316 
WaitEvent(uint32_t task_id_on_stream,uint32_t user_stream_id)1317 bool DynamicMemBuf::WaitEvent(uint32_t task_id_on_stream, uint32_t user_stream_id) {
1318   if (events_ == nullptr) {
1319     return false;
1320   }
1321   auto iter = events_->find(user_stream_id);
1322   if (iter == events_->end()) {
1323     return false;
1324   }
1325   auto &event_list = iter->second;
1326   MS_EXCEPTION_IF_NULL(event_list);
1327   // Pop all element in list that not bigger than task_id_on_stream.
1328   while (!event_list->empty() && event_list->front().first <= task_id_on_stream) {
1329     event_list->pop_front();
1330   }
1331   // Remove list if event list is empty.
1332   if (event_list->empty()) {
1333     events_->erase(iter);
1334   }
1335   return true;
1336 }
1337 
IsEventNotUsed()1338 bool DynamicMemBuf::IsEventNotUsed() { return events_ == nullptr ? true : events_->empty(); }
1339 
SyncAllEvents()1340 bool DynamicMemBuf::SyncAllEvents() {
1341   if (IsEventNotUsed()) {
1342     return false;
1343   }
1344 
1345   for (auto iter = events_->begin(); iter != events_->end();) {
1346     auto &event_list = iter->second;
1347     MS_EXCEPTION_IF_NULL(event_list);
1348     for (auto list_iter = event_list->begin(); list_iter != event_list->end();) {
1349       auto &event = list_iter->second;
1350       // Sync event if event is not arrived.
1351       if (!event->QueryEvent()) {
1352         event->SyncEvent();
1353       }
1354       list_iter = event_list->erase(list_iter);
1355     }
1356     if (event_list->empty()) {
1357       // list is empty, erase list in map.
1358       iter = events_->erase(iter);
1359     } else {
1360       MS_LOG(INTERNAL_EXCEPTION) << "Event list is not empty.";
1361     }
1362   }
1363   return events_->empty();
1364 }
1365 
AddMemBlock(const DynamicMemBlockPtr & mem_block,uint32_t stream_id)1366 void MemStatusManager::AddMemBlock(const DynamicMemBlockPtr &mem_block, uint32_t stream_id) {
1367   auto iter = mem_blocks_.find(stream_id);
1368   if (iter != mem_blocks_.end()) {
1369     DoAddMemBlock(mem_block, &iter->second);
1370   } else {
1371     (void)mem_blocks_.emplace(stream_id, std::vector<DynamicMemBlockPtr>{mem_block});
1372   }
1373 
1374   DoAddMemBlock(mem_block, &mem_block_list_);
1375   mem_block_insertion_order_.emplace_back(mem_block);
1376   total_block_size_ += mem_block->size();
1377 }
1378 
DoAddMemBlock(const DynamicMemBlockPtr & mem_block,std::vector<DynamicMemBlockPtr> * mem_block_list)1379 void MemStatusManager::DoAddMemBlock(const DynamicMemBlockPtr &mem_block,
1380                                      std::vector<DynamicMemBlockPtr> *mem_block_list) {
1381   auto iter = std::upper_bound(mem_block_list->begin(), mem_block_list->end(), mem_block->device_addr(),
1382                                [](const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
1383                                  return device_addr < mem_block->device_addr();
1384                                });
1385   (void)mem_block_list->insert(iter, mem_block);
1386 }
1387 
GetOrCreateMemBufMap(uint32_t stream_id,DynamicMemBufStatus status)1388 SizeMapMemBuf &MemStatusManager::GetOrCreateMemBufMap(uint32_t stream_id, DynamicMemBufStatus status) {
1389   return mem_bufs_[std::make_pair(stream_id, status)];
1390 }
1391 
AddMemBuf(const DynamicMemBufPtr & mem_buf)1392 void MemStatusManager::AddMemBuf(const DynamicMemBufPtr &mem_buf) {
1393   auto key = std::make_pair(mem_buf->stream_id_, mem_buf->status_);
1394   auto &mem_buf_map = mem_bufs_[key];
1395   (void)mem_buf_map.emplace(mem_buf->size_, mem_buf);
1396 }
1397 
RemoveMemBuf(const DynamicMemBufPtr & mem_buf)1398 void MemStatusManager::RemoveMemBuf(const DynamicMemBufPtr &mem_buf) {
1399   auto key = std::make_pair(mem_buf->stream_id_, mem_buf->status_);
1400   auto &mem_buf_map = mem_bufs_[key];
1401   auto &&iter = mem_buf_map.equal_range(mem_buf->size_);
1402   while (iter.first != iter.second) {
1403     if (iter.first->second->device_addr_ == mem_buf->device_addr_) {
1404       (void)mem_buf_map.erase(iter.first);
1405       return;
1406     }
1407     (void)iter.first++;
1408   }
1409   MS_LOG(INTERNAL_EXCEPTION) << "Remove mem buf failed, address : " << mem_buf->device_addr_ << ".";
1410 }
1411 
Clear()1412 void MemStatusManager::Clear() noexcept {
1413   mem_blocks_.clear();
1414   mem_block_list_.clear();
1415   mem_bufs_.clear();
1416 }
1417 
DumpMemBlockDebugInfo(const std::string & mem_type)1418 const DeviceState MemStatusManager::DumpMemBlockDebugInfo(const std::string &mem_type) {
1419   DeviceState device_state;
1420   // Dump the memory block info and memory buf info.
1421   MS_LOG(WARNING) << mem_type << " all mem_block info: counts[" << mem_block_list_.size() << "].";
1422   for (auto iter = mem_block_list_.begin(); iter != mem_block_list_.end(); ++iter) {
1423     device_state.total_mem_size_ += (*iter)->size();
1424     auto mem_buf_map = (*iter)->block_all_mem_buf_map_;
1425     MS_LOG(WARNING) << " MemBlock info: number[" << iter - mem_block_list_.begin() << "] mem_buf_counts["
1426                     << mem_buf_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size["
1427                     << (*iter)->size() << "] stream id[" << (*iter)->stream_id_ << "].";
1428     for (auto iter_mem_buf = mem_buf_map.begin(); iter_mem_buf != mem_buf_map.end(); ++iter_mem_buf) {
1429       auto mem_buf = iter_mem_buf->second;
1430       MS_EXCEPTION_IF_NULL(mem_buf);
1431       if (mem_buf->status_ == DynamicMemBufStatus::kMemBufIdle) {
1432         device_state.total_idle_mem_size_ += mem_buf->size_;
1433       } else if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsed) {
1434         device_state.total_used_mem_size_ += mem_buf->size_;
1435       } else if (mem_buf->status_ == DynamicMemBufStatus::kMemBufEagerFree) {
1436         device_state.total_eager_free_mem_size_ += mem_buf->size_;
1437       } else if (mem_buf->status_ == DynamicMemBufStatus::kMemBufUsedByEvent) {
1438         device_state.total_used_by_event_mem_size_ += mem_buf->size_;
1439       } else {
1440         MS_LOG(INTERNAL_EXCEPTION) << "Unknown mem buf status : " << mem_buf->status_ << ".";
1441       }
1442       MS_LOG(INFO) << "  MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status["
1443                    << kBufStatusString.at(mem_buf->status_) << "] name["
1444                    << (mem_buf->allocator_name_.empty() ? "Unknown" : mem_buf->allocator_name_) << "] type["
1445                    << kAllocatorTypeString.at(mem_buf->allocator_type_) << "] stream id[" << mem_buf->stream_id_
1446                    << "].";
1447     }
1448   }
1449   return device_state;
1450 }
1451 }  // namespace device
1452 }  // namespace mindspore
1453