• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16The context of mindspore, used to configure the current execution environment,
17includes the execution mode, execution backend and other feature switches.
18"""
19import json
20import os
21import time
22import threading
23from collections import namedtuple
24from types import FunctionType
25
26from mindspore import log as logger
27from mindspore._c_expression import MSContext, ms_ctx_param
28from mindspore._checkparam import args_type_check, Validator, args_unreset_check
29from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
30    _reset_auto_parallel_context
31from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context
32from .default_config import __device_target__, __package_name__
33
34__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context',
35           'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context',
36           'get_ps_context', 'reset_ps_context', 'set_fl_context', 'get_fl_context']
37
38GRAPH_MODE = 0
39PYNATIVE_MODE = 1
40_DEVICE_APP_MEMORY_SIZE = 31  # The max memory size of graph plus variable.
41_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
42_k_context = None
43
44
45def _make_directory(path):
46    """Make directory."""
47    real_path = None
48    if path is None or not isinstance(path, str) or path.strip() == "":
49        raise ValueError(f"Input path `{path}` is invalid type")
50
51    # convert the relative paths
52    path = os.path.realpath(path)
53    logger.debug("The absolute path is %r", path)
54
55    # check whether the path is already existed and has written permissions
56    if os.path.exists(path):
57        real_path = path
58    else:
59        # All exceptions need to be caught because create directory maybe have some limit(permissions)
60        logger.debug("The directory(%s) doesn't exist, will create it", path)
61        try:
62            os.makedirs(path)
63            real_path = path
64        except PermissionError as e:
65            logger.error(f"No write permission on the directory `{path}, error = {e}")
66            raise ValueError(f"No write permission on the directory `{path}`.")
67    return real_path
68
69
70def _get_print_file_name(file_name):
71    """Add timestamp suffix to file name. Rename the file name:  file_name + "." + time(seconds)."""
72    time_second = str(int(time.time()))
73    file_name = file_name + "." + time_second
74    if os.path.exists(file_name):
75        ValueError("This file {} already exists.".format(file_name))
76    return file_name
77
78
79class _ThreadLocalInfo(threading.local):
80    """
81    Thread local Info used for store thread local attributes.
82    """
83
84    def __init__(self):
85        super(_ThreadLocalInfo, self).__init__()
86        self._reserve_class_name_in_scope = True
87        self.debug_runtime = False
88
89    @property
90    def reserve_class_name_in_scope(self):
91        """Get whether to save the network class name in the scope."""
92        return self._reserve_class_name_in_scope
93
94    @reserve_class_name_in_scope.setter
95    def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
96        """Set whether to save the network class name in the scope."""
97        if not isinstance(reserve_class_name_in_scope, bool):
98            raise ValueError(
99                "Set reserve_class_name_in_scope value must be bool!")
100        self._reserve_class_name_in_scope = reserve_class_name_in_scope
101
102
103_ContextRecord = namedtuple(
104    "_ContextRecord", ["is_pynative_mode", "switch_context_fn"])
105
106
107class _ContextSwitchInfo(threading.local):
108    """
109    Record of context switch information.
110
111    Args:
112        is_pynative (bool): Whether to adopt the PyNative mode.
113    """
114
115    def __init__(self, is_pynative):
116        super(_ContextSwitchInfo, self).__init__()
117        self.context_stack = []
118        if is_pynative:
119            self.push(True, None)
120
121    def push(self, is_pynative, switch_context_fn):
122        """
123        Push a context switch record onto the stack.
124
125        Args:
126            is_pynative (bool): Whether context switch to PyNative mode.
127            switch_context_fn (Function): A callable that executes the context switch.
128        """
129        if isinstance(switch_context_fn, FunctionType):
130            switch_context_fn()
131        self.context_stack.append(
132            _ContextRecord(is_pynative, switch_context_fn))
133
134    def pop(self):
135        self.context_stack.pop()
136
137
138class _Context:
139    """
140    _Context is the environment in which operations are executed
141
142    Note:
143        Create a context through instantiating Context object is not recommended.
144        should use context() to get the context since Context is singleton.
145    """
146    _instance = None
147    _instance_lock = threading.Lock()
148
149    def __init__(self):
150        self._thread_local_info = _ThreadLocalInfo()
151        self._context_switches = _ContextSwitchInfo(False)
152        self._context_handle = MSContext.get_instance()
153
154    def __new__(cls, *args, **kwargs):
155        if cls._instance is None:
156            cls._instance_lock.acquire()
157            cls._instance = object.__new__(cls)
158            cls._instance_lock.release()
159        return cls._instance
160
161    def __getattribute__(self, attr):
162        value = object.__getattribute__(self, attr)
163        if attr == "_context_handle" and value is None:
164            raise ValueError("Context handle is none in context!!!")
165        return value
166
167    def get_param(self, param):
168        return self._context_handle.get_param(param)
169
170    def set_param(self, param, value):
171        self._context_handle.set_param(param, value)
172
173    def set_mode(self, mode):
174        """
175        Switch between Graph mode and PyNative mode.
176
177        Args:
178            mode (int): GRAPH_MODE or PYNATIVE_MODE.
179        """
180        if mode == PYNATIVE_MODE:
181            if self.enable_debug_runtime:
182                self.set_backend_policy("vm")
183            self._context_switches.push(True, None)
184        elif mode == GRAPH_MODE:
185            if self.enable_debug_runtime:
186                self.set_backend_policy("ge")
187            self._context_switches.push(False, None)
188        else:
189            raise ValueError(f'The execution mode {mode} is invalid!')
190        self.set_param(ms_ctx_param.mode, mode)
191
192    def set_backend_policy(self, policy):
193        success = self._context_handle.set_backend_policy(policy)
194        if not success:
195            raise RuntimeError("Backend policy must be one of ge, vm, ms.")
196
197    def set_save_graphs_path(self, save_graphs_path):
198        self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
199
200    def set_device_target(self, target):
201        valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
202        if not target in valid_targets:
203            raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
204        if target == "Davinci":
205            target = "Ascend"
206        self.set_param(ms_ctx_param.device_target, target)
207        if self.enable_debug_runtime and target == "CPU":
208            self.set_backend_policy("vm")
209
210    def set_auto_tune_mode(self, tune_mode):
211        candidate = ["NO_TUNE", "RL", "GA", "RL,GA", "GA,RL"]
212        if tune_mode in candidate:
213            self.set_param(ms_ctx_param.tune_mode, tune_mode)
214        else:
215            raise ValueError(f"Tune mode must be in ['NO_TUNE', 'RL', 'GA', 'RL,GA', 'GA,RL'], but got {tune_mode}")
216
217    def set_device_id(self, device_id):
218        if device_id < 0 or device_id > 4095:
219            raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
220        self.set_param(ms_ctx_param.device_id, device_id)
221
222    def set_max_call_depth(self, max_call_depth):
223        if max_call_depth <= 0:
224            raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
225        self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
226
227    def set_profiling_options(self, option):
228        if not isinstance(option, str):
229            raise TypeError("The parameter option must be str.")
230        self.set_param(ms_ctx_param.profiling_options, option)
231
232    def set_variable_memory_max_size(self, variable_memory_max_size):
233        """set values of variable_memory_max_size and graph_memory_max_size"""
234        if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern):
235            raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
236        if int(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE:
237            raise ValueError("Context param variable_memory_max_size should be not greater than 31GB.")
238        variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
239        graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
240        graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
241        self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
242        self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
243
244    def set_max_device_memory(self, max_device_memory):
245        if not Validator.check_str_by_regular(max_device_memory, _re_pattern):
246            raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
247        max_device_memory_value = float(max_device_memory[:-2])
248        if max_device_memory_value == 0:
249            raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
250        self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
251
252    def set_print_file_path(self, file_path):
253        """Add timestamp suffix to file name. Sets print file path."""
254        print_file_path = os.path.realpath(file_path)
255        if os.path.isdir(print_file_path):
256            raise IOError("Print_file_path should be file path, but got {}.".format(file_path))
257
258        if os.path.exists(print_file_path):
259            _path, _file_name = os.path.split(print_file_path)
260            path = _make_directory(_path)
261            file_name = _get_print_file_name(_file_name)
262            full_file_name = os.path.join(path, file_name)
263        else:
264            full_file_name = print_file_path
265        self.set_param(ms_ctx_param.print_file_path, full_file_name)
266
267    def set_env_config_path(self, env_config_path):
268        """Check and set env_config_path."""
269        if not self._context_handle.enable_dump_ir():
270            raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
271                             "with '-D on' and recompile source.")
272        env_config_path = os.path.realpath(env_config_path)
273        if not os.path.isfile(env_config_path):
274            raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)
275        try:
276            with open(env_config_path, 'r') as f:
277                json.load(f)
278        except (TypeError, ValueError) as exo:
279            raise ValueError("The %r set by 'env_config_path' should be a json file. "
280                             "Detail: %s." % (env_config_path, str(exo)))
281        self.set_param(ms_ctx_param.env_config_path, env_config_path)
282
283    setters = {
284        'mode': set_mode,
285        'save_graphs_path': set_save_graphs_path,
286        'device_target': set_device_target,
287        'device_id': set_device_id,
288        'auto_tune_mode': set_auto_tune_mode,
289        'max_call_depth': set_max_call_depth,
290        'profiling_options': set_profiling_options,
291        'variable_memory_max_size': set_variable_memory_max_size,
292        'max_device_memory': set_max_device_memory,
293        'print_file_path': set_print_file_path,
294        'env_config_path': set_env_config_path
295    }
296
297    @property
298    def reserve_class_name_in_scope(self):
299        """Get whether to save the network class name in the scope."""
300        return self._thread_local_info.reserve_class_name_in_scope
301
302    @reserve_class_name_in_scope.setter
303    def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
304        """Set whether to save the network class name in the scope."""
305        self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
306
307    @property
308    def enable_ge(self):
309        return self._context_handle.get_backend_policy() == 'ge'
310
311    @property
312    def enable_debug_runtime(self):
313        return self._thread_local_info.debug_runtime
314
315    @enable_debug_runtime.setter
316    def enable_debug_runtime(self, enable):
317        thread_info = self._thread_local_info
318        thread_info.debug_runtime = enable
319
320
321def _context():
322    """
323    Get the global _context, if context is not created, create a new one.
324
325    Returns:
326        _Context, the global context in PyNative mode.
327    """
328    global _k_context
329    if _k_context is None:
330        default_backend = 'debug'
331        try:
332            from mindspore import default_config
333            default_backend = default_config.__backend__
334        except ImportError:
335            logger.error("import default config fail")
336        _k_context = _Context()
337        _k_context.enable_debug_runtime = False
338        if default_backend == 'debug':
339            _k_context.enable_debug_runtime = True
340            default_backend = 'vm'
341        _k_context.set_backend_policy(default_backend)
342    return _k_context
343
344
345@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str,
346                 auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
347                 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
348                 all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int)
349def set_auto_parallel_context(**kwargs):
350    r"""
351    Set auto parallel context, which is valid only for Ascend and GPU target.
352
353    Auto parallel context should be configured before the initialization of your network.
354
355    Note:
356        Attribute name is required for setting attributes.
357        If a program has tasks on different parallel modes, before setting a new parallel mode for the
358        next task, interface mindspore.context.reset_auto_parallel_context() should be called to reset
359        the configuration.
360        Setting or changing parallel modes must be called before creating any Initializer, otherwise,
361        it may have RuntimeError when compiling the network.
362
363    Some configurations are parallel mode specific, see the below table for details:
364
365    ===========================  ===========================
366    Common                       AUTO_PARALLEL
367    ===========================  ===========================
368    device_num                   gradient_fp32_sync
369    global_rank                  loss_repeated_mean
370    gradients_mean               auto_parallel_search_mode
371    parallel_mode                strategy_ckpt_load_file
372    all_reduce_fusion_config     strategy_ckpt_save_file
373    enable_parallel_optimizer    dataset_strategy
374               \                 pipeline_stages
375               \                 grad_accumulation_step
376    ===========================  ===========================
377
378    Args:
379        device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
380        global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
381        gradients_mean (bool): Whether to perform mean operator after allreduce of gradients.
382                     "stand_alone" do not support gradients_mean. Default: False.
383        gradient_fp32_sync (bool): Run allreduce of gradients in fp32. "stand_alone", "data_parallel"
384                     and "hybrid_parallel" do not support gradient_fp32_sync. Default: True.
385        parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
386                     "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
387
388                     - stand_alone: Only one processor is working.
389
390                     - data_parallel: Distributes the data across different processors.
391
392                     - hybrid_parallel: Achieves data parallelism and model parallelism manually.
393
394                     - semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies.
395
396                     - auto_parallel: Achieving parallelism automatically.
397        auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming"
398                     and "dynamic_programming". Default: "dynamic_programming".
399
400                     - recursive_programming: Recursive programming search mode.
401
402                     - dynamic_programming: Dynamic programming search mode.
403        parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have
404                     the same network initialization parameter values for all devices, broadcast the parameters
405                     on device 0 to other devices. Parameter broadcasting in different parallel modes is different,
406                     data_parallel mode, all parameters are broadcast except for the parameter whose attribute
407                     layerwise_parallel is True. Hybrid_parallel, semi_auto_parallel and auto_parallel mode, the
408                     segmented parameters do not participate in broadcasting. Default: False.
409        strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
410        strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
411        full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter
412                       should be set as True. Default: False. The interface is not be recommended currently,
413                       it is better using 'dataset_strategy' to replace it.
414        dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: "data_parallel".
415                       dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is
416                       equal to full_batch=True. For dataset load into net by model parallel strategy likes
417                       ds_stra ((1, 8), (1, 8)), it requires using set_auto_parallel_context(dataset_strategy=ds_stra).
418        enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for
419                       data parallel training in the benefit of time and memory saving. Currently, auto and semi auto
420                       parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports
421                       `Lamb` and `AdamWeightDecay` in Ascend . Default: False.
422        all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM
423                       and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed.
424        pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are
425                        distributed alone the pipeline. The total devices will be divided into 'pipeline_stags' stages.
426                        Currently this could only be used when parallel mode semi_auto_parallel is enabled. Default: 1.
427        grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode.
428                        This should be a positive int. Default: 1.
429
430    Raises:
431        ValueError: If input key is not attribute in auto parallel context.
432
433    Examples:
434        >>> context.set_auto_parallel_context(device_num=8)
435        >>> context.set_auto_parallel_context(global_rank=0)
436        >>> context.set_auto_parallel_context(gradients_mean=True)
437        >>> context.set_auto_parallel_context(gradient_fp32_sync=False)
438        >>> context.set_auto_parallel_context(parallel_mode="auto_parallel")
439        >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming")
440        >>> context.set_auto_parallel_context(parameter_broadcast=False)
441        >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
442        >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
443        >>> context.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8)))
444        >>> context.set_auto_parallel_context(enable_parallel_optimizer=False)
445        >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160])
446        >>> context.set_auto_parallel_context(pipeline_stages=2)
447    """
448    _set_auto_parallel_context(**kwargs)
449
450
451def get_auto_parallel_context(attr_key):
452    """
453    Get auto parallel context attribute value according to the key.
454
455    Args:
456        attr_key (str): The key of the attribute.
457
458    Returns:
459        Returns attribute value according to the key.
460
461    Raises:
462        ValueError: If input key is not attribute in auto parallel context.
463    """
464    return _get_auto_parallel_context(attr_key)
465
466
467def reset_auto_parallel_context():
468    """
469    Reset auto parallel context attributes to the default values:
470
471    - device_num: 1.
472    - global_rank: 0.
473    - gradients_mean: False.
474    - gradient_fp32_sync: True.
475    - parallel_mode: 'stand_alone'.
476    - auto_parallel_search_mode: 'dynamic_programming'.
477    - parameter_broadcast: False.
478    - strategy_ckpt_load_file: ''.
479    - strategy_ckpt_save_file: ''.
480    - full_batch: False.
481    - enable_parallel_optimizer: False.
482    - pipeline_stages: 1.
483    """
484    _reset_auto_parallel_context()
485
486
487def _check_target_specific_cfgs(device, arg_key):
488    """Checking whether a config is suitable for a specified device"""
489    device_cfgs = {
490        'enable_dump': ['Ascend'],
491        'save_dump_path': ['Ascend'],
492        'enable_graph_kernel': ['Ascend', 'GPU'],
493        'graph_kernel_flags': ['Ascend', 'GPU'],
494        'enable_reduce_precision': ['Ascend'],
495        'enable_profiling': ['Ascend'],
496        'profiling_options': ['Ascend'],
497        'print_file_path': ['Ascend'],
498        'variable_memory_max_size': ['Ascend'],
499        'auto_tune_mode': ['Ascend'],
500        'max_device_memory': ['GPU']
501    }
502    # configs not in map device_cfgs are supposed to be suitable for all devices
503    if not arg_key in device_cfgs:
504        return True
505    supported_devices = device_cfgs[arg_key]
506    if device in supported_devices:
507        return True
508    logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
509                   ", ignore it.")
510    return False
511
512
513@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str)
514@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
515                 save_graphs_path=str, enable_dump=bool, auto_tune_mode=str,
516                 save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
517                 enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
518                 enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool,
519                 max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int,
520                 env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool,
521                 load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool)
522def set_context(**kwargs):
523    """
524    Set context for running environment.
525
526    Context should be configured before running your program. If there is no configuration,
527    it will be automatically set according to the device target by default.
528
529    Note:
530        Attribute name is required for setting attributes.
531        The mode is not recommended to be changed after net was initialized because the implementations of some
532        operations are different in graph mode and pynative mode. Default: GRAPH_MODE.
533
534    Some configurations are device specific, see the below table for details:
535
536    +-------------------------+------------------------------+----------------------------+
537    | Function Classification |   Configuration Parameters   |   Hardware Platform Support|
538    +=========================+==============================+============================+
539    | System Configuration    |   device_id                  |   CPU/GPU/Ascend           |
540    |                         +------------------------------+----------------------------+
541    |                         |   device_target              |   CPU/GPU/Ascend           |
542    |                         +------------------------------+----------------------------+
543    |                         |  max_device_memory           |  GPU                       |
544    |                         +------------------------------+----------------------------+
545    |                         |  variable_memory_max_size    |  Ascend                    |
546    +-------------------------+------------------------------+----------------------------+
547    | Debug Configuration     |  save_graphs                 |  CPU/GPU/Ascend            |
548    |                         +------------------------------+----------------------------+
549    |                         |  save_graphs_path            |  CPU/GPU/Ascend            |
550    |                         +------------------------------+----------------------------+
551    |                         |  enable_dump                 |  Ascend                    |
552    |                         +------------------------------+----------------------------+
553    |                         |  save_dump_path              |  Ascend                    |
554    |                         +------------------------------+----------------------------+
555    |                         |  enable_profiling            |  Ascend                    |
556    |                         +------------------------------+----------------------------+
557    |                         |  profiling_options           |  Ascend                    |
558    |                         +------------------------------+----------------------------+
559    |                         |  print_file_path             |  Ascend                    |
560    |                         +------------------------------+----------------------------+
561    |                         |  env_config_path             |  CPU/GPU/Ascend            |
562    |                         +------------------------------+----------------------------+
563    |                         |  precompile_only             |  CPU/GPU/Ascend            |
564    |                         +------------------------------+----------------------------+
565    |                         |  reserve_class_name_in_scope |  CPU/GPU/Ascend            |
566    |                         +------------------------------+----------------------------+
567    |                         |  pynative_synchronize        |  GPU/Ascend                |
568    +-------------------------+------------------------------+----------------------------+
569    | Executive Control       |   mode                       |   CPU/GPU/Ascend           |
570    |                         +------------------------------+----------------------------+
571    |                         |  enable_graph_kernel         |  Ascend/GPU                |
572    |                         +------------------------------+----------------------------+
573    |                         |  graph_kernel_flags          |  Ascend/GPU                |
574    |                         +------------------------------+----------------------------+
575    |                         |  enable_reduce_precision     |  Ascend                    |
576    |                         +------------------------------+----------------------------+
577    |                         |  auto_tune_mode              |  Ascend                    |
578    |                         +------------------------------+----------------------------+
579    |                         |  check_bprop                 |  CPU/GPU/Ascend            |
580    |                         +------------------------------+----------------------------+
581    |                         |  max_call_depth              |  CPU/GPU/Ascend            |
582    |                         +------------------------------+----------------------------+
583    |                         |  enable_sparse               |  CPU/GPU/Ascend            |
584    |                         +------------------------------+----------------------------+
585    |                         |  grad_for_scalar             |  CPU/GPU/Ascend            |
586    |                         +------------------------------+----------------------------+
587    |                         |  save_compile_cache          |  CPU/GPU/Ascend            |
588    |                         +------------------------------+----------------------------+
589    |                         |  load_compile_cache          |  CPU/GPU/Ascend            |
590    +-------------------------+------------------------------+----------------------------+
591
592    Args:
593        device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
594            while device_num_per_host should be no more than 4096. Default: 0.
595        device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
596            If device target is not set, the version of MindSpore package is used.
597        max_device_memory (str): Set the maximum memory available for devices.
598            Currently, it is only supported on GPU. The format is "xxGB". Default: "1024GB".
599            The actual used memory size is the minimum of the available memory of the device and max_device_memory.
600        variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "30GB".
601            After this parameter is set, the maximum memory used by the framework is restricted to the configured value.
602        save_graphs (bool): Whether to save graphs. Default: False.
603            When the `save_graphs` attribute is set as True, attribute of `save_graphs_path` is used to set the
604            intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
605        save_graphs_path (str): Path to save graphs. Default: ".".
606            If the specified directory does not exist, the system will automatically create the directory.
607            During distributed training, graphs will be saved to the directory of
608            `save_graphs_path/rank_${rank_id}/`. `rank_id` is the ID of the current device in the cluster.
609        enable_dump (bool): This parameters is deprecated, and will be deleted in the next version.
610        save_dump_path (str): This parameters is deprecated, and will be deleted in the next version.
611        enable_profiling (bool): This parameters is deprecated, and will be deleted in the next version.
612            Please use mindspore.profiler.Profiler api instead.
613        profiling_options (str): This parameters is deprecated, and will be deleted in the next version.
614            Please use mindspore.profiler.Profiler api instead.
615        print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to
616            a file by default, and print_file_path is not set, the screen will be displayed.
617            If the saved file already exists, the timestamp suffix will be added to the file. Saving data to a file
618            solves the problem of data loss in screen printing when a large amount of data is generated.
619            If it is not set, an error will be reported: prompt to set the upper absolute path.
620        env_config_path (str): Config path for DFX.
621            Through context.set_context(env_config_path="./mindspore_config.json")
622
623            configure RDR:
624
625            - enable: controls whether the RDR is enabled to collect the key data during training and
626              save key data in the fault scenario. When set to true, the RDR will be turned on.
627              When set to false, the RDR will be turned off.
628            - path: sets the path where RDR saves data. The current path must be absolute.
629
630            Memory reuse:
631
632            - mem_Reuse: controls whether the memory reuse function is turned on. When set to True,
633            - the memory reuse function is turned on. When set to False, the memory reuse function is turned off.
634
635        precompile_only (bool): Whether to only precompile the network. Default: False.
636            If set to True, the network will only be compiled, not executed.
637        reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
638            Each node has a scope. A scope of a subnode is the name of its parent node. If reserve_class_name_in_scope
639            is set to True, the class name will be saved after keyword 'net-' in the scope.
640            For example:
641
642            Default/net-Net1/net-Net2 (reserve_class_name_in_scope=True)
643
644            Default/net/net (reserve_class_name_in_scope=False)
645
646        pynative_synchronize (bool): Whether to enable synchronous execution of the device in PyNative mode.
647            Default: False. When the value is set to False, the operator is executed asynchronously on the device.
648            When an error occurs in the execution of the operator, the specific error script code location cannot
649            be located, when the value is set to True, the operator is executed synchronously on the device. It will
650            reduce the execution performance of the program. At this time, when an error occurs in the execution of
651            the operator, the location of the error script code can be located according to the call stack of the error.
652        mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: GRAPH_MODE(0).
653            GRAPH_MODE or PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default
654            mode is GRAPH_MODE.
655        enable_graph_kernel (bool): Whether to enable graph kernel fusion to optimize network execution performance.
656            Default: False.
657            Indicates whether to enable image-computing convergence to optimize network execution performance.
658            If enable_graph_kernel is set to True, acceleration can be enabled.
659            For details of graph kernel fusion, please check
660            `Enabling Graph Kernel Fusion <https://www.mindspore.cn/docs/programming_guide
661            /en/master/enable_graph_kernel_fusion.html>`_.
662        graph_kernel_flags (str) –
663            Optimization options of graph kernel fusion, and the priority is higher when it conflicts
664            with enable_graph_kernel. Only for experienced users.
665            For example, context.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text"). Some general options:
666
667            - opt_level: Set the optimization level.
668              Default: 2. Graph kernel fusion can be enabled equivalently by setting opt_level greater than 0.
669              Available values are:
670
671              - 0: Disable graph kernel fusion;
672              - 1: enable the basic fusion of operators;
673              - 2: includes all optimizations of level 1,
674                and turns on more optimizations such as CSE, arithmetic simplification and so on;
675              - 3: includes all optimizations of level 2, and turns on more optimizations such as SitchingFusion,
676                ParallelFusion and so on. Optimizations of this level are radical and unstable in some scenarios.
677                Be caution when using this level.
678
679            - dump_as_text: dump detail info as text files. Default: false.
680
681            More options can refer to the implementation code. These options can also be set by environment
682            variable MS_GRAPH_KERNEL_FLAGS, without modifying network source code.
683            For example, export MS_GRAPH_KERNEL_FLAGS="--opt_level=2 --dump_as_text".
684        enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
685            If set to True: user specified precision is not supported, the precision will change automatically.
686            If set to False: if the specified precision of the use case is not specified, an error will
687            be reported and exit;
688            For example, on the ascend backend, conv2d only supports fp16 input. Under the fp32 input condition,
689            set true will automatically insert the cast operator to convert fp16, and the flash personnel
690            will report an error and exit.
691        auto_tune_mode (str): The mode of auto tune when op building, get the best tiling performance.
692            Default: NO_TUNE. The value must be in ['RL', 'GA', 'RL,GA'].
693
694            - RL: Reinforcement Learning tune.
695            - GA: Genetic Algorithm tune.
696            - RL,GA: When both RL and GA optimization are enabled, the tool automatically selects RL or GA based on
697              different types of operators in the network model. The sequence of RL and GA is not differentiated.
698              (Automatic selection).
699
700            For more information about the enable operator tuning tool settings, please check
701            `Enable the operator optimization tool <https://www.mindspore.cn/docs/programming_guide/en
702            /master/enable_auto_tune.html>`_.
703        check_bprop (bool): Whether to check back propagation nodes. The checking ensures that the shape and dtype
704            of back propagation node outputs is the same as input parameters. Default: False.
705        max_call_depth (int): Specify the maximum depth of function call. Must be positive integer. Default: 1000.
706            The max_call_depth parameter needs to be set when the nested call is too deep or the number
707            of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be
708            set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow.
709        enable_sparse (bool): Whether to enable sparsity feature. Default: False.
710            For details of sparsity and sparse tensor, please check
711            `sparse tensor <https://www.mindspore.cn/docs/programming_guide/en/r1.5/tensor.html#sparse-tensor>`_.
712        grad_for_scalar (bool):  Whether to get gradient for scalar. Default: False.
713            When grad_for_scalar is set to True, the function's scalar input can be derived.
714            The default value is False. Because the back-end does not support scaling operations currently,
715            this interface only supports simple operations that can be deduced by the front-end.
716        save_compile_cache (bool): Whether to cache the graph compiled by front-end. Default: False.
717            After save_compile_cache is set to True, a hardware-independent compilation cache is
718            generated and exported to a MINDIR file, This is an experimental prototype that is
719            subject to change and/or deletion.
720        load_compile_cache (bool): Whether to use the cache of the graph compiled by front-end.
721            This parameter must be used together with save_compile_cache. After save_compile_cache is set to True,
722            a hardware-independent compilation cache is generated and exported to a MINDIR file.
723            When the network is executed again, if load_compile_cache is set to True, the compile cache is loaded.
724            By now, we do not support automatic checking for changes.
725            Default: False.
726            This is an experimental prototype that is subject to change and/or deletion.
727    Raises:
728        ValueError: If input key is not an attribute in context.
729
730    Examples:
731        >>> context.set_context(mode=context.PYNATIVE_MODE)
732        >>> context.set_context(precompile_only=True)
733        >>> context.set_context(device_target="Ascend")
734        >>> context.set_context(device_id=0)
735        >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
736        >>> context.set_context(enable_reduce_precision=True)
737        >>> context.set_context(enable_dump=True, save_dump_path=".")
738        >>> context.set_context(enable_graph_kernel=True)
739        >>> context.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text")
740        >>> context.set_context(reserve_class_name_in_scope=True)
741        >>> context.set_context(variable_memory_max_size="6GB")
742        >>> context.set_context(enable_profiling=True,
743        ...                     profiling_options='{"output":"/home/data/output","training_trace":"on"}')
744        >>> context.set_context(check_bprop=True)
745        >>> context.set_context(max_device_memory="3.5GB")
746        >>> context.set_context(print_file_path="print.pb")
747        >>> context.set_context(enable_sparse=True)
748        >>> context.set_context(max_call_depth=80)
749        >>> context.set_context(env_config_path="./env_config.json")
750        >>> context.set_context(auto_tune_mode="GA,RL")
751        >>> context.set_context(grad_for_scalar=True)
752        >>> context.set_context(save_compile_cache=True)
753        >>> context.set_context(load_compile_cache=True)
754        >>> context.set_context(pynative_synchronize=True)
755    """
756    ctx = _context()
757    # set device target first
758    if 'device_target' in kwargs:
759        ctx.set_device_target(kwargs['device_target'])
760        device = ctx.get_param(ms_ctx_param.device_target)
761        if not device.lower() in __device_target__:
762            raise ValueError(f"Error, package type {__package_name__} support device type {__device_target__}, "
763                             f"but got device target {device}")
764    device = ctx.get_param(ms_ctx_param.device_target)
765    for key, value in kwargs.items():
766        if key in ('enable_profiling', 'profiling_options', 'enable_auto_mixed_precision',
767                   'enable_dump', 'save_dump_path'):
768            logger.warning(f" '{key}' parameters will be deprecated."
769                           "For details, please see the interface parameter API comments")
770            continue
771        if not _check_target_specific_cfgs(device, key):
772            continue
773        if hasattr(ctx, key):
774            setattr(ctx, key, value)
775            continue
776        if key in ctx.setters:
777            ctx.setters[key](ctx, value)
778            continue
779        # enum variables beginning with '_' are for internal use
780        if key in ms_ctx_param.__members__ and key[0] != '_':
781            ctx.set_param(ms_ctx_param.__members__[key], value)
782            continue
783        raise ValueError("Set context keyword %s is not recognized!" % key)
784
785
786def get_context(attr_key):
787    """
788    Get context attribute value according to the input key.
789    If some attributes are not set, they will be automatically obtained.
790
791    Args:
792        attr_key (str): The key of the attribute.
793
794    Returns:
795        Object, The value of given attribute key.
796
797    Raises:
798        ValueError: If input key is not an attribute in context.
799    Examples:
800        >>> context.get_context("device_target")
801        >>> context.get_context("device_id")
802    """
803    ctx = _context()
804    device = ctx.get_param(ms_ctx_param.device_target)
805    _ = _check_target_specific_cfgs(device, attr_key)
806    if hasattr(ctx, attr_key):
807        return getattr(ctx, attr_key)
808    # enum variables beginning with '_' are for internal use
809    if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
810        return ctx.get_param(ms_ctx_param.__members__[attr_key])
811    raise ValueError("Get context keyword %s is not recognized!" % attr_key)
812
813
814class ParallelMode:
815    """
816    Parallel mode options.
817
818    There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
819    "HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
820
821    - STAND_ALONE: Only one processor is working.
822    - DATA_PARALLEL: Distributes the data across different processors.
823    - HYBRID_PARALLEL: Achieves data parallelism and model parallelism manually.
824    - SEMI_AUTO_PARALLEL: Achieves data parallelism and model parallelism by setting parallel strategies.
825    - AUTO_PARALLEL: Achieves parallelism automatically.
826
827    MODE_LIST: The list of all supported parallel modes.
828    """
829
830    STAND_ALONE = "stand_alone"
831    DATA_PARALLEL = "data_parallel"
832    HYBRID_PARALLEL = "hybrid_parallel"
833    SEMI_AUTO_PARALLEL = "semi_auto_parallel"
834    AUTO_PARALLEL = "auto_parallel"
835    MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL]
836
837
838@args_type_check(enable_ps=bool)
839def set_ps_context(**kwargs):
840    """
841    Set parameter server training mode context.
842
843    Note:
844        Some other environment variables should also be set for parameter server training mode.
845        These environment variables are listed below:
846
847        MS_SERVER_NUM: Server number
848
849        MS_WORKER_NUM: Worker number
850
851        MS_SCHED_HOST: Scheduler IP address
852
853        MS_SCHED_PORT: Scheduler port
854
855        MS_ROLE: The role of this process:
856
857        MS_SCHED: represents the scheduler,
858
859        MS_WORKER: represents the worker,
860
861        MS_PSERVER: represents the Server
862
863    Args:
864        enable_ps (bool): Whether to enable parameter server training mode.
865                          Only after enable_ps is set True, the environment variables will be effective.
866                          Default: False.
867
868    Raises:
869        ValueError: If input key is not the attribute in parameter server training mode context.
870
871    Examples:
872        >>> context.set_ps_context(enable_ps=True)
873    """
874    _set_ps_context(**kwargs)
875
876
877def get_ps_context(attr_key):
878    """
879    Get parameter server training mode context attribute value according to the key.
880
881    Args:
882        attr_key (str): The key of the attribute:
883
884            - enable_ps (bool): Whether to enable parameter server training mode.
885
886    Returns:
887        Returns attribute value according to the key.
888
889    Raises:
890        ValueError: If input key is not attribute in auto parallel context.
891
892    Examples:
893        >>> context.get_ps_context(enable_ps)
894    """
895    return _get_ps_context(attr_key)
896
897
898def reset_ps_context():
899    """
900    Reset parameter server training mode context attributes to the default values:
901
902    - enable_ps: False.
903    """
904    _reset_ps_context()
905
906
907def set_fl_context(**kwargs):
908    """
909    Set federated learning training mode context.
910
911    Args:
912        enable_fl (bool): Whether to enable federated learning training mode.
913                          Default: False.
914        server_mode (str): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'.
915                              Default: 'FEDERATED_LEARNING'.
916        ms_role (str): The process's role in the federated learning mode,
917                          which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'.
918                          Default: 'MS_SERVER'.
919        worker_num (int): The number of workers. For current version, this must be set to 1 or 0.
920        server_num (int): The number of federated learning servers. Default: 0.
921        scheduler_ip (str): The scheduler IP. Default: '0.0.0.0'.
922        scheduler_port (int): The scheduler port. Default: 6667.
923        fl_server_port (int): The http port of the federated learning server.
924                              Normally for each server this should be set to the same value. Default: 6668.
925        enable_fl_client (bool): Whether this process is federated learning client. Default: False.
926        start_fl_job_threshold (int): The threshold count of startFLJob. Default: 1.
927        start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000.
928        share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0.
929        update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0.
930        cipher_time_window (int): The time window duration for each cipher round in millisecond. Default: 300000.
931        reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0.
932        update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000.
933        fl_name (string): The federated learning job name. Default: ''.
934        fl_iteration_num (int): Iteration number of federated learning,
935                                which is the number of interactions between client and server. Default: 20.
936        client_epoch_num (int): Client training epoch number. Default: 25.
937        client_batch_size (int): Client training data batch size. Default: 32.
938        client_learning_rate (float): Client training learning rate. Default: 0.001.
939        worker_step_num_per_iteration (int): The worker's standalone training step number before communicating with
940                                             server. Default: 65.
941        dp_eps (float): Epsilon budget of differential privacy mechanism. The smaller the dp_eps, the better the
942            privacy protection effect. Default: 50.0.
943        dp_delta (float): Delta budget of differential privacy mechanism, which is usually equals the reciprocal of
944            client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01.
945        dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is
946            suggested to be 0.5~2. Default: 1.0.
947        encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or
948            'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy
949            protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If
950            'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing.
951            Default: 'NOT_ENCRYPT'.
952        config_file_path (string): Configuration file path used by recovery. Default: ''.
953        scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
954        enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
955        client_password (str): Password to decrypt the secret key stored in the client certificate.
956        server_password (str): Password to decrypt the secret key stored in the server certificate.
957
958    Raises:
959        ValueError: If input key is not the attribute in federated learning mode context.
960
961    Examples:
962        >>> context.set_fl_context(enable_fl=True, server_mode='FEDERATED_LEARNING')
963    """
964    _set_ps_context(**kwargs)
965
966
967def get_fl_context(attr_key):
968    """
969    Get federated learning mode context attribute value according to the key.
970
971    Args:
972        attr_key (str): The key of the attribute.
973                        Please refer to `set_fl_context`'s parameters to decide what key should be passed.
974
975    Returns:
976        Returns attribute value according to the key.
977
978    Raises:
979        ValueError: If input key is not attribute in federated learning mode context.
980
981    Examples:
982        >>> context.get_fl_context("server_mode")
983    """
984    return _get_ps_context(attr_key)
985