• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_AUTO_MEM_OFFLOAD_H_
17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_AUTO_MEM_OFFLOAD_H_
18 
19 #include <utility>
20 #include <queue>
21 #include <map>
22 #include <vector>
23 #include <memory>
24 #include <shared_mutex>
25 
26 #include "runtime/device/memory_manager.h"
27 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
28 #include "utils/hash_map.h"
29 #include "utils/hash_set.h"
30 
31 namespace mindspore {
32 namespace device {
33 class OffloadedMemPool {
34  public:
35   OffloadedMemPool() = default;
36   ~OffloadedMemPool() = default;
37   void *MallocHost(size_t mem_size);
38   void FreeHost(void *ptr);
39 
40  private:
41   std::map<size_t, std::queue<void *>> cached_host_mem_;
42   std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
43 };
44 
45 class MemHandler {
46  public:
MemHandler(std::shared_ptr<MemoryManager> memory_manager)47   explicit MemHandler(std::shared_ptr<MemoryManager> memory_manager) : memory_manager_(std::move(memory_manager)) {
48     host_mem_cache_ = std::make_shared<OffloadedMemPool>();
49   }
50   ~MemHandler() = default;
GetAvailableMemSize()51   size_t GetAvailableMemSize() { return memory_manager_->GetAvailableMemSize(); }
MallocDevice(size_t mem_size)52   void *MallocDevice(size_t mem_size) { return memory_manager_->MallocMemFromMemPool(mem_size, false); }
FreeDevice(void * ptr)53   void FreeDevice(void *ptr) { memory_manager_->FreeMemFromMemPool(ptr); }
MallocHost(size_t mem_size)54   void *MallocHost(size_t mem_size) { return host_mem_cache_->MallocHost(mem_size); }
FreeHost(void * ptr)55   void FreeHost(void *ptr) { host_mem_cache_->FreeHost(ptr); }
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)56   void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
57     memory_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream);
58   }
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)59   void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
60     memory_manager_->SwapOut(device_ptr, host_ptr, mem_size, stream);
61   }
MallocContinuousMemFromMemPool(const std::vector<size_t> & size_list)62   std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list) {
63     return memory_manager_->MallocContinuousMemFromMemPool(size_list);
64   }
65 
66  private:
67   std::shared_ptr<OffloadedMemPool> host_mem_cache_;
68   std::shared_ptr<MemoryManager> memory_manager_;
69 };
70 
71 class BACKEND_EXPORT AutoMemoryOffload {
72  public:
AutoMemoryOffload(std::shared_ptr<MemHandler> mem_handler)73   explicit AutoMemoryOffload(std::shared_ptr<MemHandler> mem_handler) : mem_handler_(std::move(mem_handler)) {}
74   ~AutoMemoryOffload() = default;
75   void *Get(const void *key, void *stream = nullptr, const HashSet<const void *> &pinned_memory = {});
76   void *Malloc(const void *key, size_t mem_size, void *stream, const HashSet<const void *> &pinned_memory);
77   bool MallocContinuous(const std::vector<const void *> &keys, const std::vector<size_t> &size_list, void *stream,
78                         const HashSet<const void *> &pinned_memory);
79   void Free(const void *key);
80   void Clear();
81   void SetInitHostPtr(const void *key, void *host_ptr, size_t mem_size);
82   void UpdateHighPriorityMem(const void *key);
83 
84   void SwapOut(const void *key, void *stream);
85   // Return the device ptr where the data is copied to
86   void *SwapIn(const void *key, void *stream);
87 
88  private:
89   size_t GetMemSize(const void *key);
90   void GetHostPtr(const void *key, void **host_ptr, bool *from_init);
91   void GetOrMallocHostPtr(const void *key, size_t mem_size, void **host_ptr, bool *from_init);
92   template <typename MallocInfo>
93   bool TryAllocMemory(
94     const MallocInfo &info, size_t total_size, void *stream, const HashSet<const void *> &pinned_memory,
95     const std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &, HashMap<const void *, void *> *,
96                              HashMap<const void *, size_t> *)> &alloc_func);
97   std::shared_ptr<MemHandler> mem_handler_;
98   HashMap<const void *, void *> mem_result_;
99   HashMap<const void *, size_t> mem_size_;
100   HashSet<const void *> init_from_host_keys_;
101   HashSet<const void *> updated_device_mem_;
102   HashSet<const void *> continuous_mem_key_;
103   HashMap<const void *, void *> init_host_ptr_;
104   HashMap<const void *, void *> swap_host_ptr_;
105 };
106 
107 class BACKEND_EXPORT MindRTAutoOffloadAdapter {
108  public:
MindRTAutoOffloadAdapter(DynamicMemPoolBestFit * mem_pool,size_t stream_id)109   MindRTAutoOffloadAdapter(DynamicMemPoolBestFit *mem_pool, size_t stream_id)
110       : mem_pool_(mem_pool), stream_id_(stream_id) {}
111   ~MindRTAutoOffloadAdapter() = default;
112   bool Malloc(DeviceAddress *key);
113   void *Malloc(size_t size, const HashSet<const void *> &pinned_mem = {});
114   std::vector<void *> MallocContinuousMem(const std::vector<size_t> &size_list);
115 
116  private:
117   // Return the host ptr where the data is copied to
118   void SwapOut(DeviceAddress *device_address);
119   template <typename MallocInfo, typename ReturnType>
120   bool TryAllocMemory(const MallocInfo &info, size_t total_size, const HashSet<const void *> &pinned_mem,
121                       const std::function<bool(const MallocInfo &, DynamicMemPoolBestFit *, ReturnType *)> &alloc_func,
122                       ReturnType *ret);
123 
124   DynamicMemPoolBestFit *mem_pool_;
125   size_t stream_id_;
126   // Read/Write lock for all_mem_ map.
127   std::shared_mutex all_mem_mutex_;
128   HashSet<DeviceAddress *> all_mem_;
129 };
130 }  // namespace device
131 }  // namespace mindspore
132 
133 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_AUTO_MEM_OFFLOAD_H_
134