1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEVICE_RES_MANAGER_H_ 17 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEVICE_RES_MANAGER_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include <map> 23 #include <unordered_map> 24 #include "runtime/hardware/device_context.h" 25 #include "utils/ms_context.h" 26 #include "include/transform/graph_ir/types.h" 27 #include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h" 28 #include "plugin/device/ascend/hal/hardware/dummy_ascend_collective_comm_lib.h" 29 #ifdef ENABLE_INTERNAL_KERNELS 30 #include "plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h" 31 #endif 32 #include "plugin/device/cpu/hal/device/cpu_device_address.h" 33 #include "runtime/device/kernel_runtime_manager.h" 34 35 namespace mindspore { 36 namespace device { 37 namespace ascend { 38 class GeHostAddress : public cpu::CPUDeviceAddress { 39 public: GeHostAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)40 GeHostAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const std::string &device_name, 41 uint32_t device_id) 42 : CPUDeviceAddress(ptr, size, format, type_id, device_name, device_id) {} GeHostAddress(const KernelTensorPtr & kernel_tensor)43 explicit GeHostAddress(const KernelTensorPtr &kernel_tensor) : CPUDeviceAddress(kernel_tensor) {} GetDeviceType()44 DeviceType GetDeviceType() const override { return DeviceType::kAscend; } 45 }; 46 47 class GeDeviceResManager; 48 class GeAllocator : public ::ge::Allocator { 49 public: GeAllocator(GeDeviceResManager * res_manager)50 explicit GeAllocator(GeDeviceResManager *res_manager) : res_manager_(res_manager) {} ~GeAllocator()51 ~GeAllocator() { res_manager_ = nullptr; } 52 GeAllocator(const GeAllocator &) = delete; 53 GeAllocator &operator=(const GeAllocator &) = delete; 54 ::ge::MemBlock *Malloc(size_t size) override; 55 void Free(::ge::MemBlock *block) override; 56 57 private: 58 GeDeviceResManager *res_manager_{nullptr}; 59 }; 60 61 class GeDeviceResManager : public DeviceResManager { 62 public: GeDeviceResManager()63 GeDeviceResManager() {} 64 ~GeDeviceResManager() override = default; 65 66 void Initialize() override; 67 68 void Destroy() override; 69 70 std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list, 71 uint32_t stream_id = kDefaultStreamIndex) const override; 72 73 DeviceAddressPtr CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const override; 74 DeviceAddressPtr CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format, 75 TypeId type_id, const std::string &device_name, uint32_t device_id, 76 uint32_t stream_id) const override; 77 78 static void CreateSessionAndGraphRunner(); 79 void MoveTo(const tensor::TensorPtr &src_tensor, const tensor::TensorPtr &dst_tensor, const std::string &to, 80 bool blocking, bool *return_self) override; 81 82 bool LoadCollectiveCommLib() override; 83 84 void ResetStreamAndCtx() override; 85 bool BindDeviceToCurrentThread(bool force_bind) const override; GetStream()86 void *GetStream() const override { 87 MS_EXCEPTION_IF_NULL(runtime_instance_); 88 return runtime_instance_->compute_stream(); 89 } GetCopyDataStream()90 void *GetCopyDataStream() const { 91 MS_EXCEPTION_IF_NULL(runtime_instance_); 92 return runtime_instance_->copy_data_stream(); 93 } 94 95 // Relevant function to allocate and free device memory of raw ptr. 96 bool AllocateMemory(DeviceAddress *const &address, uint32_t stream_id = UINT32_MAX) const override; 97 void *AllocateMemory(size_t size, uint32_t stream_id = kDefaultStreamIndex) const override; 98 void FreeMemory(void *ptr) const override; 99 void FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs, 100 const std::vector<size_t> &keep_addr_sizes) const override; 101 void DefragMemory() override; 102 103 size_t GetMaxUsedMemorySize() const override; 104 105 // Relevant function to manage memory statistics 106 size_t GetTotalMemStatistics() const override; 107 size_t GetTotalUsedMemStatistics() const override; 108 size_t GetTotalIdleMemStatistics() const override; 109 size_t GetTotalEagerFreeMemStatistics() const override; 110 size_t GetUsedMemPeakStatistics() const override; 111 size_t GetReservedMemPeakStatistics() const override; 112 std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const override; 113 std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const override; 114 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetCommonMemBlocksInfoStatistics() 115 const override; 116 std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> 117 GetPersistentMemBlocksInfoStatistics() const override; 118 void ResetMaxMemoryReserved() const override; 119 void ResetMaxMemoryAllocated() const override; 120 GetAllocator()121 transform::GeAllocatorPtr GetAllocator() { return std::make_shared<GeAllocator>(this); } 122 123 void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override; 124 void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override; 125 126 bool CreateStream(size_t *stream_id) const override; 127 bool CreateStreamWithPriority(size_t *stream_id, int32_t priority) const override; 128 size_t QueryStreamSize() const override; 129 std::vector<uint32_t> GetStreamIds() const override; 130 void *GetStream(size_t stream_id) const override; 131 void SetCurrentStreamId(size_t stream_id) override; 132 size_t GetCurrentStreamId() const override; 133 bool QueryStream(size_t stream_id) const override; 134 bool SyncStream(size_t stream_id = 0) const override; 135 bool SyncAllStreams() const override; 136 bool SyncNotDefaultStreams() const override; 137 size_t DefaultStream() const override; 138 139 DeviceEventPtr CreateRuntimeEvent(bool enable_blocking, bool enable_record_wait); 140 DeviceEventPtr CreateEventWithFlag(bool enable_timing, bool blocking) override; 141 142 bool single_op_multi_stream_enable() const override; 143 void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) override; 144 // Only used in graph_mode with MS_DISABLE_REF_MODE, delete it when delete MS_DISABLE_REF_MODEF 145 void SetCPUMemManager(); 146 147 private: 148 friend class GeGraphExecutor; 149 static void GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr, transform::SessionOptions *options); 150 static void GeSetReuseOptions(const std::string &key, size_t num, transform::SessionOptions *options); 151 KernelRuntime *runtime_instance_ = nullptr; 152 // Only used in graph_mode with MS_DISABLE_REF_MODE, delete it when delete MS_DISABLE_REF_MODE 153 bool is_use_cpu_memory_ = false; 154 }; 155 } // namespace ascend 156 } // namespace device 157 } // namespace mindspore 158 #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEVICE_RES_MANAGER_H_ 159