• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/loadable_device_address.h"
18 #include "include/common/debug/common.h"
19 #include "include/common/utils/offload_context.h"
20 #include "utils/file_utils.h"
21 
22 namespace mindspore {
23 namespace device {
24 namespace {
25 constexpr size_t kFileAlignSize = 512;
26 constexpr char kSwapFileSuffix[] = ".data";
27 }  // namespace
28 
Offload(size_t stream_id)29 bool LoadableDeviceAddress::Offload(size_t stream_id) {
30   if (loadable_mem_ == nullptr) {
31     loadable_mem_ = std::make_unique<LoadableMember>();
32   }
33   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
34   if (loadable_mem_->mem_offloaded_) {
35     MS_LOG(WARNING) << "Trying to offload an offloaded AscendDeviceAddress.";
36     return true;
37   }
38   MS_EXCEPTION_IF_NULL(GetDevicePtr());
39   auto device_context = GetDeviceContext();
40   MS_EXCEPTION_IF_NULL(device_context);
41   loadable_mem_->offload_ptr_ = device_context->device_res_manager_->AllocateOffloadMemory(GetSize());
42   if (loadable_mem_->offload_ptr_ == nullptr) {
43     MS_LOG(EXCEPTION) << "Alloc host memory for offloading failed, size: " << GetSize() << ".";
44   }
45   if (!AsyncDeviceToHost({}, GetSize(), kTypeUnknown, loadable_mem_->offload_ptr_, stream_id)) {
46     return false;
47   }
48   device_context->device_res_manager_->FreeMemory(GetDevicePtr());
49   SetDevicePtr(nullptr);
50   loadable_mem_->mem_offloaded_ = true;
51   return true;
52 }
53 
Load(size_t stream_id)54 bool LoadableDeviceAddress::Load(size_t stream_id) {
55   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
56   if (loadable_mem_ == nullptr || !loadable_mem_->mem_offloaded_) {
57     MS_LOG(DEBUG) << "Trying to load a loaded AscendDeviceAddress.";
58     return true;
59   }
60   MS_EXCEPTION_IF_NULL(loadable_mem_->offload_ptr_);
61   auto device_context = GetDeviceContext();
62   MS_EXCEPTION_IF_NULL(device_context);
63   if (GetDevicePtr() == nullptr && !device_context->device_res_manager_->AllocateMemory(this)) {
64     MS_LOG(EXCEPTION) << "Alloc memory for loading failed, size: " << GetSize() << ".";
65   }
66   MS_EXCEPTION_IF_NULL(GetDevicePtr());
67   if (!AsyncHostToDevice({}, GetSize(), kTypeUnknown, loadable_mem_->offload_ptr_, stream_id)) {
68     return false;
69   }
70   device_context->device_res_manager_->FreeOffloadMemory(loadable_mem_->offload_ptr_);
71   loadable_mem_->offload_ptr_ = nullptr;
72   loadable_mem_->mem_offloaded_ = false;
73   return true;
74 }
75 
MoveTo(mindspore::device::StorageType dst,bool async,size_t stream_id)76 bool LoadableDeviceAddress::MoveTo(mindspore::device::StorageType dst, bool async, size_t stream_id) {
77   bool ret = Wait();
78   if (!ret) {
79     MS_LOG(WARNING) << "Wait swapping DeviceAddress failed. Status: " << status_;
80     return false;
81   }
82   if (status_ == DeviceAddressStatus::kInDevice && GetDevicePtr() == nullptr) {
83     MS_LOG(INFO) << "Skip move empty device address.";
84     return true;
85   }
86   if (dst == StorageType::kDevice) {
87     if (!MoveToDevice(async, stream_id)) {
88       MS_LOG(WARNING) << "Move data to device failed.";
89       return false;
90     }
91   } else if (dst == StorageType::kHost) {
92     if (!MoveToHost(async, stream_id)) {
93       MS_LOG(WARNING) << "Move data to host failed.";
94       return false;
95     }
96   } else if (dst == StorageType::kFile) {
97     if (!MoveToFile(async, stream_id)) {
98       MS_LOG(WARNING) << "Move data to file failed.";
99       return false;
100     }
101   }
102   return true;
103 }
104 
MoveToHost(bool async,size_t stream_id) const105 bool LoadableDeviceAddress::MoveToHost(bool async, size_t stream_id) const {
106   const auto device_context = GetDeviceContext();
107   MS_EXCEPTION_IF_NULL(device_context);
108   const auto swap_manager = device_context->device_res_manager_->swap_manager();
109   MS_EXCEPTION_IF_NULL(swap_manager);
110   if (loadable_mem_ == nullptr) {
111     loadable_mem_ = std::make_unique<LoadableMember>();
112   }
113   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
114   if (loadable_mem_->storage_info_.host_ptr_ == nullptr || loadable_mem_->storage_info_.host_ptr_mutable_) {
115     loadable_mem_->storage_info_.host_ptr_ = swap_manager->AllocHostMemory(GetFileAlignSize());
116     if (loadable_mem_->storage_info_.host_ptr_ == nullptr) {
117       MS_LOG(WARNING) << "Allocating host memory failed, size: " << GetSize();
118       return false;
119     }
120   }
121   if (status_ == DeviceAddressStatus::kInFile) {
122     if (!CopyFileToHost(loadable_mem_->storage_info_.host_ptr_, loadable_mem_->storage_info_.file_name_, GetSize(),
123                         async)) {
124       MS_LOG(WARNING) << "Copy data from file to host failed.";
125       return false;
126     }
127     if (async) {
128       swap_manager->AddSwappingTensor(this);
129       status_ = DeviceAddressStatus::kInFileToHost;
130     } else {
131       if (loadable_mem_->storage_info_.file_name_mutable_) {
132         (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
133         loadable_mem_->storage_info_.file_name_ = "";
134       }
135       status_ = DeviceAddressStatus::kInHost;
136     }
137   } else {
138     if (!CopyDeviceToHost(loadable_mem_->storage_info_.host_ptr_, GetDevicePtr(), GetSize(), async, stream_id)) {
139       MS_LOG(WARNING) << "Copy data from device to host failed.";
140       return false;
141     }
142     if (async) {
143       swap_manager->AddSwappingTensor(this);
144       status_ = DeviceAddressStatus::kInDeviceToHost;
145     } else {
146       swap_manager->FreeDeviceMemory(GetDevicePtr());
147       SetDevicePtr(nullptr);
148       status_ = DeviceAddressStatus::kInHost;
149     }
150   }
151   return true;
152 }
153 
MoveToDevice(bool async,size_t stream_id) const154 bool LoadableDeviceAddress::MoveToDevice(bool async, size_t stream_id) const {
155   if (status_ == DeviceAddressStatus::kInDevice) {
156     return true;
157   }
158   const auto device_context = GetDeviceContext();
159   MS_EXCEPTION_IF_NULL(device_context);
160   const auto swap_manager = device_context->device_res_manager_->swap_manager();
161   MS_EXCEPTION_IF_NULL(swap_manager);
162   MS_EXCEPTION_IF_NULL(loadable_mem_);
163   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
164   if (status_ == DeviceAddressStatus::kInFile) {
165 #if defined(RT_MEMORY_P2PDMA)
166     if (GetDevicePtr() == nullptr) {
167       SetDevicePtr(swap_manager->AllocDeviceMemory(GetSize(), stream_id));
168     }
169     MS_EXCEPTION_IF_NULL(GetDevicePtr());
170     if (FileToDeviceDirectly(GetDevicePtr(), GetSize(), loadable_mem_->storage_info_.file_name_, stream_id)) {
171       if (loadable_mem_->storage_info_.file_name_mutable_ && !loadable_mem_->storage_info_.file_name_.empty()) {
172         (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
173         loadable_mem_->storage_info_.file_name_ = "";
174       }
175       if (loadable_mem_->storage_info_.host_ptr_mutable_) {
176         swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
177         loadable_mem_->storage_info_.host_ptr_ = nullptr;
178       }
179       status_ = DeviceAddressStatus::kInDevice;
180       return true;
181     }
182 #endif
183     if (!MoveToHost(false, stream_id)) {
184       return false;
185     }
186   }
187   if (GetDevicePtr() == nullptr) {
188     SetDevicePtr(swap_manager->AllocDeviceMemory(GetSize(), stream_id));
189     if (GetDevicePtr() == nullptr) {
190       MS_LOG(WARNING) << "Allocating device memory failed, size: " << GetSize();
191       return false;
192     }
193   }
194   if (!CopyHostToDevice(GetDevicePtr(), loadable_mem_->storage_info_.host_ptr_, GetSize(), async, stream_id)) {
195     MS_LOG(WARNING) << "Copy data from host to device failed.";
196     return false;
197   }
198   if (async) {
199     swap_manager->AddSwappingTensor(this);
200     status_ = DeviceAddressStatus::kInHostToDevice;
201   } else {
202     if (loadable_mem_->storage_info_.host_ptr_mutable_) {
203       swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
204       loadable_mem_->storage_info_.host_ptr_ = nullptr;
205     }
206 
207     status_ = DeviceAddressStatus::kInDevice;
208   }
209   return true;
210 }
211 
MoveToFile(bool async,size_t stream_id) const212 bool LoadableDeviceAddress::MoveToFile(bool async, size_t stream_id) const {
213   if (status_ == DeviceAddressStatus::kInFile) {
214     return true;
215   }
216   const auto device_context = GetDeviceContext();
217   MS_EXCEPTION_IF_NULL(device_context);
218   const auto swap_manager = device_context->device_res_manager_->swap_manager();
219   MS_EXCEPTION_IF_NULL(swap_manager);
220   if (loadable_mem_ == nullptr) {
221     loadable_mem_ = std::make_unique<LoadableMember>();
222   }
223   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
224   if (status_ == DeviceAddressStatus::kInDevice) {
225 #if defined(RT_MEMORY_P2PDMA)
226     if (loadable_mem_->storage_info_.file_name_.empty() || loadable_mem_->storage_info_.file_name_mutable_) {
227       loadable_mem_->storage_info_.file_name_ = GetSwapFileName();
228     }
229     if (DeviceToFileDirectly(GetDevicePtr(), GetSize(), loadable_mem_->storage_info_.file_name_, stream_id)) {
230       status_ = DeviceAddressStatus::kInFile;
231       if (GetDevicePtr() != nullptr) {
232         swap_manager->FreeDeviceMemory(GetDevicePtr());
233         SetDevicePtr(nullptr);
234       }
235       if (loadable_mem_->storage_info_.host_ptr_ != nullptr) {
236         swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
237         loadable_mem_->storage_info_.host_ptr_ = nullptr;
238       }
239       return true;
240     }
241 #endif
242     if (!MoveToHost(false, stream_id)) {
243       return false;
244     }
245   }
246   if (loadable_mem_->storage_info_.file_name_.empty() || loadable_mem_->storage_info_.file_name_mutable_) {
247     loadable_mem_->storage_info_.file_name_ = GetSwapFileName();
248     if (!swap_manager->CreateFile(loadable_mem_->storage_info_.file_name_, GetFileAlignSize())) {
249       MS_LOG(WARNING) << "Create file for swapping failed.";
250       return false;
251     }
252   }
253   if (!CopyHostToFile(loadable_mem_->storage_info_.file_name_, loadable_mem_->storage_info_.host_ptr_, GetSize(),
254                       async)) {
255     MS_LOG(WARNING) << "Copy data from host to file failed.";
256     return false;
257   }
258   if (async) {
259     swap_manager->AddSwappingTensor(this);
260     status_ = DeviceAddressStatus::kInHostToFile;
261   } else {
262     if (loadable_mem_->storage_info_.host_ptr_mutable_) {
263       swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
264       loadable_mem_->storage_info_.host_ptr_ = nullptr;
265     }
266     status_ = DeviceAddressStatus::kInFile;
267   }
268   return true;
269 }
270 
CopyHostToFile(const std::string & dst,const void * src,size_t size,bool async) const271 bool LoadableDeviceAddress::CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const {
272   MS_EXCEPTION_IF_NULL(src);
273   const auto device_context = GetDeviceContext();
274   MS_EXCEPTION_IF_NULL(device_context);
275   const auto swap_manager = device_context->device_res_manager_->swap_manager();
276   MS_EXCEPTION_IF_NULL(swap_manager);
277   AsyncIOToken token;
278   bool ret = swap_manager->HostMemoryToFile(dst, src, size, async, &token);
279   if (!ret) {
280     MS_LOG(WARNING) << "Write data from ddr to file[" << dst << "] failed.";
281     return ret;
282   }
283   if (async) {
284     MS_EXCEPTION_IF_NULL(loadable_mem_);
285     loadable_mem_->swap_event_.aio_token_ = token;
286   }
287   return ret;
288 }
289 
CopyFileToHost(void * dst,const std::string & src,size_t size,bool async) const290 bool LoadableDeviceAddress::CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const {
291   MS_EXCEPTION_IF_NULL(dst);
292   const auto device_context = GetDeviceContext();
293   MS_EXCEPTION_IF_NULL(device_context);
294   const auto swap_manager = device_context->device_res_manager_->swap_manager();
295   MS_EXCEPTION_IF_NULL(swap_manager);
296   AsyncIOToken token;
297   bool ret = swap_manager->FileToHostMemory(dst, src, size, async, &token);
298   if (!ret) {
299     MS_LOG(WARNING) << "Read data from file[" << src << "] to ddr failed.";
300     return ret;
301   }
302   if (async) {
303     MS_EXCEPTION_IF_NULL(loadable_mem_);
304     loadable_mem_->swap_event_.aio_token_ = token;
305   }
306   return true;
307 }
308 
ReleaseResource()309 void LoadableDeviceAddress::ReleaseResource() {
310   if (loadable_mem_ == nullptr || status_ == DeviceAddressStatus::kInDevice) {
311     return;
312   }
313 
314   const bool need_delete_file =
315     !loadable_mem_->storage_info_.file_name_.empty() && loadable_mem_->storage_info_.file_name_mutable_;
316   const bool need_free_host =
317     loadable_mem_->storage_info_.host_ptr_ != nullptr && loadable_mem_->storage_info_.host_ptr_mutable_;
318   if (need_delete_file || need_free_host) {
319     auto device_context = GetDeviceContext();
320     MS_EXCEPTION_IF_NULL(device_context);
321     const auto swap_manager = device_context->device_res_manager_->swap_manager();
322     MS_EXCEPTION_IF_NULL(swap_manager);
323     if (need_delete_file) {
324       (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
325     }
326     if (need_free_host) {
327       swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
328     }
329   }
330   loadable_mem_ = nullptr;
331 }
332 
GetFileAlignSize() const333 size_t LoadableDeviceAddress::GetFileAlignSize() const {
334   return (GetSize() + kFileAlignSize - 1) / kFileAlignSize * kFileAlignSize;
335 }
336 
GetSwapFileName() const337 std::string LoadableDeviceAddress::GetSwapFileName() const {
338   static size_t swap_file_index = 0;
339   std::string file_dir;
340   const auto &offload_context = OffloadContext::GetInstance();
341   if (offload_context != nullptr) {
342     const auto real_dir = FileUtils::GetRealPath(offload_context->offload_path().c_str());
343     if (!real_dir.has_value()) {
344       MS_LOG(EXCEPTION) << "Invalid offload path[" << offload_context->offload_path()
345                         << "]. Please check offload_path configuration.";
346     }
347     file_dir = real_dir.value() + "/";
348   }
349   return file_dir + std::to_string(device_id()) + "_" + std::to_string(swap_file_index++) + "_" +
350          std::to_string(Common::GetTimeStamp()) + kSwapFileSuffix;
351 }
352 
SetStorageInfo(const StorageInfo & storage_info)353 void LoadableDeviceAddress::SetStorageInfo(const StorageInfo &storage_info) {
354   if (loadable_mem_ == nullptr) {
355     loadable_mem_ = std::make_unique<LoadableMember>();
356   }
357   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
358   loadable_mem_->storage_info_ = storage_info;
359   if (loadable_mem_->storage_info_.host_ptr_ != nullptr) {
360     status_ = DeviceAddressStatus::kInHost;
361     loadable_mem_->storage_info_.host_ptr_mutable_ = false;
362   } else if (!loadable_mem_->storage_info_.file_name_.empty()) {
363     status_ = DeviceAddressStatus::kInFile;
364     loadable_mem_->storage_info_.file_name_mutable_ = false;
365   } else {
366     status_ = DeviceAddressStatus::kInDevice;
367   }
368 }
369 
GetStorageInfo() const370 StorageInfo LoadableDeviceAddress::GetStorageInfo() const {
371   if (loadable_mem_ == nullptr) {
372     loadable_mem_ = std::make_unique<LoadableMember>();
373   }
374   return loadable_mem_->storage_info_;
375 }
376 
Swap(mindspore::device::DeviceAddress * other)377 void LoadableDeviceAddress::Swap(mindspore::device::DeviceAddress *other) {
378   DeviceAddress::Swap(other);
379   if (other == this) {
380     return;
381   }
382   auto loadable_device_address = reinterpret_cast<LoadableDeviceAddress *>(other);
383   if (loadable_device_address != nullptr) {
384     if (loadable_mem_ == nullptr) {
385       loadable_mem_ = std::make_unique<LoadableMember>();
386     }
387     if (loadable_device_address->loadable_mem_ == nullptr) {
388       loadable_device_address->loadable_mem_ = std::make_unique<LoadableMember>();
389     }
390     loadable_device_address->loadable_mem_->storage_info_ = loadable_mem_->storage_info_;
391     loadable_device_address->status_ = status_;
392     loadable_device_address->loadable_mem_->offload_ptr_ = loadable_mem_->offload_ptr_;
393     loadable_device_address->loadable_mem_->mem_offloaded_ = loadable_mem_->mem_offloaded_;
394     loadable_mem_->storage_info_.host_ptr_ = nullptr;
395     loadable_mem_->storage_info_.file_name_ = "";
396     loadable_mem_->storage_info_.host_ptr_mutable_ = true;
397     loadable_mem_->storage_info_.file_name_mutable_ = true;
398     status_ = DeviceAddressStatus::kInDevice;
399     loadable_mem_->offload_ptr_ = nullptr;
400     loadable_mem_->mem_offloaded_ = false;
401   }
402 }
403 
Wait() const404 bool LoadableDeviceAddress::Wait() const {
405   if (loadable_mem_ == nullptr || !loadable_mem_->swap_event_.NeedWait()) {
406     return true;
407   }
408   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
409   const auto device_context = GetDeviceContext();
410   MS_EXCEPTION_IF_NULL(device_context);
411   const auto swap_manager = device_context->device_res_manager_->swap_manager();
412   MS_EXCEPTION_IF_NULL(swap_manager);
413   if (loadable_mem_->swap_event_.device_event_ != nullptr && loadable_mem_->swap_event_.device_event_->NeedWait()) {
414     loadable_mem_->swap_event_.device_event_->WaitEvent();
415   } else if (loadable_mem_->swap_event_.aio_token_ != kInvalidAsyncIOToken) {
416     if (!swap_manager->WaitAsyncIO(loadable_mem_->swap_event_.aio_token_)) {
417       MS_LOG(WARNING) << "Wait aio failed.";
418       return false;
419     }
420   } else {
421     MS_LOG(WARNING) << "Device address is in moving, but no valid swap event can be found.";
422   }
423   if (status_ == DeviceAddressStatus::kInFileToHost) {
424     if (loadable_mem_->storage_info_.file_name_mutable_) {
425       (void)swap_manager->DeleteFile(loadable_mem_->storage_info_.file_name_);
426       loadable_mem_->storage_info_.file_name_ = "";
427     }
428     status_ = DeviceAddressStatus::kInHost;
429   } else if (status_ == DeviceAddressStatus::kInDeviceToHost) {
430     swap_manager->FreeDeviceMemory(GetDevicePtr());
431     status_ = DeviceAddressStatus::kInHost;
432   } else {
433     if (loadable_mem_->storage_info_.host_ptr_mutable_) {
434       swap_manager->FreeHostMemory(loadable_mem_->storage_info_.host_ptr_);
435       loadable_mem_->storage_info_.host_ptr_ = nullptr;
436     }
437     if (status_ == DeviceAddressStatus::kInHostToDevice) {
438       status_ = DeviceAddressStatus::kInHost;
439     } else {
440       status_ = DeviceAddressStatus::kInFile;
441     }
442   }
443   return true;
444 }
445 
SetOffloadPtr(void * offload_ptr)446 void LoadableDeviceAddress::SetOffloadPtr(void *offload_ptr) {
447   if (loadable_mem_ == nullptr) {
448     loadable_mem_ = std::make_unique<LoadableMember>();
449   }
450   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
451   loadable_mem_->offload_ptr_ = offload_ptr;
452   loadable_mem_->mem_offloaded_ = (offload_ptr != nullptr);
453 }
454 
GetOffloadPtr() const455 void *LoadableDeviceAddress::GetOffloadPtr() const {
456   MS_EXCEPTION_IF_NULL(loadable_mem_);
457   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
458   return loadable_mem_->offload_ptr_;
459 }
460 
461 // Return whether DeviceAddress has a valid ptr.
IsPtrValid() const462 bool LoadableDeviceAddress::IsPtrValid() const {
463   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
464   return GetDevicePtr() != nullptr || (loadable_mem_ != nullptr && (loadable_mem_->offload_ptr_ != nullptr ||
465                                                                     loadable_mem_->storage_info_.host_ptr_ != nullptr ||
466                                                                     !loadable_mem_->storage_info_.file_name_.empty()));
467 }
468 
469 // Load first if data is offloaded and return the device ptr.
GetValidPtr(size_t stream_id)470 void *LoadableDeviceAddress::GetValidPtr(size_t stream_id) {
471   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
472   if (mem_offloaded() && !Load(stream_id)) {
473     MS_LOG(EXCEPTION) << "Load offloaded memory failed";
474   }
475   if (!MoveToDevice(false)) {
476     MS_LOG(ERROR) << "Move data to device failed.";
477     return nullptr;
478   }
479 
480   return DeviceAddress::GetValidPtr(stream_id);
481 }
482 }  // namespace device
483 }  // namespace mindspore
484