1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ 19 #include <vector> 20 #include <map> 21 #include <set> 22 #include <memory> 23 #include <utility> 24 25 namespace mindspore { 26 namespace device { 27 class MemHandler { 28 public: 29 virtual size_t GetAvailableMemSize() = 0; 30 virtual void *MallocDevice(size_t mem_size) = 0; 31 virtual void FreeDevice(void *ptr) = 0; 32 virtual void *MallocHost(size_t mem_size) = 0; 33 virtual void FreeHost(void *ptr) = 0; 34 virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0; 35 virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0; 36 }; 37 38 enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh }; 39 40 class MemScheduler { 41 enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut }; 42 43 struct Event { EventEvent44 Event(const EventType &in_type, size_t in_index) { 45 type = in_type; 46 index = in_index; 47 } 48 49 EventType type; 50 size_t index{0}; 51 size_t mem_size{0}; 52 const void *key{nullptr}; 53 }; 54 55 public: 56 MemScheduler() = default; 57 ~MemScheduler() = default; 58 need_record_event()59 bool need_record_event() const { return need_record_event_; } 60 optimized()61 bool optimized() const { return optimized_; } 62 SetOptimized(bool flag)63 void SetOptimized(bool flag) { optimized_ = flag; } 64 SetMemHandler(const std::shared_ptr<MemHandler> & handler)65 void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; } 66 67 void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow); 68 69 void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow); 70 RecordMemUsage()71 void RecordMemUsage() { compute_index_ = 0; } 72 73 bool PreCompute(void *stream); 74 75 bool PostCompute(void *stream); 76 77 void OptMemUsage(); 78 79 void Clear(); 80 81 bool IsHighPriorityMem(const void *key); 82 83 void SetMemPriority(const void *key, MemPriority priority); 84 SetMemUsedFactor(float factor)85 void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; } 86 SetNeedSwap(bool flag)87 void SetNeedSwap(bool flag) { need_swap_ = flag; } 88 89 private: 90 void Record(const void *key, const EventType &event_type, size_t mem_size = 0); 91 void GenEvents(); 92 void CheckMemSize(); 93 void CountMemUsage(); 94 void GenEventSpan(); 95 void GenNoSwapEventSet(); 96 std::map<const void *, MemPriority> mem_priority_; 97 std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_; 98 std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_; 99 std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_; 100 std::map<const void *, void *> mem_result_; 101 std::map<const void *, void *> init_host_ptr_; 102 std::map<const void *, void *> swap_host_ptr_; 103 std::map<const void *, void *> high_priority_device_ptr_; 104 size_t compute_index_{0}; 105 bool need_record_event_{true}; 106 bool optimized_{false}; 107 std::shared_ptr<MemHandler> mem_handler_{nullptr}; 108 bool need_swap_{false}; 109 std::multimap<size_t, std::shared_ptr<Event>> event_span_; 110 std::set<std::shared_ptr<Event>> no_swap_events_; 111 std::vector<size_t> min_mem_used_; 112 size_t mem_used_without_swap_{0}; 113 size_t min_mem_needed_{0}; 114 float mem_used_factor_{0.9}; 115 }; 116 117 class MemSchedulerManager { 118 public: 119 MemSchedulerManager() = default; 120 ~MemSchedulerManager() = default; GetOrCreateMemScheduler(uint64_t uid)121 std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) { 122 auto scheduler = GetMemScheduler(uid); 123 if (scheduler == nullptr) { 124 scheduler = std::make_shared<MemScheduler>(); 125 graph_mem_scheduler_map_[uid] = scheduler; 126 } 127 return scheduler; 128 } 129 GetMemScheduler(uint64_t uid)130 std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) { 131 auto iter = graph_mem_scheduler_map_.find(uid); 132 if (iter != graph_mem_scheduler_map_.end()) { 133 return iter->second; 134 } 135 return nullptr; 136 } 137 138 private: 139 std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_; 140 }; 141 } // namespace device 142 } // namespace mindspore 143 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_ 144