1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/memory_scheduler.h"
18 #include <algorithm>
19 #include "utils/log_adapter.h"
20
21 namespace mindspore {
22 namespace device {
Clear()23 void MemScheduler::Clear() {
24 if (mem_handler_ == nullptr) {
25 return;
26 }
27 for (auto &item : high_priority_device_ptr_) {
28 mem_handler_->FreeDevice(item.second);
29 }
30 high_priority_device_ptr_.clear();
31 }
32
IsHighPriorityMem(const void * key)33 bool MemScheduler::IsHighPriorityMem(const void *key) {
34 auto iter = mem_priority_.find(key);
35 if (iter != mem_priority_.end()) {
36 return iter->second == kMemPriorityHigh;
37 }
38 return false;
39 }
40
SetMemPriority(const void * key,MemPriority priority)41 void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
42
Record(const void * key,const EventType & event_type,size_t mem_size)43 void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
44 if (key == nullptr) {
45 return;
46 }
47 auto event = std::make_shared<Event>(event_type, compute_index_);
48 event->mem_size = mem_size;
49 event->key = key;
50 (void)mem_events_[key].emplace_back(event);
51 }
52
Init(const void * key,void * host_ptr,size_t mem_size,MemPriority priority)53 void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
54 if (need_record_event_) {
55 mem_priority_[key] = priority;
56 Record(key, kInit, mem_size);
57 } else {
58 init_host_ptr_[key] = host_ptr;
59 }
60 }
61
GetOrMalloc(const void * key,size_t mem_size,MemPriority priority)62 void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
63 if (need_record_event_) {
64 if (mem_priority_.find(key) == mem_priority_.end()) {
65 mem_priority_[key] = priority;
66 Record(key, kMalloc, mem_size);
67 } else {
68 Record(key, kGet, mem_size);
69 }
70 return nullptr;
71 }
72 auto iter = mem_result_.find(key);
73 if (iter != mem_result_.end()) {
74 auto ptr = iter->second;
75 MS_EXCEPTION_IF_NULL(ptr);
76 return ptr;
77 } else {
78 MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
79 }
80 }
81
PreCompute(void * stream)82 bool MemScheduler::PreCompute(void *stream) {
83 if (need_record_event_) {
84 return true;
85 }
86 MS_EXCEPTION_IF_NULL(mem_handler_);
87 if (pre_compute_events_.size() <= compute_index_) {
88 MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
89 << ", event size:" << pre_compute_events_.size();
90 }
91 auto &events = pre_compute_events_[compute_index_];
92 for (auto &event : events) {
93 MS_EXCEPTION_IF_NULL(event);
94 MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
95 if (event->type == kInit) {
96 auto host_ptr = init_host_ptr_[event->key];
97 MS_EXCEPTION_IF_NULL(host_ptr);
98 auto priority = mem_priority_[event->key];
99 auto iter = high_priority_device_ptr_.find(event->key);
100 if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
101 MS_EXCEPTION_IF_NULL(iter->second);
102 mem_result_[event->key] = iter->second;
103 if (priority == kMemPriorityMedium) {
104 mem_handler_->SwapIn(host_ptr, iter->second, event->mem_size, stream);
105 }
106 continue;
107 }
108 auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
109 if (device_ptr == nullptr) {
110 return false;
111 }
112 if (priority != kMemPriorityLow) {
113 high_priority_device_ptr_[event->key] = device_ptr;
114 }
115 mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
116 mem_result_[event->key] = device_ptr;
117 } else if (event->type == kMalloc) {
118 auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
119 if (device_ptr == nullptr) {
120 return false;
121 }
122 mem_result_[event->key] = device_ptr;
123 } else if (event->type == kSwapIn) {
124 bool from_init = true;
125 auto host_ptr = init_host_ptr_[event->key];
126 if (host_ptr == nullptr) {
127 host_ptr = swap_host_ptr_[event->key];
128 from_init = false;
129 }
130 auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
131 if (device_ptr == nullptr) {
132 return false;
133 }
134 MS_EXCEPTION_IF_NULL(host_ptr);
135 mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
136 mem_result_[event->key] = device_ptr;
137 if (!from_init) {
138 mem_handler_->FreeHost(host_ptr);
139 (void)swap_host_ptr_.erase(event->key);
140 }
141 }
142 }
143 return true;
144 }
145
PostCompute(void * stream)146 bool MemScheduler::PostCompute(void *stream) {
147 if (need_record_event_) {
148 ++compute_index_;
149 return true;
150 }
151 MS_EXCEPTION_IF_NULL(mem_handler_);
152 if (post_compute_events_.size() <= compute_index_) {
153 MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
154 << ", event size:" << post_compute_events_.size();
155 }
156 auto &events = post_compute_events_[compute_index_];
157 for (auto &event : events) {
158 MS_EXCEPTION_IF_NULL(event);
159 MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
160 if (event->type == kFree) {
161 auto ptr = mem_result_[event->key];
162 if (ptr == nullptr) {
163 return false;
164 }
165 mem_handler_->FreeDevice(ptr);
166 (void)mem_result_.erase(event->key);
167 } else if (event->type == kSwapOut) {
168 auto device_ptr = mem_result_[event->key];
169 if (device_ptr == nullptr) {
170 return false;
171 }
172 auto host_ptr = init_host_ptr_[event->key];
173 if (host_ptr == nullptr) {
174 host_ptr = mem_handler_->MallocHost(event->mem_size);
175 swap_host_ptr_[event->key] = host_ptr;
176 }
177 MS_EXCEPTION_IF_NULL(host_ptr);
178 mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
179 mem_handler_->FreeDevice(device_ptr);
180 (void)mem_result_.erase(device_ptr);
181 }
182 }
183 ++compute_index_;
184 return true;
185 }
186
OptMemUsage()187 void MemScheduler::OptMemUsage() {
188 need_record_event_ = false;
189 if (optimized_) {
190 return;
191 }
192 CountMemUsage();
193 CheckMemSize();
194 if (need_swap_) {
195 GenEventSpan();
196 GenNoSwapEventSet();
197 }
198 GenEvents();
199 }
200
CountMemUsage()201 void MemScheduler::CountMemUsage() {
202 if (!min_mem_used_.empty()) {
203 return;
204 }
205 min_mem_used_.resize(compute_index_, 0);
206 std::vector<size_t> total_mem_used(compute_index_, 0);
207 for (auto &item : mem_events_) {
208 auto &mem_events = item.second;
209 if (mem_events.empty()) {
210 continue;
211 }
212 auto first_event = mem_events[0];
213 MS_EXCEPTION_IF_NULL(first_event);
214 size_t i = 0;
215 if (first_event->type == kInit && mem_events.size() > 1) {
216 first_event = mem_events[1];
217 i = 1;
218 }
219 auto last_event = mem_events[mem_events.size() - 1];
220 for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
221 if (start_index < compute_index_) {
222 total_mem_used[start_index] += first_event->mem_size;
223 } else {
224 MS_LOG(ERROR) << "Error mem event index " << start_index;
225 }
226 }
227 for (; i < mem_events.size(); ++i) {
228 auto &event = mem_events[i];
229 MS_EXCEPTION_IF_NULL(event);
230 if (event->index < compute_index_) {
231 min_mem_used_[event->index] += first_event->mem_size;
232 } else {
233 MS_LOG(ERROR) << "Error mem event index " << event->index;
234 }
235 }
236 }
237 min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
238 mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
239 }
240
CheckMemSize()241 void MemScheduler::CheckMemSize() {
242 MS_EXCEPTION_IF_NULL(mem_handler_);
243 auto available_mem_size = mem_handler_->GetAvailableMemSize();
244 if (available_mem_size < min_mem_needed_) {
245 MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
246 << " while graph needs at least " << min_mem_needed_;
247 }
248 if (mem_used_without_swap_ > available_mem_size) {
249 need_swap_ = true;
250 }
251 MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size:" << mem_used_without_swap_
252 << "without swap, and needs at least " << min_mem_needed_ << " with swap.";
253 }
254
GenEventSpan()255 void MemScheduler::GenEventSpan() {
256 if (!event_span_.empty()) {
257 return;
258 }
259 for (auto &item : mem_events_) {
260 auto &mem_events = item.second;
261 if (mem_events.empty()) {
262 continue;
263 }
264 auto first_event = mem_events[0];
265 MS_EXCEPTION_IF_NULL(first_event);
266 size_t i = 0;
267 if (first_event->type == kInit && mem_events.size() > 1) {
268 first_event = mem_events[1];
269 i = 1;
270 }
271 size_t last_index = first_event->index;
272 for (; i < mem_events.size(); ++i) {
273 auto &event = mem_events[i];
274 MS_EXCEPTION_IF_NULL(event);
275 auto span = event->index - last_index;
276 if (span > 1) {
277 (void)event_span_.emplace(std::pair<size_t, std::shared_ptr<Event>>(span, event));
278 }
279 last_index = event->index;
280 }
281 }
282 }
283
GenNoSwapEventSet()284 void MemScheduler::GenNoSwapEventSet() {
285 MS_EXCEPTION_IF_NULL(mem_handler_);
286 auto available_mem_size = mem_handler_->GetAvailableMemSize();
287 auto threshold = available_mem_size * mem_used_factor_;
288 no_swap_events_.clear();
289 std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
290 for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
291 auto span = iter->first;
292 auto &event = iter->second;
293 auto start_index = event->index - span + 1;
294 bool revert = false;
295 for (size_t i = start_index; i < event->index; ++i) {
296 cur_mem_used[i] += event->mem_size;
297 if (cur_mem_used[i] > threshold) {
298 revert = true;
299 }
300 }
301 if (revert) {
302 for (size_t i = start_index; i < event->index; ++i) {
303 cur_mem_used[i] -= event->mem_size;
304 }
305 } else {
306 (void)no_swap_events_.emplace(event);
307 }
308 }
309 }
310
GenEvents()311 void MemScheduler::GenEvents() {
312 pre_compute_events_.resize(compute_index_);
313 post_compute_events_.resize(compute_index_);
314 for (auto &item : mem_events_) {
315 auto &mem_events = item.second;
316 if (mem_events.empty()) {
317 continue;
318 }
319 auto first_event = mem_events[0];
320 MS_EXCEPTION_IF_NULL(first_event);
321 if (first_event->type == kInit) {
322 if (mem_events.size() > 1) {
323 auto &second_event = mem_events[1];
324 MS_EXCEPTION_IF_NULL(second_event);
325 first_event->index = second_event->index;
326 } else {
327 continue;
328 }
329 }
330 if ((first_event->type == kInit || first_event->type == kMalloc) &&
331 first_event->index < pre_compute_events_.size()) {
332 (void)pre_compute_events_[first_event->index].emplace_back(first_event);
333 } else {
334 MS_LOG_EXCEPTION << "First event should be init or malloc!";
335 }
336 MemPriority priority = kMemPriorityLow;
337 auto iter = mem_priority_.find(first_event->key);
338 if (iter != mem_priority_.end()) {
339 priority = iter->second;
340 }
341 size_t pre_index = first_event->index;
342 for (size_t i = 1; i < mem_events.size(); ++i) {
343 auto &event = mem_events[i];
344 MS_EXCEPTION_IF_NULL(event);
345 if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
346 no_swap_events_.find(event) == no_swap_events_.end()) {
347 auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
348 swap_out_event->key = item.first;
349 swap_out_event->mem_size = first_event->mem_size;
350 (void)post_compute_events_[pre_index].emplace_back(swap_out_event);
351 auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
352 swap_in_event->key = item.first;
353 swap_in_event->mem_size = first_event->mem_size;
354 (void)pre_compute_events_[event->index].emplace_back(swap_in_event);
355 }
356 if (event->index < pre_compute_events_.size()) {
357 (void)pre_compute_events_[event->index].emplace_back(event);
358 }
359 pre_index = event->index;
360 }
361 if (priority != kMemPriorityLow) {
362 continue;
363 }
364 auto &last_event = mem_events[mem_events.size() - 1];
365 MS_EXCEPTION_IF_NULL(last_event);
366 auto free_event = std::make_shared<Event>(kFree, last_event->index);
367 free_event->key = item.first;
368 if (last_event->index < post_compute_events_.size()) {
369 (void)post_compute_events_[last_event->index].emplace_back(free_event);
370 }
371 }
372 }
373 } // namespace device
374 } // namespace mindspore
375