• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/memory_scheduler.h"
18 #include <algorithm>
19 #include "utils/log_adapter.h"
20 
21 namespace mindspore {
22 namespace device {
Clear()23 void MemScheduler::Clear() {
24   if (mem_handler_ == nullptr) {
25     return;
26   }
27   for (auto &item : high_priority_device_ptr_) {
28     mem_handler_->FreeDevice(item.second);
29   }
30   high_priority_device_ptr_.clear();
31 }
32 
IsHighPriorityMem(const void * key)33 bool MemScheduler::IsHighPriorityMem(const void *key) {
34   auto iter = mem_priority_.find(key);
35   if (iter != mem_priority_.end()) {
36     return iter->second == kMemPriorityHigh;
37   }
38   return false;
39 }
40 
SetMemPriority(const void * key,MemPriority priority)41 void MemScheduler::SetMemPriority(const void *key, MemPriority priority) { mem_priority_[key] = priority; }
42 
Record(const void * key,const EventType & event_type,size_t mem_size)43 void MemScheduler::Record(const void *key, const EventType &event_type, size_t mem_size) {
44   if (key == nullptr) {
45     return;
46   }
47   auto event = std::make_shared<Event>(event_type, compute_index_);
48   event->mem_size = mem_size;
49   event->key = key;
50   (void)mem_events_[key].emplace_back(event);
51 }
52 
Init(const void * key,void * host_ptr,size_t mem_size,MemPriority priority)53 void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
54   if (need_record_event_) {
55     mem_priority_[key] = priority;
56     Record(key, kInit, mem_size);
57   } else {
58     init_host_ptr_[key] = host_ptr;
59   }
60 }
61 
GetOrMalloc(const void * key,size_t mem_size,MemPriority priority)62 void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
63   if (need_record_event_) {
64     if (mem_priority_.find(key) == mem_priority_.end()) {
65       mem_priority_[key] = priority;
66       Record(key, kMalloc, mem_size);
67     } else {
68       Record(key, kGet, mem_size);
69     }
70     return nullptr;
71   }
72   auto iter = mem_result_.find(key);
73   if (iter != mem_result_.end()) {
74     auto ptr = iter->second;
75     MS_EXCEPTION_IF_NULL(ptr);
76     return ptr;
77   } else {
78     MS_LOG_EXCEPTION << "Mem extender get nullptr result!";
79   }
80 }
81 
PreCompute(void * stream)82 bool MemScheduler::PreCompute(void *stream) {
83   if (need_record_event_) {
84     return true;
85   }
86   MS_EXCEPTION_IF_NULL(mem_handler_);
87   if (pre_compute_events_.size() <= compute_index_) {
88     MS_LOG_EXCEPTION << "Index out of pre event range, index:" << compute_index_
89                      << ", event size:" << pre_compute_events_.size();
90   }
91   auto &events = pre_compute_events_[compute_index_];
92   for (auto &event : events) {
93     MS_EXCEPTION_IF_NULL(event);
94     MS_LOG(DEBUG) << "Pre compute " << compute_index_ << ": " << event->key << " v " << event->type;
95     if (event->type == kInit) {
96       auto host_ptr = init_host_ptr_[event->key];
97       MS_EXCEPTION_IF_NULL(host_ptr);
98       auto priority = mem_priority_[event->key];
99       auto iter = high_priority_device_ptr_.find(event->key);
100       if (priority != kMemPriorityLow && iter != high_priority_device_ptr_.end()) {
101         MS_EXCEPTION_IF_NULL(iter->second);
102         mem_result_[event->key] = iter->second;
103         if (priority == kMemPriorityMedium) {
104           mem_handler_->SwapIn(host_ptr, iter->second, event->mem_size, stream);
105         }
106         continue;
107       }
108       auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
109       if (device_ptr == nullptr) {
110         return false;
111       }
112       if (priority != kMemPriorityLow) {
113         high_priority_device_ptr_[event->key] = device_ptr;
114       }
115       mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
116       mem_result_[event->key] = device_ptr;
117     } else if (event->type == kMalloc) {
118       auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
119       if (device_ptr == nullptr) {
120         return false;
121       }
122       mem_result_[event->key] = device_ptr;
123     } else if (event->type == kSwapIn) {
124       bool from_init = true;
125       auto host_ptr = init_host_ptr_[event->key];
126       if (host_ptr == nullptr) {
127         host_ptr = swap_host_ptr_[event->key];
128         from_init = false;
129       }
130       auto device_ptr = mem_handler_->MallocDevice(event->mem_size);
131       if (device_ptr == nullptr) {
132         return false;
133       }
134       MS_EXCEPTION_IF_NULL(host_ptr);
135       mem_handler_->SwapIn(host_ptr, device_ptr, event->mem_size, stream);
136       mem_result_[event->key] = device_ptr;
137       if (!from_init) {
138         mem_handler_->FreeHost(host_ptr);
139         (void)swap_host_ptr_.erase(event->key);
140       }
141     }
142   }
143   return true;
144 }
145 
PostCompute(void * stream)146 bool MemScheduler::PostCompute(void *stream) {
147   if (need_record_event_) {
148     ++compute_index_;
149     return true;
150   }
151   MS_EXCEPTION_IF_NULL(mem_handler_);
152   if (post_compute_events_.size() <= compute_index_) {
153     MS_LOG_EXCEPTION << "Index out of post event range, index:" << compute_index_
154                      << ", event size:" << post_compute_events_.size();
155   }
156   auto &events = post_compute_events_[compute_index_];
157   for (auto &event : events) {
158     MS_EXCEPTION_IF_NULL(event);
159     MS_LOG(DEBUG) << "Post compute " << compute_index_ << ": " << event->key << " v " << event->type;
160     if (event->type == kFree) {
161       auto ptr = mem_result_[event->key];
162       if (ptr == nullptr) {
163         return false;
164       }
165       mem_handler_->FreeDevice(ptr);
166       (void)mem_result_.erase(event->key);
167     } else if (event->type == kSwapOut) {
168       auto device_ptr = mem_result_[event->key];
169       if (device_ptr == nullptr) {
170         return false;
171       }
172       auto host_ptr = init_host_ptr_[event->key];
173       if (host_ptr == nullptr) {
174         host_ptr = mem_handler_->MallocHost(event->mem_size);
175         swap_host_ptr_[event->key] = host_ptr;
176       }
177       MS_EXCEPTION_IF_NULL(host_ptr);
178       mem_handler_->SwapOut(device_ptr, host_ptr, event->mem_size, stream);
179       mem_handler_->FreeDevice(device_ptr);
180       (void)mem_result_.erase(device_ptr);
181     }
182   }
183   ++compute_index_;
184   return true;
185 }
186 
OptMemUsage()187 void MemScheduler::OptMemUsage() {
188   need_record_event_ = false;
189   if (optimized_) {
190     return;
191   }
192   CountMemUsage();
193   CheckMemSize();
194   if (need_swap_) {
195     GenEventSpan();
196     GenNoSwapEventSet();
197   }
198   GenEvents();
199 }
200 
CountMemUsage()201 void MemScheduler::CountMemUsage() {
202   if (!min_mem_used_.empty()) {
203     return;
204   }
205   min_mem_used_.resize(compute_index_, 0);
206   std::vector<size_t> total_mem_used(compute_index_, 0);
207   for (auto &item : mem_events_) {
208     auto &mem_events = item.second;
209     if (mem_events.empty()) {
210       continue;
211     }
212     auto first_event = mem_events[0];
213     MS_EXCEPTION_IF_NULL(first_event);
214     size_t i = 0;
215     if (first_event->type == kInit && mem_events.size() > 1) {
216       first_event = mem_events[1];
217       i = 1;
218     }
219     auto last_event = mem_events[mem_events.size() - 1];
220     for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
221       if (start_index < compute_index_) {
222         total_mem_used[start_index] += first_event->mem_size;
223       } else {
224         MS_LOG(ERROR) << "Error mem event index " << start_index;
225       }
226     }
227     for (; i < mem_events.size(); ++i) {
228       auto &event = mem_events[i];
229       MS_EXCEPTION_IF_NULL(event);
230       if (event->index < compute_index_) {
231         min_mem_used_[event->index] += first_event->mem_size;
232       } else {
233         MS_LOG(ERROR) << "Error mem event index " << event->index;
234       }
235     }
236   }
237   min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
238   mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end()));
239 }
240 
CheckMemSize()241 void MemScheduler::CheckMemSize() {
242   MS_EXCEPTION_IF_NULL(mem_handler_);
243   auto available_mem_size = mem_handler_->GetAvailableMemSize();
244   if (available_mem_size < min_mem_needed_) {
245     MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << available_mem_size
246                       << " while graph needs at least " << min_mem_needed_;
247   }
248   if (mem_used_without_swap_ > available_mem_size) {
249     need_swap_ = true;
250   }
251   MS_LOG(INFO) << "Available mem size: " << available_mem_size << ", graph needs mem size:" << mem_used_without_swap_
252                << "without swap, and needs at least " << min_mem_needed_ << " with swap.";
253 }
254 
GenEventSpan()255 void MemScheduler::GenEventSpan() {
256   if (!event_span_.empty()) {
257     return;
258   }
259   for (auto &item : mem_events_) {
260     auto &mem_events = item.second;
261     if (mem_events.empty()) {
262       continue;
263     }
264     auto first_event = mem_events[0];
265     MS_EXCEPTION_IF_NULL(first_event);
266     size_t i = 0;
267     if (first_event->type == kInit && mem_events.size() > 1) {
268       first_event = mem_events[1];
269       i = 1;
270     }
271     size_t last_index = first_event->index;
272     for (; i < mem_events.size(); ++i) {
273       auto &event = mem_events[i];
274       MS_EXCEPTION_IF_NULL(event);
275       auto span = event->index - last_index;
276       if (span > 1) {
277         (void)event_span_.emplace(std::pair<size_t, std::shared_ptr<Event>>(span, event));
278       }
279       last_index = event->index;
280     }
281   }
282 }
283 
GenNoSwapEventSet()284 void MemScheduler::GenNoSwapEventSet() {
285   MS_EXCEPTION_IF_NULL(mem_handler_);
286   auto available_mem_size = mem_handler_->GetAvailableMemSize();
287   auto threshold = available_mem_size * mem_used_factor_;
288   no_swap_events_.clear();
289   std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
290   for (auto iter = event_span_.begin(); iter != event_span_.end(); ++iter) {
291     auto span = iter->first;
292     auto &event = iter->second;
293     auto start_index = event->index - span + 1;
294     bool revert = false;
295     for (size_t i = start_index; i < event->index; ++i) {
296       cur_mem_used[i] += event->mem_size;
297       if (cur_mem_used[i] > threshold) {
298         revert = true;
299       }
300     }
301     if (revert) {
302       for (size_t i = start_index; i < event->index; ++i) {
303         cur_mem_used[i] -= event->mem_size;
304       }
305     } else {
306       (void)no_swap_events_.emplace(event);
307     }
308   }
309 }
310 
GenEvents()311 void MemScheduler::GenEvents() {
312   pre_compute_events_.resize(compute_index_);
313   post_compute_events_.resize(compute_index_);
314   for (auto &item : mem_events_) {
315     auto &mem_events = item.second;
316     if (mem_events.empty()) {
317       continue;
318     }
319     auto first_event = mem_events[0];
320     MS_EXCEPTION_IF_NULL(first_event);
321     if (first_event->type == kInit) {
322       if (mem_events.size() > 1) {
323         auto &second_event = mem_events[1];
324         MS_EXCEPTION_IF_NULL(second_event);
325         first_event->index = second_event->index;
326       } else {
327         continue;
328       }
329     }
330     if ((first_event->type == kInit || first_event->type == kMalloc) &&
331         first_event->index < pre_compute_events_.size()) {
332       (void)pre_compute_events_[first_event->index].emplace_back(first_event);
333     } else {
334       MS_LOG_EXCEPTION << "First event should be init or malloc!";
335     }
336     MemPriority priority = kMemPriorityLow;
337     auto iter = mem_priority_.find(first_event->key);
338     if (iter != mem_priority_.end()) {
339       priority = iter->second;
340     }
341     size_t pre_index = first_event->index;
342     for (size_t i = 1; i < mem_events.size(); ++i) {
343       auto &event = mem_events[i];
344       MS_EXCEPTION_IF_NULL(event);
345       if (need_swap_ && event->index - pre_index > 1 && priority == kMemPriorityLow &&
346           no_swap_events_.find(event) == no_swap_events_.end()) {
347         auto swap_out_event = std::make_shared<Event>(kSwapOut, pre_index);
348         swap_out_event->key = item.first;
349         swap_out_event->mem_size = first_event->mem_size;
350         (void)post_compute_events_[pre_index].emplace_back(swap_out_event);
351         auto swap_in_event = std::make_shared<Event>(kSwapIn, event->index);
352         swap_in_event->key = item.first;
353         swap_in_event->mem_size = first_event->mem_size;
354         (void)pre_compute_events_[event->index].emplace_back(swap_in_event);
355       }
356       if (event->index < pre_compute_events_.size()) {
357         (void)pre_compute_events_[event->index].emplace_back(event);
358       }
359       pre_index = event->index;
360     }
361     if (priority != kMemPriorityLow) {
362       continue;
363     }
364     auto &last_event = mem_events[mem_events.size() - 1];
365     MS_EXCEPTION_IF_NULL(last_event);
366     auto free_event = std::make_shared<Event>(kFree, last_event->index);
367     free_event->key = item.first;
368     if (last_event->index < post_compute_events_.size()) {
369       (void)post_compute_events_[last_event->index].emplace_back(free_event);
370     }
371   }
372 }
373 }  // namespace device
374 }  // namespace mindspore
375