• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
19 
20 #include <memory>
21 #include <string>
22 #include "include/backend/device_address.h"
23 #include "runtime/hardware/device_context.h"
24 #include "runtime/hardware/device_context_manager.h"
25 
26 namespace mindspore {
27 namespace device {
28 struct SwapEvent {
NeedWaitSwapEvent29   bool NeedWait() const {
30     return aio_token_ != kInvalidAsyncIOToken || (device_event_ != nullptr && device_event_->NeedWait());
31   }
32   AsyncIOToken aio_token_{kInvalidAsyncIOToken};
33   std::shared_ptr<DeviceEvent> device_event_{nullptr};
34 };
35 using SwapEventPtr = std::shared_ptr<SwapEvent>;
36 struct LoadableMember {
37   bool mem_offloaded_{false};
38   void *offload_ptr_{nullptr};
39   mutable SwapEvent swap_event_;
40   mutable StorageInfo storage_info_{nullptr};
41   bool swappable_{false};
42 };
43 using LoadableMemberPtr = std::unique_ptr<LoadableMember>;
44 
45 // LoadableDeviceAddress provide the ability to offload data on device to ddr or disk and load it back later.
46 class BACKEND_EXPORT LoadableDeviceAddress : public DeviceAddress {
47  public:
LoadableDeviceAddress(const KernelTensorPtr & kernel_tensor)48   explicit LoadableDeviceAddress(const KernelTensorPtr &kernel_tensor) : DeviceAddress(kernel_tensor) {}
LoadableDeviceAddress(void * ptr,size_t size)49   LoadableDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
LoadableDeviceAddress(void * ptr,size_t size,const string & format,TypeId type_id)50   LoadableDeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id)
51       : DeviceAddress(ptr, size, format, type_id) {}
LoadableDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index)52   LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
53                         const KernelWithIndex &node_index)
54       : DeviceAddress(ptr, size, format, type_id, node_index) {}
LoadableDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)55   LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
56                         const std::string &device_name, uint32_t device_id)
57       : DeviceAddress(ptr, size, format, type_id, device_name, device_id) {}
LoadableDeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id)58   LoadableDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format, TypeId type_id,
59                         const std::string &device_name, uint32_t device_id, uint32_t stream_id)
60       : DeviceAddress(ptr, size, shape_vector, format, type_id, device_name, device_id, stream_id) {}
LoadableDeviceAddress(void * ptr,size_t size,const std::string & device_name,uint32_t device_id)61   LoadableDeviceAddress(void *ptr, size_t size, const std::string &device_name, uint32_t device_id)
62       : DeviceAddress(ptr, size, device_name, device_id) {}
LoadableDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index,const std::string & device_name,uint32_t device_id)63   LoadableDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
64                         const KernelWithIndex &node_index, const std::string &device_name, uint32_t device_id)
65       : DeviceAddress(ptr, size, format, type_id, node_index, device_name, device_id) {}
66 
mem_offloaded()67   bool mem_offloaded() const final {
68     if (loadable_mem_ == nullptr) {
69       return false;
70     }
71     return loadable_mem_->mem_offloaded_;
72   }
73 
74   // Offload data from device to host and free device memory
75   bool Offload(size_t stream_id) final;
76 
77   // Load data from host to device and free host memory
78   bool Load(size_t stream_id) final;
79 
80   // Move data to destination hardware and free resource on source hardware
81   bool MoveTo(StorageType dest, bool async, size_t stream_id) override;
82 
83   bool Wait() const override;
84 
85   void SetStorageInfo(const StorageInfo &storage_info) final;
86   StorageInfo GetStorageInfo() const final;
87 
88   // Set host ptr data offloaded to
89   void SetOffloadPtr(void *offload_ptr) final;
90   // Get offloaded host ptr
91   void *GetOffloadPtr() const final;
92 
93   // Return whether DeviceAddress has a valid ptr.
94   bool IsPtrValid() const final;
95 
96   // Load first if data is offloaded and return the device ptr.
97   void *GetValidPtr(size_t stream_id) final;
98 
99   void Swap(DeviceAddress *other) override;
100 
DeviceToFileDirectly(void * ptr,size_t size,const std::string & file_name,size_t stream_id)101   virtual bool DeviceToFileDirectly(void *ptr, size_t size, const std::string &file_name, size_t stream_id) const {
102     return false;
103   }
104 
FileToDeviceDirectly(void * ptr,size_t size,const std::string & file_name,size_t stream_id)105   virtual bool FileToDeviceDirectly(void *ptr, size_t size, const std::string &file_name, size_t stream_id) const {
106     return false;
107   }
108 
set_swappable(bool swappable)109   void set_swappable(bool swappable) override {
110     if (loadable_mem_ == nullptr) {
111       loadable_mem_ = std::make_unique<LoadableMember>();
112     }
113     loadable_mem_->swappable_ = swappable;
114   }
swappable()115   bool swappable() override {
116     auto swappable = loadable_mem_ == nullptr ? false : loadable_mem_->swappable_;
117     return swappable && !(status_ == DeviceAddressStatus::kInDevice && GetDevicePtr() == nullptr);
118   }
119 
120  protected:
GetDeviceContext()121   DeviceContext *GetDeviceContext() const {
122     DeviceContext *device_context = nullptr;
123     device_context = DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name(), device_id()});
124     return device_context;
125   }
126 
127   bool MoveToDevice(bool async, size_t stream_id = kDefaultStreamIndex) const;
128   bool MoveToHost(bool async, size_t stream_id = kDefaultStreamIndex) const;
129   bool MoveToFile(bool async, size_t stream_id = kDefaultStreamIndex) const;
130 
CopyDeviceToHost(void * dst,const void * src,size_t size,bool async,size_t stream_id)131   virtual bool CopyDeviceToHost(void *dst, const void *src, size_t size, bool async, size_t stream_id) const {
132     return false;
133   }
CopyHostToDevice(void * dst,const void * src,size_t size,bool async,size_t stream_id)134   virtual bool CopyHostToDevice(void *dst, const void *src, size_t size, bool async, size_t stream_id) const {
135     return false;
136   }
137   virtual bool CopyHostToFile(const std::string &dst, const void *src, size_t size, bool async) const;
138   virtual bool CopyFileToHost(void *dst, const std::string &src, size_t size, bool async) const;
139 
140   void ReleaseResource();
141 
142   std::string GetSwapFileName() const;
143   size_t GetFileAlignSize() const;
144 
145   mutable LoadableMemberPtr loadable_mem_{nullptr};
146 };
147 }  // namespace device
148 }  // namespace mindspore
149 
150 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_LOADABLE_DEVICE_ADDRESS_H_
151