1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_ 19 20 #include <memory> 21 #include <string> 22 #include "include/backend/device_address.h" 23 #include "runtime/hardware/device_context.h" 24 #include "runtime/hardware/device_context_manager.h" 25 26 namespace mindspore { 27 namespace device { 28 struct SwapEvent { NeedWaitSwapEvent29 bool NeedWait() const { 30 return aio_token_ != kInvalidAsyncIOToken || (device_event_ != nullptr && device_event_->NeedWait()); 31 } 32 AsyncIOToken aio_token_{kInvalidAsyncIOToken}; 33 std::shared_ptr<DeviceEvent> device_event_{nullptr}; 34 }; 35 using SwapEventPtr = std::shared_ptr<SwapEvent>; 36 struct LoadableMember { 37 bool mem_offloaded_{false}; 38 void *offload_ptr_{nullptr}; 39 mutable SwapEvent swap_event_; 40 mutable StorageInfo storage_info_{nullptr}; 41 bool swappable_{false}; 42 }; 43 using LoadableMemberPtr = std::unique_ptr<LoadableMember>; 44 45 // LoadableDeviceAddress provide the ability to offload data on device to ddr or disk and load it back later. 46 class BACKEND_EXPORT LoadableDeviceAddress : public DeviceAddress { 47 public: LoadableDeviceAddress(const KernelTensorPtr & kernel_tensor)48 explicit LoadableDeviceAddress(const KernelTensorPtr &kernel_tensor) : DeviceAddress(kernel_tensor) {} LoadableDeviceAddress(void * ptr,size_t size)49 LoadableDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} LoadableDeviceAddress(void * ptr,size_t size,const string & format,TypeId type_id)50 LoadableDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) 51 : DeviceAddress(ptr, size, format, type_id) {} LoadableDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index)52 LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, 53 const KernelWithIndex &node_index) 54 : DeviceAddress(ptr, size, format, type_id, node_index) {} LoadableDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)55 LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, 56 const std::string &device_name, uint32_t device_id) 57 : DeviceAddress(ptr, size, format, type_id, device_name, device_id) {} LoadableDeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id)58 LoadableDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format, TypeId type_id, 59 const std::string &device_name, uint32_t device_id, uint32_t stream_id) 60 : DeviceAddress(ptr, size, shape_vector, format, type_id, device_name, device_id, stream_id) {} LoadableDeviceAddress(void * ptr,size_t size,const std::string & device_name,uint32_t device_id)61 LoadableDeviceAddress(void *ptr, size_t size, const std::string &device_name, uint32_t device_id) 62 : DeviceAddress(ptr, size, device_name, device_id) {} LoadableDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index,const std::string & device_name,uint32_t device_id)63 LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, 64 const KernelWithIndex &node_index, const std::string &device_name, uint32_t device_id) 65 : DeviceAddress(ptr, size, format, type_id, node_index, device_name, device_id) {} 66 mem_offloaded()67 bool mem_offloaded() const final { 68 if (loadable_mem_ == nullptr) { 69 return false; 70 } 71 return loadable_mem_->mem_offloaded_; 72 } 73 74 // Offload data from device to host and free device memory 75 bool Offload(size_t stream_id) final; 76 77 // Load data from host to device and free host memory 78 bool Load(size_t stream_id) final; 79 80 // Move data to destination hardware and free resource on source hardware 81 bool MoveTo(StorageType dest, bool async, size_t stream_id) override; 82 83 bool Wait() const override; 84 85 void SetStorageInfo(const StorageInfo &storage_info) final; 86 StorageInfo GetStorageInfo() const final; 87 88 // Set host ptr data offloaded to 89 void SetOffloadPtr(void *offload_ptr) final; 90 // Get offloaded host ptr 91 void *GetOffloadPtr() const final; 92 93 // Return whether DeviceAddress has a valid ptr. 94 bool IsPtrValid() const final; 95 96 // Load first if data is offloaded and return the device ptr. 97 void *GetValidPtr(size_t stream_id) final; 98 99 void Swap(DeviceAddress *other) override; 100 DeviceToFileDirectly(void * ptr,size_t size,const std::string & file_name,size_t stream_id)101 virtual bool DeviceToFileDirectly(void *ptr, size_t size, const std::string &file_name, size_t stream_id) const { 102 return false; 103 } 104 FileToDeviceDirectly(void * ptr,size_t size,const std::string & file_name,size_t stream_id)105 virtual bool FileToDeviceDirectly(void *ptr, size_t size, const std::string &file_name, size_t stream_id) const { 106 return false; 107 } 108 set_swappable(bool swappable)109 void set_swappable(bool swappable) override { 110 if (loadable_mem_ == nullptr) { 111 loadable_mem_ = std::make_unique<LoadableMember>(); 112 } 113 loadable_mem_->swappable_ = swappable; 114 } swappable()115 bool swappable() override { 116 auto swappable = loadable_mem_ == nullptr ? false : loadable_mem_->swappable_; 117 return swappable && !(status_ == DeviceAddressStatus::kInDevice && GetDevicePtr() == nullptr); 118 } 119 120 protected: GetDeviceContext()121 DeviceContext *GetDeviceContext() const { 122 DeviceContext *device_context = nullptr; 123 device_context = DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name(), device_id()}); 124 return device_context; 125 } 126 127 bool MoveToDevice(bool async, size_t stream_id = kDefaultStreamIndex) const; 128 bool MoveToHost(bool async, size_t stream_id = kDefaultStreamIndex) const; 129 bool MoveToFile(bool async, size_t stream_id = kDefaultStreamIndex) const; 130 CopyDeviceToHost(void * dst,const void * src,size_t size,bool async,size_t stream_id)131 virtual bool CopyDeviceToHost(void *dst, const void *src, size_t size, bool async, size_t stream_id) const { 132 return false; 133 } CopyHostToDevice(void * dst,const void * src,size_t size,bool async,size_t stream_id)134 virtual bool CopyHostToDevice(void *dst, const void *src, size_t size, bool async, size_t stream_id) const { 135 return false; 136 } 137 virtual bool CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const; 138 virtual bool CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const; 139 140 void ReleaseResource(); 141 142 std::string GetSwapFileName() const; 143 size_t GetFileAlignSize() const; 144 145 mutable LoadableMemberPtr loadable_mem_{nullptr}; 146 }; 147 } // namespace device 148 } // namespace mindspore 149 150 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_ 151