• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <map>
23 #include <queue>
24 #include <string>
25 #include <unordered_map>
26 #include "backend/common/mem_reuse/mem_reuse.h"
27 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
28 #include "runtime/device/common_somas_allocator.h"
29 
30 namespace mindspore {
31 namespace device {
32 enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
33 constexpr int kGetAllOuts = -1;
34 constexpr uint64_t kMemAlignSize = 512;
35 constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1;
36 using SomasAllocatorPtr = mindspore::device::CommonSomasAllocatorPtr;
37 
38 class BACKEND_EXPORT MemoryManager {
39  public:
40   MemoryManager() = default;
41   virtual ~MemoryManager() = default;
42 
43   virtual void Initialize() = 0;
44   virtual void Finalize() = 0;
ResetDynamicMemory()45   virtual void ResetDynamicMemory() {}
ClearGlobalIdleMem()46   virtual void ClearGlobalIdleMem() {}
47 
48   virtual void MallocSomasDynamicMem(const session::KernelGraph &graph);
49   uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
50                            const DeviceAddressPtr &address, bool comm_mem);
51   uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
52   uint8_t *MallocWorkSpaceMem(size_t size);
53   virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id);
MallocMem(MemType type,size_t size,const DeviceAddressPtr & address)54   virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
55     return MallocMem(type, size, address, kInvalidGraphId);
56   }
57   // param address is the address type of each device
58   // param from_persistent_mem shows whether the tensor is a parameter in Pynative mode
59   virtual bool MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size);
60   virtual void *MallocMemFromMemPool(size_t size, bool from_persistent_mem, bool need_recycle = false,
61                                      uint32_t stream_id = kDefaultStreamIndex);
GetMaxUsedMemorySize()62   virtual size_t GetMaxUsedMemorySize() const { return 0; }
63   virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
64   virtual void FreeMemFromMemPool(void *device_ptr);
65   virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
66                                               std::vector<size_t> size_list, uint32_t stream_id = kDefaultStreamIndex);
67   virtual std::vector<void *> MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list,
68                                                              uint32_t stream_id = kDefaultStreamIndex);
69 
70   static size_t GetCommonAlignSize(size_t input_size);
71   static size_t GetCommunicationAlignSize(size_t input_size);
72 
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)73   virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
74     MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
75   }
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)76   virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
77     MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
78   }
GetAvailableMemSize()79   virtual size_t GetAvailableMemSize() {
80     MS_LOG(ERROR) << "Return default 0 mem size!";
81     return 0;
82   }
83 
RecordEvent(int64_t task_id_on_stream,uint32_t user_stream_id,const std::vector<std::pair<uint32_t,DeviceMemPtr>> & memory_stream_addresses,const DeviceEventPtr & event)84   bool RecordEvent(int64_t task_id_on_stream, uint32_t user_stream_id,
85                    const std::vector<std::pair<uint32_t, DeviceMemPtr>> &memory_stream_addresses,
86                    const DeviceEventPtr &event) {
87     if (memory_pool_ == nullptr) {
88       MS_LOG(WARNING) << "memory_pool_ is nullptr.";
89       return false;
90     }
91     return memory_pool_->RecordEvent(task_id_on_stream, user_stream_id, memory_stream_addresses, event);
92   }
WaitEvent(int64_t task_id_on_stream,uint32_t user_stream_id,uint32_t memory_stream_id)93   bool WaitEvent(int64_t task_id_on_stream, uint32_t user_stream_id, uint32_t memory_stream_id) {
94     if (memory_pool_ == nullptr) {
95       MS_LOG(WARNING) << "memory_pool_ is nullptr.";
96       return false;
97     }
98     return memory_pool_->WaitEvent(task_id_on_stream, user_stream_id, memory_stream_id);
99   }
WaitEvent(int64_t task_id_on_stream,uint32_t memory_stream_id)100   bool WaitEvent(int64_t task_id_on_stream, uint32_t memory_stream_id) {
101     if (memory_pool_ == nullptr) {
102       MS_LOG(WARNING) << "memory_pool_ is nullptr.";
103       return false;
104     }
105     return memory_pool_->WaitEvent(task_id_on_stream, memory_stream_id);
106   }
SyncAllEvents()107   bool SyncAllEvents() {
108     if (memory_pool_ == nullptr) {
109       MS_LOG(WARNING) << "memory_pool_ is nullptr.";
110       return false;
111     }
112     return memory_pool_->SyncAllEvents();
113   }
114 
memory_pool()115   DynamicMemPoolBestFit *memory_pool() { return memory_pool_; }
116 
117   // Relevant function to manage memory statistics
GetTotalMemStatistics()118   virtual size_t GetTotalMemStatistics() const { return 0; }
GetTotalUsedMemStatistics()119   virtual size_t GetTotalUsedMemStatistics() const { return 0; }
GetTotalIdleMemStatistics()120   virtual size_t GetTotalIdleMemStatistics() const { return 0; }
GetTotalEagerFreeMemStatistics()121   virtual size_t GetTotalEagerFreeMemStatistics() const { return 0; }
GetUsedMemPeakStatistics()122   virtual size_t GetUsedMemPeakStatistics() const { return 0; }
GetReservedMemPeakStatistics()123   virtual size_t GetReservedMemPeakStatistics() const { return 0; }
GetBlockCountsStatistics()124   virtual std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const { return {}; }
GetBlockUnitSizeStatistics()125   virtual std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const { return {}; }
126   virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetCommonMemBlocksInfoStatistics()127   GetCommonMemBlocksInfoStatistics() const {
128     return {};
129   }
130   virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetPersistentMemBlocksInfoStatistics()131   GetPersistentMemBlocksInfoStatistics() const {
132     return {};
133   }
ResetMaxMemoryReserved()134   virtual void ResetMaxMemoryReserved() const {}
ResetMaxMemoryAllocated()135   virtual void ResetMaxMemoryAllocated() const {}
136 
137  protected:
138   virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) = 0;
MallocStaticMem(size_t size,bool communication_mem)139   virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem) {
140     return MallocStaticMem(size, communication_mem, kInvalidGraphId);
141   }
142   virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
143   SomasAllocatorPtr somas_allocator_ptr_{nullptr};
144 
145   // Hold memory pool for common operations on memory.
146   DynamicMemPoolBestFit *memory_pool_{nullptr};
147 };
148 }  // namespace device
149 }  // namespace mindspore
150 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
151