1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "runtime/device/memory_offload_strategy.h"
17 #include <vector>
18 #include <map>
19 #include <memory>
20 #include <utility>
21 #include <algorithm>
22 #include "utils/log_adapter.h"
23 #include "include/backend/device_address.h"
24
25 namespace mindspore {
26 namespace device {
27 constexpr size_t kFirstGetMemEventIndex = 1;
28 constexpr size_t kInitOrMallocMemEventIndex = 0;
29
GetInstance()30 MemoryOffloadConflict &MemoryOffloadConflict::GetInstance() {
31 static MemoryOffloadConflict instance = MemoryOffloadConflict();
32 return instance;
33 }
34
AddMemoryOffloadConflict(const HashSet<const void * > & conflict_set)35 void MemoryOffloadConflict::AddMemoryOffloadConflict(const HashSet<const void *> &conflict_set) {
36 for (const auto &key : conflict_set) {
37 (void)conflict_map_[key].insert(conflict_set.cbegin(), conflict_set.cend());
38 }
39 }
40
GetConflictMap(const void * key)41 const HashSet<const void *> &MemoryOffloadConflict::GetConflictMap(const void *key) { return conflict_map_[key]; }
42
43 template <typename Key>
Record(Key key,const MemEventType & event_type,size_t mem_size,MemPriority priority,size_t index)44 void GraphMemStatistic<Key>::Record(Key key, const MemEventType &event_type, size_t mem_size, MemPriority priority,
45 size_t index) {
46 if (mem_priority_.count(key) == 0) {
47 mem_priority_[key] = priority;
48 if (event_type == kGet) {
49 auto event = std::make_shared<MemEvent<Key>>(kMalloc, index);
50 event->mem_size = mem_size;
51 event->key = key;
52 (void)mem_events_[key].emplace_back(event);
53 }
54 }
55 auto event = std::make_shared<MemEvent<Key>>(event_type, index);
56 event->mem_size = mem_size;
57 event->key = key;
58 (void)mem_events_[key].emplace_back(event);
59 }
60
61 template <typename Key>
GetPreComputeEvents(size_t index)62 MemEventPtrList<Key> &MemOffloadStrategy<Key>::GetPreComputeEvents(size_t index) {
63 if (pre_compute_events_.size() <= index) {
64 MS_LOG_EXCEPTION << "Index out of pre event range, index:" << index
65 << ", event size:" << pre_compute_events_.size();
66 }
67 return pre_compute_events_[index];
68 }
69
70 template <typename Key>
GetPostComputeEvents(size_t index)71 MemEventPtrList<Key> &MemOffloadStrategy<Key>::GetPostComputeEvents(size_t index) {
72 if (post_compute_events_.size() <= index) {
73 MS_LOG_EXCEPTION << "Index out of post event range, index:" << index
74 << ", event size:" << post_compute_events_.size();
75 }
76 return post_compute_events_[index];
77 }
78
79 template <typename Key>
Execute()80 void MemOffloadStrategy<Key>::Execute() {
81 CountMemUsage();
82 CheckMemSize();
83 if (need_swap_) {
84 GenEventSpan();
85 GenSwapEventSet();
86 } else {
87 GenContinuousMemAllocInfo();
88 }
89 GenComputeMemEvents();
90 }
91
92 template <typename Key>
CountMemUsage()93 void MemOffloadStrategy<Key>::CountMemUsage() {
94 if (!min_mem_used_.empty()) {
95 return;
96 }
97 if (mem_events_.empty() || total_compute_index_ == 0) {
98 return;
99 }
100 min_mem_used_.resize(total_compute_index_, 0);
101 std::vector<size_t> total_mem_used(total_compute_index_, 0);
102 size_t high_priority_mem_size = 0;
103 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
104 for (auto &item : mem_events_) {
105 auto &mem_events = item.second;
106 if (mem_events.empty()) {
107 continue;
108 }
109 auto first_event = mem_events[kInitOrMallocMemEventIndex];
110 MS_EXCEPTION_IF_NULL(first_event);
111 const bool is_high_priority = IsHighPriorityMem(item.first);
112 if (continuous_mem_info_helper_->IsContinuousInputMem(item.first)) {
113 continue;
114 } else if (is_high_priority) {
115 high_priority_mem_size += first_event->mem_size;
116 } else {
117 auto last_event = mem_events[mem_events.size() - 1];
118 MS_EXCEPTION_IF_NULL(last_event);
119 for (size_t start_index = first_event->index; start_index <= last_event->index; ++start_index) {
120 total_mem_used[start_index] += first_event->mem_size;
121 }
122 }
123
124 // Calculate the minimum memory size for kernel execution.
125 for (const auto &event : mem_events) {
126 MS_EXCEPTION_IF_NULL(event);
127 if (event->type != kGet) {
128 continue;
129 }
130 min_mem_used_[event->index] += first_event->mem_size;
131 }
132 }
133 CountContinuousMemUsage(&total_mem_used);
134 min_mem_needed_ = *(std::max_element(min_mem_used_.begin(), min_mem_used_.end()));
135 mem_used_without_swap_ = *(std::max_element(total_mem_used.begin(), total_mem_used.end())) + high_priority_mem_size;
136 if (mem_size_ < min_mem_needed_) {
137 MS_LOG(EXCEPTION) << "Out of memory, as available mem size is " << mem_size_ << " while graph needs at least "
138 << min_mem_needed_;
139 }
140 }
141
142 template <typename Key>
IsHighPriorityMem(Key key) const143 bool MemOffloadStrategy<Key>::IsHighPriorityMem(Key key) const {
144 auto iter = mem_priority_.find(key);
145 if (iter != mem_priority_.end()) {
146 return iter->second == kMemPriorityHigh;
147 }
148 return false;
149 }
150
151 template <typename Key>
CheckMemSize()152 void MemOffloadStrategy<Key>::CheckMemSize() {
153 if (mem_size_ < mem_used_without_swap_ || !manual_offload_keys_.empty()) {
154 need_swap_ = true;
155 }
156 MS_LOG(INFO) << "Available mem size: " << mem_size_ << ", graph needs mem size: " << mem_used_without_swap_
157 << " without swap, and needs at least " << min_mem_needed_ << " with swap.";
158 }
159
160 template <typename Key>
GenEventSpan()161 void MemOffloadStrategy<Key>::GenEventSpan() {
162 if (!event_span_.empty()) {
163 return;
164 }
165 for (auto &item : mem_events_) {
166 auto &tensor_events = item.second;
167 if (tensor_events.size() <= 1) {
168 continue;
169 }
170 const bool is_high_priority = IsHighPriorityMem(item.first);
171 for (size_t i = kFirstGetMemEventIndex; i < tensor_events.size(); ++i) {
172 auto &event = tensor_events[i];
173 MS_EXCEPTION_IF_NULL(event);
174 if (event->type != kGet) {
175 MS_LOG(EXCEPTION) << "Event should be Get except fist event.";
176 }
177 auto latest_get_event = tensor_events[i - 1];
178 if (i == kFirstGetMemEventIndex && is_high_priority) {
179 latest_get_event = tensor_events[tensor_events.size() - 1];
180 }
181 MS_EXCEPTION_IF_NULL(latest_get_event);
182 auto span = GetSpanBetweenMemEvents(latest_get_event->index, event->index);
183 // High priority memory that is only used once in one step
184 if (is_high_priority && span == 0 && latest_get_event == event) {
185 span = total_compute_index_;
186 }
187 if (span > 1) {
188 const size_t span_mul_size = (span - 1) * event->mem_size;
189 (void)event_span_.emplace(span_mul_size, std::make_pair(event, span));
190 }
191 }
192 }
193 }
194
195 template <typename Key>
GenSwapEventSet()196 void MemOffloadStrategy<Key>::GenSwapEventSet() {
197 swap_events_.clear();
198 // manual offload strategy
199 if (!manual_offload_keys_.empty()) {
200 for (const auto &iter : event_span_) {
201 auto &event = iter.second.first;
202 MS_EXCEPTION_IF_NULL(event);
203 if (manual_offload_keys_.find(event->key) != manual_offload_keys_.end()) {
204 (void)swap_events_.emplace(event);
205 }
206 }
207 return;
208 }
209 // greedy span filter
210 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
211 continuous_mem_info_helper_->ClearContinuousMallocIndex();
212 std::vector<size_t> cur_mem_used(min_mem_used_.begin(), min_mem_used_.end());
213
214 auto compare_total_size = [](const ContinuousMemInfoPtr<Key> &l, const ContinuousMemInfoPtr<Key> &r) -> bool {
215 MS_EXCEPTION_IF_NULL(l);
216 MS_EXCEPTION_IF_NULL(r);
217 return l->total_size_ < r->total_size_;
218 };
219 auto all_continuous_mem_info = continuous_mem_info_helper_->GetAllContinuousMemInfo();
220 std::sort(all_continuous_mem_info.begin(), all_continuous_mem_info.end(), compare_total_size);
221 std::set<MemEventPtr<Key>> events_no_need_swap;
222 for (const auto &continuous_mem_info : all_continuous_mem_info) {
223 GenContinuousMemSwapEvent(continuous_mem_info, &cur_mem_used, &events_no_need_swap);
224 }
225 for (const auto &iter : event_span_) {
226 const auto &event = iter.second.first;
227 if (events_no_need_swap.count(event) > 0) {
228 continue;
229 }
230 auto span = iter.second.second;
231 AddToSwapEventSetIfOutOfMem(event, span, &cur_mem_used);
232 }
233 }
234
235 template <typename Key>
AddToSwapEventSetIfOutOfMem(const MemEventPtr<Key> & event,size_t span,std::vector<size_t> * mem_used)236 void MemOffloadStrategy<Key>::AddToSwapEventSetIfOutOfMem(const MemEventPtr<Key> &event, size_t span,
237 std::vector<size_t> *mem_used) {
238 MS_EXCEPTION_IF_NULL(event);
239 MS_EXCEPTION_IF_NULL(mem_used);
240 const auto start_index = (GetPreMemEventIndex(event->index, span) + 1) % total_compute_index_;
241 bool revert = false;
242 size_t cur_index = start_index;
243 while (cur_index != event->index) {
244 (*mem_used)[cur_index] += event->mem_size;
245 if (mem_used->at(cur_index) > mem_size_) {
246 revert = true;
247 }
248 cur_index += 1;
249 if (cur_index >= total_compute_index_) {
250 cur_index = 0;
251 }
252 }
253 if (revert) {
254 cur_index = start_index;
255 while (cur_index != event->index) {
256 (*mem_used)[cur_index] -= event->mem_size;
257 cur_index += 1;
258 if (cur_index >= total_compute_index_) {
259 cur_index = 0;
260 }
261 }
262 (void)swap_events_.emplace(event);
263 }
264 }
265
266 template <typename Key>
GenContinuousMemSwapEvent(const ContinuousMemInfoPtr<Key> & continuous_mem_info,std::vector<size_t> * mem_used,std::set<MemEventPtr<Key>> * events_no_need_swap)267 void MemOffloadStrategy<Key>::GenContinuousMemSwapEvent(const ContinuousMemInfoPtr<Key> &continuous_mem_info,
268 std::vector<size_t> *mem_used,
269 std::set<MemEventPtr<Key>> *events_no_need_swap) {
270 MS_EXCEPTION_IF_NULL(continuous_mem_info);
271 MS_EXCEPTION_IF_NULL(mem_used);
272 MS_EXCEPTION_IF_NULL(events_no_need_swap);
273 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
274 if (continuous_mem_info->key_index_map_.empty()) {
275 return;
276 }
277 const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
278 if (!continuous_mem_info->is_input_) {
279 continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
280 return;
281 }
282 const auto max_span_mem_in_device = GetMaxSpanForContinuousMem(continuous_mem_info, *mem_used);
283 size_t first_malloc_span = 0;
284 size_t first_malloc_size_dup = 0;
285 for (const auto &key_index : continuous_mem_info->key_index_map_) {
286 const auto &events_iter = mem_events_.find(key_index.first);
287 if (events_iter == mem_events_.end() || events_iter->second.empty()) {
288 MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
289 }
290 size_t swap_in_event_index = kFirstGetMemEventIndex;
291 size_t swap_in_span = 0;
292 const bool is_high_priority = IsHighPriorityMem(key_index.first);
293 for (size_t i = kFirstGetMemEventIndex; i < events_iter->second.size(); ++i) {
294 const auto &mem_event = events_iter->second[i];
295 MS_EXCEPTION_IF_NULL(mem_event);
296 if (!is_high_priority && mem_event->index > continuous_mem_used_index) {
297 continue;
298 }
299 const size_t span = GetSpanBetweenMemEvents(mem_event->index, continuous_mem_used_index);
300 // Find the max span than less than or equal to max_span_mem_in_device.
301 if (span <= max_span_mem_in_device) {
302 if (span >= swap_in_span) {
303 swap_in_span = span;
304 swap_in_event_index = i;
305 }
306 (void)events_no_need_swap->insert(mem_event);
307 }
308 }
309 if (swap_in_event_index != kFirstGetMemEventIndex || is_high_priority) {
310 (void)swap_events_.insert(events_iter->second[swap_in_event_index]);
311 }
312 // Find the earliest index that continuous memory should be allocated
313 if (swap_in_span > first_malloc_span) {
314 first_malloc_span = swap_in_span;
315 first_malloc_size_dup = events_iter->second[swap_in_event_index]->mem_size;
316 } else if (swap_in_span == first_malloc_span) {
317 // Accumulate the memory size that already added to mem_used.
318 first_malloc_size_dup += events_iter->second[swap_in_event_index]->mem_size;
319 }
320 }
321 for (size_t span = 1; span <= first_malloc_span; ++span) {
322 size_t index = GetPreMemEventIndex(continuous_mem_used_index, span);
323 (*mem_used)[index] += continuous_mem_info->total_size_;
324 }
325 size_t index = GetPreMemEventIndex(continuous_mem_used_index, first_malloc_span);
326 (*mem_used)[index] -= first_malloc_size_dup;
327 continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, index);
328 }
329
330 template <typename Key>
GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr<Key> & continuous_mem_info,const std::vector<size_t> & mem_used) const331 size_t MemOffloadStrategy<Key>::GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr<Key> &continuous_mem_info,
332 const std::vector<size_t> &mem_used) const {
333 MS_EXCEPTION_IF_NULL(continuous_mem_info);
334 const size_t continuous_mem_used_index = continuous_mem_info->compute_index_;
335 size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
336 size_t max_span_mem_in_device = GetSpanBetweenMemEvents(earliest_malloc_index, continuous_mem_used_index);
337
338 for (size_t span = 1; span <= max_span_mem_in_device; ++span) {
339 size_t cur_index = GetPreMemEventIndex(continuous_mem_used_index, span);
340 if (mem_used[cur_index] + continuous_mem_info->total_size_ > mem_size_) {
341 max_span_mem_in_device = span - 1;
342 break;
343 }
344 }
345 return max_span_mem_in_device;
346 }
347
348 template <typename Key>
GetFirstMallocIndex(const ContinuousMemInfoPtr<Key> & continuous_mem_info) const349 size_t MemOffloadStrategy<Key>::GetFirstMallocIndex(const ContinuousMemInfoPtr<Key> &continuous_mem_info) const {
350 MS_EXCEPTION_IF_NULL(continuous_mem_info);
351 size_t earliest_malloc_index = continuous_mem_info->compute_index_;
352 for (const auto &key_index : continuous_mem_info->key_index_map_) {
353 const auto &events_iter = mem_events_.find(key_index.first);
354 if (events_iter == mem_events_.end() || events_iter->second.empty()) {
355 MS_LOG(EXCEPTION) << "Can not find events for continuous input memory, device address key: " << key_index.first;
356 }
357 const auto &first_event = events_iter->second[kInitOrMallocMemEventIndex];
358 MS_EXCEPTION_IF_NULL(first_event);
359 if (first_event->index < earliest_malloc_index) {
360 earliest_malloc_index = first_event->index;
361 }
362 }
363 return earliest_malloc_index;
364 }
365
366 template <typename Key>
GenContinuousMemAllocInfo()367 void MemOffloadStrategy<Key>::GenContinuousMemAllocInfo() {
368 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
369 for (const auto &continuous_mem_info : continuous_mem_info_helper_->GetAllContinuousMemInfo()) {
370 GenContinuousMemAllocInfo(continuous_mem_info);
371 }
372 }
373
374 template <typename Key>
AdjustFirstEventIndex()375 void MemOffloadStrategy<Key>::AdjustFirstEventIndex() {
376 for (const auto &item : mem_events_) {
377 const auto &mem_events = item.second;
378 if (mem_events.empty()) {
379 continue;
380 }
381 auto &first_event = mem_events[0];
382 MS_EXCEPTION_IF_NULL(first_event);
383 const auto &priority_iter = mem_priority_.find(item.first);
384 const bool is_high_priority = (priority_iter != mem_priority_.end() && priority_iter->second == kMemPriorityHigh);
385 if (first_event->type == kInit && !is_high_priority && mem_events.size() > 1) {
386 const auto &second_event = mem_events[1];
387 MS_EXCEPTION_IF_NULL(second_event);
388 first_event->index = second_event->index;
389 }
390 }
391 }
392
393 template <typename Key>
GenContinuousMemAllocInfo(const ContinuousMemInfoPtr<Key> & continuous_mem_info)394 void MemOffloadStrategy<Key>::GenContinuousMemAllocInfo(const ContinuousMemInfoPtr<Key> &continuous_mem_info) {
395 MS_EXCEPTION_IF_NULL(continuous_mem_info);
396 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
397 if (!continuous_mem_info->is_input_) {
398 continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, continuous_mem_info->compute_index_);
399 } else {
400 const size_t earliest_malloc_index = GetFirstMallocIndex(continuous_mem_info);
401 continuous_mem_info_helper_->AddContinuousMallocIndex(continuous_mem_info, earliest_malloc_index);
402 }
403 }
404
405 template <typename Key>
GenComputeMemEvents()406 void MemOffloadStrategy<Key>::GenComputeMemEvents() {
407 pre_compute_events_.clear();
408 post_compute_events_.clear();
409 pre_compute_events_.resize(total_compute_index_);
410 post_compute_events_.resize(total_compute_index_);
411 for (auto &item : mem_events_) {
412 auto &mem_events = item.second;
413 // No need to generate events for memory that has only one event, which means it is never used by any kernel.
414 if (mem_events.size() <= 1) {
415 continue;
416 }
417
418 const bool is_high_priority = IsHighPriorityMem(item.first);
419 auto first_event = mem_events[kInitOrMallocMemEventIndex];
420 MS_EXCEPTION_IF_NULL(first_event);
421 const auto &first_get_event = mem_events[kFirstGetMemEventIndex];
422 MS_EXCEPTION_IF_NULL(first_get_event);
423 if (is_high_priority && swap_events_.find(first_get_event) != swap_events_.end()) {
424 first_event->index = first_get_event->index;
425 }
426 if ((first_event->type == kInit || first_event->type == kMalloc) && first_event->index < total_compute_index_) {
427 (void)pre_compute_events_[first_event->index].emplace_back(first_event);
428 } else {
429 MS_LOG_EXCEPTION << "First event should be init or malloc!";
430 }
431
432 const auto &last_event = mem_events[mem_events.size() - 1];
433 MS_EXCEPTION_IF_NULL(last_event);
434 size_t pre_index = is_high_priority ? last_event->index : first_event->index;
435 for (size_t i = kFirstGetMemEventIndex; i < mem_events.size(); ++i) {
436 auto &event = mem_events[i];
437 MS_EXCEPTION_IF_NULL(event);
438 if (need_swap_ && swap_events_.find(event) != swap_events_.end()) {
439 auto swap_out_event = std::make_shared<MemEvent<Key>>(kSwapOut, pre_index);
440 swap_out_event->key = item.first;
441 swap_out_event->mem_size = first_event->mem_size;
442 (void)post_compute_events_[pre_index].emplace_back(swap_out_event);
443 // avoid swap-in-event follow init-event
444 if (i != kFirstGetMemEventIndex || first_event->type != kInit) {
445 auto swap_in_event = std::make_shared<MemEvent<Key>>(kSwapIn, event->index);
446 swap_in_event->key = item.first;
447 swap_in_event->mem_size = first_event->mem_size;
448 (void)pre_compute_events_[event->index].emplace_back(swap_in_event);
449 }
450 }
451 if (event->index < pre_compute_events_.size()) {
452 (void)pre_compute_events_[event->index].emplace_back(event);
453 }
454 pre_index = event->index;
455 }
456 if (!is_high_priority) {
457 GenFreeEvent(last_event);
458 }
459 }
460 }
461
462 template <typename Key>
GenFreeEvent(const MemEventPtr<Key> & last_event)463 void MemOffloadStrategy<Key>::GenFreeEvent(const MemEventPtr<Key> &last_event) {
464 MS_EXCEPTION_IF_NULL(last_event);
465 auto free_event = std::make_shared<MemEvent<Key>>(kFree, last_event->index);
466 free_event->key = last_event->key;
467 if (last_event->index < post_compute_events_.size()) {
468 (void)post_compute_events_[last_event->index].emplace_back(free_event);
469 }
470 }
471
472 template <typename Key>
GetContinuousMemInfo(Key address_key) const473 ContinuousMemInfoPtr<Key> ContinuousMemInfoHelper<Key>::GetContinuousMemInfo(Key address_key) const {
474 const auto &continuous_info_iter = key_continuous_info_map_.find(address_key);
475 return continuous_info_iter == key_continuous_info_map_.end() ? nullptr : continuous_info_iter->second;
476 }
477
478 template <typename Key>
GetAllContinuousMemInfo() const479 std::vector<ContinuousMemInfoPtr<Key>> ContinuousMemInfoHelper<Key>::GetAllContinuousMemInfo() const {
480 std::vector<ContinuousMemInfoPtr<Key>> all_continuous_mem_info(input_continuous_mem_info_.size() +
481 output_continuous_mem_info_.size());
482 (void)std::copy(input_continuous_mem_info_.begin(), input_continuous_mem_info_.end(),
483 all_continuous_mem_info.begin());
484 (void)std::copy_backward(output_continuous_mem_info_.begin(), output_continuous_mem_info_.end(),
485 all_continuous_mem_info.end());
486 return all_continuous_mem_info;
487 }
488
489 template <typename Key>
IsContinuousMem(Key address_key) const490 bool ContinuousMemInfoHelper<Key>::IsContinuousMem(Key address_key) const {
491 const auto continuous_mem_info = GetContinuousMemInfo(address_key);
492 return (continuous_mem_info != nullptr);
493 }
494
495 template <typename Key>
IsContinuousInputMem(Key address_key) const496 bool ContinuousMemInfoHelper<Key>::IsContinuousInputMem(Key address_key) const {
497 const auto continuous_mem_info = GetContinuousMemInfo(address_key);
498 return (continuous_mem_info != nullptr && continuous_mem_info->is_input_);
499 }
500
501 template <typename Key>
AddContinuousMemInfo(bool is_input,size_t compute_index,size_t total_size,const std::vector<size_t> & align_size_list,const std::vector<Key> & address_key_list)502 void ContinuousMemInfoHelper<Key>::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
503 const std::vector<size_t> &align_size_list,
504 const std::vector<Key> &address_key_list) {
505 if (align_size_list.size() != address_key_list.size()) {
506 MS_LOG(EXCEPTION) << "Number of align size[" << align_size_list.size()
507 << "] is supposed to be equal to number of address[" << address_key_list.size() << "]";
508 }
509 ContinuousMemInfoPtr<Key> continuous_mem_info =
510 std::make_shared<ContinuousMemInfo<Key>>(is_input, total_size, compute_index, align_size_list);
511 for (size_t i = 0; i < address_key_list.size(); i += 1) {
512 auto key = address_key_list[i];
513 MS_EXCEPTION_IF_NULL(key);
514 (void)continuous_mem_info->key_index_map_.emplace(key, i);
515 (void)key_continuous_info_map_.emplace(key, continuous_mem_info);
516 }
517 if (is_input) {
518 (void)input_continuous_mem_info_.insert(continuous_mem_info);
519 } else {
520 (void)output_continuous_mem_info_.insert(continuous_mem_info);
521 }
522 (void)index_continuous_info_map_[compute_index].emplace_back(continuous_mem_info);
523 }
524
525 template <typename Key>
CountContinuousMemUsage(std::vector<size_t> * total_mem_used) const526 void MemOffloadStrategy<Key>::CountContinuousMemUsage(std::vector<size_t> *total_mem_used) const {
527 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
528 const auto &input_continuous_mem_info_ = continuous_mem_info_helper_->GetAllContinuousMemInfo();
529 for (const auto &continuous_mem_info : input_continuous_mem_info_) {
530 MS_EXCEPTION_IF_NULL(continuous_mem_info);
531 if (!continuous_mem_info->is_input_ || continuous_mem_info->key_index_map_.empty()) {
532 continue;
533 }
534 const auto &compute_index = continuous_mem_info->compute_index_;
535 size_t earliest_malloc_index = SIZE_MAX;
536 for (const auto &key_index : continuous_mem_info->key_index_map_) {
537 const auto &key = key_index.first;
538 const auto &events_iter = mem_events_.find(key);
539 if (events_iter == mem_events_.end() || events_iter->second.empty()) {
540 MS_LOG(EXCEPTION) << "Can not find memory events of continuous input memory, device address key: " << key;
541 }
542 const auto &mem_events = events_iter->second;
543 const auto &first_event = mem_events[kInitOrMallocMemEventIndex];
544 MS_EXCEPTION_IF_NULL(first_event);
545 if (first_event->index < earliest_malloc_index) {
546 earliest_malloc_index = first_event->index;
547 }
548 const auto &last_events = mem_events[mem_events.size() - 1];
549 MS_EXCEPTION_IF_NULL(last_events);
550 const auto end_index = IsHighPriorityMem(key) ? total_compute_index_ - 1 : last_events->index;
551 const auto mem_size = last_events->mem_size;
552 for (size_t start_index = compute_index + 1; start_index <= end_index; start_index += 1) {
553 (*total_mem_used)[start_index] += mem_size;
554 }
555 }
556 for (size_t start_index = earliest_malloc_index; start_index <= compute_index; ++start_index) {
557 (*total_mem_used)[start_index] += continuous_mem_info->total_size_;
558 }
559 }
560 }
561
562 template class MemOffloadStrategy<const void *>;
563 template class ContinuousMemInfoHelper<const void *>;
564 template class MemOffloadStrategy<DeviceAddress *>;
565 template class ContinuousMemInfoHelper<DeviceAddress *>;
566 template struct GraphMemStatistic<DeviceAddress *>;
567 } // namespace device
568 } // namespace mindspore
569