1# Copyright 2020-2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16The context of mindspore, used to configure the current execution environment, 17includes the execution mode, execution backend and other feature switches. 18""" 19from __future__ import absolute_import 20 21import json 22import os 23import time 24import threading 25from collections import namedtuple 26from types import FunctionType 27 28from mindspore import log as logger 29from mindspore._c_expression import MSContext, ms_ctx_param 30from mindspore import _checkparam as Validator 31from mindspore._checkparam import args_type_check 32from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ 33 _reset_auto_parallel_context 34from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context, \ 35 _need_reset_device_target_for_ps 36from mindspore.parallel._offload_context import _set_offload_context, _get_offload_context 37from mindspore.hal.device import is_initialized 38 39__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'STRICT', 'COMPATIBLE', 'LAX', 'set_context', 'get_context', 40 'set_auto_parallel_context', 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 41 'set_ps_context', 'get_ps_context', 'reset_ps_context', 'set_offload_context', 'get_offload_context'] 42 43GRAPH_MODE = 0 44PYNATIVE_MODE = 1 45_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. 46_RE_PATTERN = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' 47K_CONTEXT = None 48 49# Enumerate for the property 'jit_syntax_level'. 50STRICT = 0 51COMPATIBLE = 1 52LAX = 2 53 54# Enumerate for the property 'debug_level'. 55RELEASE = 0 56DEBUG = 1 57 58 59def _make_directory(path): 60 """Make directory.""" 61 if path is None or not isinstance(path, str) or path.strip() == "": 62 raise ValueError(f"For 'context.set_context', the 'save_graphs_path' or the 'print_file_path' is invalid " 63 f"type, it should be Non-empty string, but got '{path}'.") 64 65 path = os.path.realpath(path) 66 logger.debug("The absolute path is %r", path) 67 68 if not os.path.exists(path): 69 logger.debug("The directory(%s) doesn't exist, will create it", path) 70 try: 71 os.makedirs(path) 72 except FileExistsError: 73 logger.debug("The directory(%s) already exist.", path) 74 except PermissionError as e: 75 logger.critical(f"No write permission on the directory '{path}'', error = {e}") 76 raise ValueError(e.__str__() + f"\nNo write permission on the directory '{path}'.") 77 return path 78 79 80def _get_print_file_name(file_name): 81 """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds).""" 82 time_second = str(int(time.time())) 83 file_name = file_name + "." + time_second 84 if os.path.exists(file_name): 85 raise ValueError("For 'context.set_context', the argument 'print_file_path' {} already exists, " 86 "please check it".format(file_name)) 87 return file_name 88 89 90class _ThreadLocalInfo(threading.local): 91 """ 92 Thread local Info used for store thread local attributes. 93 """ 94 95 def __init__(self): 96 super(_ThreadLocalInfo, self).__init__() 97 self._reserve_class_name_in_scope = True 98 self.debug_runtime = False 99 100 @property 101 def reserve_class_name_in_scope(self): 102 """Get whether to save the network class name in the scope.""" 103 return self._reserve_class_name_in_scope 104 105 @reserve_class_name_in_scope.setter 106 def reserve_class_name_in_scope(self, reserve_class_name_in_scope): 107 """Set whether to save the network class name in the scope.""" 108 self._reserve_class_name_in_scope = reserve_class_name_in_scope 109 110 111_ContextRecord = namedtuple( 112 "_ContextRecord", ["is_pynative_mode", "switch_context_fn"]) 113 114 115class _ContextSwitchInfo(threading.local): 116 """ 117 Record of context switch information. 118 119 Args: 120 is_pynative (bool): Whether to adopt the PyNative mode. 121 """ 122 123 def __init__(self, is_pynative): 124 super(_ContextSwitchInfo, self).__init__() 125 self.context_stack = [] 126 if is_pynative: 127 self.push(True, None) 128 129 def push(self, is_pynative, switch_context_fn): 130 """ 131 Push a context switch record onto the stack. 132 133 Args: 134 is_pynative (bool): Whether context switch to PyNative mode. 135 switch_context_fn (Function): A callable that executes the context switch. 136 """ 137 if isinstance(switch_context_fn, FunctionType): 138 switch_context_fn() 139 self.context_stack.append( 140 _ContextRecord(is_pynative, switch_context_fn)) 141 142 def pop(self): 143 self.context_stack.pop() 144 145 146class _Context: 147 """ 148 _Context is the environment in which operations are executed 149 150 Note: 151 Create a context through instantiating Context object is not recommended. 152 should use context() to get the context since Context is a singleton. 153 """ 154 _instance = None 155 _instance_lock = threading.Lock() 156 157 def __new__(cls, *args, **kwargs): 158 if cls._instance is None: 159 cls._instance_lock.acquire() 160 cls._instance = object.__new__(cls) 161 cls._instance_lock.release() 162 return cls._instance 163 164 def __init__(self): 165 self._thread_local_info = _ThreadLocalInfo() 166 self._context_switches = _ContextSwitchInfo(False) 167 self._context_handle = MSContext.get_instance() 168 self._support_binary = False 169 self.enable_compile_cache = None 170 self._mode = PYNATIVE_MODE 171 self._jit_config = {} 172 173 def __getattribute__(self, attr): 174 value = object.__getattribute__(self, attr) 175 if attr == "_context_handle" and value is None: 176 raise ValueError("Get {} failed, please check whether 'env_config_path' is correct.".format(attr)) 177 return value 178 179 def get_param(self, param): 180 return self._context_handle.get_param(param) 181 182 def set_param(self, param, value): 183 self._context_handle.set_param(param, value) 184 185 def get_mode(self): 186 """Get current mode.""" 187 return self._mode 188 189 def get_jit_config(self): 190 """Get current jit_config.""" 191 return self._jit_config 192 193 def set_mode(self, mode): 194 """ 195 Switch between Graph mode and PyNative mode. 196 197 Args: 198 mode (int): GRAPH_MODE or PYNATIVE_MODE. 199 """ 200 if mode == PYNATIVE_MODE: 201 if self.enable_debug_runtime: 202 self.set_backend_policy("vm") 203 parallel_mode = _get_auto_parallel_context("parallel_mode") 204 if parallel_mode not in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE, ParallelMode.AUTO_PARALLEL): 205 raise ValueError(f"Got {parallel_mode}, when the user enabled SEMI_AUTO_PARALELL, " 206 f"pynative mode dose not support, you should set either " 207 f"context.set_auto_parallel_context(parallel_mode='data_parallel'), " 208 f"context.set_auto_parallel_context(parallel_mode='stand_alone') " 209 f"or context.set_auto_parallel_context(parallel_mode='auto_parallel').") 210 self._context_switches.push(True, None) 211 elif mode == GRAPH_MODE: 212 if self.enable_debug_runtime: 213 self.set_backend_policy("ge") 214 self._context_switches.push(False, None) 215 else: 216 raise ValueError(f"For 'context.set_context', the argument 'mode' should be context.GRAPH_MODE (0) " 217 f"or context.PYNATIVE_MODE (1), but got {mode}.") 218 self.set_param(ms_ctx_param.mode, mode) 219 self._mode = mode 220 221 def set_jit_syntax_level(self, level): 222 """"Set the JIT syntax level for graph compiling""" 223 if level != STRICT and level != COMPATIBLE and level != LAX: 224 raise ValueError(f"For 'context.set_jit_syntax_level', the argument 'level' should be context.STRICT " 225 f"or context.LAX, but got {level}.") 226 self.set_param(ms_ctx_param.jit_syntax_level, level) 227 228 def set_debug_level(self, level): 229 """"Set the debug level for graph compiling""" 230 if level != RELEASE and level != DEBUG: 231 raise ValueError(f"For 'context.set_debug_level', the argument 'level' should be context.RELEASE " 232 f"or context.DEBUG, but got {level}.") 233 self.set_param(ms_ctx_param.debug_level, level) 234 235 def set_memory_optimize_level(self, memory_optimize_level): 236 """ 237 The memory optimize level, support "O0", "O1". 238 239 Args: 240 target (str): "O0", "O1" 241 """ 242 memory_optimize_levels = ["O0", "O1"] 243 if memory_optimize_level not in memory_optimize_levels: 244 raise ValueError(f"For 'context.set_context', the argument 'memory_optimize_level' must be one of " 245 f"{memory_optimize_levels}, but got {memory_optimize_level}.") 246 if memory_optimize_level == "O0": 247 self.set_param(ms_ctx_param.memory_optimize_level, 0) 248 else: 249 self.set_param(ms_ctx_param.memory_optimize_level, 1) 250 251 def set_memory_offload(self, memory_offload): 252 """ 253 Enable memory offload or not, support "ON", "OFF". 254 255 Args: 256 memory_offload (str): "ON", "OFF" 257 """ 258 memory_offload_options = ["ON", "OFF"] 259 if memory_offload not in memory_offload_options: 260 raise ValueError(f"For 'context.set_context', the argument 'memory_offload' must be one of " 261 f"{memory_offload_options}, but got {memory_offload}.") 262 if memory_offload == "ON": 263 self.set_param(ms_ctx_param.memory_offload, True) 264 else: 265 self.set_param(ms_ctx_param.memory_offload, False) 266 267 def set_deterministic(self, deterministic): 268 """ 269 Enable model run in deterministic, and support the values "ON" and "OFF". 270 271 Args: 272 deterministic (str): "ON", "OFF" 273 """ 274 deterministic_options = ["ON", "OFF"] 275 if deterministic not in deterministic_options: 276 raise ValueError(f"For 'context.set_context', the argument 'deterministic' must be one of " 277 f"{deterministic_options}, but got {deterministic}.") 278 self.set_param(ms_ctx_param.deterministic, deterministic) 279 280 def set_ascend_config(self, ascend_config): 281 """ 282 Enable ascend config. 283 284 Args: 285 ascend_config (dict): 286 - precision_mode (str): "force_fp16", "allow_fp32_to_fp16", "allow_mix_precision", 287 "must_keep_origin_dtype", "force_fp32", "allow_fp32_to_bf16", 288 "allow_mix_precision_fp16" and "allow_mix_precision_bf16". 289 - jit_compile (bool): ``False`` and ``True``. 290 - atomic_clean_policy (int): ``0`` and ``1``. Default: ``1`` . 291 - op_precision_mode (str): precision mode config file path. 292 - op_debug_option (str): Enable debugging options for Ascend operators, 293 default not enabled, only supports ``"oom"`` currently. 294 ``"oom"``: Detect memory out of bounds. 295 - ge_options (dict): Global or session CANN options. 296 - exception_dump (str): Enable exception dump for Ascend operators. ``"0"`` , ``"1"`` and ``"2"``. 297 Default: ``"2"`` . 298 - parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file. 299 If its value is None or '', it does not take effect. Default None. 300 - host_scheduling_max_threshold(int): The host scheduling max threshold. 301 """ 302 ascend_cfg_modes = { 303 'precision_mode': ["force_fp16", "allow_fp32_to_fp16", "allow_mix_precision", "must_keep_origin_dtype", 304 "force_fp32", "allow_fp32_to_bf16", "allow_mix_precision_fp16", 305 "allow_mix_precision_bf16"], 306 'jit_compile': [True, False], 307 'atomic_clean_policy': [0, 1], 308 'matmul_allow_hf32': [True, False], 309 'conv_allow_hf32': [True, False], 310 'exception_dump': ["0", "1", "2"], 311 'op_precision_mode': (str,), 312 'ge_options': (dict,), 313 'parallel_speed_up_json_path': (str, None), 314 'host_scheduling_max_threshold': (int,), 315 'cur_step_num': (int,), 316 'save_checkpoint_steps': (int,), 317 'need_ckpt': (bool,), 318 'last_triggered_step': (int,), 319 'topo_order': (dict,), 320 'op_debug_option': (str, None), 321 } 322 ascend_cfg_setters = { 323 'precision_mode': self._get_ascend_config_setter('precision_mode'), 324 'jit_compile': self._get_ascend_config_setter('jit_compile', lambda v: "1" if v else "0"), 325 'atomic_clean_policy': self._get_ascend_config_setter('atomic_clean_policy', str), 326 'matmul_allow_hf32': self._get_ascend_config_setter('matmul_allow_hf32', lambda v: "1" if v else "0"), 327 'conv_allow_hf32': self._get_ascend_config_setter('conv_allow_hf32', lambda v: "1" if v else "0"), 328 'exception_dump': self._get_ascend_config_setter('exception_dump'), 329 'op_debug_option': self._set_op_debug_option, 330 'op_precision_mode': self._set_op_precision_mode, 331 'ge_options': self._set_ge_options, 332 'parallel_speed_up_json_path': self._set_speedup_config_path, 333 'host_scheduling_max_threshold': self._get_ascend_config_setter('host_scheduling_max_threshold', str), 334 'cur_step_num': self._set_cur_step_num, 335 'save_checkpoint_steps': self._set_save_checkpoint_steps, 336 'need_ckpt': self._set_need_ckpt, 337 'last_triggered_step': self._set_last_triggered_step, 338 'topo_order': self._set_topo_order 339 } 340 ascend_cfg_set = tuple(ascend_cfg_modes.keys()) 341 for ascend_key, ascend_value in ascend_config.items(): 342 if ascend_key not in ascend_cfg_set: 343 raise ValueError(f"For 'context.set_context', the key of argument 'ascend_config' must be one of " 344 f"{ascend_cfg_set}, but got {ascend_key}.") 345 supported_modes = ascend_cfg_modes.get(ascend_key) 346 if isinstance(supported_modes, list) and ascend_value not in supported_modes: 347 raise ValueError(f"For 'ascend_config', the value of argument {ascend_key} must be one of " 348 f"{supported_modes}, but got {ascend_value}.") 349 if isinstance(supported_modes, tuple) and not isinstance(ascend_value, supported_modes): 350 raise TypeError(f"For 'ascend_config', the type of argument {ascend_key} must be one of " 351 f"{supported_modes}, but got {type(ascend_value)}.") 352 cfg_setter = ascend_cfg_setters.get(ascend_key) 353 cfg_setter(ascend_value) 354 355 def set_gpu_config(self, gpu_config): 356 """ 357 Enable gpu config. 358 359 Args: 360 gpu_config (dict): 361 362 - conv_fprop_algo (str): "normal", "performance" or user specifies conv forward algorithm directly. 363 - conv_dgrad_algo (str): "normal", "performance" or user specifies conv data grad algorithm directly. 364 - conv_wgrad_algo (str): "normal", "performance" or user specifies conv weight grad algorithm directly. 365 - conv_allow_tf32 (bool): ``False`` and ``True``. 366 - matmul_allow_tf32 (bool): ``False`` and ``True``. 367 """ 368 369 gpu_cfgs = {'conv_fprop_algo': ["normal", "performance", "implicit_gemm", "precomp_gemm", "gemm", "direct", 370 "fft", "fft_tiling", "winograd", "winograd_nonfused"], 371 'conv_dgrad_algo': ["normal", "performance", "algo_0", "algo_1", "fft", "fft_tiling", "winograd", 372 "winograd_nonfused"], 373 'conv_wgrad_algo': ["normal", "performance", "algo_0", "algo_1", "fft", "algo_3", "fft_tiling", 374 "winograd_nonfused"], 375 'conv_allow_tf32': [True, False], 376 'matmul_allow_tf32': [True, False]} 377 for gpu_key in gpu_config: 378 if gpu_key not in gpu_cfgs: 379 raise ValueError(f"For 'context.set_context', the key of argument 'gpu_config' must be one of " 380 f"{gpu_cfgs}, but got {gpu_key}.") 381 supported_value = gpu_cfgs.get(gpu_key) 382 if gpu_config[gpu_key] not in supported_value: 383 raise ValueError(f"For 'gpu_config', the value of argument {gpu_key} must be one of " 384 f"{supported_value}, but got {gpu_config[gpu_key]}.") 385 if gpu_key == 'conv_fprop_algo': 386 self.set_param(ms_ctx_param.conv_fprop_algo, gpu_config[gpu_key]) 387 if gpu_key == 'conv_dgrad_algo': 388 self.set_param(ms_ctx_param.conv_dgrad_algo, gpu_config[gpu_key]) 389 if gpu_key == 'conv_wgrad_algo': 390 self.set_param(ms_ctx_param.conv_wgrad_algo, gpu_config[gpu_key]) 391 if gpu_key == 'conv_allow_tf32': 392 self.set_param(ms_ctx_param.conv_allow_tf32, gpu_config[gpu_key]) 393 if gpu_key == 'matmul_allow_tf32': 394 self.set_param(ms_ctx_param.matmul_allow_tf32, gpu_config[gpu_key]) 395 396 def set_jit_config(self, jit_config): 397 """ 398 Enable jit config. 399 400 Args: 401 jit_config (dict): 402 403 - jit_level (str): "O0", "O1" or "O2" to control the compilation optimization level. 404 """ 405 jit_cfgs = {'jit_level': ["O0", "O1", "O2"], 'infer_boost': ["on", "off"]} 406 key_args_map = {'jit_level': ms_ctx_param.jit_level, 'infer_boost': ms_ctx_param.infer_boost} 407 for jit_key in jit_config: 408 if jit_key not in jit_cfgs: 409 raise ValueError(f"For 'context.set_context', the key of argument 'jit_config' must be one of " 410 f"{jit_cfgs}, but got {jit_key}.") 411 supported_value = jit_cfgs.get(jit_key) 412 if jit_config[jit_key] not in supported_value: 413 raise ValueError(f"For 'jit_cfgs', the value of argument {jit_key} must be one of " 414 f"{supported_value}, but got {jit_config[jit_key]}.") 415 self._jit_config = jit_config 416 self.set_param(key_args_map[jit_key], jit_config[jit_key]) 417 418 if 'infer_boost' in jit_config and jit_config['infer_boost'] == "on" and jit_config['jit_level'] != "O0": 419 raise ValueError(f"Only jit_level set O0 can set infer_boost to on.") 420 421 def set_backend_policy(self, policy): 422 success = self._context_handle.set_backend_policy(policy) 423 if not success: 424 raise RuntimeError("Backend policy must be one of values in ['ge', 'vm', 'ms']. " 425 "But got {}.".format(policy)) 426 427 def set_save_graphs_path(self, save_graphs_path): 428 self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) 429 430 def set_device_target(self, target): 431 """ 432 The target device to run, support "Ascend", "GPU", and "CPU". 433 434 Args: 435 target (str): "Ascend", "GPU", and "CPU". 436 """ 437 valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] 438 if target not in valid_targets: 439 raise ValueError(f"For 'context.set_context', the argument 'device_target' must be one of " 440 f"{valid_targets}, but got {target}.") 441 if target == "Davinci": 442 target = "Ascend" 443 logger.warning("The device 'Davinci' is deprecated and will be removed in the next version. " 444 "For 'context.set_context', please set the argument 'device_target' " 445 "to 'CPU', 'GPU' or 'Ascend',if you set it to 'Davinci', it will be automatically " 446 "changed to 'Ascend'.") 447 # If in Parameter Server mode, Ascend card should not be used by server and scheduler. 448 if _need_reset_device_target_for_ps(target): 449 logger.info("Reset device target to CPU when set_device_target.") 450 target = "CPU" 451 self.set_param(ms_ctx_param.device_target, target) 452 if self.enable_debug_runtime and target == "CPU": 453 self.set_backend_policy("vm") 454 455 def set_aoe_tune_mode(self, tune_mode): 456 """ 457 Set aoe tune mode, support "online" and "offline". 458 459 Args: 460 tune_mode (str): "online" and "offline". 461 """ 462 candidate = ["online", "offline"] 463 if tune_mode in candidate: 464 self.set_param(ms_ctx_param.aoe_tune_mode, tune_mode) 465 else: 466 raise ValueError(f"For 'context.set_context', the argument 'aoe_tune_mode' must be in " 467 f"['online', 'offline'], but got {tune_mode}.") 468 469 def set_aoe_config(self, aoe_config): 470 """ 471 Enable aoe config. 472 473 Args: 474 aoe_config (dict): 475 - job_type (str): ``"1"``, ``"2"``. Default: ``"2"`` . 476 - ``"1"``: subgraph tuning. 477 - ``"2"``: operator tuning. 478 """ 479 480 aoe_cfgs = {'job_type': ["1", "2"]} 481 for aoe_config_key in aoe_config: 482 if aoe_config_key not in aoe_cfgs: 483 raise ValueError(f"For 'context.set_context', the key of argument 'aoe_config' must be one of " 484 f"{aoe_cfgs}, but got {aoe_config_key}.") 485 supported_value = aoe_cfgs.get(aoe_config_key) 486 if aoe_config[aoe_config_key] not in supported_value: 487 raise ValueError(f"For 'aoe_config', the value of argument {aoe_config_key} must be one of " 488 f"{supported_value}, but got {aoe_config[aoe_config_key]}.") 489 if aoe_config_key == 'job_type': 490 self.set_param(ms_ctx_param.aoe_job_type, aoe_config[aoe_config_key]) 491 492 def set_device_id(self, device_id): 493 if device_id < 0 or device_id > 4095: 494 raise ValueError(f"For 'context.set_context', the argument 'device_id' must be in range [0, 4095], " 495 f"but got {device_id}.") 496 self.set_param(ms_ctx_param.device_id, device_id) 497 498 def set_max_call_depth(self, max_call_depth): 499 if max_call_depth <= 0: 500 raise ValueError(f"For 'context.set_context', the argument 'max_call_depth' must be greater than 0, " 501 f"but got {max_call_depth}.") 502 self.set_param(ms_ctx_param.max_call_depth, max_call_depth) 503 504 def set_profiling_options(self, option): 505 if not isinstance(option, str): 506 raise TypeError("For 'context.set_context', the argument 'profiling_option' must be string, " 507 "but got {}.".format(type(option))) 508 self.set_param(ms_ctx_param.profiling_options, option) 509 510 def set_variable_memory_max_size(self, variable_memory_max_size): 511 """set values of variable_memory_max_size and graph_memory_max_size""" 512 logger.warning("For 'context.set_context', the parameter 'variable_memory_max_size' is deprecated, " 513 "and will be removed in a future " 514 "version. Please use parameter 'max_device_memory' instead.") 515 if not Validator.check_str_by_regular(variable_memory_max_size, _RE_PATTERN): 516 raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should be in correct" 517 " format! It must be a string ending with 'GB', in addition to that, it must contain " 518 "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}GB." 519 .format(variable_memory_max_size)) 520 if float(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE: 521 raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should not be " 522 "greater than 31GB, but got {}GB.".format(variable_memory_max_size)) 523 variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" 524 graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) 525 graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" 526 self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) 527 self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) 528 529 def set_max_device_memory(self, max_device_memory): 530 if not Validator.check_str_by_regular(max_device_memory, _RE_PATTERN): 531 raise ValueError("For 'context.set_context', the argument 'max_device_memory' should be in correct " 532 " format! It must be a string ending with 'GB', in addition to that, it must contain " 533 "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}." 534 .format(max_device_memory)) 535 max_device_memory_value = float(max_device_memory[:-2]) 536 if max_device_memory_value == 0: 537 raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".") 538 self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) 539 540 def set_mempool_block_size(self, mempool_block_size): 541 """Set the block size of memory pool.""" 542 global_jit_config = get_jit_config() 543 is_force_kbk = False 544 if global_jit_config: 545 is_force_kbk = global_jit_config.get('jit_level') == "O0" or global_jit_config.get('jit_level') == "O1" 546 if _get_mode() == GRAPH_MODE and not is_force_kbk: 547 logger.warning("Graph mode doesn't support to set parameter 'mempool_block_size' of context currently, " 548 "you can use context.set_context to set pynative mode or set jit_level=O0/O1.") 549 return 550 if not Validator.check_str_by_regular(mempool_block_size, _RE_PATTERN): 551 raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in " 552 "correct format! Such as \"10GB\", " 553 "but got {}".format(mempool_block_size)) 554 mempool_block_size_value = float(mempool_block_size[:-2]) 555 if mempool_block_size_value < 1.0: 556 raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be " 557 "greater or equal to \"1GB\", " 558 "but got {}GB".format(float(mempool_block_size[:-2]))) 559 self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value) 560 561 def set_print_file_path(self, file_path): 562 """Add timestamp suffix to file name. Sets print file path.""" 563 print_file_path = os.path.realpath(file_path) 564 if os.path.isdir(print_file_path): 565 raise IOError("For 'context.set_context', the argument 'print_file_path' should be file path, " 566 "but got directory {}.".format(file_path)) 567 568 if os.path.exists(print_file_path): 569 _path, _file_name = os.path.split(print_file_path) 570 path = _make_directory(_path) 571 file_name = _get_print_file_name(_file_name) 572 full_file_name = os.path.join(path, file_name) 573 else: 574 full_file_name = print_file_path 575 self.set_param(ms_ctx_param.print_file_path, full_file_name) 576 577 def set_env_config_path(self, env_config_path): 578 """Check and set env_config_path.""" 579 if not self._context_handle.enable_dump_ir(): 580 raise ValueError("For 'context.set_context', the argument 'env_config_path' is not supported, please " 581 "enable ENABLE_DUMP_IR with '-D on' and recompile source firstly.") 582 env_config_path = os.path.realpath(env_config_path) 583 if not os.path.isfile(env_config_path): 584 raise ValueError("For 'context.set_context', the 'env_config_path' file %r is not exists, " 585 "please check whether 'env_config_path' is correct." % env_config_path) 586 try: 587 with open(env_config_path, 'r') as f: 588 json.load(f) 589 except (TypeError, ValueError) as exo: 590 raise ValueError(str(exo) + "\nFor 'context.set_context', open or load the 'env_config_path' file {} " 591 "failed, please check whether 'env_config_path' is json file and correct, " 592 "or may not have permission to read it.".format(env_config_path)) from exo 593 self.set_param(ms_ctx_param.env_config_path, env_config_path) 594 595 def set_runtime_num_threads(self, runtime_num_threads): 596 """Check and set runtime_num_threads.""" 597 if runtime_num_threads < 0: 598 raise ValueError("The num of thread must bigger than or equal to 0.") 599 self.set_param(ms_ctx_param.runtime_num_threads, runtime_num_threads) 600 601 def set_op_timeout(self, op_timeout): 602 """Set the maximum duration of executing an operator in seconds.""" 603 if op_timeout < 0: 604 raise ValueError("The num of op exe timeout must bigger than or equal to 0.") 605 self.set_param(ms_ctx_param.op_timeout, op_timeout) 606 607 def set_inter_op_parallel_num(self, inter_op_parallel_num): 608 """Check and set inter_op_parallel_num.""" 609 if inter_op_parallel_num < 0: 610 raise ValueError("The num of parallel thread must bigger than or equal to 0.") 611 self.set_param(ms_ctx_param.inter_op_parallel_num, inter_op_parallel_num) 612 613 setters = { 614 'mode': set_mode, 615 'save_graphs_path': set_save_graphs_path, 616 'device_target': set_device_target, 617 'aoe_tune_mode': set_aoe_tune_mode, 618 'device_id': set_device_id, 619 'max_call_depth': set_max_call_depth, 620 'profiling_options': set_profiling_options, 621 'variable_memory_max_size': set_variable_memory_max_size, 622 'max_device_memory': set_max_device_memory, 623 'mempool_block_size': set_mempool_block_size, 624 'print_file_path': set_print_file_path, 625 'env_config_path': set_env_config_path, 626 'inter_op_parallel_num': set_inter_op_parallel_num, 627 'runtime_num_threads': set_runtime_num_threads, 628 'memory_optimize_level': set_memory_optimize_level, 629 'op_timeout': set_op_timeout, 630 'memory_offload': set_memory_offload, 631 'deterministic': set_deterministic, 632 'ascend_config': set_ascend_config, 633 'jit_syntax_level': set_jit_syntax_level, 634 'debug_level': set_debug_level, 635 'gpu_config': set_gpu_config, 636 'aoe_config': set_aoe_config, 637 'jit_config': set_jit_config, 638 } 639 640 @property 641 def reserve_class_name_in_scope(self): 642 """Get whether to save the network class name in the scope.""" 643 return self._thread_local_info.reserve_class_name_in_scope 644 645 @reserve_class_name_in_scope.setter 646 def reserve_class_name_in_scope(self, reserve_class_name_in_scope): 647 """Set whether to save the network class name in the scope.""" 648 if not isinstance(reserve_class_name_in_scope, bool): 649 raise ValueError("For 'context.set_context', the type of the property 'reserve_class_name_in_scope' must " 650 "be bool, but got {}.".format(type(reserve_class_name_in_scope))) 651 self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope 652 653 @property 654 def enable_ge(self): 655 return self._context_handle.get_backend_policy() == 'ge' 656 657 @property 658 def enable_debug_runtime(self): 659 return self._thread_local_info.debug_runtime 660 661 @enable_debug_runtime.setter 662 def enable_debug_runtime(self, enable): 663 thread_info = self._thread_local_info 664 thread_info.debug_runtime = enable 665 666 @property 667 def support_binary(self): 668 """Whether support run .pyc or .so in graph mode.""" 669 return self._support_binary 670 671 @support_binary.setter 672 def support_binary(self, support: bool): 673 if not isinstance(support, bool): 674 raise TypeError(f"The attribute 'support_binary' should be a bool, but got {type(support)}.") 675 self._support_binary = support 676 677 def _get_ascend_config_setter(self, ascend_key, trans_fn=None): 678 def _config_setter(ascend_value): 679 self.set_param(ms_ctx_param.__members__[ascend_key], trans_fn(ascend_value)) 680 681 if trans_fn is None: 682 trans_fn = lambda x: x 683 return _config_setter 684 685 def _set_op_debug_option(self, option_value): 686 valid_order = {'oom'} 687 if not isinstance(option_value, str): 688 raise TypeError(f"For 'ascend_config', the type of 'op_debug_option' must be str, " 689 f"but got {type(option_value)}.") 690 if option_value not in valid_order: 691 raise ValueError(f"For 'ascend_config', the 'op_debug_option' supports being set to 'oom' currently, " 692 f"but got {option_value}.") 693 self.set_param(ms_ctx_param.op_debug_option, option_value) 694 695 def _set_op_precision_mode(self, ascend_value): 696 op_precision_path = ascend_value 697 real_path = os.path.realpath(op_precision_path) 698 if not os.path.exists(real_path): 699 raise ValueError(f"For 'ascend_config', the 'op_precision_mode' is invalid path, " 700 f"got '{op_precision_path}'.") 701 self.set_param(ms_ctx_param.op_precision_mode, ascend_value) 702 703 def _set_ge_options(self, ge_options): 704 """Set ge options.""" 705 for level, options in ge_options.items(): 706 if level not in ['global', 'session']: 707 raise ValueError(f"For 'ascend_config', the key of ge_options must be one of " 708 f"('global', 'session'), but got {level}.") 709 710 if not isinstance(options, dict): 711 raise TypeError(f"For 'ge_options', the type of {level} options must be dict, " 712 f"but got {type(options)}. The error options: {options}.") 713 714 for key, value in options.items(): 715 if not isinstance(key, str): 716 raise TypeError(f"For 'ge_options', the type of key and value must be str, " 717 f"but got {type(key)}. The error key is {key}.") 718 if not isinstance(value, str): 719 raise TypeError(f"For 'ge_options', the type of key and value must be str, " 720 f"but got {type(value)}. The error value is {value}") 721 722 options_str = json.dumps(ge_options) 723 self.set_param(ms_ctx_param.ge_options, options_str) 724 725 def _set_topo_order(self, topo_order): 726 """ 727 Set topo order. 728 729 Args: 730 topo_order (dict): 731 key: str, the name of the graph. 732 value: str, the topo order of the graph, should be one of 'dfs', 'bfs', 'rdfs'. 733 """ 734 valid_order = {'dfs', 'bfs', 'rdfs'} 735 if not isinstance(topo_order, dict): 736 raise TypeError(f"For 'ascend_config', the 'topo_order' should be a dict, " 737 f"got '{type(topo_order)}'.") 738 for k, v in topo_order.items(): 739 if not isinstance(k, str): 740 raise TypeError("key {} is not a str".format(k)) 741 if v not in valid_order: 742 raise ValueError("value {} should be one of {}.".format(v, valid_order)) 743 744 options_str = json.dumps(topo_order) 745 self.set_param(ms_ctx_param.topo_order, options_str) 746 747 def _set_need_ckpt(self, need_ckpt): 748 """Set need ckpt flag""" 749 if not isinstance(need_ckpt, bool): 750 raise TypeError(f"For step num, the value type should be int, but got {type(need_ckpt)}, {need_ckpt}") 751 self.set_param(ms_ctx_param.need_ckpt, need_ckpt) 752 753 def _set_cur_step_num(self, step_num): 754 """set current step num at every step begin""" 755 if not isinstance(step_num, int): 756 raise TypeError(f"For step num, the value type should be int, but got {type(step_num)}, {step_num}") 757 self.set_param(ms_ctx_param.cur_step_num, step_num) 758 759 def _set_save_checkpoint_steps(self, steps): 760 """set save checkpoint steps before run""" 761 if not isinstance(steps, int): 762 raise TypeError(f"For step num, the value type should be int, but got {type(steps)}, {steps}") 763 self.set_param(ms_ctx_param.save_checkpoint_steps, steps) 764 765 def _set_last_triggered_step(self, step): 766 """set last triggered save ckpt steps before run""" 767 if not isinstance(step, int): 768 raise TypeError(f"For step num, the value type should be int, but got {type(step)}, {step}") 769 self.set_param(ms_ctx_param.last_triggered_step, step) 770 771 def _set_speedup_config_path(self, speedup_config_path): 772 """"Check and set speedup config for auto parallel.""" 773 if speedup_config_path is None or speedup_config_path == "": 774 return 775 speedup_config_real_path = os.path.abspath(speedup_config_path) 776 if not os.path.exists(speedup_config_real_path): 777 raise ValueError(f"For 'ascend_config', the path to parallel_speed_up_json: " 778 f"{speedup_config_real_path} does not exist, please check whether the " 779 f"'parallel_speed_up_json_path' is correct.") 780 try: 781 valid_option = {"recompute_comm_overlap": (ms_ctx_param.recompute_comm_overlap, bool), 782 "matmul_grad_comm_overlap": (ms_ctx_param.matmul_grad_comm_overlap, bool), 783 "enable_task_opt": (ms_ctx_param.enable_task_opt, bool), 784 "enable_grad_comm_opt": (ms_ctx_param.enable_grad_comm_opt, bool), 785 "recompute_allgather_overlap_fagrad": 786 (ms_ctx_param.recompute_allgather_overlap_fagrad, bool), 787 "interleaved_matmul_comm": (ms_ctx_param.interleaved_matmul_comm, bool), 788 "bias_add_comm_swap": (ms_ctx_param.bias_add_comm_swap, bool), 789 "enable_opt_shard_comm_opt": (ms_ctx_param.enable_opt_shard_comm_opt, bool), 790 "enable_begin_end_inline_opt": (ms_ctx_param.enable_begin_end_inline_opt, bool), 791 "enable_concat_eliminate_opt": (ms_ctx_param.enable_concat_eliminate_opt, bool), 792 "interleaved_layernorm_comm": (ms_ctx_param.interleaved_layernorm_comm, bool), 793 "compute_communicate_fusion_level": 794 (ms_ctx_param.compute_communicate_fusion_level, int), 795 "enable_flash_attention_load_balance": 796 (ms_ctx_param.enable_flash_attention_load_balance, bool)} 797 with open(speedup_config_real_path, 'r') as f: 798 speedup_config = json.load(f) 799 for key, value in speedup_config.items(): 800 if not isinstance(key, str): 801 raise TypeError("key {} is not a str".format(key)) 802 if key not in valid_option: 803 raise ValueError("key {} should be one of {}.".format(key, valid_option.keys())) 804 set_func, valid_type = valid_option.get(key) 805 if not isinstance(value, valid_type): 806 raise TypeError(f"The value type of {key} must be {valid_type}, " 807 f"but got value is {value} and type is {type(value)}.") 808 self.set_param(set_func, value) 809 except (TypeError, ValueError) as exo: 810 raise ValueError(str(exo) + "\nFor 'context.set_context', " 811 "open or load the 'speedup_config_path' file {} " 812 "failed, please check whether 'speedup_config_path' is json file and correct, " 813 "or may not have permission to read it.".format(speedup_config_real_path)) \ 814 from exo 815 816 817def _context(): 818 """ 819 Get the global _context, if context is not created, create a new one. 820 821 Returns: 822 _Context, the global context in PyNative mode. 823 """ 824 global K_CONTEXT 825 if K_CONTEXT is None: 826 default_backend = 'debug' 827 try: 828 from mindspore import default_config 829 default_backend = default_config.__backend__ 830 except ImportError: 831 logger.error("import default config fail") 832 K_CONTEXT = _Context() 833 K_CONTEXT.enable_debug_runtime = False 834 if default_backend == 'debug': 835 K_CONTEXT.enable_debug_runtime = True 836 default_backend = 'vm' 837 K_CONTEXT.set_backend_policy(default_backend) 838 return K_CONTEXT 839 840 841@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, 842 auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, 843 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, enable_alltoall=bool, 844 all_reduce_fusion_config=list, pipeline_stages=int, pipeline_segments=int, 845 pipeline_result_broadcast=bool, parallel_optimizer_config=dict, 846 pipeline_config=dict, 847 comm_fusion=dict, strategy_ckpt_config=dict, force_fp32_communication=bool) 848def set_auto_parallel_context(**kwargs): 849 r""" 850 Set auto parallel context, only data parallel supported on CPU. 851 852 Note: 853 Attribute name is required for setting attributes. 854 If a program has tasks on different parallel modes, before setting a new parallel mode for the 855 next task, interface :func:`mindspore.reset_auto_parallel_context` should be called to reset 856 the configuration. 857 Setting or changing parallel modes must be called before creating any Initializer, otherwise, 858 it may have RuntimeError when compiling the network. 859 860 Some configurations are parallel mode specific, see the below table for details: 861 862 =========================== =========================== 863 Common AUTO_PARALLEL 864 =========================== =========================== 865 device_num gradient_fp32_sync 866 global_rank loss_repeated_mean 867 gradients_mean search_mode 868 parallel_mode parameter_broadcast 869 all_reduce_fusion_config strategy_ckpt_load_file 870 enable_parallel_optimizer strategy_ckpt_save_file 871 parallel_optimizer_config dataset_strategy 872 enable_alltoall pipeline_stages 873 pipeline_config auto_parallel_search_mode 874 force_fp32_communication pipeline_result_broadcast 875 \ comm_fusion 876 \ strategy_ckpt_config 877 \ group_ckpt_save_file 878 \ auto_pipeline 879 =========================== =========================== 880 881 Args: 882 device_num (int): Available device number, the value must be in [1, 4096]. Default: ``1`` . 883 global_rank (int): Global rank id, the value must be in [0, 4095]. Default: ``0`` . 884 gradients_mean (bool): Whether to perform mean operator after allreduce of gradients. 885 "stand_alone" do not support gradients_mean. Default: ``False`` . 886 gradient_fp32_sync (bool): Run allreduce of gradients in fp32. "stand_alone", "data_parallel" 887 and "hybrid_parallel" do not support gradient_fp32_sync. Default: ``True`` . 888 loss_repeated_mean (bool) - Indicates whether the mean operator is executed backwards when the 889 calculation is repeated. Default: ``True`` . 890 parallel_mode (str): There are five kinds of parallel modes, ``"stand_alone"`` , ``"data_parallel"`` , 891 ``"hybrid_parallel"`` , ``"semi_auto_parallel"`` and ``"auto_parallel"`` . Note the pynative mode 892 only supports the ``"stand_alone"`` and ``"data_parallel"`` mode. Default: ``"stand_alone"`` . 893 894 - stand_alone: Only one processor is working. 895 896 - data_parallel: Distributes the data across different processors. 897 898 - hybrid_parallel: Achieves data parallelism and model parallelism manually. 899 900 - semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies. 901 902 - auto_parallel: Achieving parallelism automatically. 903 search_mode (str): There are three kinds of shard strategy search modes: ``"recursive_programming"`` , 904 ``"sharding_propagation"`` and ``"dynamic_programming"`` (Not recommended). 905 Default: ``"recursive_programming"`` . 906 907 - recursive_programming: Recursive programming search mode. In order to obtain optimal performance, 908 it is recommended that users set the batch size to be greater than or equal to the product of 909 the number of devices and the number of multi-copy parallelism. 910 911 - sharding_propagation: Propagate shardings from configured ops to non-configured ops. 912 913 - dynamic_programming: Dynamic programming search mode. 914 auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is 915 for forward compatibility, and this attribute will be deleted in a future MindSpore version. 916 parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have 917 the same network initialization parameter values for all devices, broadcast the parameters 918 on device 0 to other devices. Parameter broadcasting in different parallel modes is different, 919 ``data_parallel`` mode, all parameters are broadcast except for the parameter whose attribute 920 layerwise_parallel is ``True`` . ``Hybrid_parallel`` , ``semi_auto_parallel`` and 921 ``auto_parallel mode`` , the segmented parameters do not participate in broadcasting. 922 Default: ``False`` . 923 strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. The parameter is not to be 924 recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ``''`` 925 strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. The parameter is not to be 926 recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ``''`` 927 full_batch (bool): If you load whole batch datasets in ``auto_parallel`` mode, this parameter 928 should be set as ``True`` . Default: ``False`` . The interface is not to be recommended 929 currently, it is better using 'dataset_strategy' to replace it. 930 dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: ``"data_parallel"`` . 931 dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is 932 equal to full_batch=True. For execution mode is 'GRAPH_MODE' and dataset load into net by model 933 parallel strategy likes ds_stra ((1, 8), (1, 8)), it requires using 934 set_auto_parallel_context(dataset_strategy=ds_stra). 935 enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for 936 data parallel training in the benefit of time and memory saving. Currently, auto and semi auto 937 parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports 938 `Lamb` and `AdamWeightDecay` in Ascend . Default: ``False`` . 939 force_fp32_communication (bool): A switch that determines whether reduce operators (AllReduce, ReduceScatter) 940 are forced to use the fp32 data type for communication during communication. True is the enable 941 switch. Default: ``False`` . 942 enable_alltoall (bool): A switch that allows AllToAll operators to be generated during communication. If its 943 value is ``False`` , there will be a combination of operators such as AllGather, Split and 944 Concat instead of AllToAll. Default: ``False`` . 945 all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM 946 and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. 947 pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are 948 distributed alone in the pipeline. The total devices will be divided into 'pipeline_stags' 949 stages. 950 Default: ``1`` . 951 pipeline_result_broadcast (bool): A switch that broadcast the last stage result to all other stage in pipeline 952 parallel inference. Default: ``False`` . 953 pipeline_config (dict): A dict contains the keys and values for setting the pipeline parallelism configuration. 954 It supports the following keys: 955 956 - pipeline_interleave(bool): Indicates whether to enable the interleaved execution mode. 957 - pipeline_scheduler(str): Indicates the scheduling mode for pipeline parallelism. Only support 958 ``gpipe/1f1b``. 959 parallel_optimizer_config (dict): A dict contains the keys and values for setting the parallel optimizer 960 configure. The configure provides more detailed behavior control about parallel training 961 when parallel optimizer is enabled. The configure will be effective when we use 962 mindspore.set_auto_parallel_context(enable_parallel_optimizer=True). 963 It supports the following keys. 964 965 - gradient_accumulation_shard(bool): If ``true`` , the accumulation gradient parameters will be 966 sharded across the data parallel devices. This will 967 introduce additional communication(ReduceScatter) at 968 each step when accumulate the gradients, but saves a 969 lot of device memories, thus can make model be trained 970 with larger batch size. This configure is effective only 971 when the model runs on pipeline training or gradient 972 accumulation with data parallel. Default ``False`` . 973 974 - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel 975 optimizer is enabled, parameters with size smaller than this threshold will not be sharded 976 across the devices. Parameter size = shape[0] \* ... \* shape[n] \* size(dtype). Non-negative. 977 Unit: KB. Default: ``64`` . 978 979 - optimizer_weight_shard_size(int): Set the optimizer weight shard group size, if you want to 980 specific the maximum group size across devices when the parallel optimizer is enabled. 981 The numerical range can be (0, device_num]. If pipeline parallel is enabled, the numerical 982 range is (0, device_num/stage]. If the size of data parallel communication domain 983 of the parameter cannot be divided by `optimizer_weight_shard_size`, then the specified 984 communication group size will not take effect. Default value is ``-1`` , which means the 985 optimizer weight shard group size will be the size of data parallel group of each parameter. 986 987 comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each 988 communication fusion config has two keys: "mode" and "config". 989 It supports following communication fusion types and configurations: 990 991 - openstate: Whether turn on the communication fusion or not. If `openstate` is ``True`` , 992 turn on the communication fusion, otherwise, turn off the communication fusion. 993 Default: ``True`` . 994 995 - allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size` 996 and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default 997 fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size 998 manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as 999 `all_reduce_fusion_config`. 1000 1001 - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`. 1002 In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion 1003 threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size 1004 manually, and the fusion threshold must be larger than `0` MB. 1005 1006 - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto` 1007 and `size`. Config is same as `allgather`. 1008 1009 strategy_ckpt_config (dict): A dict contains the configurations for setting the parallel strategy file. This 1010 interface contains the functions of parameter `strategy_ckpt_load_file` and 1011 `strategy_ckpt_save_file`, it is recommonded to use this parameter to replace those two 1012 parameters. 1013 It contains following configurations: 1014 1015 - load_file (str): The path to load parallel strategy checkpoint. If the file name extension is 1016 `.json`, the file is loaded in JSON format. Otherwise, the file is loaded in ProtoBuf 1017 format. 1018 Default: ``''`` 1019 1020 - save_file (str): The path to save parallel strategy checkpoint. If the file name extension is 1021 `.json`, the file is saved in JSON format. Otherwise, the file is saved in ProtoBuf format. 1022 Default: ``''`` 1023 1024 - only_trainable_params (bool): Only save/load the strategy information for trainable parameter. 1025 Default: ``True`` . 1026 group_ckpt_save_file (str): The path to save parallel group checkpoint. 1027 auto_pipeline (bool): Set the pipeline stage number to automatic. Its value will be selected between 1 and the 1028 parameter `pipeline_stages`. This option requires the `parallel_mode` to be ``auto_parallel`` 1029 and the `search_mode` to be ``recursive_programming``. Default: ``False`` . 1030 1031 Raises: 1032 ValueError: If input key is not attribute in auto parallel context. 1033 1034 Examples: 1035 >>> import mindspore as ms 1036 >>> ms.set_auto_parallel_context(device_num=8) 1037 >>> ms.set_auto_parallel_context(global_rank=0) 1038 >>> ms.set_auto_parallel_context(gradients_mean=True) 1039 >>> ms.set_auto_parallel_context(gradient_fp32_sync=False) 1040 >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel") 1041 >>> ms.set_auto_parallel_context(search_mode="recursive_programming") 1042 >>> ms.set_auto_parallel_context(auto_parallel_search_mode="recursive_programming") 1043 >>> ms.set_auto_parallel_context(parameter_broadcast=False) 1044 >>> ms.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") 1045 >>> ms.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") 1046 >>> ms.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8))) 1047 >>> ms.set_auto_parallel_context(enable_parallel_optimizer=False) 1048 >>> ms.set_auto_parallel_context(enable_alltoall=False) 1049 >>> ms.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) 1050 >>> ms.set_auto_parallel_context(pipeline_stages=2) 1051 >>> ms.set_auto_parallel_context(pipeline_stages=2, pipeline_result_broadcast=True) 1052 >>> parallel_config = {"gradient_accumulation_shard": True, "parallel_optimizer_threshold": 24, 1053 ... "optimizer_weight_shard_size": 2} 1054 >>> ms.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True) 1055 >>> config = {"allreduce": {"mode": "size", "config": 32}, "allgather": {"mode": "size", "config": 32}} 1056 >>> ms.set_auto_parallel_context(comm_fusion=config) 1057 >>> stra_ckpt_dict = {"load_file": "./stra0.ckpt", "save_file": "./stra1.ckpt", "only_trainable_params": False} 1058 >>> ms.set_auto_parallel_context(strategy_ckpt_config=stra_ckpt_dict) 1059 """ 1060 _set_auto_parallel_context(**kwargs) 1061 1062 1063def get_auto_parallel_context(attr_key): 1064 """ 1065 Get auto parallel context attribute value according to the key. 1066 1067 Args: 1068 attr_key (str): The key of the attribute. 1069 1070 Returns: 1071 Returns attribute value according to the key. 1072 1073 Raises: 1074 ValueError: If input key is not attribute in auto parallel context. 1075 1076 Examples: 1077 >>> import mindspore as ms 1078 >>> parallel_mode = ms.get_auto_parallel_context("parallel_mode") 1079 >>> dataset_strategy = ms.get_auto_parallel_context("dataset_strategy") 1080 """ 1081 return _get_auto_parallel_context(attr_key) 1082 1083 1084def reset_auto_parallel_context(): 1085 """ 1086 Reset auto parallel context attributes to the default values. 1087 1088 - device_num: 1. 1089 - global_rank: 0. 1090 - gradients_mean: False. 1091 - gradient_fp32_sync: True. 1092 - parallel_mode: 'stand_alone'. 1093 - search_mode: 'recursive_programming'. 1094 - auto_parallel_search_mode: 'recursive_programming'. 1095 - parameter_broadcast: False. 1096 - strategy_ckpt_load_file: ''. 1097 - strategy_ckpt_save_file: ''. 1098 - full_batch: False. 1099 - enable_parallel_optimizer: False. 1100 - force_fp32_communication: False 1101 - enable_alltoall: False. 1102 - pipeline_stages: 1. 1103 - pipeline_result_broadcast: False. 1104 - fusion_threshold: 64. 1105 - auto_pipeline: False. 1106 1107 Examples: 1108 >>> import mindspore as ms 1109 >>> ms.reset_auto_parallel_context() 1110 """ 1111 _reset_auto_parallel_context() 1112 1113 1114@args_type_check(offload_config=dict) 1115def set_offload_context(offload_config): 1116 r""" 1117 Configure heterogeneous training detailed parameters to adjust the offload strategy. 1118 1119 Note: 1120 The offload configuration is only used if the memory offload feature is enabled 1121 via mindspore.set_context(memory_offload="ON"). 1122 1123 Args: 1124 offload_config (dict): A dict contains the keys and values for setting the offload context 1125 configure.It supports the following keys. 1126 1127 - offload_path (str): The path of offload, relative paths are supported. Default: ``"./offload"``. 1128 - offload_cpu_size (str): The cpu memory size for offload. The format is "xxGB". 1129 - offload_disk_size (str): The disk size for offload. The format is "xxGB" 1130 - hbm_ratio (float): The ratio that can be used based on the maximum device memory. 1131 The range is (0,1], Default: ``1.0``. 1132 - cpu_ratio (float): The ratio that can be used based on the maximum host memory. 1133 The range is (0,1], Default: ``1.0``. 1134 - enable_pinned_mem (bool): The flag of whether enabling Pinned Memory. Default: ``True``. 1135 - enable_aio (bool): The flag of whether enabling aio. Default: ``True``. 1136 - aio_block_size (str): The size of aio block. The format is "xxGB". 1137 - aio_queue_depth (int): The depth of aio queue. 1138 - offload_param (str): The param for offload destination, cpu or disk, Default: ``""``. 1139 - offload_checkpoint (str): The checkpoint for offload destination, only valid if recompute is turned on, 1140 cpu or disk, Default: ``""``. 1141 - auto_offload (bool): The flag of whether auto offload. Default: ``True``. 1142 - host_mem_block_size (str): The memory block size of host memory pool. The format is "xxGB" 1143 1144 Raises: 1145 ValueError: If input key is not attribute in auto parallel context. 1146 1147 Examples: 1148 >>> from mindspore import context 1149 >>> context.set_offload_context(offload_config={"offload_param":"cpu"}) 1150 """ 1151 _set_offload_context(offload_config) 1152 1153 1154def get_offload_context(): 1155 """ 1156 Gets the offload configuration parameters. Configure through interface mindspore.set_offload_context(). 1157 If the user is not set, the default configuration is obtained. 1158 1159 Returns: 1160 Dict, heterogeneous training offload detailed configuration parameters. 1161 1162 Examples: 1163 >>> from mindspore import context 1164 >>> offload_config = context.get_offload_context() 1165 """ 1166 return _get_offload_context() 1167 1168 1169def _check_target_specific_cfgs(device, arg_key): 1170 """Checking whether a config is suitable for a specified device""" 1171 device_cfgs = { 1172 'enable_graph_kernel': ['Ascend', 'GPU', 'CPU'], 1173 'graph_kernel_flags': ['Ascend', 'GPU', 'CPU'], 1174 'enable_reduce_precision': ['Ascend'], 1175 'print_file_path': ['Ascend'], 1176 'variable_memory_max_size': ['Ascend'], 1177 'max_device_memory': ['Ascend', 'GPU'], 1178 'mempool_block_size': ['GPU', 'Ascend'], 1179 'disable_format_transform': ['GPU'], 1180 'ascend_config': ['Ascend'], 1181 'gpu_config': ['GPU'], 1182 } 1183 # configs not in map device_cfgs are supposed to be suitable for all devices 1184 if arg_key not in device_cfgs: 1185 return True 1186 supported_devices = device_cfgs[arg_key] 1187 if device in supported_devices: 1188 return True 1189 logger.warning(f"For 'context.set_context', when set the argument '{arg_key}', " 1190 f"the argument 'device_target' only supports devices in '{supported_devices}', " 1191 f"but got '{device}', ignore it.") 1192 return False 1193 1194 1195def _check_ascend_device_context_initialized(device_target, settings): 1196 if device_target == 'Ascend' and is_initialized(device_target): 1197 for key, _ in settings.items(): 1198 if key in ('ascend_config', 'deterministic', 'jit_compile', 'exception_dump', 'device_id'): 1199 logger.warning(f"For 'context.set_context' in Ascend backend, the backend is already initialized, " 1200 "please set it before the definition of any Tensor and Parameter, and the " 1201 "instantiation and execution of any operation and net, otherwise the settings may not " 1202 "take effect. ") 1203 break 1204 1205 1206def _check_key(key): 1207 if key in ('precision_mode', 'jit_compile', 'atomic_clean_policy', 'matmul_allow_hf32', 'conv_allow_hf32', 1208 'op_precision_mode', 'host_scheduling_max_threshold', 'ge_options', 'op_debug_option'): 1209 raise ValueError(f"Please set '{key}' through parameter ascend_config") 1210 1211 1212@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=(bool, int), 1213 save_graphs_path=str, enable_dump=bool, aoe_tune_mode=str, aoe_config=dict, 1214 save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, 1215 enable_auto_mixed_precision=bool, inter_op_parallel_num=int, 1216 enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool, 1217 max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str, 1218 graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool, 1219 grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, disable_format_transform=bool, 1220 op_timeout=int, deterministic=str, ascend_config=dict, jit_syntax_level=int, debug_level=int, 1221 jit_enable_inplace_ops=bool, gpu_config=dict, jit_config=dict, enable_compile_cache=bool) 1222def set_context(**kwargs): 1223 """ 1224 Set context for running environment. 1225 1226 Context should be configured before running your program. If there is no configuration, 1227 it will be automatically set according to the device target by default. 1228 1229 Note: 1230 Attribute name is required for setting attributes. 1231 The mode is not recommended to be changed after net was initialized because the implementations of some 1232 operations are different in graph mode and pynative mode. Default: ``PYNATIVE_MODE`` . 1233 1234 Some configurations are device specific, see the below table for details: 1235 1236 +-------------------------+------------------------------+----------------------------+ 1237 | Function Classification | Configuration Parameters | Hardware Platform Support| 1238 +=========================+==============================+============================+ 1239 | System Configuration | device_id | CPU/GPU/Ascend | 1240 | +------------------------------+----------------------------+ 1241 | | device_target | CPU/GPU/Ascend | 1242 | +------------------------------+----------------------------+ 1243 | | max_device_memory | GPU/Ascend | 1244 | +------------------------------+----------------------------+ 1245 | | variable_memory_max_size | Ascend | 1246 | +------------------------------+----------------------------+ 1247 | | mempool_block_size | GPU/Ascend | 1248 | +------------------------------+----------------------------+ 1249 | | op_timeout | Ascend | 1250 +-------------------------+------------------------------+----------------------------+ 1251 | Debug Configuration | save_graphs | CPU/GPU/Ascend | 1252 | +------------------------------+----------------------------+ 1253 | | save_graphs_path | CPU/GPU/Ascend | 1254 | +------------------------------+----------------------------+ 1255 | | enable_dump | Ascend | 1256 | +------------------------------+----------------------------+ 1257 | | save_dump_path | Ascend | 1258 | +------------------------------+----------------------------+ 1259 | | deterministic | Ascend | 1260 | +------------------------------+----------------------------+ 1261 | | print_file_path | Ascend | 1262 | +------------------------------+----------------------------+ 1263 | | env_config_path | CPU/GPU/Ascend | 1264 | +------------------------------+----------------------------+ 1265 | | precompile_only | CPU/GPU/Ascend | 1266 | +------------------------------+----------------------------+ 1267 | | reserve_class_name_in_scope | CPU/GPU/Ascend | 1268 | +------------------------------+----------------------------+ 1269 | | pynative_synchronize | CPU/GPU/Ascend | 1270 | +------------------------------+----------------------------+ 1271 | | debug_level | CPU/GPU/Ascend | 1272 +-------------------------+------------------------------+----------------------------+ 1273 | Executive Control | mode | CPU/GPU/Ascend | 1274 | +------------------------------+----------------------------+ 1275 | | enable_graph_kernel | Ascend/GPU | 1276 | +------------------------------+----------------------------+ 1277 | | graph_kernel_flags | Ascend/GPU | 1278 | +------------------------------+----------------------------+ 1279 | | enable_reduce_precision | Ascend | 1280 | +------------------------------+----------------------------+ 1281 | | aoe_tune_mode | Ascend | 1282 | +------------------------------+----------------------------+ 1283 | | aoe_config | Ascend | 1284 | +------------------------------+----------------------------+ 1285 | | check_bprop | CPU/GPU/Ascend | 1286 | +------------------------------+----------------------------+ 1287 | | max_call_depth | CPU/GPU/Ascend | 1288 | +------------------------------+----------------------------+ 1289 | | grad_for_scalar | CPU/GPU/Ascend | 1290 | +------------------------------+----------------------------+ 1291 | | enable_compile_cache | CPU/GPU/Ascend | 1292 | +------------------------------+----------------------------+ 1293 | | inter_op_parallel_num | CPU/GPU/Ascend | 1294 | +------------------------------+----------------------------+ 1295 | | runtime_num_threads | CPU/GPU/Ascend | 1296 | +------------------------------+----------------------------+ 1297 | | compile_cache_path | CPU/GPU/Ascend | 1298 | +------------------------------+----------------------------+ 1299 | | disable_format_transform | GPU | 1300 | +------------------------------+----------------------------+ 1301 | | support_binary | CPU/GPU/Ascend | 1302 | +------------------------------+----------------------------+ 1303 | | memory_optimize_level | CPU/GPU/Ascend | 1304 | +------------------------------+----------------------------+ 1305 | | memory_offload | GPU/Ascend | 1306 | +------------------------------+----------------------------+ 1307 | | ascend_config | Ascend | 1308 | +------------------------------+----------------------------+ 1309 | | jit_syntax_level | CPU/GPU/Ascend | 1310 | +------------------------------+----------------------------+ 1311 | | gpu_config | GPU | 1312 | +------------------------------+----------------------------+ 1313 | | jit_config | CPU/GPU/Ascend | 1314 +-------------------------+------------------------------+----------------------------+ 1315 1316 Args: 1317 device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1], 1318 while device_num_per_host should be no more than 4096. Default: ``0`` . 1319 device_target (str): The target device to run, support "Ascend", "GPU", and "CPU". 1320 If device target is not set, the version of MindSpore package is used. 1321 max_device_memory (str): Set the maximum memory available for devices. The format is "xxGB". 1322 Default: ``" 1024GB"`` . The actual used memory size is the minimum of the available memory of the device 1323 and max_device_memory. 'max_device_memory' should be set before the program runs. 1324 variable_memory_max_size (str): This parameter is deprecated, and will be removed in a future version. 1325 Please use parameter 'max_device_memory' instead. 1326 mempool_block_size (str): Set the size of the memory pool block in PyNative mode or jit level is 'O0'/'O1' 1327 for devices. The format is "xxGB". Default: ``"1GB"`` . Minimum size is "1G". The actual used memory block 1328 size is the minimum of the available memory of the device and mempool_block_size. 1329 op_timeout (int): Set the maximum duration of executing an operator in seconds. 1330 If the execution time exceeds this value, system will terminate the task. 1331 0 means endless wait. The defaults for AI Core and AICPU operators vary on different hardware. 1332 For more information, 1333 please refer to `Ascend Community document about aclrtSetOpExecuteTimeOut 1334 <https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/infacldevg/aclcppdevg/aclcppdevg_03_0069.html>`_. 1335 Default: ``900`` . 1336 save_graphs (bool or int): Whether to save intermediate compilation graphs. Default: ``0`` . 1337 Available values are: 1338 1339 - False or 0: disable saving of intermediate compilation graphs. 1340 - 1: some intermediate files will be generated during graph compilation. 1341 - True or 2: Generate more ir files related to backend process. 1342 - 3: Generate visualization computing graphs and detailed frontend ir graphs. 1343 1344 When the network structure is complex, setting `save_graphs` attribute to ``2`` or ``3`` may take too long. 1345 If you need quick problem locating, you can switch to ``1`` first. 1346 1347 When the `save_graphs` attribute is set as ``True`` , ``1`` , ``2`` or ``3`` , attribute of 1348 `save_graphs_path` is used to set the intermediate compilation graph storage path. By default, the graphs 1349 are saved in the current directory. 1350 save_graphs_path (str): Path to save graphs. Default: ``"."``. 1351 If the specified directory does not exist, the system will automatically create the directory. 1352 During distributed training, graphs will be saved to the directory of 1353 `save_graphs_path/rank_${rank_id}/`. `rank_id` is the ID of the current device in the cluster. 1354 deterministic (str): Whether to enable op run in deterministic mode. The value must be in the 1355 range of ['ON', 'OFF'], and the default value is ``'OFF'`` . 1356 1357 - "ON": Enable operator deterministic running mode. 1358 - "OFF": Disable operator deterministic running mode. 1359 1360 When deterministic mode is on, model ops will be deterministic in Ascend. This means that if op run 1361 multiple times with the same inputs on the same hardware, it will have the exact same outputs each time. 1362 This is useful for debugging models. 1363 enable_dump (bool): This parameters is deprecated, and will be deleted in the next version. 1364 save_dump_path (str): This parameters is deprecated, and will be deleted in the next version. 1365 print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to 1366 a file by default, and print_file_path is not set, the screen will be displayed. 1367 If the saved file already exists, the timestamp suffix will be added to the file. Saving data to a file 1368 solves the problem of data loss in screen printing when a large amount of data is generated. 1369 If it is not set, an error will be reported: prompt to set the upper absolute path. 1370 When print data to file, the total output bytes of single print must be less then 2GB(limited by 1371 protobuf). 1372 env_config_path (str): Config path for DFX. 1373 Through mindspore.set_context(env_config_path="./mindspore_config.json") 1374 1375 configure RDR: 1376 1377 - enable: controls whether the RDR is enabled to collect the key data during training and 1378 save key data in the fault scenario. When set to ``true`` , the RDR will be turned on. 1379 When set to ``false`` , the RDR will be turned off. 1380 - mode: sets the mode of RDR on exporting data. When set to ``1`` , the RDR only exports data 1381 in the fault scenario. When set to ``2`` , the RDR exports data in the fault scenario and the 1382 normal end scenario. Default: ``1`` . 1383 - path: sets the path where RDR saves data. The current path must be absolute. 1384 1385 Memory reuse: 1386 1387 - mem_Reuse: controls whether the memory reuse function is turned on. When set to ``True`` , 1388 the memory reuse function is turned on. When set to ``False`` , the memory reuse function is turned off. 1389 1390 precompile_only (bool): Whether to only precompile the network. Default: ``False`` . 1391 If set to ``True`` , the network will only be compiled, not executed. 1392 reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: ``True`` . 1393 Each node has a scope. A scope of a subnode is the name of its parent node. If reserve_class_name_in_scope 1394 is set to ``True`` , the class name will be saved after keyword 'net-' in the scope. 1395 For example: 1396 1397 Default/net-Net1/net-Net2 (reserve_class_name_in_scope=True) 1398 1399 Default/net/net (reserve_class_name_in_scope=False) 1400 1401 pynative_synchronize (bool): Whether to enable synchronous execution of the device in PyNative mode. 1402 Default: ``False`` . When the value is set to ``False`` , the operator is executed asynchronously on the 1403 device. When an error occurs in the execution of the operator, the specific error script code location 1404 cannot be located, when the value is set to ``True`` , the operator is executed synchronously on the 1405 device. It will reduce the execution performance of the program. At this time, when an error occurs in the 1406 execution of the operator, the location of the error script code can be located according to the call stack 1407 of the error. 1408 mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). 1409 Both modes support all backends. Default: ``PYNATIVE_MODE`` . 1410 enable_graph_kernel (bool): Whether to enable graph kernel fusion to optimize network execution performance. 1411 Default: ``False`` . 1412 Indicates whether to enable image-computing convergence to optimize network execution performance. 1413 If enable_graph_kernel is set to ``True`` , acceleration can be enabled. 1414 For details of graph kernel fusion, please check 1415 `Enabling Graph Kernel Fusion 1416 <https://www.mindspore.cn/tutorials/experts/en/master/optimize/graph_fusion_engine.html>`_. 1417 graph_kernel_flags (str): 1418 Optimization options of graph kernel fusion, and the priority is higher when it conflicts 1419 with enable_graph_kernel. Only for experienced users. 1420 For example, 1421 1422 .. code-block:: 1423 1424 mindspore.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text") 1425 1426 Some general options: 1427 1428 - opt_level: Set the optimization level. 1429 Default: ``2`` . Graph kernel fusion can be enabled equivalently by setting opt_level greater than 0. 1430 Available values are: 1431 1432 - 0: disables graph kernel fusion; 1433 - 1: enables the basic fusion of operators; 1434 - 2: includes all optimizations of level 1, 1435 and turns on more optimizations such as CSE, arithmetic simplification and so on; 1436 - 3: includes all optimizations of level 2, and turns on more optimizations such as SitchingFusion, 1437 ParallelFusion and so on. Optimizations of this level are radical and unstable in some scenarios. 1438 Be caution when using this level. 1439 1440 - dump_as_text: dumps detail info as text files. Default: ``False`` . 1441 1442 enable_reduce_precision (bool): Whether to enable precision reduction. 1443 If the operator does not support the user-specified precision, the precision will 1444 be changed automatically. Default: ``True`` . 1445 aoe_tune_mode (str): AOE tuning mode setting, which is not set by default. 1446 When set to ``"online"`` , the tuning in online function is turned on. 1447 When set to ``"offline"`` , ge graph will be save for offline tuning. 1448 aoe_config (dict): Set the parameters specific to Ascend Optimization Engine. It is not set by default. 1449 1450 - job_type (str): Mode type setting, default value is ``"2"``. 1451 1452 - ``"1"``: subgraph tuning; 1453 - ``"2"``: operator tuning. 1454 1455 check_bprop (bool): Whether to check back propagation nodes. The checking ensures that the shape and dtype 1456 of back propagation node outputs is the same as input parameters. Default: ``False`` . 1457 max_call_depth (int): Specify the maximum depth of function call. Must be positive integer. Default: ``1000`` . 1458 The max_call_depth parameter needs to be set when the nested call is too deep or the number 1459 of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be 1460 set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow. 1461 grad_for_scalar (bool): Whether to get gradient for scalar. Default: ``False`` . 1462 When grad_for_scalar is set to ``True`` , the function's scalar input can be derived. 1463 The default value is ``False`` . Because the back-end does not support scaling operations currently, 1464 this interface only supports simple operations that can be deduced by the front-end. 1465 enable_compile_cache (bool): Whether to save or load the cache of the graph compiled by front-end. 1466 After enable_compile_cache is set to ``True`` , during the first execution, a hardware-independent 1467 compilation cache is generated and exported to a MINDIR file. When the network is executed again, 1468 if enable_compile_cache is still set to ``True`` and the network scripts are not changed, 1469 the compile cache is loaded. Note that only limited automatic detection for the changes of 1470 python scripts is supported by now, which means that there is a correctness risk. Default: ``False`` . 1471 This is an experimental prototype that is subject to change and/or deletion. 1472 compile_cache_path (str): Path to save the compile cache. Default: ``"."``. 1473 If the specified directory does not exist, the system will automatically create the directory. 1474 The cache will be saved to the directory of `compile_cache_path/rank_${rank_id}/`. The `rank_id` is 1475 the ID of the current device in the cluster. 1476 inter_op_parallel_num(int): The thread number of op parallel at the same time. Default value is ``0`` , 1477 which means use the default num. 1478 runtime_num_threads(int): The thread pool number of cpu kernel used in runtime, 1479 which must bigger than or equal to 0. Default value is ``30`` , if you run many processes at 1480 the same time, you should set the value smaller to avoid thread contention. 1481 disable_format_transform (bool): Whether to disable the automatic format transform function from NCHW to NHWC. 1482 When the network training performance of fp16 is worse than fp32, `disable_format_transform` can be set to 1483 ``True`` to try to improve training performance. Default: ``False`` . 1484 support_binary (bool): Whether to support run .pyc or .so in graph mode. If want to support run .so or .pyc 1485 in graph mode, coulde set 'support_binary' to be ``True`` , and run once .py file. It would save the source 1486 of the interfaces would be compiled by MindSpore to the interfaces definition .py file that should be 1487 guaranteed to be writable. Then compile the .py file to the .pyc or .so file, and could run in Graph mode. 1488 memory_optimize_level (str): The memory optimize level. 1489 On Ascend hardware platform, default: ``O1``, on other hardware platforms, default: ``O0``. 1490 The value must be in ['O0', 'O1']. 1491 1492 - O0: priority performance option, disable SOMAS (Safe Optimized Memory Allocation Solver) 1493 and some other memory optimizations. 1494 - O1: priority memory option, enable SOMAS and some other memory optimizations. 1495 memory_offload (str): Whether to enable the memory offload function. When it is enabled, the idle data will be 1496 temporarily copied to the host side in the case of insufficient device memory. The value must be in the 1497 range of ['ON', 'OFF'], and the default value is ``'OFF'`` . 1498 1499 - ON: Enable the memory Offload function. On Ascend hardware platform, this parameter does not take effect 1500 when the graph compilation level is not 'O0'; This parameter does not take effect when 1501 memory_optimize_level is set 'O1'. 1502 - OFF: Turn off the memory Offload function. 1503 ascend_config (dict): Set the parameters specific to Ascend hardware platform. It is not set by default. 1504 The default value of `precision_mode`, `jit_compile` and 1505 `atomic_clean_policy` are experimental parameters, may change in the future. 1506 1507 - precision_mode (str): Mixed precision mode setting, and the default value of inference network 1508 is ``force_fp16`` . The value range is as follows: 1509 1510 - force_fp16: When the operator supports both float16 and float32, select float16 directly. 1511 - allow_fp32_to_fp16: For cube operators, use the float16. For vector operators, 1512 prefer to keep the origin dtype, if the operator in model can support float32, 1513 it will keep original dtype, otherwise it will reduce to float16. 1514 - allow_mix_precision: Automatic mixing precision, facing the whole network operator, according 1515 to the built-in optimization strategy, automatically reduces the precision of some operators 1516 to float16 or bfloat16. 1517 - must_keep_origin_dtype: Keep the accuracy of the original drawing. 1518 - force_fp32: When the input of the matrix calculation operator is float16 and the output supports 1519 float16 and float32, output is forced to float32. 1520 - allow_fp32_to_bf16: For cube operators, use the bfloat16. For vector operators, 1521 prefer to keep the origin dtype, if the operator in model can support float32, 1522 it will keep original dtype, otherwise it will reduce to bfloat16. 1523 - allow_mix_precision_fp16: Automatic mixing precision, facing the whole network operator, automatically 1524 reduces the precision of some operators to float16 according to the built-in optimization strategy. 1525 - allow_mix_precision_bf16: Automatic mixing precision, facing the whole network operator, according to 1526 the built-in optimization strategy, automatically reduces the precision of some operators to bfloat16. 1527 1528 - jit_compile (bool): Whether to select online compilation. When set to 'True', online compilation is 1529 prioritized. When set to 'False', compiled operator binary files are prioritized to improve compilation 1530 performance. The default settings are online compilation for static shape, and compiled operator binary 1531 files for dynamic shape. 1532 - atomic_clean_policy (int): The policy for cleaning memory occupied by atomic operators in the network. 1533 Default: ``1`` . 1534 1535 - 0: The memory occupied by all atomic operators in the network is cleaned centrally. 1536 - 1: Memory is not cleaned centrally and each atomic operator in the network is cleaned separately. 1537 When the memory of the network exceeds the limit, you may try this cleaning policy, but it may cause 1538 performance loss. 1539 - matmul_allow_hf32 (bool): Whether to convert FP32 to HF32 for Matmul operators. Default value: ``False``. 1540 This is an experimental prototype that is subject to change and/or deletion. 1541 For detailed information, please refer to `Ascend community <https://www.hiascend.com/>`_ . 1542 - conv_allow_hf32 (bool): Whether to convert FP32 to HF32 for Conv operators. Default value: ``True``. 1543 This is an experimental prototype that is subject to change and/or deletion. 1544 For detailed information, please refer to `Ascend community <https://www.hiascend.com/>`_ . 1545 - exception_dump (str): Enable exception dump for Ascend operators, providing the input and output data for 1546 failing Ascend operators. The value can be ``"0"`` , ``"1"`` and ``"2"``. For ``"0"`` , exception dump is 1547 turned off; for ``"1"``, all inputs and outputs will be dumped for AICore exception operators; 1548 for ``"2"``, inputs will be dumped for AICore exception operators, reducing the saved information 1549 but improving performance. Default: ``"2"`` . 1550 - op_precision_mode (str): Path to config file of op precision mode. For detailed information, please refer 1551 to `Ascend community <https://www.hiascend.com/>`_ . 1552 - op_debug_option (str): Enable debugging options for Ascend operators, default not enabled. 1553 The value currently only supports being set to ``"oom"``. 1554 1555 - ``"oom"``: When there is a memory out of bounds during the execution of an operator, 1556 AscendCL will return an error code of ``EZ9999``. 1557 1558 - ge_options (dict): Set options for CANN. The options are divided into two categories: global and session. 1559 This is an experimental prototype that is subject to change and/or deletion. 1560 For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/inferapplicationdev/graphdevg/atlasgeapi_07_0119.html>`_ . 1561 The configuration options in `ge_options` may be duplicated with the options in `ascend_config`. If the 1562 same configuration options are set in both `ascend_config` and `ge_options`, the one set in `ge_options` 1563 shall prevail. 1564 1565 - global (dict): Set global options. 1566 - session (dict): Set session options. 1567 1568 - parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file, configuration 1569 can refer to `parallel_speed_up.json 1570 <https://gitee.com/mindspore/mindspore/blob/master/config/parallel_speed_up.json>`_ . 1571 If its value is None or '', it does not take effect. Default None. 1572 1573 - recompute_comm_overlap (bool): Enable overlap between recompute ops and communication ops if True. 1574 Default: False. 1575 - matmul_grad_comm_overlap (bool): Enable overlap between dw matmul and 1576 tensor parallel communication ops if True. Default: False. 1577 - recompute_allgather_overlap_fagrad (bool): Enable overlap between duplicated allgather by recomputing 1578 in sequence parallel and flashattentionscoregrad ops if True. Default: False. 1579 - enable_task_opt (bool): Enable communication fusion to optimize the number of communication operator 1580 tasks if True. 1581 Default: False. 1582 - enable_grad_comm_opt (bool): Enable overlap between dx ops and data parallel communication ops if True. 1583 Currently, do not support 1584 `LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html>` 1585 Default: False. 1586 - enable_opt_shard_comm_opt (bool): Enable overlap between forward ops 1587 and optimizer parallel allgather communication if True. Currently, do not support 1588 `LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html>` 1589 Default: False. 1590 - compute_communicate_fusion_level (int): Enable the fusion between compute and communicate. 1591 Default: ``0``. 1592 1593 - 0: Disable fusion. 1594 1595 - 1: Apply fusion to forward nodes. 1596 1597 - 2: Apply fusion to backward nodes. 1598 1599 - 3: Apply fusion to all nodes. 1600 - bias_add_comm_swap (bool): Enable node execution order swap communication operators and add operators 1601 if ``True``. Only 1-dimension bias node is supported. Default: ``False``. 1602 - host_scheduling_max_threshold(int): The max threshold to control whether the dynamic shape process is 1603 used when run the static graph, the default value is 0. When the number of operations in the static graph 1604 is less than the max threshold, this graph will be executed in dynamic shape process. In large model 1605 scenarios, this approach can save stream resources. If the number of operations in the static graph is 1606 greater than the maximum threshold, this graph will be executed in original static process. 1607 1608 jit_syntax_level (int): Set JIT syntax level for graph compiling, triggered by GRAPH_MODE and @jit decorator. 1609 The value must be ``STRICT`` or ``LAX`` . Default: ``LAX`` . All levels support all backends. 1610 1611 - ``STRICT`` : Only basic syntax is supported, and execution performance is optimal. Can be used for MindIR 1612 load and export. 1613 - ``LAX`` : Compatible with all Python syntax as much as possible. However, execution performance may be 1614 affected and not optimal. Cannot be used for MindIR load and export due to some syntax that may not be 1615 able to be exported. 1616 1617 debug_level (int): Set config for debugging. Default value: ``RELEASE``. 1618 1619 - ``RELEASE``: Used for normally running, and some debug information will be discard to get a better 1620 compiling performance. 1621 - ``DEBUG``: Used for debugging when errors occur, more information will be record in compiling process. 1622 1623 gpu_config (dict): Set the parameters specific to gpu hardware platform. It is not set by default. 1624 Currently, only setting `conv_fprop_algo` and `conv_dgrad_algo` and `conv_wgrad_algo` and `conv_allow_tf32` 1625 and `matmul_allow_tf32` are supported on GPU hardware platform. 1626 1627 - conv_fprop_algo (str): Specifies convolution forward algorithm and the default value is 'normal', 1628 The value range is as follows: 1629 1630 - normal: Use the heuristic search algorithm. 1631 - performance: Use the trial search algorithm. 1632 - implicit_gemm: This algorithm expresses the convolution as a matrix product without actually explicitly 1633 forming the matrix that holds the input tensor data. 1634 - implicit_precomp_gemm: This algorithm expresses convolution as a matrix product without actually 1635 explicitly forming the matrix that holds the input tensor data, but still needs some memory workspace to 1636 precompute some indices in order to facilitate the implicit construction of the matrix that holds the 1637 input tensor data. 1638 - gemm: This algorithm expresses the convolution as an explicit matrix product. A significant memory 1639 workspace is needed to store the matrix that holds the input tensor data. 1640 - direct: This algorithm expresses the convolution as a direct convolution (for example, without 1641 implicitly or explicitly doing a matrix multiplication). 1642 - fft: This algorithm uses the Fast-Fourier Transform approach to compute the convolution. A significant 1643 memory workspace is needed to store intermediate results. 1644 - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles. 1645 A significant memory workspace is needed to store intermediate results but less than fft algorithm for 1646 large size images. 1647 - winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably 1648 sized workspace is needed to store intermediate results. 1649 - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. A 1650 significant workspace may be needed to store intermediate results. 1651 - conv_dgrad_algo (str): Specifies convolution data grad algorithm and the default value is 'normal', 1652 The value range is as follows: 1653 1654 - normal: Use the heuristic search algorithm. 1655 - performance: Use the trial search algorithm. 1656 - algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly 1657 forming the matrix that holds the input tensor data. The sum is done using the atomic add operation, 1658 thus the results are non-deterministic. 1659 - algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming 1660 the matrix that holds the input tensor data. The results are deterministic. 1661 - fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant 1662 memory workspace is needed to store intermediate results. The results are deterministic. 1663 - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles. 1664 A significant memory workspace is needed to store intermediate results but less than fft for large size 1665 images. The results are deterministic. 1666 - winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably 1667 sized workspace is needed to store intermediate results. The results are deterministic. 1668 - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. 1669 A significant workspace may be needed to store intermediate results. The results are deterministic. 1670 - conv_wgrad_algo (str): Specifies convolution filter grad algorithm and the default value is 'normal', 1671 The value range is as follows: 1672 1673 - normal: Use the heuristic search algorithm. 1674 - performance: Use the trial search algorithm. 1675 - algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly 1676 forming the matrix that holds the input tensor data. The sum is done using the atomic add operation, 1677 thus the results are non-deterministic. 1678 - algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming 1679 the matrix that holds the input tensor data. The results are deterministic. 1680 - fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant 1681 memory workspace is needed to store intermediate results. The results are deterministic. 1682 - algo_3: This algorithm is similar to algo_0 but uses some small workspace to precompute some indices. 1683 The results are also non-deterministic. 1684 - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. 1685 A significant workspace may be needed to store intermediate results. The results are deterministic. 1686 - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles. 1687 A significant memory workspace is needed to store intermediate results but less than fft for large size 1688 images. The results are deterministic. 1689 - conv_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUDNN and the 1690 default value is ``True``. 1691 - matmul_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUBLAS and the 1692 default value is ``False``. 1693 1694 jit_config (dict): Set the global jit config for compile, take effect in network defined in Cell or jit 1695 decorators. It is not set by default. 1696 The setting in context is the global jit config, while JitConfig is the local network's jit config. 1697 When both exist simultaneously, the global jit config will not overwrite the local network's jit config. 1698 1699 - jit_level (str): Used to control the compilation optimization level. Default: ``""`` , The framework 1700 automatically selects the execution method based on product, Altas training product is O2, and all other 1701 products are O0. The value range is as follows: 1702 1703 - ``"O0"``: Except for optimizations that may affect functionality, all other optimizations are turned 1704 off, adopt KernelByKernel execution mode. 1705 - ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations, 1706 adopt KernelByKernel execution mode. 1707 - ``"O2"``: Ultimate performance optimization, adopt Sink execution mode. 1708 1709 - infer_boost (str): Used to control the infer mode. Default: ``"off"`` . The value range is as follows: 1710 1711 - ``"on"``: Enable infer mode, get better infer performance. 1712 - ``"off"``: Disable infer mode, use forward to infer, performance is not good. 1713 1714 Raises: 1715 ValueError: If input key is not an attribute in context. 1716 1717 Examples: 1718 >>> import mindspore as ms 1719 >>> ms.set_context(mode=ms.PYNATIVE_MODE) 1720 >>> ms.set_context(precompile_only=True) 1721 >>> ms.set_context(device_target="Ascend") 1722 >>> ms.set_context(device_id=0) 1723 >>> ms.set_context(save_graphs=True, save_graphs_path="./model.ms") 1724 >>> ms.set_context(enable_reduce_precision=True) 1725 >>> ms.set_context(enable_graph_kernel=True) 1726 >>> ms.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text") 1727 >>> ms.set_context(reserve_class_name_in_scope=True) 1728 >>> ms.set_context(variable_memory_max_size="6GB") 1729 >>> ms.set_context(aoe_tune_mode="online") 1730 >>> ms.set_context(aoe_config={"job_type": "2"}) 1731 >>> ms.set_context(check_bprop=True) 1732 >>> ms.set_context(max_device_memory="3.5GB") 1733 >>> ms.set_context(mempool_block_size="1GB") 1734 >>> ms.set_context(print_file_path="print.pb") 1735 >>> ms.set_context(max_call_depth=80) 1736 >>> ms.set_context(env_config_path="./env_config.json") 1737 >>> ms.set_context(grad_for_scalar=True) 1738 >>> ms.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms") 1739 >>> ms.set_context(pynative_synchronize=True) 1740 >>> ms.set_context(runtime_num_threads=10) 1741 >>> ms.set_context(inter_op_parallel_num=4) 1742 >>> ms.set_context(disable_format_transform=True) 1743 >>> ms.set_context(memory_optimize_level='O0') 1744 >>> ms.set_context(memory_offload='ON') 1745 >>> ms.set_context(deterministic='ON') 1746 >>> ms.set_context(ascend_config={"precision_mode": "force_fp16", "jit_compile": True, 1747 ... "atomic_clean_policy": 1, "op_precision_mode": "./op_precision_config_file", 1748 ... "op_debug_option": "oom", 1749 ... "ge_options": {"global": {"ge.opSelectImplmode": "high_precision"}, 1750 ... "session": {"ge.exec.atomicCleanPolicy": "0"}}}) 1751 >>> ms.set_context(jit_syntax_level=ms.STRICT) 1752 >>> ms.set_context(debug_level=ms.context.DEBUG) 1753 >>> ms.set_context(gpu_config={"conv_fprop_algo": "performance", "conv_allow_tf32": True, 1754 ... "matmul_allow_tf32": True}) 1755 >>> ms.set_context(jit_config={"jit_level": "O0"}) 1756 """ 1757 ctx = _context() 1758 # set device target first 1759 if 'device_target' in kwargs: 1760 ctx.set_device_target(kwargs['device_target']) 1761 device = ctx.get_param(ms_ctx_param.device_target) 1762 _check_ascend_device_context_initialized(device, kwargs) 1763 1764 for key, value in kwargs.items(): 1765 if key in ('enable_sparse', 'auto_tune_mode'): 1766 logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated, " 1767 "and will be removed in the next version.") 1768 continue 1769 if key in ('enable_auto_mixed_precision', 'enable_dump', 'save_dump_path'): 1770 logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated. " 1771 "For details, please see the interface parameter API comments") 1772 continue 1773 _check_key(key) 1774 if key == 'save_graphs': 1775 if value is True: 1776 value = 2 1777 if value is False: 1778 value = 0 1779 if value > 3: 1780 raise ValueError(f"value for save_graphs should be 0-3 but got '{value}'") 1781 if key == 'jit_syntax_level' and value not in (STRICT, COMPATIBLE, LAX): 1782 raise ValueError(f"For 'jit_syntax_level', the value should be context.STRICT" 1783 f" or context.LAX, but got {value}.") 1784 if key == 'debug_level' and value not in (RELEASE, DEBUG): 1785 raise ValueError(f"For 'debug_level', the value should be context.DEBUG" 1786 f" or context.RELEASE, but got {value}.") 1787 if key == 'enable_compile_cache': 1788 setattr(ctx, key, value) 1789 ctx.set_param(ms_ctx_param.__members__[key], int(value)) 1790 continue 1791 if not _check_target_specific_cfgs(device, key): 1792 continue 1793 if hasattr(ctx, key): 1794 setattr(ctx, key, value) 1795 continue 1796 if key in ctx.setters: 1797 ctx.setters[key](ctx, value) 1798 continue 1799 # enum variables beginning with '_' are for internal use 1800 if key in ms_ctx_param.__members__ and key[0] != '_': 1801 ctx.set_param(ms_ctx_param.__members__[key], value) 1802 continue 1803 raise ValueError(f"For 'context.set_context', the keyword argument {key} is not recognized! For detailed " 1804 f"usage of 'set_context', please refer to the Mindspore official website.") 1805 1806 1807def get_context(attr_key): 1808 """ 1809 Get context attribute value according to the input key. 1810 If some attributes are not set, they will be automatically obtained. 1811 1812 Args: 1813 attr_key (str): The key of the attribute. 1814 1815 Returns: 1816 Object, The value of given attribute key. 1817 1818 Raises: 1819 ValueError: If input key is not an attribute in context. 1820 Examples: 1821 >>> import mindspore as ms 1822 >>> ms.get_context("device_target") 1823 >>> ms.get_context("device_id") 1824 """ 1825 ctx = _context() 1826 device = ctx.get_param(ms_ctx_param.device_target) 1827 _ = _check_target_specific_cfgs(device, attr_key) 1828 if hasattr(ctx, attr_key): 1829 return getattr(ctx, attr_key) 1830 # enum variables beginning with '_' are for internal use 1831 if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_': 1832 return ctx.get_param(ms_ctx_param.__members__[attr_key]) 1833 raise ValueError(f"For 'context.get_context', the argument {attr_key} is not recognized! For detailed " 1834 f"usage of 'get_context', please refer to the Mindspore official website.") 1835 1836 1837def _get_mode(): 1838 """ 1839 Get execution mode. Only for internal using. 1840 1841 Returns: 1842 Object: The Value of execution mode. 1843 """ 1844 ctx = _context() 1845 return ctx.get_mode() 1846 1847 1848def get_jit_config(): 1849 """ 1850 Get global jit config. 1851 1852 Returns: 1853 Object: The Value of jit config. 1854 """ 1855 ctx = _context() 1856 return ctx.get_jit_config() 1857 1858 1859class ParallelMode: 1860 """ 1861 Parallel mode options. 1862 1863 There are five kinds of parallel modes, ``STAND_ALONE``, ``DATA_PARALLEL``, 1864 ``HYBRID_PARALLEL``, ``SEMI_AUTO_PARALLEL`` and ``AUTO_PARALLEL``. Default: ``STAND_ALONE``. 1865 1866 - ``STAND_ALONE``: Only one processor is working. 1867 - ``DATA_PARALLEL``: Distributes the data across different processors. 1868 - ``HYBRID_PARALLEL``: Achieves data parallelism and model parallelism manually. 1869 - ``SEMI_AUTO_PARALLEL``: Achieves data parallelism and model parallelism by setting parallel strategies. 1870 - ``AUTO_PARALLEL``: Achieves parallelism automatically. 1871 1872 ``MODE_LIST``: The list of all supported parallel modes. 1873 """ 1874 1875 STAND_ALONE = "stand_alone" 1876 DATA_PARALLEL = "data_parallel" 1877 HYBRID_PARALLEL = "hybrid_parallel" 1878 SEMI_AUTO_PARALLEL = "semi_auto_parallel" 1879 AUTO_PARALLEL = "auto_parallel" 1880 MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] 1881 1882 1883@args_type_check(enable_ps=bool) 1884def set_ps_context(**kwargs): 1885 """ 1886 Set parameter server training mode context. 1887 1888 Note: 1889 Parameter server mode is only supported in graph mode. 1890 Some other environment variables should also be set for parameter server training mode. 1891 These environment variables are listed below: 1892 1893 - MS_SERVER_NUM: Server number 1894 - MS_WORKER_NUM: Worker number 1895 - MS_SCHED_HOST: Scheduler IP address 1896 - MS_SCHED_PORT: Scheduler port 1897 - MS_ROLE: The role of this process: 1898 1899 - MS_SCHED: represents the scheduler, 1900 - MS_WORKER: represents the worker, 1901 - MS_PSERVER/MS_SERVER: represents the Server 1902 1903 Args: 1904 enable_ps (bool): Whether to enable parameter server training mode. 1905 Only after enable_ps is set True, the environment variables will be effective. 1906 Default: ``False`` . 1907 config_file_path (string): Configuration file path used by recovery, parameter server training mode only 1908 supports Server disaster recovery currently. Default: ``''`` . 1909 scheduler_manage_port (int): Scheduler manage port used to scale out/in. Default: ``11202`` . 1910 enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False`` . 1911 client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ``''`` . 1912 server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ``''`` . 1913 1914 Raises: 1915 ValueError: If input key is not the attribute in parameter server training mode context. 1916 1917 Examples: 1918 >>> import mindspore as ms 1919 >>> ms.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456') 1920 """ 1921 _set_ps_context(**kwargs) 1922 1923 1924def get_ps_context(attr_key): 1925 """ 1926 Get parameter server training mode context attribute value according to the key. 1927 1928 Args: 1929 attr_key (str): The key of the attribute: 1930 1931 - enable_ps (bool): Whether to enable parameter server training mode. Default: ``False`` . 1932 - config_file_path (string): Configuration file path used by recovery, parameter server training mode only 1933 supports Server disaster recovery currently. Default: ``''`` . 1934 - scheduler_manage_port (int): Scheduler manage port used to scale out/in. Default: ``11202`` . 1935 - enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False`` . 1936 - client_password (str): Password to decrypt the secret key stored in the client certificate. 1937 Default: ``''`` . 1938 - server_password (str): Password to decrypt the secret key stored in the server certificate. 1939 Default: ``''`` . 1940 1941 Returns: 1942 Returns attribute value according to the key. 1943 1944 Raises: 1945 ValueError: If input key is not attribute in auto parallel context. 1946 1947 Examples: 1948 >>> import mindspore as ms 1949 >>> ms.get_ps_context("enable_ps") 1950 """ 1951 return _get_ps_context(attr_key) 1952 1953 1954def reset_ps_context(): 1955 """ 1956 Reset parameter server training mode context attributes to the default values. 1957 1958 Meaning of each field and its default value refer to :func:`mindspore.set_ps_context`. 1959 1960 Examples: 1961 >>> import mindspore as ms 1962 >>> ms.reset_ps_context() 1963 """ 1964 _reset_ps_context() 1965 1966 1967_hccl_connect_timeout = '600' 1968 1969 1970def _init_parallel_env(): 1971 """Set hccl connect timeout.""" 1972 if 'HCCL_CONNECT_TIMEOUT' not in os.environ: 1973 os.environ['HCCL_CONNECT_TIMEOUT'] = _hccl_connect_timeout 1974 1975 1976_init_parallel_env() 1977