1 /** 2 * Copyright 2019-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ 18 #define MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ 19 #include <thread> 20 #include <memory> 21 #include <map> 22 #include <set> 23 #include <string> 24 #include <functional> 25 #include <mutex> 26 #include <vector> 27 #include <optional> 28 #include "utils/log_adapter.h" 29 #include "utils/ms_utils.h" 30 31 namespace mindspore { 32 enum MsBackendPolicy { 33 kMsBackendGeOnly = 0, 34 kMsBackendVmOnly = 1, 35 kMsBackendGePrior = 2, 36 kMsBackendVmPrior = 3, 37 kMsBackendMsPrior = 4, 38 kMsBackendBishengPrior = 5, 39 kMsBackendUnknown = 6, 40 }; 41 42 enum DumpLevel : int { 43 kIntroductory = 1, 44 kAdvanced, 45 kFully, 46 }; 47 48 enum JitSyntaxLevel : int { 49 kStrict, // JIT Fallback disabled. 50 kCompatible, // JIT Fallback partial enabled for Python basic type only, such as scalar, dict. 51 kLax, // JIT Fallback fully enabled. 52 }; 53 54 enum DebugLevel : int { 55 kLevelRelease, // Used for deployment scenarios, compile performance will be better. 56 kLevelDebug, // For debugging scenarios, compile performance will decrease. 57 }; 58 59 enum class CellReuseLevel { kNoCellReuse, kNoInline, kLazyInline }; 60 61 const int kGraphMode = 0; 62 const int kPynativeMode = 1; 63 64 const char kDeviceUnDefined[] = "DeviceUnDefined"; 65 const char kCPUDevice[] = "CPU"; 66 const char kGPUDevice[] = "GPU"; 67 const char kAscendDevice[] = "Ascend"; 68 const char kAscendVM[] = "AscendVM"; 69 const char kDavinciInferenceDevice[] = "AscendInference"; 70 const char kDavinciMultiGraphInferenceDevice[] = "AscendMultiGraphInference"; 71 const char kGpuInferenceDevice[] = "GpuInference"; 72 const char kDavinciDevice[] = "Davinci"; 73 const char KNpuLog[] = "_npu_log"; 74 const char kTraining[] = "training"; 75 const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; 76 const unsigned int kOpTimeout = 900; 77 const int kOptimizeO0 = 0; 78 const int kOptimizeO1 = 1; 79 constexpr auto kAscendVersion910 = "ascend910"; 80 constexpr auto kAscendVersion910b = "ascend910b"; 81 constexpr auto kAscendVersion910c = "ascend910c"; 82 83 const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; 84 // The default max available device memory is 1024GB. 85 const float kDefaultMaxDeviceMemory = 1024; 86 // The default memory pool block size is 1.0G. 87 const float kDefaultMempoolBlockSize = 1.0; 88 89 // enum definition for MindSpore Context Parameter 90 enum MsCtxParam : unsigned { 91 // parameter of type bool 92 MS_CTX_TYPE_BOOL_BEGIN, 93 MS_CTX_CHECK_BPROP_FLAG = MS_CTX_TYPE_BOOL_BEGIN, 94 MS_CTX_ENABLE_DUMP, 95 MS_CTX_ENABLE_DYNAMIC_MEM_POOL, 96 MS_CTX_ENABLE_GPU_SUMMARY, 97 MS_CTX_ENABLE_GRAPH_KERNEL, 98 MS_CTX_ENABLE_HCCL, 99 MS_CTX_ENABLE_LOOP_SINK, 100 MS_CTX_ENABLE_PYNATIVE_HOOK, 101 MS_CTX_ENABLE_PYNATIVE_INFER, 102 MS_CTX_ENABLE_REDUCE_PRECISION, 103 MS_CTX_ENABLE_TASK_SINK, 104 MS_CTX_IR_FUSION_FLAG, 105 MS_CTX_IS_MULTI_GRAPH_SINK, 106 MS_CTX_IS_PYNATIVE_GE_INIT, 107 MS_CTX_PRECOMPILE_ONLY, 108 MS_CTX_ENABLE_PROFILING, 109 MS_CTX_ENABLE_PROF_MEM, 110 MS_CTX_ENABLE_PARALLEL_SPLIT, 111 MS_CTX_ENABLE_INFER_OPT, 112 MS_CTX_GRAD_FOR_SCALAR, 113 MS_CTX_ENABLE_MINDRT, 114 MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, 115 MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, 116 MS_CTX_ENABLE_MEM_OFFLOAD, 117 MS_CTX_ENABLE_RECOVERY, 118 MS_CTX_ENABLE_GE_HETEROGENOUS, 119 MS_CTX_DISABLE_FORMAT_TRANSFORM, 120 MS_CTX_RECOMPUTE_COMM_OVERLAP, 121 MS_CTX_GRAD_COMM_OVERLAP, 122 MS_CTX_RECOMPUTE_ALLGATHER_OVERLAP_FAGRAD, 123 MS_CTX_ENABLE_TASK_OPT, 124 MS_CTX_ENABLE_GRAD_COMM_OPT, 125 MS_CTX_ENABLE_OPT_SHARD_COMM_OPT, 126 MS_CTX_INTERLEAVED_MATMUL_COMM, 127 MS_CTX_INTERLEAVED_LAYERNORM_COMM, 128 MS_CTX_BIAS_ADD_COMM_SWAP, 129 MS_CTX_CONV_ALLOW_TF32, 130 MS_CTX_MATMUL_ALLOW_TF32, 131 MS_CTX_ENABLE_BEGIN_END_INLINE_OPT, 132 MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT, 133 MS_CTX_ENABLE_FLASH_ATTENTION_LOAD_BALANCE, 134 MS_CTX_ENABLE_FUSED_CAST_ADD_OPT, 135 MS_CTX_NEED_CKPT, 136 MS_CTX_TYPE_BOOL_END, 137 138 // parameter of type int 139 MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END, 140 MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN, 141 MS_CTX_MEMORY_OPTIMIZE_LEVEL, 142 MS_CTX_SAVE_GRAPHS_FLAG, 143 MS_CTX_JIT_SYNTAX_LEVEL, 144 MS_CTX_CUR_STEP_NUM, 145 MS_CTX_SAVE_CKPT_STEPS, 146 MS_CTX_LAST_TRIGGERED_STEP, 147 MS_CTX_COMPUTE_COMMUNICATE_FUSION_LEVEL, 148 MS_CTX_ENABLE_COMPILE_CACHE, 149 MS_CTX_DEBUG_LEVEL, 150 MS_CTX_TYPE_INT_END, 151 152 // parameter of type uint32 153 MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END, 154 MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN, 155 MS_CTX_RUNTIME_NUM_THREADS, 156 MS_CTX_INTER_OP_PARALLEL_NUM, 157 MS_CTX_GE_REF, 158 MS_CTX_MAX_CALL_DEPTH, 159 MS_CTX_TSD_REF, 160 MS_CTX_OP_TIMEOUT, 161 MS_CTX_TYPE_UINT32_END, 162 163 // parameter of type float 164 MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END, 165 MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN, 166 MS_CTX_MEMPOOL_BLOCK_SIZE, 167 MS_CTX_TYPE_FLOAT_END, 168 169 // parameter of type string 170 MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END, 171 MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN, 172 MS_CTX_GRAPH_MEMORY_MAX_SIZE, 173 MS_CTX_PRINT_FILE_PATH, 174 MS_CTX_PROFILING_OPTIONS, 175 MS_CTX_SAVE_DUMP_PATH, 176 MS_CTX_SAVE_GRAPHS_PATH, 177 MS_CTX_COMPILE_CACHE_PATH, 178 MS_CTX_VARIABLE_MEMORY_MAX_SIZE, 179 MS_CTX_PYTHON_EXE_PATH, 180 MS_CTX_KERNEL_BUILD_SERVER_DIR, 181 MS_CTX_ENV_CONFIG_PATH, 182 MS_CTX_TUNE_MODE, 183 MS_CTX_AOE_TUNE_MODE, 184 MS_CTX_AOE_JOB_TYPE, 185 MS_CTX_GRAPH_KERNEL_FLAGS, 186 MS_CTX_INFER_PRECISION_MODE, // GPU inference precision mode configured by Serving or Unify API. 187 MS_CTX_DETERMINISTIC, 188 MS_CTX_PRECISION_MODE, 189 MS_CTX_ENABLE_JIT_COMPILE, 190 MS_CTX_ATOMIC_CLEAN_POLICY, 191 MS_CTX_MATMUL_ALLOW_HF32, 192 MS_CTX_CONV_ALLOW_HF32, 193 MS_CTX_OP_PRECISION_MODE, 194 MS_CTX_GE_OPTIONS, 195 MS_CTX_CONV_FPROP_ALGO, 196 MS_CTX_CONV_DGRAD_ALGO, 197 MS_CTX_CONV_WGRAD_ALGO, 198 MS_CTX_PROF_MEM_OUTPUT_PATH, 199 MS_CTX_JIT_LEVEL, 200 MS_CTX_INFER_BOOST, 201 MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD, 202 MS_CTX_ENABLE_EXCEPTION_DUMP, 203 MS_CTX_TOPO_ORDER, 204 MS_CTX_OP_DEBUG_OPTION, 205 MS_CTX_TYPE_STRING_END, 206 207 // parameter numbers of each type 208 NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN, 209 NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN, 210 NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN, 211 NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN, 212 NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN 213 }; 214 215 class MS_CORE_API MsContext { 216 public: 217 MsContext(const std::string &policy, const std::string &target); 218 ~MsContext() = default; 219 MsContext(const MsContext &) = delete; 220 MsContext &operator=(const MsContext &) = delete; 221 using DeviceSeter = void (*)(const std::string &device_target); 222 using InitDeviceTargetAndPolicy = void (*)(MsContext *); 223 using LoadPluginError = std::string (*)(); 224 using EnvFunc = std::function<void(const std::string &, const std::string &)>; // device name, library path 225 static std::shared_ptr<MsContext> GetInstance(); 226 227 void SetDeviceId(); 228 void Refresh(); 229 230 bool enable_dump_ir() const; 231 std::string GetSaveGraphsPath() const; 232 int GetSaveGraphsLevel() const; 233 bool CanDump(const DumpLevel &level) const; 234 std::string backend_policy() const; 235 bool set_backend_policy(const std::string &policy); 236 std::string ascend_soc_version() const; 237 bool set_ascend_soc_version(const std::string &soc_version); 238 std::string ascend_soc_name() const; 239 void set_ascend_soc_name(const std::string &soc_name); 240 // _comm_helper.py will try to dlopen libhccl.so, and minddata will try to dlopen libdvpp_utils.so. if load ascend 241 // plugin failed on ascend environment, loading above libraries will crush the process. 242 bool IsAscendPluginLoaded() const; 243 void SetDefaultDeviceTarget(); 244 void SetDeviceTargetFromInner(const std::string &device_target); 245 void SetDeviceTargetFromUser(const std::string &device_target); 246 bool IsDefaultDeviceTarget() const; IsSupportDevice(const std::string & device)247 bool IsSupportDevice(const std::string &device) const { return InitFuncMap().find(device) != InitFuncMap().end(); } 248 249 bool IsEnableInferBoost(); 250 void SetMsInternalEnableCustomKernelList(); 251 const std::set<std::string> &ms_internal_enable_custom_kernel_list() const; 252 253 void RegisterSetEnv(const EnvFunc &func); 254 void RegisterCheckEnv(const EnvFunc &func); 255 256 void SetEnv(const std::string &device); 257 void CheckEnv(const std::string &device); 258 device_seter(const DeviceSeter & device)259 static void device_seter(const DeviceSeter &device) { seter_ = device; } 260 static void RegisterInitFunc(const std::string &name, InitDeviceTargetAndPolicy func); 261 static void ResisterLoadPluginErrorFunc(LoadPluginError func); 262 263 template <typename T> set_param_inner(MsCtxParam,const T &)264 void set_param_inner(MsCtxParam, const T &) { 265 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 266 } 267 268 template <typename T> set_param(MsCtxParam param,const T & value)269 void set_param(MsCtxParam param, const T &value) { 270 CheckReadStatus<T>(param, value); 271 MarkWriteStatus(param); 272 set_param_inner<T>(param, value); 273 } 274 275 template <typename T> get_param(MsCtxParam)276 const T &get_param(MsCtxParam) const { 277 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 278 } 279 280 template <typename T> increase_param(MsCtxParam)281 void increase_param(MsCtxParam) { 282 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 283 } 284 285 template <typename T> decrease_param(MsCtxParam)286 void decrease_param(MsCtxParam) { 287 MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; 288 } 289 290 void ChildAfterFork(); // Reset ms context. Only called in child process after fork occurs. 291 bool EnableAoeOnline() const; 292 bool EnableAoeOffline() const; 293 SetCellReuseLevel(const CellReuseLevel & level)294 void SetCellReuseLevel(const CellReuseLevel &level) { cell_reuse_level_ = level; } CellReuseLevel()295 enum CellReuseLevel CellReuseLevel() const { return cell_reuse_level_; } 296 297 void SetJitLevel(const std::string &jit_level) const; 298 std::string GetJitLevel() const; 299 bool IsKByKExecutorMode() const; 300 GetLoadPluginErrorStr()301 std::string GetLoadPluginErrorStr() const { return load_plugin_error_(); } 302 set_not_convert_jit(bool not_convert_jit)303 void set_not_convert_jit(bool not_convert_jit) { not_convert_jit_ = not_convert_jit; } not_convert_jit()304 bool not_convert_jit() { return not_convert_jit_; } 305 306 private: 307 void RefreshExecutionMode(); 308 void RefreshMemoryOffload(); 309 310 void MarkReadStatus(MsCtxParam param) const; // record status to mutable member params_read_status_ 311 void MarkWriteStatus(MsCtxParam param) const; // record status to mutable member params_write_status_ 312 template <typename T> 313 void CheckReadStatus(MsCtxParam param, const T &value) const; 314 bool CheckWriteStatus(MsCtxParam param) const; 315 void SetAscendConfig(); 316 void InitBoolTypeDefaultValue(); 317 void InitStringTypeDefaultValue(); 318 void InitDigitalTypeDefaultValue(); 319 320 static DeviceSeter seter_; 321 static std::shared_ptr<MsContext> inst_context_; 322 static LoadPluginError load_plugin_error_; 323 324 bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS]; 325 int int_params_[MsCtxParam::NUM_INT_PARAMS]; 326 uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS]; 327 float float_params_[MsCtxParam::NUM_FLOAT_PARAMS]; 328 std::string string_params_[MsCtxParam::NUM_STRING_PARAMS]; 329 330 mutable std::vector<bool> params_read_status_; 331 mutable std::vector<bool> params_write_status_; 332 MsBackendPolicy backend_policy_; 333 std::string ascend_soc_version_; 334 std::string ascend_soc_name_ = "ascend"; 335 bool default_device_target_ = true; 336 337 EnvFunc set_env_ = nullptr; 338 EnvFunc check_env_ = nullptr; 339 340 static std::map<std::string, InitDeviceTargetAndPolicy> &InitFuncMap(); 341 static std::map<std::string, std::string> &PluginPathMap(); 342 enum CellReuseLevel cell_reuse_level_ = CellReuseLevel::kNoCellReuse; 343 bool not_convert_jit_{false}; 344 345 std::optional<bool> enable_infer_boost_ = std::nullopt; 346 std::set<std::string> ms_internal_enable_custom_kernel_list_; 347 }; 348 349 // set method implementation for type bool/int/uint32_t/float/std::string 350 template <> 351 inline void MsContext::set_param_inner<bool>(MsCtxParam param, const bool &value) { 352 #ifdef ENABLE_SECURITY 353 if (param == MS_CTX_SAVE_GRAPHS_FLAG) { 354 MS_EXCEPTION(ValueError) << "The save_graphs is not supported, please without '-s on' and recompile source."; 355 } 356 #endif 357 bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value; 358 } 359 360 template <> 361 inline void MsContext::set_param_inner<int>(MsCtxParam param, const int &value) { 362 int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value; 363 } 364 365 template <> 366 inline void MsContext::set_param_inner<uint32_t>(MsCtxParam param, const uint32_t &value) { 367 uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value; 368 } 369 370 template <> 371 inline void MsContext::set_param_inner<float>(MsCtxParam param, const float &value) { 372 float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value; 373 } 374 375 template <> 376 inline void MsContext::set_param_inner<std::string>(MsCtxParam param, const std::string &value) { 377 #ifdef ENABLE_SECURITY 378 if (param == MS_CTX_SAVE_GRAPHS_PATH) { 379 MS_EXCEPTION(ValueError) << "The save_graphs is not supported, please without '-s on' and recompile source."; 380 } 381 #endif 382 if (param == MS_CTX_DEVICE_TARGET) { 383 SetDeviceTargetFromUser(value); 384 } else { 385 string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value; 386 } 387 } 388 389 // get method implementation for type bool/int/uint32_t/float/std::string 390 template <> 391 inline const bool &MsContext::get_param<bool>(MsCtxParam param) const { 392 MarkReadStatus(param); 393 return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN]; 394 } 395 396 template <> 397 inline const int &MsContext::get_param<int>(MsCtxParam param) const { 398 MarkReadStatus(param); 399 return int_params_[param - MS_CTX_TYPE_INT_BEGIN]; 400 } 401 402 template <> 403 inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const { 404 MarkReadStatus(param); 405 return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]; 406 } 407 408 template <> 409 inline const float &MsContext::get_param<float>(MsCtxParam param) const { 410 MarkReadStatus(param); 411 return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN]; 412 } 413 414 template <> 415 inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const { 416 MarkReadStatus(param); 417 return string_params_[param - MS_CTX_TYPE_STRING_BEGIN]; 418 } 419 420 // increate method implementation for type uint32_t 421 template <> 422 inline void MsContext::increase_param<uint32_t>(MsCtxParam param) { 423 uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++; 424 } 425 426 // decrease method implementation for type uint32_t 427 template <> 428 inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) { 429 uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--; 430 } 431 432 #define MSCONTEXT_REGISTER_INIT_FUNC(name, func) \ 433 class name##InitFuncRegister { \ 434 public: \ 435 name##InitFuncRegister() { MsContext::RegisterInitFunc(name, func); } \ 436 } g_##name##_init_func_register; 437 } // namespace mindspore 438 439 #endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_ 440