• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
18 
19 #include <memory>
20 #include <queue>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
26 #include "include/backend/device_address.h"
27 #include "runtime/device/gsm/io_handle.h"
28 #include "runtime/device/gsm/pin_mem_pool.h"
29 #include "include/backend/kernel_info.h"
30 #include "include/backend/visible.h"
31 
32 namespace mindspore {
33 namespace device {
34 class SwappableTensorCandidates {
35   using CandidateItem = std::pair<std::weak_ptr<DeviceAddress>, DeviceAddress *>;
36 
37  public:
38   class CandidateIter {
39    public:
40     explicit CandidateIter(SwappableTensorCandidates *candidates);
41     bool IsEnd();
42     void Next();
43     DeviceAddressPtr Get();
44 
45    private:
46     size_t current_size_level_{0};
47     size_t current_candidate_idx_{0};
48     std::vector<std::vector<CandidateItem>> &swappable_tensors_;
49     std::vector<std::queue<size_t>> &null_index_;
50     HashSet<DeviceAddress *> &all_swappable_tensors_;
51   };
52   void Init(size_t size_level_num);
53   DeviceAddressPtr GetLowerBoundCandidate(size_t size);
54   CandidateIter Begin();
55   void Add(const DeviceAddressPtr &candidate);
56 
57  private:
58   size_t GetSizeLevel(size_t size) const;
59 
60   size_t size_level_num_;
61   std::vector<std::vector<CandidateItem>> swappable_tensors_;
62   std::vector<std::queue<size_t>> null_index_;
63   HashSet<DeviceAddress *> all_swappable_tensors_;
64 };
65 
66 class BACKEND_EXPORT SwapManager {
67  public:
68   SwapManager(size_t stream_id, DynamicMemPoolBestFit *device_memory_pool, PinMemPool *pin_mem_pool);
69   ~SwapManager() = default;
70   // Device memory
71   void *AllocDeviceMemory(size_t size, uint32_t stream_id = kDefaultStreamIndex);
72   std::vector<void *> AllocDeviceContinuousMem(const std::vector<size_t> &size_list,
73                                                uint32_t stream_id = kDefaultStreamIndex);
74   void FreeDeviceMemory(void *ptr);
75 
76   // Host memory
77   void *AllocHostMemory(size_t size);
78   void FreeHostMemory(void *ptr);
79 
80   // File
81   bool CreateFile(const std::string &file_name, size_t file_size);
82   bool DeleteFile(const std::string &file_name);
83   bool FileToHostMemory(void *host_memory, const std::string &file_name, size_t byte_num, bool async,
84                         AsyncIOToken *sync_token);
85   bool HostMemoryToFile(const std::string &file_name, const void *data, size_t byte_num, bool async,
86                         AsyncIOToken *sync_token);
87   bool WaitAsyncIO(AsyncIOToken sync_token);
88 
89   // Swapping and swappable tensors
90   void AddSwappableTensor(const DeviceAddressPtr &device_address);
91   void AddSwappingTensor(const DeviceAddress *device_address);
92 
93   void SetSwappableBeforeMemAllocate(const std::vector<DeviceAddress *> &inputs,
94                                      const std::vector<DeviceAddress *> &outputs) const;
95   void SetSwappableBeforeMemFree(const std::vector<DeviceAddress *> &inputs,
96                                  const std::vector<DeviceAddress *> &outputs, const KernelInfo *kernel_info) const;
GetPinMemPool()97   PinMemPool *GetPinMemPool() { return pin_mem_pool_; }
98 
99  private:
100   void *AllocDeviceMemorySimply(const size_t &size, uint32_t stream_id = kDefaultStreamIndex);
101   std::vector<void *> AllocDeviceContinuousMemSimply(const std::vector<size_t> &size_list,
102                                                      uint32_t stream_id = kDefaultStreamIndex);
103   void *AllocHostMemorySimply(const size_t &size, uint32_t /*stream_id*/);
104   bool EnoughFileSpace(const size_t &size, uint32_t /*stream_id*/);
105 
106   template <class Input, class Output>
107   bool TryAllocate(std::queue<const DeviceAddress *> queue, const Input &input, uint32_t stream_id,
108                    Output (SwapManager::*allocate_func)(const Input &, uint32_t),
109                    const std::function<bool(Output)> &success, Output *output);
110   template <class Input, class Output>
111   bool SwapOutTemp(const std::pair<DeviceAddressStatus, StorageType> &swap_type, size_t total_size, const Input &input,
112                    uint32_t stream_id, Output (SwapManager::*allocate_func)(const Input &, uint32_t),
113                    const std::function<bool(Output)> &success, Output *output);
114 
115  private:
116   size_t stream_id_;
117   DynamicMemPoolBestFit *device_memory_pool_;
118   PinMemPool *pin_mem_pool_;
119   size_t max_file_size_{0};
120   size_t current_used_file_size_{0};
121   HashMap<std::string, size_t> file_size_;
122   struct compare {
operatorcompare123     bool operator()(const DeviceAddressPtr &l, const DeviceAddressPtr &r) const { return l->GetSize() < r->GetSize(); }
124   };
125   SwappableTensorCandidates candidates_;
126   const size_t size_level_num_{0};
127   std::mutex swapping_tensors_device_mutex_;
128   std::queue<const DeviceAddress *> swapping_tensors_device_;
129   std::mutex swapping_tensors_host_mutex_;
130   std::queue<const DeviceAddress *> swapping_tensors_host_;
131   std::mutex swapping_tensors_file_mutex_;
132   std::queue<const DeviceAddress *> swapping_tensors_file_;
133   IOHandlePtr io_handle_;
134 };
135 }  // namespace device
136 }  // namespace mindspore
137 
138 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_GSM_SWAP_MANAGER_H_
139