1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/memory_manager.h"
18 #include <string>
19 #include "include/backend/anf_runtime_algorithm.h"
20 #include "utils/ms_context.h"
21
22 namespace mindspore {
23 namespace device {
24 constexpr size_t kAlignBytes = 32;
25
GetCommonAlignSize(size_t input_size)26 size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
27 return ((input_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize) * kMemAlignSize;
28 }
29
GetCommunicationAlignSize(size_t input_size)30 size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) {
31 return ((input_size + kMemAlignSize - 1) / kMemAlignSize) * kMemAlignSize + kTwiceMemAlignSize;
32 }
33
MallocSomasDynamicMem(const session::KernelGraph & graph)34 void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
35 SomasAllocatorPtr somas_allocator_ptr = std::make_shared<device::CommonSomasAllocator>();
36 MS_EXCEPTION_IF_NULL(somas_allocator_ptr);
37 somas_allocator_ptr_ = somas_allocator_ptr;
38
39 if (!(device::CommonSomasAllocator::Assign(graph))) {
40 MS_LOG(EXCEPTION) << "Somas Allocate Failed.";
41 }
42 size_t total_allocated_size = graph.somas_whole_block_size();
43 MS_LOG(INFO) << "Graph " << graph.graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]";
44 if (total_allocated_size > 0) {
45 auto base_ptr = MallocDynamicMem(total_allocated_size, false);
46 MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address ["
47 << static_cast<void *>(base_ptr + total_allocated_size) << "]";
48 somas_allocator_ptr->set_mem_base_addr(base_ptr);
49 }
50 }
51
MallocOutputMem(const AnfNodePtr & node,size_t index,MemType type,size_t size,const DeviceAddressPtr & address,bool comm_mem)52 uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
53 const DeviceAddressPtr &address, bool comm_mem) {
54 MS_EXCEPTION_IF_NULL(node);
55 MS_EXCEPTION_IF_NULL(address);
56 auto context_ptr = MsContext::GetInstance();
57 MS_EXCEPTION_IF_NULL(context_ptr);
58 uint8_t *ptr = nullptr;
59 if (comm_mem) {
60 bool communication_mem = false;
61 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
62 communication_mem = true;
63 }
64 if (type == kStaticMem) {
65 ptr = MallocStaticMem(size, communication_mem);
66 address->set_from_mem_pool(true);
67 if (communication_mem) {
68 address->set_communication_ptr(ptr - kMemAlignSize);
69 }
70 } else if (type == kSomasReuseDynamicMem) {
71 MS_EXCEPTION_IF_NULL(somas_allocator_ptr_);
72 ptr = somas_allocator_ptr_->GetNodeOutputPtr(node, index);
73 } else {
74 ptr = MallocDynamicMem(size, communication_mem);
75 }
76 address->SetDevicePtr(ptr);
77 return ptr;
78 }
79
80 if (type == kStaticMem) {
81 ptr = MallocStaticMem(size, false);
82 address->set_from_mem_pool(true);
83 } else if (type == kDynamicMem) {
84 ptr = MallocDynamicMem(size, false);
85 } else if (type == kSomasReuseDynamicMem) {
86 MS_EXCEPTION_IF_NULL(somas_allocator_ptr_);
87 ptr = somas_allocator_ptr_->GetNodeOutputPtr(node, index);
88 }
89 address->SetDevicePtr(ptr);
90 return ptr;
91 }
92
MallocWorkSpaceMem(const AnfNodePtr & node,size_t index,MemType type,size_t size)93 uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
94 if (type == kSomasReuseDynamicMem) {
95 MS_EXCEPTION_IF_NULL(somas_allocator_ptr_);
96 return somas_allocator_ptr_->GetNodeWorkSpacePtr(node, index);
97 }
98 return MallocDynamicMem(size, false);
99 }
100
MallocWorkSpaceMem(size_t size)101 uint8_t *MemoryManager::MallocWorkSpaceMem(size_t size) { return MallocDynamicMem(size, false); }
102
MallocMem(MemType type,size_t size,const DeviceAddressPtr & address,uint32_t graph_id)103 uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id) {
104 MS_EXCEPTION_IF_NULL(address);
105 uint8_t *ptr = nullptr;
106 if (type == kStaticMem) {
107 ptr = MallocStaticMem(size, false, graph_id);
108 address->set_from_mem_pool(true);
109 } else if (type == kDynamicMem) {
110 ptr = MallocDynamicMem(size, false);
111 }
112 address->SetDevicePtr(ptr);
113 return ptr;
114 }
115
MallocDynamicMem(size_t size,bool communication_mem)116 uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
117 MS_LOG(INFO) << "Call default dynamic malloc " << size << " v " << communication_mem;
118 return nullptr;
119 }
120
MallocMemFromMemPool(const DeviceAddressPtr & address,size_t size)121 bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size) {
122 MS_EXCEPTION_IF_NULL(address);
123 auto device_ptr = MallocMemFromMemPool(size, address->from_persistent_mem_, false, address->stream_id());
124 if (!device_ptr) {
125 return false;
126 }
127 MS_EXCEPTION_IF_NULL(address);
128 address->SetDevicePtr(device_ptr);
129 address->SetSize(size);
130 address->set_from_mem_pool(true);
131 return true;
132 }
133
MallocMemFromMemPool(size_t size,bool from_persistent_mem,bool,uint32_t stream_id)134 void *MemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem, bool, uint32_t stream_id) {
135 if (size == 0) {
136 MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
137 }
138 return nullptr;
139 }
140
MallocContinuousMemFromMemPool(const DeviceAddressPtrList & addr_list,size_t,std::vector<size_t> size_list,uint32_t stream_id)141 bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t,
142 std::vector<size_t> size_list, uint32_t stream_id) {
143 auto device_ptr_list = MallocContinuousMemFromMemPool(size_list, stream_id);
144 if (device_ptr_list.empty()) {
145 return false;
146 }
147 if (addr_list.size() != device_ptr_list.size()) {
148 MS_LOG(EXCEPTION) << "The size of device list " << addr_list.size() << " is not equal to the size of address list "
149 << device_ptr_list.size();
150 }
151 for (size_t i = 0; i < addr_list.size(); i++) {
152 MS_EXCEPTION_IF_NULL(device_ptr_list[i]);
153 MS_EXCEPTION_IF_NULL(addr_list[i]);
154 addr_list[i]->SetDevicePtr(device_ptr_list[i]);
155 addr_list[i]->SetSize(size_list[i]);
156 addr_list[i]->set_from_mem_pool(true);
157 }
158 return true;
159 }
160
FreeMemFromMemPool(const DeviceAddressPtr address)161 void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
162 MS_EXCEPTION_IF_NULL(address);
163 MS_EXCEPTION_IF_NULL(address->GetDevicePtr());
164 FreeMemFromMemPool(address->GetDevicePtr());
165 address->SetDevicePtr(nullptr);
166 }
167
FreeMemFromMemPool(void * device_ptr)168 void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
169 if (device_ptr == nullptr) {
170 MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
171 }
172 }
173
MallocContinuousMemFromMemPool(const std::vector<size_t> & size_list,uint32_t stream_id)174 std::vector<void *> MemoryManager::MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list,
175 uint32_t stream_id) {
176 if (size_list.empty()) {
177 MS_LOG(ERROR) << "MallocContinuousMemFromMemPool size list's size is 0.";
178 }
179 std::vector<void *> device_ptr_list;
180 for (size_t i = 0; i < size_list.size(); ++i) {
181 (void)device_ptr_list.emplace_back(nullptr);
182 }
183 return device_ptr_list;
184 }
185 } // namespace device
186 } // namespace mindspore
187