• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/memory_manager.h"
18 #include <string>
19 #include "include/backend/anf_runtime_algorithm.h"
20 #include "utils/ms_context.h"
21 
22 namespace mindspore {
23 namespace device {
24 constexpr size_t kAlignBytes = 32;
25 
GetCommonAlignSize(size_t input_size)26 size_t MemoryManager::GetCommonAlignSize(size_t input_size) {
27   return ((input_size + kMemAlignSize + kAlignBytes - 1) / kMemAlignSize) * kMemAlignSize;
28 }
29 
GetCommunicationAlignSize(size_t input_size)30 size_t MemoryManager::GetCommunicationAlignSize(size_t input_size) {
31   return ((input_size + kMemAlignSize - 1) / kMemAlignSize) * kMemAlignSize + kTwiceMemAlignSize;
32 }
33 
MallocSomasDynamicMem(const session::KernelGraph & graph)34 void MemoryManager::MallocSomasDynamicMem(const session::KernelGraph &graph) {
35   SomasAllocatorPtr somas_allocator_ptr = std::make_shared<device::CommonSomasAllocator>();
36   MS_EXCEPTION_IF_NULL(somas_allocator_ptr);
37   somas_allocator_ptr_ = somas_allocator_ptr;
38 
39   if (!(device::CommonSomasAllocator::Assign(graph))) {
40     MS_LOG(EXCEPTION) << "Somas Allocate Failed.";
41   }
42   size_t total_allocated_size = graph.somas_whole_block_size();
43   MS_LOG(INFO) << "Graph " << graph.graph_id() << ": TotalSomasReuseDynamicSize [" << total_allocated_size << "]";
44   if (total_allocated_size > 0) {
45     auto base_ptr = MallocDynamicMem(total_allocated_size, false);
46     MS_LOG(INFO) << "Somas Reuse Memory Base Address [" << static_cast<void *>(base_ptr) << "], End Address ["
47                  << static_cast<void *>(base_ptr + total_allocated_size) << "]";
48     somas_allocator_ptr->set_mem_base_addr(base_ptr);
49   }
50 }
51 
MallocOutputMem(const AnfNodePtr & node,size_t index,MemType type,size_t size,const DeviceAddressPtr & address,bool comm_mem)52 uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
53                                         const DeviceAddressPtr &address, bool comm_mem) {
54   MS_EXCEPTION_IF_NULL(node);
55   MS_EXCEPTION_IF_NULL(address);
56   auto context_ptr = MsContext::GetInstance();
57   MS_EXCEPTION_IF_NULL(context_ptr);
58   uint8_t *ptr = nullptr;
59   if (comm_mem) {
60     bool communication_mem = false;
61     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
62       communication_mem = true;
63     }
64     if (type == kStaticMem) {
65       ptr = MallocStaticMem(size, communication_mem);
66       address->set_from_mem_pool(true);
67       if (communication_mem) {
68         address->set_communication_ptr(ptr - kMemAlignSize);
69       }
70     } else if (type == kSomasReuseDynamicMem) {
71       MS_EXCEPTION_IF_NULL(somas_allocator_ptr_);
72       ptr = somas_allocator_ptr_->GetNodeOutputPtr(node, index);
73     } else {
74       ptr = MallocDynamicMem(size, communication_mem);
75     }
76     address->SetDevicePtr(ptr);
77     return ptr;
78   }
79 
80   if (type == kStaticMem) {
81     ptr = MallocStaticMem(size, false);
82     address->set_from_mem_pool(true);
83   } else if (type == kDynamicMem) {
84     ptr = MallocDynamicMem(size, false);
85   } else if (type == kSomasReuseDynamicMem) {
86     MS_EXCEPTION_IF_NULL(somas_allocator_ptr_);
87     ptr = somas_allocator_ptr_->GetNodeOutputPtr(node, index);
88   }
89   address->SetDevicePtr(ptr);
90   return ptr;
91 }
92 
MallocWorkSpaceMem(const AnfNodePtr & node,size_t index,MemType type,size_t size)93 uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size) {
94   if (type == kSomasReuseDynamicMem) {
95     MS_EXCEPTION_IF_NULL(somas_allocator_ptr_);
96     return somas_allocator_ptr_->GetNodeWorkSpacePtr(node, index);
97   }
98   return MallocDynamicMem(size, false);
99 }
100 
MallocWorkSpaceMem(size_t size)101 uint8_t *MemoryManager::MallocWorkSpaceMem(size_t size) { return MallocDynamicMem(size, false); }
102 
MallocMem(MemType type,size_t size,const DeviceAddressPtr & address,uint32_t graph_id)103 uint8_t *MemoryManager::MallocMem(MemType type, size_t size, const DeviceAddressPtr &address, uint32_t graph_id) {
104   MS_EXCEPTION_IF_NULL(address);
105   uint8_t *ptr = nullptr;
106   if (type == kStaticMem) {
107     ptr = MallocStaticMem(size, false, graph_id);
108     address->set_from_mem_pool(true);
109   } else if (type == kDynamicMem) {
110     ptr = MallocDynamicMem(size, false);
111   }
112   address->SetDevicePtr(ptr);
113   return ptr;
114 }
115 
MallocDynamicMem(size_t size,bool communication_mem)116 uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
117   MS_LOG(INFO) << "Call default dynamic malloc " << size << " v " << communication_mem;
118   return nullptr;
119 }
120 
MallocMemFromMemPool(const DeviceAddressPtr & address,size_t size)121 bool MemoryManager::MallocMemFromMemPool(const DeviceAddressPtr &address, size_t size) {
122   MS_EXCEPTION_IF_NULL(address);
123   auto device_ptr = MallocMemFromMemPool(size, address->from_persistent_mem_, false, address->stream_id());
124   if (!device_ptr) {
125     return false;
126   }
127   MS_EXCEPTION_IF_NULL(address);
128   address->SetDevicePtr(device_ptr);
129   address->SetSize(size);
130   address->set_from_mem_pool(true);
131   return true;
132 }
133 
MallocMemFromMemPool(size_t size,bool from_persistent_mem,bool,uint32_t stream_id)134 void *MemoryManager::MallocMemFromMemPool(size_t size, bool from_persistent_mem, bool, uint32_t stream_id) {
135   if (size == 0) {
136     MS_LOG(ERROR) << "MallocMemFromMemPool size is 0.";
137   }
138   return nullptr;
139 }
140 
MallocContinuousMemFromMemPool(const DeviceAddressPtrList & addr_list,size_t,std::vector<size_t> size_list,uint32_t stream_id)141 bool MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t,
142                                                    std::vector<size_t> size_list, uint32_t stream_id) {
143   auto device_ptr_list = MallocContinuousMemFromMemPool(size_list, stream_id);
144   if (device_ptr_list.empty()) {
145     return false;
146   }
147   if (addr_list.size() != device_ptr_list.size()) {
148     MS_LOG(EXCEPTION) << "The size of device list " << addr_list.size() << " is not equal to the size of address list "
149                       << device_ptr_list.size();
150   }
151   for (size_t i = 0; i < addr_list.size(); i++) {
152     MS_EXCEPTION_IF_NULL(device_ptr_list[i]);
153     MS_EXCEPTION_IF_NULL(addr_list[i]);
154     addr_list[i]->SetDevicePtr(device_ptr_list[i]);
155     addr_list[i]->SetSize(size_list[i]);
156     addr_list[i]->set_from_mem_pool(true);
157   }
158   return true;
159 }
160 
FreeMemFromMemPool(const DeviceAddressPtr address)161 void MemoryManager::FreeMemFromMemPool(const DeviceAddressPtr address) {
162   MS_EXCEPTION_IF_NULL(address);
163   MS_EXCEPTION_IF_NULL(address->GetDevicePtr());
164   FreeMemFromMemPool(address->GetDevicePtr());
165   address->SetDevicePtr(nullptr);
166 }
167 
FreeMemFromMemPool(void * device_ptr)168 void MemoryManager::FreeMemFromMemPool(void *device_ptr) {
169   if (device_ptr == nullptr) {
170     MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null.";
171   }
172 }
173 
MallocContinuousMemFromMemPool(const std::vector<size_t> & size_list,uint32_t stream_id)174 std::vector<void *> MemoryManager::MallocContinuousMemFromMemPool(const std::vector<size_t> &size_list,
175                                                                   uint32_t stream_id) {
176   if (size_list.empty()) {
177     MS_LOG(ERROR) << "MallocContinuousMemFromMemPool size list's size is 0.";
178   }
179   std::vector<void *> device_ptr_list;
180   for (size_t i = 0; i < size_list.size(); ++i) {
181     (void)device_ptr_list.emplace_back(nullptr);
182   }
183   return device_ptr_list;
184 }
185 }  // namespace device
186 }  // namespace mindspore
187