• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h"
18 #include "utils/ms_utils.h"
19 #include "utils/convert_utils.h"
20 #include "utils/log_adapter.h"
21 
22 namespace mindspore {
23 namespace device {
~DynamicMemPoolBestFit()24 DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
25   global_mem_block_list_.clear();
26   global_idle_mem_buf_map_.clear();
27 }
28 
AllocTensorMem(size_t size)29 DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) {
30   size_t align_size = AlignMemorySize(size);
31   std::lock_guard<std::mutex> locker(mutex_);
32   // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf.
33   DeviceMemPtr device_addr = FindIdleMemBuf(align_size);
34   if (!device_addr) {
35     device_addr = AddMemBlockAndMemBuf(align_size);
36   }
37   return device_addr;
38 }
39 
AllocContinuousTensorMem(size_t total_size,std::vector<size_t> size_list)40 std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size,
41                                                                           std::vector<size_t> size_list) {
42   std::vector<DeviceMemPtr> device_addr_list;
43   // Pre-alloc the one whole piece memory.
44   auto device_addr = AllocTensorMem(total_size);
45   if (!device_addr) {
46     return device_addr_list;
47   }
48   std::lock_guard<std::mutex> locker(mutex_);
49   // Remove the pre-alloc memory.
50   const auto &mem_block = FindMemBlock(device_addr);
51   MS_EXCEPTION_IF_NULL(mem_block);
52   const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
53   if (iter == mem_block->block_all_mem_buf_map_.end()) {
54     MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "].";
55   }
56   auto mem_buf = iter->second;
57   MS_EXCEPTION_IF_NULL(mem_buf);
58   if (mem_buf->size_ < total_size) {
59     MS_LOG(EXCEPTION) << "The size of membuf is less than total_size.";
60   }
61   auto rest_size = mem_buf->size_ - total_size;
62   (void)mem_block->block_all_mem_buf_map_.erase(iter);
63   // Split the pre-alloc memory into continuous memory by the size list.
64   DynamicMemBufPtr continuous_mem_buf;
65   auto buf_addr = device_addr;
66   for (size_t i = 0; i < size_list.size(); i++) {
67     continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, size_list[i]);
68     (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
69     device_addr_list.emplace_back(buf_addr);
70     buf_addr = AddressOffset(buf_addr, size_list[i]);
71   }
72   // Update the size of the last memory buf.
73   continuous_mem_buf->size_ += rest_size;
74   return device_addr_list;
75 }
76 
AlignMemorySize(size_t size) const77 size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const {
78   if (size == 0) {
79     return DYNAMIC_MEM_ALIGN_SIZE;
80   }
81   return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
82 }
83 
FindIdleMemBuf(size_t size)84 DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) {
85   const auto &iter = global_idle_mem_buf_map_.lower_bound(size);
86   if (iter != global_idle_mem_buf_map_.end()) {
87     auto mem_buf = iter->second;
88     MS_EXCEPTION_IF_NULL(mem_buf);
89     if (mem_buf->status_ != kMemBufIdle) {
90       MS_LOG(EXCEPTION) << "Find the mem_buf is not idle, alloc_size[" << size << "] mem_buf_size[" << mem_buf->size_
91                         << "] mem_buf_address[" << mem_buf->device_addr_ << "].";
92     }
93     mem_buf->status_ = kMemBufUsed;
94     // Remove map of old idle memory buf
95     (void)global_idle_mem_buf_map_.erase(iter);
96     // Divide memory buf
97     if (IsDivide(size, mem_buf->size_)) {
98       DivideMemBuf(size, mem_buf);
99     }
100     // Memory statistics
101     total_used_mem_statistics_ += mem_buf->size_;
102     if (total_used_mem_statistics_ > used_mem_peak_statistics_) {
103       used_mem_peak_statistics_ = total_used_mem_statistics_;
104     }
105     return mem_buf->device_addr_;
106   }
107   return nullptr;
108 }
109 
AddMemBlockAndMemBuf(size_t size)110 DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) {
111   size_t alloc_mem_size = CalMemBlockAllocSize(size);
112   if (alloc_mem_size == 0) {
113     return nullptr;
114   }
115   // Add new memory block
116   DeviceMemPtr device_addr = nullptr;
117   auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr);
118   if (real_alloc_size < size) {
119     MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
120                     << "].";
121     return nullptr;
122   }
123   mem_alloc_unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
124   auto mem_block = std::make_shared<DynamicMemBlock>(device_addr, real_alloc_size);
125   MS_EXCEPTION_IF_NULL(mem_block);
126   const auto &iter =
127     std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock);
128   (void)global_mem_block_list_.insert(iter, mem_block);
129   // Add new memory buf
130   auto mem_buf = std::make_shared<DynamicMemBuf>(device_addr, kMemBufUsed, real_alloc_size);
131   MS_EXCEPTION_IF_NULL(mem_buf);
132   // Add map of new memory buf in the block
133   (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf);
134   // Divide memory buf
135   if (IsDivide(size, mem_buf->size_)) {
136     DivideMemBuf(size, mem_buf);
137   }
138   // Memory statistics
139   total_mem_statistics_ += real_alloc_size;
140   total_used_mem_statistics_ += mem_buf->size_;
141   if (total_used_mem_statistics_ > used_mem_peak_statistics_) {
142     used_mem_peak_statistics_ = total_used_mem_statistics_;
143   }
144   return mem_buf->device_addr_;
145 }
146 
CalMemBlockAllocSize(size_t size)147 size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) {
148   auto device_free_mem_size = free_mem_size();
149   if (device_free_mem_size < size) {
150     MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size
151                     << "] is smaller than required size[" << size << "].";
152     return 0;
153   }
154   auto alloc_mem_size = mem_alloc_unit_size();
155   // Growing at twice of alloc size
156   constexpr size_t kDouble = 2;
157   while (alloc_mem_size < size) {
158     alloc_mem_size = alloc_mem_size * kDouble;
159   }
160   alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size);
161   return alloc_mem_size;
162 }
163 
IsDivide(size_t tensor_size,size_t mem_buf_size) const164 bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const {
165   return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE;
166 }
167 
DivideMemBuf(size_t size,const DynamicMemBufPtr & mem_buf)168 void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) {
169   MS_EXCEPTION_IF_NULL(mem_buf);
170   const auto &mem_block = FindMemBlock(mem_buf->device_addr_);
171   MS_EXCEPTION_IF_NULL(mem_block);
172   // Divide new memory buf
173   if (mem_buf->size_ < size) {
174     MS_LOG(EXCEPTION) << "The size of membuf is less than size.";
175   }
176   size_t newbuf_size = mem_buf->size_ - size;
177   mem_buf->size_ = size;
178   DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size);
179   auto new_mem_buf = std::make_shared<DynamicMemBuf>(newbuf_addr, kMemBufIdle, newbuf_size);
180   // Add map of new memory buf in the block
181   (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
182   // Add map of new idle memory buf
183   (void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf);
184 }
185 
CmpMemBlock(const DeviceMemPtr & device_addr,const DynamicMemBlockPtr & mem_block)186 bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
187   MS_EXCEPTION_IF_NULL(device_addr);
188   MS_EXCEPTION_IF_NULL(mem_block);
189   return device_addr < mem_block->device_addr();
190 }
191 
FindMemBlock(const DeviceMemPtr & device_addr)192 DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr) {
193   MS_EXCEPTION_IF_NULL(device_addr);
194   auto &&iter =
195     std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock);
196   if (iter != global_mem_block_list_.begin()) {
197     return *(--iter);
198   }
199   return nullptr;
200 }
201 
FreeTensorMem(const DeviceMemPtr & device_addr)202 void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
203   MS_EXCEPTION_IF_NULL(device_addr);
204   std::lock_guard<std::mutex> locker(mutex_);
205   const auto &mem_block = FindMemBlock(device_addr);
206   if (mem_block == nullptr) {
207     // May be destroy the memory pool first, then destroy the address, so this is normal case.
208     MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
209     return;
210   }
211   CombineMemBuf(mem_block, device_addr);
212 }
213 
CombineMemBuf(const DynamicMemBlockPtr & mem_block,const DeviceMemPtr & device_addr)214 void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr) {
215   MS_EXCEPTION_IF_NULL(mem_block);
216   MS_EXCEPTION_IF_NULL(device_addr);
217   const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
218   if (iter == mem_block->block_all_mem_buf_map_.end()) {
219     MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "].";
220   }
221   auto mem_buf = iter->second;
222   MS_EXCEPTION_IF_NULL(mem_buf);
223   if (mem_buf->status_ != kMemBufUsed) {
224     MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "].";
225   }
226   mem_buf->status_ = kMemBufIdle;
227   if (total_used_mem_statistics_ < mem_buf->size_) {
228     MS_LOG(EXCEPTION) << "The total used mem size is less than the size of membuf.";
229   }
230   total_used_mem_statistics_ -= mem_buf->size_;
231   // Combine backward(combine the next_mem_buf to mem_buf)
232   auto next_iter = iter;
233   (void)next_iter++;
234   if (next_iter != mem_block->block_all_mem_buf_map_.end()) {
235     auto next_mem_buf = next_iter->second;
236     MS_EXCEPTION_IF_NULL(next_mem_buf);
237     if (next_mem_buf->status_ == kMemBufIdle) {
238       mem_buf->size_ += next_mem_buf->size_;
239       EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_);
240       (void)mem_block->block_all_mem_buf_map_.erase(next_iter);
241     }
242   }
243   // Combine forward(combine the mem_buf to prev_mem_buf)
244   bool forward_combine = false;
245   DynamicMemBufPtr prev_mem_buf;
246   if (iter != mem_block->block_all_mem_buf_map_.begin()) {
247     auto prev_iter = iter;
248     (void)prev_iter--;
249     prev_mem_buf = prev_iter->second;
250     MS_EXCEPTION_IF_NULL(prev_mem_buf);
251     if (prev_mem_buf->status_ == kMemBufIdle) {
252       EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_);
253       prev_mem_buf->size_ += mem_buf->size_;
254       (void)mem_block->block_all_mem_buf_map_.erase(iter);
255       forward_combine = true;
256     }
257   }
258   // Add map of new idle memory
259   if (forward_combine) {
260     (void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf);
261   } else {
262     (void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf);
263   }
264 }
265 
EraseIdleMemBuf(size_t size,const DeviceMemPtr & device_addr)266 void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr) {
267   MS_EXCEPTION_IF_NULL(device_addr);
268   auto &&iter = global_idle_mem_buf_map_.equal_range(size);
269   while (iter.first != iter.second) {
270     MS_EXCEPTION_IF_NULL(iter.first->second);
271     // Remove map of the idle memory buf by size and device address
272     if (iter.first->second->device_addr_ == device_addr) {
273       (void)global_idle_mem_buf_map_.erase(iter.first);
274       return;
275     }
276     (void)iter.first++;
277   }
278   MS_LOG(ERROR) << "Can't find the size[" << size << "] and device address[" << device_addr << "] in the idle mem_buf.";
279 }
280 
ReleaseDeviceRes()281 void DynamicMemPoolBestFit::ReleaseDeviceRes() {
282   std::lock_guard<std::mutex> locker(mutex_);
283   MS_LOG(INFO) << "The dynamic memory pool total size is " << total_mem_statistics_ << ", total used size is "
284                << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << ".";
285   for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
286     auto &device_addr = (*iter)->device_addr_base_;
287     if (device_addr != nullptr) {
288       if (!FreeDeviceMem(device_addr)) {
289         MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
290       }
291       device_addr = nullptr;
292     }
293   }
294 
295   global_mem_block_list_.clear();
296   global_idle_mem_buf_map_.clear();
297 }
298 
DumpDynamicMemPoolInfo()299 void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
300   std::lock_guard<std::mutex> locker(mutex_);
301   MS_LOG(INFO) << "Start dump dynamic memory pool info.";
302   DeviceAddrMapMemBuf mem_block_map;
303   DynamicMemBufPtr mem_buf;
304   size_t total_mem = 0;
305   size_t total_used_mem = 0;
306   size_t total_idle_mem1 = 0;
307   size_t total_idle_mem2 = 0;
308   // Dump the memory block info and memory buf info
309   MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "].";
310   for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
311     total_mem += (*iter)->size();
312     mem_block_map = (*iter)->block_all_mem_buf_map_;
313     MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts["
314                  << mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size["
315                  << (*iter)->size() << "].";
316     for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) {
317       mem_buf = iter_mem_buf->second;
318       MS_EXCEPTION_IF_NULL(mem_buf);
319       if (mem_buf->status_ == kMemBufIdle) {
320         total_idle_mem1 += mem_buf->size_;
321       } else {
322         total_used_mem += mem_buf->size_;
323       }
324       MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status["
325                    << mem_buf->status_ << "].";
326     }
327   }
328   // Dump all the idle memory buf info
329   MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "].";
330   for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) {
331     mem_buf = iter_idle->second;
332     MS_EXCEPTION_IF_NULL(mem_buf);
333     total_idle_mem2 += mem_buf->size_;
334     MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status["
335                  << mem_buf->status_ << "].";
336   }
337   // Dump the memory statistical info
338   MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory["
339                << total_idle_mem1 << "].";
340   if (total_idle_mem1 != total_idle_mem2) {
341     MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory.";
342   }
343   if (total_mem != total_used_mem + total_idle_mem1) {
344     MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory.";
345   }
346   MS_LOG(INFO) << "Finish dump dynamic memory pool info.";
347 }
348 }  // namespace device
349 }  // namespace mindspore
350