• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <string>
17 #include "runtime/device/ascend/ascend_memory_manager.h"
18 #include "runtime/device/ascend/ascend_memory_pool.h"
19 #include "utils/ms_context.h"
20 #include "runtime/mem.h"
21 #ifndef ENABLE_SECURITY
22 #include "runtime/device/ascend/profiling/profiling_manager.h"
23 #include "profiler/device/ascend/memory_profiling.h"
24 
25 using mindspore::device::ascend::ProfilingManager;
26 using mindspore::profiler::ascend::MemoryProfiling;
27 #endif
28 
29 namespace mindspore {
30 namespace device {
31 namespace ascend {
32 namespace {
33 constexpr uint64_t kAscendInitDeviceMemGB = 30;
34 constexpr uint64_t kMemSizeGB = 30;
35 constexpr uint64_t kAscendDeviceMemSize = (kAscendInitDeviceMemGB << kMemSizeGB);
36 
GetDeviceHBMSize()37 uint64_t GetDeviceHBMSize() {
38   size_t free = 0;
39   size_t total = 0;
40   rtError_t ret = rtMemGetInfoEx(RT_MEMORYINFO_HBM, &free, &total);
41   if (ret != RT_ERROR_NONE || total == 0) {
42     MS_LOG(EXCEPTION) << "Get Device HBM memory size failed, ret = " << ret << ", total =  " << total;
43   }
44   return total;
45 }
46 
GetDefaultDeviceMemSize()47 uint64_t GetDefaultDeviceMemSize() {
48   auto total = GetDeviceHBMSize();
49   auto ret = total * 15 / 16;  // reserved memory is 1/16 of total
50   MS_LOG(INFO) << "The Device HBM memory size is " << total << ", allocate " << ret << " for backend.";
51   return ret;
52 }
53 }  // namespace
54 
MallocDeviceMemory()55 void AscendMemoryManager::MallocDeviceMemory() {
56   auto context_mem = GetDeviceMemSizeFromContext();
57   device_mem_size_ = context_mem == 0 ? GetDefaultDeviceMemSize() : context_mem;
58   auto ret = rtMalloc(reinterpret_cast<void **>(&device_mem_base_), device_mem_size_, RT_MEMORY_HBM);
59   if (ret != ACL_RT_SUCCESS) {
60     if (ret == ACL_ERROR_RT_MEMORY_ALLOCATION) {
61       auto context_ptr = MsContext::GetInstance();
62       MS_EXCEPTION_IF_NULL(context_ptr);
63       unsigned int device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
64       MS_LOG(EXCEPTION) << "Malloc device memory failed, size[" << device_mem_size_ << "], ret[" << ret << "], "
65                         << "Device " << device_id
66                         << " may be other processes occupying this card, check as: ps -ef|grep python";
67     } else {
68       MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
69     }
70   } else {
71     MS_LOG(INFO) << "Call rtMalloc to allocate device memory Success, size : " << device_mem_size_
72                  << " bytes , address : " << reinterpret_cast<void *>(device_mem_base_);
73   }
74   AscendMemoryPool::GetInstance().Init(device_mem_base_, device_mem_size_, dynamic_mem_offset_);
75 }
76 
GetDeviceMemSize()77 uint64_t AscendMemoryManager::GetDeviceMemSize() {
78   auto mem_size = GetDeviceMemSizeFromContext();
79   return mem_size == 0 ? GetDefaultDeviceMemSize() : mem_size;
80 }
81 
GetDeviceMemSizeFromContext()82 uint64_t AscendMemoryManager::GetDeviceMemSizeFromContext() {
83   auto context = MsContext::GetInstance();
84   MS_EXCEPTION_IF_NULL(context);
85   auto variable_memory_max_size = context->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
86   if (variable_memory_max_size == "0") {
87     return 0;
88   }
89   MS_LOG(INFO) << "context variable_memory_max_size:" << variable_memory_max_size;
90   auto pos = variable_memory_max_size.find('*');
91   if (pos == std::string::npos) {
92     MS_LOG(EXCEPTION) << "Invalid variable_memory_max_size";
93   }
94   auto gb_str = variable_memory_max_size.substr(0, pos);
95   auto gb_var = std::stoull(gb_str);
96   MS_LOG(INFO) << "variable_memory_max_size(GB):" << gb_var;
97   auto total_hbm_size_GB = GetDeviceHBMSize() >> kMemSizeGB;
98   auto backend_max_size_GB = total_hbm_size_GB - 1;  // reserved 1 GB for other component
99   if (gb_var > backend_max_size_GB || gb_var == 0) {
100     MS_LOG(EXCEPTION) << "The Total Device Memory Size is " << total_hbm_size_GB
101                       << " GB, variable_memory_max_size should be in range (0-" << backend_max_size_GB
102                       << "]GB, but got " << gb_var
103                       << "GB, please set the context key 'variable_memory_max_size' in valid range.";
104   }
105   return gb_var << kMemSizeGB;
106 }
107 
FreeDeviceMemory()108 void AscendMemoryManager::FreeDeviceMemory() {
109   if (device_mem_base_ != nullptr) {
110     auto ret = rtFree(device_mem_base_);
111     if (ret != RT_ERROR_NONE) {
112       MS_LOG(ERROR) << "rtFree mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
113     }
114     device_mem_base_ = nullptr;
115   }
116 }
117 
ResetDynamicMemory()118 void AscendMemoryManager::ResetDynamicMemory() {
119   total_dynamic_size_ = 0;
120   dynamic_mem_offset_ = 0;
121   AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
122 }
123 
ClearGlobalIdleMem()124 void AscendMemoryManager::ClearGlobalIdleMem() { AscendMemoryPool::GetInstance().ResetIdleMemBuf(); }
125 
MallocMemFromMemPool(size_t size)126 void *AscendMemoryManager::MallocMemFromMemPool(size_t size) {
127   auto align_size = GetCommonAlignSize(size);
128   return AscendMemoryPool::GetInstance().AllocTensorMem(align_size);
129 }
130 
FreeMemFromMemPool(void * device_ptr)131 void AscendMemoryManager::FreeMemFromMemPool(void *device_ptr) {
132   AscendMemoryPool::GetInstance().FreeTensorMem(device_ptr);
133 }
134 
MallocStaticMem(size_t size,bool communication_mem,uint32_t graph_id)135 uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id) {
136   size_t align_size = 0;
137   if (communication_mem) {
138     align_size = GetCommunicationAlignSize(size);
139   } else {
140     align_size = GetCommonAlignSize(size);
141   }
142   auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
143   MS_LOG(INFO) << "Malloc Memory for Static: size[" << align_size << "], Memory statistics: total[" << device_mem_size_
144                << "] dynamic [" << total_dynamic_size_ << "] static [" << device_mem_size_ - device_mem_pool_offset
145                << "], Pool statistics: pool total size [" << AscendMemoryPool::GetInstance().total_mem_statistics()
146                << "] used [" << AscendMemoryPool::GetInstance().used_mem_statistics()
147                << "] communication_mem:" << communication_mem;
148 #ifndef ENABLE_SECURITY
149   if (MemoryProfiling::GetInstance().IsMemoryProfilingEnable() && graph_id != kInvalidGraphId) {
150     auto node = MemoryProfiling::GetInstance().GetGraphMemoryNode(graph_id);
151     if (node == nullptr) {
152       node = MemoryProfiling::GetInstance().AddGraphMemoryNode(graph_id);
153       MS_LOG(INFO) << "Add graph memory node for static memory profiling, graph id is " << graph_id;
154     }
155 
156     node->AddStaticMemorySize(SizeToUint(align_size));
157   }
158 #endif
159   if (communication_mem) {
160     // create protect area [kMemAlignSize -- data -- kMemAlignSize]
161     uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
162     return alloc_address + kMemAlignSize;
163   } else {
164     return reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
165   }
166 }
167 
MallocDynamicMem(size_t size,bool communication_mem)168 uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
169   size_t align_size = 0;
170   if (communication_mem) {
171     align_size = GetCommunicationAlignSize(size);
172   } else {
173     align_size = GetCommonAlignSize(size);
174   }
175 
176   auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
177   MS_LOG(INFO) << "Malloc Memory for Dynamic: size[" << align_size << "], Memory statistics: total[" << device_mem_size_
178                << "] dynamic[" << total_dynamic_size_ << "] static[" << device_mem_size_ - device_mem_pool_offset
179                << "] communication_mem: " << communication_mem;
180   auto offset = dynamic_mem_offset_;
181   auto new_offset = dynamic_mem_offset_ + align_size;
182   if (new_offset >= device_mem_pool_offset) {
183     MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
184                       << "] memory pool[" << device_mem_size_ - device_mem_pool_offset << "])"
185                       << " malloc [" << align_size
186                       << "] failed! Please try to reduce 'batch_size' or check whether exists extra large shape. More "
187                          "details can be found in mindspore's FAQ";
188   }
189   total_dynamic_size_ += align_size;
190   dynamic_mem_offset_ = new_offset;
191   AscendMemoryPool::GetInstance().set_graph_dynamic_mem_offset(dynamic_mem_offset_);
192   if (communication_mem) {
193     // create protect area [kMemAlignSize -- data -- kMemAlignSize]
194     return device_mem_base_ + offset + kMemAlignSize;
195   } else {
196     return device_mem_base_ + offset;
197   }
198 }
199 
MallocSomasDynamicMem(const session::KernelGraph & graph)200 void AscendMemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
201   MemoryManager::MallocSomasDynamicMem(graph);
202 #ifndef ENABLE_SECURITY
203   if (MemoryProfiling::GetInstance().IsMemoryProfilingEnable()) {
204     MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
205     somas_reuse_util_ptr_->ConvertToProfilingNode(graph.graph_id());
206   }
207 #endif
208 }
209 
210 // communication memory: [512align_size + data + 512align_size]
211 // return the pointer to the start of data address.
MallocCommunicationMemFromMemPool(size_t size)212 uint8_t *AscendMemoryManager::MallocCommunicationMemFromMemPool(size_t size) {
213   auto align_size = GetCommunicationAlignSize(size);
214   uint8_t *base_ptr = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
215   return base_ptr + kMemAlignSize;
216 }
217 
GetAvailableMemSize()218 size_t AscendMemoryManager::GetAvailableMemSize() {
219   auto available_mem_size = AscendMemoryPool::GetInstance().free_mem_size() +
220                             AscendMemoryPool::GetInstance().total_mem_statistics() -
221                             AscendMemoryPool::GetInstance().used_mem_statistics();
222   return available_mem_size;
223 }
224 
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)225 void AscendMemoryManager::SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
226   if (stream == nullptr) {
227     auto ret_rt_memcpy = rtMemcpy(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE);
228     if (ret_rt_memcpy != RT_ERROR_NONE) {
229       MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpy failed.";
230     }
231   } else {
232     auto ret_rt_memcpy = rtMemcpyAsync(device_ptr, mem_size, host_ptr, mem_size, RT_MEMCPY_HOST_TO_DEVICE, stream);
233     if (ret_rt_memcpy != RT_ERROR_NONE) {
234       MS_EXCEPTION(DeviceProcessError) << "SwapIn rtMemcpyAsync failed.";
235     }
236     if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
237       MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
238     }
239   }
240 }
241 
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)242 void AscendMemoryManager::SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
243   if (stream == nullptr) {
244     auto ret_rt_memcpy = rtMemcpy(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST);
245     if (ret_rt_memcpy != RT_ERROR_NONE) {
246       MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpy failed.";
247     }
248   } else {
249     auto ret_rt_memcpy = rtMemcpyAsync(host_ptr, mem_size, device_ptr, mem_size, RT_MEMCPY_DEVICE_TO_HOST, stream);
250     if (ret_rt_memcpy != RT_ERROR_NONE) {
251       MS_EXCEPTION(DeviceProcessError) << "SwapOut rtMemcpyAsync failed.";
252     }
253     if (rtStreamSynchronize(stream) != RT_ERROR_NONE) {
254       MS_LOG(ERROR) << "Call runtime rtStreamSynchronize error.";
255     }
256   }
257 }
258 }  // namespace ascend
259 }  // namespace device
260 }  // namespace mindspore
261