• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEVICE_RES_MANAGER_H_
17 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEVICE_RES_MANAGER_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include <map>
23 #include <unordered_map>
24 #include "runtime/hardware/device_context.h"
25 #include "utils/ms_context.h"
26 #include "include/transform/graph_ir/types.h"
27 #include "plugin/device/ascend/hal/hardware/ascend_collective_comm_lib.h"
28 #include "plugin/device/ascend/hal/hardware/dummy_ascend_collective_comm_lib.h"
29 #ifdef ENABLE_INTERNAL_KERNELS
30 #include "plugin/device/ascend/hal/hardware/lowlatency_collective_comm_lib.h"
31 #endif
32 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
33 #include "runtime/device/kernel_runtime_manager.h"
34 
35 namespace mindspore {
36 namespace device {
37 namespace ascend {
38 class GeHostAddress : public cpu::CPUDeviceAddress {
39  public:
GeHostAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)40   GeHostAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const std::string &device_name,
41                 uint32_t device_id)
42       : CPUDeviceAddress(ptr, size, format, type_id, device_name, device_id) {}
GeHostAddress(const KernelTensorPtr & kernel_tensor)43   explicit GeHostAddress(const KernelTensorPtr &kernel_tensor) : CPUDeviceAddress(kernel_tensor) {}
GetDeviceType()44   DeviceType GetDeviceType() const override { return DeviceType::kAscend; }
45 };
46 
47 class GeDeviceResManager;
48 class GeAllocator : public ::ge::Allocator {
49  public:
GeAllocator(GeDeviceResManager * res_manager)50   explicit GeAllocator(GeDeviceResManager *res_manager) : res_manager_(res_manager) {}
~GeAllocator()51   ~GeAllocator() { res_manager_ = nullptr; }
52   GeAllocator(const GeAllocator &) = delete;
53   GeAllocator &operator=(const GeAllocator &) = delete;
54   ::ge::MemBlock *Malloc(size_t size) override;
55   void Free(::ge::MemBlock *block) override;
56 
57  private:
58   GeDeviceResManager *res_manager_{nullptr};
59 };
60 
61 class GeDeviceResManager : public DeviceResManager {
62  public:
GeDeviceResManager()63   GeDeviceResManager() {}
64   ~GeDeviceResManager() override = default;
65 
66   void Initialize() override;
67 
68   void Destroy() override;
69 
70   std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list,
71                                                uint32_t stream_id = kDefaultStreamIndex) const override;
72 
73   DeviceAddressPtr CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const override;
74   DeviceAddressPtr CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format,
75                                        TypeId type_id, const std::string &device_name, uint32_t device_id,
76                                        uint32_t stream_id) const override;
77 
78   static void CreateSessionAndGraphRunner();
79   void MoveTo(const tensor::TensorPtr &src_tensor, const tensor::TensorPtr &dst_tensor, const std::string &to,
80               bool blocking, bool *return_self) override;
81 
82   bool LoadCollectiveCommLib() override;
83 
84   void ResetStreamAndCtx() override;
85   bool BindDeviceToCurrentThread(bool force_bind) const override;
GetStream()86   void *GetStream() const override {
87     MS_EXCEPTION_IF_NULL(runtime_instance_);
88     return runtime_instance_->compute_stream();
89   }
GetCopyDataStream()90   void *GetCopyDataStream() const {
91     MS_EXCEPTION_IF_NULL(runtime_instance_);
92     return runtime_instance_->copy_data_stream();
93   }
94 
95   // Relevant function to allocate and free device memory of raw ptr.
96   bool AllocateMemory(DeviceAddress *const &address, uint32_t stream_id = UINT32_MAX) const override;
97   void *AllocateMemory(size_t size, uint32_t stream_id = kDefaultStreamIndex) const override;
98   void FreeMemory(void *ptr) const override;
99   void FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs,
100                        const std::vector<size_t> &keep_addr_sizes) const override;
101   void DefragMemory() override;
102 
103   size_t GetMaxUsedMemorySize() const override;
104 
105   // Relevant function to manage memory statistics
106   size_t GetTotalMemStatistics() const override;
107   size_t GetTotalUsedMemStatistics() const override;
108   size_t GetTotalIdleMemStatistics() const override;
109   size_t GetTotalEagerFreeMemStatistics() const override;
110   size_t GetUsedMemPeakStatistics() const override;
111   size_t GetReservedMemPeakStatistics() const override;
112   std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const override;
113   std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const override;
114   std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetCommonMemBlocksInfoStatistics()
115     const override;
116   std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
117   GetPersistentMemBlocksInfoStatistics() const override;
118   void ResetMaxMemoryReserved() const override;
119   void ResetMaxMemoryAllocated() const override;
120 
GetAllocator()121   transform::GeAllocatorPtr GetAllocator() { return std::make_shared<GeAllocator>(this); }
122 
123   void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) override;
124   void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) override;
125 
126   bool CreateStream(size_t *stream_id) const override;
127   bool CreateStreamWithPriority(size_t *stream_id, int32_t priority) const override;
128   size_t QueryStreamSize() const override;
129   std::vector<uint32_t> GetStreamIds() const override;
130   void *GetStream(size_t stream_id) const override;
131   void SetCurrentStreamId(size_t stream_id) override;
132   size_t GetCurrentStreamId() const override;
133   bool QueryStream(size_t stream_id) const override;
134   bool SyncStream(size_t stream_id = 0) const override;
135   bool SyncAllStreams() const override;
136   bool SyncNotDefaultStreams() const override;
137   size_t DefaultStream() const override;
138 
139   DeviceEventPtr CreateRuntimeEvent(bool enable_blocking, bool enable_record_wait);
140   DeviceEventPtr CreateEventWithFlag(bool enable_timing, bool blocking) override;
141 
142   bool single_op_multi_stream_enable() const override;
143   void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) override;
144   // Only used in graph_mode with MS_DISABLE_REF_MODE, delete it when delete MS_DISABLE_REF_MODEF
145   void SetCPUMemManager();
146 
147  private:
148   friend class GeGraphExecutor;
149   static void GeSetContextOptions(const std::shared_ptr<MsContext> &ms_context_ptr, transform::SessionOptions *options);
150   static void GeSetReuseOptions(const std::string &key, size_t num, transform::SessionOptions *options);
151   KernelRuntime *runtime_instance_ = nullptr;
152   // Only used in graph_mode with MS_DISABLE_REF_MODE, delete it when delete MS_DISABLE_REF_MODE
153   bool is_use_cpu_memory_ = false;
154 };
155 }  // namespace ascend
156 }  // namespace device
157 }  // namespace mindspore
158 #endif  // MINDSPORE_CCSRC_RUNTIME_HARDWARE_ASCEND_GE_DEVICE_RES_MANAGER_H_
159