1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <string>
17 #include "runtime/device/ascend/ascend_memory_manager.h"
18 #include "runtime/device/ascend/ascend_memory_pool.h"
19 #include "utils/ms_context.h"
20 #include "runtime/mem.h"
21 #ifndef ENABLE_SECURITY
22 #include "runtime/device/ascend/profiling/profiling_manager.h"
23 #include "profiler/device/ascend/memory_profiling.h"
24
25 using mindspore::device::ascend::ProfilingManager;
26 using mindspore::profiler::ascend::MemoryProfiling;
27 #endif
28
29 namespace mindspore {
30 namespace device {
31 namespace ascend {
32 namespace {
33 constexpr uint64_t kAscendInitDeviceMemGB = 30;
34 constexpr uint64_t kMemSizeGB = 30;
35 constexpr uint64_t kAscendDeviceMemSize = (kAscendInitDeviceMemGB << kMemSizeGB);
36
GetDeviceHBMSize()37 uint64_t GetDeviceHBMSize() {
38 size_t free = 0;
39 size_t total = 0;
40 rtError_t ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free, &total);
41 if (ret != RT_ERROR_NONE || total == 0) {
42 MS_LOG(EXCEPTION) << "Get Device HBM memory size failed, ret = " << ret << ", total = " << total;
43 }
44 return total;
45 }
46
GetDefaultDeviceMemSize()47 uint64_t GetDefaultDeviceMemSize() {
48 auto total = GetDeviceHBMSize();
49 auto ret = total * 15 / 16; // reserved memory is 1/16 of total
50 MS_LOG(INFO) << "The Device HBM memory size is " << total << ", allocate " << ret << " for backend.";
51 return ret;
52 }
53 } // namespace
54
MallocDeviceMemory()55 void AscendMemoryManager::MallocDeviceMemory() {
56 auto context_mem = GetDeviceMemSizeFromContext();
57 device_mem_size_ = context_mem == 0 ? GetDefaultDeviceMemSize() : context_mem;
58 auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM);
59 if (ret != ACL_RT_SUCCESS) {
60 if (ret == ACL_ERROR_RT_MEMORY_ALLOCATION) {
61 auto context_ptr = MsContext::GetInstance();
62 MS_EXCEPTION_IF_NULL(context_ptr);
63 unsigned int device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
64 MS_LOG(EXCEPTION) << "Malloc device memory failed, size[" << device_mem_size_ << "], ret[" << ret << "], "
65 << "Device " << device_id
66 << " may be other processes occupying this card, check as: ps -ef|grep python";
67 } else {
68 MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
69 }
70 } else {
71 MS_LOG(INFO) << "Call rtMalloc to allocate device memory Success, size : " << device_mem_size_
72 << " bytes , address : " << reinterpret_cast<void *>(device_mem_base_);
73 }
74 AscendMemoryPool::GetInstance().Init(device_mem_base_, device_mem_size_, dynamic_mem_offset_);
75 }
76
GetDeviceMemSize()77 uint64_t AscendMemoryManager::GetDeviceMemSize() {
78 auto mem_size = GetDeviceMemSizeFromContext();
79 return mem_size == 0 ? GetDefaultDeviceMemSize() : mem_size;
80 }
81
GetDeviceMemSizeFromContext()82 uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
83 auto context = MsContext::GetInstance();
84 MS_EXCEPTION_IF_NULL(context);
85 auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
86 if (variable_memory_max_size == "0") {
87 return 0;
88 }
89 MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size;
90 auto pos = variable_memory_max_size.find('*');
91 if (pos == std::string::npos) {
92 MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size";
93 }
94 auto gb_str = variable_memory_max_size.substr(0, pos);
95 auto gb_var = std::stoull(gb_str);
96 MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var;
97 auto total_hbm_size_GB = GetDeviceHBMSize() >> kMemSizeGB;
98 auto backend_max_size_GB = total_hbm_size_GB - 1; // reserved 1 GB for other component
99 if (gb_var > backend_max_size_GB || gb_var == 0) {
100 MS_LOG(EXCEPTION) << "The Total Device Memory Size is " << total_hbm_size_GB
101 << " GB, variable_memory_max_size should be in range (0-" << backend_max_size_GB
102 << "]GB, but got " << gb_var
103 << "GB, please set the context key 'variable_memory_max_size' in valid range.";
104 }
105 return gb_var << kMemSizeGB;
106 }
107
FreeDeviceMemory()108 void AscendMemoryManager::FreeDeviceMemory() {
109 if (device_mem_base_ != nullptr) {
110 auto ret = rtFree(device_mem_base_);
111 if (ret != RT_ERROR_NONE) {
112 MS_LOG(ERROR) << "rtFree mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
113 }
114 device_mem_base_ = nullptr;
115 }
116 }
117
ResetDynamicMemory()118 void AscendMemoryManager::ResetDynamicMemory() {
119 total_dynamic_size_ = 0;
120 dynamic_mem_offset_ = 0;
121 AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
122 }
123
ClearGlobalIdleMem()124 void AscendMemoryManager::ClearGlobalIdleMem() { AscendMemoryPool::GetInstance().ResetIdleMemBuf(); }
125
MallocMemFromMemPool(size_t size)126 void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
127 auto align_size = GetCommonAlignSize(size);
128 return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
129 }
130
FreeMemFromMemPool(void * device_ptr)131 void AscendMemoryManager::FreeMemFromMemPool(void *device_ptr) {
132 AscendMemoryPool::GetInstance().FreeTensorMem(device_ptr);
133 }
134
MallocStaticMem(size_t size,bool communication_mem,uint32_t graph_id)135 uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) {
136 size_t align_size = 0;
137 if (communication_mem) {
138 align_size = GetCommunicationAlignSize(size);
139 } else {
140 align_size = GetCommonAlignSize(size);
141 }
142 auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
143 MS_LOG(INFO) << "Malloc Memory for Static: size[" << align_size << "], Memory statistics: total[" << device_mem_size_
144 << "] dynamic [" << total_dynamic_size_ << "] static [" << device_mem_size_ - device_mem_pool_offset
145 << "], Pool statistics: pool total size [" << AscendMemoryPool::GetInstance().total_mem_statistics()
146 << "] used [" << AscendMemoryPool::GetInstance().used_mem_statistics()
147 << "] communication_mem:" << communication_mem;
148 #ifndef ENABLE_SECURITY
149 if (MemoryProfiling::GetInstance().IsMemoryProfilingEnable() && graph_id != kInvalidGraphId) {
150 auto node = MemoryProfiling::GetInstance().GetGraphMemoryNode(graph_id);
151 if (node == nullptr) {
152 node = MemoryProfiling::GetInstance().AddGraphMemoryNode(graph_id);
153 MS_LOG(INFO) << "Add graph memory node for static memory profiling, graph id is " << graph_id;
154 }
155
156 node->AddStaticMemorySize(SizeToUint(align_size));
157 }
158 #endif
159 if (communication_mem) {
160 // create protect area [kMemAlignSize -- data -- kMemAlignSize]
161 uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
162 return alloc_address + kMemAlignSize;
163 } else {
164 return reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
165 }
166 }
167
MallocDynamicMem(size_t size,bool communication_mem)168 uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
169 size_t align_size = 0;
170 if (communication_mem) {
171 align_size = GetCommunicationAlignSize(size);
172 } else {
173 align_size = GetCommonAlignSize(size);
174 }
175
176 auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
177 MS_LOG(INFO) << "Malloc Memory for Dynamic: size[" << align_size << "], Memory statistics: total[" << device_mem_size_
178 << "] dynamic[" << total_dynamic_size_ << "] static[" << device_mem_size_ - device_mem_pool_offset
179 << "] communication_mem: " << communication_mem;
180 auto offset = dynamic_mem_offset_;
181 auto new_offset = dynamic_mem_offset_ + align_size;
182 if (new_offset >= device_mem_pool_offset) {
183 MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
184 << "] memory pool[" << device_mem_size_ - device_mem_pool_offset << "])"
185 << " malloc [" << align_size
186 << "] failed! Please try to reduce 'batch_size' or check whether exists extra large shape. More "
187 "details can be found in mindspore's FAQ";
188 }
189 total_dynamic_size_ += align_size;
190 dynamic_mem_offset_ = new_offset;
191 AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
192 if (communication_mem) {
193 // create protect area [kMemAlignSize -- data -- kMemAlignSize]
194 return device_mem_base_ + offset + kMemAlignSize;
195 } else {
196 return device_mem_base_ + offset;
197 }
198 }
199
MallocSomasDynamicMem(const session::KernelGraph & graph)200 void AscendMemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
201 MemoryManager::MallocSomasDynamicMem(graph);
202 #ifndef ENABLE_SECURITY
203 if (MemoryProfiling::GetInstance().IsMemoryProfilingEnable()) {
204 MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
205 somas_reuse_util_ptr_->ConvertToProfilingNode(graph.graph_id());
206 }
207 #endif
208 }
209
210 // communication memory: [512align_size + data + 512align_size]
211 // return the pointer to the start of data address.
MallocCommunicationMemFromMemPool(size_t size)212 uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
213 auto align_size = GetCommunicationAlignSize(size);
214 uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
215 return base_ptr + kMemAlignSize;
216 }
217
GetAvailableMemSize()218 size_t AscendMemoryManager::GetAvailableMemSize() {
219 auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
220 AscendMemoryPool::GetInstance().total_mem_statistics() -
221 AscendMemoryPool::GetInstance().used_mem_statistics();
222 return available_mem_size;
223 }
224
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)225 void AscendMemoryManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
226 if (stream == nullptr) {
227 auto ret_rt_memcpy = rtMemcpy(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE);
228 if (ret_rt_memcpy != RT_ERROR_NONE) {
229 MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpy failed.";
230 }
231 } else {
232 auto ret_rt_memcpy = rtMemcpyAsync(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE, stream);
233 if (ret_rt_memcpy != RT_ERROR_NONE) {
234 MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpyAsync failed.";
235 }
236 if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
237 MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
238 }
239 }
240 }
241
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)242 void AscendMemoryManager::SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
243 if (stream == nullptr) {
244 auto ret_rt_memcpy = rtMemcpy(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST);
245 if (ret_rt_memcpy != RT_ERROR_NONE) {
246 MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpy failed.";
247 }
248 } else {
249 auto ret_rt_memcpy = rtMemcpyAsync(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST, stream);
250 if (ret_rt_memcpy != RT_ERROR_NONE) {
251 MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpyAsync failed.";
252 }
253 if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
254 MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
255 }
256 }
257 }
258 } // namespace ascend
259 } // namespace device
260 } // namespace mindspore
261