1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ 19 #include <memory> 20 #include <utility> 21 #include <vector> 22 #include <map> 23 #include <queue> 24 #include <string> 25 #include <unordered_map> 26 #include "backend/common/mem_reuse/mem_reuse.h" 27 #include "include/backend/mem_reuse/mem_dynamic_allocator.h" 28 #include "runtime/device/common_somas_allocator.h" 29 30 namespace mindspore { 31 namespace device { 32 enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem }; 33 constexpr int kGetAllOuts = -1; 34 constexpr uint64_t kMemAlignSize = 512; 35 constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1; 36 using SomasAllocatorPtr = mindspore::device::CommonSomasAllocatorPtr; 37 38 class BACKEND_EXPORT MemoryManager { 39 public: 40 MemoryManager() = default; 41 virtual ~MemoryManager() = default; 42 43 virtual void Initialize() = 0; 44 virtual void Finalize() = 0; ResetDynamicMemory()45 virtual void ResetDynamicMemory() {} ClearGlobalIdleMem()46 virtual void ClearGlobalIdleMem() {} 47 48 virtual void MallocSomasDynamicMem(const session::KernelGraph &graph); 49 uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size, 50 const DeviceAddressPtr &address, bool comm_mem); 51 uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size); 52 uint8_t *MallocWorkSpaceMem(size_t size); 53 virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id); MallocMem(MemType type,size_t size,const DeviceAddressPtr & address)54 virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { 55 return MallocMem(type, size, address, kInvalidGraphId); 56 } 57 // param address is the address type of each device 58 // param from_persistent_mem shows whether the tensor is a parameter in Pynative mode 59 virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size); 60 virtual void *MallocMemFromMemPool(size_t size, bool from_persistent_mem, bool need_recycle = false, 61 uint32_t stream_id = kDefaultStreamIndex); GetMaxUsedMemorySize()62 virtual size_t GetMaxUsedMemorySize() const { return 0; } 63 virtual void FreeMemFromMemPool(const DeviceAddressPtr address); 64 virtual void FreeMemFromMemPool(void *device_ptr); 65 virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size, 66 std::vector<size_t> size_list, uint32_t stream_id = kDefaultStreamIndex); 67 virtual std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list, 68 uint32_t stream_id = kDefaultStreamIndex); 69 70 static size_t GetCommonAlignSize(size_t input_size); 71 static size_t GetCommunicationAlignSize(size_t input_size); 72 SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)73 virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) { 74 MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream; 75 } SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)76 virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) { 77 MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream; 78 } GetAvailableMemSize()79 virtual size_t GetAvailableMemSize() { 80 MS_LOG(ERROR) << "Return default 0 mem size!"; 81 return 0; 82 } 83 RecordEvent(int64_t task_id_on_stream,uint32_t user_stream_id,const std::vector<std::pair<uint32_t,DeviceMemPtr>> & memory_stream_addresses,const DeviceEventPtr & event)84 bool RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id, 85 const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses, 86 const DeviceEventPtr &event) { 87 if (memory_pool_ == nullptr) { 88 MS_LOG(WARNING) << "memory_pool_ is nullptr."; 89 return false; 90 } 91 return memory_pool_->RecordEvent(task_id_on_stream, user_stream_id, memory_stream_addresses, event); 92 } WaitEvent(int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)93 bool WaitEvent(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id) { 94 if (memory_pool_ == nullptr) { 95 MS_LOG(WARNING) << "memory_pool_ is nullptr."; 96 return false; 97 } 98 return memory_pool_->WaitEvent(task_id_on_stream, user_stream_id, memory_stream_id); 99 } WaitEvent(int64_t task_id_on_stream,uint32_t memory_stream_id)100 bool WaitEvent(int64_t task_id_on_stream, uint32_t memory_stream_id) { 101 if (memory_pool_ == nullptr) { 102 MS_LOG(WARNING) << "memory_pool_ is nullptr."; 103 return false; 104 } 105 return memory_pool_->WaitEvent(task_id_on_stream, memory_stream_id); 106 } SyncAllEvents()107 bool SyncAllEvents() { 108 if (memory_pool_ == nullptr) { 109 MS_LOG(WARNING) << "memory_pool_ is nullptr."; 110 return false; 111 } 112 return memory_pool_->SyncAllEvents(); 113 } 114 memory_pool()115 DynamicMemPoolBestFit *memory_pool() { return memory_pool_; } 116 117 // Relevant function to manage memory statistics GetTotalMemStatistics()118 virtual size_t GetTotalMemStatistics() const { return 0; } GetTotalUsedMemStatistics()119 virtual size_t GetTotalUsedMemStatistics() const { return 0; } GetTotalIdleMemStatistics()120 virtual size_t GetTotalIdleMemStatistics() const { return 0; } GetTotalEagerFreeMemStatistics()121 virtual size_t GetTotalEagerFreeMemStatistics() const { return 0; } GetUsedMemPeakStatistics()122 virtual size_t GetUsedMemPeakStatistics() const { return 0; } GetReservedMemPeakStatistics()123 virtual size_t GetReservedMemPeakStatistics() const { return 0; } GetBlockCountsStatistics()124 virtual std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const { return {}; } GetBlockUnitSizeStatistics()125 virtual std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const { return {}; } 126 virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetCommonMemBlocksInfoStatistics()127 GetCommonMemBlocksInfoStatistics() const { 128 return {}; 129 } 130 virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetPersistentMemBlocksInfoStatistics()131 GetPersistentMemBlocksInfoStatistics() const { 132 return {}; 133 } ResetMaxMemoryReserved()134 virtual void ResetMaxMemoryReserved() const {} ResetMaxMemoryAllocated()135 virtual void ResetMaxMemoryAllocated() const {} 136 137 protected: 138 virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) = 0; MallocStaticMem(size_t size,bool communication_mem)139 virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem) { 140 return MallocStaticMem(size, communication_mem, kInvalidGraphId); 141 } 142 virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem); 143 SomasAllocatorPtr somas_allocator_ptr_{nullptr}; 144 145 // Hold memory pool for common operations on memory. 146 DynamicMemPoolBestFit *memory_pool_{nullptr}; 147 }; 148 } // namespace device 149 } // namespace mindspore 150 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_ 151