1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16The context of mindspore, used to configure the current execution environment, 17includes the execution mode, execution backend and other feature switches. 18""" 19import json 20import os 21import time 22import threading 23from collections import namedtuple 24from types import FunctionType 25 26from mindspore import log as logger 27from mindspore._c_expression import MSContext, ms_ctx_param 28from mindspore._checkparam import args_type_check, Validator, args_unreset_check 29from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ 30 _reset_auto_parallel_context 31from mindspore.parallel._ps_context import _set_ps_context, _get_ps_context, _reset_ps_context 32from .default_config import __device_target__, __package_name__ 33 34__all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_auto_parallel_context', 35 'get_auto_parallel_context', 'reset_auto_parallel_context', 'ParallelMode', 'set_ps_context', 36 'get_ps_context', 'reset_ps_context', 'set_fl_context', 'get_fl_context'] 37 38GRAPH_MODE = 0 39PYNATIVE_MODE = 1 40_DEVICE_APP_MEMORY_SIZE = 31 # The max memory size of graph plus variable. 41_re_pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' 42_k_context = None 43 44 45def _make_directory(path): 46 """Make directory.""" 47 real_path = None 48 if path is None or not isinstance(path, str) or path.strip() == "": 49 raise ValueError(f"Input path `{path}` is invalid type") 50 51 # convert the relative paths 52 path = os.path.realpath(path) 53 logger.debug("The absolute path is %r", path) 54 55 # check whether the path is already existed and has written permissions 56 if os.path.exists(path): 57 real_path = path 58 else: 59 # All exceptions need to be caught because create directory maybe have some limit(permissions) 60 logger.debug("The directory(%s) doesn't exist, will create it", path) 61 try: 62 os.makedirs(path) 63 real_path = path 64 except PermissionError as e: 65 logger.error(f"No write permission on the directory `{path}, error = {e}") 66 raise ValueError(f"No write permission on the directory `{path}`.") 67 return real_path 68 69 70def _get_print_file_name(file_name): 71 """Add timestamp suffix to file name. Rename the file name: file_name + "." + time(seconds).""" 72 time_second = str(int(time.time())) 73 file_name = file_name + "." + time_second 74 if os.path.exists(file_name): 75 ValueError("This file {} already exists.".format(file_name)) 76 return file_name 77 78 79class _ThreadLocalInfo(threading.local): 80 """ 81 Thread local Info used for store thread local attributes. 82 """ 83 84 def __init__(self): 85 super(_ThreadLocalInfo, self).__init__() 86 self._reserve_class_name_in_scope = True 87 self.debug_runtime = False 88 89 @property 90 def reserve_class_name_in_scope(self): 91 """Get whether to save the network class name in the scope.""" 92 return self._reserve_class_name_in_scope 93 94 @reserve_class_name_in_scope.setter 95 def reserve_class_name_in_scope(self, reserve_class_name_in_scope): 96 """Set whether to save the network class name in the scope.""" 97 if not isinstance(reserve_class_name_in_scope, bool): 98 raise ValueError( 99 "Set reserve_class_name_in_scope value must be bool!") 100 self._reserve_class_name_in_scope = reserve_class_name_in_scope 101 102 103_ContextRecord = namedtuple( 104 "_ContextRecord", ["is_pynative_mode", "switch_context_fn"]) 105 106 107class _ContextSwitchInfo(threading.local): 108 """ 109 Record of context switch information. 110 111 Args: 112 is_pynative (bool): Whether to adopt the PyNative mode. 113 """ 114 115 def __init__(self, is_pynative): 116 super(_ContextSwitchInfo, self).__init__() 117 self.context_stack = [] 118 if is_pynative: 119 self.push(True, None) 120 121 def push(self, is_pynative, switch_context_fn): 122 """ 123 Push a context switch record onto the stack. 124 125 Args: 126 is_pynative (bool): Whether context switch to PyNative mode. 127 switch_context_fn (Function): A callable that executes the context switch. 128 """ 129 if isinstance(switch_context_fn, FunctionType): 130 switch_context_fn() 131 self.context_stack.append( 132 _ContextRecord(is_pynative, switch_context_fn)) 133 134 def pop(self): 135 self.context_stack.pop() 136 137 138class _Context: 139 """ 140 _Context is the environment in which operations are executed 141 142 Note: 143 Create a context through instantiating Context object is not recommended. 144 should use context() to get the context since Context is singleton. 145 """ 146 _instance = None 147 _instance_lock = threading.Lock() 148 149 def __init__(self): 150 self._thread_local_info = _ThreadLocalInfo() 151 self._context_switches = _ContextSwitchInfo(False) 152 self._context_handle = MSContext.get_instance() 153 154 def __new__(cls, *args, **kwargs): 155 if cls._instance is None: 156 cls._instance_lock.acquire() 157 cls._instance = object.__new__(cls) 158 cls._instance_lock.release() 159 return cls._instance 160 161 def __getattribute__(self, attr): 162 value = object.__getattribute__(self, attr) 163 if attr == "_context_handle" and value is None: 164 raise ValueError("Context handle is none in context!!!") 165 return value 166 167 def get_param(self, param): 168 return self._context_handle.get_param(param) 169 170 def set_param(self, param, value): 171 self._context_handle.set_param(param, value) 172 173 def set_mode(self, mode): 174 """ 175 Switch between Graph mode and PyNative mode. 176 177 Args: 178 mode (int): GRAPH_MODE or PYNATIVE_MODE. 179 """ 180 if mode == PYNATIVE_MODE: 181 if self.enable_debug_runtime: 182 self.set_backend_policy("vm") 183 self._context_switches.push(True, None) 184 elif mode == GRAPH_MODE: 185 if self.enable_debug_runtime: 186 self.set_backend_policy("ge") 187 self._context_switches.push(False, None) 188 else: 189 raise ValueError(f'The execution mode {mode} is invalid!') 190 self.set_param(ms_ctx_param.mode, mode) 191 192 def set_backend_policy(self, policy): 193 success = self._context_handle.set_backend_policy(policy) 194 if not success: 195 raise RuntimeError("Backend policy must be one of ge, vm, ms.") 196 197 def set_save_graphs_path(self, save_graphs_path): 198 self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) 199 200 def set_device_target(self, target): 201 valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] 202 if not target in valid_targets: 203 raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") 204 if target == "Davinci": 205 target = "Ascend" 206 self.set_param(ms_ctx_param.device_target, target) 207 if self.enable_debug_runtime and target == "CPU": 208 self.set_backend_policy("vm") 209 210 def set_auto_tune_mode(self, tune_mode): 211 candidate = ["NO_TUNE", "RL", "GA", "RL,GA", "GA,RL"] 212 if tune_mode in candidate: 213 self.set_param(ms_ctx_param.tune_mode, tune_mode) 214 else: 215 raise ValueError(f"Tune mode must be in ['NO_TUNE', 'RL', 'GA', 'RL,GA', 'GA,RL'], but got {tune_mode}") 216 217 def set_device_id(self, device_id): 218 if device_id < 0 or device_id > 4095: 219 raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") 220 self.set_param(ms_ctx_param.device_id, device_id) 221 222 def set_max_call_depth(self, max_call_depth): 223 if max_call_depth <= 0: 224 raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") 225 self.set_param(ms_ctx_param.max_call_depth, max_call_depth) 226 227 def set_profiling_options(self, option): 228 if not isinstance(option, str): 229 raise TypeError("The parameter option must be str.") 230 self.set_param(ms_ctx_param.profiling_options, option) 231 232 def set_variable_memory_max_size(self, variable_memory_max_size): 233 """set values of variable_memory_max_size and graph_memory_max_size""" 234 if not Validator.check_str_by_regular(variable_memory_max_size, _re_pattern): 235 raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") 236 if int(variable_memory_max_size[:-2]) > _DEVICE_APP_MEMORY_SIZE: 237 raise ValueError("Context param variable_memory_max_size should be not greater than 31GB.") 238 variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" 239 graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) 240 graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" 241 self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) 242 self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) 243 244 def set_max_device_memory(self, max_device_memory): 245 if not Validator.check_str_by_regular(max_device_memory, _re_pattern): 246 raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") 247 max_device_memory_value = float(max_device_memory[:-2]) 248 if max_device_memory_value == 0: 249 raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") 250 self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) 251 252 def set_print_file_path(self, file_path): 253 """Add timestamp suffix to file name. Sets print file path.""" 254 print_file_path = os.path.realpath(file_path) 255 if os.path.isdir(print_file_path): 256 raise IOError("Print_file_path should be file path, but got {}.".format(file_path)) 257 258 if os.path.exists(print_file_path): 259 _path, _file_name = os.path.split(print_file_path) 260 path = _make_directory(_path) 261 file_name = _get_print_file_name(_file_name) 262 full_file_name = os.path.join(path, file_name) 263 else: 264 full_file_name = print_file_path 265 self.set_param(ms_ctx_param.print_file_path, full_file_name) 266 267 def set_env_config_path(self, env_config_path): 268 """Check and set env_config_path.""" 269 if not self._context_handle.enable_dump_ir(): 270 raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR " 271 "with '-D on' and recompile source.") 272 env_config_path = os.path.realpath(env_config_path) 273 if not os.path.isfile(env_config_path): 274 raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path) 275 try: 276 with open(env_config_path, 'r') as f: 277 json.load(f) 278 except (TypeError, ValueError) as exo: 279 raise ValueError("The %r set by 'env_config_path' should be a json file. " 280 "Detail: %s." % (env_config_path, str(exo))) 281 self.set_param(ms_ctx_param.env_config_path, env_config_path) 282 283 setters = { 284 'mode': set_mode, 285 'save_graphs_path': set_save_graphs_path, 286 'device_target': set_device_target, 287 'device_id': set_device_id, 288 'auto_tune_mode': set_auto_tune_mode, 289 'max_call_depth': set_max_call_depth, 290 'profiling_options': set_profiling_options, 291 'variable_memory_max_size': set_variable_memory_max_size, 292 'max_device_memory': set_max_device_memory, 293 'print_file_path': set_print_file_path, 294 'env_config_path': set_env_config_path 295 } 296 297 @property 298 def reserve_class_name_in_scope(self): 299 """Get whether to save the network class name in the scope.""" 300 return self._thread_local_info.reserve_class_name_in_scope 301 302 @reserve_class_name_in_scope.setter 303 def reserve_class_name_in_scope(self, reserve_class_name_in_scope): 304 """Set whether to save the network class name in the scope.""" 305 self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope 306 307 @property 308 def enable_ge(self): 309 return self._context_handle.get_backend_policy() == 'ge' 310 311 @property 312 def enable_debug_runtime(self): 313 return self._thread_local_info.debug_runtime 314 315 @enable_debug_runtime.setter 316 def enable_debug_runtime(self, enable): 317 thread_info = self._thread_local_info 318 thread_info.debug_runtime = enable 319 320 321def _context(): 322 """ 323 Get the global _context, if context is not created, create a new one. 324 325 Returns: 326 _Context, the global context in PyNative mode. 327 """ 328 global _k_context 329 if _k_context is None: 330 default_backend = 'debug' 331 try: 332 from mindspore import default_config 333 default_backend = default_config.__backend__ 334 except ImportError: 335 logger.error("import default config fail") 336 _k_context = _Context() 337 _k_context.enable_debug_runtime = False 338 if default_backend == 'debug': 339 _k_context.enable_debug_runtime = True 340 default_backend = 'vm' 341 _k_context.set_backend_policy(default_backend) 342 return _k_context 343 344 345@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, 346 auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, 347 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, 348 all_reduce_fusion_config=list, pipeline_stages=int, grad_accumulation_step=int) 349def set_auto_parallel_context(**kwargs): 350 r""" 351 Set auto parallel context, which is valid only for Ascend and GPU target. 352 353 Auto parallel context should be configured before the initialization of your network. 354 355 Note: 356 Attribute name is required for setting attributes. 357 If a program has tasks on different parallel modes, before setting a new parallel mode for the 358 next task, interface mindspore.context.reset_auto_parallel_context() should be called to reset 359 the configuration. 360 Setting or changing parallel modes must be called before creating any Initializer, otherwise, 361 it may have RuntimeError when compiling the network. 362 363 Some configurations are parallel mode specific, see the below table for details: 364 365 =========================== =========================== 366 Common AUTO_PARALLEL 367 =========================== =========================== 368 device_num gradient_fp32_sync 369 global_rank loss_repeated_mean 370 gradients_mean auto_parallel_search_mode 371 parallel_mode strategy_ckpt_load_file 372 all_reduce_fusion_config strategy_ckpt_save_file 373 enable_parallel_optimizer dataset_strategy 374 \ pipeline_stages 375 \ grad_accumulation_step 376 =========================== =========================== 377 378 Args: 379 device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. 380 global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. 381 gradients_mean (bool): Whether to perform mean operator after allreduce of gradients. 382 "stand_alone" do not support gradients_mean. Default: False. 383 gradient_fp32_sync (bool): Run allreduce of gradients in fp32. "stand_alone", "data_parallel" 384 and "hybrid_parallel" do not support gradient_fp32_sync. Default: True. 385 parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", 386 "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". 387 388 - stand_alone: Only one processor is working. 389 390 - data_parallel: Distributes the data across different processors. 391 392 - hybrid_parallel: Achieves data parallelism and model parallelism manually. 393 394 - semi_auto_parallel: Achieves data and model parallelism by setting parallel strategies. 395 396 - auto_parallel: Achieving parallelism automatically. 397 auto_parallel_search_mode (str): There are two kinds of shard strategy search modes, "recursive_programming" 398 and "dynamic_programming". Default: "dynamic_programming". 399 400 - recursive_programming: Recursive programming search mode. 401 402 - dynamic_programming: Dynamic programming search mode. 403 parameter_broadcast (bool): Whether to broadcast parameters before training. Before training, in order to have 404 the same network initialization parameter values for all devices, broadcast the parameters 405 on device 0 to other devices. Parameter broadcasting in different parallel modes is different, 406 data_parallel mode, all parameters are broadcast except for the parameter whose attribute 407 layerwise_parallel is True. Hybrid_parallel, semi_auto_parallel and auto_parallel mode, the 408 segmented parameters do not participate in broadcasting. Default: False. 409 strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' 410 strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' 411 full_batch (bool): If you load whole batch datasets in auto_parallel mode, this parameter 412 should be set as True. Default: False. The interface is not be recommended currently, 413 it is better using 'dataset_strategy' to replace it. 414 dataset_strategy (Union[str, tuple]): Dataset sharding strategy. Default: "data_parallel". 415 dataset_strategy="data_parallel" is equal to full_batch=False, dataset_strategy="full_batch" is 416 equal to full_batch=True. For dataset load into net by model parallel strategy likes 417 ds_stra ((1, 8), (1, 8)), it requires using set_auto_parallel_context(dataset_strategy=ds_stra). 418 enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation for 419 data parallel training in the benefit of time and memory saving. Currently, auto and semi auto 420 parallel mode support all optimizers in both Ascend and GPU. Data parallel mode only supports 421 `Lamb` and `AdamWeightDecay` in Ascend . Default: False. 422 all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Only support ReduceOp.SUM 423 and HCCL_WORLD_GROUP/NCCL_WORLD_GROUP. No Default, if it is not set, the fusion is closed. 424 pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how the devices are 425 distributed alone the pipeline. The total devices will be divided into 'pipeline_stags' stages. 426 Currently this could only be used when parallel mode semi_auto_parallel is enabled. Default: 1. 427 grad_accumulation_step (int): Set the accumulation steps of gradients in auto and semi auto parallel mode. 428 This should be a positive int. Default: 1. 429 430 Raises: 431 ValueError: If input key is not attribute in auto parallel context. 432 433 Examples: 434 >>> context.set_auto_parallel_context(device_num=8) 435 >>> context.set_auto_parallel_context(global_rank=0) 436 >>> context.set_auto_parallel_context(gradients_mean=True) 437 >>> context.set_auto_parallel_context(gradient_fp32_sync=False) 438 >>> context.set_auto_parallel_context(parallel_mode="auto_parallel") 439 >>> context.set_auto_parallel_context(auto_parallel_search_mode="dynamic_programming") 440 >>> context.set_auto_parallel_context(parameter_broadcast=False) 441 >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") 442 >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") 443 >>> context.set_auto_parallel_context(dataset_strategy=((1, 8), (1, 8))) 444 >>> context.set_auto_parallel_context(enable_parallel_optimizer=False) 445 >>> context.set_auto_parallel_context(all_reduce_fusion_config=[8, 160]) 446 >>> context.set_auto_parallel_context(pipeline_stages=2) 447 """ 448 _set_auto_parallel_context(**kwargs) 449 450 451def get_auto_parallel_context(attr_key): 452 """ 453 Get auto parallel context attribute value according to the key. 454 455 Args: 456 attr_key (str): The key of the attribute. 457 458 Returns: 459 Returns attribute value according to the key. 460 461 Raises: 462 ValueError: If input key is not attribute in auto parallel context. 463 """ 464 return _get_auto_parallel_context(attr_key) 465 466 467def reset_auto_parallel_context(): 468 """ 469 Reset auto parallel context attributes to the default values: 470 471 - device_num: 1. 472 - global_rank: 0. 473 - gradients_mean: False. 474 - gradient_fp32_sync: True. 475 - parallel_mode: 'stand_alone'. 476 - auto_parallel_search_mode: 'dynamic_programming'. 477 - parameter_broadcast: False. 478 - strategy_ckpt_load_file: ''. 479 - strategy_ckpt_save_file: ''. 480 - full_batch: False. 481 - enable_parallel_optimizer: False. 482 - pipeline_stages: 1. 483 """ 484 _reset_auto_parallel_context() 485 486 487def _check_target_specific_cfgs(device, arg_key): 488 """Checking whether a config is suitable for a specified device""" 489 device_cfgs = { 490 'enable_dump': ['Ascend'], 491 'save_dump_path': ['Ascend'], 492 'enable_graph_kernel': ['Ascend', 'GPU'], 493 'graph_kernel_flags': ['Ascend', 'GPU'], 494 'enable_reduce_precision': ['Ascend'], 495 'enable_profiling': ['Ascend'], 496 'profiling_options': ['Ascend'], 497 'print_file_path': ['Ascend'], 498 'variable_memory_max_size': ['Ascend'], 499 'auto_tune_mode': ['Ascend'], 500 'max_device_memory': ['GPU'] 501 } 502 # configs not in map device_cfgs are supposed to be suitable for all devices 503 if not arg_key in device_cfgs: 504 return True 505 supported_devices = device_cfgs[arg_key] 506 if device in supported_devices: 507 return True 508 logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'" 509 ", ignore it.") 510 return False 511 512 513@args_unreset_check(device_id=int, variable_memory_max_size=str, max_device_memory=str) 514@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, 515 save_graphs_path=str, enable_dump=bool, auto_tune_mode=str, 516 save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, 517 enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, 518 enable_graph_kernel=bool, reserve_class_name_in_scope=bool, check_bprop=bool, 519 max_device_memory=str, print_file_path=str, enable_sparse=bool, max_call_depth=int, 520 env_config_path=str, graph_kernel_flags=str, save_compile_cache=bool, 521 load_compile_cache=bool, grad_for_scalar=bool, pynative_synchronize=bool) 522def set_context(**kwargs): 523 """ 524 Set context for running environment. 525 526 Context should be configured before running your program. If there is no configuration, 527 it will be automatically set according to the device target by default. 528 529 Note: 530 Attribute name is required for setting attributes. 531 The mode is not recommended to be changed after net was initialized because the implementations of some 532 operations are different in graph mode and pynative mode. Default: GRAPH_MODE. 533 534 Some configurations are device specific, see the below table for details: 535 536 +-------------------------+------------------------------+----------------------------+ 537 | Function Classification | Configuration Parameters | Hardware Platform Support| 538 +=========================+==============================+============================+ 539 | System Configuration | device_id | CPU/GPU/Ascend | 540 | +------------------------------+----------------------------+ 541 | | device_target | CPU/GPU/Ascend | 542 | +------------------------------+----------------------------+ 543 | | max_device_memory | GPU | 544 | +------------------------------+----------------------------+ 545 | | variable_memory_max_size | Ascend | 546 +-------------------------+------------------------------+----------------------------+ 547 | Debug Configuration | save_graphs | CPU/GPU/Ascend | 548 | +------------------------------+----------------------------+ 549 | | save_graphs_path | CPU/GPU/Ascend | 550 | +------------------------------+----------------------------+ 551 | | enable_dump | Ascend | 552 | +------------------------------+----------------------------+ 553 | | save_dump_path | Ascend | 554 | +------------------------------+----------------------------+ 555 | | enable_profiling | Ascend | 556 | +------------------------------+----------------------------+ 557 | | profiling_options | Ascend | 558 | +------------------------------+----------------------------+ 559 | | print_file_path | Ascend | 560 | +------------------------------+----------------------------+ 561 | | env_config_path | CPU/GPU/Ascend | 562 | +------------------------------+----------------------------+ 563 | | precompile_only | CPU/GPU/Ascend | 564 | +------------------------------+----------------------------+ 565 | | reserve_class_name_in_scope | CPU/GPU/Ascend | 566 | +------------------------------+----------------------------+ 567 | | pynative_synchronize | GPU/Ascend | 568 +-------------------------+------------------------------+----------------------------+ 569 | Executive Control | mode | CPU/GPU/Ascend | 570 | +------------------------------+----------------------------+ 571 | | enable_graph_kernel | Ascend/GPU | 572 | +------------------------------+----------------------------+ 573 | | graph_kernel_flags | Ascend/GPU | 574 | +------------------------------+----------------------------+ 575 | | enable_reduce_precision | Ascend | 576 | +------------------------------+----------------------------+ 577 | | auto_tune_mode | Ascend | 578 | +------------------------------+----------------------------+ 579 | | check_bprop | CPU/GPU/Ascend | 580 | +------------------------------+----------------------------+ 581 | | max_call_depth | CPU/GPU/Ascend | 582 | +------------------------------+----------------------------+ 583 | | enable_sparse | CPU/GPU/Ascend | 584 | +------------------------------+----------------------------+ 585 | | grad_for_scalar | CPU/GPU/Ascend | 586 | +------------------------------+----------------------------+ 587 | | save_compile_cache | CPU/GPU/Ascend | 588 | +------------------------------+----------------------------+ 589 | | load_compile_cache | CPU/GPU/Ascend | 590 +-------------------------+------------------------------+----------------------------+ 591 592 Args: 593 device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1], 594 while device_num_per_host should be no more than 4096. Default: 0. 595 device_target (str): The target device to run, support "Ascend", "GPU", and "CPU". 596 If device target is not set, the version of MindSpore package is used. 597 max_device_memory (str): Set the maximum memory available for devices. 598 Currently, it is only supported on GPU. The format is "xxGB". Default: "1024GB". 599 The actual used memory size is the minimum of the available memory of the device and max_device_memory. 600 variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "30GB". 601 After this parameter is set, the maximum memory used by the framework is restricted to the configured value. 602 save_graphs (bool): Whether to save graphs. Default: False. 603 When the `save_graphs` attribute is set as True, attribute of `save_graphs_path` is used to set the 604 intermediate compilation graph storage path. By default, the graphs are saved in the current directory. 605 save_graphs_path (str): Path to save graphs. Default: ".". 606 If the specified directory does not exist, the system will automatically create the directory. 607 During distributed training, graphs will be saved to the directory of 608 `save_graphs_path/rank_${rank_id}/`. `rank_id` is the ID of the current device in the cluster. 609 enable_dump (bool): This parameters is deprecated, and will be deleted in the next version. 610 save_dump_path (str): This parameters is deprecated, and will be deleted in the next version. 611 enable_profiling (bool): This parameters is deprecated, and will be deleted in the next version. 612 Please use mindspore.profiler.Profiler api instead. 613 profiling_options (str): This parameters is deprecated, and will be deleted in the next version. 614 Please use mindspore.profiler.Profiler api instead. 615 print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to 616 a file by default, and print_file_path is not set, the screen will be displayed. 617 If the saved file already exists, the timestamp suffix will be added to the file. Saving data to a file 618 solves the problem of data loss in screen printing when a large amount of data is generated. 619 If it is not set, an error will be reported: prompt to set the upper absolute path. 620 env_config_path (str): Config path for DFX. 621 Through context.set_context(env_config_path="./mindspore_config.json") 622 623 configure RDR: 624 625 - enable: controls whether the RDR is enabled to collect the key data during training and 626 save key data in the fault scenario. When set to true, the RDR will be turned on. 627 When set to false, the RDR will be turned off. 628 - path: sets the path where RDR saves data. The current path must be absolute. 629 630 Memory reuse: 631 632 - mem_Reuse: controls whether the memory reuse function is turned on. When set to True, 633 - the memory reuse function is turned on. When set to False, the memory reuse function is turned off. 634 635 precompile_only (bool): Whether to only precompile the network. Default: False. 636 If set to True, the network will only be compiled, not executed. 637 reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. 638 Each node has a scope. A scope of a subnode is the name of its parent node. If reserve_class_name_in_scope 639 is set to True, the class name will be saved after keyword 'net-' in the scope. 640 For example: 641 642 Default/net-Net1/net-Net2 (reserve_class_name_in_scope=True) 643 644 Default/net/net (reserve_class_name_in_scope=False) 645 646 pynative_synchronize (bool): Whether to enable synchronous execution of the device in PyNative mode. 647 Default: False. When the value is set to False, the operator is executed asynchronously on the device. 648 When an error occurs in the execution of the operator, the specific error script code location cannot 649 be located, when the value is set to True, the operator is executed synchronously on the device. It will 650 reduce the execution performance of the program. At this time, when an error occurs in the execution of 651 the operator, the location of the error script code can be located according to the call stack of the error. 652 mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: GRAPH_MODE(0). 653 GRAPH_MODE or PYNATIVE_MODE can be set by `mode` attribute and both modes support all backends, default 654 mode is GRAPH_MODE. 655 enable_graph_kernel (bool): Whether to enable graph kernel fusion to optimize network execution performance. 656 Default: False. 657 Indicates whether to enable image-computing convergence to optimize network execution performance. 658 If enable_graph_kernel is set to True, acceleration can be enabled. 659 For details of graph kernel fusion, please check 660 `Enabling Graph Kernel Fusion <https://www.mindspore.cn/docs/programming_guide 661 /en/master/enable_graph_kernel_fusion.html>`_. 662 graph_kernel_flags (str) – 663 Optimization options of graph kernel fusion, and the priority is higher when it conflicts 664 with enable_graph_kernel. Only for experienced users. 665 For example, context.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text"). Some general options: 666 667 - opt_level: Set the optimization level. 668 Default: 2. Graph kernel fusion can be enabled equivalently by setting opt_level greater than 0. 669 Available values are: 670 671 - 0: Disable graph kernel fusion; 672 - 1: enable the basic fusion of operators; 673 - 2: includes all optimizations of level 1, 674 and turns on more optimizations such as CSE, arithmetic simplification and so on; 675 - 3: includes all optimizations of level 2, and turns on more optimizations such as SitchingFusion, 676 ParallelFusion and so on. Optimizations of this level are radical and unstable in some scenarios. 677 Be caution when using this level. 678 679 - dump_as_text: dump detail info as text files. Default: false. 680 681 More options can refer to the implementation code. These options can also be set by environment 682 variable MS_GRAPH_KERNEL_FLAGS, without modifying network source code. 683 For example, export MS_GRAPH_KERNEL_FLAGS="--opt_level=2 --dump_as_text". 684 enable_reduce_precision (bool): Whether to enable precision reduction. Default: True. 685 If set to True: user specified precision is not supported, the precision will change automatically. 686 If set to False: if the specified precision of the use case is not specified, an error will 687 be reported and exit; 688 For example, on the ascend backend, conv2d only supports fp16 input. Under the fp32 input condition, 689 set true will automatically insert the cast operator to convert fp16, and the flash personnel 690 will report an error and exit. 691 auto_tune_mode (str): The mode of auto tune when op building, get the best tiling performance. 692 Default: NO_TUNE. The value must be in ['RL', 'GA', 'RL,GA']. 693 694 - RL: Reinforcement Learning tune. 695 - GA: Genetic Algorithm tune. 696 - RL,GA: When both RL and GA optimization are enabled, the tool automatically selects RL or GA based on 697 different types of operators in the network model. The sequence of RL and GA is not differentiated. 698 (Automatic selection). 699 700 For more information about the enable operator tuning tool settings, please check 701 `Enable the operator optimization tool <https://www.mindspore.cn/docs/programming_guide/en 702 /master/enable_auto_tune.html>`_. 703 check_bprop (bool): Whether to check back propagation nodes. The checking ensures that the shape and dtype 704 of back propagation node outputs is the same as input parameters. Default: False. 705 max_call_depth (int): Specify the maximum depth of function call. Must be positive integer. Default: 1000. 706 The max_call_depth parameter needs to be set when the nested call is too deep or the number 707 of subgraphs is too large. If max_call_depth is set larger than before, the system max stack depth should be 708 set larger too, otherwise a `core dumped` exception may be raised because of system stack overflow. 709 enable_sparse (bool): Whether to enable sparsity feature. Default: False. 710 For details of sparsity and sparse tensor, please check 711 `sparse tensor <https://www.mindspore.cn/docs/programming_guide/en/r1.5/tensor.html#sparse-tensor>`_. 712 grad_for_scalar (bool): Whether to get gradient for scalar. Default: False. 713 When grad_for_scalar is set to True, the function's scalar input can be derived. 714 The default value is False. Because the back-end does not support scaling operations currently, 715 this interface only supports simple operations that can be deduced by the front-end. 716 save_compile_cache (bool): Whether to cache the graph compiled by front-end. Default: False. 717 After save_compile_cache is set to True, a hardware-independent compilation cache is 718 generated and exported to a MINDIR file, This is an experimental prototype that is 719 subject to change and/or deletion. 720 load_compile_cache (bool): Whether to use the cache of the graph compiled by front-end. 721 This parameter must be used together with save_compile_cache. After save_compile_cache is set to True, 722 a hardware-independent compilation cache is generated and exported to a MINDIR file. 723 When the network is executed again, if load_compile_cache is set to True, the compile cache is loaded. 724 By now, we do not support automatic checking for changes. 725 Default: False. 726 This is an experimental prototype that is subject to change and/or deletion. 727 Raises: 728 ValueError: If input key is not an attribute in context. 729 730 Examples: 731 >>> context.set_context(mode=context.PYNATIVE_MODE) 732 >>> context.set_context(precompile_only=True) 733 >>> context.set_context(device_target="Ascend") 734 >>> context.set_context(device_id=0) 735 >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms") 736 >>> context.set_context(enable_reduce_precision=True) 737 >>> context.set_context(enable_dump=True, save_dump_path=".") 738 >>> context.set_context(enable_graph_kernel=True) 739 >>> context.set_context(graph_kernel_flags="--opt_level=2 --dump_as_text") 740 >>> context.set_context(reserve_class_name_in_scope=True) 741 >>> context.set_context(variable_memory_max_size="6GB") 742 >>> context.set_context(enable_profiling=True, 743 ... profiling_options='{"output":"/home/data/output","training_trace":"on"}') 744 >>> context.set_context(check_bprop=True) 745 >>> context.set_context(max_device_memory="3.5GB") 746 >>> context.set_context(print_file_path="print.pb") 747 >>> context.set_context(enable_sparse=True) 748 >>> context.set_context(max_call_depth=80) 749 >>> context.set_context(env_config_path="./env_config.json") 750 >>> context.set_context(auto_tune_mode="GA,RL") 751 >>> context.set_context(grad_for_scalar=True) 752 >>> context.set_context(save_compile_cache=True) 753 >>> context.set_context(load_compile_cache=True) 754 >>> context.set_context(pynative_synchronize=True) 755 """ 756 ctx = _context() 757 # set device target first 758 if 'device_target' in kwargs: 759 ctx.set_device_target(kwargs['device_target']) 760 device = ctx.get_param(ms_ctx_param.device_target) 761 if not device.lower() in __device_target__: 762 raise ValueError(f"Error, package type {__package_name__} support device type {__device_target__}, " 763 f"but got device target {device}") 764 device = ctx.get_param(ms_ctx_param.device_target) 765 for key, value in kwargs.items(): 766 if key in ('enable_profiling', 'profiling_options', 'enable_auto_mixed_precision', 767 'enable_dump', 'save_dump_path'): 768 logger.warning(f" '{key}' parameters will be deprecated." 769 "For details, please see the interface parameter API comments") 770 continue 771 if not _check_target_specific_cfgs(device, key): 772 continue 773 if hasattr(ctx, key): 774 setattr(ctx, key, value) 775 continue 776 if key in ctx.setters: 777 ctx.setters[key](ctx, value) 778 continue 779 # enum variables beginning with '_' are for internal use 780 if key in ms_ctx_param.__members__ and key[0] != '_': 781 ctx.set_param(ms_ctx_param.__members__[key], value) 782 continue 783 raise ValueError("Set context keyword %s is not recognized!" % key) 784 785 786def get_context(attr_key): 787 """ 788 Get context attribute value according to the input key. 789 If some attributes are not set, they will be automatically obtained. 790 791 Args: 792 attr_key (str): The key of the attribute. 793 794 Returns: 795 Object, The value of given attribute key. 796 797 Raises: 798 ValueError: If input key is not an attribute in context. 799 Examples: 800 >>> context.get_context("device_target") 801 >>> context.get_context("device_id") 802 """ 803 ctx = _context() 804 device = ctx.get_param(ms_ctx_param.device_target) 805 _ = _check_target_specific_cfgs(device, attr_key) 806 if hasattr(ctx, attr_key): 807 return getattr(ctx, attr_key) 808 # enum variables beginning with '_' are for internal use 809 if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_': 810 return ctx.get_param(ms_ctx_param.__members__[attr_key]) 811 raise ValueError("Get context keyword %s is not recognized!" % attr_key) 812 813 814class ParallelMode: 815 """ 816 Parallel mode options. 817 818 There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL", 819 "HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE". 820 821 - STAND_ALONE: Only one processor is working. 822 - DATA_PARALLEL: Distributes the data across different processors. 823 - HYBRID_PARALLEL: Achieves data parallelism and model parallelism manually. 824 - SEMI_AUTO_PARALLEL: Achieves data parallelism and model parallelism by setting parallel strategies. 825 - AUTO_PARALLEL: Achieves parallelism automatically. 826 827 MODE_LIST: The list of all supported parallel modes. 828 """ 829 830 STAND_ALONE = "stand_alone" 831 DATA_PARALLEL = "data_parallel" 832 HYBRID_PARALLEL = "hybrid_parallel" 833 SEMI_AUTO_PARALLEL = "semi_auto_parallel" 834 AUTO_PARALLEL = "auto_parallel" 835 MODE_LIST = [STAND_ALONE, DATA_PARALLEL, HYBRID_PARALLEL, SEMI_AUTO_PARALLEL, AUTO_PARALLEL] 836 837 838@args_type_check(enable_ps=bool) 839def set_ps_context(**kwargs): 840 """ 841 Set parameter server training mode context. 842 843 Note: 844 Some other environment variables should also be set for parameter server training mode. 845 These environment variables are listed below: 846 847 MS_SERVER_NUM: Server number 848 849 MS_WORKER_NUM: Worker number 850 851 MS_SCHED_HOST: Scheduler IP address 852 853 MS_SCHED_PORT: Scheduler port 854 855 MS_ROLE: The role of this process: 856 857 MS_SCHED: represents the scheduler, 858 859 MS_WORKER: represents the worker, 860 861 MS_PSERVER: represents the Server 862 863 Args: 864 enable_ps (bool): Whether to enable parameter server training mode. 865 Only after enable_ps is set True, the environment variables will be effective. 866 Default: False. 867 868 Raises: 869 ValueError: If input key is not the attribute in parameter server training mode context. 870 871 Examples: 872 >>> context.set_ps_context(enable_ps=True) 873 """ 874 _set_ps_context(**kwargs) 875 876 877def get_ps_context(attr_key): 878 """ 879 Get parameter server training mode context attribute value according to the key. 880 881 Args: 882 attr_key (str): The key of the attribute: 883 884 - enable_ps (bool): Whether to enable parameter server training mode. 885 886 Returns: 887 Returns attribute value according to the key. 888 889 Raises: 890 ValueError: If input key is not attribute in auto parallel context. 891 892 Examples: 893 >>> context.get_ps_context(enable_ps) 894 """ 895 return _get_ps_context(attr_key) 896 897 898def reset_ps_context(): 899 """ 900 Reset parameter server training mode context attributes to the default values: 901 902 - enable_ps: False. 903 """ 904 _reset_ps_context() 905 906 907def set_fl_context(**kwargs): 908 """ 909 Set federated learning training mode context. 910 911 Args: 912 enable_fl (bool): Whether to enable federated learning training mode. 913 Default: False. 914 server_mode (str): Describe the server mode, which must one of 'FEDERATED_LEARNING' and 'HYBRID_TRAINING'. 915 Default: 'FEDERATED_LEARNING'. 916 ms_role (str): The process's role in the federated learning mode, 917 which must be one of 'MS_SERVER', 'MS_WORKER' and 'MS_SCHED'. 918 Default: 'MS_SERVER'. 919 worker_num (int): The number of workers. For current version, this must be set to 1 or 0. 920 server_num (int): The number of federated learning servers. Default: 0. 921 scheduler_ip (str): The scheduler IP. Default: '0.0.0.0'. 922 scheduler_port (int): The scheduler port. Default: 6667. 923 fl_server_port (int): The http port of the federated learning server. 924 Normally for each server this should be set to the same value. Default: 6668. 925 enable_fl_client (bool): Whether this process is federated learning client. Default: False. 926 start_fl_job_threshold (int): The threshold count of startFLJob. Default: 1. 927 start_fl_job_time_window (int): The time window duration for startFLJob in millisecond. Default: 3000. 928 share_secrets_ratio (float): The ratio for computing the threshold count of share secrets. Default: 1.0. 929 update_model_ratio (float): The ratio for computing the threshold count of updateModel. Default: 1.0. 930 cipher_time_window (int): The time window duration for each cipher round in millisecond. Default: 300000. 931 reconstruct_secrets_threshold (int): The threshold count of reconstruct threshold. Default: 0. 932 update_model_time_window (int): The time window duration for updateModel in millisecond. Default: 3000. 933 fl_name (string): The federated learning job name. Default: ''. 934 fl_iteration_num (int): Iteration number of federated learning, 935 which is the number of interactions between client and server. Default: 20. 936 client_epoch_num (int): Client training epoch number. Default: 25. 937 client_batch_size (int): Client training data batch size. Default: 32. 938 client_learning_rate (float): Client training learning rate. Default: 0.001. 939 worker_step_num_per_iteration (int): The worker's standalone training step number before communicating with 940 server. Default: 65. 941 dp_eps (float): Epsilon budget of differential privacy mechanism. The smaller the dp_eps, the better the 942 privacy protection effect. Default: 50.0. 943 dp_delta (float): Delta budget of differential privacy mechanism, which is usually equals the reciprocal of 944 client number. The smaller the dp_delta, the better the privacy protection effect. Default: 0.01. 945 dp_norm_clip (float): A factor used for clipping model's weights for differential mechanism. Its value is 946 suggested to be 0.5~2. Default: 1.0. 947 encrypt_type (string): Secure schema for federated learning, which can be 'NOT_ENCRYPT', 'DP_ENCRYPT' or 948 'PW_ENCRYPT'. If 'DP_ENCRYPT', differential privacy schema would be applied for clients and the privacy 949 protection effect would be determined by dp_eps, dp_delta and dp_norm_clip as described above. If 950 'PW_ENCRYPT', pairwise secure aggregation would be applied to protect clients' model from stealing. 951 Default: 'NOT_ENCRYPT'. 952 config_file_path (string): Configuration file path used by recovery. Default: ''. 953 scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202. 954 enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true. 955 client_password (str): Password to decrypt the secret key stored in the client certificate. 956 server_password (str): Password to decrypt the secret key stored in the server certificate. 957 958 Raises: 959 ValueError: If input key is not the attribute in federated learning mode context. 960 961 Examples: 962 >>> context.set_fl_context(enable_fl=True, server_mode='FEDERATED_LEARNING') 963 """ 964 _set_ps_context(**kwargs) 965 966 967def get_fl_context(attr_key): 968 """ 969 Get federated learning mode context attribute value according to the key. 970 971 Args: 972 attr_key (str): The key of the attribute. 973 Please refer to `set_fl_context`'s parameters to decide what key should be passed. 974 975 Returns: 976 Returns attribute value according to the key. 977 978 Raises: 979 ValueError: If input key is not attribute in federated learning mode context. 980 981 Examples: 982 >>> context.get_fl_context("server_mode") 983 """ 984 return _get_ps_context(attr_key) 985