1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ 19 #include <vector> 20 #include <map> 21 #include <set> 22 #include <memory> 23 #include <queue> 24 #include <utility> 25 #include "runtime/device/memory_offload_strategy.h" 26 #include "runtime/device/auto_mem_offload.h" 27 28 namespace mindspore { 29 namespace device { 30 class MemScheduler { 31 public: 32 MemScheduler() = default; 33 ~MemScheduler() = default; 34 need_record_event()35 bool need_record_event() const { return need_record_event_; } 36 set_need_record_event(bool flag)37 void set_need_record_event(bool flag) { need_record_event_ = flag; } 38 optimized()39 bool optimized() const { return optimized_; } 40 41 void Update(); 42 SetMemHandler(const std::shared_ptr<MemHandler> & handler)43 void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { 44 mem_handler_ = handler; 45 auto_mem_offload_ = std::make_shared<AutoMemoryOffload>(handler); 46 } 47 48 void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow); 49 50 void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow); 51 HasDeviceMem(const void * key)52 bool HasDeviceMem(const void *key) const { return auto_mem_offload_->Get(key) != nullptr; } 53 UpdateHighPriorityMem(const void * key)54 void UpdateHighPriorityMem(const void *key) { auto_mem_offload_->UpdateHighPriorityMem(key); } 55 SetTotalStep(size_t step)56 void SetTotalStep(size_t step) { 57 total_step_ = step; 58 step_keys_.resize(total_step_); 59 } 60 Reset()61 void Reset() { current_step_ = 0; } 62 63 bool PreCompute(void *stream); 64 65 bool PostCompute(void *stream); 66 67 bool Optimize(); 68 Clear()69 void Clear() { auto_mem_offload_->Clear(); } 70 SetOffload(const void * key)71 void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); } 72 AddMemNeedInit(const void * key)73 void AddMemNeedInit(const void *key) { (void)high_priority_mem_need_init_.insert(key); } 74 ClearMemNeedInit()75 void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); } 76 77 void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size, 78 const std::vector<size_t> &align_size_list, 79 const std::vector<const void *> &address_key_list); 80 81 private: 82 void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0); 83 84 void OptMemUsage(float mem_used_factor = 1.0f); 85 86 bool Mock(); 87 88 bool PreComputeMock(const MemEventPtr<const void *> &event); 89 90 bool PreComputeInit(const MemEventPtr<const void *> &event, void *stream); 91 92 bool PreComputeMalloc(const MemEventPtr<const void *> &event, void *stream); 93 94 bool PreComputeSwapIn(const MemEventPtr<const void *> &event, void *stream); 95 96 bool PreComputeGet(const MemEventPtr<const void *> &event, void *stream); 97 GetNoReuseKeys()98 const HashSet<const void *> &GetNoReuseKeys() const { return step_keys_[current_step_]; } 99 100 void *Malloc(const MemEventPtr<const void *> &event, void *stream); 101 102 // Scheduler status 103 bool need_record_event_{true}; 104 bool optimized_{false}; 105 bool updated_{false}; 106 bool record_compute_time_{false}; 107 size_t total_step_{0}; 108 size_t current_step_{0}; 109 // Memory status 110 std::map<const void *, MemEventPtrList<const void *>> mem_events_; 111 std::map<const void *, MemPriority> mem_priority_; 112 std::vector<HashSet<const void *>> step_keys_; 113 std::set<const void *> high_priority_mem_need_init_; 114 std::shared_ptr<ContinuousMemInfoHelper<const void *>> continuous_mem_info_helper_{ 115 std::make_shared<ContinuousMemInfoHelper<const void *>>()}; 116 std::set<ContinuousMemInfoPtr<const void *>> cur_step_allocated_continuous_mem_; 117 std::set<const void *> manual_offload_keys_; 118 // Compute time 119 std::vector<double> compute_time_; 120 double compute_start_time_{0}; 121 122 std::shared_ptr<AutoMemoryOffload> auto_mem_offload_; 123 std::shared_ptr<MemHandler> mem_handler_{nullptr}; 124 std::shared_ptr<MemOffloadStrategy<const void *>> strategy_{nullptr}; 125 }; 126 127 class MemSchedulerManager { 128 public: 129 MemSchedulerManager() = default; 130 ~MemSchedulerManager() = default; GetOrCreateMemScheduler(uint64_t uid)131 std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) { 132 auto scheduler = GetMemScheduler(uid); 133 if (scheduler == nullptr) { 134 scheduler = std::make_shared<MemScheduler>(); 135 graph_mem_scheduler_map_[uid] = scheduler; 136 } 137 return scheduler; 138 } 139 GetMemScheduler(uint64_t uid)140 std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) { 141 auto iter = graph_mem_scheduler_map_.find(uid); 142 if (iter != graph_mem_scheduler_map_.end()) { 143 return iter->second; 144 } 145 return nullptr; 146 } 147 148 private: 149 std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_; 150 }; 151 } // namespace device 152 } // namespace mindspore 153 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ 154