1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/memory_scheduler.h"
18 #include <algorithm>
19 #include <queue>
20 #include <set>
21 #ifdef _MSC_VER
22 #include <time.h>
23 #else
24 #include <sys/time.h>
25 #endif
26 #include "utils/log_adapter.h"
27 #include "utils/convert_utils_base.h"
28
29 namespace mindspore {
30 namespace device {
31 namespace {
32 constexpr float kMaxMemReuseFactor = 1.0;
33 constexpr float kMinMemReuseFactor = 0.5;
34 constexpr float kRetryFactor = 0.1;
35 constexpr size_t kMockTimes = 5;
36
GetCurrentTime()37 double GetCurrentTime() {
38 #ifdef _MSC_VER
39 return time(NULL) * 1.0e6;
40 #else
41 struct timeval tv;
42 (void)gettimeofday(&tv, nullptr);
43 return tv.tv_sec * 1.0e6 + tv.tv_usec;
44 #endif
45 }
46 } // namespace
47
AddContinuousMemInfo(bool is_input,size_t compute_index,size_t total_size,const std::vector<size_t> & align_size_list,const std::vector<const void * > & address_key_list)48 void MemScheduler::AddContinuousMemInfo(bool is_input, size_t compute_index, size_t total_size,
49 const std::vector<size_t> &align_size_list,
50 const std::vector<const void *> &address_key_list) {
51 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
52 continuous_mem_info_helper_->AddContinuousMemInfo(is_input, compute_index, total_size, align_size_list,
53 address_key_list);
54 }
55
Record(const void * key,const MemEventType & event_type,size_t mem_size)56 void MemScheduler::Record(const void *key, const MemEventType &event_type, size_t mem_size) {
57 if (key == nullptr) {
58 return;
59 }
60 auto event = std::make_shared<MemEvent<const void *>>(event_type, current_step_);
61 event->mem_size = mem_size;
62 event->key = key;
63 (void)mem_events_[key].emplace_back(event);
64 if (step_keys_.size() < current_step_ + 1) {
65 step_keys_.resize(current_step_ + 1);
66 }
67 if (event->type == kGet) {
68 (void)step_keys_[current_step_].insert(event->key);
69 }
70 }
71
Init(const void * key,void * host_ptr,size_t mem_size,MemPriority priority)72 void MemScheduler::Init(const void *key, void *host_ptr, size_t mem_size, MemPriority priority) {
73 if (need_record_event_) {
74 mem_priority_[key] = priority;
75 Record(key, kInit, mem_size);
76 }
77 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
78 auto_mem_offload_->SetInitHostPtr(key, host_ptr, mem_size);
79 }
80
GetOrMalloc(const void * key,size_t mem_size,MemPriority priority)81 void *MemScheduler::GetOrMalloc(const void *key, size_t mem_size, MemPriority priority) {
82 if (need_record_event_) {
83 if (mem_priority_.find(key) == mem_priority_.end()) {
84 mem_priority_[key] = priority;
85 Record(key, kMalloc, mem_size);
86 }
87 Record(key, kGet, mem_size);
88 return nullptr;
89 }
90 if (strategy_ == nullptr) {
91 return nullptr;
92 }
93 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
94 return auto_mem_offload_->Get(key);
95 }
96
Malloc(const MemEventPtr<const void * > & event,void * stream)97 void *MemScheduler::Malloc(const MemEventPtr<const void *> &event, void *stream) {
98 MS_EXCEPTION_IF_NULL(event);
99 MS_EXCEPTION_IF_NULL(continuous_mem_info_helper_);
100 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
101 const bool is_continuous_mem = continuous_mem_info_helper_->IsContinuousMem(event->key);
102 if (!is_continuous_mem) {
103 return auto_mem_offload_->Malloc(event->key, event->mem_size, stream, GetNoReuseKeys());
104 }
105 const auto &continuous_mem_info = continuous_mem_info_helper_->GetContinuousMemInfo(event->key);
106 MS_EXCEPTION_IF_NULL(continuous_mem_info);
107 if (!continuous_mem_info_helper_->NeedMallocContinuousMem(continuous_mem_info, current_step_) ||
108 cur_step_allocated_continuous_mem_.count(continuous_mem_info) != 0) {
109 return auto_mem_offload_->Malloc(event->key, event->mem_size, stream, GetNoReuseKeys());
110 }
111 std::vector<const void *> keys(continuous_mem_info->key_index_map_.size(), nullptr);
112 for (const auto &iter : continuous_mem_info->key_index_map_) {
113 if (auto_mem_offload_->Get(iter.first, stream, GetNoReuseKeys()) != nullptr) {
114 MS_LOG(EXCEPTION) << "Device memory is allocated before first continuous memory alloc event, event key: "
115 << event->key << ", continuous memory used index: " << continuous_mem_info->compute_index_;
116 }
117 keys[iter.second] = iter.first;
118 }
119 if (!auto_mem_offload_->MallocContinuous(keys, continuous_mem_info->align_size_list_, stream, GetNoReuseKeys())) {
120 MS_LOG(WARNING) << "Alloc continuous memory failed, size: " << continuous_mem_info->total_size_;
121 return nullptr;
122 }
123 (void)cur_step_allocated_continuous_mem_.insert(continuous_mem_info);
124 return auto_mem_offload_->Get(event->key);
125 }
126
PreComputeMock(const MemEventPtr<const void * > & event)127 bool MemScheduler::PreComputeMock(const MemEventPtr<const void *> &event) {
128 MS_EXCEPTION_IF_NULL(event);
129 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
130 void *device_ptr = nullptr;
131 if (auto_mem_offload_->Get(event->key) != nullptr) {
132 return true;
133 } else {
134 device_ptr = Malloc(event, nullptr);
135 }
136 return device_ptr != nullptr;
137 }
138
PreComputeInit(const MemEventPtr<const void * > & event,void * stream)139 bool MemScheduler::PreComputeInit(const MemEventPtr<const void *> &event, void *stream) {
140 MS_EXCEPTION_IF_NULL(event);
141 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
142 auto device_ptr = auto_mem_offload_->Get(event->key);
143 const bool new_malloc = device_ptr == nullptr;
144 if (new_malloc) {
145 device_ptr = Malloc(event, stream);
146 }
147 if (device_ptr == nullptr) {
148 return false;
149 }
150 if (new_malloc || high_priority_mem_need_init_.count(event->key) != 0) {
151 MS_LOG(DEBUG) << "Init input data from host, key: " << event->key;
152 (void)auto_mem_offload_->SwapIn(event->key, stream);
153 }
154 return true;
155 }
156
PreComputeMalloc(const MemEventPtr<const void * > & event,void * stream)157 bool MemScheduler::PreComputeMalloc(const MemEventPtr<const void *> &event, void *stream) {
158 return Malloc(event, stream) != nullptr;
159 }
160
PreComputeSwapIn(const MemEventPtr<const void * > & event,void * stream)161 bool MemScheduler::PreComputeSwapIn(const MemEventPtr<const void *> &event, void *stream) {
162 MS_EXCEPTION_IF_NULL(event);
163 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
164 if (Malloc(event, stream) == nullptr) {
165 return false;
166 }
167 return auto_mem_offload_->SwapIn(event->key, stream) != nullptr;
168 }
169
PreComputeGet(const MemEventPtr<const void * > & event,void * stream)170 bool MemScheduler::PreComputeGet(const MemEventPtr<const void *> &event, void *stream) {
171 MS_EXCEPTION_IF_NULL(event);
172 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
173 return auto_mem_offload_->Get(event->key, stream, GetNoReuseKeys()) != nullptr;
174 }
175
PreCompute(void * stream)176 bool MemScheduler::PreCompute(void *stream) {
177 if (strategy_ == nullptr) {
178 return true;
179 }
180 MS_EXCEPTION_IF_NULL(mem_handler_);
181 auto &events = strategy_->GetPreComputeEvents(current_step_);
182 for (auto &event : events) {
183 MS_EXCEPTION_IF_NULL(event);
184 MS_LOG(DEBUG) << "Pre compute " << current_step_ << ": " << event->key << " v " << event->type;
185 bool ret = true;
186 if (!optimized_) {
187 ret = PreComputeMock(event);
188 } else if (event->type == kInit) {
189 ret = PreComputeInit(event, stream);
190 } else if (event->type == kMalloc) {
191 ret = PreComputeMalloc(event, stream);
192 } else if (event->type == kSwapIn) {
193 ret = PreComputeSwapIn(event, stream);
194 } else if (event->type == kGet) {
195 ret = PreComputeGet(event, stream);
196 }
197 if (!ret) {
198 cur_step_allocated_continuous_mem_.clear();
199 return false;
200 }
201 }
202 if (record_compute_time_ && !updated_) {
203 compute_start_time_ = GetCurrentTime();
204 }
205 cur_step_allocated_continuous_mem_.clear();
206 return true;
207 }
208
PostCompute(void * stream)209 bool MemScheduler::PostCompute(void *stream) {
210 if (strategy_ == nullptr) {
211 ++current_step_;
212 return true;
213 }
214 if (record_compute_time_ && !updated_ && current_step_ < compute_time_.size()) {
215 compute_time_[current_step_] = GetCurrentTime() - compute_start_time_;
216 }
217 auto &events = strategy_->GetPostComputeEvents(current_step_);
218 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
219 for (auto &event : events) {
220 MS_EXCEPTION_IF_NULL(event);
221 MS_LOG(DEBUG) << "Post compute " << current_step_ << ": " << event->key << " v " << event->type;
222 if (event->type == kSwapOut && optimized_) {
223 auto_mem_offload_->SwapOut(event->key, stream);
224 }
225 auto_mem_offload_->Free(event->key);
226 }
227 ++current_step_;
228 return true;
229 }
230
OptMemUsage(float mem_used_factor)231 void MemScheduler::OptMemUsage(float mem_used_factor) {
232 MS_EXCEPTION_IF_NULL(mem_handler_);
233 MS_EXCEPTION_IF_NULL(auto_mem_offload_);
234 if (strategy_ == nullptr) {
235 strategy_ = std::make_shared<MemOffloadStrategy<const void *>>(mem_priority_, mem_events_, manual_offload_keys_,
236 total_step_, continuous_mem_info_helper_);
237 if (manual_offload_keys_.empty()) {
238 compute_time_.resize(total_step_);
239 } else {
240 updated_ = true;
241 }
242 }
243
244 auto available_mem_size = mem_handler_->GetAvailableMemSize();
245 available_mem_size = FloatToSize(available_mem_size * mem_used_factor);
246 strategy_->set_mem_size(available_mem_size);
247 strategy_->Execute();
248 }
249
Optimize()250 bool MemScheduler::Optimize() {
251 float mem_used_factor = kMaxMemReuseFactor;
252 while (mem_used_factor >= kMinMemReuseFactor) {
253 bool ret = true;
254 OptMemUsage(mem_used_factor);
255 for (size_t mock_time = 0; mock_time < kMockTimes; ++mock_time) {
256 ret = Mock();
257 if (!ret) {
258 break;
259 }
260 }
261 if (ret) {
262 optimized_ = true;
263 return true;
264 }
265 Clear();
266 mem_used_factor -= kRetryFactor;
267 }
268 return false;
269 }
270
Mock()271 bool MemScheduler::Mock() {
272 current_step_ = 0;
273 for (size_t step = 0; step < total_step_; ++step) {
274 bool ret = PreCompute(nullptr);
275 if (!ret) {
276 return false;
277 }
278 auto &step_keys = step_keys_[step];
279 for (auto &key : step_keys) {
280 auto ptr = GetOrMalloc(key, 0);
281 if (ptr == nullptr) {
282 return false;
283 }
284 }
285 ret = PostCompute(nullptr);
286 if (!ret) {
287 return false;
288 }
289 }
290 return true;
291 }
292
Update()293 void MemScheduler::Update() {
294 if (!optimized_) {
295 return;
296 }
297
298 if (strategy_ == nullptr || !strategy_->need_swap()) {
299 return;
300 }
301
302 if (updated_) {
303 return;
304 }
305
306 if (!record_compute_time_) {
307 record_compute_time_ = true;
308 return;
309 }
310
311 strategy_->SetComputeTime(compute_time_);
312 strategy_->Execute();
313 updated_ = true;
314 }
315 } // namespace device
316 } // namespace mindspore
317