• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "runtime/device/memory_offload_strategy.h"
17 #include <vector>
18 #include <map>
19 #include <memory>
20 #include <utility>
21 #include <algorithm>
22 #include "utils/log_adapter.h"
23 #include "include/backend/device_address.h"
24 
25 namespace mindspore {
26 namespace device {
27 constexpr size_t kFirstGetMemEventIndex = 1;
28 constexpr size_t kInitOrMallocMemEventIndex = 0;
29 
GetInstance()30 MemoryOffloadConflict &MemoryOffloadConflict::GetInstance() {
31   static MemoryOffloadConflict instance = MemoryOffloadConflict();
32   return instance;
33 }
34 
AddMemoryOffloadConflict(const HashSet<const void * > & conflict_set)35 void MemoryOffloadConflict::AddMemoryOffloadConflict(const HashSet<const void *> &conflict_set) {
36   for (const auto &key : conflict_set) {
37     (void)conflict_map_[key].insert(conflict_set.cbegin(), conflict_set.cend());
38   }
39 }
40 
GetConflictMap(const void * key)41 const HashSet<const void *> &MemoryOffloadConflict::GetConflictMap(const void *key) { return conflict_map_[key]; }
42 
43 template <typename Key>
Record(Key key,const MemEventType & event_type,size_t mem_size,MemPriority priority,size_t index)44 void GraphMemStatistic<Key>::Record(Key key, const MemEventType &event_type, size_t mem_size, MemPriority priority,
45                                     size_t index) {
46   if (mem_priority_.count(key) == 0) {
47     mem_priority_[key] = priority;
48     if (event_type == kGet) {
49       auto event = std::make_shared<MemEvent<Key>>(kMalloc, index);
50       event->mem_size = mem_size;
51       event->key = key;
52       (void)mem_events_[key].emplace_back(event);
53     }
54   }
55   auto event = std::make_shared<MemEvent<Key>>(event_type, index);
56   event->mem_size = mem_size;
57   event->key = key;
58   (void)mem_events_[key].emplace_back(event);
59 }
60 
61 template <typename Key>
GetPreComputeEvents(size_t index)62 MemEventPtrList<Key> &MemOffloadStrategy<Key>::GetPreComputeEvents(size_t index) {
63   if (pre_compute_events_.size() <= index) {
64     MS_LOG_EXCEPTION << "Index out of pre event range, index:" << index
65                      << ", event size:" << pre_compute_events_.size();
66   }
67   return pre_compute_events_[index];
68 }
69 
70 template <typename Key>
GetPostComputeEvents(size_t index)71 MemEventPtrList<Key> &MemOffloadStrategy<Key>::GetPostComputeEvents(size_t index) {
72   if (post_compute_events_.size() <= index) {
73     MS_LOG_EXCEPTION << "Index out of post event range, index:" << index
74                      << ", event size:" << post_compute_events_.size();
75   }
76   return post_compute_events_[index];
77 }
78 
79 template <typename Key>
Execute()80 void MemOffloadStrategy<Key>::Execute() {
81   CountMemUsage();
82   CheckMemSize();
83   if (need_swap_) {
84     GenEventSpan();
85     GenSwapEventSet();
86   } else {
87     GenContinuousMemAllocInfo();
88   }
89   GenComputeMemEvents();
90 }
91 
92 template <typename Key>
CountMemUsage()93 void MemOffloadStrategy<Key>::CountMemUsage() {
94   if (!min_mem_used_.empty()) {
95     return;
96   }
97   if (mem_events_.empty() || total_compute_index_ == 0) {
98     return;
99   }
100   min_mem_used_.resize(total_compute_index_, 0);
101   std::vector<size_t> total_mem_used(total_compute_index_, 0);
102   size_t high_priority_mem_size = 0;
103   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
104   for (auto &item : mem_events_) {
105     auto &mem_events = item.second;
106     if (mem_events.empty()) {
107       continue;
108     }
109     auto first_event = mem_events[kInitOrMallocMemEventIndex];
110     MS_EXCEPTION_IF_NULL(first_event);
111     const bool is_high_priority = IsHighPriorityMem(item.first);
112     if (continuous_mem_info_helper_->IsContinuousInputMem(item.first)) {
113       continue;
114     } else if (is_high_priority) {
115       high_priority_mem_size += first_event->mem_size;
116     } else {
117       auto last_event = mem_events[mem_events.size() - 1];
118       MS_EXCEPTION_IF_NULL(last_event);
119       for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
120         total_mem_used[start_index] += first_event->mem_size;
121       }
122     }
123 
124     // Calculate the minimum memory size for kernel execution.
125     for (const auto &event : mem_events) {
126       MS_EXCEPTION_IF_NULL(event);
127       if (event->type != kGet) {
128         continue;
129       }
130       min_mem_used_[event->index] += first_event->mem_size;
131     }
132   }
133   CountContinuousMemUsage(&total_mem_used);
134   min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
135   mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
136   if (mem_size_ < min_mem_needed_) {
137     MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
138                       << min_mem_needed_;
139   }
140 }
141 
142 template <typename Key>
IsHighPriorityMem(Key key) const143 bool MemOffloadStrategy<Key>::IsHighPriorityMem(Key key) const {
144   auto iter = mem_priority_.find(key);
145   if (iter != mem_priority_.end()) {
146     return iter->second == kMemPriorityHigh;
147   }
148   return false;
149 }
150 
151 template <typename Key>
CheckMemSize()152 void MemOffloadStrategy<Key>::CheckMemSize() {
153   if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
154     need_swap_ = true;
155   }
156   MS_LOG(INFO) << "Available mem size: " << mem_size_ << ", graph needs mem size: " << mem_used_without_swap_
157                << " without swap, and needs at least " << min_mem_needed_ << " with swap.";
158 }
159 
160 template <typename Key>
GenEventSpan()161 void MemOffloadStrategy<Key>::GenEventSpan() {
162   if (!event_span_.empty()) {
163     return;
164   }
165   for (auto &item : mem_events_) {
166     auto &tensor_events = item.second;
167     if (tensor_events.size() <= 1) {
168       continue;
169     }
170     const bool is_high_priority = IsHighPriorityMem(item.first);
171     for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
172       auto &event = tensor_events[i];
173       MS_EXCEPTION_IF_NULL(event);
174       if (event->type != kGet) {
175         MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
176       }
177       auto latest_get_event = tensor_events[i - 1];
178       if (i == kFirstGetMemEventIndex && is_high_priority) {
179         latest_get_event = tensor_events[tensor_events.size() - 1];
180       }
181       MS_EXCEPTION_IF_NULL(latest_get_event);
182       auto span = GetSpanBetweenMemEvents(latest_get_event->index, event->index);
183       // High priority memory that is only used once in one step
184       if (is_high_priority && span == 0 && latest_get_event == event) {
185         span = total_compute_index_;
186       }
187       if (span > 1) {
188         const size_t span_mul_size = (span - 1) * event->mem_size;
189         (void)event_span_.emplace(span_mul_size, std::make_pair(event, span));
190       }
191     }
192   }
193 }
194 
195 template <typename Key>
GenSwapEventSet()196 void MemOffloadStrategy<Key>::GenSwapEventSet() {
197   swap_events_.clear();
198   // manual offload strategy
199   if (!manual_offload_keys_.empty()) {
200     for (const auto &iter : event_span_) {
201       auto &event = iter.second.first;
202       MS_EXCEPTION_IF_NULL(event);
203       if (manual_offload_keys_.find(event->key) != manual_offload_keys_.end()) {
204         (void)swap_events_.emplace(event);
205       }
206     }
207     return;
208   }
209   // greedy span filter
210   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
211   continuous_mem_info_helper_->ClearContinuousMallocIndex();
212   std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
213 
214   auto compare_total_size = [](const ContinuousMemInfoPtr<Key> &l, const ContinuousMemInfoPtr<Key> &r) -> bool {
215     MS_EXCEPTION_IF_NULL(l);
216     MS_EXCEPTION_IF_NULL(r);
217     return l->total_size_ < r->total_size_;
218   };
219   auto all_continuous_mem_info = continuous_mem_info_helper_->GetAllContinuousMemInfo();
220   std::sort(all_continuous_mem_info.begin(), all_continuous_mem_info.end(), compare_total_size);
221   std::set<MemEventPtr<Key>> events_no_need_swap;
222   for (const auto &continuous_mem_info : all_continuous_mem_info) {
223     GenContinuousMemSwapEvent(continuous_mem_info, &cur_mem_used, &events_no_need_swap);
224   }
225   for (const auto &iter : event_span_) {
226     const auto &event = iter.second.first;
227     if (events_no_need_swap.count(event) > 0) {
228       continue;
229     }
230     auto span = iter.second.second;
231     AddToSwapEventSetIfOutOfMem(event, span, &cur_mem_used);
232   }
233 }
234 
235 template <typename Key>
AddToSwapEventSetIfOutOfMem(const MemEventPtr<Key> & event,size_t span,std::vector<size_t> * mem_used)236 void MemOffloadStrategy<Key>::AddToSwapEventSetIfOutOfMem(const MemEventPtr<Key> &event, size_t span,
237                                                           std::vector<size_t> *mem_used) {
238   MS_EXCEPTION_IF_NULL(event);
239   MS_EXCEPTION_IF_NULL(mem_used);
240   const auto start_index = (GetPreMemEventIndex(event->index, span) + 1) % total_compute_index_;
241   bool revert = false;
242   size_t cur_index = start_index;
243   while (cur_index != event->index) {
244     (*mem_used)[cur_index] += event->mem_size;
245     if (mem_used->at(cur_index) > mem_size_) {
246       revert = true;
247     }
248     cur_index += 1;
249     if (cur_index >= total_compute_index_) {
250       cur_index = 0;
251     }
252   }
253   if (revert) {
254     cur_index = start_index;
255     while (cur_index != event->index) {
256       (*mem_used)[cur_index] -= event->mem_size;
257       cur_index += 1;
258       if (cur_index >= total_compute_index_) {
259         cur_index = 0;
260       }
261     }
262     (void)swap_events_.emplace(event);
263   }
264 }
265 
266 template <typename Key>
GenContinuousMemSwapEvent(const ContinuousMemInfoPtr<Key> & continuous_mem_info,std::vector<size_t> * mem_used,std::set<MemEventPtr<Key>> * events_no_need_swap)267 void MemOffloadStrategy<Key>::GenContinuousMemSwapEvent(const ContinuousMemInfoPtr<Key> &continuous_mem_info,
268                                                         std::vector<size_t> *mem_used,
269                                                         std::set<MemEventPtr<Key>> *events_no_need_swap) {
270   MS_EXCEPTION_IF_NULL(continuous_mem_info);
271   MS_EXCEPTION_IF_NULL(mem_used);
272   MS_EXCEPTION_IF_NULL(events_no_need_swap);
273   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
274   if (continuous_mem_info->key_index_map_.empty()) {
275     return;
276   }
277   const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
278   if (!continuous_mem_info->is_input_) {
279     continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
280     return;
281   }
282   const auto max_span_mem_in_device = GetMaxSpanForContinuousMem(continuous_mem_info, *mem_used);
283   size_t first_malloc_span = 0;
284   size_t first_malloc_size_dup = 0;
285   for (const auto &key_index : continuous_mem_info->key_index_map_) {
286     const auto &events_iter = mem_events_.find(key_index.first);
287     if (events_iter == mem_events_.end() || events_iter->second.empty()) {
288       MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
289     }
290     size_t swap_in_event_index = kFirstGetMemEventIndex;
291     size_t swap_in_span = 0;
292     const bool is_high_priority = IsHighPriorityMem(key_index.first);
293     for (size_t i = kFirstGetMemEventIndex; i < events_iter->second.size(); ++i) {
294       const auto &mem_event = events_iter->second[i];
295       MS_EXCEPTION_IF_NULL(mem_event);
296       if (!is_high_priority && mem_event->index > continuous_mem_used_index) {
297         continue;
298       }
299       const size_t span = GetSpanBetweenMemEvents(mem_event->index, continuous_mem_used_index);
300       // Find the max span than less than or equal to max_span_mem_in_device.
301       if (span <= max_span_mem_in_device) {
302         if (span >= swap_in_span) {
303           swap_in_span = span;
304           swap_in_event_index = i;
305         }
306         (void)events_no_need_swap->insert(mem_event);
307       }
308     }
309     if (swap_in_event_index != kFirstGetMemEventIndex || is_high_priority) {
310       (void)swap_events_.insert(events_iter->second[swap_in_event_index]);
311     }
312     // Find the earliest index that continuous memory should be allocated
313     if (swap_in_span > first_malloc_span) {
314       first_malloc_span = swap_in_span;
315       first_malloc_size_dup = events_iter->second[swap_in_event_index]->mem_size;
316     } else if (swap_in_span == first_malloc_span) {
317       // Accumulate the memory size that already added to mem_used.
318       first_malloc_size_dup += events_iter->second[swap_in_event_index]->mem_size;
319     }
320   }
321   for (size_t span = 1; span <= first_malloc_span; ++span) {
322     size_t index = GetPreMemEventIndex(continuous_mem_used_index, span);
323     (*mem_used)[index] += continuous_mem_info->total_size_;
324   }
325   size_t index = GetPreMemEventIndex(continuous_mem_used_index, first_malloc_span);
326   (*mem_used)[index] -= first_malloc_size_dup;
327   continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, index);
328 }
329 
330 template <typename Key>
GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr<Key> & continuous_mem_info,const std::vector<size_t> & mem_used) const331 size_t MemOffloadStrategy<Key>::GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr<Key> &continuous_mem_info,
332                                                            const std::vector<size_t> &mem_used) const {
333   MS_EXCEPTION_IF_NULL(continuous_mem_info);
334   const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
335   size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
336   size_t max_span_mem_in_device = GetSpanBetweenMemEvents(earliest_malloc_index, continuous_mem_used_index);
337 
338   for (size_t span = 1; span <= max_span_mem_in_device; ++span) {
339     size_t cur_index = GetPreMemEventIndex(continuous_mem_used_index, span);
340     if (mem_used[cur_index] + continuous_mem_info->total_size_ > mem_size_) {
341       max_span_mem_in_device = span - 1;
342       break;
343     }
344   }
345   return max_span_mem_in_device;
346 }
347 
348 template <typename Key>
GetFirstMallocIndex(const ContinuousMemInfoPtr<Key> & continuous_mem_info) const349 size_t MemOffloadStrategy<Key>::GetFirstMallocIndex(const ContinuousMemInfoPtr<Key> &continuous_mem_info) const {
350   MS_EXCEPTION_IF_NULL(continuous_mem_info);
351   size_t earliest_malloc_index = continuous_mem_info->compute_index_;
352   for (const auto &key_index : continuous_mem_info->key_index_map_) {
353     const auto &events_iter = mem_events_.find(key_index.first);
354     if (events_iter == mem_events_.end() || events_iter->second.empty()) {
355       MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
356     }
357     const auto &first_event = events_iter->second[kInitOrMallocMemEventIndex];
358     MS_EXCEPTION_IF_NULL(first_event);
359     if (first_event->index < earliest_malloc_index) {
360       earliest_malloc_index = first_event->index;
361     }
362   }
363   return earliest_malloc_index;
364 }
365 
366 template <typename Key>
GenContinuousMemAllocInfo()367 void MemOffloadStrategy<Key>::GenContinuousMemAllocInfo() {
368   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
369   for (const auto &continuous_mem_info : continuous_mem_info_helper_->GetAllContinuousMemInfo()) {
370     GenContinuousMemAllocInfo(continuous_mem_info);
371   }
372 }
373 
374 template <typename Key>
AdjustFirstEventIndex()375 void MemOffloadStrategy<Key>::AdjustFirstEventIndex() {
376   for (const auto &item : mem_events_) {
377     const auto &mem_events = item.second;
378     if (mem_events.empty()) {
379       continue;
380     }
381     auto &first_event = mem_events[0];
382     MS_EXCEPTION_IF_NULL(first_event);
383     const auto &priority_iter = mem_priority_.find(item.first);
384     const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
385     if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
386       const auto &second_event = mem_events[1];
387       MS_EXCEPTION_IF_NULL(second_event);
388       first_event->index = second_event->index;
389     }
390   }
391 }
392 
393 template <typename Key>
GenContinuousMemAllocInfo(const ContinuousMemInfoPtr<Key> & continuous_mem_info)394 void MemOffloadStrategy<Key>::GenContinuousMemAllocInfo(const ContinuousMemInfoPtr<Key> &continuous_mem_info) {
395   MS_EXCEPTION_IF_NULL(continuous_mem_info);
396   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
397   if (!continuous_mem_info->is_input_) {
398     continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
399   } else {
400     const size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
401     continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, earliest_malloc_index);
402   }
403 }
404 
405 template <typename Key>
GenComputeMemEvents()406 void MemOffloadStrategy<Key>::GenComputeMemEvents() {
407   pre_compute_events_.clear();
408   post_compute_events_.clear();
409   pre_compute_events_.resize(total_compute_index_);
410   post_compute_events_.resize(total_compute_index_);
411   for (auto &item : mem_events_) {
412     auto &mem_events = item.second;
413     // No need to generate events for memory that has only one event, which means it is never used by any kernel.
414     if (mem_events.size() <= 1) {
415       continue;
416     }
417 
418     const bool is_high_priority = IsHighPriorityMem(item.first);
419     auto first_event = mem_events[kInitOrMallocMemEventIndex];
420     MS_EXCEPTION_IF_NULL(first_event);
421     const auto &first_get_event = mem_events[kFirstGetMemEventIndex];
422     MS_EXCEPTION_IF_NULL(first_get_event);
423     if (is_high_priority && swap_events_.find(first_get_event) != swap_events_.end()) {
424       first_event->index = first_get_event->index;
425     }
426     if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_compute_index_) {
427       (void)pre_compute_events_[first_event->index].emplace_back(first_event);
428     } else {
429       MS_LOG_EXCEPTION << "First event should be init or malloc!";
430     }
431 
432     const auto &last_event = mem_events[mem_events.size() - 1];
433     MS_EXCEPTION_IF_NULL(last_event);
434     size_t pre_index = is_high_priority ? last_event->index : first_event->index;
435     for (size_t i = kFirstGetMemEventIndex; i < mem_events.size(); ++i) {
436       auto &event = mem_events[i];
437       MS_EXCEPTION_IF_NULL(event);
438       if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
439         auto swap_out_event = std::make_shared<MemEvent<Key>>(kSwapOut, pre_index);
440         swap_out_event->key = item.first;
441         swap_out_event->mem_size = first_event->mem_size;
442         (void)post_compute_events_[pre_index].emplace_back(swap_out_event);
443         // avoid swap-in-event follow init-event
444         if (i != kFirstGetMemEventIndex || first_event->type != kInit) {
445           auto swap_in_event = std::make_shared<MemEvent<Key>>(kSwapIn, event->index);
446           swap_in_event->key = item.first;
447           swap_in_event->mem_size = first_event->mem_size;
448           (void)pre_compute_events_[event->index].emplace_back(swap_in_event);
449         }
450       }
451       if (event->index < pre_compute_events_.size()) {
452         (void)pre_compute_events_[event->index].emplace_back(event);
453       }
454       pre_index = event->index;
455     }
456     if (!is_high_priority) {
457       GenFreeEvent(last_event);
458     }
459   }
460 }
461 
462 template <typename Key>
GenFreeEvent(const MemEventPtr<Key> & last_event)463 void MemOffloadStrategy<Key>::GenFreeEvent(const MemEventPtr<Key> &last_event) {
464   MS_EXCEPTION_IF_NULL(last_event);
465   auto free_event = std::make_shared<MemEvent<Key>>(kFree, last_event->index);
466   free_event->key = last_event->key;
467   if (last_event->index < post_compute_events_.size()) {
468     (void)post_compute_events_[last_event->index].emplace_back(free_event);
469   }
470 }
471 
472 template <typename Key>
GetContinuousMemInfo(Key address_key) const473 ContinuousMemInfoPtr<Key> ContinuousMemInfoHelper<Key>::GetContinuousMemInfo(Key address_key) const {
474   const auto &continuous_info_iter = key_continuous_info_map_.find(address_key);
475   return continuous_info_iter == key_continuous_info_map_.end() ? nullptr : continuous_info_iter->second;
476 }
477 
478 template <typename Key>
GetAllContinuousMemInfo() const479 std::vector<ContinuousMemInfoPtr<Key>> ContinuousMemInfoHelper<Key>::GetAllContinuousMemInfo() const {
480   std::vector<ContinuousMemInfoPtr<Key>> all_continuous_mem_info(input_continuous_mem_info_.size() +
481                                                                  output_continuous_mem_info_.size());
482   (void)std::copy(input_continuous_mem_info_.begin(), input_continuous_mem_info_.end(),
483                   all_continuous_mem_info.begin());
484   (void)std::copy_backward(output_continuous_mem_info_.begin(), output_continuous_mem_info_.end(),
485                            all_continuous_mem_info.end());
486   return all_continuous_mem_info;
487 }
488 
489 template <typename Key>
IsContinuousMem(Key address_key) const490 bool ContinuousMemInfoHelper<Key>::IsContinuousMem(Key address_key) const {
491   const auto continuous_mem_info = GetContinuousMemInfo(address_key);
492   return (continuous_mem_info != nullptr);
493 }
494 
495 template <typename Key>
IsContinuousInputMem(Key address_key) const496 bool ContinuousMemInfoHelper<Key>::IsContinuousInputMem(Key address_key) const {
497   const auto continuous_mem_info = GetContinuousMemInfo(address_key);
498   return (continuous_mem_info != nullptr && continuous_mem_info->is_input_);
499 }
500 
501 template <typename Key>
AddContinuousMemInfo(bool is_input,size_t compute_index,size_t total_size,const std::vector<size_t> & align_size_list,const std::vector<Key> & address_key_list)502 void ContinuousMemInfoHelper<Key>::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
503                                                         const std::vector<size_t> &align_size_list,
504                                                         const std::vector<Key> &address_key_list) {
505   if (align_size_list.size() != address_key_list.size()) {
506     MS_LOG(EXCEPTION) << "Number of align size[" << align_size_list.size()
507                       << "] is supposed to be equal to number of address[" << address_key_list.size() << "]";
508   }
509   ContinuousMemInfoPtr<Key> continuous_mem_info =
510     std::make_shared<ContinuousMemInfo<Key>>(is_input, total_size, compute_index, align_size_list);
511   for (size_t i = 0; i < address_key_list.size(); i += 1) {
512     auto key = address_key_list[i];
513     MS_EXCEPTION_IF_NULL(key);
514     (void)continuous_mem_info->key_index_map_.emplace(key, i);
515     (void)key_continuous_info_map_.emplace(key, continuous_mem_info);
516   }
517   if (is_input) {
518     (void)input_continuous_mem_info_.insert(continuous_mem_info);
519   } else {
520     (void)output_continuous_mem_info_.insert(continuous_mem_info);
521   }
522   (void)index_continuous_info_map_[compute_index].emplace_back(continuous_mem_info);
523 }
524 
525 template <typename Key>
CountContinuousMemUsage(std::vector<size_t> * total_mem_used) const526 void MemOffloadStrategy<Key>::CountContinuousMemUsage(std::vector<size_t> *total_mem_used) const {
527   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
528   const auto &input_continuous_mem_info_ = continuous_mem_info_helper_->GetAllContinuousMemInfo();
529   for (const auto &continuous_mem_info : input_continuous_mem_info_) {
530     MS_EXCEPTION_IF_NULL(continuous_mem_info);
531     if (!continuous_mem_info->is_input_ || continuous_mem_info->key_index_map_.empty()) {
532       continue;
533     }
534     const auto &compute_index = continuous_mem_info->compute_index_;
535     size_t earliest_malloc_index = SIZE_MAX;
536     for (const auto &key_index : continuous_mem_info->key_index_map_) {
537       const auto &key = key_index.first;
538       const auto &events_iter = mem_events_.find(key);
539       if (events_iter == mem_events_.end() || events_iter->second.empty()) {
540         MS_LOG(EXCEPTION) << "Can not find memory events of continuous input memory, device address key: " << key;
541       }
542       const auto &mem_events = events_iter->second;
543       const auto &first_event = mem_events[kInitOrMallocMemEventIndex];
544       MS_EXCEPTION_IF_NULL(first_event);
545       if (first_event->index < earliest_malloc_index) {
546         earliest_malloc_index = first_event->index;
547       }
548       const auto &last_events = mem_events[mem_events.size() - 1];
549       MS_EXCEPTION_IF_NULL(last_events);
550       const auto end_index = IsHighPriorityMem(key) ? total_compute_index_ - 1 : last_events->index;
551       const auto mem_size = last_events->mem_size;
552       for (size_t start_index = compute_index + 1; start_index <= end_index; start_index += 1) {
553         (*total_mem_used)[start_index] += mem_size;
554       }
555     }
556     for (size_t start_index = earliest_malloc_index; start_index <= compute_index; ++start_index) {
557       (*total_mem_used)[start_index] += continuous_mem_info->total_size_;
558     }
559   }
560 }
561 
562 template class MemOffloadStrategy<const void *>;
563 template class ContinuousMemInfoHelper<const void *>;
564 template class MemOffloadStrategy<DeviceAddress *>;
565 template class ContinuousMemInfoHelper<DeviceAddress *>;
566 template struct GraphMemStatistic<DeviceAddress *>;
567 }  // namespace device
568 }  // namespace mindspore
569