• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/memory_manager.h"
18 #include <string>
19 #include "backend/session/anf_runtime_algorithm.h"
20 #include "debug/common.h"
21 #ifdef ENABLE_DUMP_IR
22 #include "debug/rdr/running_data_recorder.h"
23 #endif
24 #include "utils/ms_context.h"
25 
26 namespace mindspore {
27 namespace device {
28 constexpr size_t kAlignBytes = 32;
29 
GetCommonAlignSize(size_t input_size)30 size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
31   return (input_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize * kMemAlignSize;
32 }
33 
GetCommunicationAlignSize(size_t input_size)34 size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) {
35   return (input_size + kMemAlignSize - 1) / kMemAlignSize * kMemAlignSize + kTwiceMemAlignSize;
36 }
37 
MallocSomasDynamicMem(const session::KernelGraph & graph)38 void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
39   SomasPtr somas_reuse_util_ptr = std::make_shared<somas::Somas>();
40   MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr);
41   somas_reuse_util_ptr_ = somas_reuse_util_ptr;
42 
43   if (!(somas_reuse_util_ptr->Allocate(&graph))) {
44     MS_LOG(EXCEPTION) << "Somas Allocate Failed.";
45   }
46 
47   size_t total_allocated_size = somas_reuse_util_ptr->GetTotalMemSize();
48   MS_LOG(INFO) << "Graph " << graph.graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]";
49   if (total_allocated_size > 0) {
50     auto base_ptr = MallocDynamicMem(total_allocated_size, false);
51     MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address ["
52                  << static_cast<void *>(base_ptr + total_allocated_size) << "]";
53     somas_reuse_util_ptr->set_mem_base_addr(base_ptr);
54   }
55 
56   auto context_ptr = MsContext::GetInstance();
57   MS_EXCEPTION_IF_NULL(context_ptr);
58 #ifdef ENABLE_DUMP_IR
59   SubModuleId module = SubModuleId::SM_OPTIMIZER;
60 
61   std::string name = "somas_allocate_info." + std::to_string(graph.graph_id());
62   (void)mindspore::RDR::RecordString(module, name, somas_reuse_util_ptr_->SomasInfo());
63 
64   name = "somas_mem_info." + std::to_string(graph.graph_id());
65   (void)mindspore::RDR::RecordString(module, name, somas_reuse_util_ptr_->SomasMemory());
66 #endif
67   bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
68   if (save_graphs) {
69     std::string file_path = GetSaveGraphsPathName("somas_allocate_info_" + std::to_string(graph.graph_id()) + ".ir");
70     somas_reuse_util_ptr_->DumpSomasInfoIR(file_path);
71 
72     std::string mem_file_path = GetSaveGraphsPathName("somas_mem_info_" + std::to_string(graph.graph_id()) + ".ir");
73     somas_reuse_util_ptr_->DumpSomasMemoryIR(mem_file_path);
74   }
75 }
76 
MallocOutputMem(const AnfNodePtr & node,size_t index,MemType type,size_t size,const DeviceAddressPtr & address,bool comm_mem)77 uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
78                                         const DeviceAddressPtr &address, bool comm_mem) {
79   MS_EXCEPTION_IF_NULL(node);
80   MS_EXCEPTION_IF_NULL(address);
81   auto context_ptr = MsContext::GetInstance();
82   MS_EXCEPTION_IF_NULL(context_ptr);
83   uint8_t *ptr = nullptr;
84   if (comm_mem) {
85     bool communication_mem = false;
86     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
87       communication_mem = true;
88     }
89     if (type == kStaticMem) {
90       ptr = MallocStaticMem(size, communication_mem);
91       address->from_mem_pool_ = true;
92       if (communication_mem) {
93         address->communication_ptr_ = ptr - kMemAlignSize;
94       }
95     } else if (type == kSomasReuseDynamicMem) {
96       MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
97       ptr = somas_reuse_util_ptr_->GetNodeOutputPtr(node, index);
98     } else {
99       ptr = MallocDynamicMem(size, communication_mem);
100     }
101     address->ptr_ = ptr;
102     return ptr;
103   }
104 
105   if (type == kStaticMem) {
106     ptr = MallocStaticMem(size, false);
107     address->from_mem_pool_ = true;
108   } else if (type == kDynamicMem) {
109     ptr = MallocDynamicMem(size, false);
110   } else if (type == kSomasReuseDynamicMem) {
111     MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
112     ptr = somas_reuse_util_ptr_->GetNodeOutputPtr(node, index);
113   }
114   address->ptr_ = ptr;
115   return ptr;
116 }
117 
MallocWorkSpaceMem(const AnfNodePtr & node,size_t index,MemType type,size_t size)118 uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
119   if (type == kSomasReuseDynamicMem) {
120     MS_EXCEPTION_IF_NULL(somas_reuse_util_ptr_);
121     return somas_reuse_util_ptr_->GetNodeWorkSpacePtr(node, index);
122   }
123   return MallocDynamicMem(size, false);
124 }
125 
MallocMem(MemType type,size_t size,const DeviceAddressPtr & address,uint32_t graph_id)126 uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id) {
127   MS_EXCEPTION_IF_NULL(address);
128   uint8_t *ptr = nullptr;
129   if (type == kStaticMem) {
130     ptr = MallocStaticMem(size, false, graph_id);
131     address->from_mem_pool_ = true;
132   } else if (type == kDynamicMem) {
133     ptr = MallocDynamicMem(size, false);
134   }
135   address->ptr_ = ptr;
136   return ptr;
137 }
138 
MallocDynamicMem(size_t size,bool communication_mem)139 uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
140   MS_LOG(INFO) << "Call default dynamic malloc " << size << " v " << communication_mem;
141   return nullptr;
142 }
143 
MallocMemFromMemPool(const DeviceAddressPtr address,size_t size)144 bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr address, size_t size) {
145   MS_EXCEPTION_IF_NULL(address);
146   auto device_ptr = MallocMemFromMemPool(size);
147   if (!device_ptr) {
148     return false;
149   }
150   MS_EXCEPTION_IF_NULL(address);
151   address->ptr_ = device_ptr;
152   address->size_ = size;
153   address->from_mem_pool_ = true;
154   return true;
155 }
156 
MallocMemFromMemPool(size_t size)157 void *MemoryManager::MallocMemFromMemPool(size_t size) {
158   if (size == 0) {
159     MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
160   }
161   return nullptr;
162 }
163 
FreeMemFromMemPool(const DeviceAddressPtr address)164 void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
165   MS_EXCEPTION_IF_NULL(address);
166   MS_EXCEPTION_IF_NULL(address->ptr_);
167   FreeMemFromMemPool(address->ptr_);
168   address->ptr_ = nullptr;
169 }
170 
FreeMemFromMemPool(void * device_ptr)171 void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
172   if (device_ptr == nullptr) {
173     MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
174   }
175 }
176 
MallocContinuousMemFromMemPool(const DeviceAddressPtrList & addr_list,size_t total_size,std::vector<size_t> size_list)177 bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
178                                                    std::vector<size_t> size_list) {
179   auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list);
180   if (device_ptr_list.empty()) {
181     return false;
182   }
183   if (addr_list.size() != device_ptr_list.size()) {
184     MS_LOG(EXCEPTION) << "The size of device list " << addr_list.size() << " is not equal to the size of address list "
185                       << device_ptr_list.size();
186   }
187   for (size_t i = 0; i < addr_list.size(); i++) {
188     MS_EXCEPTION_IF_NULL(device_ptr_list[i]);
189     MS_EXCEPTION_IF_NULL(addr_list[i]);
190     addr_list[i]->ptr_ = device_ptr_list[i];
191     addr_list[i]->size_ = size_list[i];
192     addr_list[i]->from_mem_pool_ = true;
193   }
194   return true;
195 }
196 
MallocContinuousMemFromMemPool(size_t total_size,std::vector<size_t> size_list)197 std::vector<void *> MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list) {
198   if (total_size == 0) {
199     MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0.";
200   }
201   std::vector<void *> device_ptr_list;
202   for (size_t i = 0; i < size_list.size(); ++i) {
203     device_ptr_list.emplace_back(nullptr);
204   }
205   return device_ptr_list;
206 }
207 }  // namespace device
208 }  // namespace mindspore
209