1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/ms_context.h"
18 #include <string>
19 #include <thread>
20 #include <atomic>
21 #include <fstream>
22 #include <algorithm>
23 #include <utility>
24 #include "utils/ms_utils.h"
25 #include "include/common/utils/utils.h"
26 #include "utils/convert_utils_base.h"
27 #include "utils/phase.h"
28
29 #if defined(_WIN32) || defined(_WIN64)
30 #include <windows.h>
31 #else
32 #include <dlfcn.h>
33 #endif
34
35 namespace mindspore {
36 namespace {
37 std::map<std::string, MsBackendPolicy> kPolicyMap = {{"ge", kMsBackendGePrior}, {"bisheng", kMsBackendBishengPrior},
38 {"vm", kMsBackendVmOnly}, {"ms", kMsBackendMsPrior},
39 {"ge_only", kMsBackendGeOnly}, {"vm_prior", kMsBackendVmPrior}};
40
41 constexpr auto kDeviceTargetSize2 = 2;
42 } // namespace
43 std::atomic<bool> thread_1_must_end(false);
44
45 MsContext::DeviceSeter MsContext::seter_ = nullptr;
46 MsContext::LoadPluginError MsContext::load_plugin_error_ = nullptr;
47 std::shared_ptr<MsContext> MsContext::inst_context_ = nullptr;
48
49 std::map<MsCtxParam, std::string> kUnresetParamCheckList = {
50 {MsCtxParam::MS_CTX_DEVICE_ID, "device_id"},
51 {MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "variable_memory_max_size"},
52 {MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY, "max_device_memory"},
53 {MsCtxParam::MS_CTX_MEMPOOL_BLOCK_SIZE, "mempool_block_size"}};
54
MsContext(const std::string & policy,const std::string & target)55 MsContext::MsContext(const std::string &policy, const std::string &target) {
56 #ifndef ENABLE_SECURITY
57 set_param<int>(MS_CTX_SAVE_GRAPHS_FLAG, 0);
58 set_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH, ".");
59 set_param<std::string>(MS_CTX_COMPILE_CACHE_PATH, "");
60 #else
61 // Need set a default value for arrays even if running in the security mode.
62 int_params_[MS_CTX_SAVE_GRAPHS_FLAG - MS_CTX_TYPE_BOOL_BEGIN] = 0;
63 string_params_[MS_CTX_SAVE_GRAPHS_PATH - MS_CTX_TYPE_STRING_BEGIN] = ".";
64 #endif
65 InitBoolTypeDefaultValue();
66 InitStringTypeDefaultValue();
67 InitDigitalTypeDefaultValue();
68 MsContext::SetDeviceId();
69 string_params_[MS_CTX_DEVICE_TARGET - MS_CTX_TYPE_STRING_BEGIN] = target;
70 set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, target == kAscendDevice || target == kDavinciDevice);
71
72 backend_policy_ = kPolicyMap[policy];
73 ascend_soc_version_ = "";
74
75 params_read_status_ = std::vector<bool>(
76 static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
77 MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
78 false);
79 params_write_status_ = std::vector<bool>(
80 static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
81 MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
82 false);
83
84 SetAscendConfig();
85 }
86
GetInstance()87 std::shared_ptr<MsContext> MsContext::GetInstance() {
88 static std::once_flag inst_context_init_flag_ = {};
89 std::call_once(inst_context_init_flag_, [&]() {
90 if (inst_context_ == nullptr) {
91 MS_LOG(DEBUG) << "Create new mindspore context";
92 inst_context_ = std::make_shared<MsContext>("vm", kCPUDevice);
93 }
94 });
95 MS_EXCEPTION_IF_NULL(inst_context_);
96 return inst_context_;
97 }
98
SetDeviceId()99 void MsContext::SetDeviceId() {
100 auto env_device = common::GetEnv("DEVICE_ID");
101 if (!env_device.empty()) {
102 try {
103 uint32_t device_id = UlongToUint(std::stoul(env_device));
104 set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
105 } catch (std::invalid_argument &e) {
106 MS_LOG(WARNING) << "Invalid DEVICE_ID env:" << env_device << ". Please set DEVICE_ID to 0-7";
107 set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
108 }
109 } else {
110 set_param<uint32_t>(MS_CTX_DEVICE_ID, 0);
111 }
112 }
113
Refresh()114 void MsContext::Refresh() {
115 RefreshExecutionMode();
116 RefreshMemoryOffload();
117 }
118
RefreshExecutionMode()119 void MsContext::RefreshExecutionMode() {
120 const std::string &target = get_param<std::string>(MS_CTX_DEVICE_TARGET);
121 if (target == kAscendDevice) {
122 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
123 set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
124 } else if (IsKByKExecutorMode()) {
125 set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
126 }
127 }
128 }
129
RefreshMemoryOffload()130 void MsContext::RefreshMemoryOffload() {
131 const bool enable_mem_offload = get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD);
132 if (!enable_mem_offload) {
133 return;
134 }
135 const std::string &target = get_param<std::string>(MS_CTX_DEVICE_TARGET);
136 if (target == kCPUDevice) {
137 MS_LOG(WARNING) << "Memory offload is not available on CPU device.";
138 set_param(MS_CTX_ENABLE_MEM_OFFLOAD, false);
139 return;
140 }
141 if (target == kAscendDevice && get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && !IsKByKExecutorMode()) {
142 MS_LOG(WARNING) << "Run graph mode with kernel by kernel because memory offload is ON.";
143 set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
144 return;
145 }
146 if (get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) == kOptimizeO1) {
147 MS_LOG(WARNING) << "Memory offload is not available when memory_optimize_level is set to O1.";
148 set_param(MS_CTX_ENABLE_MEM_OFFLOAD, false);
149 return;
150 }
151 MS_LOG(INFO) << "Set memory pool block size to max device memory size for memory offload.";
152 set_param_inner(MS_CTX_MEMPOOL_BLOCK_SIZE, get_param<float>(MS_CTX_MAX_DEVICE_MEMORY));
153 }
154
set_backend_policy(const std::string & policy)155 bool MsContext::set_backend_policy(const std::string &policy) {
156 auto iter = kPolicyMap.find(policy);
157 if (iter == kPolicyMap.end()) {
158 MS_LOG(ERROR) << "invalid backend policy name: " << policy;
159 return false;
160 }
161 backend_policy_ = iter->second;
162 MS_LOG(INFO) << "ms set context backend policy:" << policy;
163 return true;
164 }
165
backend_policy() const166 std::string MsContext::backend_policy() const {
167 auto res = std::find_if(
168 kPolicyMap.begin(), kPolicyMap.end(),
169 [&, this](const std::pair<std::string, MsBackendPolicy> &item) { return item.second == backend_policy_; });
170 if (res != kPolicyMap.end()) {
171 return res->first;
172 }
173 return "unknown";
174 }
175
set_ascend_soc_name(const std::string & soc_name)176 void MsContext::set_ascend_soc_name(const std::string &soc_name) { ascend_soc_name_ = soc_name; }
177
ascend_soc_name() const178 std::string MsContext::ascend_soc_name() const { return ascend_soc_name_; }
179
set_ascend_soc_version(const std::string & soc_version)180 bool MsContext::set_ascend_soc_version(const std::string &soc_version) {
181 ascend_soc_version_ = soc_version;
182 return true;
183 }
184
ascend_soc_version() const185 std::string MsContext::ascend_soc_version() const { return ascend_soc_version_; }
186
enable_dump_ir() const187 bool MsContext::enable_dump_ir() const {
188 #ifdef ENABLE_DUMP_IR
189 return true;
190 #else
191 return false;
192 #endif
193 }
194
InitFuncMap()195 std::map<std::string, MsContext::InitDeviceTargetAndPolicy> &MsContext::InitFuncMap() {
196 static std::map<std::string, InitDeviceTargetAndPolicy> init_func_map = {};
197 return init_func_map;
198 }
199
PluginPathMap()200 std::map<std::string, std::string> &MsContext::PluginPathMap() {
201 static std::map<std::string, std::string> plugin_path_map = {};
202 return plugin_path_map;
203 }
204
RegisterInitFunc(const std::string & name,MsContext::InitDeviceTargetAndPolicy func)205 void MsContext::RegisterInitFunc(const std::string &name, MsContext::InitDeviceTargetAndPolicy func) {
206 (void)InitFuncMap().emplace(name, func);
207 if (GetInstance() != nullptr) {
208 GetInstance()->SetDefaultDeviceTarget();
209 }
210 std::string plugin_path;
211 #if !defined(_WIN32) && !defined(_WIN64)
212 Dl_info dl_info;
213 if (dladdr(reinterpret_cast<void *>(func), &dl_info) == 0) {
214 MS_LOG(EXCEPTION) << "Get dladdr error for " << name;
215 }
216 plugin_path = dl_info.dli_fname;
217 #else
218 HMODULE h_module = nullptr;
219 if (GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT | GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS,
220 (LPCSTR)func, &h_module) == 0) {
221 MS_LOG(EXCEPTION) << "Get GetModuleHandleEx failed for " << name;
222 }
223 char sz_path[MAX_PATH];
224 if (GetModuleFileName(h_module, sz_path, sizeof(sz_path)) == 0) {
225 MS_LOG(EXCEPTION) << "Get GetModuleFileName failed for " << name;
226 }
227 plugin_path = std::string(sz_path);
228 #endif
229 (void)PluginPathMap().emplace(name, plugin_path);
230 }
231
ResisterLoadPluginErrorFunc(MsContext::LoadPluginError func)232 void MsContext::ResisterLoadPluginErrorFunc(MsContext::LoadPluginError func) { load_plugin_error_ = func; }
233
IsAscendPluginLoaded() const234 bool MsContext::IsAscendPluginLoaded() const {
235 #ifdef WITH_BACKEND
236 return InitFuncMap().find("Ascend") != InitFuncMap().end();
237 #else
238 // for ut test
239 return true;
240 #endif
241 }
242
SetDefaultDeviceTarget()243 void MsContext::SetDefaultDeviceTarget() {
244 auto cpu_iter = InitFuncMap().find(kCPUDevice);
245 if (cpu_iter == InitFuncMap().end()) {
246 return;
247 }
248 if (InitFuncMap().size() == 1) {
249 // when only cpu in map
250 cpu_iter->second(inst_context_.get());
251 } else if (InitFuncMap().size() == kDeviceTargetSize2) {
252 // when cpu and another in map
253 for (auto [name, func] : InitFuncMap()) {
254 if (name != kCPUDevice) {
255 inst_context_ = std::make_shared<MsContext>("ms", name);
256 func(inst_context_.get());
257 }
258 }
259 } else {
260 cpu_iter->second(inst_context_.get());
261 }
262 default_device_target_ = true;
263 }
264
SetDeviceTargetFromInner(const std::string & device_target)265 void MsContext::SetDeviceTargetFromInner(const std::string &device_target) {
266 if (seter_ != nullptr) {
267 if (!InitFuncMap().empty()) {
268 if (auto iter = InitFuncMap().find(device_target); iter == InitFuncMap().end()) {
269 CheckEnv(device_target);
270 std::string device_list = "[";
271 for (auto citer = InitFuncMap().cbegin(); citer != InitFuncMap().cend(); ++citer) {
272 if (device_list == "[") {
273 device_list += "\'" + citer->first + "\'";
274 } else {
275 device_list += ", \'" + citer->first + "\'";
276 }
277 }
278 device_list += "]";
279 if (load_plugin_error_ != nullptr) {
280 auto load_plugin_error_str = load_plugin_error_();
281 if (!load_plugin_error_str.empty()) {
282 MS_EXCEPTION(RuntimeError) << "Unsupported device target " << device_target
283 << ". This process only supports one of the " << device_list
284 << ". Please check whether the " << device_target
285 << " environment is installed and configured correctly, and check whether "
286 "current mindspore wheel package was built with \"-e "
287 << device_target
288 << "\". For details, please refer to \"Device load error message\"." << std::endl
289 << "#umsg#Device load error message:#umsg#" << load_plugin_error_str;
290 }
291 }
292 MS_EXCEPTION(RuntimeError) << "Unsupported device target " << device_target
293 << ". This process only supports one of the " << device_list
294 << ". Please check whether the " << device_target
295 << " environment is installed and configured correctly, and check whether "
296 "current mindspore wheel package was built with \"-e "
297 << device_target << "\".";
298 } else {
299 iter->second(this);
300 SetEnv(device_target);
301 }
302 }
303 MS_LOG(INFO) << "ms set context device target:" << device_target;
304 seter_(device_target);
305 }
306 if (device_target == "Ascend" && !CheckWriteStatus(MS_CTX_MEMORY_OPTIMIZE_LEVEL)) {
307 MS_LOG(INFO) << "Set memory_optimize_level to O1 as default on ascend";
308 int_params_[MS_CTX_MEMORY_OPTIMIZE_LEVEL - MS_CTX_TYPE_INT_BEGIN] = kOptimizeO1;
309 } else if (!CheckWriteStatus(MS_CTX_MEMORY_OPTIMIZE_LEVEL)) {
310 MS_LOG(INFO) << "Set memory_optimize_level to O0 as default on other device";
311 int_params_[MS_CTX_MEMORY_OPTIMIZE_LEVEL - MS_CTX_TYPE_INT_BEGIN] = kOptimizeO0;
312 }
313 string_params_[MS_CTX_DEVICE_TARGET - MS_CTX_TYPE_STRING_BEGIN] = device_target;
314 }
315
SetDeviceTargetFromUser(const std::string & device_target)316 void MsContext::SetDeviceTargetFromUser(const std::string &device_target) {
317 SetDeviceTargetFromInner(device_target);
318 default_device_target_ = false;
319 }
320
IsDefaultDeviceTarget() const321 bool MsContext::IsDefaultDeviceTarget() const { return default_device_target_; }
322
RegisterSetEnv(const EnvFunc & func)323 void MsContext::RegisterSetEnv(const EnvFunc &func) { set_env_ = func; }
RegisterCheckEnv(const EnvFunc & func)324 void MsContext::RegisterCheckEnv(const EnvFunc &func) { check_env_ = func; }
325
SetEnv(const std::string & device)326 void MsContext::SetEnv(const std::string &device) {
327 if (set_env_ == nullptr) {
328 return;
329 }
330
331 if (auto iter = PluginPathMap().find(device); iter != PluginPathMap().end()) {
332 const auto &library_path = iter->second;
333 set_env_(device, library_path);
334 }
335 }
336
CheckEnv(const std::string & device)337 void MsContext::CheckEnv(const std::string &device) {
338 if (check_env_ == nullptr) {
339 return;
340 }
341
342 check_env_(device, "");
343 }
344
GetSaveGraphsPath() const345 std::string MsContext::GetSaveGraphsPath() const {
346 std::string path = common::GetEnv("MS_DEV_SAVE_GRAPHS_PATH");
347 if (!path.empty()) {
348 return path;
349 } else {
350 return MsContext::GetInstance()->get_param<std::string>(MS_CTX_SAVE_GRAPHS_PATH);
351 }
352 }
353
GetSaveGraphsLevel() const354 int MsContext::GetSaveGraphsLevel() const {
355 static std::string save_env = common::GetEnv("MS_DEV_SAVE_GRAPHS");
356 if (save_env.size() == 1) {
357 int save_graphs_by_env = -1;
358 try {
359 save_graphs_by_env = std::stoi(save_env);
360 } catch (const std::invalid_argument &ia) {
361 MS_LOG(EXCEPTION) << "Invalid argument: " << ia.what() << " when parse " << save_env;
362 }
363 if (save_graphs_by_env < 0 || save_graphs_by_env > kFully) {
364 MS_LOG(EXCEPTION) << "Dump level can only be from 0 to 3";
365 }
366 return save_graphs_by_env;
367 } else if (save_env.size() > 1) {
368 MS_LOG(EXCEPTION) << "MS_DEV_SAVE_GRAPHS should be a single number with one digit.";
369 }
370 return MsContext::GetInstance()->get_param<int>(MS_CTX_SAVE_GRAPHS_FLAG);
371 }
372
CanDump(const DumpLevel & level) const373 bool MsContext::CanDump(const DumpLevel &level) const { return GetSaveGraphsLevel() >= level; }
374
MarkReadStatus(MsCtxParam param) const375 void MsContext::MarkReadStatus(MsCtxParam param) const {
376 #if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES) || defined(BUILD_LITE))
377 // unit tests will set device_id many times in one process
378 if (static_cast<size_t>(param) < params_read_status_.size()) {
379 params_read_status_[static_cast<size_t>(param)] = true;
380 }
381 #endif
382 }
383
MarkWriteStatus(MsCtxParam param) const384 void MsContext::MarkWriteStatus(MsCtxParam param) const {
385 if (static_cast<size_t>(param) < params_write_status_.size()) {
386 params_write_status_[static_cast<size_t>(param)] = true;
387 }
388 }
389
390 template <typename T>
CheckReadStatus(MsCtxParam param,const T & value) const391 void MsContext::CheckReadStatus(MsCtxParam param, const T &value) const {
392 #if !(defined(ENABLE_TEST) || defined(ENABLE_TESTCASES) || defined(BUILD_LITE))
393 // unit tests will set device_id many times in one process
394 if (static_cast<size_t>(param) >= params_read_status_.size()) {
395 return;
396 }
397 auto iter = kUnresetParamCheckList.find(param);
398 if (iter == kUnresetParamCheckList.end()) {
399 return;
400 }
401 auto origin_status = params_read_status_;
402 T origin_value = get_param<T>(param);
403 params_read_status_ = origin_status;
404 if (params_read_status_[static_cast<size_t>(param)] && value != origin_value) {
405 MS_EXCEPTION(TypeError) << "For 'set_context', the parameter " << iter->second
406 << " can not be set repeatedly, origin value [" << origin_value << "] has been in effect."
407 << " Maybe 'mindspore.communication.init()' has been called before 'set_context()'.";
408 }
409 #endif
410 }
411
CheckWriteStatus(MsCtxParam param) const412 bool MsContext::CheckWriteStatus(MsCtxParam param) const {
413 if (static_cast<size_t>(param) >= params_write_status_.size()) {
414 return false;
415 }
416 return params_write_status_[static_cast<size_t>(param)];
417 }
418
419 // Reset ms context. Only called in child process after fork occurs.
ChildAfterFork()420 void MsContext::ChildAfterFork() {
421 MS_LOG(DEBUG) << "Reset context after fork.";
422 // configs can be modified again.
423 params_read_status_ = std::vector<bool>(
424 static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
425 MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
426 false);
427 params_write_status_ = std::vector<bool>(
428 static_cast<size_t>(MsCtxParam::NUM_BOOL_PARAMS + MsCtxParam::NUM_UINT32_PARAMS + MsCtxParam::NUM_INT_PARAMS +
429 MsCtxParam::NUM_FLOAT_PARAMS + MsCtxParam::NUM_STRING_PARAMS),
430 false);
431 std::string device_target_ = get_param<std::string>(MS_CTX_DEVICE_TARGET);
432 if (device_target_ != kCPUDevice) {
433 // set device_target to 'CPU' as default.
434 MS_LOG(INFO) << "Process " << getpid() << " config changed: 'device_target' is reset to 'CPU'.";
435 SetDeviceTargetFromUser("CPU");
436 }
437 }
438
EnableAoeOnline() const439 bool MsContext::EnableAoeOnline() const {
440 std::string aoe_tune_mode = MsContext::GetInstance()->get_param<std::string>(MS_CTX_AOE_TUNE_MODE);
441 return aoe_tune_mode == "online";
442 }
443
EnableAoeOffline() const444 bool MsContext::EnableAoeOffline() const {
445 std::string aoe_tune_mode = MsContext::GetInstance()->get_param<std::string>(MS_CTX_AOE_TUNE_MODE);
446 return aoe_tune_mode == "offline";
447 }
448
449 namespace {
PrintJitLevelAndExecMode(bool is_jit_level_changed,const std::string & jit_level,const std::string & exec_mode)450 void PrintJitLevelAndExecMode(bool is_jit_level_changed, const std::string &jit_level, const std::string &exec_mode) {
451 if (!is_jit_level_changed) {
452 return;
453 }
454
455 MS_LOG(INFO) << "The jit_level is: " << jit_level << ", and " << exec_mode;
456 static std::string is_enable_runtime_cfg = common::GetEnv("MS_DEV_RUNTIME_CONF");
457 if (!is_enable_runtime_cfg.empty()) {
458 std::cout << "[MS_RUNTIME_PROF]The jit_level is: " << jit_level << ", and " << exec_mode << std::endl;
459 }
460 }
461 } // namespace
462
SetJitLevel(const std::string & jit_level) const463 void MsContext::SetJitLevel(const std::string &jit_level) const {
464 if (jit_level.empty()) {
465 return;
466 }
467 std::map<std::string, std::string> jit_config = PhaseManager::GetInstance().jit_config();
468 jit_config["jit_level"] = jit_level;
469 PhaseManager::GetInstance().set_jit_config(jit_config);
470 }
471
GetJitLevel() const472 std::string MsContext::GetJitLevel() const {
473 static bool first_call = true;
474 const auto &jit_config = PhaseManager::GetInstance().jit_config();
475 std::string jit_level = "";
476 auto iter = jit_config.find("jit_level");
477 if (iter != jit_config.end()) {
478 jit_level = iter->second;
479 }
480
481 auto global_jit_level = get_param<std::string>(MS_CTX_JIT_LEVEL);
482 auto device_target = get_param<std::string>(MS_CTX_DEVICE_TARGET);
483 auto mode = get_param<int>(MS_CTX_EXECUTION_MODE);
484 if (jit_level.empty()) {
485 if (!global_jit_level.empty()) {
486 jit_level = global_jit_level;
487 } else if (device_target == kAscendDevice && mode == kGraphMode) {
488 jit_level = ascend_soc_version() == kAscendVersion910 ? kAttrJitLevelO2 : kAttrJitLevelO0;
489 } else {
490 jit_level = kAttrJitLevelO0;
491 }
492 }
493
494 if (mode == kPynativeMode && jit_level == kAttrJitLevelO2) {
495 if (first_call) {
496 MS_LOG(WARNING) << "Pynative mode can not set jit_level to O2, use O0 instead.";
497 }
498 jit_level = kAttrJitLevelO0;
499 }
500
501 // If use rank table startup method, set jit level to O2.
502 if (!common::UseDynamicCluster() && !common::GetEnv("RANK_TABLE_FILE").empty() && jit_level != kAttrJitLevelO2) {
503 if (first_call) {
504 MS_LOG(WARNING) << "Set jit level to O2 for rank table startup method.";
505 }
506 jit_level = kAttrJitLevelO2;
507 }
508 first_call = false;
509
510 return jit_level;
511 }
512
IsKByKExecutorMode() const513 bool MsContext::IsKByKExecutorMode() const {
514 // Get jit level.
515 std::string jit_level = GetJitLevel();
516 static std::string jit_level_log = "";
517 bool is_jit_level_changed = false;
518 auto mode = get_param<int>(MS_CTX_EXECUTION_MODE);
519 if (jit_level_log != jit_level) {
520 is_jit_level_changed = true;
521 jit_level_log = jit_level;
522 }
523
524 if (get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD)) {
525 PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable kernelbykernel executor by mem offload.");
526 return true;
527 }
528
529 if (mode == kPynativeMode) {
530 if (jit_level == kAttrJitLevelO2) {
531 PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable graph_sink executor in the PYNATIVE mode.");
532 return false;
533 }
534 PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable kernelbykernel executor in the PYNATIVE mode.");
535 return true;
536 }
537
538 if (mode == kGraphMode) {
539 if (jit_level == kAttrJitLevelO0 || jit_level == kAttrJitLevelO1) {
540 PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable kernelbykernel executor in the GRAPH mode.");
541 return true;
542 }
543 PrintJitLevelAndExecMode(is_jit_level_changed, jit_level, "enable graph_sink executor in the GRAPH mode.");
544 return false;
545 }
546
547 MS_LOG(ERROR) << "No valid executor mode.";
548 return false;
549 }
550
SetAscendConfig()551 void MsContext::SetAscendConfig() {
552 set_param<std::string>(MS_CTX_PRECISION_MODE, "");
553 set_param<std::string>(MS_CTX_ENABLE_JIT_COMPILE, "");
554 set_param<std::string>(MS_CTX_ATOMIC_CLEAN_POLICY, "");
555 set_param<std::string>(MS_CTX_MATMUL_ALLOW_HF32, "");
556 set_param<std::string>(MS_CTX_CONV_ALLOW_HF32, "");
557 set_param<std::string>(MS_CTX_OP_PRECISION_MODE, "");
558 set_param<std::string>(MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD, "");
559 set_param<std::string>(MS_CTX_GE_OPTIONS, "");
560 }
561
InitBoolTypeDefaultValue()562 void MsContext::InitBoolTypeDefaultValue() {
563 set_param<bool>(MS_CTX_ENABLE_DUMP, false);
564 set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
565 set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
566 set_param<bool>(MS_CTX_ENABLE_REDUCE_PRECISION, true);
567 set_param<bool>(MS_CTX_ENABLE_TASK_SINK, true);
568 set_param<bool>(MS_CTX_IR_FUSION_FLAG, true);
569 set_param<bool>(MS_CTX_ENABLE_HCCL, false);
570 set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
571 set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
572 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
573 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
574 set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
575 set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
576 set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
577 set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
578 set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
579 set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
580 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE, false);
581 set_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE, true);
582 set_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD, false);
583 set_param<bool>(MS_CTX_ENABLE_PROF_MEM, false);
584 set_param<bool>(MS_CTX_ENABLE_RECOVERY, false);
585 set_param<bool>(MS_CTX_ENABLE_GE_HETEROGENOUS, false);
586 set_param<bool>(MS_CTX_DISABLE_FORMAT_TRANSFORM, false);
587 set_param<bool>(MS_CTX_RECOMPUTE_COMM_OVERLAP, false);
588 set_param<bool>(MS_CTX_GRAD_COMM_OVERLAP, false);
589 set_param<bool>(MS_CTX_ENABLE_OPT_SHARD_COMM_OPT, false);
590 set_param<bool>(MS_CTX_ENABLE_TASK_OPT, false);
591 set_param<bool>(MS_CTX_ENABLE_GRAD_COMM_OPT, false);
592 set_param<bool>(MS_CTX_INTERLEAVED_MATMUL_COMM, false);
593 set_param<bool>(MS_CTX_INTERLEAVED_LAYERNORM_COMM, false);
594 set_param<bool>(MS_CTX_BIAS_ADD_COMM_SWAP, false);
595 set_param<bool>(MS_CTX_ENABLE_BEGIN_END_INLINE_OPT, false);
596 set_param<bool>(MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT, false);
597 set_param<bool>(MS_CTX_ENABLE_FUSED_CAST_ADD_OPT, false);
598 set_param<bool>(MS_CTX_ENABLE_PROFILING, false);
599 set_param<bool>(MS_CTX_CHECK_BPROP_FLAG, false);
600 set_param<bool>(MS_CTX_CONV_ALLOW_TF32, true);
601 set_param<bool>(MS_CTX_MATMUL_ALLOW_TF32, false);
602 set_param<bool>(MS_CTX_NEED_CKPT, false);
603 set_param<bool>(MS_CTX_RECOMPUTE_ALLGATHER_OVERLAP_FAGRAD, false);
604 set_param<bool>(MS_CTX_ENABLE_FLASH_ATTENTION_LOAD_BALANCE, false);
605 }
606
InitStringTypeDefaultValue()607 void MsContext::InitStringTypeDefaultValue() {
608 set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, "python");
609 set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, "");
610 set_param<std::string>(MS_CTX_SAVE_DUMP_PATH, ".");
611 set_param<std::string>(MS_CTX_DETERMINISTIC, "OFF");
612 set_param<std::string>(MS_CTX_ENV_CONFIG_PATH, "");
613 set_param<std::string>(MS_CTX_TUNE_MODE, "NO_TUNE");
614 set_param<std::string>(MS_CTX_AOE_TUNE_MODE, "");
615 set_param<std::string>(MS_CTX_AOE_JOB_TYPE, "2");
616 set_param<std::string>(MS_CTX_GRAPH_KERNEL_FLAGS, "");
617 set_param<std::string>(MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD, "");
618 set_param<std::string>(MS_CTX_ENABLE_EXCEPTION_DUMP, "2");
619 set_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE, "0");
620 set_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE, "0");
621 set_param<std::string>(MS_CTX_PROFILING_OPTIONS, "training_trace");
622 set_param<std::string>(MS_CTX_PRINT_FILE_PATH, "");
623 set_param<std::string>(MS_CTX_CONV_FPROP_ALGO, "normal");
624 set_param<std::string>(MS_CTX_CONV_DGRAD_ALGO, "normal");
625 set_param<std::string>(MS_CTX_CONV_WGRAD_ALGO, "normal");
626 set_param<std::string>(MS_CTX_JIT_LEVEL, "");
627 set_param<std::string>(MS_CTX_INFER_BOOST, "off");
628 set_param<std::string>(MS_CTX_PROF_MEM_OUTPUT_PATH, "");
629 }
630
InitDigitalTypeDefaultValue()631 void MsContext::InitDigitalTypeDefaultValue() {
632 set_param<int>(MS_CTX_EXECUTION_MODE, kPynativeMode);
633 set_param<int>(MS_CTX_JIT_SYNTAX_LEVEL, kLax);
634 set_param<int>(MS_CTX_CUR_STEP_NUM, 0);
635 set_param<int>(MS_CTX_SAVE_CKPT_STEPS, 0);
636 set_param<int>(MS_CTX_LAST_TRIGGERED_STEP, 0);
637 set_param<int>(MS_CTX_COMPUTE_COMMUNICATE_FUSION_LEVEL, 0);
638 set_param<int>(MS_CTX_ENABLE_COMPILE_CACHE, -1);
639 set_param<int>(MS_CTX_DEBUG_LEVEL, kLevelRelease);
640 set_param<float>(MS_CTX_MAX_DEVICE_MEMORY, kDefaultMaxDeviceMemory);
641 set_param<float>(MS_CTX_MEMPOOL_BLOCK_SIZE, kDefaultMempoolBlockSize);
642 //
643 uint32_t kDefaultInterOpParallelThreads = 0;
644 uint32_t kDefaultRuntimeNumThreads = 30;
645 uint32_t cpu_core_num = std::thread::hardware_concurrency();
646 uint32_t runtime_num_threads_default = std::min(cpu_core_num, kDefaultRuntimeNumThreads);
647 uint32_t inter_op_parallel_num_default = std::min(cpu_core_num, kDefaultInterOpParallelThreads);
648 set_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS, runtime_num_threads_default);
649 set_param<uint32_t>(MS_CTX_INTER_OP_PARALLEL_NUM, inter_op_parallel_num_default);
650 //
651 set_param<uint32_t>(MS_CTX_TSD_REF, 0);
652 set_param<uint32_t>(MS_CTX_GE_REF, 0);
653 set_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH, MAX_CALL_DEPTH_DEFAULT);
654 set_param<uint32_t>(MS_CTX_OP_TIMEOUT, kOpTimeout);
655 }
656
SplitString(const std::string & str,char delim,std::set<std::string> * output_list)657 inline void SplitString(const std::string &str, char delim, std::set<std::string> *output_list) {
658 std::stringstream ss(str);
659 std::string item;
660 while (std::getline(ss, item, delim)) {
661 if (!item.empty()) {
662 output_list->emplace(item);
663 }
664 }
665 }
666
SetToString(const std::set<std::string> & kernel_list)667 inline std::string SetToString(const std::set<std::string> &kernel_list) {
668 std::string out = "";
669 for (auto &name : kernel_list) {
670 out.append(name).append(", ");
671 }
672 return out;
673 }
674
SetMsInternalEnableCustomKernelList()675 void MsContext::SetMsInternalEnableCustomKernelList() {
676 const std::string kDefaultEnabledOpList =
677 "MatMul,RmsNorm,Add,Sub,FlashAttentionScore,PagedAttention,AddRmsNorm,AddLayerNorm,MatMulAllReduce,"
678 "InferenceMatmulSplit";
679 auto internal_op_boost_env = common::GetEnv("MS_ENABLE_INTERNAL_BOOST");
680 bool is_enable_internal_op = true;
681 if (internal_op_boost_env == "off") {
682 is_enable_internal_op = false;
683 }
684
685 std::set<std::string> enable_fusion_list;
686 if (is_enable_internal_op) {
687 SplitString(kDefaultEnabledOpList, ',', &enable_fusion_list);
688 }
689
690 std::string env = common::GetEnv("MS_INTERNAL_ENABLE_CUSTOM_KERNEL_LIST");
691 if (!env.empty()) {
692 SplitString(env, ',', &enable_fusion_list);
693 }
694
695 std::set<std::string> disable_fusion_list;
696 env = common::GetEnv("MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST");
697 if (!env.empty()) {
698 SplitString(env, ',', &disable_fusion_list);
699 }
700
701 ms_internal_enable_custom_kernel_list_.clear();
702 for (const auto &fusion_name : enable_fusion_list) {
703 if (disable_fusion_list.find(fusion_name) == disable_fusion_list.end()) {
704 ms_internal_enable_custom_kernel_list_.emplace(fusion_name);
705 }
706 }
707
708 MS_LOG(INFO) << "Enable internal kernel list: " << SetToString(ms_internal_enable_custom_kernel_list_);
709 }
710
IsEnableInferBoost()711 bool MsContext::IsEnableInferBoost() {
712 if (enable_infer_boost_.has_value()) {
713 return enable_infer_boost_.value();
714 }
715
716 const auto &jit_config = PhaseManager::GetInstance().jit_config();
717 auto iter = jit_config.find("infer_boost");
718 if (iter != jit_config.end() && iter->second == "on") {
719 enable_infer_boost_ = true;
720 MS_LOG(INFO) << "MSContext enable ms infer boost from JitConfig";
721 SetMsInternalEnableCustomKernelList();
722 return enable_infer_boost_.value();
723 }
724
725 auto global_infer_boost = get_param<std::string>(MS_CTX_INFER_BOOST);
726 if (global_infer_boost == "on") {
727 enable_infer_boost_ = true;
728 MS_LOG(INFO) << "MSContext enable ms infer boost from Global Context JitConfig";
729 SetMsInternalEnableCustomKernelList();
730 return enable_infer_boost_.value();
731 }
732
733 if (common::GetEnv("MS_ENABLE_INTERNAL_KERNELS") == "on") {
734 enable_infer_boost_ = true;
735 MS_LOG(INFO) << "MSContext enable ms infer boost from Env";
736 SetMsInternalEnableCustomKernelList();
737 } else {
738 enable_infer_boost_ = false;
739 }
740
741 return enable_infer_boost_.value();
742 }
743
ms_internal_enable_custom_kernel_list() const744 const std::set<std::string> &MsContext::ms_internal_enable_custom_kernel_list() const {
745 return ms_internal_enable_custom_kernel_list_;
746 }
747
748 template MS_CORE_API void MsContext::CheckReadStatus<bool>(MsCtxParam, const bool &) const;
749 template MS_CORE_API void MsContext::CheckReadStatus<uint32_t>(MsCtxParam, const uint32_t &) const;
750 template MS_CORE_API void MsContext::CheckReadStatus<int>(MsCtxParam, const int &) const;
751 template MS_CORE_API void MsContext::CheckReadStatus<float>(MsCtxParam, const float &) const;
752 template MS_CORE_API void MsContext::CheckReadStatus<std::string>(MsCtxParam, const std::string &) const;
753 } // namespace mindspore
754