1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/mem_reuse/mem_dynamic_allocator.h"
18 #include "utils/ms_utils.h"
19 #include "utils/convert_utils.h"
20 #include "utils/log_adapter.h"
21
22 namespace mindspore {
23 namespace device {
~DynamicMemPoolBestFit()24 DynamicMemPoolBestFit::~DynamicMemPoolBestFit() {
25 global_mem_block_list_.clear();
26 global_idle_mem_buf_map_.clear();
27 }
28
AllocTensorMem(size_t size)29 DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) {
30 size_t align_size = AlignMemorySize(size);
31 std::lock_guard<std::mutex> locker(mutex_);
32 // Find the idle memory buf by tensor size, if not find, then add new memory block and memory buf.
33 DeviceMemPtr device_addr = FindIdleMemBuf(align_size);
34 if (!device_addr) {
35 device_addr = AddMemBlockAndMemBuf(align_size);
36 }
37 return device_addr;
38 }
39
AllocContinuousTensorMem(size_t total_size,std::vector<size_t> size_list)40 std::vector<DeviceMemPtr> DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size,
41 std::vector<size_t> size_list) {
42 std::vector<DeviceMemPtr> device_addr_list;
43 // Pre-alloc the one whole piece memory.
44 auto device_addr = AllocTensorMem(total_size);
45 if (!device_addr) {
46 return device_addr_list;
47 }
48 std::lock_guard<std::mutex> locker(mutex_);
49 // Remove the pre-alloc memory.
50 const auto &mem_block = FindMemBlock(device_addr);
51 MS_EXCEPTION_IF_NULL(mem_block);
52 const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
53 if (iter == mem_block->block_all_mem_buf_map_.end()) {
54 MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "].";
55 }
56 auto mem_buf = iter->second;
57 MS_EXCEPTION_IF_NULL(mem_buf);
58 if (mem_buf->size_ < total_size) {
59 MS_LOG(EXCEPTION) << "The size of membuf is less than total_size.";
60 }
61 auto rest_size = mem_buf->size_ - total_size;
62 (void)mem_block->block_all_mem_buf_map_.erase(iter);
63 // Split the pre-alloc memory into continuous memory by the size list.
64 DynamicMemBufPtr continuous_mem_buf;
65 auto buf_addr = device_addr;
66 for (size_t i = 0; i < size_list.size(); i++) {
67 continuous_mem_buf = std::make_shared<DynamicMemBuf>(buf_addr, kMemBufUsed, size_list[i]);
68 (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf);
69 device_addr_list.emplace_back(buf_addr);
70 buf_addr = AddressOffset(buf_addr, size_list[i]);
71 }
72 // Update the size of the last memory buf.
73 continuous_mem_buf->size_ += rest_size;
74 return device_addr_list;
75 }
76
AlignMemorySize(size_t size) const77 size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const {
78 if (size == 0) {
79 return DYNAMIC_MEM_ALIGN_SIZE;
80 }
81 return ((size + DYNAMIC_MEM_ALIGN_SIZE - 1) / DYNAMIC_MEM_ALIGN_SIZE) * DYNAMIC_MEM_ALIGN_SIZE;
82 }
83
FindIdleMemBuf(size_t size)84 DeviceMemPtr DynamicMemPoolBestFit::FindIdleMemBuf(size_t size) {
85 const auto &iter = global_idle_mem_buf_map_.lower_bound(size);
86 if (iter != global_idle_mem_buf_map_.end()) {
87 auto mem_buf = iter->second;
88 MS_EXCEPTION_IF_NULL(mem_buf);
89 if (mem_buf->status_ != kMemBufIdle) {
90 MS_LOG(EXCEPTION) << "Find the mem_buf is not idle, alloc_size[" << size << "] mem_buf_size[" << mem_buf->size_
91 << "] mem_buf_address[" << mem_buf->device_addr_ << "].";
92 }
93 mem_buf->status_ = kMemBufUsed;
94 // Remove map of old idle memory buf
95 (void)global_idle_mem_buf_map_.erase(iter);
96 // Divide memory buf
97 if (IsDivide(size, mem_buf->size_)) {
98 DivideMemBuf(size, mem_buf);
99 }
100 // Memory statistics
101 total_used_mem_statistics_ += mem_buf->size_;
102 if (total_used_mem_statistics_ > used_mem_peak_statistics_) {
103 used_mem_peak_statistics_ = total_used_mem_statistics_;
104 }
105 return mem_buf->device_addr_;
106 }
107 return nullptr;
108 }
109
AddMemBlockAndMemBuf(size_t size)110 DeviceMemPtr DynamicMemPoolBestFit::AddMemBlockAndMemBuf(size_t size) {
111 size_t alloc_mem_size = CalMemBlockAllocSize(size);
112 if (alloc_mem_size == 0) {
113 return nullptr;
114 }
115 // Add new memory block
116 DeviceMemPtr device_addr = nullptr;
117 auto real_alloc_size = AllocDeviceMem(alloc_mem_size, &device_addr);
118 if (real_alloc_size < size) {
119 MS_LOG(WARNING) << "Memory not enough: alloc size[" << real_alloc_size << "] is smaller than required size[" << size
120 << "].";
121 return nullptr;
122 }
123 mem_alloc_unit_size_ = DYNAMIC_MEM_ALLOC_UNIT_SIZE;
124 auto mem_block = std::make_shared<DynamicMemBlock>(device_addr, real_alloc_size);
125 MS_EXCEPTION_IF_NULL(mem_block);
126 const auto &iter =
127 std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock);
128 (void)global_mem_block_list_.insert(iter, mem_block);
129 // Add new memory buf
130 auto mem_buf = std::make_shared<DynamicMemBuf>(device_addr, kMemBufUsed, real_alloc_size);
131 MS_EXCEPTION_IF_NULL(mem_buf);
132 // Add map of new memory buf in the block
133 (void)mem_block->block_all_mem_buf_map_.emplace(device_addr, mem_buf);
134 // Divide memory buf
135 if (IsDivide(size, mem_buf->size_)) {
136 DivideMemBuf(size, mem_buf);
137 }
138 // Memory statistics
139 total_mem_statistics_ += real_alloc_size;
140 total_used_mem_statistics_ += mem_buf->size_;
141 if (total_used_mem_statistics_ > used_mem_peak_statistics_) {
142 used_mem_peak_statistics_ = total_used_mem_statistics_;
143 }
144 return mem_buf->device_addr_;
145 }
146
CalMemBlockAllocSize(size_t size)147 size_t DynamicMemPoolBestFit::CalMemBlockAllocSize(size_t size) {
148 auto device_free_mem_size = free_mem_size();
149 if (device_free_mem_size < size) {
150 MS_LOG(WARNING) << "Memory not enough: current free memory size[" << device_free_mem_size
151 << "] is smaller than required size[" << size << "].";
152 return 0;
153 }
154 auto alloc_mem_size = mem_alloc_unit_size();
155 // Growing at twice of alloc size
156 constexpr size_t kDouble = 2;
157 while (alloc_mem_size < size) {
158 alloc_mem_size = alloc_mem_size * kDouble;
159 }
160 alloc_mem_size = std::min(alloc_mem_size, device_free_mem_size);
161 return alloc_mem_size;
162 }
163
IsDivide(size_t tensor_size,size_t mem_buf_size) const164 bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) const {
165 return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE;
166 }
167
DivideMemBuf(size_t size,const DynamicMemBufPtr & mem_buf)168 void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) {
169 MS_EXCEPTION_IF_NULL(mem_buf);
170 const auto &mem_block = FindMemBlock(mem_buf->device_addr_);
171 MS_EXCEPTION_IF_NULL(mem_block);
172 // Divide new memory buf
173 if (mem_buf->size_ < size) {
174 MS_LOG(EXCEPTION) << "The size of membuf is less than size.";
175 }
176 size_t newbuf_size = mem_buf->size_ - size;
177 mem_buf->size_ = size;
178 DeviceMemPtr newbuf_addr = AddressOffset(mem_buf->device_addr_, size);
179 auto new_mem_buf = std::make_shared<DynamicMemBuf>(newbuf_addr, kMemBufIdle, newbuf_size);
180 // Add map of new memory buf in the block
181 (void)mem_block->block_all_mem_buf_map_.emplace(newbuf_addr, new_mem_buf);
182 // Add map of new idle memory buf
183 (void)global_idle_mem_buf_map_.emplace(newbuf_size, new_mem_buf);
184 }
185
CmpMemBlock(const DeviceMemPtr & device_addr,const DynamicMemBlockPtr & mem_block)186 bool DynamicMemPoolBestFit::CmpMemBlock(const DeviceMemPtr &device_addr, const DynamicMemBlockPtr &mem_block) {
187 MS_EXCEPTION_IF_NULL(device_addr);
188 MS_EXCEPTION_IF_NULL(mem_block);
189 return device_addr < mem_block->device_addr();
190 }
191
FindMemBlock(const DeviceMemPtr & device_addr)192 DynamicMemBlockPtr DynamicMemPoolBestFit::FindMemBlock(const DeviceMemPtr &device_addr) {
193 MS_EXCEPTION_IF_NULL(device_addr);
194 auto &&iter =
195 std::upper_bound(global_mem_block_list_.begin(), global_mem_block_list_.end(), device_addr, CmpMemBlock);
196 if (iter != global_mem_block_list_.begin()) {
197 return *(--iter);
198 }
199 return nullptr;
200 }
201
FreeTensorMem(const DeviceMemPtr & device_addr)202 void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr &device_addr) {
203 MS_EXCEPTION_IF_NULL(device_addr);
204 std::lock_guard<std::mutex> locker(mutex_);
205 const auto &mem_block = FindMemBlock(device_addr);
206 if (mem_block == nullptr) {
207 // May be destroy the memory pool first, then destroy the address, so this is normal case.
208 MS_LOG(DEBUG) << "Can't find the mem_block of the device address[" << device_addr << "].";
209 return;
210 }
211 CombineMemBuf(mem_block, device_addr);
212 }
213
CombineMemBuf(const DynamicMemBlockPtr & mem_block,const DeviceMemPtr & device_addr)214 void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr &device_addr) {
215 MS_EXCEPTION_IF_NULL(mem_block);
216 MS_EXCEPTION_IF_NULL(device_addr);
217 const auto &iter = mem_block->block_all_mem_buf_map_.find(device_addr);
218 if (iter == mem_block->block_all_mem_buf_map_.end()) {
219 MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "].";
220 }
221 auto mem_buf = iter->second;
222 MS_EXCEPTION_IF_NULL(mem_buf);
223 if (mem_buf->status_ != kMemBufUsed) {
224 MS_LOG(EXCEPTION) << "Find the mem_buf is not used, mem_buf_address[" << mem_buf->device_addr_ << "].";
225 }
226 mem_buf->status_ = kMemBufIdle;
227 if (total_used_mem_statistics_ < mem_buf->size_) {
228 MS_LOG(EXCEPTION) << "The total used mem size is less than the size of membuf.";
229 }
230 total_used_mem_statistics_ -= mem_buf->size_;
231 // Combine backward(combine the next_mem_buf to mem_buf)
232 auto next_iter = iter;
233 (void)next_iter++;
234 if (next_iter != mem_block->block_all_mem_buf_map_.end()) {
235 auto next_mem_buf = next_iter->second;
236 MS_EXCEPTION_IF_NULL(next_mem_buf);
237 if (next_mem_buf->status_ == kMemBufIdle) {
238 mem_buf->size_ += next_mem_buf->size_;
239 EraseIdleMemBuf(next_mem_buf->size_, next_mem_buf->device_addr_);
240 (void)mem_block->block_all_mem_buf_map_.erase(next_iter);
241 }
242 }
243 // Combine forward(combine the mem_buf to prev_mem_buf)
244 bool forward_combine = false;
245 DynamicMemBufPtr prev_mem_buf;
246 if (iter != mem_block->block_all_mem_buf_map_.begin()) {
247 auto prev_iter = iter;
248 (void)prev_iter--;
249 prev_mem_buf = prev_iter->second;
250 MS_EXCEPTION_IF_NULL(prev_mem_buf);
251 if (prev_mem_buf->status_ == kMemBufIdle) {
252 EraseIdleMemBuf(prev_mem_buf->size_, prev_mem_buf->device_addr_);
253 prev_mem_buf->size_ += mem_buf->size_;
254 (void)mem_block->block_all_mem_buf_map_.erase(iter);
255 forward_combine = true;
256 }
257 }
258 // Add map of new idle memory
259 if (forward_combine) {
260 (void)global_idle_mem_buf_map_.emplace(prev_mem_buf->size_, prev_mem_buf);
261 } else {
262 (void)global_idle_mem_buf_map_.emplace(mem_buf->size_, mem_buf);
263 }
264 }
265
EraseIdleMemBuf(size_t size,const DeviceMemPtr & device_addr)266 void DynamicMemPoolBestFit::EraseIdleMemBuf(size_t size, const DeviceMemPtr &device_addr) {
267 MS_EXCEPTION_IF_NULL(device_addr);
268 auto &&iter = global_idle_mem_buf_map_.equal_range(size);
269 while (iter.first != iter.second) {
270 MS_EXCEPTION_IF_NULL(iter.first->second);
271 // Remove map of the idle memory buf by size and device address
272 if (iter.first->second->device_addr_ == device_addr) {
273 (void)global_idle_mem_buf_map_.erase(iter.first);
274 return;
275 }
276 (void)iter.first++;
277 }
278 MS_LOG(ERROR) << "Can't find the size[" << size << "] and device address[" << device_addr << "] in the idle mem_buf.";
279 }
280
ReleaseDeviceRes()281 void DynamicMemPoolBestFit::ReleaseDeviceRes() {
282 std::lock_guard<std::mutex> locker(mutex_);
283 MS_LOG(INFO) << "The dynamic memory pool total size is " << total_mem_statistics_ << ", total used size is "
284 << total_used_mem_statistics_ << ", used peak size is " << used_mem_peak_statistics_ << ".";
285 for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
286 auto &device_addr = (*iter)->device_addr_base_;
287 if (device_addr != nullptr) {
288 if (!FreeDeviceMem(device_addr)) {
289 MS_LOG(EXCEPTION) << "Free device memory[" << device_addr << "] error.";
290 }
291 device_addr = nullptr;
292 }
293 }
294
295 global_mem_block_list_.clear();
296 global_idle_mem_buf_map_.clear();
297 }
298
DumpDynamicMemPoolInfo()299 void DynamicMemPoolBestFit::DumpDynamicMemPoolInfo() {
300 std::lock_guard<std::mutex> locker(mutex_);
301 MS_LOG(INFO) << "Start dump dynamic memory pool info.";
302 DeviceAddrMapMemBuf mem_block_map;
303 DynamicMemBufPtr mem_buf;
304 size_t total_mem = 0;
305 size_t total_used_mem = 0;
306 size_t total_idle_mem1 = 0;
307 size_t total_idle_mem2 = 0;
308 // Dump the memory block info and memory buf info
309 MS_LOG(INFO) << "Dump all mem_block info: counts[" << global_mem_block_list_.size() << "].";
310 for (auto iter = global_mem_block_list_.begin(); iter != global_mem_block_list_.end(); ++iter) {
311 total_mem += (*iter)->size();
312 mem_block_map = (*iter)->block_all_mem_buf_map_;
313 MS_LOG(INFO) << "MemBlock info: number[" << iter - global_mem_block_list_.begin() << "] mem_buf_counts["
314 << mem_block_map.size() << "] base_address[" << (*iter)->device_addr() << "] block_size["
315 << (*iter)->size() << "].";
316 for (auto iter_mem_buf = mem_block_map.begin(); iter_mem_buf != mem_block_map.end(); ++iter_mem_buf) {
317 mem_buf = iter_mem_buf->second;
318 MS_EXCEPTION_IF_NULL(mem_buf);
319 if (mem_buf->status_ == kMemBufIdle) {
320 total_idle_mem1 += mem_buf->size_;
321 } else {
322 total_used_mem += mem_buf->size_;
323 }
324 MS_LOG(INFO) << "MemBuf info: address[" << mem_buf->device_addr_ << "] size[" << mem_buf->size_ << "] status["
325 << mem_buf->status_ << "].";
326 }
327 }
328 // Dump all the idle memory buf info
329 MS_LOG(INFO) << "Dump all idle mem_buf info: counts[" << global_idle_mem_buf_map_.size() << "].";
330 for (auto iter_idle = global_idle_mem_buf_map_.begin(); iter_idle != global_idle_mem_buf_map_.end(); ++iter_idle) {
331 mem_buf = iter_idle->second;
332 MS_EXCEPTION_IF_NULL(mem_buf);
333 total_idle_mem2 += mem_buf->size_;
334 MS_LOG(INFO) << "Idle mem_buf info: size[" << mem_buf->size_ << "] address[" << mem_buf->device_addr_ << "] status["
335 << mem_buf->status_ << "].";
336 }
337 // Dump the memory statistical info
338 MS_LOG(INFO) << "Total allocated memory[" << total_mem << "], used memory[" << total_used_mem << "], idle memory["
339 << total_idle_mem1 << "].";
340 if (total_idle_mem1 != total_idle_mem2) {
341 MS_LOG(ERROR) << "Check error: the idle memory in the mem_block is not equal the global idle memory.";
342 }
343 if (total_mem != total_used_mem + total_idle_mem1) {
344 MS_LOG(ERROR) << "Check error: the the total memory is not equal the sum of used memory and idle memory.";
345 }
346 MS_LOG(INFO) << "Finish dump dynamic memory pool info.";
347 }
348 } // namespace device
349 } // namespace mindspore
350