• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <map>
23 #include <queue>
24 #include "backend/optimizer/mem_reuse/mem_reuse.h"
25 #include "backend/optimizer/somas/somas.h"
26 #include "runtime/device/memory_scheduler.h"
27 namespace mindspore {
28 namespace device {
29 enum MemType { kStaticMem, kDynamicMem, kSomasReuseDynamicMem };
30 constexpr int kGetAllOuts = -1;
31 constexpr uint64_t kMemAlignSize = 512;
32 constexpr uint64_t kTwiceMemAlignSize = kMemAlignSize << 1;
33 using SomasPtr = mindspore::somas::SomasPtr;
34 
35 class MemoryManager : public MemHandler {
36  public:
37   MemoryManager() = default;
38   virtual ~MemoryManager() = default;
39 
40   virtual void MallocDeviceMemory() = 0;
41   virtual void FreeDeviceMemory() = 0;
ResetDynamicMemory()42   virtual void ResetDynamicMemory() {
43     total_dynamic_size_ = 0;
44     dynamic_mem_offset_ = 0;
45   }
ClearGlobalIdleMem()46   virtual void ClearGlobalIdleMem() {}
47 
48   virtual void MallocSomasDynamicMem(const session::KernelGraph &graph);
49   uint8_t *MallocOutputMem(const AnfNodePtr &node, size_t index, MemType type, size_t size,
50                            const DeviceAddressPtr &address, bool comm_mem);
51   uint8_t *MallocWorkSpaceMem(const AnfNodePtr &node, size_t index, MemType type, size_t size);
52   virtual uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address,
53                              uint32_t graph_id = kInvalidGraphId);
54 
55   virtual bool MallocMemFromMemPool(const DeviceAddressPtr address, size_t size);
56   virtual void *MallocMemFromMemPool(size_t size);
MallocCommunicationMemFromMemPool(size_t size)57   virtual uint8_t *MallocCommunicationMemFromMemPool(size_t size) { return nullptr; }
58   virtual void FreeMemFromMemPool(const DeviceAddressPtr address);
59   virtual void FreeMemFromMemPool(void *device_ptr);
60   virtual bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
61                                               std::vector<size_t> size_list);
62   virtual std::vector<void *> MallocContinuousMemFromMemPool(size_t total_size, std::vector<size_t> size_list);
63 
64   static size_t GetCommonAlignSize(size_t input_size);
65   static size_t GetCommunicationAlignSize(size_t input_size);
66 
67   // swap manager interface
MallocDevice(size_t mem_size)68   void *MallocDevice(size_t mem_size) override { return MallocMemFromMemPool(mem_size); }
FreeDevice(void * ptr)69   void FreeDevice(void *ptr) override {
70     MS_EXCEPTION_IF_NULL(ptr);
71     FreeMemFromMemPool(ptr);
72   }
MallocHost(size_t mem_size)73   void *MallocHost(size_t mem_size) override {
74     auto &mem_que = cached_host_mem_[mem_size];
75     if (!mem_que.empty()) {
76       auto ret = mem_que.front();
77       mem_que.pop();
78       return ret;
79     }
80     auto block = std::make_shared<std::vector<uint8_t>>();
81     try {
82       block->resize(mem_size, 0);
83       auto ptr = block->data();
84       host_mem_block_map_[ptr] = block;
85       return ptr;
86     } catch (const std::exception &e) {
87       MS_LOG(EXCEPTION) << "Malloc memory failed: size " << mem_size;
88     }
89   }
FreeHost(void * ptr)90   void FreeHost(void *ptr) override {
91     MS_EXCEPTION_IF_NULL(ptr);
92     auto iter = host_mem_block_map_.find(ptr);
93     if (iter == host_mem_block_map_.end()) {
94       MS_LOG(ERROR) << "Free ptr not be created from manager!";
95     }
96     auto mem_size = iter->second->size();
97     cached_host_mem_[mem_size].emplace(iter->first);
98   }
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)99   void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override {
100     MS_LOG(INFO) << "Call default swap in " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
101   }
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)102   void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override {
103     MS_LOG(INFO) << "Call default swap out " << host_ptr << "," << device_ptr << "," << mem_size << "," << stream;
104   }
GetAvailableMemSize()105   size_t GetAvailableMemSize() override {
106     MS_LOG(ERROR) << "Return default 0 mem size!";
107     return 0;
108   }
109 
110  protected:
111   virtual uint8_t *MallocStaticMem(size_t size, bool communication_mem, uint32_t graph_id = kInvalidGraphId) = 0;
112   virtual uint8_t *MallocDynamicMem(size_t size, bool communication_mem);
113   uint8_t *device_mem_base_{nullptr};
114   uint64_t device_mem_size_{0};
115   uint64_t dynamic_mem_offset_{0};
116   uint64_t static_mem_offset_{0};
117   size_t total_static_size_ = 0;
118   size_t total_dynamic_size_ = 0;
119   SomasPtr somas_reuse_util_ptr_{nullptr};
120   std::map<size_t, std::queue<void *>> cached_host_mem_;
121   std::map<void *, std::shared_ptr<std::vector<uint8_t>>> host_mem_block_map_;
122 };
123 }  // namespace device
124 }  // namespace mindspore
125 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_MEMORY_MANAGER_H_
126