1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ 19 20 #include <vector> 21 #include <string> 22 #include <unordered_map> 23 #include "runtime/device/memory_manager.h" 24 #include "plugin/device/ascend/hal/device/ascend_memory_pool.h" 25 26 namespace mindspore { 27 namespace device { 28 namespace ascend { 29 class AscendMemoryManager : public MemoryManager { 30 public: 31 AscendMemoryManager() = default; 32 ~AscendMemoryManager() override = default; 33 34 void Initialize() override; 35 void Finalize() override; 36 void ResetDynamicMemory() override; 37 void ClearGlobalIdleMem() override; 38 void *MallocMemFromMemPool(size_t size, bool from_persistent_mem, bool need_recycle = false, 39 uint32_t stream_id = kDefaultStreamIndex) override; 40 void FreeMemFromMemPool(void *device_ptr) override; 41 size_t GetMaxUsedMemorySize() const override; 42 uint64_t GetMsMaxMemSize() const; 43 bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size, 44 std::vector<size_t> size_list, uint32_t stream_id = kDefaultStreamIndex) override; 45 std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list, 46 uint32_t stream_id = kDefaultStreamIndex) override { 47 return AscendMemoryPool::GetInstance().AllocContinuousTensorMem(size_list, stream_id); 48 } 49 50 void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override; 51 void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override; 52 size_t GetAvailableMemSize() override; 53 uint64_t GetMsUsedHbmSize() const; 54 55 // Relevant function to manage memory statistics 56 size_t GetTotalMemStatistics() const override; 57 size_t GetTotalUsedMemStatistics() const override; 58 size_t GetTotalIdleMemStatistics() const override; 59 size_t GetTotalEagerFreeMemStatistics() const override; 60 size_t GetUsedMemPeakStatistics() const override; 61 size_t GetReservedMemPeakStatistics() const override; 62 std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const override; 63 std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const override; 64 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetCommonMemBlocksInfoStatistics() 65 const override; 66 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> 67 GetPersistentMemBlocksInfoStatistics() const override; 68 void ResetMaxMemoryReserved() const override; 69 void ResetMaxMemoryAllocated() const override; 70 71 protected: 72 uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) override; 73 uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override; 74 }; 75 } // namespace ascend 76 } // namespace device 77 } // namespace mindspore 78 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ 79