1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Context of auto parallel""" 16import threading 17 18import mindspore.context as context 19import mindspore.log as logger 20from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size 21from mindspore.parallel._ps_context import _is_role_pserver 22from mindspore._c_expression import AutoParallelContext 23from mindspore._checkparam import args_type_check, Validator 24 25_MAX_GROUP_NAME_LEN = 127 26_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" 27_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" 28 29 30class _AutoParallelContext: 31 """ 32 _AutoParallelContext is the environment in which operations are executed 33 34 Note: 35 Create a context through instantiating Context object is not recommended. 36 Should use auto_parallel_context() to get the context since Context is singleton. 37 """ 38 _instance = None 39 _instance_lock = threading.Lock() 40 41 def __init__(self): 42 self._context_handle = AutoParallelContext.get_instance() 43 self._dataset_strategy_using_str = True 44 45 def __new__(cls): 46 if cls._instance is None: 47 cls._instance_lock.acquire() 48 cls._instance = object.__new__(cls) 49 cls._instance_lock.release() 50 return cls._instance 51 52 def check_context_handle(self): 53 """ 54 Check context handle. 55 56 Raises: 57 ValueError: If the context handle is none. 58 """ 59 if self._context_handle is None: 60 raise ValueError("Context handle is none in context!!!") 61 62 def set_device_num(self, device_num): 63 """ 64 Set device num for auto parallel. 65 66 Args: 67 device_num (int): The device number. 68 69 Raises: 70 ValueError: If the device num is not in [1, 4096]. 71 """ 72 self.check_context_handle() 73 if device_num < 1 or device_num > 4096: 74 raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num)) 75 self._context_handle.set_device_num(device_num) 76 77 def get_device_num(self): 78 """Get device num.""" 79 self.check_context_handle() 80 return self._context_handle.get_device_num() 81 82 def set_global_rank(self, global_rank): 83 """ 84 Set global rank for auto parallel. 85 86 Args: 87 global_rank (int): The rank id of current rank. 88 89 Raises: 90 ValueError: If the global rank is not in [1, 4096]. 91 """ 92 self.check_context_handle() 93 if global_rank < 0 or global_rank > 4095: 94 raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank)) 95 self._context_handle.set_global_rank(global_rank) 96 97 def get_global_rank(self): 98 """Get current rank id.""" 99 self.check_context_handle() 100 return self._context_handle.get_global_rank() 101 102 def set_pipeline_stages(self, stages): 103 """Set the stages of the pipeline""" 104 if isinstance(stages, bool): 105 raise TypeError("The type of pipeline_stage_num must be int, but got bool.") 106 if not isinstance(stages, int): 107 raise TypeError("The type of pipeline_stage_num must be int.") 108 if stages < 1: 109 raise ValueError("pipeline_stage_num can't be less than 1.") 110 backend = context.get_context("device_target") 111 if backend == "GPU" and stages > 1: 112 raise RuntimeError("Now GPU don't support pipeline parallel.") 113 self.check_context_handle() 114 self._context_handle.set_pipeline_stage_split_num(stages) 115 116 def get_pipeline_stages(self): 117 """Get the stages of the pipeline""" 118 self.check_context_handle() 119 return self._context_handle.get_pipeline_stage_split_num() 120 121 def set_gradients_mean(self, gradients_mean): 122 """ 123 Set gradients_mean flag. 124 125 Note: 126 If gradients_mean is true, it will insert a div operator after parameter gradients allreduce. 127 128 Args: 129 gradients_mean (bool): The gradients_mean flag. 130 """ 131 self.check_context_handle() 132 self._context_handle.set_gradients_mean(gradients_mean) 133 134 def get_gradients_mean(self): 135 """Get gradients_mean flag.""" 136 self.check_context_handle() 137 return self._context_handle.get_gradients_mean() 138 139 def set_gradient_fp32_sync(self, gradient_fp32_sync): 140 """ 141 Set gradient_fp32_sync. 142 143 Note: 144 If gradient_fp32_sync is true, 145 it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. 146 147 Args: 148 gradient_fp32_sync (bool): The gradient_fp32_sync flag. 149 """ 150 self.check_context_handle() 151 self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) 152 153 def get_gradient_fp32_sync(self): 154 """Get gradient_fp32_sync flag.""" 155 self.check_context_handle() 156 return self._context_handle.get_gradient_fp32_sync() 157 158 def set_loss_repeated_mean(self, loss_repeated_mean): 159 """ 160 Set loss_repeated_mean flag. 161 162 Note: 163 If loss_repeated_mean is true, 164 Distributed automatic differentiation will perform a mean operator 165 in backward in the case of repeated calculations. 166 167 Args: 168 loss_repeated_mean (bool): The loss_repeated_mean flag. 169 """ 170 if not isinstance(loss_repeated_mean, bool): 171 raise TypeError(f"The type of loss_repeated_mean must be bool, but got {type(loss_repeated_mean)}.") 172 self.check_context_handle() 173 self._context_handle.set_loss_repeated_mean(loss_repeated_mean) 174 175 def get_loss_repeated_mean(self): 176 """Get loss_repeated_mean flag.""" 177 self.check_context_handle() 178 return self._context_handle.get_loss_repeated_mean() 179 180 def set_parallel_mode(self, parallel_mode): 181 """ 182 Set parallel mode for auto parallel. 183 184 Args: 185 parallel_mode (str): The parallel mode of auto parallel. 186 187 Raises: 188 ValueError: If parallel mode is not supported. 189 """ 190 self.check_context_handle() 191 ret = self._context_handle.set_parallel_mode(parallel_mode) 192 if ret is False: 193 raise ValueError("Parallel mode does not support {}".format(parallel_mode)) 194 195 def get_parallel_mode(self): 196 """Get parallel mode.""" 197 self.check_context_handle() 198 if _is_role_pserver(): 199 return context.ParallelMode.STAND_ALONE 200 return self._context_handle.get_parallel_mode() 201 202 def set_strategy_search_mode(self, auto_parallel_search_mode): 203 """ 204 Set search mode of strategy. 205 206 Args: 207 auto_parallel_search_mode (str): The search mode of strategy. 208 """ 209 self.check_context_handle() 210 ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode) 211 if ret is False: 212 raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode)) 213 214 def get_strategy_search_mode(self): 215 """Get search mode of strategy.""" 216 self.check_context_handle() 217 return self._context_handle.get_strategy_search_mode() 218 219 def set_parameter_broadcast(self, parameter_broadcast): 220 """ 221 Set parameter broadcast. 222 223 Args: 224 parameter_broadcast (bool): Parameter broadcast or not. 225 """ 226 self.check_context_handle() 227 self._context_handle.set_parameter_broadcast(parameter_broadcast) 228 229 def get_parameter_broadcast(self): 230 """Get parameter broadcast flag.""" 231 self.check_context_handle() 232 return self._context_handle.get_parameter_broadcast() 233 234 def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file): 235 """ 236 Set strategy checkpoint load path. 237 238 Args: 239 strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint. 240 """ 241 self.check_context_handle() 242 self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file) 243 244 def get_strategy_ckpt_load_file(self): 245 """Get strategy checkpoint load path.""" 246 self.check_context_handle() 247 return self._context_handle.get_strategy_ckpt_load_file() 248 249 def set_full_batch(self, full_batch): 250 """ 251 Set whether load full batch on each device. 252 253 Args: 254 full_batch (bool): True if load full batch on each device. 255 """ 256 self.check_context_handle() 257 self._context_handle.set_full_batch(full_batch) 258 259 def get_full_batch(self): 260 """Get whether load full batch on each device.""" 261 self.check_context_handle() 262 if _is_role_pserver(): 263 return False 264 return self._context_handle.get_full_batch() 265 266 def set_dataset_strategy(self, dataset_strategy): 267 """ 268 Set dataset sharding strategy. 269 270 Args: 271 dataset_strategy (str or tuple(tuple)): The dataset sharding strategy. 272 """ 273 self.check_context_handle() 274 if isinstance(dataset_strategy, str): 275 if dataset_strategy not in ("full_batch", "data_parallel"): 276 raise ValueError("The dataset_strategy string should be 'full_batch' or 'data_parallel', " 277 "otherwise, incoming tuple(tuple) type strategy") 278 self._context_handle.set_full_batch(dataset_strategy == "full_batch") 279 self._dataset_strategy_using_str = True 280 return 281 if not isinstance(dataset_strategy, tuple): 282 raise TypeError(f'strategy must be str or tuple type, but got:{type(dataset_strategy)}') 283 for ele in dataset_strategy: 284 if not isinstance(ele, tuple): 285 raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}') 286 for dim in ele: 287 if not isinstance(dim, int): 288 raise TypeError(f'The dim of each strategy value must be int type, but got:{type(dim)}') 289 self._dataset_strategy_using_str = False 290 self._context_handle.set_dataset_strategy(dataset_strategy) 291 292 def get_dataset_strategy(self): 293 """Get dataset sharding strategy.""" 294 self.check_context_handle() 295 if self._dataset_strategy_using_str: 296 if self._context_handle.get_full_batch(): 297 return "full_batch" 298 return "data_parallel" 299 return self._context_handle.get_dataset_strategy() 300 301 def set_grad_accumulation_step(self, grad_accumulation_step): 302 """ 303 Set grad accumulation step. 304 305 Args: 306 grad_accumulation_step (int): The grad accumulation step. 307 """ 308 self.check_context_handle() 309 Validator.check_positive_int(grad_accumulation_step) 310 self._context_handle.set_grad_accumulation_step(grad_accumulation_step) 311 312 def get_grad_accumulation_step(self): 313 """Get grad accumulation step.""" 314 self.check_context_handle() 315 return self._context_handle.get_grad_accumulation_step() 316 317 def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): 318 """ 319 Set strategy checkpoint save path. 320 321 Args: 322 strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. 323 """ 324 self.check_context_handle() 325 import os 326 dir_path = os.path.dirname(strategy_ckpt_save_file) 327 if dir_path and not os.path.exists(dir_path): 328 os.makedirs(dir_path) 329 self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) 330 331 def get_strategy_ckpt_save_file(self): 332 """Get strategy checkpoint save path.""" 333 self.check_context_handle() 334 return self._context_handle.get_strategy_ckpt_save_file() 335 336 def set_group_ckpt_save_file(self, group_ckpt_save_file): 337 """Set group checkpoint save path.""" 338 self.check_context_handle() 339 import os 340 dir_path = os.path.dirname(group_ckpt_save_file) 341 if dir_path and not os.path.exists(dir_path): 342 os.makedirs(dir_path) 343 self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file) 344 345 def get_parameter_broadcast_is_set(self): 346 """Get parameter broadcast is set or not.""" 347 self.check_context_handle() 348 return self._context_handle.get_parameter_broadcast_is_set() 349 350 def set_all_reduce_fusion_split_indices(self, indices, group=""): 351 """ 352 Set allreduce fusion strategy by parameters indices. 353 354 Args: 355 indices (list): Indices list. 356 group (str): The communication group of hccl/nccl. 357 358 Raises: 359 TypeError: If type of indices item is not int. 360 TypeError: If group is not a python str. 361 """ 362 self.check_context_handle() 363 if not indices: 364 raise ValueError('indices can not be empty') 365 366 if isinstance(indices, (list)): 367 for index in indices: 368 if not isinstance(index, int) or isinstance(index, bool): 369 raise TypeError(f"The type of index must be int, but got {type(index)}.") 370 else: 371 raise TypeError('indices must be a python list') 372 373 if len(set(indices)) != len(indices): 374 raise ValueError('indices has duplicate elements') 375 376 if sorted(indices) != indices: 377 raise ValueError('elements in indices must be sorted in ascending order') 378 379 new_group = self._check_and_default_group(group) 380 381 self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group) 382 if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"): 383 _set_fusion_strategy_by_idx(indices) 384 385 def get_all_reduce_fusion_split_indices(self, group=""): 386 """ 387 Get allreduce fusion split indices. 388 389 Args: 390 group (str): The communication group of hccl/nccl. 391 392 Returns: 393 Return split sizes list according to the group. 394 395 Raises: 396 TypeError: If group is not a python str. 397 """ 398 self.check_context_handle() 399 new_group = self._check_and_default_group(group) 400 return self._context_handle.get_all_reduce_fusion_split_indices(new_group) 401 402 def set_all_reduce_fusion_split_sizes(self, sizes, group=""): 403 """ 404 Set allreduce fusion strategy by parameters data sizes. 405 406 Args: 407 sizes (list): Sizes list. 408 group (str): The communication group of hccl/nccl. 409 410 Raises: 411 TypeError: If type of sizes item is not int. 412 TypeError: If group is not a python str. 413 """ 414 self.check_context_handle() 415 if isinstance(sizes, (list)): 416 for size in sizes: 417 if not isinstance(size, int) or isinstance(size, bool): 418 raise TypeError(f"The type of size must be int, but got {type(size)}.") 419 else: 420 raise TypeError('sizes must be a python list') 421 422 new_group = self._check_and_default_group(group) 423 self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group) 424 if context.get_context("device_target") == "Ascend": 425 _set_fusion_strategy_by_size(sizes) 426 427 def get_all_reduce_fusion_split_sizes(self, group=""): 428 """ 429 Get allreduce fusion split sizes. 430 431 Args: 432 group (str): The communication group of hccl/nccl. 433 434 Returns: 435 Return split sizes list according to the group. 436 437 Raises: 438 TypeError: If group is not a python str. 439 """ 440 self.check_context_handle() 441 new_group = self._check_and_default_group(group) 442 return self._context_handle.get_all_reduce_fusion_split_sizes(new_group) 443 444 def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): 445 """ 446 Set enable/disable all reduce fusion. 447 448 Args: 449 enable_all_reduce_fusion (bool): Enable/disable all reduce fusion. 450 """ 451 self.check_context_handle() 452 if not isinstance(enable_all_reduce_fusion, bool): 453 raise TypeError('enable_all_reduce_fusion is invalid type') 454 self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion) 455 456 def get_enable_all_reduce_fusion(self): 457 """Get all reduce fusion flag.""" 458 self.check_context_handle() 459 return self._context_handle.get_enable_all_reduce_fusion() 460 461 def get_device_num_is_set(self): 462 """Get device number is set or not.""" 463 self.check_context_handle() 464 return self._context_handle.get_device_num_is_set() 465 466 def get_global_rank_is_set(self): 467 """Get global rank is set or not.""" 468 self.check_context_handle() 469 return self._context_handle.get_global_rank_is_set() 470 471 def set_enable_parallel_optimizer(self, enable_parallel_optimizer): 472 """ 473 Set enable/disable parallel optimizer. 474 475 Args: 476 set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer. 477 """ 478 self.check_context_handle() 479 if not isinstance(enable_parallel_optimizer, bool): 480 raise TypeError('enable_parallel_optimizer is invalid type') 481 self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) 482 483 def get_enable_parallel_optimizer(self): 484 """Get parallel optimizer flag.""" 485 self.check_context_handle() 486 return self._context_handle.get_enable_parallel_optimizer() 487 488 def set_sharding_propagation(self, sharding_propagation): 489 """ 490 Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators 491 will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm 492 will search the desired strategies. 493 Default: False. 494 495 Args: 496 sharding_propagation (bool): Enable/disable strategy propagation. 497 """ 498 self.check_context_handle() 499 if not isinstance(sharding_propagation, bool): 500 raise TypeError("'sharding_propagation' is an invalid type.") 501 self._context_handle.set_sharding_propagation(sharding_propagation) 502 503 def get_sharding_propagation(self): 504 """Get the value of sharding strategy propagation.""" 505 self.check_context_handle() 506 return self._context_handle.get_sharding_propagation() 507 508 def set_enable_alltoall(self, enable_a2a): 509 """ 510 Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. 511 Default: False. 512 513 Args: 514 enable_a2a (bool): Enable/disable AllToAll. 515 """ 516 self.check_context_handle() 517 if not isinstance(enable_a2a, bool): 518 raise TypeError("'enable_a2a' is an invalid type.") 519 self._context_handle.set_enable_alltoall(enable_a2a) 520 521 def get_enable_alltoall(self): 522 """Get the value of enabling AllToAll.""" 523 self.check_context_handle() 524 return self._context_handle.get_enable_alltoall() 525 526 def set_communi_parallel_mode(self, communi_parallel_mode): 527 """ 528 Set communication parallel mode. 529 530 Args: 531 communi_parallel_mode (str): The communication parallel mode. 532 533 Raises: 534 ValueError: If parallel mode is not supported. 535 """ 536 if not isinstance(communi_parallel_mode, str): 537 raise TypeError(f"The type of communi_parallel_mode must be str, \ 538 but got {type(communi_parallel_mode)}.") 539 self.check_context_handle() 540 ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode) 541 if ret is False: 542 raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode)) 543 544 def get_communi_parallel_mode(self): 545 """Get communication parallel mode.""" 546 self.check_context_handle() 547 return self._context_handle.get_communi_parallel_mode() 548 549 def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): 550 """ 551 Set optimizer_weight_shard_size. 552 553 Args: 554 optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel 555 optimizer across devices. 556 """ 557 self.check_context_handle() 558 if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool): 559 raise TypeError(f"The type of optimizer_weight_shard_size must be int, \ 560 but got {type(optimizer_weight_shard_size)}.") 561 if optimizer_weight_shard_size <= 1: 562 logger.warning("The setting 'optimizer_weight_shard_size' is invalid. " 563 "Please use the integer larger than 1.") 564 return 565 self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size) 566 567 def get_optimizer_weight_shard_size(self): 568 """Get optimizer_weight_shard_size.""" 569 self.check_context_handle() 570 return self._context_handle.get_optimizer_weight_shard_size() 571 572 def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save): 573 """ 574 Set optimizer_weight_shard_aggregated_save. 575 576 Args: 577 optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when 578 enable parallel optimizer. 579 """ 580 self.check_context_handle() 581 if not isinstance(optimizer_weight_shard_aggregated_save, bool): 582 raise TypeError('optimizer_weight_shard_aggregated_save is invalid type') 583 self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save) 584 585 586 def get_optimizer_weight_shard_aggregated_save(self): 587 """Get optimizer_weight_shard_size.""" 588 self.check_context_handle() 589 return self._context_handle.get_optimizer_weight_shard_aggregated_save() 590 591 592 def reset(self): 593 """Reset all settings.""" 594 self.check_context_handle() 595 self._context_handle.reset() 596 597 598 def _check_and_default_group(self, group): 599 """Validate the given group, if group is empty, returns a default fusion group""" 600 if isinstance(group, (str)): 601 group_len = len(group) 602 if group_len > _MAX_GROUP_NAME_LEN: 603 raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') 604 else: 605 raise TypeError('Group must be a python str') 606 607 if group == "": 608 if context.get_context("device_target") == "Ascend": 609 group = _DEFAULT_HCCL_FUSION_GROUP_NAME 610 else: 611 group = _DEFAULT_NCCL_FUSION_GROUP_NAME 612 return group 613 614 615_auto_parallel_context = None 616 617 618def auto_parallel_context(): 619 """ 620 Get the global _auto_parallel_context, if it is not created, create a new one. 621 622 Returns: 623 _AutoParallelContext, the global auto parallel context. 624 """ 625 global _auto_parallel_context 626 if _auto_parallel_context is None: 627 _auto_parallel_context = _AutoParallelContext() 628 return _auto_parallel_context 629 630 631_set_auto_parallel_context_func_map = { 632 "device_num": auto_parallel_context().set_device_num, 633 "global_rank": auto_parallel_context().set_global_rank, 634 "gradients_mean": auto_parallel_context().set_gradients_mean, 635 "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, 636 "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, 637 "pipeline_stages": auto_parallel_context().set_pipeline_stages, 638 "parallel_mode": auto_parallel_context().set_parallel_mode, 639 "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode, 640 "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, 641 "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, 642 "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, 643 "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file, 644 "full_batch": auto_parallel_context().set_full_batch, 645 "dataset_strategy": auto_parallel_context().set_dataset_strategy, 646 "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, 647 "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, 648 "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, 649 "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, 650 "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, 651 "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, 652 "sharding_propagation": auto_parallel_context().set_sharding_propagation, 653 "enable_alltoall": auto_parallel_context().set_enable_alltoall} 654 655 656_get_auto_parallel_context_func_map = { 657 "device_num": auto_parallel_context().get_device_num, 658 "global_rank": auto_parallel_context().get_global_rank, 659 "gradients_mean": auto_parallel_context().get_gradients_mean, 660 "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, 661 "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, 662 "pipeline_stages": auto_parallel_context().get_pipeline_stages, 663 "parallel_mode": auto_parallel_context().get_parallel_mode, 664 "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode, 665 "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, 666 "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, 667 "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, 668 "full_batch": auto_parallel_context().get_full_batch, 669 "dataset_strategy": auto_parallel_context().get_dataset_strategy, 670 "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, 671 "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, 672 "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, 673 "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, 674 "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, 675 "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save, 676 "sharding_propagation": auto_parallel_context().get_sharding_propagation, 677 "enable_alltoall": auto_parallel_context().get_enable_alltoall} 678 679 680@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, 681 loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, 682 parameter_broadcast=bool, strategy_ckpt_load_file=str, 683 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, 684 grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, 685 communi_parallel_mode=str, optimizer_weight_shard_size=int, 686 optimizer_weight_shard_aggregated_save=bool, 687 sharding_propagation=bool, enable_alltoall=bool) 688 689def _set_auto_parallel_context(**kwargs): 690 """ 691 Set auto parallel context. 692 693 Note: 694 Attribute name is required for setting attributes. 695 696 Args: 697 device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. 698 global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. 699 gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False. 700 loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated 701 calculations. Default: True. 702 gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. 703 Default: True. 704 parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", 705 "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". 706 707 - stand_alone: Only one processor working. 708 709 - data_parallel: Distributing the data across different processors. 710 711 - hybrid_parallel: Achieving data parallelism and model parallelism manually. 712 713 - semi_auto_parallel: Achieving data parallelism and model parallelism by 714 setting parallel strategies. 715 716 - auto_parallel: Achieving parallelism automatically. 717 auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming" 718 and "dynamic_programming". Default: "dynamic_programming". 719 720 - recursive_programming: Recursive programming search mode. 721 722 - dynamic_programming: Dynamic programming search mode. 723 parameter_broadcast (bool): Indicating whether to broadcast parameters before training. 724 "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter 725 broadcast. Default: False. 726 strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' 727 strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' 728 group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: '' 729 full_batch (bool): Whether to load the whole batch on each device. Default: False. 730 dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel". 731 enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. 732 all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. 733 pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how 734 the devices are distributed alone the pipeline. The total devices will be divided into 735 'pipeline_stags' stages. This currently could only be used when 736 parallel mode semi_auto_parallel is enabled. Default: 0 737 communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", 738 "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". 739 740 - all_group_parallel: All communication groups are in parallel. 741 742 - same_server_group_parallel: Only the communication groups within the same server are parallel. 743 744 - no_group_parallel: All communication groups are not parallel. 745 optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer. 746 It should be larger than one and less than or equal with the data parallel size. 747 Default: -1, which means fully use parallel optimizer in data parallel dimension. 748 optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel 749 optimizer. Default: False. 750 sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, 751 the strategy-configured operators will propagate the strategies to other 752 operators with minimum redistribution cost; otherwise, the algorithm will 753 search the desired strategies. Default: False. 754 enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to 755 circumvent AllToAll. Default: False. 756 757 Raises: 758 ValueError: If input key is not attribute in auto parallel context. 759 """ 760 for key, value in kwargs.items(): 761 if key not in _set_auto_parallel_context_func_map: 762 raise ValueError("Set context keyword %s is not recognized!" % key) 763 set_func = _set_auto_parallel_context_func_map[key] 764 set_func(value) 765 766 767def _get_auto_parallel_context(attr_key): 768 """ 769 Get auto parallel context attribute value according to the key. 770 771 Args: 772 attr_key (str): The key of the attribute. 773 774 Returns: 775 Return attribute value according to the key. 776 777 Raises: 778 ValueError: If input key is not attribute in auto parallel context. 779 """ 780 if attr_key not in _get_auto_parallel_context_func_map: 781 raise ValueError("Get context keyword %s is not recognized!" % attr_key) 782 get_func = _get_auto_parallel_context_func_map[attr_key] 783 return get_func() 784 785 786def _reset_auto_parallel_context(): 787 """ 788 Reset auto parallel context attributes to the default values: 789 790 - device_num: 1. 791 - global_rank: 0. 792 - gradients_mean: False. 793 - gradient_fp32_sync: True. 794 - parallel_mode: "stand_alone". 795 - parameter_broadcast: False. 796 - strategy_ckpt_load_file: "" 797 - strategy_ckpt_save_file: "" 798 - enable_parallel_optimizer: False 799 - auto_parallel_search_mode: dynamic_programming 800 - pipeline_stages: 0 801 """ 802 auto_parallel_context().reset() 803