• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
19 #include <vector>
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <utility>
24 
25 namespace mindspore {
26 namespace device {
27 class MemHandler {
28  public:
29   virtual size_t GetAvailableMemSize() = 0;
30   virtual void *MallocDevice(size_t mem_size) = 0;
31   virtual void FreeDevice(void *ptr) = 0;
32   virtual void *MallocHost(size_t mem_size) = 0;
33   virtual void FreeHost(void *ptr) = 0;
34   virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) = 0;
35   virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) = 0;
36 };
37 
38 enum MemPriority { kMemPriorityLow, kMemPriorityMedium, kMemPriorityHigh };
39 
40 class MemScheduler {
41   enum EventType { kInit, kMalloc, kGet, kFree, kSwapIn, kSwapOut };
42 
43   struct Event {
EventEvent44     Event(const EventType &in_type, size_t in_index) {
45       type = in_type;
46       index = in_index;
47     }
48 
49     EventType type;
50     size_t index{0};
51     size_t mem_size{0};
52     const void *key{nullptr};
53   };
54 
55  public:
56   MemScheduler() = default;
57   ~MemScheduler() = default;
58 
need_record_event()59   bool need_record_event() const { return need_record_event_; }
60 
optimized()61   bool optimized() const { return optimized_; }
62 
SetOptimized(bool flag)63   void SetOptimized(bool flag) { optimized_ = flag; }
64 
SetMemHandler(const std::shared_ptr<MemHandler> & handler)65   void SetMemHandler(const std::shared_ptr<MemHandler> &handler) { mem_handler_ = handler; }
66 
67   void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow);
68 
69   void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
70 
RecordMemUsage()71   void RecordMemUsage() { compute_index_ = 0; }
72 
73   bool PreCompute(void *stream);
74 
75   bool PostCompute(void *stream);
76 
77   void OptMemUsage();
78 
79   void Clear();
80 
81   bool IsHighPriorityMem(const void *key);
82 
83   void SetMemPriority(const void *key, MemPriority priority);
84 
SetMemUsedFactor(float factor)85   void SetMemUsedFactor(float factor) { mem_used_factor_ = factor; }
86 
SetNeedSwap(bool flag)87   void SetNeedSwap(bool flag) { need_swap_ = flag; }
88 
89  private:
90   void Record(const void *key, const EventType &event_type, size_t mem_size = 0);
91   void GenEvents();
92   void CheckMemSize();
93   void CountMemUsage();
94   void GenEventSpan();
95   void GenNoSwapEventSet();
96   std::map<const void *, MemPriority> mem_priority_;
97   std::map<const void *, std::vector<std::shared_ptr<Event>>> mem_events_;
98   std::vector<std::vector<std::shared_ptr<Event>>> pre_compute_events_;
99   std::vector<std::vector<std::shared_ptr<Event>>> post_compute_events_;
100   std::map<const void *, void *> mem_result_;
101   std::map<const void *, void *> init_host_ptr_;
102   std::map<const void *, void *> swap_host_ptr_;
103   std::map<const void *, void *> high_priority_device_ptr_;
104   size_t compute_index_{0};
105   bool need_record_event_{true};
106   bool optimized_{false};
107   std::shared_ptr<MemHandler> mem_handler_{nullptr};
108   bool need_swap_{false};
109   std::multimap<size_t, std::shared_ptr<Event>> event_span_;
110   std::set<std::shared_ptr<Event>> no_swap_events_;
111   std::vector<size_t> min_mem_used_;
112   size_t mem_used_without_swap_{0};
113   size_t min_mem_needed_{0};
114   float mem_used_factor_{0.9};
115 };
116 
117 class MemSchedulerManager {
118  public:
119   MemSchedulerManager() = default;
120   ~MemSchedulerManager() = default;
GetOrCreateMemScheduler(uint64_t uid)121   std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) {
122     auto scheduler = GetMemScheduler(uid);
123     if (scheduler == nullptr) {
124       scheduler = std::make_shared<MemScheduler>();
125       graph_mem_scheduler_map_[uid] = scheduler;
126     }
127     return scheduler;
128   }
129 
GetMemScheduler(uint64_t uid)130   std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) {
131     auto iter = graph_mem_scheduler_map_.find(uid);
132     if (iter != graph_mem_scheduler_map_.end()) {
133       return iter->second;
134     }
135     return nullptr;
136   }
137 
138  private:
139   std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_;
140 };
141 }  // namespace device
142 }  // namespace mindspore
143 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
144