• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/auto_mem_offload.h"
18 #include <vector>
19 #include "runtime/hardware/device_context.h"
20 #include "runtime/device/memory_offload_strategy.h"
21 
22 namespace mindspore {
23 namespace device {
MallocHost(size_t mem_size)24 void *OffloadedMemPool::MallocHost(size_t mem_size) {
25   auto &mem_que = cached_host_mem_[mem_size];
26   if (!mem_que.empty()) {
27     auto ret = mem_que.front();
28     mem_que.pop();
29     return ret;
30   }
31   auto block = std::make_shared<std::vector<uint8_t>>();
32   try {
33     block->resize(mem_size, 0);
34     auto ptr = block->data();
35     host_mem_block_map_[ptr] = block;
36     return ptr;
37   } catch (const std::exception &e) {
38     MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
39   }
40 }
41 
FreeHost(void * ptr)42 void OffloadedMemPool::FreeHost(void *ptr) {
43   MS_EXCEPTION_IF_NULL(ptr);
44   auto iter = host_mem_block_map_.find(ptr);
45   if (iter == host_mem_block_map_.end()) {
46     MS_LOG(DEBUG) << "Free ptr not be created from here, abort";
47     return;
48   }
49   MS_EXCEPTION_IF_NULL(iter->second);
50   auto mem_size = iter->second->size();
51   (void)cached_host_mem_[mem_size].emplace(iter->first);
52 }
53 
SetInitHostPtr(const void * key,void * host_ptr,size_t mem_size)54 void AutoMemoryOffload::SetInitHostPtr(const void *key, void *host_ptr, size_t mem_size) {
55   (void)init_from_host_keys_.insert(key);
56   init_host_ptr_[key] = host_ptr;
57   mem_size_[key] = mem_size;
58 }
59 
Free(const void * key)60 void AutoMemoryOffload::Free(const void *key) {
61   const auto &iter = mem_result_.find(key);
62   if (iter == mem_result_.end()) {
63     return;
64   }
65   auto ptr = iter->second;
66   MS_EXCEPTION_IF_NULL(mem_handler_);
67   mem_handler_->FreeDevice(ptr);
68   (void)mem_result_.erase(key);
69 }
70 
Get(const void * key,void * stream,const HashSet<const void * > & pinned_memory)71 void *AutoMemoryOffload::Get(const void *key, void *stream, const HashSet<const void *> &pinned_memory) {
72   auto iter = mem_result_.find(key);
73   if (iter != mem_result_.end()) {
74     return iter->second;
75   }
76   if (stream == nullptr) {
77     return nullptr;
78   }
79   void *host_ptr = nullptr;
80   bool from_init = false;
81   GetHostPtr(key, &host_ptr, &from_init);
82   if (host_ptr == nullptr) {
83     return nullptr;
84   }
85   const auto mem_size = GetMemSize(key);
86   auto device_ptr = Malloc(key, mem_size, stream, pinned_memory);
87   if (device_ptr == nullptr) {
88     return nullptr;
89   }
90   MS_EXCEPTION_IF_NULL(mem_handler_);
91   mem_handler_->SwapIn(host_ptr, device_ptr, mem_size, stream);
92   if (!from_init) {
93     (void)swap_host_ptr_.erase(key);
94     mem_handler_->FreeHost(host_ptr);
95   }
96   mem_result_[key] = device_ptr;
97   return device_ptr;
98 }
99 
MallocContinuous(const std::vector<const void * > & keys,const std::vector<size_t> & size_list,void * stream,const HashSet<const void * > & pinned_memory)100 bool AutoMemoryOffload::MallocContinuous(const std::vector<const void *> &keys, const std::vector<size_t> &size_list,
101                                          void *stream, const HashSet<const void *> &pinned_memory) {
102   MS_EXCEPTION_IF_NULL(mem_handler_);
103   const size_t total_size = std::accumulate(size_list.begin(), size_list.end(), static_cast<size_t>(0));
104   using MallocInfo = std::pair<const std::vector<const void *> &, const std::vector<size_t> &>;
105   std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &mem_handler,
106                      HashMap<const void *, void *> *, HashMap<const void *, size_t> *)>
107     malloc_func = [](const MallocInfo &info, const std::shared_ptr<MemHandler> &mem_handler,
108                      HashMap<const void *, void *> *mem_result, HashMap<const void *, size_t> *mem_size) {
109       const auto keys = info.first;
110       const auto size_list = info.second;
111       auto device_ptr = mem_handler->MallocContinuousMemFromMemPool(size_list);
112       if (device_ptr.size() != keys.size()) {
113         return false;
114       }
115       for (size_t i = 0; i < device_ptr.size(); i += 1) {
116         (*mem_result)[keys[i]] = device_ptr[i];
117         (*mem_size)[keys[i]] = size_list[i];
118       }
119       return true;
120     };
121   if (!TryAllocMemory<MallocInfo>(std::make_pair(keys, size_list), total_size, stream, pinned_memory, malloc_func)) {
122     return false;
123   }
124   for (auto key : keys) {
125     (void)continuous_mem_key_.insert(key);
126   }
127   return true;
128 }
129 
Malloc(const void * key,size_t mem_size,void * stream,const HashSet<const void * > & pinned_memory)130 void *AutoMemoryOffload::Malloc(const void *key, size_t mem_size, void *stream,
131                                 const HashSet<const void *> &pinned_memory) {
132   auto iter = mem_result_.find(key);
133   if (iter != mem_result_.end()) {
134     return iter->second;
135   }
136 
137   using MallocInfo = std::pair<const void *, size_t>;
138   std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &mem_handler,
139                      HashMap<const void *, void *> *, HashMap<const void *, size_t> *)>
140     malloc_func = [](const MallocInfo &info, const std::shared_ptr<MemHandler> &mem_handler,
141                      HashMap<const void *, void *> *mem_result, HashMap<const void *, size_t> *mem_size) {
142       MS_EXCEPTION_IF_NULL(mem_handler);
143       const auto key = info.first;
144       const auto size = info.second;
145       auto device_ptr = mem_handler->MallocDevice(size);
146       if (device_ptr == nullptr) {
147         return false;
148       }
149       (*mem_result)[key] = device_ptr;
150       (*mem_size)[key] = size;
151       return true;
152     };
153   return TryAllocMemory<MallocInfo>(std::make_pair(key, mem_size), mem_size, stream, pinned_memory, malloc_func)
154            ? mem_result_[key]
155            : nullptr;
156 }
157 
158 template <typename MallocInfo>
TryAllocMemory(const MallocInfo & info,size_t total_size,void * stream,const HashSet<const void * > & pinned_memory,const std::function<bool (const MallocInfo &,const std::shared_ptr<MemHandler> &,HashMap<const void *,void * > *,HashMap<const void *,size_t> *)> & alloc_func)159 bool AutoMemoryOffload::TryAllocMemory(
160   const MallocInfo &info, size_t total_size, void *stream, const HashSet<const void *> &pinned_memory,
161   const std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &, HashMap<const void *, void *> *,
162                            HashMap<const void *, size_t> *)> &alloc_func) {
163   if (alloc_func(info, mem_handler_, &mem_result_, &mem_size_)) {
164     return true;
165   }
166   if (stream == nullptr) {
167     return false;
168   }
169   using KeySizePair = std::pair<const void *, size_t>;
170   auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
171   std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_offload(less);
172   for (const auto &i : mem_result_) {
173     const auto offload_key = i.first;
174     if (pinned_memory.count(offload_key) != 0) {
175       continue;
176     }
177     const auto device_mem_size = GetMemSize(offload_key);
178     if (device_mem_size >= total_size) {
179       SwapOut(offload_key, stream);
180       Free(offload_key);
181       if (alloc_func(info, mem_handler_, &mem_result_, &mem_size_)) {
182         return true;
183       }
184     }
185     mem_can_offload.push({offload_key, device_mem_size});
186   }
187   while (!mem_can_offload.empty()) {
188     const auto &max_mem_in_device = mem_can_offload.top();
189     const auto offload_mem_key = max_mem_in_device.first;
190     auto offload_device_ptr = mem_result_[offload_mem_key];
191     MS_EXCEPTION_IF_NULL(offload_device_ptr);
192     SwapOut(offload_mem_key, stream);
193     Free(offload_mem_key);
194     if (alloc_func(info, mem_handler_, &mem_result_, &mem_size_)) {
195       return true;
196     }
197     mem_can_offload.pop();
198   }
199   return false;
200 }
201 
SwapOut(const void * key,void * stream)202 void AutoMemoryOffload::SwapOut(const void *key, void *stream) {
203   const auto iter = mem_result_.find(key);
204   void *host_ptr = nullptr;
205   bool from_init = false;
206   const auto mem_size = GetMemSize(key);
207   if (iter == mem_result_.end()) {
208     GetHostPtr(key, &host_ptr, &from_init);
209     if (host_ptr == nullptr) {
210       MS_LOG(EXCEPTION) << "Can not find device ptr for key " << key;
211     }
212     return;
213   }
214   const auto device_ptr = iter->second;
215   GetOrMallocHostPtr(key, mem_size, &host_ptr, &from_init);
216   MS_EXCEPTION_IF_NULL(host_ptr);
217   auto updated_iter = from_init ? updated_device_mem_.find(key) : updated_device_mem_.end();
218   if (!from_init || updated_iter != updated_device_mem_.end()) {
219     mem_handler_->SwapOut(device_ptr, host_ptr, mem_size, stream);
220     if (updated_iter != updated_device_mem_.end()) {
221       (void)updated_device_mem_.erase(updated_iter);
222     }
223   }
224 }
225 
SwapIn(const void * key,void * stream)226 void *AutoMemoryOffload::SwapIn(const void *key, void *stream) {
227   MS_EXCEPTION_IF_NULL(mem_handler_);
228   const size_t mem_size = GetMemSize(key);
229   const auto &iter = mem_result_.find(key);
230   if (iter == mem_result_.end()) {
231     MS_LOG(EXCEPTION) << "Can not find device ptr for key " << key;
232   }
233   bool from_init = true;
234   void *host_ptr = nullptr;
235   GetHostPtr(key, &host_ptr, &from_init);
236   MS_EXCEPTION_IF_NULL(host_ptr);
237   mem_handler_->SwapIn(host_ptr, iter->second, mem_size, stream);
238   if (!from_init) {
239     mem_handler_->FreeHost(host_ptr);
240     (void)swap_host_ptr_.erase(key);
241   }
242   return iter->second;
243 }
244 
GetMemSize(const void * key)245 size_t AutoMemoryOffload::GetMemSize(const void *key) {
246   const auto &iter = mem_size_.find(key);
247   if (iter == mem_size_.end()) {
248     MS_LOG(EXCEPTION) << "Can not find memory size for key " << key;
249   }
250   return iter->second;
251 }
252 
GetOrMallocHostPtr(const void * key,size_t mem_size,void ** host_ptr,bool * from_init)253 void AutoMemoryOffload::GetOrMallocHostPtr(const void *key, size_t mem_size, void **host_ptr, bool *from_init) {
254   MS_EXCEPTION_IF_NULL(host_ptr);
255   MS_EXCEPTION_IF_NULL(mem_handler_);
256   GetHostPtr(key, host_ptr, from_init);
257   if (*host_ptr != nullptr) {
258     return;
259   }
260   *host_ptr = mem_handler_->MallocHost(mem_size);
261   *from_init = false;
262   swap_host_ptr_[key] = *host_ptr;
263 }
264 
GetHostPtr(const void * key,void ** host_ptr,bool * from_init)265 void AutoMemoryOffload::GetHostPtr(const void *key, void **host_ptr, bool *from_init) {
266   *from_init = init_from_host_keys_.count(key) != 0;
267   if (*from_init) {
268     const auto iter = init_host_ptr_.find(key);
269     if (iter == init_host_ptr_.end()) {
270       MS_LOG(EXCEPTION) << "Can not find host ptr for key " << key;
271     }
272     *host_ptr = iter->second;
273   } else {
274     auto iter = swap_host_ptr_.find(key);
275     if (iter != swap_host_ptr_.end()) {
276       *host_ptr = iter->second;
277     }
278   }
279 }
280 
Clear()281 void AutoMemoryOffload::Clear() {
282   if (mem_handler_ == nullptr) {
283     return;
284   }
285   for (auto &item : mem_result_) {
286     mem_handler_->FreeDevice(item.second);
287   }
288   mem_result_.clear();
289   for (const auto &item : swap_host_ptr_) {
290     const auto host_ptr = item.second;
291     if (host_ptr != nullptr) {
292       mem_handler_->FreeHost(host_ptr);
293     }
294   }
295   swap_host_ptr_.clear();
296   init_host_ptr_.clear();
297   init_from_host_keys_.clear();
298 }
299 
UpdateHighPriorityMem(const void * key)300 void AutoMemoryOffload::UpdateHighPriorityMem(const void *key) { (void)updated_device_mem_.insert(key); }
301 
Malloc(DeviceAddress * device_address)302 bool MindRTAutoOffloadAdapter::Malloc(DeviceAddress *device_address) {
303   if (device_address->GetPtr() != nullptr) {
304     return true;
305   }
306   const auto original_size = device_address->GetSize();
307   constexpr size_t kAlignBytes = 32;
308   const size_t align_size = ((original_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize) * kMemAlignSize;
309   const auto &pinned_mem = MemoryOffloadConflict::GetInstance().GetConflictMap(device_address);
310   const auto &device_ptr = Malloc(align_size, pinned_mem);
311   if (device_ptr == nullptr) {
312     return false;
313   }
314   device_address->set_ptr(device_ptr);
315   device_address->set_from_mem_pool(true);
316   std::unique_lock<std::shared_mutex> unq_lock(all_mem_mutex_);
317   (void)all_mem_.insert(device_address);
318   return true;
319 }
320 
Malloc(size_t size,const HashSet<const void * > & pinned_mem)321 void *MindRTAutoOffloadAdapter::Malloc(size_t size, const HashSet<const void *> &pinned_mem) {
322   const auto malloc_func = [](size_t size, DynamicMemPoolBestFit *mem_pool, void **device_ptr) {
323     *device_ptr = mem_pool->AllocTensorMem(size);
324     return *device_ptr != nullptr;
325   };
326   void *device_ptr = nullptr;
327   return TryAllocMemory<size_t, void *>(size, size, pinned_mem, malloc_func, &device_ptr) ? device_ptr : nullptr;
328 }
329 
MallocContinuousMem(const std::vector<size_t> & size_list)330 std::vector<void *> MindRTAutoOffloadAdapter::MallocContinuousMem(const std::vector<size_t> &size_list) {
331   const auto malloc_func = [](const std::vector<size_t> &size_list, DynamicMemPoolBestFit *mem_pool,
332                               std::vector<void *> *ptr_list) {
333     *ptr_list = std::move(mem_pool->AllocContinuousTensorMem(size_list));
334     return !ptr_list->empty();
335   };
336   size_t total_size = std::accumulate(size_list.cbegin(), size_list.cend(), size_t(0));
337   std::vector<void *> ptr_list;
338   if (!TryAllocMemory<const std::vector<size_t> &, std::vector<void *>>(size_list, total_size, {}, malloc_func,
339                                                                         &ptr_list)) {
340     return ptr_list;
341   }
342   if (ptr_list.size() != size_list.size()) {
343     MS_LOG(EXCEPTION) << "Size of ptr list[" << ptr_list.size() << "] and size list[" << size_list.size()
344                       << "] should be same.";
345   }
346   return ptr_list;
347 }
348 
349 template <typename MallocInfo, typename ReturnType>
TryAllocMemory(const MallocInfo & info,size_t total_size,const HashSet<const void * > & pinned_mem,const std::function<bool (const MallocInfo &,DynamicMemPoolBestFit *,ReturnType *)> & alloc_func,ReturnType * ret)350 bool MindRTAutoOffloadAdapter::TryAllocMemory(
351   const MallocInfo &info, size_t total_size, const HashSet<const void *> &pinned_mem,
352   const std::function<bool(const MallocInfo &, DynamicMemPoolBestFit *, ReturnType *)> &alloc_func, ReturnType *ret) {
353   if (alloc_func(info, mem_pool_, ret)) {
354     return true;
355   }
356   using KeySizePair = std::pair<DeviceAddress *, size_t>;
357   auto less = [](const KeySizePair &a, const KeySizePair &b) -> bool { return a.second < b.second; };
358   std::priority_queue<KeySizePair, std::vector<KeySizePair>, decltype(less)> mem_can_offload(less);
359   {
360     std::shared_lock<std::shared_mutex> shd_lock(all_mem_mutex_);
361     for (const auto &mem : all_mem_) {
362       if (!MemoryOffloadConflict::GetInstance().CanBeOffloaded(mem) || mem->mem_offloaded() ||
363           mem->GetPtr() == nullptr || pinned_mem.count(mem) != 0) {
364         continue;
365       }
366       const auto device_mem_size = mem->GetSize();
367       if (device_mem_size >= total_size) {
368         SwapOut(mem);
369 
370         if (alloc_func(info, mem_pool_, ret)) {
371           return true;
372         }
373       } else {
374         mem_can_offload.push({mem, device_mem_size});
375       }
376     }
377   }
378   while (!mem_can_offload.empty()) {
379     const auto &max_mem_in_device = mem_can_offload.top();
380     const auto offload_mem = max_mem_in_device.first;
381     SwapOut(offload_mem);
382 
383     if (alloc_func(info, mem_pool_, ret)) {
384       return true;
385     }
386     mem_can_offload.pop();
387   }
388   return false;
389 }
390 
SwapOut(DeviceAddress * device_address)391 void MindRTAutoOffloadAdapter::SwapOut(DeviceAddress *device_address) {
392   if (device_address->mem_offloaded()) {
393     return;
394   }
395   if (!device_address->Offload(stream_id_)) {
396     MS_LOG(EXCEPTION) << "Offload failed, size: " << device_address->GetSize() << ", stream id: " << stream_id_;
397   }
398 }
399 }  // namespace device
400 }  // namespace mindspore
401