1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_AUTO_MEM_OFFLOAD_H_ 17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_AUTO_MEM_OFFLOAD_H_ 18 19 #include <utility> 20 #include <queue> 21 #include <map> 22 #include <vector> 23 #include <memory> 24 #include <shared_mutex> 25 26 #include "runtime/device/memory_manager.h" 27 #include "include/backend/mem_reuse/mem_dynamic_allocator.h" 28 #include "utils/hash_map.h" 29 #include "utils/hash_set.h" 30 31 namespace mindspore { 32 namespace device { 33 class OffloadedMemPool { 34 public: 35 OffloadedMemPool() = default; 36 ~OffloadedMemPool() = default; 37 void *MallocHost(size_t mem_size); 38 void FreeHost(void *ptr); 39 40 private: 41 std::map<size_t, std::queue<void *>> cached_host_mem_; 42 std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_; 43 }; 44 45 class MemHandler { 46 public: MemHandler(std::shared_ptr<MemoryManager> memory_manager)47 explicit MemHandler(std::shared_ptr<MemoryManager> memory_manager) : memory_manager_(std::move(memory_manager)) { 48 host_mem_cache_ = std::make_shared<OffloadedMemPool>(); 49 } 50 ~MemHandler() = default; GetAvailableMemSize()51 size_t GetAvailableMemSize() { return memory_manager_->GetAvailableMemSize(); } MallocDevice(size_t mem_size)52 void *MallocDevice(size_t mem_size) { return memory_manager_->MallocMemFromMemPool(mem_size, false); } FreeDevice(void * ptr)53 void FreeDevice(void *ptr) { memory_manager_->FreeMemFromMemPool(ptr); } MallocHost(size_t mem_size)54 void *MallocHost(size_t mem_size) { return host_mem_cache_->MallocHost(mem_size); } FreeHost(void * ptr)55 void FreeHost(void *ptr) { host_mem_cache_->FreeHost(ptr); } SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)56 void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) { 57 memory_manager_->SwapIn(host_ptr, device_ptr, mem_size, stream); 58 } SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)59 void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) { 60 memory_manager_->SwapOut(device_ptr, host_ptr, mem_size, stream); 61 } MallocContinuousMemFromMemPool(const std::vector<size_t> & size_list)62 std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list) { 63 return memory_manager_->MallocContinuousMemFromMemPool(size_list); 64 } 65 66 private: 67 std::shared_ptr<OffloadedMemPool> host_mem_cache_; 68 std::shared_ptr<MemoryManager> memory_manager_; 69 }; 70 71 class BACKEND_EXPORT AutoMemoryOffload { 72 public: AutoMemoryOffload(std::shared_ptr<MemHandler> mem_handler)73 explicit AutoMemoryOffload(std::shared_ptr<MemHandler> mem_handler) : mem_handler_(std::move(mem_handler)) {} 74 ~AutoMemoryOffload() = default; 75 void *Get(const void *key, void *stream = nullptr, const HashSet<const void *> &pinned_memory = {}); 76 void *Malloc(const void *key, size_t mem_size, void *stream, const HashSet<const void *> &pinned_memory); 77 bool MallocContinuous(const std::vector<const void *> &keys, const std::vector<size_t> &size_list, void *stream, 78 const HashSet<const void *> &pinned_memory); 79 void Free(const void *key); 80 void Clear(); 81 void SetInitHostPtr(const void *key, void *host_ptr, size_t mem_size); 82 void UpdateHighPriorityMem(const void *key); 83 84 void SwapOut(const void *key, void *stream); 85 // Return the device ptr where the data is copied to 86 void *SwapIn(const void *key, void *stream); 87 88 private: 89 size_t GetMemSize(const void *key); 90 void GetHostPtr(const void *key, void **host_ptr, bool *from_init); 91 void GetOrMallocHostPtr(const void *key, size_t mem_size, void **host_ptr, bool *from_init); 92 template <typename MallocInfo> 93 bool TryAllocMemory( 94 const MallocInfo &info, size_t total_size, void *stream, const HashSet<const void *> &pinned_memory, 95 const std::function<bool(const MallocInfo &, const std::shared_ptr<MemHandler> &, HashMap<const void *, void *> *, 96 HashMap<const void *, size_t> *)> &alloc_func); 97 std::shared_ptr<MemHandler> mem_handler_; 98 HashMap<const void *, void *> mem_result_; 99 HashMap<const void *, size_t> mem_size_; 100 HashSet<const void *> init_from_host_keys_; 101 HashSet<const void *> updated_device_mem_; 102 HashSet<const void *> continuous_mem_key_; 103 HashMap<const void *, void *> init_host_ptr_; 104 HashMap<const void *, void *> swap_host_ptr_; 105 }; 106 107 class BACKEND_EXPORT MindRTAutoOffloadAdapter { 108 public: MindRTAutoOffloadAdapter(DynamicMemPoolBestFit * mem_pool,size_t stream_id)109 MindRTAutoOffloadAdapter(DynamicMemPoolBestFit *mem_pool, size_t stream_id) 110 : mem_pool_(mem_pool), stream_id_(stream_id) {} 111 ~MindRTAutoOffloadAdapter() = default; 112 bool Malloc(DeviceAddress *key); 113 void *Malloc(size_t size, const HashSet<const void *> &pinned_mem = {}); 114 std::vector<void *> MallocContinuousMem(const std::vector<size_t> &size_list); 115 116 private: 117 // Return the host ptr where the data is copied to 118 void SwapOut(DeviceAddress *device_address); 119 template <typename MallocInfo, typename ReturnType> 120 bool TryAllocMemory(const MallocInfo &info, size_t total_size, const HashSet<const void *> &pinned_mem, 121 const std::function<bool(const MallocInfo &, DynamicMemPoolBestFit *, ReturnType *)> &alloc_func, 122 ReturnType *ret); 123 124 DynamicMemPoolBestFit *mem_pool_; 125 size_t stream_id_; 126 // Read/Write lock for all_mem_ map. 127 std::shared_mutex all_mem_mutex_; 128 HashSet<DeviceAddress *> all_mem_; 129 }; 130 } // namespace device 131 } // namespace mindspore 132 133 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_AUTO_MEM_OFFLOAD_H_ 134