1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_ 19 #include <vector> 20 #include <map> 21 #include <set> 22 #include <memory> 23 #include <utility> 24 #include <algorithm> 25 #include "utils/hash_map.h" 26 #include "utils/hash_set.h" 27 28 namespace mindspore { 29 namespace device { 30 enum MemPriority { kMemPriorityLow, kMemPriorityHigh }; 31 32 enum MemEventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut }; 33 34 template <typename Key> 35 struct MemEvent { MemEventMemEvent36 MemEvent(const MemEventType &in_type, size_t in_index) : type(in_type), index(in_index) {} 37 38 MemEventType type; 39 size_t index{0}; 40 size_t mem_size{0}; 41 Key key{nullptr}; 42 }; 43 44 template <typename Key> 45 using MemEventPtr = std::shared_ptr<MemEvent<Key>>; 46 template <typename Key> 47 using MemEventPtrList = std::vector<MemEventPtr<Key>>; 48 49 template <typename Key> 50 struct ContinuousMemInfo { ContinuousMemInfoContinuousMemInfo51 ContinuousMemInfo(bool is_input, size_t total_size, size_t compute_index, std::vector<size_t> align_size_list) 52 : is_input_(is_input), 53 total_size_(total_size), 54 compute_index_(compute_index), 55 align_size_list_(std::move(align_size_list)) {} 56 bool is_input_; 57 size_t total_size_; 58 size_t compute_index_; 59 const std::vector<size_t> align_size_list_; 60 std::map<Key, size_t> key_index_map_; 61 }; 62 63 template <typename Key> 64 using ContinuousMemInfoPtr = std::shared_ptr<ContinuousMemInfo<Key>>; 65 66 template <typename Key> 67 class ContinuousMemInfoHelper { 68 public: 69 ContinuousMemInfoHelper() = default; 70 ~ContinuousMemInfoHelper() = default; 71 void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size, 72 const std::vector<size_t> &align_size_list, const std::vector<Key> &address_key_list); 73 ContinuousMemInfoPtr<Key> GetContinuousMemInfo(Key address_key) const; 74 std::vector<ContinuousMemInfoPtr<Key>> GetAllContinuousMemInfo() const; 75 bool IsContinuousMem(Key address_key) const; 76 bool IsContinuousInputMem(Key address_key) const; 77 AddContinuousMallocIndex(const ContinuousMemInfoPtr<Key> & mem_info,size_t index)78 void AddContinuousMallocIndex(const ContinuousMemInfoPtr<Key> &mem_info, size_t index) { 79 (void)first_malloc_index_.emplace(mem_info, index); 80 (void)continuous_mem_first_alloc_info_[index].emplace_back(mem_info); 81 } 82 NeedMallocContinuousMem(const ContinuousMemInfoPtr<Key> & mem_info,size_t index)83 bool NeedMallocContinuousMem(const ContinuousMemInfoPtr<Key> &mem_info, size_t index) const { 84 const auto &iter = first_malloc_index_.find(mem_info); 85 return iter != first_malloc_index_.end() && iter->second == index; 86 } 87 GetContinuousMemAllocInfo(size_t index)88 std::vector<ContinuousMemInfoPtr<Key>> GetContinuousMemAllocInfo(size_t index) { 89 const auto &iter = continuous_mem_first_alloc_info_.find(index); 90 if (iter == continuous_mem_first_alloc_info_.end()) { 91 return {}; 92 } 93 return iter->second; 94 } 95 ClearContinuousMallocIndex()96 void ClearContinuousMallocIndex() { first_malloc_index_.clear(); } 97 GetIndexContinuousMemInfo(size_t index)98 const std::vector<ContinuousMemInfoPtr<Key>> &GetIndexContinuousMemInfo(size_t index) { 99 return index_continuous_info_map_[index]; 100 } 101 102 private: 103 std::set<ContinuousMemInfoPtr<Key>> input_continuous_mem_info_; 104 std::set<ContinuousMemInfoPtr<Key>> output_continuous_mem_info_; 105 std::map<Key, ContinuousMemInfoPtr<Key>> key_continuous_info_map_; 106 std::map<ContinuousMemInfoPtr<Key>, size_t> first_malloc_index_; 107 std::map<size_t, std::vector<ContinuousMemInfoPtr<Key>>> continuous_mem_first_alloc_info_; 108 std::map<size_t, std::vector<ContinuousMemInfoPtr<Key>>> index_continuous_info_map_; 109 }; 110 111 class MemoryOffloadConflict { 112 public: 113 void AddMemoryOffloadConflict(const HashSet<const void *> &conflict_set); 114 const HashSet<const void *> &GetConflictMap(const void *key); 115 static MemoryOffloadConflict &GetInstance(); CanBeOffloaded(const void * key)116 bool CanBeOffloaded(const void *key) { return offload_backlog_.count(key) != 0; } AddOffloadBacklog(const void * key)117 void AddOffloadBacklog(const void *key) { (void)offload_backlog_.insert(key); } 118 119 private: 120 MemoryOffloadConflict() = default; 121 ~MemoryOffloadConflict() = default; 122 HashSet<const void *> offload_backlog_; 123 HashMap<const void *, HashSet<const void *>> conflict_map_; 124 }; 125 126 template <typename Key> 127 struct GraphMemStatistic { 128 public: GraphMemStatisticGraphMemStatistic129 GraphMemStatistic() { continuous_mem_info_helper_ = std::make_shared<device::ContinuousMemInfoHelper<Key>>(); } 130 ~GraphMemStatistic() = default; 131 void Record(Key key, const MemEventType &event_type, size_t mem_size, MemPriority priority, size_t index); 132 133 std::map<Key, MemPriority> mem_priority_; 134 std::map<Key, MemEventPtrList<Key>> mem_events_; 135 std::set<Key> manual_offload_keys_; 136 std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_helper_; 137 size_t total_compute_index_{}; 138 }; 139 140 template <typename Key> 141 class MemOffloadStrategy { 142 public: MemOffloadStrategy(const std::map<Key,MemPriority> & mem_priority,const std::map<Key,MemEventPtrList<Key>> & mem_events,const std::set<Key> & manual_offload_keys,size_t total_compute_index,std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_manager)143 MemOffloadStrategy(const std::map<Key, MemPriority> &mem_priority, 144 const std::map<Key, MemEventPtrList<Key>> &mem_events, const std::set<Key> &manual_offload_keys, 145 size_t total_compute_index, 146 std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_manager) 147 : mem_priority_(mem_priority), 148 mem_events_(mem_events), 149 manual_offload_keys_(manual_offload_keys), 150 total_compute_index_(total_compute_index), 151 continuous_mem_info_helper_(std::move(continuous_mem_info_manager)) { 152 AdjustFirstEventIndex(); 153 } 154 MemOffloadStrategy(const GraphMemStatistic<Key> & mem_statistic)155 explicit MemOffloadStrategy(const GraphMemStatistic<Key> &mem_statistic) 156 : mem_priority_(mem_statistic.mem_priority_), 157 mem_events_(mem_statistic.mem_events_), 158 manual_offload_keys_(mem_statistic.manual_offload_keys_), 159 total_compute_index_(mem_statistic.total_compute_index_), 160 continuous_mem_info_helper_(mem_statistic.continuous_mem_info_helper_) { 161 AdjustFirstEventIndex(); 162 } 163 164 virtual ~MemOffloadStrategy() = default; 165 166 virtual void Execute(); 167 SetComputeTime(const std::vector<double> & compute_time)168 void SetComputeTime(const std::vector<double> &compute_time) { compute_time_ = compute_time; } 169 170 MemEventPtrList<Key> &GetPreComputeEvents(size_t index); 171 172 MemEventPtrList<Key> &GetPostComputeEvents(size_t index); 173 set_mem_size(size_t mem_size)174 void set_mem_size(size_t mem_size) { mem_size_ = mem_size; } 175 need_swap()176 bool need_swap() const { return need_swap_; } 177 GetContinuousMemAllocInfo(size_t index)178 std::vector<ContinuousMemInfoPtr<Key>> GetContinuousMemAllocInfo(size_t index) { 179 return continuous_mem_info_helper_->GetContinuousMemAllocInfo(index); 180 } 181 182 private: 183 void AdjustFirstEventIndex(); 184 185 bool IsHighPriorityMem(Key key) const; 186 187 void CountMemUsage(); 188 189 void CheckMemSize(); 190 191 void GenEventSpan(); 192 193 void GenSwapEventSet(); 194 195 void GenComputeMemEvents(); 196 197 void GenFreeEvent(const MemEventPtr<Key> &last_event); 198 199 void AddToSwapEventSetIfOutOfMem(const MemEventPtr<Key> &mem_event, size_t span, std::vector<size_t> *mem_used); 200 201 void GenContinuousMemSwapEvent(const ContinuousMemInfoPtr<Key> &continuous_mem_info, std::vector<size_t> *mem_used, 202 std::set<MemEventPtr<Key>> *events_no_need_swap); 203 204 size_t GetMaxSpanForContinuousMem(const ContinuousMemInfoPtr<Key> &continuous_mem_info, 205 const std::vector<size_t> &mem_used) const; 206 207 size_t GetFirstMallocIndex(const ContinuousMemInfoPtr<Key> &continuous_mem_info) const; 208 209 void GenContinuousMemAllocInfo(); 210 211 void GenContinuousMemAllocInfo(const ContinuousMemInfoPtr<Key> &continuous_mem_info); 212 213 void CountContinuousMemUsage(std::vector<size_t> *total_mem_used) const; 214 GetSpanBetweenMemEvents(size_t pre_index,size_t post_index)215 size_t GetSpanBetweenMemEvents(size_t pre_index, size_t post_index) const { 216 return (post_index + total_compute_index_ - pre_index) % total_compute_index_; 217 } 218 GetPreMemEventIndex(size_t cur_index,size_t span)219 size_t GetPreMemEventIndex(size_t cur_index, size_t span) const { 220 return (cur_index + total_compute_index_ - span) % total_compute_index_; 221 } 222 223 const std::map<Key, MemPriority> &mem_priority_; 224 const std::map<Key, MemEventPtrList<Key>> &mem_events_; 225 const std::set<Key> &manual_offload_keys_; 226 const size_t total_compute_index_; 227 std::vector<MemEventPtrList<Key>> pre_compute_events_; 228 std::vector<MemEventPtrList<Key>> post_compute_events_; 229 230 size_t mem_size_{0}; 231 std::vector<double> compute_time_; 232 bool need_swap_{false}; 233 std::multimap<size_t, std::pair<MemEventPtr<Key>, size_t>> event_span_; 234 std::set<MemEventPtr<Key>> swap_events_; 235 std::vector<size_t> min_mem_used_; 236 size_t mem_used_without_swap_{0}; 237 size_t min_mem_needed_{0}; 238 std::shared_ptr<ContinuousMemInfoHelper<Key>> continuous_mem_info_helper_{nullptr}; 239 }; 240 } // namespace device 241 } // namespace mindspore 242 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_OFFLOAD_STRATEGY_H_ 243