1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/auto_mem_offload.h"
18 #include <vector>
19 #include "runtime/hardware/device_context.h"
20 #include "runtime/device/memory_offload_strategy.h"
21
22 namespace mindspore {
23 namespace device {
MallocHost(size_t mem_size)24 void *OffloadedMemPool::MallocHost(size_t mem_size) {
25 auto &mem_que = cached_host_mem_[mem_size];
26 if (!mem_que.empty()) {
27 auto ret = mem_que.front();
28 mem_que.pop();
29 return ret;
30 }
31 auto block = std::make_shared<std::vector<uint8_t>>();
32 try {
33 block->resize(mem_size, 0);
34 auto ptr = block->data();
35 host_mem_block_map_[ptr] = block;
36 return ptr;
37 } catch (const std::exception &e) {
38 MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
39 }
40 }
41
FreeHost(void * ptr)42 void OffloadedMemPool::FreeHost(void *ptr) {
43 MS_EXCEPTION_IF_NULL(ptr);
44 auto iter = host_mem_block_map_.find(ptr);
45 if (iter == host_mem_block_map_.end()) {
46 MS_LOG(DEBUG) << "Free ptr not be created from here, abort";
47 return;
48 }
49 MS_EXCEPTION_IF_NULL(iter->second);
50 auto mem_size = iter->second->size();
51 (void)cached_host_mem_[mem_size].emplace(iter->first);
52 }
53
SetInitHostPtr(const void * key,void * host_ptr,size_t mem_size)54 void AutoMemoryOffload::SetInitHostPtr(const void *key, void *host_ptr, size_t mem_size) {
55 (void)init_from_host_keys_.insert(key);
56 init_host_ptr_[key] = host_ptr;
57 mem_size_[key] = mem_size;
58 }
59
Free(const void * key)60 void AutoMemoryOffload::Free(const void *key) {
61 const auto &iter = mem_result_.find(key);
62 if (iter == mem_result_.end()) {
63 return;
64 }
65 auto ptr = iter->second;
66 MS_EXCEPTION_IF_NULL(mem_handler_);
67 mem_handler_->FreeDevice(ptr);
68 (void)mem_result_.erase(key);
69 }
70
Get(const void * key,void * stream,const HashSet<const void * > & pinned_memory)71 void *AutoMemoryOffload::Get(const void *key, void *stream, const HashSet<const void *> &pinned_memory) {
72 auto iter = mem_result_.find(key);
73 if (iter != mem_result_.end()) {
74 return iter->second;
75 }
76 if (stream == nullptr) {
77 return nullptr;
78 }
79 void *host_ptr = nullptr;
80 bool from_init = false;
81 GetHostPtr(key, &host_ptr, &from_init);
82 if (host_ptr == nullptr) {
83 return nullptr;
84 }
85 const auto mem_size = GetMemSize(key);
86 auto device_ptr = Malloc(key, mem_size, stream, pinned_memory);
87 if (device_ptr == nullptr) {
88 return nullptr;
89 }
90 MS_EXCEPTION_IF_NULL(mem_handler_);
91 mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
92 if (!from_init) {
93 (void)swap_host_ptr_.erase(key);
94 mem_handler_->FreeHost(host_ptr);
95 }
96 mem_result_[key] = device_ptr;
97 return device_ptr;
98 }
99
MallocContinuous(const std::vector<const void * > & keys,const std::vector<size_t> & size_list,void * stream,const HashSet<const void * > & pinned_memory)100 bool AutoMemoryOffload::MallocContinuous(const std::vector<const void *> &keys, const std::vector<size_t> &size_list,
101 void *stream, const HashSet<const void *> &pinned_memory) {
102 MS_EXCEPTION_IF_NULL(mem_handler_);
103 const size_t total_size = std::accumulate(size_list.begin(), size_list.end(), static_cast<size_t>(0));
104 using MallocInfo = std::pair<const std::vector<const void *> &, const std::vector<size_t> &>;
105 std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &mem_handler,
106 HashMap<const void *, void *> *, HashMap<const void *, size_t> *)>
107 malloc_func = [](const MallocInfo &info, const std::shared_ptr<MemHandler> &mem_handler,
108 HashMap<const void *, void *> *mem_result, HashMap<const void *, size_t> *mem_size) {
109 const auto keys = info.first;
110 const auto size_list = info.second;
111 auto device_ptr = mem_handler->MallocContinuousMemFromMemPool(size_list);
112 if (device_ptr.size() != keys.size()) {
113 return false;
114 }
115 for (size_t i = 0; i < device_ptr.size(); i += 1) {
116 (*mem_result)[keys[i]] = device_ptr[i];
117 (*mem_size)[keys[i]] = size_list[i];
118 }
119 return true;
120 };
121 if (!TryAllocMemory<MallocInfo>(std::make_pair(keys, size_list), total_size, stream, pinned_memory, malloc_func)) {
122 return false;
123 }
124 for (auto key : keys) {
125 (void)continuous_mem_key_.insert(key);
126 }
127 return true;
128 }
129
Malloc(const void * key,size_t mem_size,void * stream,const HashSet<const void * > & pinned_memory)130 void *AutoMemoryOffload::Malloc(const void *key, size_t mem_size, void *stream,
131 const HashSet<const void *> &pinned_memory) {
132 auto iter = mem_result_.find(key);
133 if (iter != mem_result_.end()) {
134 return iter->second;
135 }
136
137 using MallocInfo = std::pair<const void *, size_t>;
138 std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &mem_handler,
139 HashMap<const void *, void *> *, HashMap<const void *, size_t> *)>
140 malloc_func = [](const MallocInfo &info, const std::shared_ptr<MemHandler> &mem_handler,
141 HashMap<const void *, void *> *mem_result, HashMap<const void *, size_t> *mem_size) {
142 MS_EXCEPTION_IF_NULL(mem_handler);
143 const auto key = info.first;
144 const auto size = info.second;
145 auto device_ptr = mem_handler->MallocDevice(size);
146 if (device_ptr == nullptr) {
147 return false;
148 }
149 (*mem_result)[key] = device_ptr;
150 (*mem_size)[key] = size;
151 return true;
152 };
153 return TryAllocMemory<MallocInfo>(std::make_pair(key, mem_size), mem_size, stream, pinned_memory, malloc_func)
154 ? mem_result_[key]
155 : nullptr;
156 }
157
158 template <typename MallocInfo>
TryAllocMemory(const MallocInfo & info,size_t total_size,void * stream,const HashSet<const void * > & pinned_memory,const std::function<bool (const MallocInfo &,const std::shared_ptr<MemHandler> &,HashMap<const void *,void * > *,HashMap<const void *,size_t> *)> & alloc_func)159 bool AutoMemoryOffload::TryAllocMemory(
160 const MallocInfo &info, size_t total_size, void *stream, const HashSet<const void *> &pinned_memory,
161 const std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &, HashMap<const void *, void *> *,
162 HashMap<const void *, size_t> *)> &alloc_func) {
163 if (alloc_func(info, mem_handler_, &mem_result_, &mem_size_)) {
164 return true;
165 }
166 if (stream == nullptr) {
167 return false;
168 }
169 using KeySizePair = std::pair<const void *, size_t>;
170 auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
171 std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_offload(less);
172 for (const auto &i : mem_result_) {
173 const auto offload_key = i.first;
174 if (pinned_memory.count(offload_key) != 0) {
175 continue;
176 }
177 const auto device_mem_size = GetMemSize(offload_key);
178 if (device_mem_size >= total_size) {
179 SwapOut(offload_key, stream);
180 Free(offload_key);
181 if (alloc_func(info, mem_handler_, &mem_result_, &mem_size_)) {
182 return true;
183 }
184 }
185 mem_can_offload.push({offload_key, device_mem_size});
186 }
187 while (!mem_can_offload.empty()) {
188 const auto &max_mem_in_device = mem_can_offload.top();
189 const auto offload_mem_key = max_mem_in_device.first;
190 auto offload_device_ptr = mem_result_[offload_mem_key];
191 MS_EXCEPTION_IF_NULL(offload_device_ptr);
192 SwapOut(offload_mem_key, stream);
193 Free(offload_mem_key);
194 if (alloc_func(info, mem_handler_, &mem_result_, &mem_size_)) {
195 return true;
196 }
197 mem_can_offload.pop();
198 }
199 return false;
200 }
201
SwapOut(const void * key,void * stream)202 void AutoMemoryOffload::SwapOut(const void *key, void *stream) {
203 const auto iter = mem_result_.find(key);
204 void *host_ptr = nullptr;
205 bool from_init = false;
206 const auto mem_size = GetMemSize(key);
207 if (iter == mem_result_.end()) {
208 GetHostPtr(key, &host_ptr, &from_init);
209 if (host_ptr == nullptr) {
210 MS_LOG(EXCEPTION) << "Can not find device ptr for key " << key;
211 }
212 return;
213 }
214 const auto device_ptr = iter->second;
215 GetOrMallocHostPtr(key, mem_size, &host_ptr, &from_init);
216 MS_EXCEPTION_IF_NULL(host_ptr);
217 auto updated_iter = from_init ? updated_device_mem_.find(key) : updated_device_mem_.end();
218 if (!from_init || updated_iter != updated_device_mem_.end()) {
219 mem_handler_->SwapOut(device_ptr, host_ptr, mem_size, stream);
220 if (updated_iter != updated_device_mem_.end()) {
221 (void)updated_device_mem_.erase(updated_iter);
222 }
223 }
224 }
225
SwapIn(const void * key,void * stream)226 void *AutoMemoryOffload::SwapIn(const void *key, void *stream) {
227 MS_EXCEPTION_IF_NULL(mem_handler_);
228 const size_t mem_size = GetMemSize(key);
229 const auto &iter = mem_result_.find(key);
230 if (iter == mem_result_.end()) {
231 MS_LOG(EXCEPTION) << "Can not find device ptr for key " << key;
232 }
233 bool from_init = true;
234 void *host_ptr = nullptr;
235 GetHostPtr(key, &host_ptr, &from_init);
236 MS_EXCEPTION_IF_NULL(host_ptr);
237 mem_handler_->SwapIn(host_ptr, iter->second, mem_size, stream);
238 if (!from_init) {
239 mem_handler_->FreeHost(host_ptr);
240 (void)swap_host_ptr_.erase(key);
241 }
242 return iter->second;
243 }
244
GetMemSize(const void * key)245 size_t AutoMemoryOffload::GetMemSize(const void *key) {
246 const auto &iter = mem_size_.find(key);
247 if (iter == mem_size_.end()) {
248 MS_LOG(EXCEPTION) << "Can not find memory size for key " << key;
249 }
250 return iter->second;
251 }
252
GetOrMallocHostPtr(const void * key,size_t mem_size,void ** host_ptr,bool * from_init)253 void AutoMemoryOffload::GetOrMallocHostPtr(const void *key, size_t mem_size, void **host_ptr, bool *from_init) {
254 MS_EXCEPTION_IF_NULL(host_ptr);
255 MS_EXCEPTION_IF_NULL(mem_handler_);
256 GetHostPtr(key, host_ptr, from_init);
257 if (*host_ptr != nullptr) {
258 return;
259 }
260 *host_ptr = mem_handler_->MallocHost(mem_size);
261 *from_init = false;
262 swap_host_ptr_[key] = *host_ptr;
263 }
264
GetHostPtr(const void * key,void ** host_ptr,bool * from_init)265 void AutoMemoryOffload::GetHostPtr(const void *key, void **host_ptr, bool *from_init) {
266 *from_init = init_from_host_keys_.count(key) != 0;
267 if (*from_init) {
268 const auto iter = init_host_ptr_.find(key);
269 if (iter == init_host_ptr_.end()) {
270 MS_LOG(EXCEPTION) << "Can not find host ptr for key " << key;
271 }
272 *host_ptr = iter->second;
273 } else {
274 auto iter = swap_host_ptr_.find(key);
275 if (iter != swap_host_ptr_.end()) {
276 *host_ptr = iter->second;
277 }
278 }
279 }
280
Clear()281 void AutoMemoryOffload::Clear() {
282 if (mem_handler_ == nullptr) {
283 return;
284 }
285 for (auto &item : mem_result_) {
286 mem_handler_->FreeDevice(item.second);
287 }
288 mem_result_.clear();
289 for (const auto &item : swap_host_ptr_) {
290 const auto host_ptr = item.second;
291 if (host_ptr != nullptr) {
292 mem_handler_->FreeHost(host_ptr);
293 }
294 }
295 swap_host_ptr_.clear();
296 init_host_ptr_.clear();
297 init_from_host_keys_.clear();
298 }
299
UpdateHighPriorityMem(const void * key)300 void AutoMemoryOffload::UpdateHighPriorityMem(const void *key) { (void)updated_device_mem_.insert(key); }
301
Malloc(DeviceAddress * device_address)302 bool MindRTAutoOffloadAdapter::Malloc(DeviceAddress *device_address) {
303 if (device_address->GetPtr() != nullptr) {
304 return true;
305 }
306 const auto original_size = device_address->GetSize();
307 constexpr size_t kAlignBytes = 32;
308 const size_t align_size = ((original_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize) * kMemAlignSize;
309 const auto &pinned_mem = MemoryOffloadConflict::GetInstance().GetConflictMap(device_address);
310 const auto &device_ptr = Malloc(align_size, pinned_mem);
311 if (device_ptr == nullptr) {
312 return false;
313 }
314 device_address->set_ptr(device_ptr);
315 device_address->set_from_mem_pool(true);
316 std::unique_lock<std::shared_mutex> unq_lock(all_mem_mutex_);
317 (void)all_mem_.insert(device_address);
318 return true;
319 }
320
Malloc(size_t size,const HashSet<const void * > & pinned_mem)321 void *MindRTAutoOffloadAdapter::Malloc(size_t size, const HashSet<const void *> &pinned_mem) {
322 const auto malloc_func = [](size_t size, DynamicMemPoolBestFit *mem_pool, void **device_ptr) {
323 *device_ptr = mem_pool->AllocTensorMem(size);
324 return *device_ptr != nullptr;
325 };
326 void *device_ptr = nullptr;
327 return TryAllocMemory<size_t, void *>(size, size, pinned_mem, malloc_func, &device_ptr) ? device_ptr : nullptr;
328 }
329
MallocContinuousMem(const std::vector<size_t> & size_list)330 std::vector<void *> MindRTAutoOffloadAdapter::MallocContinuousMem(const std::vector<size_t> &size_list) {
331 const auto malloc_func = [](const std::vector<size_t> &size_list, DynamicMemPoolBestFit *mem_pool,
332 std::vector<void *> *ptr_list) {
333 *ptr_list = std::move(mem_pool->AllocContinuousTensorMem(size_list));
334 return !ptr_list->empty();
335 };
336 size_t total_size = std::accumulate(size_list.cbegin(), size_list.cend(), size_t(0));
337 std::vector<void *> ptr_list;
338 if (!TryAllocMemory<const std::vector<size_t> &, std::vector<void *>>(size_list, total_size, {}, malloc_func,
339 &ptr_list)) {
340 return ptr_list;
341 }
342 if (ptr_list.size() != size_list.size()) {
343 MS_LOG(EXCEPTION) << "Size of ptr list[" << ptr_list.size() << "] and size list[" << size_list.size()
344 << "] should be same.";
345 }
346 return ptr_list;
347 }
348
349 template <typename MallocInfo, typename ReturnType>
TryAllocMemory(const MallocInfo & info,size_t total_size,const HashSet<const void * > & pinned_mem,const std::function<bool (const MallocInfo &,DynamicMemPoolBestFit *,ReturnType *)> & alloc_func,ReturnType * ret)350 bool MindRTAutoOffloadAdapter::TryAllocMemory(
351 const MallocInfo &info, size_t total_size, const HashSet<const void *> &pinned_mem,
352 const std::function<bool(const MallocInfo &, DynamicMemPoolBestFit *, ReturnType *)> &alloc_func, ReturnType *ret) {
353 if (alloc_func(info, mem_pool_, ret)) {
354 return true;
355 }
356 using KeySizePair = std::pair<DeviceAddress *, size_t>;
357 auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
358 std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_offload(less);
359 {
360 std::shared_lock<std::shared_mutex> shd_lock(all_mem_mutex_);
361 for (const auto &mem : all_mem_) {
362 if (!MemoryOffloadConflict::GetInstance().CanBeOffloaded(mem) || mem->mem_offloaded() ||
363 mem->GetPtr() == nullptr || pinned_mem.count(mem) != 0) {
364 continue;
365 }
366 const auto device_mem_size = mem->GetSize();
367 if (device_mem_size >= total_size) {
368 SwapOut(mem);
369
370 if (alloc_func(info, mem_pool_, ret)) {
371 return true;
372 }
373 } else {
374 mem_can_offload.push({mem, device_mem_size});
375 }
376 }
377 }
378 while (!mem_can_offload.empty()) {
379 const auto &max_mem_in_device = mem_can_offload.top();
380 const auto offload_mem = max_mem_in_device.first;
381 SwapOut(offload_mem);
382
383 if (alloc_func(info, mem_pool_, ret)) {
384 return true;
385 }
386 mem_can_offload.pop();
387 }
388 return false;
389 }
390
SwapOut(DeviceAddress * device_address)391 void MindRTAutoOffloadAdapter::SwapOut(DeviceAddress *device_address) {
392 if (device_address->mem_offloaded()) {
393 return;
394 }
395 if (!device_address->Offload(stream_id_)) {
396 MS_LOG(EXCEPTION) << "Offload failed, size: " << device_address->GetSize() << ", stream id: " << stream_id_;
397 }
398 }
399 } // namespace device
400 } // namespace mindspore
401