1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/loadable_device_address.h"
18 #include "include/common/debug/common.h"
19 #include "include/common/utils/offload_context.h"
20 #include "utils/file_utils.h"
21
22 namespace mindspore {
23 namespace device {
24 namespace {
25 constexpr size_t kFileAlignSize = 512;
26 constexpr char kSwapFileSuffix[] = ".data";
27 } // namespace
28
Offload(size_t stream_id)29 bool LoadableDeviceAddress::Offload(size_t stream_id) {
30 if (loadable_mem_ == nullptr) {
31 loadable_mem_ = std::make_unique<LoadableMember>();
32 }
33 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
34 if (loadable_mem_->mem_offloaded_) {
35 MS_LOG(WARNING) << "Trying to offload an offloaded AscendDeviceAddress.";
36 return true;
37 }
38 MS_EXCEPTION_IF_NULL(GetDevicePtr());
39 auto device_context = GetDeviceContext();
40 MS_EXCEPTION_IF_NULL(device_context);
41 loadable_mem_->offload_ptr_ = device_context->device_res_manager_->AllocateOffloadMemory(GetSize());
42 if (loadable_mem_->offload_ptr_ == nullptr) {
43 MS_LOG(EXCEPTION) << "Alloc host memory for offloading failed, size: " << GetSize() << ".";
44 }
45 if (!AsyncDeviceToHost({}, GetSize(), kTypeUnknown, loadable_mem_->offload_ptr_, stream_id)) {
46 return false;
47 }
48 device_context->device_res_manager_->FreeMemory(GetDevicePtr());
49 SetDevicePtr(nullptr);
50 loadable_mem_->mem_offloaded_ = true;
51 return true;
52 }
53
Load(size_t stream_id)54 bool LoadableDeviceAddress::Load(size_t stream_id) {
55 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
56 if (loadable_mem_ == nullptr || !loadable_mem_->mem_offloaded_) {
57 MS_LOG(DEBUG) << "Trying to load a loaded AscendDeviceAddress.";
58 return true;
59 }
60 MS_EXCEPTION_IF_NULL(loadable_mem_->offload_ptr_);
61 auto device_context = GetDeviceContext();
62 MS_EXCEPTION_IF_NULL(device_context);
63 if (GetDevicePtr() == nullptr && !device_context->device_res_manager_->AllocateMemory(this)) {
64 MS_LOG(EXCEPTION) << "Alloc memory for loading failed, size: " << GetSize() << ".";
65 }
66 MS_EXCEPTION_IF_NULL(GetDevicePtr());
67 if (!AsyncHostToDevice({}, GetSize(), kTypeUnknown, loadable_mem_->offload_ptr_, stream_id)) {
68 return false;
69 }
70 device_context->device_res_manager_->FreeOffloadMemory(loadable_mem_->offload_ptr_);
71 loadable_mem_->offload_ptr_ = nullptr;
72 loadable_mem_->mem_offloaded_ = false;
73 return true;
74 }
75
MoveTo(mindspore::device::StorageType dst,bool async,size_t stream_id)76 bool LoadableDeviceAddress::MoveTo(mindspore::device::StorageType dst, bool async, size_t stream_id) {
77 bool ret = Wait();
78 if (!ret) {
79 MS_LOG(WARNING) << "Wait swapping DeviceAddress failed. Status: " << status_;
80 return false;
81 }
82 if (status_ == DeviceAddressStatus::kInDevice && GetDevicePtr() == nullptr) {
83 MS_LOG(INFO) << "Skip move empty device address.";
84 return true;
85 }
86 if (dst == StorageType::kDevice) {
87 if (!MoveToDevice(async, stream_id)) {
88 MS_LOG(WARNING) << "Move data to device failed.";
89 return false;
90 }
91 } else if (dst == StorageType::kHost) {
92 if (!MoveToHost(async, stream_id)) {
93 MS_LOG(WARNING) << "Move data to host failed.";
94 return false;
95 }
96 } else if (dst == StorageType::kFile) {
97 if (!MoveToFile(async, stream_id)) {
98 MS_LOG(WARNING) << "Move data to file failed.";
99 return false;
100 }
101 }
102 return true;
103 }
104
MoveToHost(bool async,size_t stream_id) const105 bool LoadableDeviceAddress::MoveToHost(bool async, size_t stream_id) const {
106 const auto device_context = GetDeviceContext();
107 MS_EXCEPTION_IF_NULL(device_context);
108 const auto swap_manager = device_context->device_res_manager_->swap_manager();
109 MS_EXCEPTION_IF_NULL(swap_manager);
110 if (loadable_mem_ == nullptr) {
111 loadable_mem_ = std::make_unique<LoadableMember>();
112 }
113 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
114 if (loadable_mem_->storage_info_.host_ptr_ == nullptr || loadable_mem_->storage_info_.host_ptr_mutable_) {
115 loadable_mem_->storage_info_.host_ptr_ = swap_manager->AllocHostMemory(GetFileAlignSize());
116 if (loadable_mem_->storage_info_.host_ptr_ == nullptr) {
117 MS_LOG(WARNING) << "Allocating host memory failed, size: " << GetSize();
118 return false;
119 }
120 }
121 if (status_ == DeviceAddressStatus::kInFile) {
122 if (!CopyFileToHost(loadable_mem_->storage_info_.host_ptr_, loadable_mem_->storage_info_.file_name_, GetSize(),
123 async)) {
124 MS_LOG(WARNING) << "Copy data from file to host failed.";
125 return false;
126 }
127 if (async) {
128 swap_manager->AddSwappingTensor(this);
129 status_ = DeviceAddressStatus::kInFileToHost;
130 } else {
131 if (loadable_mem_->storage_info_.file_name_mutable_) {
132 (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
133 loadable_mem_->storage_info_.file_name_ = "";
134 }
135 status_ = DeviceAddressStatus::kInHost;
136 }
137 } else {
138 if (!CopyDeviceToHost(loadable_mem_->storage_info_.host_ptr_, GetDevicePtr(), GetSize(), async, stream_id)) {
139 MS_LOG(WARNING) << "Copy data from device to host failed.";
140 return false;
141 }
142 if (async) {
143 swap_manager->AddSwappingTensor(this);
144 status_ = DeviceAddressStatus::kInDeviceToHost;
145 } else {
146 swap_manager->FreeDeviceMemory(GetDevicePtr());
147 SetDevicePtr(nullptr);
148 status_ = DeviceAddressStatus::kInHost;
149 }
150 }
151 return true;
152 }
153
MoveToDevice(bool async,size_t stream_id) const154 bool LoadableDeviceAddress::MoveToDevice(bool async, size_t stream_id) const {
155 if (status_ == DeviceAddressStatus::kInDevice) {
156 return true;
157 }
158 const auto device_context = GetDeviceContext();
159 MS_EXCEPTION_IF_NULL(device_context);
160 const auto swap_manager = device_context->device_res_manager_->swap_manager();
161 MS_EXCEPTION_IF_NULL(swap_manager);
162 MS_EXCEPTION_IF_NULL(loadable_mem_);
163 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
164 if (status_ == DeviceAddressStatus::kInFile) {
165 #if defined(RT_MEMORY_P2PDMA)
166 if (GetDevicePtr() == nullptr) {
167 SetDevicePtr(swap_manager->AllocDeviceMemory(GetSize(), stream_id));
168 }
169 MS_EXCEPTION_IF_NULL(GetDevicePtr());
170 if (FileToDeviceDirectly(GetDevicePtr(), GetSize(), loadable_mem_->storage_info_.file_name_, stream_id)) {
171 if (loadable_mem_->storage_info_.file_name_mutable_ && !loadable_mem_->storage_info_.file_name_.empty()) {
172 (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
173 loadable_mem_->storage_info_.file_name_ = "";
174 }
175 if (loadable_mem_->storage_info_.host_ptr_mutable_) {
176 swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
177 loadable_mem_->storage_info_.host_ptr_ = nullptr;
178 }
179 status_ = DeviceAddressStatus::kInDevice;
180 return true;
181 }
182 #endif
183 if (!MoveToHost(false, stream_id)) {
184 return false;
185 }
186 }
187 if (GetDevicePtr() == nullptr) {
188 SetDevicePtr(swap_manager->AllocDeviceMemory(GetSize(), stream_id));
189 if (GetDevicePtr() == nullptr) {
190 MS_LOG(WARNING) << "Allocating device memory failed, size: " << GetSize();
191 return false;
192 }
193 }
194 if (!CopyHostToDevice(GetDevicePtr(), loadable_mem_->storage_info_.host_ptr_, GetSize(), async, stream_id)) {
195 MS_LOG(WARNING) << "Copy data from host to device failed.";
196 return false;
197 }
198 if (async) {
199 swap_manager->AddSwappingTensor(this);
200 status_ = DeviceAddressStatus::kInHostToDevice;
201 } else {
202 if (loadable_mem_->storage_info_.host_ptr_mutable_) {
203 swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
204 loadable_mem_->storage_info_.host_ptr_ = nullptr;
205 }
206
207 status_ = DeviceAddressStatus::kInDevice;
208 }
209 return true;
210 }
211
MoveToFile(bool async,size_t stream_id) const212 bool LoadableDeviceAddress::MoveToFile(bool async, size_t stream_id) const {
213 if (status_ == DeviceAddressStatus::kInFile) {
214 return true;
215 }
216 const auto device_context = GetDeviceContext();
217 MS_EXCEPTION_IF_NULL(device_context);
218 const auto swap_manager = device_context->device_res_manager_->swap_manager();
219 MS_EXCEPTION_IF_NULL(swap_manager);
220 if (loadable_mem_ == nullptr) {
221 loadable_mem_ = std::make_unique<LoadableMember>();
222 }
223 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
224 if (status_ == DeviceAddressStatus::kInDevice) {
225 #if defined(RT_MEMORY_P2PDMA)
226 if (loadable_mem_->storage_info_.file_name_.empty() || loadable_mem_->storage_info_.file_name_mutable_) {
227 loadable_mem_->storage_info_.file_name_ = GetSwapFileName();
228 }
229 if (DeviceToFileDirectly(GetDevicePtr(), GetSize(), loadable_mem_->storage_info_.file_name_, stream_id)) {
230 status_ = DeviceAddressStatus::kInFile;
231 if (GetDevicePtr() != nullptr) {
232 swap_manager->FreeDeviceMemory(GetDevicePtr());
233 SetDevicePtr(nullptr);
234 }
235 if (loadable_mem_->storage_info_.host_ptr_ != nullptr) {
236 swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
237 loadable_mem_->storage_info_.host_ptr_ = nullptr;
238 }
239 return true;
240 }
241 #endif
242 if (!MoveToHost(false, stream_id)) {
243 return false;
244 }
245 }
246 if (loadable_mem_->storage_info_.file_name_.empty() || loadable_mem_->storage_info_.file_name_mutable_) {
247 loadable_mem_->storage_info_.file_name_ = GetSwapFileName();
248 if (!swap_manager->CreateFile(loadable_mem_->storage_info_.file_name_, GetFileAlignSize())) {
249 MS_LOG(WARNING) << "Create file for swapping failed.";
250 return false;
251 }
252 }
253 if (!CopyHostToFile(loadable_mem_->storage_info_.file_name_, loadable_mem_->storage_info_.host_ptr_, GetSize(),
254 async)) {
255 MS_LOG(WARNING) << "Copy data from host to file failed.";
256 return false;
257 }
258 if (async) {
259 swap_manager->AddSwappingTensor(this);
260 status_ = DeviceAddressStatus::kInHostToFile;
261 } else {
262 if (loadable_mem_->storage_info_.host_ptr_mutable_) {
263 swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
264 loadable_mem_->storage_info_.host_ptr_ = nullptr;
265 }
266 status_ = DeviceAddressStatus::kInFile;
267 }
268 return true;
269 }
270
CopyHostToFile(const std::string & dst,const void * src,size_t size,bool async) const271 bool LoadableDeviceAddress::CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const {
272 MS_EXCEPTION_IF_NULL(src);
273 const auto device_context = GetDeviceContext();
274 MS_EXCEPTION_IF_NULL(device_context);
275 const auto swap_manager = device_context->device_res_manager_->swap_manager();
276 MS_EXCEPTION_IF_NULL(swap_manager);
277 AsyncIOToken token;
278 bool ret = swap_manager->HostMemoryToFile(dst, src, size, async, &token);
279 if (!ret) {
280 MS_LOG(WARNING) << "Write data from ddr to file[" << dst << "] failed.";
281 return ret;
282 }
283 if (async) {
284 MS_EXCEPTION_IF_NULL(loadable_mem_);
285 loadable_mem_->swap_event_.aio_token_ = token;
286 }
287 return ret;
288 }
289
CopyFileToHost(void * dst,const std::string & src,size_t size,bool async) const290 bool LoadableDeviceAddress::CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const {
291 MS_EXCEPTION_IF_NULL(dst);
292 const auto device_context = GetDeviceContext();
293 MS_EXCEPTION_IF_NULL(device_context);
294 const auto swap_manager = device_context->device_res_manager_->swap_manager();
295 MS_EXCEPTION_IF_NULL(swap_manager);
296 AsyncIOToken token;
297 bool ret = swap_manager->FileToHostMemory(dst, src, size, async, &token);
298 if (!ret) {
299 MS_LOG(WARNING) << "Read data from file[" << src << "] to ddr failed.";
300 return ret;
301 }
302 if (async) {
303 MS_EXCEPTION_IF_NULL(loadable_mem_);
304 loadable_mem_->swap_event_.aio_token_ = token;
305 }
306 return true;
307 }
308
ReleaseResource()309 void LoadableDeviceAddress::ReleaseResource() {
310 if (loadable_mem_ == nullptr || status_ == DeviceAddressStatus::kInDevice) {
311 return;
312 }
313
314 const bool need_delete_file =
315 !loadable_mem_->storage_info_.file_name_.empty() && loadable_mem_->storage_info_.file_name_mutable_;
316 const bool need_free_host =
317 loadable_mem_->storage_info_.host_ptr_ != nullptr && loadable_mem_->storage_info_.host_ptr_mutable_;
318 if (need_delete_file || need_free_host) {
319 auto device_context = GetDeviceContext();
320 MS_EXCEPTION_IF_NULL(device_context);
321 const auto swap_manager = device_context->device_res_manager_->swap_manager();
322 MS_EXCEPTION_IF_NULL(swap_manager);
323 if (need_delete_file) {
324 (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
325 }
326 if (need_free_host) {
327 swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
328 }
329 }
330 loadable_mem_ = nullptr;
331 }
332
GetFileAlignSize() const333 size_t LoadableDeviceAddress::GetFileAlignSize() const {
334 return (GetSize() + kFileAlignSize - 1) / kFileAlignSize * kFileAlignSize;
335 }
336
GetSwapFileName() const337 std::string LoadableDeviceAddress::GetSwapFileName() const {
338 static size_t swap_file_index = 0;
339 std::string file_dir;
340 const auto &offload_context = OffloadContext::GetInstance();
341 if (offload_context != nullptr) {
342 const auto real_dir = FileUtils::GetRealPath(offload_context->offload_path().c_str());
343 if (!real_dir.has_value()) {
344 MS_LOG(EXCEPTION) << "Invalid offload path[" << offload_context->offload_path()
345 << "]. Please check offload_path configuration.";
346 }
347 file_dir = real_dir.value() + "/";
348 }
349 return file_dir + std::to_string(device_id()) + "_" + std::to_string(swap_file_index++) + "_" +
350 std::to_string(Common::GetTimeStamp()) + kSwapFileSuffix;
351 }
352
SetStorageInfo(const StorageInfo & storage_info)353 void LoadableDeviceAddress::SetStorageInfo(const StorageInfo &storage_info) {
354 if (loadable_mem_ == nullptr) {
355 loadable_mem_ = std::make_unique<LoadableMember>();
356 }
357 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
358 loadable_mem_->storage_info_ = storage_info;
359 if (loadable_mem_->storage_info_.host_ptr_ != nullptr) {
360 status_ = DeviceAddressStatus::kInHost;
361 loadable_mem_->storage_info_.host_ptr_mutable_ = false;
362 } else if (!loadable_mem_->storage_info_.file_name_.empty()) {
363 status_ = DeviceAddressStatus::kInFile;
364 loadable_mem_->storage_info_.file_name_mutable_ = false;
365 } else {
366 status_ = DeviceAddressStatus::kInDevice;
367 }
368 }
369
GetStorageInfo() const370 StorageInfo LoadableDeviceAddress::GetStorageInfo() const {
371 if (loadable_mem_ == nullptr) {
372 loadable_mem_ = std::make_unique<LoadableMember>();
373 }
374 return loadable_mem_->storage_info_;
375 }
376
Swap(mindspore::device::DeviceAddress * other)377 void LoadableDeviceAddress::Swap(mindspore::device::DeviceAddress *other) {
378 DeviceAddress::Swap(other);
379 if (other == this) {
380 return;
381 }
382 auto loadable_device_address = reinterpret_cast<LoadableDeviceAddress *>(other);
383 if (loadable_device_address != nullptr) {
384 if (loadable_mem_ == nullptr) {
385 loadable_mem_ = std::make_unique<LoadableMember>();
386 }
387 if (loadable_device_address->loadable_mem_ == nullptr) {
388 loadable_device_address->loadable_mem_ = std::make_unique<LoadableMember>();
389 }
390 loadable_device_address->loadable_mem_->storage_info_ = loadable_mem_->storage_info_;
391 loadable_device_address->status_ = status_;
392 loadable_device_address->loadable_mem_->offload_ptr_ = loadable_mem_->offload_ptr_;
393 loadable_device_address->loadable_mem_->mem_offloaded_ = loadable_mem_->mem_offloaded_;
394 loadable_mem_->storage_info_.host_ptr_ = nullptr;
395 loadable_mem_->storage_info_.file_name_ = "";
396 loadable_mem_->storage_info_.host_ptr_mutable_ = true;
397 loadable_mem_->storage_info_.file_name_mutable_ = true;
398 status_ = DeviceAddressStatus::kInDevice;
399 loadable_mem_->offload_ptr_ = nullptr;
400 loadable_mem_->mem_offloaded_ = false;
401 }
402 }
403
Wait() const404 bool LoadableDeviceAddress::Wait() const {
405 if (loadable_mem_ == nullptr || !loadable_mem_->swap_event_.NeedWait()) {
406 return true;
407 }
408 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
409 const auto device_context = GetDeviceContext();
410 MS_EXCEPTION_IF_NULL(device_context);
411 const auto swap_manager = device_context->device_res_manager_->swap_manager();
412 MS_EXCEPTION_IF_NULL(swap_manager);
413 if (loadable_mem_->swap_event_.device_event_ != nullptr && loadable_mem_->swap_event_.device_event_->NeedWait()) {
414 loadable_mem_->swap_event_.device_event_->WaitEvent();
415 } else if (loadable_mem_->swap_event_.aio_token_ != kInvalidAsyncIOToken) {
416 if (!swap_manager->WaitAsyncIO(loadable_mem_->swap_event_.aio_token_)) {
417 MS_LOG(WARNING) << "Wait aio failed.";
418 return false;
419 }
420 } else {
421 MS_LOG(WARNING) << "Device address is in moving, but no valid swap event can be found.";
422 }
423 if (status_ == DeviceAddressStatus::kInFileToHost) {
424 if (loadable_mem_->storage_info_.file_name_mutable_) {
425 (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
426 loadable_mem_->storage_info_.file_name_ = "";
427 }
428 status_ = DeviceAddressStatus::kInHost;
429 } else if (status_ == DeviceAddressStatus::kInDeviceToHost) {
430 swap_manager->FreeDeviceMemory(GetDevicePtr());
431 status_ = DeviceAddressStatus::kInHost;
432 } else {
433 if (loadable_mem_->storage_info_.host_ptr_mutable_) {
434 swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
435 loadable_mem_->storage_info_.host_ptr_ = nullptr;
436 }
437 if (status_ == DeviceAddressStatus::kInHostToDevice) {
438 status_ = DeviceAddressStatus::kInHost;
439 } else {
440 status_ = DeviceAddressStatus::kInFile;
441 }
442 }
443 return true;
444 }
445
SetOffloadPtr(void * offload_ptr)446 void LoadableDeviceAddress::SetOffloadPtr(void *offload_ptr) {
447 if (loadable_mem_ == nullptr) {
448 loadable_mem_ = std::make_unique<LoadableMember>();
449 }
450 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
451 loadable_mem_->offload_ptr_ = offload_ptr;
452 loadable_mem_->mem_offloaded_ = (offload_ptr != nullptr);
453 }
454
GetOffloadPtr() const455 void *LoadableDeviceAddress::GetOffloadPtr() const {
456 MS_EXCEPTION_IF_NULL(loadable_mem_);
457 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
458 return loadable_mem_->offload_ptr_;
459 }
460
461 // Return whether DeviceAddress has a valid ptr.
IsPtrValid() const462 bool LoadableDeviceAddress::IsPtrValid() const {
463 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
464 return GetDevicePtr() != nullptr || (loadable_mem_ != nullptr && (loadable_mem_->offload_ptr_ != nullptr ||
465 loadable_mem_->storage_info_.host_ptr_ != nullptr ||
466 !loadable_mem_->storage_info_.file_name_.empty()));
467 }
468
469 // Load first if data is offloaded and return the device ptr.
GetValidPtr(size_t stream_id)470 void *LoadableDeviceAddress::GetValidPtr(size_t stream_id) {
471 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
472 if (mem_offloaded() && !Load(stream_id)) {
473 MS_LOG(EXCEPTION) << "Load offloaded memory failed";
474 }
475 if (!MoveToDevice(false)) {
476 MS_LOG(ERROR) << "Move data to device failed.";
477 return nullptr;
478 }
479
480 return DeviceAddress::GetValidPtr(stream_id);
481 }
482 } // namespace device
483 } // namespace mindspore
484