• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_TRACKER_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_TRACKER_H_
19 #include <mutex>
20 #include <vector>
21 #include <string>
22 #include <map>
23 #include <utility>
24 #include <memory>
25 #include "utils/ms_context.h"
26 #include "utils/ms_utils.h"
27 #include "utils/log_adapter.h"
28 #include "include/backend/visible.h"
29 #include "include/backend/device_address.h"
30 
31 namespace mindspore {
32 namespace device {
33 namespace tracker {
34 enum class MemType : int {
35   kWeight,
36   kConstantValue,
37   kKernel,
38   kGraphOutput,
39   kSomas,
40   kInSideSomas,
41   kSomasOutput,
42   kGeConst,
43   kBatchMemory,
44   kContinuousMemory,
45   kPyNativeInput,
46   kPyNativeOutput,
47   kGeFeatureMemory,
48   kWorkSpace,
49   kOther
50 };
51 
52 const std::map<MemType, std::string> MemTypeToStr = {{MemType::kWeight, "Weight"},
53                                                      {MemType::kConstantValue, "ConstantValue"},
54                                                      {MemType::kKernel, "Kernel"},
55                                                      {MemType::kGraphOutput, "GraphOutput"},
56                                                      {MemType::kSomas, "Somas"},
57                                                      {MemType::kInSideSomas, "InSideSomas"},
58                                                      {MemType::kSomasOutput, "SomasOutput"},
59                                                      {MemType::kGeConst, "GeConst"},
60                                                      {MemType::kBatchMemory, "BatchMemory"},
61                                                      {MemType::kContinuousMemory, "ContinuousMemory"},
62                                                      {MemType::kPyNativeInput, "PyNativeInput"},
63                                                      {MemType::kPyNativeOutput, "PyNativeOutput"},
64                                                      {MemType::kGeFeatureMemory, "GeFeatureMemory"},
65                                                      {MemType::kWorkSpace, "WorkSpace"},
66                                                      {MemType::kOther, "Other"}};
67 using DeviceMemPtr = const void *;
68 using KernelTensorPtr = const void *;
69 
70 struct TaskInfo {
71   std::string node_name;
72   std::string graph_name;
73   std::string task_name;
74   int64_t time_stamp;
75   // The code location of task execution
76   std::string file_name;
77   size_t line_num;
78   std::string python_stack;
TaskInfoTaskInfo79   TaskInfo() : node_name(), graph_name(), task_name(), time_stamp(0), file_name(), line_num(0) {}
80 };
81 
82 using TaskInfoPtr = std::shared_ptr<TaskInfo>;
83 
84 struct MemInfo;
85 struct MemBlockInfo {
86   // start and end use the operands of the memory pool
87   int64_t start_time_stamp;
88   int64_t end_time_stamp;
89   DeviceMemPtr device_addr;
90   std::weak_ptr<MemInfo> mem_info;
91   bool is_bind;
92   uint32_t stream_id;
93   size_t actual_peak_memory;
94   size_t size;
95   std::string pool_name;
96 
97   // Record mem info for profiling
98   double real_start_time{-1};
99   double real_end_time{-1};
100   size_t alloc_in_used_size{0};    // Record in used size when allocate mem
101   size_t alloc_total_size{0};      // Record total size when allocate mem
102   size_t release_in_used_size{0};  // Record in used size when release mem
103   size_t release_total_size{0};    // Record total size when release mem
MemBlockInfoMemBlockInfo104   MemBlockInfo()
105       : start_time_stamp(INT64_MAX),
106         end_time_stamp(INT64_MAX),
107         device_addr(nullptr),
108         is_bind(false),
109         stream_id(0),
110         actual_peak_memory(0),
111         size(0),
112         pool_name() {}
113 };
114 
115 using MemBlockInfoPtr = std::shared_ptr<MemBlockInfo>;
116 
117 struct MemInfo {
118   // mem info
119   MemType type;
120   size_t size;
121   KernelTensorPtr kernel_tensor;
122   // producer and user
123   std::vector<TaskInfoPtr> user_tasks;
124   TaskInfoPtr producer_task;
125   // mem block
126   MemBlockInfoPtr mem_block;
127   // Memory application code location
128   std::string file_name;
129   size_t line_num;
MemInfoMemInfo130   MemInfo() : type(MemType::kOther), size(0), kernel_tensor(nullptr), file_name(), line_num(0) {}
131 };
132 
133 using MemInfoPtr = std::shared_ptr<MemInfo>;
134 
135 // Struct for interaction with profiling
136 struct ProfileMemInfo {
137   std::string name;
138   size_t size;                  // size of block, B
139   double alloc_time;            // alloc time, us
140   double release_time;          // release time, us
141   size_t alloc_in_used_size;    // Record in used size when allocate mem, B
142   size_t alloc_total_size;      // Record total size when allocate mem, B
143   size_t release_in_used_size;  // Record in used size when release mem, B
144   size_t release_total_size;    // Record total size when release mem, B
145   std::string device;
ProfileMemInfoProfileMemInfo146   ProfileMemInfo()
147       : name(),
148         size(0),
149         alloc_time(-1),
150         release_time(-1),
151         alloc_in_used_size(0),
152         alloc_total_size(0),
153         release_in_used_size(0),
154         release_total_size(0),
155         device() {}
156 };
157 using ProfileMemInfoPtr = std::shared_ptr<ProfileMemInfo>;
158 
159 class BACKEND_EXPORT MemTracker {
160  public:
161   virtual void AddTask(const std::string &task_name, const std::string &node_name, const std::string &graph_name,
162                        const std::string &file_name, size_t line_num) = 0;
163   virtual void AddMemInfo(const std::string &task_name, MemType type, size_t size, DeviceAddress *device_address,
164                           const std::string &file_name, size_t line_num) = 0;
165   virtual void AddCompileTimeMemInfo(const std::string &task_name, size_t size, DeviceMemPtr device_ptr,
166                                      MemType mem_type, const std::string &file_name, size_t line_num) = 0;
167   virtual void UpdateMemInfo(const DeviceAddress *device_address, MemType mem_type, const std::string &file_name,
168                              size_t line_num) = 0;
169   virtual void AllocMemBlock(DeviceMemPtr device_addr, size_t size, const std::string &pool_name,
170                              size_t actual_peak_memory, size_t in_used_size, size_t total_size, uint32_t stream_id) = 0;
171   virtual void FreeMemBlock(DeviceMemPtr device_addr, size_t in_used_size, size_t total_size) = 0;
172   virtual void UseMemBlock(const std::string &task_name, DeviceMemPtr device_addr, const std::string &file_name,
173                            size_t line_num) = 0;
174   virtual void BindDevicePtr(DeviceAddress *kernel_tensor, DeviceMemPtr device_ptr, const std::string &file_name,
175                              size_t line_num) = 0;
176   virtual void UpdateDevicePtrInfo(DeviceMemPtr device_ptr, MemType mem_type, const std::string &task_name,
177                                    const std::string &file_name, size_t line_num) = 0;
178 
179   virtual void Dump() = 0;
180   virtual void UpdateProfilingPos() = 0;
181   virtual void DumpProfilingMemInfo(const std::string &path, const std::string &file_name) = 0;
182   virtual bool IsEnabled() = 0;
183   virtual ~MemTracker() = default;
184 };
185 
186 class BACKEND_EXPORT MemoryTrackerEnabled : public MemTracker {
187   friend class MemTrackerManager;
188 
189  public:
190   void AddTask(const std::string &task_name, const std::string &node_name, const std::string &graph_name,
191                const std::string &file_name, size_t line_num) override;
192   void AddMemInfo(const std::string &task_name, MemType type, size_t size, DeviceAddress *device_address,
193                   const std::string &file_name, size_t line_num) override;
194   void AddCompileTimeMemInfo(const std::string &task_name, size_t size, DeviceMemPtr device_ptr, MemType mem_type,
195                              const std::string &file_name, size_t line_num) override;
196   void UpdateMemInfo(const DeviceAddress *device_address, MemType mem_type, const std::string &file_name,
197                      size_t line_num) override;
198   void AllocMemBlock(DeviceMemPtr device_addr, size_t size, const std::string &pool_name, size_t actual_peak_memory,
199                      size_t in_used_size, size_t total_size, uint32_t stream_id) override;
200   void FreeMemBlock(DeviceMemPtr device_addr, size_t in_used_size, size_t total_size) override;
201   void UseMemBlock(const std::string &task_name, DeviceMemPtr device_addr, const std::string &file_name,
202                    size_t line_num) override;
203   void BindDevicePtr(DeviceAddress *device_address, DeviceMemPtr device_ptr, const std::string &file_name,
204                      size_t line_num) override;
205   void UpdateDevicePtrInfo(DeviceMemPtr device_ptr, MemType mem_type, const std::string &task_name,
206                            const std::string &file_name, size_t line_num) override;
207   void Dump() override;
208   void UpdateProfilingPos() override;
209   void DumpProfilingMemInfo(const std::string &path, const std::string &file_name) override;
210 
IsEnabled()211   bool IsEnabled() override { return true; }
212   std::pair<std::string, std::string> GetPath();
213   MemoryTrackerEnabled(const MemoryTrackerEnabled &) = delete;
214   MemoryTrackerEnabled &operator=(const MemoryTrackerEnabled &) = delete;
215 
216  private:
217   MemoryTrackerEnabled() = default;
218   ~MemoryTrackerEnabled() override = default;
WithPythonStack()219   bool WithPythonStack() {
220     static bool is_pynative = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
221     // PythonStack is no need in graph mode.
222     return is_pynative;
223   }
224 
225   MemInfoPtr NewMemInfo(const std::string &task_name, MemType type, size_t size, KernelTensorPtr kernel_tensor,
226                         const std::string &file_name, size_t line_num);
227 
228   void AddMemInfoForKernelTensor(const std::string &task_name, MemType type, size_t size, KernelTensorPtr kernel_tensor,
229                                  const std::string &file_name, size_t line_num);
230   std::mutex mutex_;
231   int64_t time_stamp_ = 0;
232   size_t last_profiling_pos_{0};  // Prevent the same data from being dumped.
233   // for dump
234   bool has_dump = false;
235   bool is_init_enable_hccl_ = false;
236   bool enable_hccl_ = false;
237   std::vector<TaskInfoPtr> task_list_;
238   std::vector<MemInfoPtr> mem_info_list_;
239   std::vector<MemBlockInfoPtr> mem_block_list_;
240   // actor name -> task info
241   std::map<std::string, TaskInfoPtr> task_map_;
242   // kernel tensor -> mem info
243   std::map<KernelTensorPtr, MemInfoPtr> kernel_tensor_mem_map;
244   // device address -> mem info
245   std::map<DeviceAddress *, MemInfoPtr> device_address_mem_map;
246   // device addr -> mem block info
247   std::map<DeviceMemPtr, MemBlockInfoPtr> device_mem_block_map;  // for somas
248   std::map<DeviceMemPtr, MemBlockInfoPtr> real_device_mem_block_map;
getInstance()249   static MemoryTrackerEnabled &getInstance() {
250     static MemoryTrackerEnabled instance;
251     return instance;
252   }
253 };
254 
255 class BACKEND_EXPORT MemoryTrackerDisabled : public MemTracker {
256   friend class MemTrackerManager;
257 
258  public:
259   // mock
AddTask(const std::string & task_name,const std::string & node_name,const std::string & graph_name,const std::string & file_name,size_t line_num)260   void AddTask(const std::string &task_name, const std::string &node_name, const std::string &graph_name,
261                const std::string &file_name, size_t line_num) override {}
AddMemInfo(const std::string & task_name,MemType type,size_t size,DeviceAddress * device_address,const std::string & file_name,const size_t line_num)262   void AddMemInfo(const std::string &task_name, MemType type, size_t size, DeviceAddress *device_address,
263                   const std::string &file_name, const size_t line_num) override {}
AddCompileTimeMemInfo(const std::string & task_name,size_t size,DeviceMemPtr device_ptr,MemType mem_type,const std::string & file_name,size_t line_num)264   void AddCompileTimeMemInfo(const std::string &task_name, size_t size, DeviceMemPtr device_ptr, MemType mem_type,
265                              const std::string &file_name, size_t line_num) override {}
UpdateMemInfo(const DeviceAddress * device_address,MemType mem_type,const std::string & file_name,size_t line_num)266   void UpdateMemInfo(const DeviceAddress *device_address, MemType mem_type, const std::string &file_name,
267                      size_t line_num) override {}
AllocMemBlock(DeviceMemPtr device_addr,size_t size,const std::string & pool_name,size_t actual_peak_memory,size_t in_used_size,size_t total_size,uint32_t stream_id)268   void AllocMemBlock(DeviceMemPtr device_addr, size_t size, const std::string &pool_name, size_t actual_peak_memory,
269                      size_t in_used_size, size_t total_size, uint32_t stream_id) override {}
FreeMemBlock(DeviceMemPtr device_addr,size_t in_used_size,size_t total_size)270   void FreeMemBlock(DeviceMemPtr device_addr, size_t in_used_size, size_t total_size) override {}
UseMemBlock(const std::string & task_name,DeviceMemPtr device_addr,const std::string & file_name,size_t line_num)271   void UseMemBlock(const std::string &task_name, DeviceMemPtr device_addr, const std::string &file_name,
272                    size_t line_num) override {}
BindDevicePtr(DeviceAddress * device_address,DeviceMemPtr device_ptr,const std::string & file_name,size_t line_num)273   void BindDevicePtr(DeviceAddress *device_address, DeviceMemPtr device_ptr, const std::string &file_name,
274                      size_t line_num) override {}
UpdateDevicePtrInfo(DeviceMemPtr device_ptr,MemType mem_type,const std::string & task_name,const std::string & file_name,size_t line_num)275   void UpdateDevicePtrInfo(DeviceMemPtr device_ptr, MemType mem_type, const std::string &task_name,
276                            const std::string &file_name, size_t line_num) override {}
Dump()277   void Dump() override {}
UpdateProfilingPos()278   void UpdateProfilingPos() override {}
DumpProfilingMemInfo(const std::string & path,const std::string & file_name)279   void DumpProfilingMemInfo(const std::string &path, const std::string &file_name) {}
IsEnabled()280   bool IsEnabled() override { return false; }
281   MemoryTrackerDisabled(const MemoryTrackerDisabled &) = delete;
282   MemoryTrackerDisabled &operator=(const MemoryTrackerDisabled &) = delete;
283 
284  private:
285   MemoryTrackerDisabled() = default;
286   ~MemoryTrackerDisabled() override = default;
getInstance()287   static MemoryTrackerDisabled &getInstance() {
288     static MemoryTrackerDisabled instance;
289     return instance;
290   }
291 };
292 
293 class BACKEND_EXPORT MemTrackerManager {
294  public:
GetInstance()295   static MemTracker &GetInstance() {
296     static bool enable_trace_mem = common::IsEnableAlllocConfig(common::kAllocMemoryTracker);
297     if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PROF_MEM) || enable_trace_mem) {
298       return MemoryTrackerEnabled::getInstance();
299     } else {
300       return MemoryTrackerDisabled::getInstance();
301     }
302   }
303 };
304 #define CALL_MEMORY_TRACKER_WITH_FILE(func, ...) MemTrackerManager::GetInstance().func(__VA_ARGS__, FILE_NAME, __LINE__)
305 #define CALL_MEMORY_TRACKER(func, ...) MemTrackerManager::GetInstance().func(__VA_ARGS__)
306 }  // namespace tracker
307 }  // namespace device
308 }  // namespace mindspore
309 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_TRACKER_H_
310