1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_TRACKER_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_TRACKER_H_ 19 #include <mutex> 20 #include <vector> 21 #include <string> 22 #include <map> 23 #include <utility> 24 #include <memory> 25 #include "utils/ms_context.h" 26 #include "utils/ms_utils.h" 27 #include "utils/log_adapter.h" 28 #include "include/backend/visible.h" 29 #include "include/backend/device_address.h" 30 31 namespace mindspore { 32 namespace device { 33 namespace tracker { 34 enum class MemType : int { 35 kWeight, 36 kConstantValue, 37 kKernel, 38 kGraphOutput, 39 kSomas, 40 kInSideSomas, 41 kSomasOutput, 42 kGeConst, 43 kBatchMemory, 44 kContinuousMemory, 45 kPyNativeInput, 46 kPyNativeOutput, 47 kGeFeatureMemory, 48 kWorkSpace, 49 kOther 50 }; 51 52 const std::map<MemType, std::string> MemTypeToStr = {{MemType::kWeight, "Weight"}, 53 {MemType::kConstantValue, "ConstantValue"}, 54 {MemType::kKernel, "Kernel"}, 55 {MemType::kGraphOutput, "GraphOutput"}, 56 {MemType::kSomas, "Somas"}, 57 {MemType::kInSideSomas, "InSideSomas"}, 58 {MemType::kSomasOutput, "SomasOutput"}, 59 {MemType::kGeConst, "GeConst"}, 60 {MemType::kBatchMemory, "BatchMemory"}, 61 {MemType::kContinuousMemory, "ContinuousMemory"}, 62 {MemType::kPyNativeInput, "PyNativeInput"}, 63 {MemType::kPyNativeOutput, "PyNativeOutput"}, 64 {MemType::kGeFeatureMemory, "GeFeatureMemory"}, 65 {MemType::kWorkSpace, "WorkSpace"}, 66 {MemType::kOther, "Other"}}; 67 using DeviceMemPtr = const void *; 68 using KernelTensorPtr = const void *; 69 70 struct TaskInfo { 71 std::string node_name; 72 std::string graph_name; 73 std::string task_name; 74 int64_t time_stamp; 75 // The code location of task execution 76 std::string file_name; 77 size_t line_num; 78 std::string python_stack; TaskInfoTaskInfo79 TaskInfo() : node_name(), graph_name(), task_name(), time_stamp(0), file_name(), line_num(0) {} 80 }; 81 82 using TaskInfoPtr = std::shared_ptr<TaskInfo>; 83 84 struct MemInfo; 85 struct MemBlockInfo { 86 // start and end use the operands of the memory pool 87 int64_t start_time_stamp; 88 int64_t end_time_stamp; 89 DeviceMemPtr device_addr; 90 std::weak_ptr<MemInfo> mem_info; 91 bool is_bind; 92 uint32_t stream_id; 93 size_t actual_peak_memory; 94 size_t size; 95 std::string pool_name; 96 97 // Record mem info for profiling 98 double real_start_time{-1}; 99 double real_end_time{-1}; 100 size_t alloc_in_used_size{0}; // Record in used size when allocate mem 101 size_t alloc_total_size{0}; // Record total size when allocate mem 102 size_t release_in_used_size{0}; // Record in used size when release mem 103 size_t release_total_size{0}; // Record total size when release mem MemBlockInfoMemBlockInfo104 MemBlockInfo() 105 : start_time_stamp(INT64_MAX), 106 end_time_stamp(INT64_MAX), 107 device_addr(nullptr), 108 is_bind(false), 109 stream_id(0), 110 actual_peak_memory(0), 111 size(0), 112 pool_name() {} 113 }; 114 115 using MemBlockInfoPtr = std::shared_ptr<MemBlockInfo>; 116 117 struct MemInfo { 118 // mem info 119 MemType type; 120 size_t size; 121 KernelTensorPtr kernel_tensor; 122 // producer and user 123 std::vector<TaskInfoPtr> user_tasks; 124 TaskInfoPtr producer_task; 125 // mem block 126 MemBlockInfoPtr mem_block; 127 // Memory application code location 128 std::string file_name; 129 size_t line_num; MemInfoMemInfo130 MemInfo() : type(MemType::kOther), size(0), kernel_tensor(nullptr), file_name(), line_num(0) {} 131 }; 132 133 using MemInfoPtr = std::shared_ptr<MemInfo>; 134 135 // Struct for interaction with profiling 136 struct ProfileMemInfo { 137 std::string name; 138 size_t size; // size of block, B 139 double alloc_time; // alloc time, us 140 double release_time; // release time, us 141 size_t alloc_in_used_size; // Record in used size when allocate mem, B 142 size_t alloc_total_size; // Record total size when allocate mem, B 143 size_t release_in_used_size; // Record in used size when release mem, B 144 size_t release_total_size; // Record total size when release mem, B 145 std::string device; ProfileMemInfoProfileMemInfo146 ProfileMemInfo() 147 : name(), 148 size(0), 149 alloc_time(-1), 150 release_time(-1), 151 alloc_in_used_size(0), 152 alloc_total_size(0), 153 release_in_used_size(0), 154 release_total_size(0), 155 device() {} 156 }; 157 using ProfileMemInfoPtr = std::shared_ptr<ProfileMemInfo>; 158 159 class BACKEND_EXPORT MemTracker { 160 public: 161 virtual void AddTask(const std::string &task_name, const std::string &node_name, const std::string &graph_name, 162 const std::string &file_name, size_t line_num) = 0; 163 virtual void AddMemInfo(const std::string &task_name, MemType type, size_t size, DeviceAddress *device_address, 164 const std::string &file_name, size_t line_num) = 0; 165 virtual void AddCompileTimeMemInfo(const std::string &task_name, size_t size, DeviceMemPtr device_ptr, 166 MemType mem_type, const std::string &file_name, size_t line_num) = 0; 167 virtual void UpdateMemInfo(const DeviceAddress *device_address, MemType mem_type, const std::string &file_name, 168 size_t line_num) = 0; 169 virtual void AllocMemBlock(DeviceMemPtr device_addr, size_t size, const std::string &pool_name, 170 size_t actual_peak_memory, size_t in_used_size, size_t total_size, uint32_t stream_id) = 0; 171 virtual void FreeMemBlock(DeviceMemPtr device_addr, size_t in_used_size, size_t total_size) = 0; 172 virtual void UseMemBlock(const std::string &task_name, DeviceMemPtr device_addr, const std::string &file_name, 173 size_t line_num) = 0; 174 virtual void BindDevicePtr(DeviceAddress *kernel_tensor, DeviceMemPtr device_ptr, const std::string &file_name, 175 size_t line_num) = 0; 176 virtual void UpdateDevicePtrInfo(DeviceMemPtr device_ptr, MemType mem_type, const std::string &task_name, 177 const std::string &file_name, size_t line_num) = 0; 178 179 virtual void Dump() = 0; 180 virtual void UpdateProfilingPos() = 0; 181 virtual void DumpProfilingMemInfo(const std::string &path, const std::string &file_name) = 0; 182 virtual bool IsEnabled() = 0; 183 virtual ~MemTracker() = default; 184 }; 185 186 class BACKEND_EXPORT MemoryTrackerEnabled : public MemTracker { 187 friend class MemTrackerManager; 188 189 public: 190 void AddTask(const std::string &task_name, const std::string &node_name, const std::string &graph_name, 191 const std::string &file_name, size_t line_num) override; 192 void AddMemInfo(const std::string &task_name, MemType type, size_t size, DeviceAddress *device_address, 193 const std::string &file_name, size_t line_num) override; 194 void AddCompileTimeMemInfo(const std::string &task_name, size_t size, DeviceMemPtr device_ptr, MemType mem_type, 195 const std::string &file_name, size_t line_num) override; 196 void UpdateMemInfo(const DeviceAddress *device_address, MemType mem_type, const std::string &file_name, 197 size_t line_num) override; 198 void AllocMemBlock(DeviceMemPtr device_addr, size_t size, const std::string &pool_name, size_t actual_peak_memory, 199 size_t in_used_size, size_t total_size, uint32_t stream_id) override; 200 void FreeMemBlock(DeviceMemPtr device_addr, size_t in_used_size, size_t total_size) override; 201 void UseMemBlock(const std::string &task_name, DeviceMemPtr device_addr, const std::string &file_name, 202 size_t line_num) override; 203 void BindDevicePtr(DeviceAddress *device_address, DeviceMemPtr device_ptr, const std::string &file_name, 204 size_t line_num) override; 205 void UpdateDevicePtrInfo(DeviceMemPtr device_ptr, MemType mem_type, const std::string &task_name, 206 const std::string &file_name, size_t line_num) override; 207 void Dump() override; 208 void UpdateProfilingPos() override; 209 void DumpProfilingMemInfo(const std::string &path, const std::string &file_name) override; 210 IsEnabled()211 bool IsEnabled() override { return true; } 212 std::pair<std::string, std::string> GetPath(); 213 MemoryTrackerEnabled(const MemoryTrackerEnabled &) = delete; 214 MemoryTrackerEnabled &operator=(const MemoryTrackerEnabled &) = delete; 215 216 private: 217 MemoryTrackerEnabled() = default; 218 ~MemoryTrackerEnabled() override = default; WithPythonStack()219 bool WithPythonStack() { 220 static bool is_pynative = MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode; 221 // PythonStack is no need in graph mode. 222 return is_pynative; 223 } 224 225 MemInfoPtr NewMemInfo(const std::string &task_name, MemType type, size_t size, KernelTensorPtr kernel_tensor, 226 const std::string &file_name, size_t line_num); 227 228 void AddMemInfoForKernelTensor(const std::string &task_name, MemType type, size_t size, KernelTensorPtr kernel_tensor, 229 const std::string &file_name, size_t line_num); 230 std::mutex mutex_; 231 int64_t time_stamp_ = 0; 232 size_t last_profiling_pos_{0}; // Prevent the same data from being dumped. 233 // for dump 234 bool has_dump = false; 235 bool is_init_enable_hccl_ = false; 236 bool enable_hccl_ = false; 237 std::vector<TaskInfoPtr> task_list_; 238 std::vector<MemInfoPtr> mem_info_list_; 239 std::vector<MemBlockInfoPtr> mem_block_list_; 240 // actor name -> task info 241 std::map<std::string, TaskInfoPtr> task_map_; 242 // kernel tensor -> mem info 243 std::map<KernelTensorPtr, MemInfoPtr> kernel_tensor_mem_map; 244 // device address -> mem info 245 std::map<DeviceAddress *, MemInfoPtr> device_address_mem_map; 246 // device addr -> mem block info 247 std::map<DeviceMemPtr, MemBlockInfoPtr> device_mem_block_map; // for somas 248 std::map<DeviceMemPtr, MemBlockInfoPtr> real_device_mem_block_map; getInstance()249 static MemoryTrackerEnabled &getInstance() { 250 static MemoryTrackerEnabled instance; 251 return instance; 252 } 253 }; 254 255 class BACKEND_EXPORT MemoryTrackerDisabled : public MemTracker { 256 friend class MemTrackerManager; 257 258 public: 259 // mock AddTask(const std::string & task_name,const std::string & node_name,const std::string & graph_name,const std::string & file_name,size_t line_num)260 void AddTask(const std::string &task_name, const std::string &node_name, const std::string &graph_name, 261 const std::string &file_name, size_t line_num) override {} AddMemInfo(const std::string & task_name,MemType type,size_t size,DeviceAddress * device_address,const std::string & file_name,const size_t line_num)262 void AddMemInfo(const std::string &task_name, MemType type, size_t size, DeviceAddress *device_address, 263 const std::string &file_name, const size_t line_num) override {} AddCompileTimeMemInfo(const std::string & task_name,size_t size,DeviceMemPtr device_ptr,MemType mem_type,const std::string & file_name,size_t line_num)264 void AddCompileTimeMemInfo(const std::string &task_name, size_t size, DeviceMemPtr device_ptr, MemType mem_type, 265 const std::string &file_name, size_t line_num) override {} UpdateMemInfo(const DeviceAddress * device_address,MemType mem_type,const std::string & file_name,size_t line_num)266 void UpdateMemInfo(const DeviceAddress *device_address, MemType mem_type, const std::string &file_name, 267 size_t line_num) override {} AllocMemBlock(DeviceMemPtr device_addr,size_t size,const std::string & pool_name,size_t actual_peak_memory,size_t in_used_size,size_t total_size,uint32_t stream_id)268 void AllocMemBlock(DeviceMemPtr device_addr, size_t size, const std::string &pool_name, size_t actual_peak_memory, 269 size_t in_used_size, size_t total_size, uint32_t stream_id) override {} FreeMemBlock(DeviceMemPtr device_addr,size_t in_used_size,size_t total_size)270 void FreeMemBlock(DeviceMemPtr device_addr, size_t in_used_size, size_t total_size) override {} UseMemBlock(const std::string & task_name,DeviceMemPtr device_addr,const std::string & file_name,size_t line_num)271 void UseMemBlock(const std::string &task_name, DeviceMemPtr device_addr, const std::string &file_name, 272 size_t line_num) override {} BindDevicePtr(DeviceAddress * device_address,DeviceMemPtr device_ptr,const std::string & file_name,size_t line_num)273 void BindDevicePtr(DeviceAddress *device_address, DeviceMemPtr device_ptr, const std::string &file_name, 274 size_t line_num) override {} UpdateDevicePtrInfo(DeviceMemPtr device_ptr,MemType mem_type,const std::string & task_name,const std::string & file_name,size_t line_num)275 void UpdateDevicePtrInfo(DeviceMemPtr device_ptr, MemType mem_type, const std::string &task_name, 276 const std::string &file_name, size_t line_num) override {} Dump()277 void Dump() override {} UpdateProfilingPos()278 void UpdateProfilingPos() override {} DumpProfilingMemInfo(const std::string & path,const std::string & file_name)279 void DumpProfilingMemInfo(const std::string &path, const std::string &file_name) {} IsEnabled()280 bool IsEnabled() override { return false; } 281 MemoryTrackerDisabled(const MemoryTrackerDisabled &) = delete; 282 MemoryTrackerDisabled &operator=(const MemoryTrackerDisabled &) = delete; 283 284 private: 285 MemoryTrackerDisabled() = default; 286 ~MemoryTrackerDisabled() override = default; getInstance()287 static MemoryTrackerDisabled &getInstance() { 288 static MemoryTrackerDisabled instance; 289 return instance; 290 } 291 }; 292 293 class BACKEND_EXPORT MemTrackerManager { 294 public: GetInstance()295 static MemTracker &GetInstance() { 296 static bool enable_trace_mem = common::IsEnableAlllocConfig(common::kAllocMemoryTracker); 297 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_PROF_MEM) || enable_trace_mem) { 298 return MemoryTrackerEnabled::getInstance(); 299 } else { 300 return MemoryTrackerDisabled::getInstance(); 301 } 302 } 303 }; 304 #define CALL_MEMORY_TRACKER_WITH_FILE(func, ...) MemTrackerManager::GetInstance().func(__VA_ARGS__, FILE_NAME, __LINE__) 305 #define CALL_MEMORY_TRACKER(func, ...) MemTrackerManager::GetInstance().func(__VA_ARGS__) 306 } // namespace tracker 307 } // namespace device 308 } // namespace mindspore 309 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_MEM_REUSE_MEM_TRACKER_H_ 310