• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/extendrt/dynamic_mem_manager.h"
17 #include "src/common/log_adapter.h"
18 #include "src/common/utils.h"
19 #include "src/common/common.h"
20 
21 using mindspore::numa::NUMAAdapter;
22 
23 using mindspore::numa::MemoryInfo;
24 
25 namespace mindspore {
26 namespace {
27 // Alloc memory aligned according to 64 bytes.
28 static constexpr size_t kMemAlginSize = 64;
29 
30 // The default unit size (256M) of memory block used for dynamic extend.
31 static constexpr size_t kAllocUnitSize = 256 * 1024 * 1024;
32 // The minimum unit size (64M) of memory block used for dynamic extend.
33 static constexpr size_t kMinimumAllocUnitSize = 64 * 1024 * 1024;
34 // 16G
35 static constexpr size_t kMinimumSysMemory = 17179869184;
36 static constexpr auto kBlockSize = 2048;
37 // invalid block index
38 static constexpr int kInvalidIndex = -1;
39 // invalid numa node id
40 static constexpr int kInvalidNodeId = -1;
41 static constexpr int kInvalidRefCount = -1;
Rounded(size_t size)42 size_t Rounded(size_t size) { return (size + kMemAlginSize - 1) & (~(kMemAlginSize - 1)); }
43 }  // namespace
44 
Allocate(size_t rounded_size,int node_id,size_t * allocate_size)45 void *MemOperator::Allocate(size_t rounded_size, int node_id, size_t *allocate_size) {
46   static const auto kMaxMallocSize = lite::GetMaxMallocSize();
47   static const auto unit_size = kMaxMallocSize < kMinimumSysMemory ? kMinimumAllocUnitSize : kAllocUnitSize;
48   auto allocate_tmp_size = rounded_size < unit_size ? unit_size : rounded_size;
49   if (allocate_tmp_size > kMaxMallocSize) {
50     MS_LOG(ERROR) << "request invalid memory size " << allocate_tmp_size << ", total: " << kMaxMallocSize;
51     return nullptr;
52   }
53 
54   *allocate_size = allocate_tmp_size;
55   void *data = nullptr;
56 #ifdef _WIN32
57   data = _aligned_malloc(allocate_tmp_size, kMemAlginSize);
58 #else
59   if (node_id >= 0) {
60     data = numa_instance_->Malloc(node_id, static_cast<size_t>(allocate_tmp_size));
61     if (MS_UNLIKELY((data == nullptr && allocate_tmp_size > rounded_size))) {
62       MS_LOG(WARNING) << "Malloc memory(" << allocate_tmp_size << ") failed! malloc rounded_size(" << rounded_size
63                       << ") memory again. node_id: " << node_id;
64       allocate_tmp_size = rounded_size;
65       *allocate_size = rounded_size;
66       data = numa_instance_->Malloc(node_id, rounded_size);
67     }
68   } else {
69     auto ret = posix_memalign(&data, kMemAlginSize, static_cast<size_t>(allocate_tmp_size));
70     if (MS_UNLIKELY(ret == ENOMEM && allocate_tmp_size > rounded_size)) {
71       MS_LOG(WARNING) << "Malloc memory(" << allocate_tmp_size << ") failed! malloc rounded_size(" << rounded_size
72                       << ") memory again.";
73       allocate_tmp_size = rounded_size;
74       *allocate_size = rounded_size;
75       ret = posix_memalign(&data, kMemAlginSize, rounded_size);
76     }
77     if (MS_UNLIKELY(ret != 0)) {
78       MS_LOG(ERROR) << "posix_memalign failed!ret: " << ret << ", node_id: " << node_id
79                     << ", request: " << allocate_tmp_size;
80       return nullptr;
81     }
82   }
83 #endif
84   if (MS_UNLIKELY(data == nullptr)) {
85     MS_LOG(ERROR) << "malloc data failed!node_id: " << node_id << ", request: " << allocate_tmp_size;
86     return nullptr;
87   }
88 
89   return data;
90 }
91 
GetBlock()92 Block *MemOperator::GetBlock() {
93   Block *block;
94   if (garbage_block_ != kInvalidIndex) {
95     block = &blocks_[garbage_block_];
96     garbage_block_ = blocks_[garbage_block_].next_index_;
97   } else {
98     if (block_count_ >= blocks_.size()) {
99       blocks_.resize(blocks_.size() + kBlockSize);
100     }
101     blocks_[block_count_].index_ = static_cast<int64_t>(block_count_);
102     block = &blocks_[block_count_++];
103   }
104   block->used_ = false;
105   block->ref_count_ = 0;
106   block->pre_index_ = kInvalidIndex;
107   block->next_index_ = kInvalidIndex;
108   return block;
109 }
110 
AddGarbageBlock(const int64_t index)111 void MemOperator::AddGarbageBlock(const int64_t index) {
112   blocks_[index].next_index_ = garbage_block_;
113   garbage_block_ = index;
114 }
115 
116 // malloc memory for data storage
Malloc(size_t size)117 void *MemOperator::Malloc(size_t size) {
118   auto rounded_size = size == 0 ? 1 : Rounded(size);
119   std::lock_guard<std::mutex> locker(mutex_);
120   auto iter = free_blocks_.lower_bound(rounded_size);
121   if (iter != free_blocks_.end()) {
122     auto index = iter->second;
123     free_blocks_.erase(iter);
124     blocks_[index].used_ = true;
125     auto data = blocks_[index].data_;
126     datas_.emplace(data, index);
127     if (blocks_[index].size_ > rounded_size) {
128       Block *block_next = GetBlock();
129       auto *block = &blocks_[index];
130       block_next->size_ = block->size_ - rounded_size;
131       block->size_ = rounded_size;
132       block_next->data_ = static_cast<int8_t *>(block->data_) + rounded_size;
133       block_next->pre_index_ = index;
134       auto next_index = block->next_index_;
135       block_next->next_index_ = next_index;
136       if (next_index != kInvalidIndex) {
137         blocks_[next_index].pre_index_ = block_next->index_;
138       }
139       block->next_index_ = block_next->index_;
140       free_blocks_.emplace(block_next->size_, block_next->index_);
141     }
142     return data;
143   }
144   // todo kAllocUnitSize can be replaced by config
145   size_t allocate_size;
146   void *data = Allocate(rounded_size, node_id_, &allocate_size);
147   if (MS_UNLIKELY(data == nullptr)) {
148     return nullptr;
149   }
150   all_datas_.emplace(data, allocate_size);
151   Block *block = GetBlock();
152   block->size_ = rounded_size;
153   block->data_ = data;
154   block->used_ = true;
155   datas_.emplace(data, block->index_);
156   if (allocate_size > rounded_size) {
157     Block *block_next = GetBlock();
158     block_next->data_ = static_cast<int8_t *>(data) + rounded_size;
159     block_next->size_ = allocate_size - rounded_size;
160     block_next->pre_index_ = block->index_;
161     block->next_index_ = block_next->index_;
162     free_blocks_.emplace(block_next->size_, block_next->index_);
163   }
164   return data;
165 }
166 
167 // return memory to the memory pool
Free(void * ptr)168 void MemOperator::Free(void *ptr) {
169   if (MS_UNLIKELY(ptr == nullptr)) {
170     return;
171   }
172   std::lock_guard<std::mutex> locker(mutex_);
173   auto iter = datas_.find(ptr);
174   if (iter == datas_.end()) {
175     return;
176   }
177 
178   auto index = iter->second;
179   datas_.erase(iter);
180   Block *block = &blocks_[index];
181   auto next_index = block->next_index_;
182   if (next_index != kInvalidIndex && !blocks_[next_index].used_) {
183     EraseFreeBlock(next_index);
184     block->size_ += blocks_[next_index].size_;
185     auto next_next_index = blocks_[next_index].next_index_;
186     if (next_next_index != kInvalidIndex) {
187       blocks_[next_next_index].pre_index_ = block->index_;
188     }
189     block->next_index_ = next_next_index;
190     block->used_ = false;
191     block->ref_count_ = 0;
192     free_blocks_.emplace(block->size_, block->index_);
193     AddGarbageBlock(next_index);
194   }
195   auto pre_index = block->pre_index_;
196   if (pre_index != kInvalidIndex && !blocks_[pre_index].used_) {
197     EraseFreeBlock(pre_index);
198     if (!block->used_) {
199       EraseFreeBlock(index);
200     }
201     blocks_[pre_index].size_ += block->size_;
202     next_index = block->next_index_;
203     blocks_[pre_index].next_index_ = next_index;
204     if (next_index != kInvalidIndex) {
205       blocks_[next_index].pre_index_ = pre_index;
206     }
207     block->used_ = false;
208     block->ref_count_ = 0;
209     free_blocks_.emplace(blocks_[pre_index].size_, pre_index);
210     AddGarbageBlock(index);
211   }
212   if (block->used_) {
213     block->used_ = false;
214     block->ref_count_ = 0;
215     free_blocks_.emplace(block->size_, block->index_);
216   }
217 }
218 
EraseFreeBlock(const int64_t index)219 void MemOperator::EraseFreeBlock(const int64_t index) {
220   auto range = free_blocks_.equal_range(blocks_[index].size_);
221   for (auto item = range.first; item != range.second; ++item) {
222     if (item->second == index) {
223       free_blocks_.erase(item);
224       break;
225     }
226   }
227 }
228 
MemOperator(int node_id)229 MemOperator::MemOperator(int node_id) {
230   numa_instance_ = NUMAAdapter::GetInstance();
231   if (node_id >= 0 && numa_instance_->Available()) {
232     node_id_ = node_id;
233   }
234 
235   blocks_.resize(kBlockSize);
236   garbage_block_ = kInvalidIndex;
237   auto *block = GetBlock();
238   size_t allocate_size;
239   block->data_ = Allocate(kAllocUnitSize, node_id, &allocate_size);
240   if (MS_UNLIKELY(block->data_ == nullptr)) {
241     return;
242   }
243   all_datas_.emplace(block->data_, allocate_size);
244   block->size_ = allocate_size;
245   free_blocks_.emplace(allocate_size, block->index_);
246 }
247 
~MemOperator()248 MemOperator::~MemOperator() {
249   MS_LOG(DEBUG) << "~MemOperator() begin.";
250   for (auto &&data : all_datas_) {
251 #ifdef _WIN32
252     _aligned_free(data.first);
253 #else
254     if (node_id_ >= 0) {
255       numa_instance_->Free(data.first, data.second);
256     } else {
257       free(data.first);
258     }
259 #endif
260   }
261   free_blocks_.clear();
262   all_datas_.clear();
263   blocks_.clear();
264   MS_LOG(DEBUG) << "~MemOperator() end.";
265 }
266 
SetRefCount(void * ptr,int ref_count)267 int MemOperator::SetRefCount(void *ptr, int ref_count) {
268   std::lock_guard<std::mutex> locker(mutex_);
269   auto iter = datas_.find(ptr);
270   if (iter != datas_.end()) {
271     auto index = iter->second;
272     blocks_[index].ref_count_ = ref_count;
273     return ref_count;
274   }
275   return kInvalidRefCount;
276 }
277 
IncRefCount(void * ptr,int ref_count)278 int MemOperator::IncRefCount(void *ptr, int ref_count) {
279   std::lock_guard<std::mutex> locker(mutex_);
280   auto iter = datas_.find(ptr);
281   if (iter != datas_.end()) {
282     auto index = iter->second;
283     blocks_[index].ref_count_ += ref_count;
284     return static_cast<int>(blocks_[index].ref_count_);
285   }
286   return kInvalidRefCount;
287 }
288 
DecRefCount(void * ptr,int ref_count)289 int MemOperator::DecRefCount(void *ptr, int ref_count) {
290   std::lock_guard<std::mutex> locker(mutex_);
291   auto iter = datas_.find(ptr);
292   if (iter != datas_.end()) {
293     auto index = iter->second;
294     blocks_[index].ref_count_ -= ref_count;
295     return static_cast<int>(blocks_[index].ref_count_);
296   }
297   return kInvalidRefCount;
298 }
299 
RefCount(void * ptr)300 int MemOperator::RefCount(void *ptr) {
301   std::lock_guard<std::mutex> locker(mutex_);
302   auto iter = datas_.find(ptr);
303   if (iter != datas_.end()) {
304     return static_cast<int>(blocks_[iter->second].ref_count_);
305   }
306   return kInvalidRefCount;
307 }
308 
GetMemOperator(const int node_id)309 std::shared_ptr<MemOperator> DynamicMemManager::GetMemOperator(const int node_id) {
310   std::map<int, std::shared_ptr<MemOperator>>::iterator iter;
311   int numa_node_id = node_id;
312   if (numa_node_id < 0) {
313     numa_node_id = kInvalidNodeId;
314   }
315 
316   std::lock_guard<std::mutex> locker(mutex_);
317   std::shared_ptr<MemOperator> mem_oper = nullptr;
318   iter = nodes_mem_.find(numa_node_id);
319   if (iter == nodes_mem_.end()) {
320     mem_oper = std::make_shared<MemOperator>(numa_node_id);
321     if (MS_UNLIKELY(mem_oper == nullptr)) {
322       MS_LOG(ERROR) << "make_shared MemOperator failed!";
323       return nullptr;
324     }
325     nodes_mem_.insert({numa_node_id, mem_oper});
326   } else {
327     mem_oper = iter->second;
328   }
329   return mem_oper;
330 }
331 }  // namespace mindspore
332