1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/extendrt/dynamic_mem_manager.h"
17 #include "src/common/log_adapter.h"
18 #include "src/common/utils.h"
19 #include "src/common/common.h"
20
21 using mindspore::numa::NUMAAdapter;
22
23 using mindspore::numa::MemoryInfo;
24
25 namespace mindspore {
26 namespace {
27 // Alloc memory aligned according to 64 bytes.
28 static constexpr size_t kMemAlginSize = 64;
29
30 // The default unit size (256M) of memory block used for dynamic extend.
31 static constexpr size_t kAllocUnitSize = 256 * 1024 * 1024;
32 // The minimum unit size (64M) of memory block used for dynamic extend.
33 static constexpr size_t kMinimumAllocUnitSize = 64 * 1024 * 1024;
34 // 16G
35 static constexpr size_t kMinimumSysMemory = 17179869184;
36 static constexpr auto kBlockSize = 2048;
37 // invalid block index
38 static constexpr int kInvalidIndex = -1;
39 // invalid numa node id
40 static constexpr int kInvalidNodeId = -1;
41 static constexpr int kInvalidRefCount = -1;
Rounded(size_t size)42 size_t Rounded(size_t size) { return (size + kMemAlginSize - 1) & (~(kMemAlginSize - 1)); }
43 } // namespace
44
Allocate(size_t rounded_size,int node_id,size_t * allocate_size)45 void *MemOperator::Allocate(size_t rounded_size, int node_id, size_t *allocate_size) {
46 static const auto kMaxMallocSize = lite::GetMaxMallocSize();
47 static const auto unit_size = kMaxMallocSize < kMinimumSysMemory ? kMinimumAllocUnitSize : kAllocUnitSize;
48 auto allocate_tmp_size = rounded_size < unit_size ? unit_size : rounded_size;
49 if (allocate_tmp_size > kMaxMallocSize) {
50 MS_LOG(ERROR) << "request invalid memory size " << allocate_tmp_size << ", total: " << kMaxMallocSize;
51 return nullptr;
52 }
53
54 *allocate_size = allocate_tmp_size;
55 void *data = nullptr;
56 #ifdef _WIN32
57 data = _aligned_malloc(allocate_tmp_size, kMemAlginSize);
58 #else
59 if (node_id >= 0) {
60 data = numa_instance_->Malloc(node_id, static_cast<size_t>(allocate_tmp_size));
61 if (MS_UNLIKELY((data == nullptr && allocate_tmp_size > rounded_size))) {
62 MS_LOG(WARNING) << "Malloc memory(" << allocate_tmp_size << ") failed! malloc rounded_size(" << rounded_size
63 << ") memory again. node_id: " << node_id;
64 allocate_tmp_size = rounded_size;
65 *allocate_size = rounded_size;
66 data = numa_instance_->Malloc(node_id, rounded_size);
67 }
68 } else {
69 auto ret = posix_memalign(&data, kMemAlginSize, static_cast<size_t>(allocate_tmp_size));
70 if (MS_UNLIKELY(ret == ENOMEM && allocate_tmp_size > rounded_size)) {
71 MS_LOG(WARNING) << "Malloc memory(" << allocate_tmp_size << ") failed! malloc rounded_size(" << rounded_size
72 << ") memory again.";
73 allocate_tmp_size = rounded_size;
74 *allocate_size = rounded_size;
75 ret = posix_memalign(&data, kMemAlginSize, rounded_size);
76 }
77 if (MS_UNLIKELY(ret != 0)) {
78 MS_LOG(ERROR) << "posix_memalign failed!ret: " << ret << ", node_id: " << node_id
79 << ", request: " << allocate_tmp_size;
80 return nullptr;
81 }
82 }
83 #endif
84 if (MS_UNLIKELY(data == nullptr)) {
85 MS_LOG(ERROR) << "malloc data failed!node_id: " << node_id << ", request: " << allocate_tmp_size;
86 return nullptr;
87 }
88
89 return data;
90 }
91
GetBlock()92 Block *MemOperator::GetBlock() {
93 Block *block;
94 if (garbage_block_ != kInvalidIndex) {
95 block = &blocks_[garbage_block_];
96 garbage_block_ = blocks_[garbage_block_].next_index_;
97 } else {
98 if (block_count_ >= blocks_.size()) {
99 blocks_.resize(blocks_.size() + kBlockSize);
100 }
101 blocks_[block_count_].index_ = static_cast<int64_t>(block_count_);
102 block = &blocks_[block_count_++];
103 }
104 block->used_ = false;
105 block->ref_count_ = 0;
106 block->pre_index_ = kInvalidIndex;
107 block->next_index_ = kInvalidIndex;
108 return block;
109 }
110
AddGarbageBlock(const int64_t index)111 void MemOperator::AddGarbageBlock(const int64_t index) {
112 blocks_[index].next_index_ = garbage_block_;
113 garbage_block_ = index;
114 }
115
116 // malloc memory for data storage
Malloc(size_t size)117 void *MemOperator::Malloc(size_t size) {
118 auto rounded_size = size == 0 ? 1 : Rounded(size);
119 std::lock_guard<std::mutex> locker(mutex_);
120 auto iter = free_blocks_.lower_bound(rounded_size);
121 if (iter != free_blocks_.end()) {
122 auto index = iter->second;
123 free_blocks_.erase(iter);
124 blocks_[index].used_ = true;
125 auto data = blocks_[index].data_;
126 datas_.emplace(data, index);
127 if (blocks_[index].size_ > rounded_size) {
128 Block *block_next = GetBlock();
129 auto *block = &blocks_[index];
130 block_next->size_ = block->size_ - rounded_size;
131 block->size_ = rounded_size;
132 block_next->data_ = static_cast<int8_t *>(block->data_) + rounded_size;
133 block_next->pre_index_ = index;
134 auto next_index = block->next_index_;
135 block_next->next_index_ = next_index;
136 if (next_index != kInvalidIndex) {
137 blocks_[next_index].pre_index_ = block_next->index_;
138 }
139 block->next_index_ = block_next->index_;
140 free_blocks_.emplace(block_next->size_, block_next->index_);
141 }
142 return data;
143 }
144 // todo kAllocUnitSize can be replaced by config
145 size_t allocate_size;
146 void *data = Allocate(rounded_size, node_id_, &allocate_size);
147 if (MS_UNLIKELY(data == nullptr)) {
148 return nullptr;
149 }
150 all_datas_.emplace(data, allocate_size);
151 Block *block = GetBlock();
152 block->size_ = rounded_size;
153 block->data_ = data;
154 block->used_ = true;
155 datas_.emplace(data, block->index_);
156 if (allocate_size > rounded_size) {
157 Block *block_next = GetBlock();
158 block_next->data_ = static_cast<int8_t *>(data) + rounded_size;
159 block_next->size_ = allocate_size - rounded_size;
160 block_next->pre_index_ = block->index_;
161 block->next_index_ = block_next->index_;
162 free_blocks_.emplace(block_next->size_, block_next->index_);
163 }
164 return data;
165 }
166
167 // return memory to the memory pool
Free(void * ptr)168 void MemOperator::Free(void *ptr) {
169 if (MS_UNLIKELY(ptr == nullptr)) {
170 return;
171 }
172 std::lock_guard<std::mutex> locker(mutex_);
173 auto iter = datas_.find(ptr);
174 if (iter == datas_.end()) {
175 return;
176 }
177
178 auto index = iter->second;
179 datas_.erase(iter);
180 Block *block = &blocks_[index];
181 auto next_index = block->next_index_;
182 if (next_index != kInvalidIndex && !blocks_[next_index].used_) {
183 EraseFreeBlock(next_index);
184 block->size_ += blocks_[next_index].size_;
185 auto next_next_index = blocks_[next_index].next_index_;
186 if (next_next_index != kInvalidIndex) {
187 blocks_[next_next_index].pre_index_ = block->index_;
188 }
189 block->next_index_ = next_next_index;
190 block->used_ = false;
191 block->ref_count_ = 0;
192 free_blocks_.emplace(block->size_, block->index_);
193 AddGarbageBlock(next_index);
194 }
195 auto pre_index = block->pre_index_;
196 if (pre_index != kInvalidIndex && !blocks_[pre_index].used_) {
197 EraseFreeBlock(pre_index);
198 if (!block->used_) {
199 EraseFreeBlock(index);
200 }
201 blocks_[pre_index].size_ += block->size_;
202 next_index = block->next_index_;
203 blocks_[pre_index].next_index_ = next_index;
204 if (next_index != kInvalidIndex) {
205 blocks_[next_index].pre_index_ = pre_index;
206 }
207 block->used_ = false;
208 block->ref_count_ = 0;
209 free_blocks_.emplace(blocks_[pre_index].size_, pre_index);
210 AddGarbageBlock(index);
211 }
212 if (block->used_) {
213 block->used_ = false;
214 block->ref_count_ = 0;
215 free_blocks_.emplace(block->size_, block->index_);
216 }
217 }
218
EraseFreeBlock(const int64_t index)219 void MemOperator::EraseFreeBlock(const int64_t index) {
220 auto range = free_blocks_.equal_range(blocks_[index].size_);
221 for (auto item = range.first; item != range.second; ++item) {
222 if (item->second == index) {
223 free_blocks_.erase(item);
224 break;
225 }
226 }
227 }
228
MemOperator(int node_id)229 MemOperator::MemOperator(int node_id) {
230 numa_instance_ = NUMAAdapter::GetInstance();
231 if (node_id >= 0 && numa_instance_->Available()) {
232 node_id_ = node_id;
233 }
234
235 blocks_.resize(kBlockSize);
236 garbage_block_ = kInvalidIndex;
237 auto *block = GetBlock();
238 size_t allocate_size;
239 block->data_ = Allocate(kAllocUnitSize, node_id, &allocate_size);
240 if (MS_UNLIKELY(block->data_ == nullptr)) {
241 return;
242 }
243 all_datas_.emplace(block->data_, allocate_size);
244 block->size_ = allocate_size;
245 free_blocks_.emplace(allocate_size, block->index_);
246 }
247
~MemOperator()248 MemOperator::~MemOperator() {
249 MS_LOG(DEBUG) << "~MemOperator() begin.";
250 for (auto &&data : all_datas_) {
251 #ifdef _WIN32
252 _aligned_free(data.first);
253 #else
254 if (node_id_ >= 0) {
255 numa_instance_->Free(data.first, data.second);
256 } else {
257 free(data.first);
258 }
259 #endif
260 }
261 free_blocks_.clear();
262 all_datas_.clear();
263 blocks_.clear();
264 MS_LOG(DEBUG) << "~MemOperator() end.";
265 }
266
SetRefCount(void * ptr,int ref_count)267 int MemOperator::SetRefCount(void *ptr, int ref_count) {
268 std::lock_guard<std::mutex> locker(mutex_);
269 auto iter = datas_.find(ptr);
270 if (iter != datas_.end()) {
271 auto index = iter->second;
272 blocks_[index].ref_count_ = ref_count;
273 return ref_count;
274 }
275 return kInvalidRefCount;
276 }
277
IncRefCount(void * ptr,int ref_count)278 int MemOperator::IncRefCount(void *ptr, int ref_count) {
279 std::lock_guard<std::mutex> locker(mutex_);
280 auto iter = datas_.find(ptr);
281 if (iter != datas_.end()) {
282 auto index = iter->second;
283 blocks_[index].ref_count_ += ref_count;
284 return static_cast<int>(blocks_[index].ref_count_);
285 }
286 return kInvalidRefCount;
287 }
288
DecRefCount(void * ptr,int ref_count)289 int MemOperator::DecRefCount(void *ptr, int ref_count) {
290 std::lock_guard<std::mutex> locker(mutex_);
291 auto iter = datas_.find(ptr);
292 if (iter != datas_.end()) {
293 auto index = iter->second;
294 blocks_[index].ref_count_ -= ref_count;
295 return static_cast<int>(blocks_[index].ref_count_);
296 }
297 return kInvalidRefCount;
298 }
299
RefCount(void * ptr)300 int MemOperator::RefCount(void *ptr) {
301 std::lock_guard<std::mutex> locker(mutex_);
302 auto iter = datas_.find(ptr);
303 if (iter != datas_.end()) {
304 return static_cast<int>(blocks_[iter->second].ref_count_);
305 }
306 return kInvalidRefCount;
307 }
308
GetMemOperator(const int node_id)309 std::shared_ptr<MemOperator> DynamicMemManager::GetMemOperator(const int node_id) {
310 std::map<int, std::shared_ptr<MemOperator>>::iterator iter;
311 int numa_node_id = node_id;
312 if (numa_node_id < 0) {
313 numa_node_id = kInvalidNodeId;
314 }
315
316 std::lock_guard<std::mutex> locker(mutex_);
317 std::shared_ptr<MemOperator> mem_oper = nullptr;
318 iter = nodes_mem_.find(numa_node_id);
319 if (iter == nodes_mem_.end()) {
320 mem_oper = std::make_shared<MemOperator>(numa_node_id);
321 if (MS_UNLIKELY(mem_oper == nullptr)) {
322 MS_LOG(ERROR) << "make_shared MemOperator failed!";
323 return nullptr;
324 }
325 nodes_mem_.insert({numa_node_id, mem_oper});
326 } else {
327 mem_oper = iter->second;
328 }
329 return mem_oper;
330 }
331 } // namespace mindspore
332