• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16The context of mindspore, used to configure the current execution environment,
17includes the execution mode, execution backend and other feature switches.
18"""
19from __future__ import absolute_import
20
21import json
22import os
23import time
24import threading
25from collections import namedtuple
26from types import FunctionType
27
28from mindspore import log as logger
29from mindspore._c_expression import MSContext, ms_ctx_param
30from mindspore import _checkparam as Validator
31from mindspore._checkparam import args_type_check
32from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
33    _reset_auto_parallel_context
34from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context, \
35    _need_reset_device_target_for_ps
36from mindspore.parallel._offload_context import _set_offload_context, _get_offload_context
37from mindspore.hal.device import is_initialized
38
39__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'STRICT', 'COMPATIBLE', 'LAX', 'set_context', 'get_context',
40           'set_auto_parallel_context', 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode',
41           'set_ps_context', 'get_ps_context', 'reset_ps_context', 'set_offload_context', 'get_offload_context']
42
43GRAPH_MODE = 0
44PYNATIVE_MODE = 1
45_DEVICE_APP_MEMORY_SIZE = 31  # The max memory size of graph plus variable.
46_RE_PATTERN = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
47K_CONTEXT = None
48
49# Enumerate for the property 'jit_syntax_level'.
50STRICT = 0
51COMPATIBLE = 1
52LAX = 2
53
54# Enumerate for the property 'debug_level'.
55RELEASE = 0
56DEBUG = 1
57
58
59def _make_directory(path):
60    """Make directory."""
61    if path is None or not isinstance(path, str) or path.strip() == "":
62        raise ValueError(f"For 'context.set_context', the 'save_graphs_path' or the 'print_file_path' is invalid "
63                         f"type, it should be Non-empty string, but got '{path}'.")
64
65    path = os.path.realpath(path)
66    logger.debug("The absolute path is %r", path)
67
68    if not os.path.exists(path):
69        logger.debug("The directory(%s) doesn't exist, will create it", path)
70        try:
71            os.makedirs(path)
72        except FileExistsError:
73            logger.debug("The directory(%s) already exist.", path)
74        except PermissionError as e:
75            logger.critical(f"No write permission on the directory '{path}'', error = {e}")
76            raise ValueError(e.__str__() + f"\nNo write permission on the directory '{path}'.")
77    return path
78
79
80def _get_print_file_name(file_name):
81    """Add timestamp suffix to file name. Rename the file name:  file_name + "." + time(seconds)."""
82    time_second = str(int(time.time()))
83    file_name = file_name + "." + time_second
84    if os.path.exists(file_name):
85        raise ValueError("For 'context.set_context', the argument 'print_file_path' {} already exists, "
86                         "please check it".format(file_name))
87    return file_name
88
89
90class _ThreadLocalInfo(threading.local):
91    """
92    Thread local Info used for store thread local attributes.
93    """
94
95    def __init__(self):
96        super(_ThreadLocalInfo, self).__init__()
97        self._reserve_class_name_in_scope = True
98        self.debug_runtime = False
99
100    @property
101    def reserve_class_name_in_scope(self):
102        """Get whether to save the network class name in the scope."""
103        return self._reserve_class_name_in_scope
104
105    @reserve_class_name_in_scope.setter
106    def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
107        """Set whether to save the network class name in the scope."""
108        self._reserve_class_name_in_scope = reserve_class_name_in_scope
109
110
111_ContextRecord = namedtuple(
112    "_ContextRecord", ["is_pynative_mode", "switch_context_fn"])
113
114
115class _ContextSwitchInfo(threading.local):
116    """
117    Record of context switch information.
118
119    Args:
120        is_pynative (bool): Whether to adopt the PyNative mode.
121    """
122
123    def __init__(self, is_pynative):
124        super(_ContextSwitchInfo, self).__init__()
125        self.context_stack = []
126        if is_pynative:
127            self.push(True, None)
128
129    def push(self, is_pynative, switch_context_fn):
130        """
131        Push a context switch record onto the stack.
132
133        Args:
134            is_pynative (bool): Whether context switch to PyNative mode.
135            switch_context_fn (Function): A callable that executes the context switch.
136        """
137        if isinstance(switch_context_fn, FunctionType):
138            switch_context_fn()
139        self.context_stack.append(
140            _ContextRecord(is_pynative, switch_context_fn))
141
142    def pop(self):
143        self.context_stack.pop()
144
145
146class _Context:
147    """
148    _Context is the environment in which operations are executed
149
150    Note:
151        Create a context through instantiating Context object is not recommended.
152        should use context() to get the context since Context is a singleton.
153    """
154    _instance = None
155    _instance_lock = threading.Lock()
156
157    def __new__(cls, *args, **kwargs):
158        if cls._instance is None:
159            cls._instance_lock.acquire()
160            cls._instance = object.__new__(cls)
161            cls._instance_lock.release()
162        return cls._instance
163
164    def __init__(self):
165        self._thread_local_info = _ThreadLocalInfo()
166        self._context_switches = _ContextSwitchInfo(False)
167        self._context_handle = MSContext.get_instance()
168        self._support_binary = False
169        self.enable_compile_cache = None
170        self._mode = PYNATIVE_MODE
171        self._jit_config = {}
172
173    def __getattribute__(self, attr):
174        value = object.__getattribute__(self, attr)
175        if attr == "_context_handle" and value is None:
176            raise ValueError("Get {} failed, please check whether 'env_config_path' is correct.".format(attr))
177        return value
178
179    def get_param(self, param):
180        return self._context_handle.get_param(param)
181
182    def set_param(self, param, value):
183        self._context_handle.set_param(param, value)
184
185    def get_mode(self):
186        """Get current mode."""
187        return self._mode
188
189    def get_jit_config(self):
190        """Get current jit_config."""
191        return self._jit_config
192
193    def set_mode(self, mode):
194        """
195        Switch between Graph mode and PyNative mode.
196
197        Args:
198            mode (int): GRAPH_MODE or PYNATIVE_MODE.
199        """
200        if mode == PYNATIVE_MODE:
201            if self.enable_debug_runtime:
202                self.set_backend_policy("vm")
203            parallel_mode = _get_auto_parallel_context("parallel_mode")
204            if parallel_mode not in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE, ParallelMode.AUTO_PARALLEL):
205                raise ValueError(f"Got {parallel_mode}, when the user enabled SEMI_AUTO_PARALELL, "
206                                 f"pynative mode dose not support, you should set either "
207                                 f"context.set_auto_parallel_context(parallel_mode='data_parallel'), "
208                                 f"context.set_auto_parallel_context(parallel_mode='stand_alone') "
209                                 f"or context.set_auto_parallel_context(parallel_mode='auto_parallel').")
210            self._context_switches.push(True, None)
211        elif mode == GRAPH_MODE:
212            if self.enable_debug_runtime:
213                self.set_backend_policy("ge")
214            self._context_switches.push(False, None)
215        else:
216            raise ValueError(f"For 'context.set_context', the argument 'mode' should be context.GRAPH_MODE (0) "
217                             f"or context.PYNATIVE_MODE (1), but got {mode}.")
218        self.set_param(ms_ctx_param.mode, mode)
219        self._mode = mode
220
221    def set_jit_syntax_level(self, level):
222        """"Set the JIT syntax level for graph compiling"""
223        if level != STRICT and level != COMPATIBLE and level != LAX:
224            raise ValueError(f"For 'context.set_jit_syntax_level', the argument 'level' should be context.STRICT "
225                             f"or context.LAX, but got {level}.")
226        self.set_param(ms_ctx_param.jit_syntax_level, level)
227
228    def set_debug_level(self, level):
229        """"Set the debug level for graph compiling"""
230        if level != RELEASE and level != DEBUG:
231            raise ValueError(f"For 'context.set_debug_level', the argument 'level' should be context.RELEASE "
232                             f"or context.DEBUG, but got {level}.")
233        self.set_param(ms_ctx_param.debug_level, level)
234
235    def set_memory_optimize_level(self, memory_optimize_level):
236        """
237        The memory optimize level, support "O0", "O1".
238
239        Args:
240            target (str): "O0", "O1"
241        """
242        memory_optimize_levels = ["O0", "O1"]
243        if memory_optimize_level not in memory_optimize_levels:
244            raise ValueError(f"For 'context.set_context', the argument 'memory_optimize_level' must be one of "
245                             f"{memory_optimize_levels}, but got {memory_optimize_level}.")
246        if memory_optimize_level == "O0":
247            self.set_param(ms_ctx_param.memory_optimize_level, 0)
248        else:
249            self.set_param(ms_ctx_param.memory_optimize_level, 1)
250
251    def set_memory_offload(self, memory_offload):
252        """
253        Enable memory offload or not, support "ON", "OFF".
254
255        Args:
256            memory_offload (str): "ON", "OFF"
257        """
258        memory_offload_options = ["ON", "OFF"]
259        if memory_offload not in memory_offload_options:
260            raise ValueError(f"For 'context.set_context', the argument 'memory_offload' must be one of "
261                             f"{memory_offload_options}, but got {memory_offload}.")
262        if memory_offload == "ON":
263            self.set_param(ms_ctx_param.memory_offload, True)
264        else:
265            self.set_param(ms_ctx_param.memory_offload, False)
266
267    def set_deterministic(self, deterministic):
268        """
269        Enable model run in deterministic, and support the values "ON" and "OFF".
270
271        Args:
272            deterministic (str): "ON", "OFF"
273        """
274        deterministic_options = ["ON", "OFF"]
275        if deterministic not in deterministic_options:
276            raise ValueError(f"For 'context.set_context', the argument 'deterministic' must be one of "
277                             f"{deterministic_options}, but got {deterministic}.")
278        self.set_param(ms_ctx_param.deterministic, deterministic)
279
280    def set_ascend_config(self, ascend_config):
281        """
282        Enable ascend config.
283
284        Args:
285            ascend_config (dict):
286                - precision_mode (str): "force_fp16", "allow_fp32_to_fp16", "allow_mix_precision",
287                            "must_keep_origin_dtype", "force_fp32", "allow_fp32_to_bf16",
288                            "allow_mix_precision_fp16" and "allow_mix_precision_bf16".
289                - jit_compile (bool): ``False`` and ``True``.
290                - atomic_clean_policy (int): ``0`` and ``1``. Default: ``1`` .
291                - op_precision_mode (str): precision mode config file path.
292                - op_debug_option (str): Enable debugging options for Ascend operators,
293                  default not enabled, only supports ``"oom"`` currently.
294                  ``"oom"``: Detect memory out of bounds.
295                - ge_options (dict): Global or session CANN options.
296                - exception_dump (str): Enable exception dump for Ascend operators. ``"0"`` , ``"1"`` and ``"2"``.
297                  Default: ``"2"`` .
298                - parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file.
299                  If its value is None or '', it does not take effect. Default None.
300                - host_scheduling_max_threshold(int): The host scheduling max threshold.
301        """
302        ascend_cfg_modes = {
303            'precision_mode': ["force_fp16", "allow_fp32_to_fp16", "allow_mix_precision", "must_keep_origin_dtype",
304                               "force_fp32", "allow_fp32_to_bf16", "allow_mix_precision_fp16",
305                               "allow_mix_precision_bf16"],
306            'jit_compile': [True, False],
307            'atomic_clean_policy': [0, 1],
308            'matmul_allow_hf32': [True, False],
309            'conv_allow_hf32': [True, False],
310            'exception_dump': ["0", "1", "2"],
311            'op_precision_mode': (str,),
312            'ge_options': (dict,),
313            'parallel_speed_up_json_path': (str, None),
314            'host_scheduling_max_threshold': (int,),
315            'cur_step_num': (int,),
316            'save_checkpoint_steps': (int,),
317            'need_ckpt': (bool,),
318            'last_triggered_step': (int,),
319            'topo_order': (dict,),
320            'op_debug_option': (str, None),
321        }
322        ascend_cfg_setters = {
323            'precision_mode': self._get_ascend_config_setter('precision_mode'),
324            'jit_compile': self._get_ascend_config_setter('jit_compile', lambda v: "1" if v else "0"),
325            'atomic_clean_policy': self._get_ascend_config_setter('atomic_clean_policy', str),
326            'matmul_allow_hf32': self._get_ascend_config_setter('matmul_allow_hf32', lambda v: "1" if v else "0"),
327            'conv_allow_hf32': self._get_ascend_config_setter('conv_allow_hf32', lambda v: "1" if v else "0"),
328            'exception_dump': self._get_ascend_config_setter('exception_dump'),
329            'op_debug_option': self._set_op_debug_option,
330            'op_precision_mode': self._set_op_precision_mode,
331            'ge_options': self._set_ge_options,
332            'parallel_speed_up_json_path': self._set_speedup_config_path,
333            'host_scheduling_max_threshold': self._get_ascend_config_setter('host_scheduling_max_threshold', str),
334            'cur_step_num': self._set_cur_step_num,
335            'save_checkpoint_steps': self._set_save_checkpoint_steps,
336            'need_ckpt': self._set_need_ckpt,
337            'last_triggered_step': self._set_last_triggered_step,
338            'topo_order': self._set_topo_order
339        }
340        ascend_cfg_set = tuple(ascend_cfg_modes.keys())
341        for ascend_key, ascend_value in ascend_config.items():
342            if ascend_key not in ascend_cfg_set:
343                raise ValueError(f"For 'context.set_context', the key of argument 'ascend_config' must be one of "
344                                 f"{ascend_cfg_set}, but got {ascend_key}.")
345            supported_modes = ascend_cfg_modes.get(ascend_key)
346            if isinstance(supported_modes, list) and ascend_value not in supported_modes:
347                raise ValueError(f"For 'ascend_config', the value of argument {ascend_key} must be one of "
348                                 f"{supported_modes}, but got {ascend_value}.")
349            if isinstance(supported_modes, tuple) and not isinstance(ascend_value, supported_modes):
350                raise TypeError(f"For 'ascend_config', the type of argument {ascend_key} must be one of "
351                                f"{supported_modes}, but got {type(ascend_value)}.")
352            cfg_setter = ascend_cfg_setters.get(ascend_key)
353            cfg_setter(ascend_value)
354
355    def set_gpu_config(self, gpu_config):
356        """
357        Enable gpu config.
358
359        Args:
360            gpu_config (dict):
361
362                - conv_fprop_algo (str): "normal", "performance" or user specifies conv forward algorithm directly.
363                - conv_dgrad_algo (str): "normal", "performance" or user specifies conv data grad algorithm directly.
364                - conv_wgrad_algo (str): "normal", "performance" or user specifies conv weight grad algorithm directly.
365                - conv_allow_tf32 (bool): ``False`` and ``True``.
366                - matmul_allow_tf32 (bool): ``False`` and ``True``.
367        """
368
369        gpu_cfgs = {'conv_fprop_algo': ["normal", "performance", "implicit_gemm", "precomp_gemm", "gemm", "direct",
370                                        "fft", "fft_tiling", "winograd", "winograd_nonfused"],
371                    'conv_dgrad_algo': ["normal", "performance", "algo_0", "algo_1", "fft", "fft_tiling", "winograd",
372                                        "winograd_nonfused"],
373                    'conv_wgrad_algo': ["normal", "performance", "algo_0", "algo_1", "fft", "algo_3", "fft_tiling",
374                                        "winograd_nonfused"],
375                    'conv_allow_tf32': [True, False],
376                    'matmul_allow_tf32': [True, False]}
377        for gpu_key in gpu_config:
378            if gpu_key not in gpu_cfgs:
379                raise ValueError(f"For 'context.set_context', the key of argument 'gpu_config' must be one of "
380                                 f"{gpu_cfgs}, but got {gpu_key}.")
381            supported_value = gpu_cfgs.get(gpu_key)
382            if gpu_config[gpu_key] not in supported_value:
383                raise ValueError(f"For 'gpu_config', the value of argument {gpu_key} must be one of "
384                                 f"{supported_value}, but got {gpu_config[gpu_key]}.")
385            if gpu_key == 'conv_fprop_algo':
386                self.set_param(ms_ctx_param.conv_fprop_algo, gpu_config[gpu_key])
387            if gpu_key == 'conv_dgrad_algo':
388                self.set_param(ms_ctx_param.conv_dgrad_algo, gpu_config[gpu_key])
389            if gpu_key == 'conv_wgrad_algo':
390                self.set_param(ms_ctx_param.conv_wgrad_algo, gpu_config[gpu_key])
391            if gpu_key == 'conv_allow_tf32':
392                self.set_param(ms_ctx_param.conv_allow_tf32, gpu_config[gpu_key])
393            if gpu_key == 'matmul_allow_tf32':
394                self.set_param(ms_ctx_param.matmul_allow_tf32, gpu_config[gpu_key])
395
396    def set_jit_config(self, jit_config):
397        """
398        Enable jit config.
399
400        Args:
401            jit_config (dict):
402
403                - jit_level (str): "O0", "O1" or "O2" to control the compilation optimization level.
404        """
405        jit_cfgs = {'jit_level': ["O0", "O1", "O2"], 'infer_boost': ["on", "off"]}
406        key_args_map = {'jit_level': ms_ctx_param.jit_level, 'infer_boost': ms_ctx_param.infer_boost}
407        for jit_key in jit_config:
408            if jit_key not in jit_cfgs:
409                raise ValueError(f"For 'context.set_context', the key of argument 'jit_config' must be one of "
410                                 f"{jit_cfgs}, but got {jit_key}.")
411            supported_value = jit_cfgs.get(jit_key)
412            if jit_config[jit_key] not in supported_value:
413                raise ValueError(f"For 'jit_cfgs', the value of argument {jit_key} must be one of "
414                                 f"{supported_value}, but got {jit_config[jit_key]}.")
415            self._jit_config = jit_config
416            self.set_param(key_args_map[jit_key], jit_config[jit_key])
417
418        if 'infer_boost' in jit_config and jit_config['infer_boost'] == "on" and jit_config['jit_level'] != "O0":
419            raise ValueError(f"Only jit_level set O0 can set infer_boost to on.")
420
421    def set_backend_policy(self, policy):
422        success = self._context_handle.set_backend_policy(policy)
423        if not success:
424            raise RuntimeError("Backend policy must be one of values in ['ge', 'vm', 'ms']. "
425                               "But got {}.".format(policy))
426
427    def set_save_graphs_path(self, save_graphs_path):
428        self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
429
430    def set_device_target(self, target):
431        """
432        The target device to run, support "Ascend", "GPU", and "CPU".
433
434        Args:
435            target (str): "Ascend", "GPU", and "CPU".
436        """
437        valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
438        if target not in valid_targets:
439            raise ValueError(f"For 'context.set_context', the argument 'device_target' must be one of "
440                             f"{valid_targets}, but got {target}.")
441        if target == "Davinci":
442            target = "Ascend"
443            logger.warning("The device 'Davinci' is deprecated and will be removed in the next version. "
444                           "For 'context.set_context', please set the argument 'device_target' "
445                           "to 'CPU', 'GPU' or 'Ascend',if you set it to 'Davinci', it will be automatically "
446                           "changed to 'Ascend'.")
447        # If in Parameter Server mode, Ascend card should not be used by server and scheduler.
448        if _need_reset_device_target_for_ps(target):
449            logger.info("Reset device target to CPU when set_device_target.")
450            target = "CPU"
451        self.set_param(ms_ctx_param.device_target, target)
452        if self.enable_debug_runtime and target == "CPU":
453            self.set_backend_policy("vm")
454
455    def set_aoe_tune_mode(self, tune_mode):
456        """
457        Set aoe tune mode, support "online" and "offline".
458
459        Args:
460            tune_mode (str): "online" and "offline".
461        """
462        candidate = ["online", "offline"]
463        if tune_mode in candidate:
464            self.set_param(ms_ctx_param.aoe_tune_mode, tune_mode)
465        else:
466            raise ValueError(f"For 'context.set_context', the argument 'aoe_tune_mode' must be in "
467                             f"['online', 'offline'], but got {tune_mode}.")
468
469    def set_aoe_config(self, aoe_config):
470        """
471        Enable aoe config.
472
473        Args:
474            aoe_config (dict):
475                - job_type (str): ``"1"``, ``"2"``. Default: ``"2"`` .
476                  - ``"1"``: subgraph tuning.
477                  - ``"2"``: operator tuning.
478        """
479
480        aoe_cfgs = {'job_type': ["1", "2"]}
481        for aoe_config_key in aoe_config:
482            if aoe_config_key not in aoe_cfgs:
483                raise ValueError(f"For 'context.set_context', the key of argument 'aoe_config' must be one of "
484                                 f"{aoe_cfgs}, but got {aoe_config_key}.")
485            supported_value = aoe_cfgs.get(aoe_config_key)
486            if aoe_config[aoe_config_key] not in supported_value:
487                raise ValueError(f"For 'aoe_config', the value of argument {aoe_config_key} must be one of "
488                                 f"{supported_value}, but got {aoe_config[aoe_config_key]}.")
489            if aoe_config_key == 'job_type':
490                self.set_param(ms_ctx_param.aoe_job_type, aoe_config[aoe_config_key])
491
492    def set_device_id(self, device_id):
493        if device_id < 0 or device_id > 4095:
494            raise ValueError(f"For 'context.set_context', the argument 'device_id' must be in range [0, 4095], "
495                             f"but got {device_id}.")
496        self.set_param(ms_ctx_param.device_id, device_id)
497
498    def set_max_call_depth(self, max_call_depth):
499        if max_call_depth <= 0:
500            raise ValueError(f"For 'context.set_context', the argument 'max_call_depth' must be greater than 0, "
501                             f"but got {max_call_depth}.")
502        self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
503
504    def set_profiling_options(self, option):
505        if not isinstance(option, str):
506            raise TypeError("For 'context.set_context', the argument 'profiling_option' must be string, "
507                            "but got {}.".format(type(option)))
508        self.set_param(ms_ctx_param.profiling_options, option)
509
510    def set_variable_memory_max_size(self, variable_memory_max_size):
511        """set values of variable_memory_max_size and graph_memory_max_size"""
512        logger.warning("For 'context.set_context', the parameter 'variable_memory_max_size' is deprecated, "
513                       "and will be removed in a future "
514                       "version. Please use parameter 'max_device_memory' instead.")
515        if not Validator.check_str_by_regular(variable_memory_max_size, _RE_PATTERN):
516            raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should be in correct"
517                             " format! It must be a string ending with 'GB', in addition to that, it must contain "
518                             "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}GB."
519                             .format(variable_memory_max_size))
520        if float(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE:
521            raise ValueError("For 'context.set_context', the argument 'variable_memory_max_size' should not be "
522                             "greater than 31GB, but got {}GB.".format(variable_memory_max_size))
523        variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
524        graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
525        graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
526        self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
527        self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
528
529    def set_max_device_memory(self, max_device_memory):
530        if not Validator.check_str_by_regular(max_device_memory, _RE_PATTERN):
531            raise ValueError("For 'context.set_context', the argument 'max_device_memory' should be in correct "
532                             " format! It must be a string ending with 'GB', in addition to that, it must contain "
533                             "only numbers or decimal points, such as \"5GB\" or \"3.5GB\", but got {}."
534                             .format(max_device_memory))
535        max_device_memory_value = float(max_device_memory[:-2])
536        if max_device_memory_value == 0:
537            raise ValueError("For 'context.set_context', the argument 'max_device_memory' should not be \"0GB\".")
538        self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
539
540    def set_mempool_block_size(self, mempool_block_size):
541        """Set the block size of memory pool."""
542        global_jit_config = get_jit_config()
543        is_force_kbk = False
544        if global_jit_config:
545            is_force_kbk = global_jit_config.get('jit_level') == "O0" or global_jit_config.get('jit_level') == "O1"
546        if _get_mode() == GRAPH_MODE and not is_force_kbk:
547            logger.warning("Graph mode doesn't support to set parameter 'mempool_block_size' of context currently, "
548                           "you can use context.set_context to set pynative mode or set jit_level=O0/O1.")
549            return
550        if not Validator.check_str_by_regular(mempool_block_size, _RE_PATTERN):
551            raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in "
552                             "correct format! Such as \"10GB\", "
553                             "but got {}".format(mempool_block_size))
554        mempool_block_size_value = float(mempool_block_size[:-2])
555        if mempool_block_size_value < 1.0:
556            raise ValueError("For 'context.set_context',  the argument 'mempool_block_size' should be "
557                             "greater or equal to \"1GB\", "
558                             "but got {}GB".format(float(mempool_block_size[:-2])))
559        self.set_param(ms_ctx_param.mempool_block_size, mempool_block_size_value)
560
561    def set_print_file_path(self, file_path):
562        """Add timestamp suffix to file name. Sets print file path."""
563        print_file_path = os.path.realpath(file_path)
564        if os.path.isdir(print_file_path):
565            raise IOError("For 'context.set_context', the argument 'print_file_path' should be file path, "
566                          "but got directory {}.".format(file_path))
567
568        if os.path.exists(print_file_path):
569            _path, _file_name = os.path.split(print_file_path)
570            path = _make_directory(_path)
571            file_name = _get_print_file_name(_file_name)
572            full_file_name = os.path.join(path, file_name)
573        else:
574            full_file_name = print_file_path
575        self.set_param(ms_ctx_param.print_file_path, full_file_name)
576
577    def set_env_config_path(self, env_config_path):
578        """Check and set env_config_path."""
579        if not self._context_handle.enable_dump_ir():
580            raise ValueError("For 'context.set_context', the argument 'env_config_path' is not supported, please "
581                             "enable ENABLE_DUMP_IR with '-D on' and recompile source firstly.")
582        env_config_path = os.path.realpath(env_config_path)
583        if not os.path.isfile(env_config_path):
584            raise ValueError("For 'context.set_context', the 'env_config_path' file %r is not exists, "
585                             "please check whether 'env_config_path' is correct." % env_config_path)
586        try:
587            with open(env_config_path, 'r') as f:
588                json.load(f)
589        except (TypeError, ValueError) as exo:
590            raise ValueError(str(exo) + "\nFor 'context.set_context', open or load the 'env_config_path' file {} "
591                                        "failed, please check whether 'env_config_path' is json file and correct, "
592                                        "or may not have permission to read it.".format(env_config_path)) from exo
593        self.set_param(ms_ctx_param.env_config_path, env_config_path)
594
595    def set_runtime_num_threads(self, runtime_num_threads):
596        """Check and set runtime_num_threads."""
597        if runtime_num_threads < 0:
598            raise ValueError("The num of thread must bigger than or equal to 0.")
599        self.set_param(ms_ctx_param.runtime_num_threads, runtime_num_threads)
600
601    def set_op_timeout(self, op_timeout):
602        """Set the maximum duration of executing an operator in seconds."""
603        if op_timeout < 0:
604            raise ValueError("The num of op exe timeout must bigger than or equal to 0.")
605        self.set_param(ms_ctx_param.op_timeout, op_timeout)
606
607    def set_inter_op_parallel_num(self, inter_op_parallel_num):
608        """Check and set inter_op_parallel_num."""
609        if inter_op_parallel_num < 0:
610            raise ValueError("The num of parallel thread must bigger than or equal to 0.")
611        self.set_param(ms_ctx_param.inter_op_parallel_num, inter_op_parallel_num)
612
613    setters = {
614        'mode': set_mode,
615        'save_graphs_path': set_save_graphs_path,
616        'device_target': set_device_target,
617        'aoe_tune_mode': set_aoe_tune_mode,
618        'device_id': set_device_id,
619        'max_call_depth': set_max_call_depth,
620        'profiling_options': set_profiling_options,
621        'variable_memory_max_size': set_variable_memory_max_size,
622        'max_device_memory': set_max_device_memory,
623        'mempool_block_size': set_mempool_block_size,
624        'print_file_path': set_print_file_path,
625        'env_config_path': set_env_config_path,
626        'inter_op_parallel_num': set_inter_op_parallel_num,
627        'runtime_num_threads': set_runtime_num_threads,
628        'memory_optimize_level': set_memory_optimize_level,
629        'op_timeout': set_op_timeout,
630        'memory_offload': set_memory_offload,
631        'deterministic': set_deterministic,
632        'ascend_config': set_ascend_config,
633        'jit_syntax_level': set_jit_syntax_level,
634        'debug_level': set_debug_level,
635        'gpu_config': set_gpu_config,
636        'aoe_config': set_aoe_config,
637        'jit_config': set_jit_config,
638    }
639
640    @property
641    def reserve_class_name_in_scope(self):
642        """Get whether to save the network class name in the scope."""
643        return self._thread_local_info.reserve_class_name_in_scope
644
645    @reserve_class_name_in_scope.setter
646    def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
647        """Set whether to save the network class name in the scope."""
648        if not isinstance(reserve_class_name_in_scope, bool):
649            raise ValueError("For 'context.set_context', the type of the property 'reserve_class_name_in_scope' must "
650                             "be bool, but got {}.".format(type(reserve_class_name_in_scope)))
651        self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
652
653    @property
654    def enable_ge(self):
655        return self._context_handle.get_backend_policy() == 'ge'
656
657    @property
658    def enable_debug_runtime(self):
659        return self._thread_local_info.debug_runtime
660
661    @enable_debug_runtime.setter
662    def enable_debug_runtime(self, enable):
663        thread_info = self._thread_local_info
664        thread_info.debug_runtime = enable
665
666    @property
667    def support_binary(self):
668        """Whether support run .pyc or .so in graph mode."""
669        return self._support_binary
670
671    @support_binary.setter
672    def support_binary(self, support: bool):
673        if not isinstance(support, bool):
674            raise TypeError(f"The attribute 'support_binary' should be a bool, but got {type(support)}.")
675        self._support_binary = support
676
677    def _get_ascend_config_setter(self, ascend_key, trans_fn=None):
678        def _config_setter(ascend_value):
679            self.set_param(ms_ctx_param.__members__[ascend_key], trans_fn(ascend_value))
680
681        if trans_fn is None:
682            trans_fn = lambda x: x
683        return _config_setter
684
685    def _set_op_debug_option(self, option_value):
686        valid_order = {'oom'}
687        if not isinstance(option_value, str):
688            raise TypeError(f"For 'ascend_config', the type of 'op_debug_option' must be str, "
689                            f"but got {type(option_value)}.")
690        if option_value not in valid_order:
691            raise ValueError(f"For 'ascend_config', the 'op_debug_option' supports being set to 'oom' currently, "
692                             f"but got {option_value}.")
693        self.set_param(ms_ctx_param.op_debug_option, option_value)
694
695    def _set_op_precision_mode(self, ascend_value):
696        op_precision_path = ascend_value
697        real_path = os.path.realpath(op_precision_path)
698        if not os.path.exists(real_path):
699            raise ValueError(f"For 'ascend_config', the 'op_precision_mode' is invalid path, "
700                             f"got '{op_precision_path}'.")
701        self.set_param(ms_ctx_param.op_precision_mode, ascend_value)
702
703    def _set_ge_options(self, ge_options):
704        """Set ge options."""
705        for level, options in ge_options.items():
706            if level not in ['global', 'session']:
707                raise ValueError(f"For 'ascend_config', the key of ge_options must be one of "
708                                 f"('global', 'session'), but got {level}.")
709
710            if not isinstance(options, dict):
711                raise TypeError(f"For 'ge_options', the type of {level} options must be dict, "
712                                f"but got {type(options)}. The error options: {options}.")
713
714            for key, value in options.items():
715                if not isinstance(key, str):
716                    raise TypeError(f"For 'ge_options', the type of key and value must be str, "
717                                    f"but got {type(key)}. The error key is {key}.")
718                if not isinstance(value, str):
719                    raise TypeError(f"For 'ge_options', the type of key and value must be str, "
720                                    f"but got {type(value)}. The error value is {value}")
721
722        options_str = json.dumps(ge_options)
723        self.set_param(ms_ctx_param.ge_options, options_str)
724
725    def _set_topo_order(self, topo_order):
726        """
727        Set topo order.
728
729        Args:
730            topo_order (dict):
731                key: str, the name of the graph.
732                value: str, the topo order of the graph, should be one of 'dfs', 'bfs', 'rdfs'.
733        """
734        valid_order = {'dfs', 'bfs', 'rdfs'}
735        if not isinstance(topo_order, dict):
736            raise TypeError(f"For 'ascend_config', the 'topo_order' should be a dict, "
737                            f"got '{type(topo_order)}'.")
738        for k, v in topo_order.items():
739            if not isinstance(k, str):
740                raise TypeError("key {} is not a str".format(k))
741            if v not in valid_order:
742                raise ValueError("value {} should be one of {}.".format(v, valid_order))
743
744        options_str = json.dumps(topo_order)
745        self.set_param(ms_ctx_param.topo_order, options_str)
746
747    def _set_need_ckpt(self, need_ckpt):
748        """Set need ckpt flag"""
749        if not isinstance(need_ckpt, bool):
750            raise TypeError(f"For step num, the value type should be int, but got {type(need_ckpt)}, {need_ckpt}")
751        self.set_param(ms_ctx_param.need_ckpt, need_ckpt)
752
753    def _set_cur_step_num(self, step_num):
754        """set current step num at every step begin"""
755        if not isinstance(step_num, int):
756            raise TypeError(f"For step num, the value type should be int, but got {type(step_num)}, {step_num}")
757        self.set_param(ms_ctx_param.cur_step_num, step_num)
758
759    def _set_save_checkpoint_steps(self, steps):
760        """set save checkpoint steps before run"""
761        if not isinstance(steps, int):
762            raise TypeError(f"For step num, the value type should be int, but got {type(steps)}, {steps}")
763        self.set_param(ms_ctx_param.save_checkpoint_steps, steps)
764
765    def _set_last_triggered_step(self, step):
766        """set last triggered save ckpt steps before run"""
767        if not isinstance(step, int):
768            raise TypeError(f"For step num, the value type should be int, but got {type(step)}, {step}")
769        self.set_param(ms_ctx_param.last_triggered_step, step)
770
771    def _set_speedup_config_path(self, speedup_config_path):
772        """"Check and set speedup config for auto parallel."""
773        if speedup_config_path is None or speedup_config_path == "":
774            return
775        speedup_config_real_path = os.path.abspath(speedup_config_path)
776        if not os.path.exists(speedup_config_real_path):
777            raise ValueError(f"For 'ascend_config', the path to parallel_speed_up_json: "
778                             f"{speedup_config_real_path} does not exist, please check whether the "
779                             f"'parallel_speed_up_json_path' is correct.")
780        try:
781            valid_option = {"recompute_comm_overlap": (ms_ctx_param.recompute_comm_overlap, bool),
782                            "matmul_grad_comm_overlap": (ms_ctx_param.matmul_grad_comm_overlap, bool),
783                            "enable_task_opt": (ms_ctx_param.enable_task_opt, bool),
784                            "enable_grad_comm_opt": (ms_ctx_param.enable_grad_comm_opt, bool),
785                            "recompute_allgather_overlap_fagrad":
786                                (ms_ctx_param.recompute_allgather_overlap_fagrad, bool),
787                            "interleaved_matmul_comm": (ms_ctx_param.interleaved_matmul_comm, bool),
788                            "bias_add_comm_swap": (ms_ctx_param.bias_add_comm_swap, bool),
789                            "enable_opt_shard_comm_opt": (ms_ctx_param.enable_opt_shard_comm_opt, bool),
790                            "enable_begin_end_inline_opt": (ms_ctx_param.enable_begin_end_inline_opt, bool),
791                            "enable_concat_eliminate_opt": (ms_ctx_param.enable_concat_eliminate_opt, bool),
792                            "interleaved_layernorm_comm": (ms_ctx_param.interleaved_layernorm_comm, bool),
793                            "compute_communicate_fusion_level":
794                                (ms_ctx_param.compute_communicate_fusion_level, int),
795                            "enable_flash_attention_load_balance":
796                                (ms_ctx_param.enable_flash_attention_load_balance, bool)}
797            with open(speedup_config_real_path, 'r') as f:
798                speedup_config = json.load(f)
799                for key, value in speedup_config.items():
800                    if not isinstance(key, str):
801                        raise TypeError("key {} is not a str".format(key))
802                    if key not in valid_option:
803                        raise ValueError("key {} should be one of {}.".format(key, valid_option.keys()))
804                    set_func, valid_type = valid_option.get(key)
805                    if not isinstance(value, valid_type):
806                        raise TypeError(f"The value type of {key} must be {valid_type}, "
807                                        f"but got value is {value} and type is {type(value)}.")
808                    self.set_param(set_func, value)
809        except (TypeError, ValueError) as exo:
810            raise ValueError(str(exo) + "\nFor 'context.set_context', "
811                                        "open or load the 'speedup_config_path' file {} "
812                                        "failed, please check whether 'speedup_config_path' is json file and correct, "
813                                        "or may not have permission to read it.".format(speedup_config_real_path)) \
814                                        from exo
815
816
817def _context():
818    """
819    Get the global _context, if context is not created, create a new one.
820
821    Returns:
822        _Context, the global context in PyNative mode.
823    """
824    global K_CONTEXT
825    if K_CONTEXT is None:
826        default_backend = 'debug'
827        try:
828            from mindspore import default_config
829            default_backend = default_config.__backend__
830        except ImportError:
831            logger.error("import default config fail")
832        K_CONTEXT = _Context()
833        K_CONTEXT.enable_debug_runtime = False
834        if default_backend == 'debug':
835            K_CONTEXT.enable_debug_runtime = True
836            default_backend = 'vm'
837            K_CONTEXT.set_backend_policy(default_backend)
838    return K_CONTEXT
839
840
841@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
842                 auto_parallel_search_mode=str, search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
843                 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, enable_alltoall=bool,
844                 all_reduce_fusion_config=list, pipeline_stages=int, pipeline_segments=int,
845                 pipeline_result_broadcast=bool, parallel_optimizer_config=dict,
846                 pipeline_config=dict,
847                 comm_fusion=dict, strategy_ckpt_config=dict, force_fp32_communication=bool)
848def set_auto_parallel_context(**kwargs):
849    r"""
850    Set auto parallel context, only data parallel supported on CPU.
851
852    Note:
853        Attribute name is required for setting attributes.
854        If a program has tasks on different parallel modes, before setting a new parallel mode for the
855        next task, interface :func:`mindspore.reset_auto_parallel_context` should be called to reset
856        the configuration.
857        Setting or changing parallel modes must be called before creating any Initializer, otherwise,
858        it may have RuntimeError when compiling the network.
859
860    Some configurations are parallel mode specific, see the below table for details:
861
862    ===========================  ===========================
863    Common                       AUTO_PARALLEL
864    ===========================  ===========================
865    device_num                   gradient_fp32_sync
866    global_rank                  loss_repeated_mean
867    gradients_mean               search_mode
868    parallel_mode                parameter_broadcast
869    all_reduce_fusion_config     strategy_ckpt_load_file
870    enable_parallel_optimizer    strategy_ckpt_save_file
871    parallel_optimizer_config    dataset_strategy
872    enable_alltoall              pipeline_stages
873    pipeline_config              auto_parallel_search_mode
874    force_fp32_communication     pipeline_result_broadcast
875               \                 comm_fusion
876               \                 strategy_ckpt_config
877               \                 group_ckpt_save_file
878               \                 auto_pipeline
879    ===========================  ===========================
880
881    Args:
882        device_num (int): Available device number, the value must be in [1, 4096]. Default: ``1`` .
883        global_rank (int): Global rank id, the value must be in [0, 4095]. Default: ``0`` .
884        gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
885                     "stand_alone" do not support gradients_mean. Default: ``False`` .
886        gradient_fp32_sync (bool): Run allreduce of gradients in fp32. "stand_alone", "data_parallel"
887                     and "hybrid_parallel" do not support gradient_fp32_sync. Default: ``True`` .
888        loss_repeated_mean (bool) - Indicates whether the mean operator is executed backwards when the
889                     calculation is repeated. Default: ``True`` .
890        parallel_mode (str): There are five kinds of parallel modes, ``"stand_alone"`` , ``"data_parallel"`` ,
891                     ``"hybrid_parallel"`` , ``"semi_auto_parallel"`` and ``"auto_parallel"`` . Note the pynative mode
892                     only supports the ``"stand_alone"`` and ``"data_parallel"`` mode. Default: ``"stand_alone"`` .
893
894                     - stand_alone: Only one processor is working.
895
896                     - data_parallel: Distributes the data across different processors.
897
898                     - hybrid_parallel: Achieves data parallelism and model parallelism manually.
899
900                     - semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies.
901
902                     - auto_parallel: Achieving parallelism automatically.
903        search_mode (str): There are three kinds of shard strategy search modes: ``"recursive_programming"`` ,
904                     ``"sharding_propagation"`` and ``"dynamic_programming"`` (Not recommended).
905                     Default: ``"recursive_programming"`` .
906
907                     - recursive_programming: Recursive programming search mode. In order to obtain optimal performance,
908                       it is recommended that users set the batch size to be greater than or equal to the product of
909                       the number of devices and the number of multi-copy parallelism.
910
911                     - sharding_propagation: Propagate shardings from configured ops to non-configured ops.
912
913                     - dynamic_programming: Dynamic programming search mode.
914        auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
915                     for forward compatibility, and this attribute will be deleted in a future MindSpore version.
916        parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
917                     the same network initialization parameter values for all devices, broadcast the parameters
918                     on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
919                     ``data_parallel`` mode, all parameters are broadcast except for the parameter whose attribute
920                     layerwise_parallel is ``True`` . ``Hybrid_parallel`` , ``semi_auto_parallel``  and
921                     ``auto_parallel mode`` , the segmented parameters do not participate in broadcasting.
922                     Default: ``False`` .
923        strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. The parameter is not to be
924                       recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ``''``
925        strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. The parameter is not to be
926                       recommended currently, it is better using 'strategy_ckpt_config' to replace it. Default: ``''``
927        full_batch (bool): If you load whole batch datasets in ``auto_parallel`` mode, this parameter
928                       should be set as ``True`` . Default: ``False`` . The interface is not to be recommended
929                       currently, it is better using 'dataset_strategy' to replace it.
930        dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: ``"data_parallel"`` .
931                       dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is
932                       equal to full_batch=True. For execution mode is 'GRAPH_MODE' and dataset load into net by model
933                       parallel strategy likes ds_stra ((1, 8), (1, 8)), it requires using
934                       set_auto_parallel_context(dataset_strategy=ds_stra).
935        enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
936                       data parallel training in the benefit of time and memory saving. Currently, auto and semi auto
937                       parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports
938                       `Lamb` and `AdamWeightDecay` in Ascend . Default: ``False`` .
939        force_fp32_communication (bool): A switch that determines whether reduce operators (AllReduce, ReduceScatter)
940                        are forced to use the fp32 data type for communication during communication. True is the enable
941                        switch. Default: ``False`` .
942        enable_alltoall (bool): A switch that allows AllToAll operators to be generated during communication. If its
943                        value is ``False`` , there will be a combination of operators such as AllGather, Split and
944                        Concat instead of AllToAll. Default: ``False`` .
945        all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
946                       and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
947        pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are
948                        distributed alone in the pipeline. The total devices will be divided into 'pipeline_stags'
949                        stages.
950                        Default: ``1`` .
951        pipeline_result_broadcast (bool): A switch that broadcast the last stage result to all other stage in pipeline
952                        parallel inference. Default: ``False`` .
953        pipeline_config (dict): A dict contains the keys and values for setting the pipeline parallelism configuration.
954                        It supports the following keys:
955
956                        - pipeline_interleave(bool): Indicates whether to enable the interleaved execution mode.
957                        - pipeline_scheduler(str): Indicates the scheduling mode for pipeline parallelism. Only support
958                          ``gpipe/1f1b``.
959        parallel_optimizer_config (dict): A dict contains the keys and values for setting the parallel optimizer
960                        configure. The configure provides more detailed behavior control about parallel training
961                        when parallel optimizer is enabled. The configure will be effective when we use
962                        mindspore.set_auto_parallel_context(enable_parallel_optimizer=True).
963                        It supports the following keys.
964
965                        - gradient_accumulation_shard(bool): If ``true`` , the accumulation gradient parameters will be
966                          sharded across the data parallel devices. This will
967                          introduce additional communication(ReduceScatter) at
968                          each step when accumulate the gradients, but saves a
969                          lot of device memories, thus can make model be trained
970                          with larger batch size. This configure is effective only
971                          when the model runs on pipeline training or gradient
972                          accumulation with data parallel. Default ``False`` .
973
974                        - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel
975                          optimizer is enabled, parameters with size smaller than this threshold will not be sharded
976                          across the devices. Parameter size = shape[0] \* ... \* shape[n] \* size(dtype). Non-negative.
977                          Unit: KB. Default: ``64`` .
978
979                        - optimizer_weight_shard_size(int): Set the optimizer weight shard group size, if you want to
980                          specific the maximum group size across devices when the parallel optimizer is enabled.
981                          The numerical range can be (0, device_num]. If pipeline parallel is enabled, the numerical
982                          range is (0, device_num/stage]. If the size of data parallel communication domain
983                          of the parameter cannot be divided by `optimizer_weight_shard_size`, then the specified
984                          communication group size will not take effect. Default value is ``-1`` , which means the
985                          optimizer weight shard group size will be the size of data parallel group of each parameter.
986
987        comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
988                        communication fusion config has two keys: "mode" and "config".
989                        It supports following communication fusion types and configurations:
990
991                        - openstate: Whether turn on the communication fusion or not. If `openstate` is ``True`` ,
992                          turn on the communication fusion, otherwise, turn off the communication fusion.
993                          Default: ``True`` .
994
995                        - allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
996                          and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default
997                          fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size
998                          manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
999                          `all_reduce_fusion_config`.
1000
1001                        - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
1002                          In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
1003                          threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
1004                          manually, and the fusion threshold must be larger than `0` MB.
1005
1006                        - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
1007                          and `size`. Config is same as `allgather`.
1008
1009        strategy_ckpt_config (dict): A dict contains the configurations for setting the parallel strategy file. This
1010                        interface contains the functions of parameter `strategy_ckpt_load_file` and
1011                        `strategy_ckpt_save_file`, it is recommonded to use this parameter to replace those two
1012                        parameters.
1013                        It contains following configurations:
1014
1015                        - load_file (str): The path to load parallel strategy checkpoint. If the file name extension is
1016                          `.json`, the file is loaded in JSON format. Otherwise, the file is loaded in ProtoBuf
1017                          format.
1018                          Default: ``''``
1019
1020                        - save_file (str): The path to save parallel strategy checkpoint. If the file name extension is
1021                          `.json`, the file is saved in JSON format. Otherwise, the file is saved in ProtoBuf format.
1022                          Default: ``''``
1023
1024                        - only_trainable_params (bool): Only save/load the strategy information for trainable parameter.
1025                          Default: ``True`` .
1026        group_ckpt_save_file (str): The path to save parallel group checkpoint.
1027        auto_pipeline (bool): Set the pipeline stage number to automatic. Its value will be selected between 1 and the
1028                        parameter `pipeline_stages`. This option requires the `parallel_mode` to be ``auto_parallel``
1029                        and the `search_mode` to be ``recursive_programming``. Default: ``False`` .
1030
1031    Raises:
1032        ValueError: If input key is not attribute in auto parallel context.
1033
1034    Examples:
1035        >>> import mindspore as ms
1036        >>> ms.set_auto_parallel_context(device_num=8)
1037        >>> ms.set_auto_parallel_context(global_rank=0)
1038        >>> ms.set_auto_parallel_context(gradients_mean=True)
1039        >>> ms.set_auto_parallel_context(gradient_fp32_sync=False)
1040        >>> ms.set_auto_parallel_context(parallel_mode="auto_parallel")
1041        >>> ms.set_auto_parallel_context(search_mode="recursive_programming")
1042        >>> ms.set_auto_parallel_context(auto_parallel_search_mode="recursive_programming")
1043        >>> ms.set_auto_parallel_context(parameter_broadcast=False)
1044        >>> ms.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
1045        >>> ms.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
1046        >>> ms.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8)))
1047        >>> ms.set_auto_parallel_context(enable_parallel_optimizer=False)
1048        >>> ms.set_auto_parallel_context(enable_alltoall=False)
1049        >>> ms.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
1050        >>> ms.set_auto_parallel_context(pipeline_stages=2)
1051        >>> ms.set_auto_parallel_context(pipeline_stages=2, pipeline_result_broadcast=True)
1052        >>> parallel_config = {"gradient_accumulation_shard": True, "parallel_optimizer_threshold": 24,
1053        ...                    "optimizer_weight_shard_size": 2}
1054        >>> ms.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
1055        >>> config = {"allreduce": {"mode": "size", "config": 32}, "allgather": {"mode": "size", "config": 32}}
1056        >>> ms.set_auto_parallel_context(comm_fusion=config)
1057        >>> stra_ckpt_dict = {"load_file": "./stra0.ckpt", "save_file": "./stra1.ckpt", "only_trainable_params": False}
1058        >>> ms.set_auto_parallel_context(strategy_ckpt_config=stra_ckpt_dict)
1059    """
1060    _set_auto_parallel_context(**kwargs)
1061
1062
1063def get_auto_parallel_context(attr_key):
1064    """
1065    Get auto parallel context attribute value according to the key.
1066
1067    Args:
1068        attr_key (str): The key of the attribute.
1069
1070    Returns:
1071        Returns attribute value according to the key.
1072
1073    Raises:
1074        ValueError: If input key is not attribute in auto parallel context.
1075
1076    Examples:
1077        >>> import mindspore as ms
1078        >>> parallel_mode = ms.get_auto_parallel_context("parallel_mode")
1079        >>> dataset_strategy = ms.get_auto_parallel_context("dataset_strategy")
1080    """
1081    return _get_auto_parallel_context(attr_key)
1082
1083
1084def reset_auto_parallel_context():
1085    """
1086    Reset auto parallel context attributes to the default values.
1087
1088    - device_num: 1.
1089    - global_rank: 0.
1090    - gradients_mean: False.
1091    - gradient_fp32_sync: True.
1092    - parallel_mode: 'stand_alone'.
1093    - search_mode: 'recursive_programming'.
1094    - auto_parallel_search_mode: 'recursive_programming'.
1095    - parameter_broadcast: False.
1096    - strategy_ckpt_load_file: ''.
1097    - strategy_ckpt_save_file: ''.
1098    - full_batch: False.
1099    - enable_parallel_optimizer: False.
1100    - force_fp32_communication: False
1101    - enable_alltoall: False.
1102    - pipeline_stages: 1.
1103    - pipeline_result_broadcast: False.
1104    - fusion_threshold: 64.
1105    - auto_pipeline: False.
1106
1107    Examples:
1108        >>> import mindspore as ms
1109        >>> ms.reset_auto_parallel_context()
1110    """
1111    _reset_auto_parallel_context()
1112
1113
1114@args_type_check(offload_config=dict)
1115def set_offload_context(offload_config):
1116    r"""
1117    Configure heterogeneous training detailed parameters to adjust the offload strategy.
1118
1119    Note:
1120        The offload configuration is only used if the memory offload feature is enabled
1121        via mindspore.set_context(memory_offload="ON").
1122
1123    Args:
1124        offload_config (dict): A dict contains the keys and values for setting the offload context
1125            configure.It supports the following keys.
1126
1127            - offload_path (str):  The path of offload, relative paths are supported. Default: ``"./offload"``.
1128            - offload_cpu_size (str):  The cpu memory size for offload. The format is "xxGB".
1129            - offload_disk_size (str): The disk size for offload. The format is "xxGB"
1130            - hbm_ratio (float): The ratio that can be used based on the maximum device memory.
1131              The range is (0,1], Default: ``1.0``.
1132            - cpu_ratio (float): The ratio that can be used based on the maximum host memory.
1133              The range is (0,1], Default: ``1.0``.
1134            - enable_pinned_mem (bool): The flag of whether enabling Pinned Memory. Default: ``True``.
1135            - enable_aio (bool): The flag of whether enabling aio. Default: ``True``.
1136            - aio_block_size (str): The size of aio block. The format is "xxGB".
1137            - aio_queue_depth (int): The depth of aio queue.
1138            - offload_param (str):  The param for offload destination, cpu or disk, Default: ``""``.
1139            - offload_checkpoint (str):  The checkpoint for offload destination, only valid if recompute is turned on,
1140              cpu or disk, Default: ``""``.
1141            - auto_offload (bool): The flag of whether auto offload. Default: ``True``.
1142            - host_mem_block_size (str): The memory block size of host memory pool. The format is "xxGB"
1143
1144    Raises:
1145        ValueError: If input key is not attribute in auto parallel context.
1146
1147    Examples:
1148        >>> from mindspore import context
1149        >>> context.set_offload_context(offload_config={"offload_param":"cpu"})
1150    """
1151    _set_offload_context(offload_config)
1152
1153
1154def get_offload_context():
1155    """
1156    Gets the offload configuration parameters. Configure through interface mindspore.set_offload_context().
1157    If the user is not set, the default configuration is obtained.
1158
1159    Returns:
1160        Dict, heterogeneous training offload detailed configuration parameters.
1161
1162    Examples:
1163        >>> from mindspore import context
1164        >>> offload_config = context.get_offload_context()
1165    """
1166    return _get_offload_context()
1167
1168
1169def _check_target_specific_cfgs(device, arg_key):
1170    """Checking whether a config is suitable for a specified device"""
1171    device_cfgs = {
1172        'enable_graph_kernel': ['Ascend', 'GPU', 'CPU'],
1173        'graph_kernel_flags': ['Ascend', 'GPU', 'CPU'],
1174        'enable_reduce_precision': ['Ascend'],
1175        'print_file_path': ['Ascend'],
1176        'variable_memory_max_size': ['Ascend'],
1177        'max_device_memory': ['Ascend', 'GPU'],
1178        'mempool_block_size': ['GPU', 'Ascend'],
1179        'disable_format_transform': ['GPU'],
1180        'ascend_config': ['Ascend'],
1181        'gpu_config': ['GPU'],
1182    }
1183    # configs not in map device_cfgs are supposed to be suitable for all devices
1184    if arg_key not in device_cfgs:
1185        return True
1186    supported_devices = device_cfgs[arg_key]
1187    if device in supported_devices:
1188        return True
1189    logger.warning(f"For 'context.set_context', when set the argument '{arg_key}', "
1190                   f"the argument 'device_target' only supports devices in '{supported_devices}', "
1191                   f"but got '{device}', ignore it.")
1192    return False
1193
1194
1195def _check_ascend_device_context_initialized(device_target, settings):
1196    if device_target == 'Ascend' and is_initialized(device_target):
1197        for key, _ in settings.items():
1198            if key in ('ascend_config', 'deterministic', 'jit_compile', 'exception_dump', 'device_id'):
1199                logger.warning(f"For 'context.set_context' in Ascend backend, the backend is already initialized, "
1200                               "please set it before the definition of any Tensor and Parameter, and the "
1201                               "instantiation and execution of any operation and net, otherwise the settings may not "
1202                               "take effect. ")
1203                break
1204
1205
1206def _check_key(key):
1207    if key in ('precision_mode', 'jit_compile', 'atomic_clean_policy', 'matmul_allow_hf32', 'conv_allow_hf32',
1208               'op_precision_mode', 'host_scheduling_max_threshold', 'ge_options', 'op_debug_option'):
1209        raise ValueError(f"Please set '{key}' through parameter ascend_config")
1210
1211
1212@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=(bool, int),
1213                 save_graphs_path=str, enable_dump=bool, aoe_tune_mode=str, aoe_config=dict,
1214                 save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
1215                 enable_auto_mixed_precision=bool, inter_op_parallel_num=int,
1216                 enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
1217                 max_device_memory=str, print_file_path=str, max_call_depth=int, env_config_path=str,
1218                 graph_kernel_flags=str, save_compile_cache=bool, runtime_num_threads=int, load_compile_cache=bool,
1219                 grad_for_scalar=bool, pynative_synchronize=bool, mempool_block_size=str, disable_format_transform=bool,
1220                 op_timeout=int, deterministic=str, ascend_config=dict, jit_syntax_level=int, debug_level=int,
1221                 jit_enable_inplace_ops=bool, gpu_config=dict, jit_config=dict, enable_compile_cache=bool)
1222def set_context(**kwargs):
1223    """
1224    Set context for running environment.
1225
1226    Context should be configured before running your program. If there is no configuration,
1227    it will be automatically set according to the device target by default.
1228
1229    Note:
1230        Attribute name is required for setting attributes.
1231        The mode is not recommended to be changed after net was initialized because the implementations of some
1232        operations are different in graph mode and pynative mode. Default: ``PYNATIVE_MODE`` .
1233
1234    Some configurations are device specific, see the below table for details:
1235
1236    +-------------------------+------------------------------+----------------------------+
1237    | Function Classification |   Configuration Parameters   |   Hardware Platform Support|
1238    +=========================+==============================+============================+
1239    | System Configuration    |   device_id                  |   CPU/GPU/Ascend           |
1240    |                         +------------------------------+----------------------------+
1241    |                         |   device_target              |   CPU/GPU/Ascend           |
1242    |                         +------------------------------+----------------------------+
1243    |                         |  max_device_memory           |  GPU/Ascend                |
1244    |                         +------------------------------+----------------------------+
1245    |                         |  variable_memory_max_size    |  Ascend                    |
1246    |                         +------------------------------+----------------------------+
1247    |                         |  mempool_block_size          |  GPU/Ascend                |
1248    |                         +------------------------------+----------------------------+
1249    |                         |  op_timeout                  |  Ascend                    |
1250    +-------------------------+------------------------------+----------------------------+
1251    | Debug Configuration     |  save_graphs                 |  CPU/GPU/Ascend            |
1252    |                         +------------------------------+----------------------------+
1253    |                         |  save_graphs_path            |  CPU/GPU/Ascend            |
1254    |                         +------------------------------+----------------------------+
1255    |                         |  enable_dump                 |  Ascend                    |
1256    |                         +------------------------------+----------------------------+
1257    |                         |  save_dump_path              |  Ascend                    |
1258    |                         +------------------------------+----------------------------+
1259    |                         |  deterministic               |  Ascend                    |
1260    |                         +------------------------------+----------------------------+
1261    |                         |  print_file_path             |  Ascend                    |
1262    |                         +------------------------------+----------------------------+
1263    |                         |  env_config_path             |  CPU/GPU/Ascend            |
1264    |                         +------------------------------+----------------------------+
1265    |                         |  precompile_only             |  CPU/GPU/Ascend            |
1266    |                         +------------------------------+----------------------------+
1267    |                         |  reserve_class_name_in_scope |  CPU/GPU/Ascend            |
1268    |                         +------------------------------+----------------------------+
1269    |                         |  pynative_synchronize        |  CPU/GPU/Ascend            |
1270    |                         +------------------------------+----------------------------+
1271    |                         |  debug_level                 |  CPU/GPU/Ascend            |
1272    +-------------------------+------------------------------+----------------------------+
1273    | Executive Control       |   mode                       |   CPU/GPU/Ascend           |
1274    |                         +------------------------------+----------------------------+
1275    |                         |  enable_graph_kernel         |  Ascend/GPU                |
1276    |                         +------------------------------+----------------------------+
1277    |                         |  graph_kernel_flags          |  Ascend/GPU                |
1278    |                         +------------------------------+----------------------------+
1279    |                         |  enable_reduce_precision     |  Ascend                    |
1280    |                         +------------------------------+----------------------------+
1281    |                         |  aoe_tune_mode               |  Ascend                    |
1282    |                         +------------------------------+----------------------------+
1283    |                         |  aoe_config                  |  Ascend                    |
1284    |                         +------------------------------+----------------------------+
1285    |                         |  check_bprop                 |  CPU/GPU/Ascend            |
1286    |                         +------------------------------+----------------------------+
1287    |                         |  max_call_depth              |  CPU/GPU/Ascend            |
1288    |                         +------------------------------+----------------------------+
1289    |                         |  grad_for_scalar             |  CPU/GPU/Ascend            |
1290    |                         +------------------------------+----------------------------+
1291    |                         |  enable_compile_cache        |  CPU/GPU/Ascend            |
1292    |                         +------------------------------+----------------------------+
1293    |                         |  inter_op_parallel_num       |  CPU/GPU/Ascend            |
1294    |                         +------------------------------+----------------------------+
1295    |                         |  runtime_num_threads         |  CPU/GPU/Ascend            |
1296    |                         +------------------------------+----------------------------+
1297    |                         |  compile_cache_path          |  CPU/GPU/Ascend            |
1298    |                         +------------------------------+----------------------------+
1299    |                         |  disable_format_transform    |  GPU                       |
1300    |                         +------------------------------+----------------------------+
1301    |                         |  support_binary              |  CPU/GPU/Ascend            |
1302    |                         +------------------------------+----------------------------+
1303    |                         |  memory_optimize_level       |  CPU/GPU/Ascend            |
1304    |                         +------------------------------+----------------------------+
1305    |                         |  memory_offload              |  GPU/Ascend                |
1306    |                         +------------------------------+----------------------------+
1307    |                         |  ascend_config               |  Ascend                    |
1308    |                         +------------------------------+----------------------------+
1309    |                         |  jit_syntax_level            |  CPU/GPU/Ascend            |
1310    |                         +------------------------------+----------------------------+
1311    |                         |  gpu_config                  |  GPU                       |
1312    |                         +------------------------------+----------------------------+
1313    |                         |  jit_config                  |  CPU/GPU/Ascend            |
1314    +-------------------------+------------------------------+----------------------------+
1315
1316    Args:
1317        device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
1318            while device_num_per_host should be no more than 4096. Default: ``0`` .
1319        device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
1320            If device target is not set, the version of MindSpore package is used.
1321        max_device_memory (str): Set the maximum memory available for devices. The format is "xxGB".
1322            Default: ``" 1024GB"`` . The actual used memory size is the minimum of the available memory of the device
1323            and max_device_memory. 'max_device_memory' should be set before the program runs.
1324        variable_memory_max_size (str): This parameter is deprecated, and will be removed in a future version.
1325            Please use parameter 'max_device_memory' instead.
1326        mempool_block_size (str): Set the size of the memory pool block in PyNative mode or jit level is 'O0'/'O1'
1327            for devices. The format is "xxGB". Default: ``"1GB"`` . Minimum size is "1G". The actual used memory block
1328            size is the minimum of the available memory of the device and mempool_block_size.
1329        op_timeout (int): Set the maximum duration of executing an operator in seconds.
1330            If the execution time exceeds this value, system will terminate the task.
1331            0 means endless wait. The defaults for AI Core and AICPU operators vary on different hardware.
1332            For more information,
1333            please refer to `Ascend Community document about aclrtSetOpExecuteTimeOut
1334            <https://www.hiascend.com/document/detail/en/CANNCommunityEdition/600alphaX/infacldevg/aclcppdevg/aclcppdevg_03_0069.html>`_.
1335            Default: ``900`` .
1336        save_graphs (bool or int): Whether to save intermediate compilation graphs. Default: ``0`` .
1337            Available values are:
1338
1339            - False or 0: disable saving of intermediate compilation graphs.
1340            - 1: some intermediate files will be generated during graph compilation.
1341            - True or 2: Generate more ir files related to backend process.
1342            - 3: Generate visualization computing graphs and detailed frontend ir graphs.
1343
1344            When the network structure is complex, setting `save_graphs` attribute to ``2`` or ``3`` may take too long.
1345            If you need quick problem locating, you can switch to ``1`` first.
1346
1347            When the `save_graphs` attribute is set as ``True`` , ``1`` , ``2`` or ``3`` , attribute of
1348            `save_graphs_path` is used to set the intermediate compilation graph storage path. By default, the graphs
1349            are saved in the current directory.
1350        save_graphs_path (str): Path to save graphs. Default: ``"."``.
1351            If the specified directory does not exist, the system will automatically create the directory.
1352            During distributed training, graphs will be saved to the directory of
1353            `save_graphs_path/rank_${rank_id}/`. `rank_id` is the ID of the current device in the cluster.
1354        deterministic (str): Whether to enable op run in deterministic mode. The value must be in the
1355            range of ['ON', 'OFF'], and the default value is ``'OFF'`` .
1356
1357            - "ON": Enable operator deterministic running mode.
1358            - "OFF": Disable operator deterministic running mode.
1359
1360            When deterministic mode is on, model ops will be deterministic in Ascend. This means that if op run
1361            multiple times with the same inputs on the same hardware, it will have the exact same outputs each time.
1362            This is useful for debugging models.
1363        enable_dump (bool): This parameters is deprecated, and will be deleted in the next version.
1364        save_dump_path (str): This parameters is deprecated, and will be deleted in the next version.
1365        print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to
1366            a file by default, and print_file_path is not set, the screen will be displayed.
1367            If the saved file already exists, the timestamp suffix will be added to the file. Saving data to a file
1368            solves the problem of data loss in screen printing when a large amount of data is generated.
1369            If it is not set, an error will be reported: prompt to set the upper absolute path.
1370            When print data to file, the total output bytes of single print must be less then 2GB(limited by
1371            protobuf).
1372        env_config_path (str): Config path for DFX.
1373            Through mindspore.set_context(env_config_path="./mindspore_config.json")
1374
1375            configure RDR:
1376
1377            - enable: controls whether the RDR is enabled to collect the key data during training and
1378              save key data in the fault scenario. When set to ``true`` , the RDR will be turned on.
1379              When set to ``false`` , the RDR will be turned off.
1380            - mode: sets the mode of RDR on exporting data. When set to ``1`` , the RDR only exports data
1381              in the fault scenario. When set to ``2`` , the RDR exports data in the fault scenario and the
1382              normal end scenario. Default: ``1`` .
1383            - path: sets the path where RDR saves data. The current path must be absolute.
1384
1385            Memory reuse:
1386
1387            - mem_Reuse: controls whether the memory reuse function is turned on. When set to ``True`` ,
1388              the memory reuse function is turned on. When set to ``False`` , the memory reuse function is turned off.
1389
1390        precompile_only (bool): Whether to only precompile the network. Default: ``False`` .
1391            If set to ``True`` , the network will only be compiled, not executed.
1392        reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: ``True`` .
1393            Each node has a scope. A scope of a subnode is the name of its parent node. If reserve_class_name_in_scope
1394            is set to ``True`` , the class name will be saved after keyword 'net-' in the scope.
1395            For example:
1396
1397            Default/net-Net1/net-Net2 (reserve_class_name_in_scope=True)
1398
1399            Default/net/net (reserve_class_name_in_scope=False)
1400
1401        pynative_synchronize (bool): Whether to enable synchronous execution of the device in PyNative mode.
1402            Default: ``False`` . When the value is set to ``False`` , the operator is executed asynchronously on the
1403            device. When an error occurs in the execution of the operator, the specific error script code location
1404            cannot be located, when the value is set to ``True`` , the operator is executed synchronously on the
1405            device. It will reduce the execution performance of the program. At this time, when an error occurs in the
1406            execution of the operator, the location of the error script code can be located according to the call stack
1407            of the error.
1408        mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
1409            Both modes support all backends. Default: ``PYNATIVE_MODE`` .
1410        enable_graph_kernel (bool): Whether to enable graph kernel fusion to optimize network execution performance.
1411            Default: ``False`` .
1412            Indicates whether to enable image-computing convergence to optimize network execution performance.
1413            If enable_graph_kernel is set to ``True`` , acceleration can be enabled.
1414            For details of graph kernel fusion, please check
1415            `Enabling Graph Kernel Fusion
1416            <https://www.mindspore.cn/tutorials/experts/en/master/optimize/graph_fusion_engine.html>`_.
1417        graph_kernel_flags (str):
1418            Optimization options of graph kernel fusion, and the priority is higher when it conflicts
1419            with enable_graph_kernel. Only for experienced users.
1420            For example,
1421
1422            .. code-block::
1423
1424                mindspore.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text")
1425
1426            Some general options:
1427
1428            - opt_level: Set the optimization level.
1429              Default: ``2`` . Graph kernel fusion can be enabled equivalently by setting opt_level greater than 0.
1430              Available values are:
1431
1432              - 0: disables graph kernel fusion;
1433              - 1: enables the basic fusion of operators;
1434              - 2: includes all optimizations of level 1,
1435                and turns on more optimizations such as CSE, arithmetic simplification and so on;
1436              - 3: includes all optimizations of level 2, and turns on more optimizations such as SitchingFusion,
1437                ParallelFusion and so on. Optimizations of this level are radical and unstable in some scenarios.
1438                Be caution when using this level.
1439
1440            - dump_as_text: dumps detail info as text files. Default: ``False`` .
1441
1442        enable_reduce_precision (bool): Whether to enable precision reduction.
1443            If the operator does not support the user-specified precision, the precision will
1444            be changed automatically. Default: ``True`` .
1445        aoe_tune_mode (str): AOE tuning mode setting, which is not set by default.
1446            When set to ``"online"`` , the tuning in online function is turned on.
1447            When set to ``"offline"`` , ge graph will be save for offline tuning.
1448        aoe_config (dict): Set the parameters specific to Ascend Optimization Engine. It is not set by default.
1449
1450            - job_type (str): Mode type setting, default value is ``"2"``.
1451
1452              - ``"1"``: subgraph tuning;
1453              - ``"2"``: operator tuning.
1454
1455        check_bprop (bool): Whether to check back propagation nodes. The checking ensures that the shape and dtype
1456            of back propagation node outputs is the same as input parameters. Default: ``False`` .
1457        max_call_depth (int): Specify the maximum depth of function call. Must be positive integer. Default: ``1000`` .
1458            The max_call_depth parameter needs to be set when the nested call is too deep or the number
1459            of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be
1460            set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow.
1461        grad_for_scalar (bool):  Whether to get gradient for scalar. Default: ``False`` .
1462            When grad_for_scalar is set to ``True`` , the function's scalar input can be derived.
1463            The default value is ``False`` . Because the back-end does not support scaling operations currently,
1464            this interface only supports simple operations that can be deduced by the front-end.
1465        enable_compile_cache (bool): Whether to save or load the cache of the graph compiled by front-end.
1466            After enable_compile_cache is set to ``True`` , during the first execution, a hardware-independent
1467            compilation cache is generated and exported to a MINDIR file. When the network is executed again,
1468            if enable_compile_cache is still set to ``True`` and the network scripts are not changed,
1469            the compile cache is loaded. Note that only limited automatic detection for the changes of
1470            python scripts is supported by now, which means that there is a correctness risk. Default: ``False`` .
1471            This is an experimental prototype that is subject to change and/or deletion.
1472        compile_cache_path (str): Path to save the compile cache. Default: ``"."``.
1473            If the specified directory does not exist, the system will automatically create the directory.
1474            The cache will be saved to the directory of `compile_cache_path/rank_${rank_id}/`. The `rank_id` is
1475            the ID of the current device in the cluster.
1476        inter_op_parallel_num(int): The thread number of op parallel at the same time. Default value is ``0`` ,
1477            which means use the default num.
1478        runtime_num_threads(int): The thread pool number of cpu kernel used in runtime,
1479            which must bigger than or equal to 0. Default value is ``30`` , if you run many processes at
1480            the same time, you should set the value smaller to avoid thread contention.
1481        disable_format_transform (bool): Whether to disable the automatic format transform function from NCHW to NHWC.
1482            When the network training performance of fp16 is worse than fp32, `disable_format_transform` can be set to
1483            ``True`` to try to improve training performance. Default: ``False`` .
1484        support_binary (bool): Whether to support run .pyc or .so in graph mode. If want to support run .so or .pyc
1485            in graph mode, coulde set 'support_binary' to be ``True`` , and run once .py file. It would save the source
1486            of the interfaces would be compiled by MindSpore to the interfaces definition .py file that should be
1487            guaranteed to be writable. Then compile the .py file to the .pyc or .so file, and could run in Graph mode.
1488        memory_optimize_level (str): The memory optimize level.
1489            On Ascend hardware platform, default: ``O1``, on other hardware platforms, default: ``O0``.
1490            The value must be in ['O0', 'O1'].
1491
1492            - O0: priority performance option, disable SOMAS (Safe Optimized Memory Allocation Solver)
1493              and some other memory optimizations.
1494            - O1: priority memory option, enable SOMAS and some other memory optimizations.
1495        memory_offload (str): Whether to enable the memory offload function. When it is enabled, the idle data will be
1496            temporarily copied to the host side in the case of insufficient device memory. The value must be in the
1497            range of ['ON', 'OFF'], and the default value is ``'OFF'`` .
1498
1499            - ON: Enable the memory Offload function. On Ascend hardware platform, this parameter does not take effect
1500              when the graph compilation level is not 'O0'; This parameter does not take effect when
1501              memory_optimize_level is set 'O1'.
1502            - OFF: Turn off the memory Offload function.
1503        ascend_config (dict): Set the parameters specific to Ascend hardware platform. It is not set by default.
1504            The default value of `precision_mode`, `jit_compile` and
1505            `atomic_clean_policy` are experimental parameters, may change in the future.
1506
1507            - precision_mode (str): Mixed precision mode setting, and the default value of inference network
1508              is ``force_fp16`` . The value range is as follows:
1509
1510              - force_fp16: When the operator supports both float16 and float32, select float16 directly.
1511              - allow_fp32_to_fp16: For cube operators, use the float16. For vector operators,
1512                prefer to keep the origin dtype, if the operator in model can support float32,
1513                it will keep original dtype, otherwise it will reduce to float16.
1514              - allow_mix_precision: Automatic mixing precision, facing the whole network operator, according
1515                to the built-in optimization strategy, automatically reduces the precision of some operators
1516                to float16 or bfloat16.
1517              - must_keep_origin_dtype: Keep the accuracy of the original drawing.
1518              - force_fp32: When the input of the matrix calculation operator is float16 and the output supports
1519                float16 and float32, output is forced to float32.
1520              - allow_fp32_to_bf16: For cube operators, use the bfloat16. For vector operators,
1521                prefer to keep the origin dtype, if the operator in model can support float32,
1522                it will keep original dtype, otherwise it will reduce to bfloat16.
1523              - allow_mix_precision_fp16: Automatic mixing precision, facing the whole network operator, automatically
1524                reduces the precision of some operators to float16 according to the built-in optimization strategy.
1525              - allow_mix_precision_bf16: Automatic mixing precision, facing the whole network operator, according to
1526                the built-in optimization strategy, automatically reduces the precision of some operators to bfloat16.
1527
1528            - jit_compile (bool): Whether to select online compilation. When set to 'True', online compilation is
1529              prioritized. When set to 'False', compiled operator binary files are prioritized to improve compilation
1530              performance. The default settings are online compilation for static shape, and compiled operator binary
1531              files for dynamic shape.
1532            - atomic_clean_policy (int): The policy for cleaning memory occupied by atomic operators in the network.
1533              Default: ``1`` .
1534
1535              - 0: The memory occupied by all atomic operators in the network is cleaned centrally.
1536              - 1: Memory is not cleaned centrally and each atomic operator in the network is cleaned separately.
1537                When the memory of the network exceeds the limit, you may try this cleaning policy, but it may cause
1538                performance loss.
1539            - matmul_allow_hf32 (bool): Whether to convert FP32 to HF32 for Matmul operators. Default value: ``False``.
1540              This is an experimental prototype that is subject to change and/or deletion.
1541              For detailed information, please refer to `Ascend community <https://www.hiascend.com/>`_ .
1542            - conv_allow_hf32 (bool): Whether to convert FP32 to HF32 for Conv operators. Default value: ``True``.
1543              This is an experimental prototype that is subject to change and/or deletion.
1544              For detailed information, please refer to `Ascend community <https://www.hiascend.com/>`_ .
1545            - exception_dump (str): Enable exception dump for Ascend operators, providing the input and output data for
1546              failing Ascend operators. The value can be ``"0"`` , ``"1"`` and ``"2"``. For ``"0"`` , exception dump is
1547              turned off; for ``"1"``, all inputs and outputs will be dumped for AICore exception operators;
1548              for ``"2"``, inputs will be dumped for AICore exception operators, reducing the saved information
1549              but improving performance. Default: ``"2"`` .
1550            - op_precision_mode (str): Path to config file of op precision mode. For detailed information, please refer
1551              to `Ascend community <https://www.hiascend.com/>`_ .
1552            - op_debug_option (str): Enable debugging options for Ascend operators, default not enabled.
1553              The value currently only supports being set to ``"oom"``.
1554
1555              - ``"oom"``: When there is a memory out of bounds during the execution of an operator,
1556                AscendCL will return an error code of ``EZ9999``.
1557
1558            - ge_options (dict): Set options for CANN. The options are divided into two categories: global and session.
1559              This is an experimental prototype that is subject to change and/or deletion.
1560              For detailed information, please refer to `Ascend community <https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/inferapplicationdev/graphdevg/atlasgeapi_07_0119.html>`_ .
1561              The configuration options in `ge_options` may be duplicated with the options in `ascend_config`. If the
1562              same configuration options are set in both `ascend_config` and `ge_options`, the one set in `ge_options`
1563              shall prevail.
1564
1565              - global (dict): Set global options.
1566              - session (dict): Set session options.
1567
1568            - parallel_speed_up_json_path(Union[str, None]): The path to the parallel speed up json file, configuration
1569              can refer to `parallel_speed_up.json
1570              <https://gitee.com/mindspore/mindspore/blob/master/config/parallel_speed_up.json>`_ .
1571              If its value is None or '', it does not take effect. Default None.
1572
1573              - recompute_comm_overlap (bool): Enable overlap between recompute ops and communication ops if True.
1574                Default: False.
1575              - matmul_grad_comm_overlap (bool): Enable overlap between dw matmul and
1576                tensor parallel communication ops if True. Default: False.
1577              - recompute_allgather_overlap_fagrad (bool): Enable overlap between duplicated allgather by recomputing
1578                in sequence parallel and flashattentionscoregrad ops if True. Default: False.
1579              - enable_task_opt (bool): Enable communication fusion to optimize the number of communication operator
1580                tasks if True.
1581                Default: False.
1582              - enable_grad_comm_opt (bool): Enable overlap between dx ops and data parallel communication ops if True.
1583                Currently, do not support
1584                `LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html>`
1585                Default: False.
1586              - enable_opt_shard_comm_opt (bool): Enable overlap between forward ops
1587                and optimizer parallel allgather communication if True. Currently, do not support
1588                `LazyInline <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.lazy_inline.html>`
1589                Default: False.
1590              - compute_communicate_fusion_level (int): Enable the fusion between compute and communicate.
1591                Default: ``0``.
1592
1593                - 0: Disable fusion.
1594
1595                - 1: Apply fusion to forward nodes.
1596
1597                - 2: Apply fusion to backward nodes.
1598
1599                - 3: Apply fusion to all nodes.
1600              - bias_add_comm_swap (bool): Enable node execution order swap communication operators and add operators
1601                if ``True``. Only 1-dimension bias node is supported. Default: ``False``.
1602            - host_scheduling_max_threshold(int): The max threshold to control whether the dynamic shape process is
1603              used when run the static graph, the default value is 0. When the number of operations in the static graph
1604              is less than the max threshold, this graph will be executed in dynamic shape process. In large model
1605              scenarios, this approach can save stream resources. If the number of operations in the static graph is
1606              greater than the maximum threshold, this graph will be executed in original static process.
1607
1608        jit_syntax_level (int): Set JIT syntax level for graph compiling, triggered by GRAPH_MODE and @jit decorator.
1609            The value must be ``STRICT`` or ``LAX`` . Default: ``LAX`` . All levels support all backends.
1610
1611            - ``STRICT`` : Only basic syntax is supported, and execution performance is optimal. Can be used for MindIR
1612              load and export.
1613            - ``LAX`` : Compatible with all Python syntax as much as possible. However, execution performance may be
1614              affected and not optimal. Cannot be used for MindIR load and export due to some syntax that may not be
1615              able to be exported.
1616
1617        debug_level (int): Set config for debugging. Default value: ``RELEASE``.
1618
1619            - ``RELEASE``: Used for normally running, and some debug information will be discard to get a better
1620              compiling performance.
1621            - ``DEBUG``: Used for debugging when errors occur, more information will be record in compiling process.
1622
1623        gpu_config (dict): Set the parameters specific to gpu hardware platform. It is not set by default.
1624            Currently, only setting `conv_fprop_algo` and `conv_dgrad_algo` and `conv_wgrad_algo` and `conv_allow_tf32`
1625            and `matmul_allow_tf32` are supported on GPU hardware platform.
1626
1627            - conv_fprop_algo (str): Specifies convolution forward algorithm and the default value is 'normal',
1628              The value range is as follows:
1629
1630              - normal: Use the heuristic search algorithm.
1631              - performance: Use the trial search algorithm.
1632              - implicit_gemm: This algorithm expresses the convolution as a matrix product without actually explicitly
1633                forming the matrix that holds the input tensor data.
1634              - implicit_precomp_gemm: This algorithm expresses convolution as a matrix product without actually
1635                explicitly forming the matrix that holds the input tensor data, but still needs some memory workspace to
1636                precompute some indices in order to facilitate the implicit construction of the matrix that holds the
1637                input tensor data.
1638              - gemm: This algorithm expresses the convolution as an explicit matrix product. A significant memory
1639                workspace is needed to store the matrix that holds the input tensor data.
1640              - direct: This algorithm expresses the convolution as a direct convolution (for example, without
1641                implicitly or explicitly doing a matrix multiplication).
1642              - fft: This algorithm uses the Fast-Fourier Transform approach to compute the convolution. A significant
1643                memory workspace is needed to store intermediate results.
1644              - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles.
1645                A significant memory workspace is needed to store intermediate results but less than fft algorithm for
1646                large size images.
1647              - winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably
1648                sized workspace is needed to store intermediate results.
1649              - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution. A
1650                significant workspace may be needed to store intermediate results.
1651            - conv_dgrad_algo (str): Specifies convolution data grad algorithm and the default value is 'normal',
1652              The value range is as follows:
1653
1654              - normal: Use the heuristic search algorithm.
1655              - performance: Use the trial search algorithm.
1656              - algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly
1657                forming the matrix that holds the input tensor data. The sum is done using the atomic add operation,
1658                thus the results are non-deterministic.
1659              - algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming
1660                the matrix that holds the input tensor data. The results are deterministic.
1661              - fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant
1662                memory workspace is needed to store intermediate results. The results are deterministic.
1663              - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles.
1664                A significant memory workspace is needed to store intermediate results but less than fft for large size
1665                images. The results are deterministic.
1666              - winograd: This algorithm uses the Winograd Transform approach to compute the convolution. A reasonably
1667                sized workspace is needed to store intermediate results. The results are deterministic.
1668              - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution.
1669                A significant workspace may be needed to store intermediate results. The results are deterministic.
1670            - conv_wgrad_algo (str): Specifies convolution filter grad algorithm and the default value is 'normal',
1671              The value range is as follows:
1672
1673              - normal: Use the heuristic search algorithm.
1674              - performance: Use the trial search algorithm.
1675              - algo_0: This algorithm expresses the convolution as a sum of matrix products without actually explicitly
1676                forming the matrix that holds the input tensor data. The sum is done using the atomic add operation,
1677                thus the results are non-deterministic.
1678              - algo_1: This algorithm expresses the convolution as a matrix product without actually explicitly forming
1679                the matrix that holds the input tensor data. The results are deterministic.
1680              - fft: This algorithm uses a Fast-Fourier Transform approach to compute the convolution. A significant
1681                memory workspace is needed to store intermediate results. The results are deterministic.
1682              - algo_3: This algorithm is similar to algo_0 but uses some small workspace to precompute some indices.
1683                The results are also non-deterministic.
1684              - winograd_nonfused: This algorithm uses the Winograd Transform approach to compute the convolution.
1685                A significant workspace may be needed to store intermediate results. The results are deterministic.
1686              - fft_tiling: This algorithm uses the Fast-Fourier Transform approach but splits the inputs into tiles.
1687                A significant memory workspace is needed to store intermediate results but less than fft for large size
1688                images. The results are deterministic.
1689            - conv_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUDNN and the
1690              default value is ``True``.
1691            - matmul_allow_tf32 (bool): The flag below controls to allow Tensor core TF32 computation on CUBLAS and the
1692              default value is ``False``.
1693
1694        jit_config (dict): Set the global jit config for compile, take effect in network defined in Cell or jit
1695            decorators. It is not set by default.
1696            The setting in context is the global jit config, while JitConfig is the local network's jit config.
1697            When both exist simultaneously, the global jit config will not overwrite the local network's jit config.
1698
1699            - jit_level (str): Used to control the compilation optimization level. Default: ``""`` , The framework
1700              automatically selects the execution method based on product, Altas training product is O2, and all other
1701              products are O0. The value range is as follows:
1702
1703              - ``"O0"``: Except for optimizations that may affect functionality, all other optimizations are turned
1704                off, adopt KernelByKernel execution mode.
1705              - ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations,
1706                adopt KernelByKernel execution mode.
1707              - ``"O2"``: Ultimate performance optimization, adopt Sink execution mode.
1708
1709            - infer_boost (str): Used to control the infer mode. Default: ``"off"`` . The value range is as follows:
1710
1711              - ``"on"``: Enable infer mode, get better infer performance.
1712              - ``"off"``: Disable infer mode, use forward to infer, performance is not good.
1713
1714    Raises:
1715        ValueError: If input key is not an attribute in context.
1716
1717    Examples:
1718        >>> import mindspore as ms
1719        >>> ms.set_context(mode=ms.PYNATIVE_MODE)
1720        >>> ms.set_context(precompile_only=True)
1721        >>> ms.set_context(device_target="Ascend")
1722        >>> ms.set_context(device_id=0)
1723        >>> ms.set_context(save_graphs=True, save_graphs_path="./model.ms")
1724        >>> ms.set_context(enable_reduce_precision=True)
1725        >>> ms.set_context(enable_graph_kernel=True)
1726        >>> ms.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text")
1727        >>> ms.set_context(reserve_class_name_in_scope=True)
1728        >>> ms.set_context(variable_memory_max_size="6GB")
1729        >>> ms.set_context(aoe_tune_mode="online")
1730        >>> ms.set_context(aoe_config={"job_type": "2"})
1731        >>> ms.set_context(check_bprop=True)
1732        >>> ms.set_context(max_device_memory="3.5GB")
1733        >>> ms.set_context(mempool_block_size="1GB")
1734        >>> ms.set_context(print_file_path="print.pb")
1735        >>> ms.set_context(max_call_depth=80)
1736        >>> ms.set_context(env_config_path="./env_config.json")
1737        >>> ms.set_context(grad_for_scalar=True)
1738        >>> ms.set_context(enable_compile_cache=True, compile_cache_path="./cache.ms")
1739        >>> ms.set_context(pynative_synchronize=True)
1740        >>> ms.set_context(runtime_num_threads=10)
1741        >>> ms.set_context(inter_op_parallel_num=4)
1742        >>> ms.set_context(disable_format_transform=True)
1743        >>> ms.set_context(memory_optimize_level='O0')
1744        >>> ms.set_context(memory_offload='ON')
1745        >>> ms.set_context(deterministic='ON')
1746        >>> ms.set_context(ascend_config={"precision_mode": "force_fp16", "jit_compile": True,
1747        ...                "atomic_clean_policy": 1, "op_precision_mode": "./op_precision_config_file",
1748        ...                "op_debug_option": "oom",
1749        ...                "ge_options": {"global": {"ge.opSelectImplmode": "high_precision"},
1750        ...                               "session": {"ge.exec.atomicCleanPolicy": "0"}}})
1751        >>> ms.set_context(jit_syntax_level=ms.STRICT)
1752        >>> ms.set_context(debug_level=ms.context.DEBUG)
1753        >>> ms.set_context(gpu_config={"conv_fprop_algo": "performance", "conv_allow_tf32": True,
1754        ...                "matmul_allow_tf32": True})
1755        >>> ms.set_context(jit_config={"jit_level": "O0"})
1756    """
1757    ctx = _context()
1758    # set device target first
1759    if 'device_target' in kwargs:
1760        ctx.set_device_target(kwargs['device_target'])
1761    device = ctx.get_param(ms_ctx_param.device_target)
1762    _check_ascend_device_context_initialized(device, kwargs)
1763
1764    for key, value in kwargs.items():
1765        if key in ('enable_sparse', 'auto_tune_mode'):
1766            logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated, "
1767                           "and will be removed in the next version.")
1768            continue
1769        if key in ('enable_auto_mixed_precision', 'enable_dump', 'save_dump_path'):
1770            logger.warning(f"For 'context.set_context', '{key}' parameter is deprecated. "
1771                           "For details, please see the interface parameter API comments")
1772            continue
1773        _check_key(key)
1774        if key == 'save_graphs':
1775            if value is True:
1776                value = 2
1777            if value is False:
1778                value = 0
1779            if value > 3:
1780                raise ValueError(f"value for save_graphs should be 0-3 but got '{value}'")
1781        if key == 'jit_syntax_level' and value not in (STRICT, COMPATIBLE, LAX):
1782            raise ValueError(f"For 'jit_syntax_level', the value should be context.STRICT"
1783                             f" or context.LAX, but got {value}.")
1784        if key == 'debug_level' and value not in (RELEASE, DEBUG):
1785            raise ValueError(f"For 'debug_level', the value should be context.DEBUG"
1786                             f" or context.RELEASE, but got {value}.")
1787        if key == 'enable_compile_cache':
1788            setattr(ctx, key, value)
1789            ctx.set_param(ms_ctx_param.__members__[key], int(value))
1790            continue
1791        if not _check_target_specific_cfgs(device, key):
1792            continue
1793        if hasattr(ctx, key):
1794            setattr(ctx, key, value)
1795            continue
1796        if key in ctx.setters:
1797            ctx.setters[key](ctx, value)
1798            continue
1799        # enum variables beginning with '_' are for internal use
1800        if key in ms_ctx_param.__members__ and key[0] != '_':
1801            ctx.set_param(ms_ctx_param.__members__[key], value)
1802            continue
1803        raise ValueError(f"For 'context.set_context', the keyword argument {key} is not recognized! For detailed "
1804                         f"usage of 'set_context', please refer to the Mindspore official website.")
1805
1806
1807def get_context(attr_key):
1808    """
1809    Get context attribute value according to the input key.
1810    If some attributes are not set, they will be automatically obtained.
1811
1812    Args:
1813        attr_key (str): The key of the attribute.
1814
1815    Returns:
1816        Object, The value of given attribute key.
1817
1818    Raises:
1819        ValueError: If input key is not an attribute in context.
1820    Examples:
1821        >>> import mindspore as ms
1822        >>> ms.get_context("device_target")
1823        >>> ms.get_context("device_id")
1824    """
1825    ctx = _context()
1826    device = ctx.get_param(ms_ctx_param.device_target)
1827    _ = _check_target_specific_cfgs(device, attr_key)
1828    if hasattr(ctx, attr_key):
1829        return getattr(ctx, attr_key)
1830    # enum variables beginning with '_' are for internal use
1831    if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
1832        return ctx.get_param(ms_ctx_param.__members__[attr_key])
1833    raise ValueError(f"For 'context.get_context', the argument {attr_key} is not recognized! For detailed "
1834                     f"usage of 'get_context', please refer to the Mindspore official website.")
1835
1836
1837def _get_mode():
1838    """
1839    Get execution mode. Only for internal using.
1840
1841    Returns:
1842        Object: The Value of execution mode.
1843    """
1844    ctx = _context()
1845    return ctx.get_mode()
1846
1847
1848def get_jit_config():
1849    """
1850    Get global jit config.
1851
1852    Returns:
1853        Object: The Value of jit config.
1854    """
1855    ctx = _context()
1856    return ctx.get_jit_config()
1857
1858
1859class ParallelMode:
1860    """
1861    Parallel mode options.
1862
1863    There are five kinds of parallel modes, ``STAND_ALONE``, ``DATA_PARALLEL``,
1864    ``HYBRID_PARALLEL``, ``SEMI_AUTO_PARALLEL`` and ``AUTO_PARALLEL``. Default: ``STAND_ALONE``.
1865
1866    - ``STAND_ALONE``: Only one processor is working.
1867    - ``DATA_PARALLEL``: Distributes the data across different processors.
1868    - ``HYBRID_PARALLEL``: Achieves data parallelism and model parallelism manually.
1869    - ``SEMI_AUTO_PARALLEL``: Achieves data parallelism and model parallelism by setting parallel strategies.
1870    - ``AUTO_PARALLEL``: Achieves parallelism automatically.
1871
1872    ``MODE_LIST``: The list of all supported parallel modes.
1873    """
1874
1875    STAND_ALONE = "stand_alone"
1876    DATA_PARALLEL = "data_parallel"
1877    HYBRID_PARALLEL = "hybrid_parallel"
1878    SEMI_AUTO_PARALLEL = "semi_auto_parallel"
1879    AUTO_PARALLEL = "auto_parallel"
1880    MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
1881
1882
1883@args_type_check(enable_ps=bool)
1884def set_ps_context(**kwargs):
1885    """
1886    Set parameter server training mode context.
1887
1888    Note:
1889        Parameter server mode is only supported in graph mode.
1890        Some other environment variables should also be set for parameter server training mode.
1891        These environment variables are listed below:
1892
1893        - MS_SERVER_NUM: Server number
1894        - MS_WORKER_NUM: Worker number
1895        - MS_SCHED_HOST: Scheduler IP address
1896        - MS_SCHED_PORT: Scheduler port
1897        - MS_ROLE: The role of this process:
1898
1899          - MS_SCHED: represents the scheduler,
1900          - MS_WORKER: represents the worker,
1901          - MS_PSERVER/MS_SERVER: represents the Server
1902
1903    Args:
1904        enable_ps (bool): Whether to enable parameter server training mode.
1905                          Only after enable_ps is set True, the environment variables will be effective.
1906                          Default: ``False`` .
1907        config_file_path (string): Configuration file path used by recovery, parameter server training mode only
1908                                   supports Server disaster recovery currently. Default: ``''`` .
1909        scheduler_manage_port (int): Scheduler manage port used to scale out/in. Default: ``11202`` .
1910        enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False`` .
1911        client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ``''`` .
1912        server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ``''`` .
1913
1914    Raises:
1915        ValueError: If input key is not the attribute in parameter server training mode context.
1916
1917    Examples:
1918        >>> import mindspore as ms
1919        >>> ms.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
1920    """
1921    _set_ps_context(**kwargs)
1922
1923
1924def get_ps_context(attr_key):
1925    """
1926    Get parameter server training mode context attribute value according to the key.
1927
1928    Args:
1929        attr_key (str): The key of the attribute:
1930
1931            - enable_ps (bool): Whether to enable parameter server training mode. Default: ``False`` .
1932            - config_file_path (string): Configuration file path used by recovery, parameter server training mode only
1933              supports Server disaster recovery currently. Default: ``''`` .
1934            - scheduler_manage_port (int): Scheduler manage port used to scale out/in. Default: ``11202`` .
1935            - enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False`` .
1936            - client_password (str): Password to decrypt the secret key stored in the client certificate.
1937              Default: ``''`` .
1938            - server_password (str): Password to decrypt the secret key stored in the server certificate.
1939              Default: ``''`` .
1940
1941    Returns:
1942        Returns attribute value according to the key.
1943
1944    Raises:
1945        ValueError: If input key is not attribute in auto parallel context.
1946
1947    Examples:
1948        >>> import mindspore as ms
1949        >>> ms.get_ps_context("enable_ps")
1950    """
1951    return _get_ps_context(attr_key)
1952
1953
1954def reset_ps_context():
1955    """
1956    Reset parameter server training mode context attributes to the default values.
1957
1958    Meaning of each field and its default value refer to :func:`mindspore.set_ps_context`.
1959
1960    Examples:
1961        >>> import mindspore as ms
1962        >>> ms.reset_ps_context()
1963    """
1964    _reset_ps_context()
1965
1966
1967_hccl_connect_timeout = '600'
1968
1969
1970def _init_parallel_env():
1971    """Set hccl connect timeout."""
1972    if 'HCCL_CONNECT_TIMEOUT' not in os.environ:
1973        os.environ['HCCL_CONNECT_TIMEOUT'] = _hccl_connect_timeout
1974
1975
1976_init_parallel_env()
1977