• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/memory_scheduler.h"
18 #include <algorithm>
19 #include <queue>
20 #include <set>
21 #ifdef _MSC_VER
22 #include <time.h>
23 #else
24 #include <sys/time.h>
25 #endif
26 #include "utils/log_adapter.h"
27 #include "utils/convert_utils_base.h"
28 
29 namespace mindspore {
30 namespace device {
31 namespace {
32 constexpr float kMaxMemReuseFactor = 1.0;
33 constexpr float kMinMemReuseFactor = 0.5;
34 constexpr float kRetryFactor = 0.1;
35 constexpr size_t kMockTimes = 5;
36 
GetCurrentTime()37 double GetCurrentTime() {
38 #ifdef _MSC_VER
39   return time(NULL) * 1.0e6;
40 #else
41   struct timeval tv;
42   (void)gettimeofday(&tv, nullptr);
43   return tv.tv_sec * 1.0e6 + tv.tv_usec;
44 #endif
45 }
46 }  // namespace
47 
AddContinuousMemInfo(bool is_input,size_t compute_index,size_t total_size,const std::vector<size_t> & align_size_list,const std::vector<const void * > & address_key_list)48 void MemScheduler::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
49                                         const std::vector<size_t> &align_size_list,
50                                         const std::vector<const void *> &address_key_list) {
51   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
52   continuous_mem_info_helper_->AddContinuousMemInfo(is_input, compute_index, total_size, align_size_list,
53                                                     address_key_list);
54 }
55 
Record(const void * key,const MemEventType & event_type,size_t mem_size)56 void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
57   if (key == nullptr) {
58     return;
59   }
60   auto event = std::make_shared<MemEvent<const void *>>(event_type, current_step_);
61   event->mem_size = mem_size;
62   event->key = key;
63   (void)mem_events_[key].emplace_back(event);
64   if (step_keys_.size() < current_step_ + 1) {
65     step_keys_.resize(current_step_ + 1);
66   }
67   if (event->type == kGet) {
68     (void)step_keys_[current_step_].insert(event->key);
69   }
70 }
71 
Init(const void * key,void * host_ptr,size_t mem_size,MemPriority priority)72 void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
73   if (need_record_event_) {
74     mem_priority_[key] = priority;
75     Record(key, kInit, mem_size);
76   }
77   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
78   auto_mem_offload_->SetInitHostPtr(key, host_ptr, mem_size);
79 }
80 
GetOrMalloc(const void * key,size_t mem_size,MemPriority priority)81 void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
82   if (need_record_event_) {
83     if (mem_priority_.find(key) == mem_priority_.end()) {
84       mem_priority_[key] = priority;
85       Record(key, kMalloc, mem_size);
86     }
87     Record(key, kGet, mem_size);
88     return nullptr;
89   }
90   if (strategy_ == nullptr) {
91     return nullptr;
92   }
93   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
94   return auto_mem_offload_->Get(key);
95 }
96 
Malloc(const MemEventPtr<const void * > & event,void * stream)97 void *MemScheduler::Malloc(const MemEventPtr<const void *> &event, void *stream) {
98   MS_EXCEPTION_IF_NULL(event);
99   MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
100   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
101   const bool is_continuous_mem = continuous_mem_info_helper_->IsContinuousMem(event->key);
102   if (!is_continuous_mem) {
103     return auto_mem_offload_->Malloc(event->key, event->mem_size, stream, GetNoReuseKeys());
104   }
105   const auto &continuous_mem_info = continuous_mem_info_helper_->GetContinuousMemInfo(event->key);
106   MS_EXCEPTION_IF_NULL(continuous_mem_info);
107   if (!continuous_mem_info_helper_->NeedMallocContinuousMem(continuous_mem_info, current_step_) ||
108       cur_step_allocated_continuous_mem_.count(continuous_mem_info) != 0) {
109     return auto_mem_offload_->Malloc(event->key, event->mem_size, stream, GetNoReuseKeys());
110   }
111   std::vector<const void *> keys(continuous_mem_info->key_index_map_.size(), nullptr);
112   for (const auto &iter : continuous_mem_info->key_index_map_) {
113     if (auto_mem_offload_->Get(iter.first, stream, GetNoReuseKeys()) != nullptr) {
114       MS_LOG(EXCEPTION) << "Device memory is allocated before first continuous memory alloc event, event key: "
115                         << event->key << ", continuous memory used index: " << continuous_mem_info->compute_index_;
116     }
117     keys[iter.second] = iter.first;
118   }
119   if (!auto_mem_offload_->MallocContinuous(keys, continuous_mem_info->align_size_list_, stream, GetNoReuseKeys())) {
120     MS_LOG(WARNING) << "Alloc continuous memory failed, size: " << continuous_mem_info->total_size_;
121     return nullptr;
122   }
123   (void)cur_step_allocated_continuous_mem_.insert(continuous_mem_info);
124   return auto_mem_offload_->Get(event->key);
125 }
126 
PreComputeMock(const MemEventPtr<const void * > & event)127 bool MemScheduler::PreComputeMock(const MemEventPtr<const void *> &event) {
128   MS_EXCEPTION_IF_NULL(event);
129   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
130   void *device_ptr = nullptr;
131   if (auto_mem_offload_->Get(event->key) != nullptr) {
132     return true;
133   } else {
134     device_ptr = Malloc(event, nullptr);
135   }
136   return device_ptr != nullptr;
137 }
138 
PreComputeInit(const MemEventPtr<const void * > & event,void * stream)139 bool MemScheduler::PreComputeInit(const MemEventPtr<const void *> &event, void *stream) {
140   MS_EXCEPTION_IF_NULL(event);
141   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
142   auto device_ptr = auto_mem_offload_->Get(event->key);
143   const bool new_malloc = device_ptr == nullptr;
144   if (new_malloc) {
145     device_ptr = Malloc(event, stream);
146   }
147   if (device_ptr == nullptr) {
148     return false;
149   }
150   if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
151     MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
152     (void)auto_mem_offload_->SwapIn(event->key, stream);
153   }
154   return true;
155 }
156 
PreComputeMalloc(const MemEventPtr<const void * > & event,void * stream)157 bool MemScheduler::PreComputeMalloc(const MemEventPtr<const void *> &event, void *stream) {
158   return Malloc(event, stream) != nullptr;
159 }
160 
PreComputeSwapIn(const MemEventPtr<const void * > & event,void * stream)161 bool MemScheduler::PreComputeSwapIn(const MemEventPtr<const void *> &event, void *stream) {
162   MS_EXCEPTION_IF_NULL(event);
163   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
164   if (Malloc(event, stream) == nullptr) {
165     return false;
166   }
167   return auto_mem_offload_->SwapIn(event->key, stream) != nullptr;
168 }
169 
PreComputeGet(const MemEventPtr<const void * > & event,void * stream)170 bool MemScheduler::PreComputeGet(const MemEventPtr<const void *> &event, void *stream) {
171   MS_EXCEPTION_IF_NULL(event);
172   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
173   return auto_mem_offload_->Get(event->key, stream, GetNoReuseKeys()) != nullptr;
174 }
175 
PreCompute(void * stream)176 bool MemScheduler::PreCompute(void *stream) {
177   if (strategy_ == nullptr) {
178     return true;
179   }
180   MS_EXCEPTION_IF_NULL(mem_handler_);
181   auto &events = strategy_->GetPreComputeEvents(current_step_);
182   for (auto &event : events) {
183     MS_EXCEPTION_IF_NULL(event);
184     MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
185     bool ret = true;
186     if (!optimized_) {
187       ret = PreComputeMock(event);
188     } else if (event->type == kInit) {
189       ret = PreComputeInit(event, stream);
190     } else if (event->type == kMalloc) {
191       ret = PreComputeMalloc(event, stream);
192     } else if (event->type == kSwapIn) {
193       ret = PreComputeSwapIn(event, stream);
194     } else if (event->type == kGet) {
195       ret = PreComputeGet(event, stream);
196     }
197     if (!ret) {
198       cur_step_allocated_continuous_mem_.clear();
199       return false;
200     }
201   }
202   if (record_compute_time_ && !updated_) {
203     compute_start_time_ = GetCurrentTime();
204   }
205   cur_step_allocated_continuous_mem_.clear();
206   return true;
207 }
208 
PostCompute(void * stream)209 bool MemScheduler::PostCompute(void *stream) {
210   if (strategy_ == nullptr) {
211     ++current_step_;
212     return true;
213   }
214   if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
215     compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
216   }
217   auto &events = strategy_->GetPostComputeEvents(current_step_);
218   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
219   for (auto &event : events) {
220     MS_EXCEPTION_IF_NULL(event);
221     MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
222     if (event->type == kSwapOut && optimized_) {
223       auto_mem_offload_->SwapOut(event->key, stream);
224     }
225     auto_mem_offload_->Free(event->key);
226   }
227   ++current_step_;
228   return true;
229 }
230 
OptMemUsage(float mem_used_factor)231 void MemScheduler::OptMemUsage(float mem_used_factor) {
232   MS_EXCEPTION_IF_NULL(mem_handler_);
233   MS_EXCEPTION_IF_NULL(auto_mem_offload_);
234   if (strategy_ == nullptr) {
235     strategy_ = std::make_shared<MemOffloadStrategy<const void *>>(mem_priority_, mem_events_, manual_offload_keys_,
236                                                                    total_step_, continuous_mem_info_helper_);
237     if (manual_offload_keys_.empty()) {
238       compute_time_.resize(total_step_);
239     } else {
240       updated_ = true;
241     }
242   }
243 
244   auto available_mem_size = mem_handler_->GetAvailableMemSize();
245   available_mem_size = FloatToSize(available_mem_size * mem_used_factor);
246   strategy_->set_mem_size(available_mem_size);
247   strategy_->Execute();
248 }
249 
Optimize()250 bool MemScheduler::Optimize() {
251   float mem_used_factor = kMaxMemReuseFactor;
252   while (mem_used_factor >= kMinMemReuseFactor) {
253     bool ret = true;
254     OptMemUsage(mem_used_factor);
255     for (size_t mock_time = 0; mock_time < kMockTimes; ++mock_time) {
256       ret = Mock();
257       if (!ret) {
258         break;
259       }
260     }
261     if (ret) {
262       optimized_ = true;
263       return true;
264     }
265     Clear();
266     mem_used_factor -= kRetryFactor;
267   }
268   return false;
269 }
270 
Mock()271 bool MemScheduler::Mock() {
272   current_step_ = 0;
273   for (size_t step = 0; step < total_step_; ++step) {
274     bool ret = PreCompute(nullptr);
275     if (!ret) {
276       return false;
277     }
278     auto &step_keys = step_keys_[step];
279     for (auto &key : step_keys) {
280       auto ptr = GetOrMalloc(key, 0);
281       if (ptr == nullptr) {
282         return false;
283       }
284     }
285     ret = PostCompute(nullptr);
286     if (!ret) {
287       return false;
288     }
289   }
290   return true;
291 }
292 
Update()293 void MemScheduler::Update() {
294   if (!optimized_) {
295     return;
296   }
297 
298   if (strategy_ == nullptr || !strategy_->need_swap()) {
299     return;
300   }
301 
302   if (updated_) {
303     return;
304   }
305 
306   if (!record_compute_time_) {
307     record_compute_time_ = true;
308     return;
309   }
310 
311   strategy_->SetComputeTime(compute_time_);
312   strategy_->Execute();
313   updated_ = true;
314 }
315 }  // namespace device
316 }  // namespace mindspore
317