• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
19 #include <vector>
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <queue>
24 #include <utility>
25 #include "runtime/device/memory_offload_strategy.h"
26 #include "runtime/device/auto_mem_offload.h"
27 
28 namespace mindspore {
29 namespace device {
30 class MemScheduler {
31  public:
32   MemScheduler() = default;
33   ~MemScheduler() = default;
34 
need_record_event()35   bool need_record_event() const { return need_record_event_; }
36 
set_need_record_event(bool flag)37   void set_need_record_event(bool flag) { need_record_event_ = flag; }
38 
optimized()39   bool optimized() const { return optimized_; }
40 
41   void Update();
42 
SetMemHandler(const std::shared_ptr<MemHandler> & handler)43   void SetMemHandler(const std::shared_ptr<MemHandler> &handler) {
44     mem_handler_ = handler;
45     auto_mem_offload_ = std::make_shared<AutoMemoryOffload>(handler);
46   }
47 
48   void Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority = kMemPriorityLow);
49 
50   void *GetOrMalloc(const void *key, size_t mem_size, MemPriority priority = kMemPriorityLow);
51 
HasDeviceMem(const void * key)52   bool HasDeviceMem(const void *key) const { return auto_mem_offload_->Get(key) != nullptr; }
53 
UpdateHighPriorityMem(const void * key)54   void UpdateHighPriorityMem(const void *key) { auto_mem_offload_->UpdateHighPriorityMem(key); }
55 
SetTotalStep(size_t step)56   void SetTotalStep(size_t step) {
57     total_step_ = step;
58     step_keys_.resize(total_step_);
59   }
60 
Reset()61   void Reset() { current_step_ = 0; }
62 
63   bool PreCompute(void *stream);
64 
65   bool PostCompute(void *stream);
66 
67   bool Optimize();
68 
Clear()69   void Clear() { auto_mem_offload_->Clear(); }
70 
SetOffload(const void * key)71   void SetOffload(const void *key) { (void)manual_offload_keys_.insert(key); }
72 
AddMemNeedInit(const void * key)73   void AddMemNeedInit(const void *key) { (void)high_priority_mem_need_init_.insert(key); }
74 
ClearMemNeedInit()75   void ClearMemNeedInit() { high_priority_mem_need_init_.clear(); }
76 
77   void AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
78                             const std::vector<size_t> &align_size_list,
79                             const std::vector<const void *> &address_key_list);
80 
81  private:
82   void Record(const void *key, const MemEventType &event_type, size_t mem_size = 0);
83 
84   void OptMemUsage(float mem_used_factor = 1.0f);
85 
86   bool Mock();
87 
88   bool PreComputeMock(const MemEventPtr<const void *> &event);
89 
90   bool PreComputeInit(const MemEventPtr<const void *> &event, void *stream);
91 
92   bool PreComputeMalloc(const MemEventPtr<const void *> &event, void *stream);
93 
94   bool PreComputeSwapIn(const MemEventPtr<const void *> &event, void *stream);
95 
96   bool PreComputeGet(const MemEventPtr<const void *> &event, void *stream);
97 
GetNoReuseKeys()98   const HashSet<const void *> &GetNoReuseKeys() const { return step_keys_[current_step_]; }
99 
100   void *Malloc(const MemEventPtr<const void *> &event, void *stream);
101 
102   // Scheduler status
103   bool need_record_event_{true};
104   bool optimized_{false};
105   bool updated_{false};
106   bool record_compute_time_{false};
107   size_t total_step_{0};
108   size_t current_step_{0};
109   // Memory status
110   std::map<const void *, MemEventPtrList<const void *>> mem_events_;
111   std::map<const void *, MemPriority> mem_priority_;
112   std::vector<HashSet<const void *>> step_keys_;
113   std::set<const void *> high_priority_mem_need_init_;
114   std::shared_ptr<ContinuousMemInfoHelper<const void *>> continuous_mem_info_helper_{
115     std::make_shared<ContinuousMemInfoHelper<const void *>>()};
116   std::set<ContinuousMemInfoPtr<const void *>> cur_step_allocated_continuous_mem_;
117   std::set<const void *> manual_offload_keys_;
118   // Compute time
119   std::vector<double> compute_time_;
120   double compute_start_time_{0};
121 
122   std::shared_ptr<AutoMemoryOffload> auto_mem_offload_;
123   std::shared_ptr<MemHandler> mem_handler_{nullptr};
124   std::shared_ptr<MemOffloadStrategy<const void *>> strategy_{nullptr};
125 };
126 
127 class MemSchedulerManager {
128  public:
129   MemSchedulerManager() = default;
130   ~MemSchedulerManager() = default;
GetOrCreateMemScheduler(uint64_t uid)131   std::shared_ptr<MemScheduler> GetOrCreateMemScheduler(uint64_t uid) {
132     auto scheduler = GetMemScheduler(uid);
133     if (scheduler == nullptr) {
134       scheduler = std::make_shared<MemScheduler>();
135       graph_mem_scheduler_map_[uid] = scheduler;
136     }
137     return scheduler;
138   }
139 
GetMemScheduler(uint64_t uid)140   std::shared_ptr<MemScheduler> GetMemScheduler(uint64_t uid) {
141     auto iter = graph_mem_scheduler_map_.find(uid);
142     if (iter != graph_mem_scheduler_map_.end()) {
143       return iter->second;
144     }
145     return nullptr;
146   }
147 
148  private:
149   std::map<uint64_t, std::shared_ptr<MemScheduler>> graph_mem_scheduler_map_;
150 };
151 }  // namespace device
152 }  // namespace mindspore
153 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_SCHEDULER_H_
154