• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
19 #include <vector>
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <utility>
24 #include <algorithm>
25 #include "utils/hash_map.h"
26 #include "utils/hash_set.h"
27 
28 namespace mindspore {
29 namespace device {
30 enum MemPriority { kMemPriorityLow, kMemPriorityHigh };
31 
32 enum MemEventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
33 
34 template <typename Key>
35 struct MemEvent {
MemEventMemEvent36   MemEvent(const MemEventType &in_type, size_t in_index) : type(in_type), index(in_index) {}
37 
38   MemEventType type;
39   size_t index{0};
40   size_t mem_size{0};
41   Key key{nullptr};
42 };
43 
44 template <typename Key>
45 using MemEventPtr = std::shared_ptr<MemEvent<Key>>;
46 template <typename Key>
47 using MemEventPtrList = std::vector<MemEventPtr<Key>>;
48 
49 template <typename Key>
50 struct ContinuousMemInfo {
ContinuousMemInfoContinuousMemInfo51   ContinuousMemInfo(bool is_input, size_t total_size, size_t compute_index, std::vector<size_t> align_size_list)
52       : is_input_(is_input),
53         total_size_(total_size),
54         compute_index_(compute_index),
55         align_size_list_(std::move(align_size_list)) {}
56   bool is_input_;
57   size_t total_size_;
58   size_t compute_index_;
59   const std::vector<size_t> align_size_list_;
60   std::map<Key, size_t> key_index_map_;
61 };
62 
63 template <typename Key>
64 using ContinuousMemInfoPtr = std::shared_ptr<ContinuousMemInfo<Key>>;
65 
66 template <typename Key>
67 class ContinuousMemInfoHelper {
68  public:
69   ContinuousMemInfoHelper() = default;
70   ~ContinuousMemInfoHelper() = default;
71   void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
72                             const std::vector<size_t> &align_size_list, const std::vector<Key> &address_key_list);
73   ContinuousMemInfoPtr<Key> GetContinuousMemInfo(Key address_key) const;
74   std::vector<ContinuousMemInfoPtr<Key>> GetAllContinuousMemInfo() const;
75   bool IsContinuousMem(Key address_key) const;
76   bool IsContinuousInputMem(Key address_key) const;
77 
AddContinuousMallocIndex(const ContinuousMemInfoPtr<Key> & mem_info,size_t index)78   void AddContinuousMallocIndex(const ContinuousMemInfoPtr<Key> &mem_info, size_t index) {
79     (void)first_malloc_index_.emplace(mem_info, index);
80     (void)continuous_mem_first_alloc_info_[index].emplace_back(mem_info);
81   }
82 
NeedMallocContinuousMem(const ContinuousMemInfoPtr<Key> & mem_info,size_t index)83   bool NeedMallocContinuousMem(const ContinuousMemInfoPtr<Key> &mem_info, size_t index) const {
84     const auto &iter = first_malloc_index_.find(mem_info);
85     return iter != first_malloc_index_.end() && iter->second == index;
86   }
87 
GetContinuousMemAllocInfo(size_t index)88   std::vector<ContinuousMemInfoPtr<Key>> GetContinuousMemAllocInfo(size_t index) {
89     const auto &iter = continuous_mem_first_alloc_info_.find(index);
90     if (iter == continuous_mem_first_alloc_info_.end()) {
91       return {};
92     }
93     return iter->second;
94   }
95 
ClearContinuousMallocIndex()96   void ClearContinuousMallocIndex() { first_malloc_index_.clear(); }
97 
GetIndexContinuousMemInfo(size_t index)98   const std::vector<ContinuousMemInfoPtr<Key>> &GetIndexContinuousMemInfo(size_t index) {
99     return index_continuous_info_map_[index];
100   }
101 
102  private:
103   std::set<ContinuousMemInfoPtr<Key>> input_continuous_mem_info_;
104   std::set<ContinuousMemInfoPtr<Key>> output_continuous_mem_info_;
105   std::map<Key, ContinuousMemInfoPtr<Key>> key_continuous_info_map_;
106   std::map<ContinuousMemInfoPtr<Key>, size_t> first_malloc_index_;
107   std::map<size_t, std::vector<ContinuousMemInfoPtr<Key>>> continuous_mem_first_alloc_info_;
108   std::map<size_t, std::vector<ContinuousMemInfoPtr<Key>>> index_continuous_info_map_;
109 };
110 
111 class MemoryOffloadConflict {
112  public:
113   void AddMemoryOffloadConflict(const HashSet<const void *> &conflict_set);
114   const HashSet<const void *> &GetConflictMap(const void *key);
115   static MemoryOffloadConflict &GetInstance();
CanBeOffloaded(const void * key)116   bool CanBeOffloaded(const void *key) { return offload_backlog_.count(key) != 0; }
AddOffloadBacklog(const void * key)117   void AddOffloadBacklog(const void *key) { (void)offload_backlog_.insert(key); }
118 
119  private:
120   MemoryOffloadConflict() = default;
121   ~MemoryOffloadConflict() = default;
122   HashSet<const void *> offload_backlog_;
123   HashMap<const void *, HashSet<const void *>> conflict_map_;
124 };
125 
126 template <typename Key>
127 struct GraphMemStatistic {
128  public:
GraphMemStatisticGraphMemStatistic129   GraphMemStatistic() { continuous_mem_info_helper_ = std::make_shared<device::ContinuousMemInfoHelper<Key>>(); }
130   ~GraphMemStatistic() = default;
131   void Record(Key key, const MemEventType &event_type, size_t mem_size, MemPriority priority, size_t index);
132 
133   std::map<Key, MemPriority> mem_priority_;
134   std::map<Key, MemEventPtrList<Key>> mem_events_;
135   std::set<Key> manual_offload_keys_;
136   std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_helper_;
137   size_t total_compute_index_{};
138 };
139 
140 template <typename Key>
141 class MemOffloadStrategy {
142  public:
MemOffloadStrategy(const std::map<Key,MemPriority> & mem_priority,const std::map<Key,MemEventPtrList<Key>> & mem_events,const std::set<Key> & manual_offload_keys,size_t total_compute_index,std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_manager)143   MemOffloadStrategy(const std::map<Key, MemPriority> &mem_priority,
144                      const std::map<Key, MemEventPtrList<Key>> &mem_events, const std::set<Key> &manual_offload_keys,
145                      size_t total_compute_index,
146                      std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_manager)
147       : mem_priority_(mem_priority),
148         mem_events_(mem_events),
149         manual_offload_keys_(manual_offload_keys),
150         total_compute_index_(total_compute_index),
151         continuous_mem_info_helper_(std::move(continuous_mem_info_manager)) {
152     AdjustFirstEventIndex();
153   }
154 
MemOffloadStrategy(const GraphMemStatistic<Key> & mem_statistic)155   explicit MemOffloadStrategy(const GraphMemStatistic<Key> &mem_statistic)
156       : mem_priority_(mem_statistic.mem_priority_),
157         mem_events_(mem_statistic.mem_events_),
158         manual_offload_keys_(mem_statistic.manual_offload_keys_),
159         total_compute_index_(mem_statistic.total_compute_index_),
160         continuous_mem_info_helper_(mem_statistic.continuous_mem_info_helper_) {
161     AdjustFirstEventIndex();
162   }
163 
164   virtual ~MemOffloadStrategy() = default;
165 
166   virtual void Execute();
167 
SetComputeTime(const std::vector<double> & compute_time)168   void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; }
169 
170   MemEventPtrList<Key> &GetPreComputeEvents(size_t index);
171 
172   MemEventPtrList<Key> &GetPostComputeEvents(size_t index);
173 
set_mem_size(size_t mem_size)174   void set_mem_size(size_t mem_size) { mem_size_ = mem_size; }
175 
need_swap()176   bool need_swap() const { return need_swap_; }
177 
GetContinuousMemAllocInfo(size_t index)178   std::vector<ContinuousMemInfoPtr<Key>> GetContinuousMemAllocInfo(size_t index) {
179     return continuous_mem_info_helper_->GetContinuousMemAllocInfo(index);
180   }
181 
182  private:
183   void AdjustFirstEventIndex();
184 
185   bool IsHighPriorityMem(Key key) const;
186 
187   void CountMemUsage();
188 
189   void CheckMemSize();
190 
191   void GenEventSpan();
192 
193   void GenSwapEventSet();
194 
195   void GenComputeMemEvents();
196 
197   void GenFreeEvent(const MemEventPtr<Key> &last_event);
198 
199   void AddToSwapEventSetIfOutOfMem(const MemEventPtr<Key> &mem_event, size_t span, std::vector<size_t> *mem_used);
200 
201   void GenContinuousMemSwapEvent(const ContinuousMemInfoPtr<Key> &continuous_mem_info, std::vector<size_t> *mem_used,
202                                  std::set<MemEventPtr<Key>> *events_no_need_swap);
203 
204   size_t GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr<Key> &continuous_mem_info,
205                                     const std::vector<size_t> &mem_used) const;
206 
207   size_t GetFirstMallocIndex(const ContinuousMemInfoPtr<Key> &continuous_mem_info) const;
208 
209   void GenContinuousMemAllocInfo();
210 
211   void GenContinuousMemAllocInfo(const ContinuousMemInfoPtr<Key> &continuous_mem_info);
212 
213   void CountContinuousMemUsage(std::vector<size_t> *total_mem_used) const;
214 
GetSpanBetweenMemEvents(size_t pre_index,size_t post_index)215   size_t GetSpanBetweenMemEvents(size_t pre_index, size_t post_index) const {
216     return (post_index + total_compute_index_ - pre_index) % total_compute_index_;
217   }
218 
GetPreMemEventIndex(size_t cur_index,size_t span)219   size_t GetPreMemEventIndex(size_t cur_index, size_t span) const {
220     return (cur_index + total_compute_index_ - span) % total_compute_index_;
221   }
222 
223   const std::map<Key, MemPriority> &mem_priority_;
224   const std::map<Key, MemEventPtrList<Key>> &mem_events_;
225   const std::set<Key> &manual_offload_keys_;
226   const size_t total_compute_index_;
227   std::vector<MemEventPtrList<Key>> pre_compute_events_;
228   std::vector<MemEventPtrList<Key>> post_compute_events_;
229 
230   size_t mem_size_{0};
231   std::vector<double> compute_time_;
232   bool need_swap_{false};
233   std::multimap<size_t, std::pair<MemEventPtr<Key>, size_t>> event_span_;
234   std::set<MemEventPtr<Key>> swap_events_;
235   std::vector<size_t> min_mem_used_;
236   size_t mem_used_without_swap_{0};
237   size_t min_mem_needed_{0};
238   std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_helper_{nullptr};
239 };
240 }  // namespace device
241 }  // namespace mindspore
242 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_
243