1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/memory_manager.h"
18 #include <string>
19 #include "backend/session/anf_runtime_algorithm.h"
20 #include "debug/common.h"
21 #ifdef ENABLE_DUMP_IR
22 #include "debug/rdr/running_data_recorder.h"
23 #endif
24 #include "utils/ms_context.h"
25
26 namespace mindspore {
27 namespace device {
28 constexpr size_t kAlignBytes = 32;
29
GetCommonAlignSize(size_t input_size)30 size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
31 return (input_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize * kMemAlignSize;
32 }
33
GetCommunicationAlignSize(size_t input_size)34 size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) {
35 return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + kTwiceMemAlignSize;
36 }
37
MallocSomasDynamicMem(const session::KernelGraph & graph)38 void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
39 SomasPtr somas_reuse_util_ptr = std::make_shared<somas::Somas>();
40 MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr);
41 somas_reuse_util_ptr_ = somas_reuse_util_ptr;
42
43 if (!(somas_reuse_util_ptr->Allocate(&graph))) {
44 MS_LOG(EXCEPTION) << "Somas Allocate Failed.";
45 }
46
47 size_t total_allocated_size = somas_reuse_util_ptr->GetTotalMemSize();
48 MS_LOG(INFO) << "Graph " << graph.graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]";
49 if (total_allocated_size > 0) {
50 auto base_ptr = MallocDynamicMem(total_allocated_size, false);
51 MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address ["
52 << static_cast<void *>(base_ptr + total_allocated_size) << "]";
53 somas_reuse_util_ptr->set_mem_base_addr(base_ptr);
54 }
55
56 auto context_ptr = MsContext::GetInstance();
57 MS_EXCEPTION_IF_NULL(context_ptr);
58 #ifdef ENABLE_DUMP_IR
59 SubModuleId module = SubModuleId::SM_OPTIMIZER;
60
61 std::string name = "somas_allocate_info." + std::to_string(graph.graph_id());
62 (void)mindspore::RDR::RecordString(module, name, somas_reuse_util_ptr_->SomasInfo());
63
64 name = "somas_mem_info." + std::to_string(graph.graph_id());
65 (void)mindspore::RDR::RecordString(module, name, somas_reuse_util_ptr_->SomasMemory());
66 #endif
67 bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
68 if (save_graphs) {
69 std::string file_path = GetSaveGraphsPathName("somas_allocate_info_" + std::to_string(graph.graph_id()) + ".ir");
70 somas_reuse_util_ptr_->DumpSomasInfoIR(file_path);
71
72 std::string mem_file_path = GetSaveGraphsPathName("somas_mem_info_" + std::to_string(graph.graph_id()) + ".ir");
73 somas_reuse_util_ptr_->DumpSomasMemoryIR(mem_file_path);
74 }
75 }
76
MallocOutputMem(const AnfNodePtr & node,size_t index,MemType type,size_t size,const DeviceAddressPtr & address,bool comm_mem)77 uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
78 const DeviceAddressPtr &address, bool comm_mem) {
79 MS_EXCEPTION_IF_NULL(node);
80 MS_EXCEPTION_IF_NULL(address);
81 auto context_ptr = MsContext::GetInstance();
82 MS_EXCEPTION_IF_NULL(context_ptr);
83 uint8_t *ptr = nullptr;
84 if (comm_mem) {
85 bool communication_mem = false;
86 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
87 communication_mem = true;
88 }
89 if (type == kStaticMem) {
90 ptr = MallocStaticMem(size, communication_mem);
91 address->from_mem_pool_ = true;
92 if (communication_mem) {
93 address->communication_ptr_ = ptr - kMemAlignSize;
94 }
95 } else if (type == kSomasReuseDynamicMem) {
96 MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
97 ptr = somas_reuse_util_ptr_->GetNodeOutputPtr(node, index);
98 } else {
99 ptr = MallocDynamicMem(size, communication_mem);
100 }
101 address->ptr_ = ptr;
102 return ptr;
103 }
104
105 if (type == kStaticMem) {
106 ptr = MallocStaticMem(size, false);
107 address->from_mem_pool_ = true;
108 } else if (type == kDynamicMem) {
109 ptr = MallocDynamicMem(size, false);
110 } else if (type == kSomasReuseDynamicMem) {
111 MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
112 ptr = somas_reuse_util_ptr_->GetNodeOutputPtr(node, index);
113 }
114 address->ptr_ = ptr;
115 return ptr;
116 }
117
MallocWorkSpaceMem(const AnfNodePtr & node,size_t index,MemType type,size_t size)118 uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
119 if (type == kSomasReuseDynamicMem) {
120 MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
121 return somas_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
122 }
123 return MallocDynamicMem(size, false);
124 }
125
MallocMem(MemType type,size_t size,const DeviceAddressPtr & address,uint32_t graph_id)126 uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id) {
127 MS_EXCEPTION_IF_NULL(address);
128 uint8_t *ptr = nullptr;
129 if (type == kStaticMem) {
130 ptr = MallocStaticMem(size, false, graph_id);
131 address->from_mem_pool_ = true;
132 } else if (type == kDynamicMem) {
133 ptr = MallocDynamicMem(size, false);
134 }
135 address->ptr_ = ptr;
136 return ptr;
137 }
138
MallocDynamicMem(size_t size,bool communication_mem)139 uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
140 MS_LOG(INFO) << "Call default dynamic malloc " << size << " v " << communication_mem;
141 return nullptr;
142 }
143
MallocMemFromMemPool(const DeviceAddressPtr address,size_t size)144 bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) {
145 MS_EXCEPTION_IF_NULL(address);
146 auto device_ptr = MallocMemFromMemPool(size);
147 if (!device_ptr) {
148 return false;
149 }
150 MS_EXCEPTION_IF_NULL(address);
151 address->ptr_ = device_ptr;
152 address->size_ = size;
153 address->from_mem_pool_ = true;
154 return true;
155 }
156
MallocMemFromMemPool(size_t size)157 void *MemoryManager::MallocMemFromMemPool(size_t size) {
158 if (size == 0) {
159 MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
160 }
161 return nullptr;
162 }
163
FreeMemFromMemPool(const DeviceAddressPtr address)164 void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
165 MS_EXCEPTION_IF_NULL(address);
166 MS_EXCEPTION_IF_NULL(address->ptr_);
167 FreeMemFromMemPool(address->ptr_);
168 address->ptr_ = nullptr;
169 }
170
FreeMemFromMemPool(void * device_ptr)171 void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
172 if (device_ptr == nullptr) {
173 MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
174 }
175 }
176
MallocContinuousMemFromMemPool(const DeviceAddressPtrList & addr_list,size_t total_size,std::vector<size_t> size_list)177 bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
178 std::vector<size_t> size_list) {
179 auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list);
180 if (device_ptr_list.empty()) {
181 return false;
182 }
183 if (addr_list.size() != device_ptr_list.size()) {
184 MS_LOG(EXCEPTION) << "The size of device list " << addr_list.size() << " is not equal to the size of address list "
185 << device_ptr_list.size();
186 }
187 for (size_t i = 0; i < addr_list.size(); i++) {
188 MS_EXCEPTION_IF_NULL(device_ptr_list[i]);
189 MS_EXCEPTION_IF_NULL(addr_list[i]);
190 addr_list[i]->ptr_ = device_ptr_list[i];
191 addr_list[i]->size_ = size_list[i];
192 addr_list[i]->from_mem_pool_ = true;
193 }
194 return true;
195 }
196
MallocContinuousMemFromMemPool(size_t total_size,std::vector<size_t> size_list)197 std::vector<void *> MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) {
198 if (total_size == 0) {
199 MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0.";
200 }
201 std::vector<void *> device_ptr_list;
202 for (size_t i = 0; i < size_list.size(); ++i) {
203 device_ptr_list.emplace_back(nullptr);
204 }
205 return device_ptr_list;
206 }
207 } // namespace device
208 } // namespace mindspore
209