• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_MS_CONTEXT_H_
18 #define MINDSPORE_CORE_UTILS_MS_CONTEXT_H_
19 #include <thread>
20 #include <memory>
21 #include <map>
22 #include <set>
23 #include <string>
24 #include <functional>
25 #include <mutex>
26 #include <vector>
27 #include <optional>
28 #include "utils/log_adapter.h"
29 #include "utils/ms_utils.h"
30 
31 namespace mindspore {
32 enum MsBackendPolicy {
33   kMsBackendGeOnly = 0,
34   kMsBackendVmOnly = 1,
35   kMsBackendGePrior = 2,
36   kMsBackendVmPrior = 3,
37   kMsBackendMsPrior = 4,
38   kMsBackendBishengPrior = 5,
39   kMsBackendUnknown = 6,
40 };
41 
42 enum DumpLevel : int {
43   kIntroductory = 1,
44   kAdvanced,
45   kFully,
46 };
47 
48 enum JitSyntaxLevel : int {
49   kStrict,      // JIT Fallback disabled.
50   kCompatible,  // JIT Fallback partial enabled for Python basic type only, such as scalar, dict.
51   kLax,         // JIT Fallback fully enabled.
52 };
53 
54 enum DebugLevel : int {
55   kLevelRelease,  // Used for deployment scenarios, compile performance will be better.
56   kLevelDebug,    // For debugging scenarios, compile performance will decrease.
57 };
58 
59 enum class CellReuseLevel { kNoCellReuse, kNoInline, kLazyInline };
60 
61 const int kGraphMode = 0;
62 const int kPynativeMode = 1;
63 
64 const char kDeviceUnDefined[] = "DeviceUnDefined";
65 const char kCPUDevice[] = "CPU";
66 const char kGPUDevice[] = "GPU";
67 const char kAscendDevice[] = "Ascend";
68 const char kAscendVM[] = "AscendVM";
69 const char kDavinciInferenceDevice[] = "AscendInference";
70 const char kDavinciMultiGraphInferenceDevice[] = "AscendMultiGraphInference";
71 const char kGpuInferenceDevice[] = "GpuInference";
72 const char kDavinciDevice[] = "Davinci";
73 const char KNpuLog[] = "_npu_log";
74 const char kTraining[] = "training";
75 const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000;
76 const unsigned int kOpTimeout = 900;
77 const int kOptimizeO0 = 0;
78 const int kOptimizeO1 = 1;
79 constexpr auto kAscendVersion910 = "ascend910";
80 constexpr auto kAscendVersion910b = "ascend910b";
81 constexpr auto kAscendVersion910c = "ascend910c";
82 
83 const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice};
84 // The default max available device memory is 1024GB.
85 const float kDefaultMaxDeviceMemory = 1024;
86 // The default memory pool block size is 1.0G.
87 const float kDefaultMempoolBlockSize = 1.0;
88 
89 // enum definition for MindSpore Context Parameter
90 enum MsCtxParam : unsigned {
91   // parameter of type bool
92   MS_CTX_TYPE_BOOL_BEGIN,
93   MS_CTX_CHECK_BPROP_FLAG = MS_CTX_TYPE_BOOL_BEGIN,
94   MS_CTX_ENABLE_DUMP,
95   MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
96   MS_CTX_ENABLE_GPU_SUMMARY,
97   MS_CTX_ENABLE_GRAPH_KERNEL,
98   MS_CTX_ENABLE_HCCL,
99   MS_CTX_ENABLE_LOOP_SINK,
100   MS_CTX_ENABLE_PYNATIVE_HOOK,
101   MS_CTX_ENABLE_PYNATIVE_INFER,
102   MS_CTX_ENABLE_REDUCE_PRECISION,
103   MS_CTX_ENABLE_TASK_SINK,
104   MS_CTX_IR_FUSION_FLAG,
105   MS_CTX_IS_MULTI_GRAPH_SINK,
106   MS_CTX_IS_PYNATIVE_GE_INIT,
107   MS_CTX_PRECOMPILE_ONLY,
108   MS_CTX_ENABLE_PROFILING,
109   MS_CTX_ENABLE_PROF_MEM,
110   MS_CTX_ENABLE_PARALLEL_SPLIT,
111   MS_CTX_ENABLE_INFER_OPT,
112   MS_CTX_GRAD_FOR_SCALAR,
113   MS_CTX_ENABLE_MINDRT,
114   MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE,
115   MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE,
116   MS_CTX_ENABLE_MEM_OFFLOAD,
117   MS_CTX_ENABLE_RECOVERY,
118   MS_CTX_ENABLE_GE_HETEROGENOUS,
119   MS_CTX_DISABLE_FORMAT_TRANSFORM,
120   MS_CTX_RECOMPUTE_COMM_OVERLAP,
121   MS_CTX_GRAD_COMM_OVERLAP,
122   MS_CTX_RECOMPUTE_ALLGATHER_OVERLAP_FAGRAD,
123   MS_CTX_ENABLE_TASK_OPT,
124   MS_CTX_ENABLE_GRAD_COMM_OPT,
125   MS_CTX_ENABLE_OPT_SHARD_COMM_OPT,
126   MS_CTX_INTERLEAVED_MATMUL_COMM,
127   MS_CTX_INTERLEAVED_LAYERNORM_COMM,
128   MS_CTX_BIAS_ADD_COMM_SWAP,
129   MS_CTX_CONV_ALLOW_TF32,
130   MS_CTX_MATMUL_ALLOW_TF32,
131   MS_CTX_ENABLE_BEGIN_END_INLINE_OPT,
132   MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT,
133   MS_CTX_ENABLE_FLASH_ATTENTION_LOAD_BALANCE,
134   MS_CTX_ENABLE_FUSED_CAST_ADD_OPT,
135   MS_CTX_NEED_CKPT,
136   MS_CTX_TYPE_BOOL_END,
137 
138   // parameter of type int
139   MS_CTX_TYPE_INT_BEGIN = MS_CTX_TYPE_BOOL_END,
140   MS_CTX_EXECUTION_MODE = MS_CTX_TYPE_INT_BEGIN,
141   MS_CTX_MEMORY_OPTIMIZE_LEVEL,
142   MS_CTX_SAVE_GRAPHS_FLAG,
143   MS_CTX_JIT_SYNTAX_LEVEL,
144   MS_CTX_CUR_STEP_NUM,
145   MS_CTX_SAVE_CKPT_STEPS,
146   MS_CTX_LAST_TRIGGERED_STEP,
147   MS_CTX_COMPUTE_COMMUNICATE_FUSION_LEVEL,
148   MS_CTX_ENABLE_COMPILE_CACHE,
149   MS_CTX_DEBUG_LEVEL,
150   MS_CTX_TYPE_INT_END,
151 
152   // parameter of type uint32
153   MS_CTX_TYPE_UINT32_BEGIN = MS_CTX_TYPE_INT_END,
154   MS_CTX_DEVICE_ID = MS_CTX_TYPE_UINT32_BEGIN,
155   MS_CTX_RUNTIME_NUM_THREADS,
156   MS_CTX_INTER_OP_PARALLEL_NUM,
157   MS_CTX_GE_REF,
158   MS_CTX_MAX_CALL_DEPTH,
159   MS_CTX_TSD_REF,
160   MS_CTX_OP_TIMEOUT,
161   MS_CTX_TYPE_UINT32_END,
162 
163   // parameter of type float
164   MS_CTX_TYPE_FLOAT_BEGIN = MS_CTX_TYPE_UINT32_END,
165   MS_CTX_MAX_DEVICE_MEMORY = MS_CTX_TYPE_FLOAT_BEGIN,
166   MS_CTX_MEMPOOL_BLOCK_SIZE,
167   MS_CTX_TYPE_FLOAT_END,
168 
169   // parameter of type string
170   MS_CTX_TYPE_STRING_BEGIN = MS_CTX_TYPE_FLOAT_END,
171   MS_CTX_DEVICE_TARGET = MS_CTX_TYPE_STRING_BEGIN,
172   MS_CTX_GRAPH_MEMORY_MAX_SIZE,
173   MS_CTX_PRINT_FILE_PATH,
174   MS_CTX_PROFILING_OPTIONS,
175   MS_CTX_SAVE_DUMP_PATH,
176   MS_CTX_SAVE_GRAPHS_PATH,
177   MS_CTX_COMPILE_CACHE_PATH,
178   MS_CTX_VARIABLE_MEMORY_MAX_SIZE,
179   MS_CTX_PYTHON_EXE_PATH,
180   MS_CTX_KERNEL_BUILD_SERVER_DIR,
181   MS_CTX_ENV_CONFIG_PATH,
182   MS_CTX_TUNE_MODE,
183   MS_CTX_AOE_TUNE_MODE,
184   MS_CTX_AOE_JOB_TYPE,
185   MS_CTX_GRAPH_KERNEL_FLAGS,
186   MS_CTX_INFER_PRECISION_MODE,  // GPU inference precision mode configured by Serving or Unify API.
187   MS_CTX_DETERMINISTIC,
188   MS_CTX_PRECISION_MODE,
189   MS_CTX_ENABLE_JIT_COMPILE,
190   MS_CTX_ATOMIC_CLEAN_POLICY,
191   MS_CTX_MATMUL_ALLOW_HF32,
192   MS_CTX_CONV_ALLOW_HF32,
193   MS_CTX_OP_PRECISION_MODE,
194   MS_CTX_GE_OPTIONS,
195   MS_CTX_CONV_FPROP_ALGO,
196   MS_CTX_CONV_DGRAD_ALGO,
197   MS_CTX_CONV_WGRAD_ALGO,
198   MS_CTX_PROF_MEM_OUTPUT_PATH,
199   MS_CTX_JIT_LEVEL,
200   MS_CTX_INFER_BOOST,
201   MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD,
202   MS_CTX_ENABLE_EXCEPTION_DUMP,
203   MS_CTX_TOPO_ORDER,
204   MS_CTX_OP_DEBUG_OPTION,
205   MS_CTX_TYPE_STRING_END,
206 
207   // parameter numbers of each type
208   NUM_BOOL_PARAMS = MS_CTX_TYPE_BOOL_END - MS_CTX_TYPE_BOOL_BEGIN,
209   NUM_INT_PARAMS = MS_CTX_TYPE_INT_END - MS_CTX_TYPE_INT_BEGIN,
210   NUM_UINT32_PARAMS = MS_CTX_TYPE_UINT32_END - MS_CTX_TYPE_UINT32_BEGIN,
211   NUM_FLOAT_PARAMS = MS_CTX_TYPE_FLOAT_END - MS_CTX_TYPE_FLOAT_BEGIN,
212   NUM_STRING_PARAMS = MS_CTX_TYPE_STRING_END - MS_CTX_TYPE_STRING_BEGIN
213 };
214 
215 class MS_CORE_API MsContext {
216  public:
217   MsContext(const std::string &policy, const std::string &target);
218   ~MsContext() = default;
219   MsContext(const MsContext &) = delete;
220   MsContext &operator=(const MsContext &) = delete;
221   using DeviceSeter = void (*)(const std::string &device_target);
222   using InitDeviceTargetAndPolicy = void (*)(MsContext *);
223   using LoadPluginError = std::string (*)();
224   using EnvFunc = std::function<void(const std::string &, const std::string &)>;  // device name, library path
225   static std::shared_ptr<MsContext> GetInstance();
226 
227   void SetDeviceId();
228   void Refresh();
229 
230   bool enable_dump_ir() const;
231   std::string GetSaveGraphsPath() const;
232   int GetSaveGraphsLevel() const;
233   bool CanDump(const DumpLevel &level) const;
234   std::string backend_policy() const;
235   bool set_backend_policy(const std::string &policy);
236   std::string ascend_soc_version() const;
237   bool set_ascend_soc_version(const std::string &soc_version);
238   std::string ascend_soc_name() const;
239   void set_ascend_soc_name(const std::string &soc_name);
240   // _comm_helper.py will try to dlopen libhccl.so, and minddata will try to dlopen libdvpp_utils.so. if load ascend
241   // plugin failed on ascend environment, loading above libraries will crush the process.
242   bool IsAscendPluginLoaded() const;
243   void SetDefaultDeviceTarget();
244   void SetDeviceTargetFromInner(const std::string &device_target);
245   void SetDeviceTargetFromUser(const std::string &device_target);
246   bool IsDefaultDeviceTarget() const;
IsSupportDevice(const std::string & device)247   bool IsSupportDevice(const std::string &device) const { return InitFuncMap().find(device) != InitFuncMap().end(); }
248 
249   bool IsEnableInferBoost();
250   void SetMsInternalEnableCustomKernelList();
251   const std::set<std::string> &ms_internal_enable_custom_kernel_list() const;
252 
253   void RegisterSetEnv(const EnvFunc &func);
254   void RegisterCheckEnv(const EnvFunc &func);
255 
256   void SetEnv(const std::string &device);
257   void CheckEnv(const std::string &device);
258 
device_seter(const DeviceSeter & device)259   static void device_seter(const DeviceSeter &device) { seter_ = device; }
260   static void RegisterInitFunc(const std::string &name, InitDeviceTargetAndPolicy func);
261   static void ResisterLoadPluginErrorFunc(LoadPluginError func);
262 
263   template <typename T>
set_param_inner(MsCtxParam,const T &)264   void set_param_inner(MsCtxParam, const T &) {
265     MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
266   }
267 
268   template <typename T>
set_param(MsCtxParam param,const T & value)269   void set_param(MsCtxParam param, const T &value) {
270     CheckReadStatus<T>(param, value);
271     MarkWriteStatus(param);
272     set_param_inner<T>(param, value);
273   }
274 
275   template <typename T>
get_param(MsCtxParam)276   const T &get_param(MsCtxParam) const {
277     MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
278   }
279 
280   template <typename T>
increase_param(MsCtxParam)281   void increase_param(MsCtxParam) {
282     MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
283   }
284 
285   template <typename T>
decrease_param(MsCtxParam)286   void decrease_param(MsCtxParam) {
287     MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
288   }
289 
290   void ChildAfterFork();  // Reset ms context. Only called in child process after fork occurs.
291   bool EnableAoeOnline() const;
292   bool EnableAoeOffline() const;
293 
SetCellReuseLevel(const CellReuseLevel & level)294   void SetCellReuseLevel(const CellReuseLevel &level) { cell_reuse_level_ = level; }
CellReuseLevel()295   enum CellReuseLevel CellReuseLevel() const { return cell_reuse_level_; }
296 
297   void SetJitLevel(const std::string &jit_level) const;
298   std::string GetJitLevel() const;
299   bool IsKByKExecutorMode() const;
300 
GetLoadPluginErrorStr()301   std::string GetLoadPluginErrorStr() const { return load_plugin_error_(); }
302 
set_not_convert_jit(bool not_convert_jit)303   void set_not_convert_jit(bool not_convert_jit) { not_convert_jit_ = not_convert_jit; }
not_convert_jit()304   bool not_convert_jit() { return not_convert_jit_; }
305 
306  private:
307   void RefreshExecutionMode();
308   void RefreshMemoryOffload();
309 
310   void MarkReadStatus(MsCtxParam param) const;   // record status to mutable member params_read_status_
311   void MarkWriteStatus(MsCtxParam param) const;  // record status to mutable member params_write_status_
312   template <typename T>
313   void CheckReadStatus(MsCtxParam param, const T &value) const;
314   bool CheckWriteStatus(MsCtxParam param) const;
315   void SetAscendConfig();
316   void InitBoolTypeDefaultValue();
317   void InitStringTypeDefaultValue();
318   void InitDigitalTypeDefaultValue();
319 
320   static DeviceSeter seter_;
321   static std::shared_ptr<MsContext> inst_context_;
322   static LoadPluginError load_plugin_error_;
323 
324   bool bool_params_[MsCtxParam::NUM_BOOL_PARAMS];
325   int int_params_[MsCtxParam::NUM_INT_PARAMS];
326   uint32_t uint32_params_[MsCtxParam::NUM_UINT32_PARAMS];
327   float float_params_[MsCtxParam::NUM_FLOAT_PARAMS];
328   std::string string_params_[MsCtxParam::NUM_STRING_PARAMS];
329 
330   mutable std::vector<bool> params_read_status_;
331   mutable std::vector<bool> params_write_status_;
332   MsBackendPolicy backend_policy_;
333   std::string ascend_soc_version_;
334   std::string ascend_soc_name_ = "ascend";
335   bool default_device_target_ = true;
336 
337   EnvFunc set_env_ = nullptr;
338   EnvFunc check_env_ = nullptr;
339 
340   static std::map<std::string, InitDeviceTargetAndPolicy> &InitFuncMap();
341   static std::map<std::string, std::string> &PluginPathMap();
342   enum CellReuseLevel cell_reuse_level_ = CellReuseLevel::kNoCellReuse;
343   bool not_convert_jit_{false};
344 
345   std::optional<bool> enable_infer_boost_ = std::nullopt;
346   std::set<std::string> ms_internal_enable_custom_kernel_list_;
347 };
348 
349 // set method implementation for type bool/int/uint32_t/float/std::string
350 template <>
351 inline void MsContext::set_param_inner<bool>(MsCtxParam param, const bool &value) {
352 #ifdef ENABLE_SECURITY
353   if (param == MS_CTX_SAVE_GRAPHS_FLAG) {
354     MS_EXCEPTION(ValueError) << "The save_graphs is not supported, please without '-s on' and recompile source.";
355   }
356 #endif
357   bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN] = value;
358 }
359 
360 template <>
361 inline void MsContext::set_param_inner<int>(MsCtxParam param, const int &value) {
362   int_params_[param - MS_CTX_TYPE_INT_BEGIN] = value;
363 }
364 
365 template <>
366 inline void MsContext::set_param_inner<uint32_t>(MsCtxParam param, const uint32_t &value) {
367   uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN] = value;
368 }
369 
370 template <>
371 inline void MsContext::set_param_inner<float>(MsCtxParam param, const float &value) {
372   float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN] = value;
373 }
374 
375 template <>
376 inline void MsContext::set_param_inner<std::string>(MsCtxParam param, const std::string &value) {
377 #ifdef ENABLE_SECURITY
378   if (param == MS_CTX_SAVE_GRAPHS_PATH) {
379     MS_EXCEPTION(ValueError) << "The save_graphs is not supported, please without '-s on' and recompile source.";
380   }
381 #endif
382   if (param == MS_CTX_DEVICE_TARGET) {
383     SetDeviceTargetFromUser(value);
384   } else {
385     string_params_[param - MS_CTX_TYPE_STRING_BEGIN] = value;
386   }
387 }
388 
389 // get method implementation for type bool/int/uint32_t/float/std::string
390 template <>
391 inline const bool &MsContext::get_param<bool>(MsCtxParam param) const {
392   MarkReadStatus(param);
393   return bool_params_[param - MS_CTX_TYPE_BOOL_BEGIN];
394 }
395 
396 template <>
397 inline const int &MsContext::get_param<int>(MsCtxParam param) const {
398   MarkReadStatus(param);
399   return int_params_[param - MS_CTX_TYPE_INT_BEGIN];
400 }
401 
402 template <>
403 inline const uint32_t &MsContext::get_param<uint32_t>(MsCtxParam param) const {
404   MarkReadStatus(param);
405   return uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN];
406 }
407 
408 template <>
409 inline const float &MsContext::get_param<float>(MsCtxParam param) const {
410   MarkReadStatus(param);
411   return float_params_[param - MS_CTX_TYPE_FLOAT_BEGIN];
412 }
413 
414 template <>
415 inline const std::string &MsContext::get_param<std::string>(MsCtxParam param) const {
416   MarkReadStatus(param);
417   return string_params_[param - MS_CTX_TYPE_STRING_BEGIN];
418 }
419 
420 // increate method implementation for type uint32_t
421 template <>
422 inline void MsContext::increase_param<uint32_t>(MsCtxParam param) {
423   uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]++;
424 }
425 
426 // decrease method implementation for type uint32_t
427 template <>
428 inline void MsContext::decrease_param<uint32_t>(MsCtxParam param) {
429   uint32_params_[param - MS_CTX_TYPE_UINT32_BEGIN]--;
430 }
431 
432 #define MSCONTEXT_REGISTER_INIT_FUNC(name, func)                          \
433   class name##InitFuncRegister {                                          \
434    public:                                                                \
435     name##InitFuncRegister() { MsContext::RegisterInitFunc(name, func); } \
436   } g_##name##_init_func_register;
437 }  // namespace mindspore
438 
439 #endif  // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_
440