1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ 19 20 #include <vector> 21 #include "runtime/device/memory_manager.h" 22 #include "graphengine/inc/external/runtime/rt_error_codes.h" 23 namespace mindspore { 24 namespace device { 25 namespace ascend { 26 class AscendMemoryManager : public MemoryManager { 27 public: 28 AscendMemoryManager() = default; 29 ~AscendMemoryManager() override = default; 30 31 void MallocDeviceMemory() override; 32 void FreeDeviceMemory() override; 33 void ResetDynamicMemory() override; 34 void ClearGlobalIdleMem() override; 35 void *MallocMemFromMemPool(size_t size) override; 36 void FreeMemFromMemPool(void *device_ptr) override; 37 uint64_t GetDeviceMemSize(); 38 void MallocSomasDynamicMem(const session::KernelGraph &graph) override; 39 uint8_t *MallocCommunicationMemFromMemPool(size_t size) override; MallocContinuousMemFromMemPool(size_t total_size,std::vector<size_t> size_list)40 std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) override { 41 return AscendMemoryPool::GetInstance().AllocContinuousTensorMem(total_size, size_list); 42 } 43 44 void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override; 45 void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override; 46 size_t GetAvailableMemSize() override; 47 48 protected: 49 uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) override; 50 uint8_t *MallocDynamicMem(size_t size, bool communication_mem) override; 51 52 private: 53 uint8_t *device_mem_pool_base_{nullptr}; 54 uint64_t device_mem_pool_size_{0}; 55 56 uint64_t GetDeviceMemSizeFromContext(); 57 }; 58 } // namespace ascend 59 } // namespace device 60 } // namespace mindspore 61 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_MEMORY_MANAGER_H_ 62