1# Copyright 2020-2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Context of auto parallel""" 16from __future__ import absolute_import 17import os 18import copy 19import threading 20from mindspore import context 21import mindspore.log as logger 22from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size 23from mindspore.parallel._ps_context import _is_role_pserver 24from mindspore._c_expression import AutoParallelContext 25from mindspore._checkparam import args_type_check 26from mindspore import _checkparam as Validator 27 28_MAX_GROUP_NAME_LEN = 127 29_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1" 30_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1" 31 32 33class _ParallelFusionConfig: 34 """ 35 The key of the Parallel fusion method configuration. 36 """ 37 ALLREDUCE = "allreduce" 38 ALLGATHER = "allgather" 39 REDUCESCATTER = "reducescatter" 40 MODE = "mode" 41 FUSION_CONFIG = "config" 42 AUTO = "auto" 43 INDEX = "index" 44 SIZE = "size" 45 OPENSTATE = "openstate" 46 CONFIG = {"openstate": True, 47 "allreduce": {"mode": "auto", "config": None}, 48 "allgather": {"mode": "auto", "config": None}, 49 "reducescatter": {"mode": "auto", "config": None}} 50 51 @classmethod 52 def reset(cls): 53 cls.CONFIG = {"openstate": True, 54 "allreduce": {"mode": "auto", "config": None}, 55 "allgather": {"mode": "auto", "config": None}, 56 "reducescatter": {"mode": "auto", "config": None}} 57 58 59class _ParallelOptimizerConfig: 60 """ 61 The key of the Parallel Optimizer. There are three 62 """ 63 GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard" 64 PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold" 65 OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size" 66 67 68class _PipelineConfig: 69 """ 70 The key of the Pipeline parallelism. 71 """ 72 PIPELINE_INTERLEAVE = "pipeline_interleave" 73 PIPELINE_SCHEDULER = "pipeline_scheduler" 74 75 76class _PipelineScheduler: 77 PIPELINE_1F1B = "1f1b" 78 PIPELINE_GPIPE = "gpipe" 79 80 81class _AutoParallelContext: 82 """ 83 _AutoParallelContext is the environment in which operations are executed 84 85 Note: 86 Create a context through instantiating Context object is not recommended. 87 Should use auto_parallel_context() to get the context since Context is singleton. 88 """ 89 _instance = None 90 _instance_lock = threading.Lock() 91 92 def __new__(cls): 93 if cls._instance is None: 94 cls._instance_lock.acquire() 95 cls._instance = object.__new__(cls) 96 cls._instance_lock.release() 97 return cls._instance 98 99 def __init__(self): 100 self._context_handle = AutoParallelContext.get_instance() 101 self._dataset_strategy_using_str = True 102 103 def check_context_handle(self): 104 """ 105 Check context handle. 106 107 Raises: 108 ValueError: If the context handle is none. 109 """ 110 if self._context_handle is None: 111 raise ValueError("Context handle is none in context!!!") 112 113 def set_device_num(self, device_num): 114 """ 115 Set device num for auto parallel. 116 117 Args: 118 device_num (int): The device number. 119 120 Raises: 121 ValueError: If the device num is not a positive integer. 122 """ 123 self.check_context_handle() 124 if device_num < 1: 125 raise ValueError("The context configuration parameter 'device_num' must be a positive integer, " 126 "but got the value of device_num : {}.".format(device_num)) 127 from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE 128 self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE) 129 self._context_handle.set_device_num(device_num) 130 131 def get_device_num(self): 132 """Get device num.""" 133 self.check_context_handle() 134 return self._context_handle.get_device_num() 135 136 def set_comm_fusion(self, config): 137 """ 138 Set fusion method for auto parallel. 139 140 Args: 141 config (dict): A dict contains the methods and values for setting the communication fusion. Currently it 142 supports: `allreduce`. 143 144 Raises: 145 KeyError: When key of comm_fusion is not 'allreduce'. 146 """ 147 self.check_context_handle() 148 config = copy.deepcopy(config) 149 if _ParallelFusionConfig.OPENSTATE not in config.keys(): 150 config[_ParallelFusionConfig.OPENSTATE] = True 151 for key in list(config.keys()): 152 if key == _ParallelFusionConfig.ALLREDUCE: 153 self._set_allreduce_comm_fusion(config[key]) 154 elif key == _ParallelFusionConfig.ALLGATHER: 155 self._set_allgather_comm_fusion(config[key], key) 156 elif key == _ParallelFusionConfig.REDUCESCATTER: 157 self._set_allgather_comm_fusion(config[key], key) 158 elif key == _ParallelFusionConfig.OPENSTATE: 159 self._set_openstate_comm_fusion(config[key]) 160 else: 161 raise KeyError("comm fusion type must be openstate," 162 "allreduce, allgather or reducescatter, but got {}".format(key)) 163 if key in _ParallelFusionConfig.CONFIG: 164 _ParallelFusionConfig.CONFIG[key] = config[key] 165 166 def get_comm_fusion(self): 167 """Get comm fusion config.""" 168 self.check_context_handle() 169 return _ParallelFusionConfig.CONFIG 170 171 def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"): 172 """ 173 Set fusion threshold (MB) for auto parallel. 174 175 Args: 176 fusion_threshold (int): The fusion threshold (unit: MB). Default: 64. 177 comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`. 178 179 Raises: 180 ValueError: If the fusion threshold is not in [0, +inf]. 181 """ 182 self.check_context_handle() 183 if fusion_threshold < 0: 184 raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold)) 185 186 if comm_type == _ParallelFusionConfig.ALLREDUCE: 187 self._context_handle.set_fusion_threshold_mb(fusion_threshold) 188 if comm_type == _ParallelFusionConfig.ALLGATHER: 189 self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold) 190 if comm_type == _ParallelFusionConfig.REDUCESCATTER: 191 self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold) 192 193 def fusion_threshold_mb(self): 194 """Get all reduce threshold.""" 195 self.check_context_handle() 196 return self._context_handle.fusion_threshold_mb() 197 198 def allgather_fusion_threshold_mb(self): 199 """Get allgather threshold.""" 200 self.check_context_handle() 201 return self._context_handle.allgather_fusion_threshold_mb() 202 203 def reducescatter_fusion_threshold_mb(self): 204 """Get reducescatter threshold.""" 205 self.check_context_handle() 206 return self._context_handle.reducescatter_fusion_threshold_mb() 207 208 def set_global_rank(self, global_rank): 209 """ 210 Set global rank for auto parallel. 211 212 Args: 213 global_rank (int): The rank id of current rank. 214 215 Raises: 216 ValueError: If the global rank is not in [1, 4096]. 217 """ 218 self.check_context_handle() 219 if global_rank < 0 or global_rank > 4095: 220 raise ValueError("The context configuration parameter 'global_rank' must be in [0, 4095], " 221 "but got the value of global_rank : {}.".format(global_rank)) 222 self._context_handle.set_global_rank(global_rank) 223 224 def get_global_rank(self): 225 """Get current rank id.""" 226 self.check_context_handle() 227 return self._context_handle.get_global_rank() 228 229 def set_pipeline_stages(self, stages): 230 """Set the stages of the pipeline""" 231 if isinstance(stages, bool) or not isinstance(stages, int): 232 raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " 233 "must be int, but got the type : {}.".format(type(stages))) 234 if stages < 1: 235 raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' " 236 "should be greater or equal 1, but got the value of stages : {}.".format(stages)) 237 self.check_context_handle() 238 self._context_handle.set_pipeline_stage_split_num(stages) 239 240 def get_pipeline_stages(self): 241 """Get the stages of the pipeline""" 242 self.check_context_handle() 243 return self._context_handle.get_pipeline_stage_split_num() 244 245 def set_auto_pipeline(self, auto_pipeline): 246 """Set the pipeline stage number to automatic""" 247 if not isinstance(auto_pipeline, bool): 248 raise TypeError("For 'set_auto_parallel_context', the argument 'auto_pipeline' " 249 "must be bool, but got the type : {}.".format(type(auto_pipeline))) 250 self.check_context_handle() 251 self._context_handle.set_auto_pipeline(auto_pipeline) 252 253 def get_auto_pipeline(self): 254 """Get whether the pipeline stage number is automatic""" 255 self.check_context_handle() 256 return self._context_handle.get_auto_pipeline() 257 258 def set_pipeline_result_broadcast(self, pipeline_result_broadcast): 259 """ 260 Set the value of enabling pipeline result broadcast. Default: ``False``. 261 262 Args: 263 pipeline_result_broadcast (bool): Enable/disable broadcast the last stage result to all other stages. 264 """ 265 self.check_context_handle() 266 if not isinstance(pipeline_result_broadcast, bool): 267 raise TypeError("For 'set_auto_parallel_context().set_pipeline_result_broadcast', the argument " 268 "'pipeline_result_broadcast' must be bool, but got the type : {}." 269 .format(type(pipeline_result_broadcast))) 270 self._context_handle.set_pipeline_result_broadcast(pipeline_result_broadcast) 271 272 def get_pipeline_result_broadcast(self): 273 """Get the value of enabling pipeline result broadcast""" 274 self.check_context_handle() 275 return self._context_handle.get_pipeline_result_broadcast() 276 277 def get_pipeline_interleave(self): 278 """Get pipeline interleave flag""" 279 self.check_context_handle() 280 return self._context_handle.get_pipeline_interleave() 281 282 def get_pipeline_scheduler(self): 283 """Get pipeline scheduler""" 284 self.check_context_handle() 285 return self._context_handle.get_pipeline_scheduler() 286 287 def set_pipeline_segments(self, segments): 288 """Set the segments of the pipeline""" 289 if isinstance(segments, bool) or not isinstance(segments, int): 290 raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_segments' " 291 "must be int, but got the type : {}.".format(type(segments))) 292 if segments < 1: 293 raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_segments' " 294 "should be greater or equal 1, but got the value of segments : {}.".format(segments)) 295 self.check_context_handle() 296 self._context_handle.set_pipeline_segment_split_num(segments) 297 298 def get_pipeline_segments(self): 299 """Get the stages of the pipeline""" 300 self.check_context_handle() 301 return self._context_handle.get_pipeline_segment_split_num() 302 303 def set_gradients_mean(self, gradients_mean): 304 """ 305 Set gradients_mean flag. 306 307 Note: 308 If gradients_mean is true, it will insert a div operator after parameter gradients allreduce. 309 310 Args: 311 gradients_mean (bool): The gradients_mean flag. 312 """ 313 self.check_context_handle() 314 self._context_handle.set_gradients_mean(gradients_mean) 315 316 def get_gradients_mean(self): 317 """Get gradients_mean flag.""" 318 self.check_context_handle() 319 return self._context_handle.get_gradients_mean() 320 321 def set_gradient_fp32_sync(self, gradient_fp32_sync): 322 """ 323 Set gradient_fp32_sync. 324 325 Note: 326 If gradient_fp32_sync is true, 327 it will convert tensor type from fp16 to fp32 before parameter gradients allreduce. 328 329 Args: 330 gradient_fp32_sync (bool): The gradient_fp32_sync flag. 331 """ 332 self.check_context_handle() 333 self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync) 334 335 def get_gradient_fp32_sync(self): 336 """Get gradient_fp32_sync flag.""" 337 self.check_context_handle() 338 return self._context_handle.get_gradient_fp32_sync() 339 340 def set_loss_repeated_mean(self, loss_repeated_mean): 341 """ 342 Set loss_repeated_mean flag. 343 344 Note: 345 If loss_repeated_mean is true, 346 Distributed automatic differentiation will perform a mean operator 347 in backward in the case of repeated calculations. 348 349 Args: 350 loss_repeated_mean (bool): The loss_repeated_mean flag. 351 """ 352 if not isinstance(loss_repeated_mean, bool): 353 raise TypeError("For 'set_auto_parallel_context', the argument 'loss_repeated_mean' " 354 "must be bool, but got the type : {}.".format(type(loss_repeated_mean))) 355 self.check_context_handle() 356 self._context_handle.set_loss_repeated_mean(loss_repeated_mean) 357 358 def get_loss_repeated_mean(self): 359 """Get loss_repeated_mean flag.""" 360 self.check_context_handle() 361 return self._context_handle.get_loss_repeated_mean() 362 363 def set_parallel_mode(self, parallel_mode): 364 """ 365 Set parallel mode for auto parallel. 366 367 Args: 368 parallel_mode (str): The parallel mode of auto parallel. 369 370 Raises: 371 ValueError: If parallel mode is not supported. 372 """ 373 self.check_context_handle() 374 run_mode = context.get_context("mode") 375 if run_mode == context.PYNATIVE_MODE and parallel_mode not in ( 376 context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE, 377 context.ParallelMode.AUTO_PARALLEL): 378 raise ValueError(f"Pynative only supports STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL using" 379 f" sharding_propagation under shard function" 380 f" for ParallelMode, " 381 f"but got {parallel_mode.upper()}.") 382 ret = self._context_handle.set_parallel_mode(parallel_mode) 383 if ret is False: 384 raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', " 385 "'data_parallel', 'hybrid_parallel', 'semi_auto_parallel' and 'auto_parallel', " 386 "but got the value : {}.".format(parallel_mode)) 387 388 def get_parallel_mode(self): 389 """Get parallel mode.""" 390 self.check_context_handle() 391 return self._context_handle.get_parallel_mode() 392 393 def set_strategy_search_mode(self, search_mode): 394 """ 395 Set search mode of strategy. 396 397 Args: 398 search_mode (str): The search mode of strategy. 399 """ 400 self.check_context_handle() 401 ret = self._context_handle.set_strategy_search_mode(search_mode) 402 if ret is False: 403 raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support " 404 "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', " 405 "but got the value: {}." 406 .format(search_mode)) 407 408 def get_strategy_search_mode(self): 409 """Get search mode of strategy.""" 410 self.check_context_handle() 411 return self._context_handle.get_strategy_search_mode() 412 413 def set_auto_parallel_search_mode(self, search_mode): 414 """ 415 Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future 416 MindSpore version. 417 418 Args: 419 search_mode (str): The search mode of strategy. 420 """ 421 logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. " 422 "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.") 423 self.check_context_handle() 424 ret = self._context_handle.set_strategy_search_mode(search_mode) 425 if ret is False: 426 raise ValueError("The context configuration parameter 'search_mode' only support " 427 "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', " 428 "but got the value: {}." 429 .format(search_mode)) 430 431 def get_auto_parallel_search_mode(self): 432 """Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future 433 MindSpore version. 434 """ 435 logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. " 436 "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.") 437 self.check_context_handle() 438 return self._context_handle.get_strategy_search_mode() 439 440 def set_sharding_propagation(self, sharding_propagation): 441 """ 442 Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators 443 will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm 444 will search the desired strategies. Default: ``False``. 445 This attribute is replaced by context.set_auto_parallel_context(search_mode="sharding_propagation"). 446 447 Args: 448 sharding_propagation (bool): Enable/disable strategy propagation. 449 """ 450 logger.warning("This attribute is replaced by " 451 "context.set_auto_parallel_context(search_mode='sharding_propagation'), and this attribute will" 452 " be deleted in a future MindSpore version.") 453 self.check_context_handle() 454 if not isinstance(sharding_propagation, bool): 455 raise TypeError("For 'set_auto_parallel_context().set_sharding_propagation', " 456 "the argument 'sharding_propagation' must be bool, but got the type : {}." 457 .format(type(sharding_propagation))) 458 self._context_handle.set_sharding_propagation(sharding_propagation) 459 460 def get_sharding_propagation(self): 461 """Get the value of sharding strategy propagation.""" 462 self.check_context_handle() 463 return self._context_handle.get_sharding_propagation() 464 465 def set_parameter_broadcast(self, parameter_broadcast): 466 """ 467 Set parameter broadcast. 468 469 Args: 470 parameter_broadcast (bool): Parameter broadcast or not. 471 """ 472 self.check_context_handle() 473 self._context_handle.set_parameter_broadcast(parameter_broadcast) 474 475 def get_parameter_broadcast(self): 476 """Get parameter broadcast flag.""" 477 self.check_context_handle() 478 return self._context_handle.get_parameter_broadcast() 479 480 def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file): 481 """ 482 Set strategy checkpoint load path. 483 484 Args: 485 strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint. 486 """ 487 self.check_context_handle() 488 self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file) 489 490 def get_strategy_ckpt_load_file(self): 491 """Get strategy checkpoint load path.""" 492 self.check_context_handle() 493 return self._context_handle.get_strategy_ckpt_load_file() 494 495 def set_full_batch(self, full_batch): 496 """ 497 Set whether load full batch on each device. 498 499 Args: 500 full_batch (bool): True if load full batch on each device. 501 """ 502 self.check_context_handle() 503 self._context_handle.set_full_batch(full_batch) 504 505 def get_full_batch(self): 506 """Get whether load full batch on each device.""" 507 self.check_context_handle() 508 if _is_role_pserver(): 509 return False 510 return self._context_handle.get_full_batch() 511 512 def set_dataset_strategy(self, dataset_strategy): 513 """ 514 Set dataset sharding strategy. 515 516 Args: 517 dataset_strategy (str or tuple(tuple)): The dataset sharding strategy. 518 """ 519 self.check_context_handle() 520 if isinstance(dataset_strategy, str): 521 if dataset_strategy not in ("full_batch", "data_parallel"): 522 raise ValueError("For 'set_auto_parallel_context', the argument " 523 "'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}." 524 .format(dataset_strategy)) 525 self._context_handle.set_full_batch(dataset_strategy == "full_batch") 526 self._dataset_strategy_using_str = True 527 return 528 if not isinstance(dataset_strategy, tuple): 529 raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' " 530 "must be str or tuple type, but got the type : {}.".format(type(dataset_strategy))) 531 for ele in dataset_strategy: 532 if not isinstance(ele, tuple): 533 raise TypeError("For 'set_auto_parallel_context', the element of argument " 534 "'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele))) 535 for dim in ele: 536 if not isinstance(dim, int): 537 raise TypeError("For 'set_auto_parallel_context', the element of argument " 538 "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim))) 539 if context.get_context('mode') == context.PYNATIVE_MODE: 540 raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' " 541 f"or 'data_parallel', but got {dataset_strategy}.") 542 self._dataset_strategy_using_str = False 543 self._context_handle.set_dataset_strategy(dataset_strategy) 544 545 def get_dataset_strategy(self): 546 """Get dataset sharding strategy.""" 547 self.check_context_handle() 548 if self._dataset_strategy_using_str: 549 if self._context_handle.get_full_batch(): 550 return "full_batch" 551 return "data_parallel" 552 dataset_strategy = self._context_handle.get_dataset_strategy() 553 if context.get_context('mode') == context.PYNATIVE_MODE: 554 raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' " 555 f"or 'data_parallel', but got the setting value is {dataset_strategy}.") 556 return dataset_strategy 557 558 def set_grad_accumulation_step(self, grad_accumulation_step): 559 """ 560 Set grad accumulation step. 561 562 Args: 563 grad_accumulation_step (int): The grad accumulation step. 564 """ 565 if grad_accumulation_step > 1: 566 raise ValueError("The interface is deprecated. To use gradient accumulation, " 567 "please use GradAccumulationCell in mindspore.nn.wrap.cell_wrapper.") 568 self.check_context_handle() 569 Validator.check_positive_int(grad_accumulation_step) 570 self._context_handle.set_grad_accumulation_step(grad_accumulation_step) 571 572 def get_grad_accumulation_step(self): 573 """Get grad accumulation step.""" 574 self.check_context_handle() 575 return self._context_handle.get_grad_accumulation_step() 576 577 def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file): 578 """ 579 Set strategy checkpoint save path. 580 581 Args: 582 strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint. 583 """ 584 self.check_context_handle() 585 dir_path = os.path.dirname(strategy_ckpt_save_file) 586 if dir_path and not os.path.exists(dir_path): 587 os.makedirs(dir_path, exist_ok=True) 588 self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file) 589 590 def get_strategy_ckpt_save_file(self): 591 """Get strategy checkpoint save path.""" 592 self.check_context_handle() 593 return self._context_handle.get_strategy_ckpt_save_file() 594 595 def set_strategy_ckpt_config(self, strategy_ckpt_config): 596 """ 597 Set strategy checkpoint config. 598 599 Args: 600 strategy_ckpt_config (dict): The strategy checkpoint config. 601 """ 602 self.check_context_handle() 603 if not isinstance(strategy_ckpt_config, dict): 604 raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' " 605 "must be dict, but got the type : {}.".format(type(strategy_ckpt_config))) 606 for config_name in strategy_ckpt_config: 607 unknown_config = [] 608 if config_name not in ["load_file", "save_file", "only_trainable_params"]: 609 unknown_config.append(config_name) 610 611 if unknown_config: 612 raise ValueError("Unknown config: {}".format(unknown_config)) 613 if "load_file" in strategy_ckpt_config: 614 load_file = strategy_ckpt_config.get("load_file") 615 if not isinstance(load_file, str): 616 raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', " 617 "the argument 'load_file' must be str, but got the type : {} .".format(type(load_file))) 618 self._context_handle.set_strategy_ckpt_load_file(load_file) 619 if "save_file" in strategy_ckpt_config: 620 save_file = strategy_ckpt_config.get("save_file") 621 if not isinstance(save_file, str): 622 raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', " 623 "the argument 'save_file' must be str, but got the type : {} .".format(type(save_file))) 624 self._context_handle.set_strategy_ckpt_save_file(save_file) 625 if "only_trainable_params" in strategy_ckpt_config: 626 only_trainable_params = strategy_ckpt_config.get("only_trainable_params") 627 if not isinstance(only_trainable_params, bool): 628 raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', " 629 "the argument 'only_trainable_params' must be bool," 630 " but got the type : {} .".format(type(only_trainable_params))) 631 self._context_handle.set_stra_file_only_trainable_params(only_trainable_params) 632 633 def get_strategy_ckpt_config(self): 634 """Get strategy checkpoint config.""" 635 self.check_context_handle() 636 load_file = self._context_handle.get_strategy_ckpt_load_file() 637 save_file = self._context_handle.get_strategy_ckpt_save_file() 638 only_trainable_param = self._context_handle.get_stra_file_only_trainable_params() 639 return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param} 640 641 def set_group_ckpt_save_file(self, group_ckpt_save_file): 642 """Set group checkpoint save path.""" 643 self.check_context_handle() 644 dir_path = os.path.dirname(group_ckpt_save_file) 645 if dir_path and not os.path.exists(dir_path): 646 os.makedirs(dir_path) 647 self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file) 648 649 def get_parameter_broadcast_is_set(self): 650 """Get parameter broadcast is set or not.""" 651 self.check_context_handle() 652 return self._context_handle.get_parameter_broadcast_is_set() 653 654 def set_all_reduce_fusion_split_indices(self, indices, group=""): 655 """ 656 Set allreduce fusion strategy by parameters indices. 657 658 Args: 659 indices (list): Indices list. 660 group (str): The communication group of hccl/nccl. 661 662 Raises: 663 TypeError: If type of indices item is not int. 664 TypeError: If group is not a python str. 665 """ 666 self.check_context_handle() 667 if not indices: 668 raise ValueError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', " 669 "the argument 'indices' can not be empty") 670 671 if isinstance(indices, (list)): 672 for index in indices: 673 if not isinstance(index, int) or isinstance(index, bool): 674 raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', " 675 "the argument 'index' must be int, but got the type : {} .".format(type(index))) 676 else: 677 raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', " 678 "the argument 'indices' must be list, but got the type : {} .".format(type(indices))) 679 680 if len(set(indices)) != len(indices): 681 raise ValueError("The indices has duplicate elements") 682 683 if sorted(indices) != indices: 684 raise ValueError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', " 685 "the elements in argument 'indices' must be sorted in ascending order") 686 687 new_group = self._check_and_default_group(group) 688 689 self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group) 690 if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"): 691 _set_fusion_strategy_by_idx(indices) 692 693 def get_all_reduce_fusion_split_indices(self, group=""): 694 """ 695 Get allreduce fusion split indices. 696 697 Args: 698 group (str): The communication group of hccl/nccl. 699 700 Returns: 701 Return split sizes list according to the group. 702 703 Raises: 704 TypeError: If group is not a python str. 705 """ 706 self.check_context_handle() 707 new_group = self._check_and_default_group(group) 708 return self._context_handle.get_all_reduce_fusion_split_indices(new_group) 709 710 def set_all_reduce_fusion_split_sizes(self, sizes, group=""): 711 """ 712 Set allreduce fusion strategy by parameters data sizes. 713 714 Args: 715 sizes (list): Sizes list. 716 group (str): The communication group of hccl/nccl. 717 718 Raises: 719 TypeError: If type of sizes item is not int. 720 TypeError: If group is not a python str. 721 """ 722 self.check_context_handle() 723 if isinstance(sizes, (list)): 724 for size in sizes: 725 if not isinstance(size, int) or isinstance(size, bool): 726 raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_sizes', " 727 "the argument 'sizes' must be int, but got the type : {}.".format(type(size))) 728 else: 729 raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_sizes', " 730 "the argument 'sizes' must be list, but got the type : {}.".format(type(sizes))) 731 732 new_group = self._check_and_default_group(group) 733 self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group) 734 if context.get_context("device_target") == "Ascend": 735 _set_fusion_strategy_by_size(sizes) 736 737 def get_all_reduce_fusion_split_sizes(self, group=""): 738 """ 739 Get allreduce fusion split sizes. 740 741 Args: 742 group (str): The communication group of hccl/nccl. 743 744 Returns: 745 Return split sizes list according to the group. 746 747 Raises: 748 TypeError: If group is not a python str. 749 """ 750 self.check_context_handle() 751 new_group = self._check_and_default_group(group) 752 return self._context_handle.get_all_reduce_fusion_split_sizes(new_group) 753 754 def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion): 755 """ 756 Set enable/disable all reduce fusion. 757 758 Args: 759 enable_all_reduce_fusion (bool): Enable/disable all reduce fusion. 760 """ 761 self.check_context_handle() 762 if not isinstance(enable_all_reduce_fusion, bool): 763 raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', " 764 "the argument 'enable_fusion' must be bool, but got the type : {}." 765 .format(type(enable_all_reduce_fusion))) 766 self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion) 767 768 def set_enable_all_gather_fusion(self, enable_all_gather_fusion): 769 """ 770 Set enable/disable all gather fusion. 771 772 Args: 773 enable_all_gather_fusion (bool): Enable/disable all gather fusion. 774 """ 775 self.check_context_handle() 776 if not isinstance(enable_all_gather_fusion, bool): 777 raise TypeError("For 'set_auto_parallel_context().set_enable_all_gather_fusion', " 778 "the argument 'enable_fusion' must be bool, but got the type : {}." 779 .format(type(enable_all_gather_fusion))) 780 self._context_handle.set_enable_all_gather_fusion(enable_all_gather_fusion) 781 782 def set_enable_reduce_scatter_fusion(self, enable_reduce_scatter_fusion): 783 """ 784 Set enable/disable reduce scatter fusion. 785 786 Args: 787 enable_reduce_scatter_fusion (bool): Enable/disable reduce scatter fusion. 788 """ 789 self.check_context_handle() 790 if not isinstance(enable_reduce_scatter_fusion, bool): 791 raise TypeError("For 'set_auto_parallel_context().set_enable_reduce_scatter_fusion', " 792 "the argument 'enable_fusion' must be bool, but got the type : {}." 793 .format(type(enable_reduce_scatter_fusion))) 794 self._context_handle.set_enable_reduce_scatter_fusion(enable_reduce_scatter_fusion) 795 796 def get_enable_all_reduce_fusion(self): 797 """Get all reduce fusion flag.""" 798 self.check_context_handle() 799 return self._context_handle.get_enable_all_reduce_fusion() 800 801 def get_enable_all_gather_fusion(self): 802 """Get all gather fusion flag.""" 803 self.check_context_handle() 804 return self._context_handle.get_enable_all_gather_fusion() 805 806 def get_enable_reduce_scatter_fusion(self): 807 """Get reduce scatter flag.""" 808 self.check_context_handle() 809 return self._context_handle.get_enable_reduce_scatter_fusion() 810 811 def get_device_num_is_set(self): 812 """Get device number is set or not.""" 813 self.check_context_handle() 814 return self._context_handle.get_device_num_is_set() 815 816 def get_global_rank_is_set(self): 817 """Get global rank is set or not.""" 818 self.check_context_handle() 819 return self._context_handle.get_global_rank_is_set() 820 821 def set_enable_parallel_optimizer(self, enable_parallel_optimizer): 822 """ 823 Set enable/disable parallel optimizer. 824 825 Args: 826 set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer. 827 """ 828 self.check_context_handle() 829 if not isinstance(enable_parallel_optimizer, bool): 830 raise TypeError("For 'set_auto_parallel_context', " 831 "the argument 'enable_parallel_optimizer' must be bool, but got the type : {}." 832 .format(type(enable_parallel_optimizer))) 833 self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer) 834 835 def set_force_fp32_communication(self, force_fp32_communication): 836 """ 837 Set enable/disable force fp32 communication. 838 839 Args: 840 set_force_fp32_communication (bool): Enable/disable force fp32 communication. 841 """ 842 self.check_context_handle() 843 if not isinstance(force_fp32_communication, bool): 844 raise TypeError("For 'set_auto_parallel_context', " 845 "the argument 'force_fp32_communication' must be bool, but got the type : {}." 846 .format(type(force_fp32_communication))) 847 self._context_handle.set_force_fp32_communication(force_fp32_communication) 848 849 def get_enable_fold_pipeline(self): 850 """Get parallel optimizer flag.""" 851 self.check_context_handle() 852 return self._context_handle.get_enable_fold_pipeline() 853 854 def set_pipeline_config(self, pipeline_config): 855 r""" 856 Set the configuration for pipeline parallelism. The configuration provides more detailed behavior control about 857 parallel training when pipeline parallelism is enabled. 858 859 Args: 860 pipeline_config (dict): The configuration for pipeline parallelism. It supports following keys: 861 862 - pipeline_interleave(bool): Setting true enable interleave scheduler for pipeline parallelism. This 863 scheduler requires more memory but less bubble. 864 - pipeline_scheduler(string): There are two choices, "1f1b" and "gpipe". default is "1f1b" 865 866 - 1f1b: It requires less memory and bubble ratio, for it run backward pass when corresponding forward pass 867 finished. 868 - gpipe: It requires more memory and bubble ratio, for it run backward pass after all forward pass 869 finished. 870 871 Raises: 872 TypeError: If the type of `pipeline_config` is not `dict`. 873 ValueError: If the key in `pipeline_config` not in ["pipeline_interleave", "pipeline_scheduler"]. 874 ValueError: If pipeline interleave is False, pipeline scheduler is not `1f1b`. 875 """ 876 self.check_context_handle() 877 878 if not isinstance(pipeline_config, dict): 879 raise TypeError("For 'set_pipeline_config', the argument 'pipeine_config' " 880 "must be dict, but got the type : {}.".format(type(pipeline_config))) 881 882 pp_interleave = _PipelineConfig.PIPELINE_INTERLEAVE 883 pp_scheduler = _PipelineConfig.PIPELINE_SCHEDULER 884 885 for config_name in pipeline_config: 886 unknown_config = [] 887 if config_name not in [pp_interleave, pp_scheduler]: 888 unknown_config.append(config_name) 889 890 if unknown_config: 891 raise ValueError("Unknown config: {}".format(unknown_config)) 892 893 Validator.check_bool( 894 pipeline_config[pp_interleave], pp_interleave, pp_interleave) 895 self._context_handle.set_pipeline_interleave( 896 pipeline_config[pp_interleave]) 897 898 Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B, 899 _PipelineScheduler.PIPELINE_GPIPE]) 900 if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B: 901 raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported") 902 903 self._context_handle.set_pipeline_scheduler(pipeline_config[pp_scheduler]) 904 905 def get_enable_parallel_optimizer(self): 906 """Get parallel optimizer flag.""" 907 self.check_context_handle() 908 return self._context_handle.get_enable_parallel_optimizer() 909 910 def get_force_fp32_communication(self): 911 """Get force fp32 communication flag.""" 912 self.check_context_handle() 913 return self._context_handle.get_force_fp32_communication() 914 915 916 def set_parallel_optimizer_config(self, parallel_optimizer_config): 917 r""" 918 Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel 919 training when parallel optimizer is enabled. 920 921 Args: 922 parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer 923 configure. It supports the following keys: 924 925 - gradient_accumulation_shard(bool): If true, the accumulation gradient parameters will be sharded 926 across the data parallel devices. This will introduce additional 927 communication cost(ReduceScatter) at each step when accumulate the 928 gradients, but saves a lot of device memories, 929 thus can make model be trained with larger batch size. 930 This configuration is effective only when the model runs on pipeline 931 training or gradient accumulation with data parallel. 932 933 - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel optimizer is 934 enabled, parameters with size smaller than this threshold will not be 935 sharded across the devices. Parameter size = shape[0] \* ... \* 936 shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64. 937 - optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the 938 maximum group size across devices when the parallel optimizer is 939 enabled. The numerical range can be (0, device_num]. Default value 940 is -1, which means the optimizer weight shard group size will 941 the data parallel group of each parameter. Default -1. 942 943 """ 944 self.check_context_handle() 945 grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD 946 threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD 947 optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE 948 949 for config_name in parallel_optimizer_config: 950 unknown_config = [] 951 if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name]: 952 unknown_config.append(config_name) 953 954 if unknown_config: 955 raise ValueError("Unknown config: {}".format(unknown_config)) 956 957 if grad_shard_name in parallel_optimizer_config: 958 Validator.check_bool( 959 parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name) 960 self._context_handle.set_grad_accumulation_shard( 961 parallel_optimizer_config[grad_shard_name]) 962 963 if threshold_name in parallel_optimizer_config: 964 Validator.check_non_negative_int( 965 parallel_optimizer_config[threshold_name]) 966 self._context_handle.set_parallel_optimizer_threshold( 967 parallel_optimizer_config[threshold_name]) 968 969 if optimizer_weight_shard_size_name in parallel_optimizer_config: 970 value = parallel_optimizer_config[optimizer_weight_shard_size_name] 971 Validator.check_positive_int(value) 972 self.set_optimizer_weight_shard_size(value) 973 974 def get_grad_accumulation_shard(self): 975 """Get grad accumulation shard.""" 976 self.check_context_handle() 977 return self._context_handle.get_grad_accumulation_shard() 978 979 def get_parallel_optimizer_threshold(self): 980 """Get parallel optimizer threshold.""" 981 self.check_context_handle() 982 return self._context_handle.get_parallel_optimizer_threshold() 983 984 def set_enable_alltoall(self, enable_a2a): 985 """ 986 Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. 987 Default: ``False``. 988 989 Args: 990 enable_a2a (bool): Enable/disable AllToAll. 991 """ 992 self.check_context_handle() 993 if not isinstance(enable_a2a, bool): 994 raise TypeError("For 'set_auto_parallel_context().set_enable_alltoall', the argument 'enable_a2a' " 995 "must be bool, but got the type : {}.".format(type(enable_a2a))) 996 self._context_handle.set_enable_alltoall(enable_a2a) 997 998 def get_enable_alltoall(self): 999 """Get the value of enabling AllToAll.""" 1000 self.check_context_handle() 1001 return self._context_handle.get_enable_alltoall() 1002 1003 def set_communi_parallel_mode(self, communi_parallel_mode): 1004 """ 1005 Set communication parallel mode. 1006 1007 Args: 1008 communi_parallel_mode (str): The communication parallel mode. 1009 1010 Raises: 1011 ValueError: If parallel mode is not supported. 1012 """ 1013 if not isinstance(communi_parallel_mode, str): 1014 raise TypeError("For 'set_auto_parallel_context().set_communi_parallel_mode', " 1015 "the argument 'communi_parallel_mode' must be str, but got the type : {}." 1016 .format(type(communi_parallel_mode))) 1017 self.check_context_handle() 1018 ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode) 1019 if ret is False: 1020 raise ValueError("For 'set_auto_parallel_context().set_communi_parallel_mode', " 1021 "the argument 'communi_parallel_mode' only support 'ALL_GROUP_PARALLEL', " 1022 "'SAME_SEVER_GROUP_PARALLEL' and 'NO_GROUP_PARALLEL', " 1023 "but got the value : {}.".format(communi_parallel_mode)) 1024 1025 def get_communi_parallel_mode(self): 1026 """Get communication parallel mode.""" 1027 self.check_context_handle() 1028 return self._context_handle.get_communi_parallel_mode() 1029 1030 def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): 1031 """ 1032 Set optimizer_weight_shard_size. 1033 1034 Args: 1035 optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel 1036 optimizer across devices. 1037 """ 1038 self.check_context_handle() 1039 if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool): 1040 raise TypeError(f"The type of optimizer_weight_shard_size must be int, \ 1041 but got {type(optimizer_weight_shard_size)}.") 1042 if optimizer_weight_shard_size <= 1: 1043 logger.warning("The setting 'optimizer_weight_shard_size' is invalid. " 1044 "Please use the integer larger than 1.") 1045 return 1046 self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size) 1047 1048 def get_optimizer_weight_shard_size(self): 1049 """Get optimizer_weight_shard_size.""" 1050 self.check_context_handle() 1051 return self._context_handle.get_optimizer_weight_shard_size() 1052 1053 def set_ops_strategy_json_config(self, type, path, mode): 1054 """ 1055 Set configuration of saving ops strategy in file .json. 1056 """ 1057 self.check_context_handle() 1058 self._context_handle.set_ops_strategy_json_config(type, path, mode) 1059 1060 def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save): 1061 """ 1062 Set optimizer_weight_shard_aggregated_save. 1063 1064 Args: 1065 optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when 1066 enable parallel optimizer. 1067 """ 1068 self.check_context_handle() 1069 if not isinstance(optimizer_weight_shard_aggregated_save, bool): 1070 raise TypeError('optimizer_weight_shard_aggregated_save is invalid type') 1071 self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save) 1072 1073 def get_optimizer_weight_shard_aggregated_save(self): 1074 """Get optimizer_weight_shard_size.""" 1075 self.check_context_handle() 1076 return self._context_handle.get_optimizer_weight_shard_aggregated_save() 1077 1078 def get_full_batch_is_set(self): 1079 """Get full batch attr""" 1080 self.check_context_handle() 1081 return self._context_handle.get_full_batch_is_set() 1082 1083 def reset(self): 1084 """Reset all settings.""" 1085 self.check_context_handle() 1086 self._context_handle.reset() 1087 _ParallelFusionConfig.reset() 1088 1089 def _check_and_default_group(self, group): 1090 """Validate the given group, if group is empty, returns a default fusion group""" 1091 if isinstance(group, (str)): 1092 group_len = len(group) 1093 if group_len > _MAX_GROUP_NAME_LEN: 1094 raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}') 1095 else: 1096 raise TypeError('Group must be a python str') 1097 1098 if group == "": 1099 if context.get_context("device_target") == "Ascend": 1100 group = _DEFAULT_HCCL_FUSION_GROUP_NAME 1101 else: 1102 group = _DEFAULT_NCCL_FUSION_GROUP_NAME 1103 return group 1104 1105 def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"): 1106 """ 1107 Set allgather and reducescatter fusion method for auto parallel. 1108 1109 Args: 1110 comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it 1111 supports four fusion methods: `auto` and `size`. 1112 comm_type (str): The name of the communication operator, `allgather` or `reducescatter`. 1113 1114 Raises: 1115 KeyError: When key of comm_fusion is not 'mode' or 'config'. 1116 KeyError: When `mode` is not 'auto', 'size'. 1117 """ 1118 self.check_context_handle() 1119 if comm_type == "allgather" and not self.get_enable_all_gather_fusion(): 1120 return 1121 if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion(): 1122 return 1123 if not isinstance(comm_fusion, dict): 1124 raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format( 1125 comm_type, type(comm_fusion))) 1126 if _ParallelFusionConfig.MODE not in comm_fusion: 1127 raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") 1128 if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: 1129 raise KeyError("For 'comm_fusion', the key 'config' should be contained.") 1130 check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE] 1131 if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: 1132 self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) 1133 else: 1134 raise KeyError("fusion method mode must be auto or size, but got {}".format( 1135 comm_fusion[_ParallelFusionConfig.MODE])) 1136 1137 fusion_threshold = 64 1138 if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO: 1139 fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG] 1140 self.set_fusion_threshold_mb(fusion_threshold, comm_type) 1141 1142 def _set_allreduce_comm_fusion(self, comm_fusion): 1143 """ 1144 Set fusion method for auto parallel. 1145 1146 Args: 1147 comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it 1148 supports four fusion methods: `auto`, `size` and `index`. 1149 1150 Raises: 1151 KeyError: When key of comm_fusion is not 'mode' or 'config'. 1152 KeyError: When `mode` is not 'auto', 'size' or 'index'. 1153 """ 1154 self.check_context_handle() 1155 if not self.get_enable_all_reduce_fusion(): 1156 return 1157 if not isinstance(comm_fusion, dict): 1158 raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format( 1159 type(comm_fusion))) 1160 if _ParallelFusionConfig.MODE not in comm_fusion: 1161 raise KeyError("For 'comm_fusion', the key 'mode' should be contained.") 1162 if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion: 1163 raise KeyError("For 'comm_fusion', the key 'config' should be contained.") 1164 check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE] 1165 if comm_fusion[_ParallelFusionConfig.MODE] in check_mode: 1166 self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE]) 1167 else: 1168 raise KeyError("fusion method mode must be auto, index or size, but got {}".format( 1169 comm_fusion[_ParallelFusionConfig.MODE])) 1170 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO: 1171 self.set_fusion_threshold_mb(fusion_threshold=64) 1172 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE: 1173 self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) 1174 if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX: 1175 self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]) 1176 1177 def _set_openstate_comm_fusion(self, openstate): 1178 """ 1179 Set open state for comm fusion. 1180 1181 Args: 1182 openstate (bool): The open state value to set the fusion method whether or not. Currently it 1183 supports two states: `True`, or `Flase`. 1184 1185 Raises: 1186 TypeError: When the value is not bool. 1187 """ 1188 self.check_context_handle() 1189 if not isinstance(openstate, bool): 1190 raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format( 1191 type(openstate))) 1192 if not openstate: 1193 self.set_enable_all_reduce_fusion(openstate) 1194 self.set_enable_all_gather_fusion(openstate) 1195 self.set_enable_reduce_scatter_fusion(openstate) 1196 1197 1198def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"): 1199 """ 1200 Set strategy json configuration. 1201 1202 Args: 1203 type (str): The parameter for choosing save or load .json file. 1204 path (str): Path to save or load parallel strategy json. 1205 mode (str): The parameter for choosing save all or important operators. 1206 1207 Raises: 1208 KeyError: When type is not 'SAVE' or 'LOAD'. 1209 KeyError: When mode is not 'all' or 'principal'. 1210 """ 1211 dir_path = os.path.dirname(path) 1212 if dir_path and not os.path.exists(dir_path): 1213 os.makedirs(dir_path) 1214 check_type = ["SAVE", "LOAD"] 1215 check_mode = ["all", "principal"] 1216 if type in check_type and mode in check_mode: 1217 auto_parallel_context().set_ops_strategy_json_config(type, path, mode) 1218 else: 1219 raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'") 1220 1221 1222_AUTO_PARALLEL_CONTEXT = None 1223 1224 1225def auto_parallel_context(): 1226 """ 1227 Get the global _AUTO_PARALLEL_CONTEXT, if it is not created, create a new one. 1228 1229 Returns: 1230 _AutoParallelContext, the global auto parallel context. 1231 """ 1232 global _AUTO_PARALLEL_CONTEXT 1233 if _AUTO_PARALLEL_CONTEXT is None: 1234 _AUTO_PARALLEL_CONTEXT = _AutoParallelContext() 1235 return _AUTO_PARALLEL_CONTEXT 1236 1237 1238_set_auto_parallel_context_func_map = { 1239 "device_num": auto_parallel_context().set_device_num, 1240 "global_rank": auto_parallel_context().set_global_rank, 1241 "gradients_mean": auto_parallel_context().set_gradients_mean, 1242 "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync, 1243 "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean, 1244 "pipeline_stages": auto_parallel_context().set_pipeline_stages, 1245 "auto_pipeline": auto_parallel_context().set_auto_pipeline, 1246 "pipeline_result_broadcast": auto_parallel_context().set_pipeline_result_broadcast, 1247 "pipeline_segments": auto_parallel_context().set_pipeline_segments, 1248 "parallel_mode": auto_parallel_context().set_parallel_mode, 1249 "search_mode": auto_parallel_context().set_strategy_search_mode, 1250 "auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode, 1251 "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, 1252 "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, 1253 "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, 1254 "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file, 1255 "full_batch": auto_parallel_context().set_full_batch, 1256 "dataset_strategy": auto_parallel_context().set_dataset_strategy, 1257 "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, 1258 "force_fp32_communication": auto_parallel_context().set_force_fp32_communication, 1259 "parallel_optimizer_config": auto_parallel_context().set_parallel_optimizer_config, 1260 "pipeline_config": auto_parallel_context().set_pipeline_config, 1261 "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, 1262 "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices, 1263 "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, 1264 "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, 1265 "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, 1266 "sharding_propagation": auto_parallel_context().set_sharding_propagation, 1267 "enable_alltoall": auto_parallel_context().set_enable_alltoall, 1268 "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config, 1269 "comm_fusion": auto_parallel_context().set_comm_fusion} 1270 1271_get_auto_parallel_context_func_map = { 1272 "device_num": auto_parallel_context().get_device_num, 1273 "global_rank": auto_parallel_context().get_global_rank, 1274 "gradients_mean": auto_parallel_context().get_gradients_mean, 1275 "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync, 1276 "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean, 1277 "pipeline_stages": auto_parallel_context().get_pipeline_stages, 1278 "auto_pipeline": auto_parallel_context().get_auto_pipeline, 1279 "pipeline_result_broadcast": auto_parallel_context().get_pipeline_result_broadcast, 1280 "pipeline_interleave": auto_parallel_context().get_pipeline_interleave, 1281 "pipeline_scheduler": auto_parallel_context().get_pipeline_scheduler, 1282 "parallel_mode": auto_parallel_context().get_parallel_mode, 1283 "search_mode": auto_parallel_context().get_strategy_search_mode, 1284 "auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode, 1285 "parameter_broadcast": auto_parallel_context().get_parameter_broadcast, 1286 "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, 1287 "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, 1288 "full_batch": auto_parallel_context().get_full_batch, 1289 "dataset_strategy": auto_parallel_context().get_dataset_strategy, 1290 "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, 1291 "force_fp32_communication": auto_parallel_context().get_force_fp32_communication, 1292 "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step, 1293 "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices, 1294 "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, 1295 "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, 1296 "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save, 1297 "sharding_propagation": auto_parallel_context().get_sharding_propagation, 1298 "enable_alltoall": auto_parallel_context().get_enable_alltoall, 1299 "comm_fusion": auto_parallel_context().get_comm_fusion, 1300 "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config, 1301 "full_batch_is_set": auto_parallel_context().get_full_batch_is_set} 1302 1303 1304@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, 1305 loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str, 1306 parameter_broadcast=bool, strategy_ckpt_load_file=str, 1307 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, 1308 grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, 1309 communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool, 1310 optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict, 1311 strategy_ckpt_config=dict, force_fp32_communication=bool) 1312def _set_auto_parallel_context(**kwargs): 1313 """ 1314 Set auto parallel context. 1315 1316 Note: 1317 Attribute name is required for setting attributes. 1318 1319 Args: 1320 device_num (int): Available device number, the value must be in [1, 4096]. Default: 1. 1321 global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0. 1322 gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: ``False``. 1323 loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated 1324 calculations. Default: ``True``. 1325 gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True. 1326 Default: ``True``. 1327 parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel", 1328 "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone". 1329 1330 - stand_alone: Only one processor working. 1331 1332 - data_parallel: Distributing the data across different processors. 1333 1334 - hybrid_parallel: Achieving data parallelism and model parallelism manually. 1335 1336 - semi_auto_parallel: Achieving data parallelism and model parallelism by 1337 setting parallel strategies. 1338 1339 - auto_parallel: Achieving parallelism automatically. 1340 search_mode (str): There are two kinds of search modes: "recursive_programming", "dynamic_programming" 1341 and "sharding_propagation". Default: "dynamic_programming". 1342 1343 - recursive_programming: Recursive programming search mode. 1344 1345 - dynamic_programming: Dynamic programming search mode. 1346 1347 - sharding_propagation: Propagate shardings from configured ops to non-configured ops. 1348 auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is 1349 for forward compatibility, and this attribute will be deleted in a future MindSpore version. 1350 parameter_broadcast (bool): Indicating whether to broadcast parameters before training. 1351 "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter 1352 broadcast. Default: ``False``. 1353 strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' 1354 strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' 1355 group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: '' 1356 full_batch (bool): Whether to load the whole batch on each device. Default: ``False``. 1357 dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel". 1358 enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: ``False``. 1359 force_fp32_communication (bool): A switch that determines whether reduce operators (AllReduce, ReduceScatter) 1360 are forced to use the fp32 data type for communication during communication. True is the enable 1361 switch. Default: ``False`` . 1362 all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. 1363 pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how 1364 the devices are distributed alone the pipeline. The total devices will be divided into 1365 'pipeline_stags' stages. This currently could only be used when 1366 parallel mode semi_auto_parallel is enabled. Default: 0 1367 auto_pipeline (bool): Set the pipeline stage number to automatic. Its value will be selected between 1 and the 1368 parameter `pipeline_stages`. This option requires the `parallel_mode` to be ``auto_parallel`` 1369 and the `search_mode` to be ``recursive_programming``. Default: ``False`` . 1370 pipeline_result_broadcast (bool): A switch that broadcast the last stage result to all other stage in pipeline 1371 parallel inference. Default: ``False`` . 1372 communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel", 1373 "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel". 1374 1375 - all_group_parallel: All communication groups are in parallel. 1376 1377 - same_server_group_parallel: Only the communication groups within the same server are parallel. 1378 1379 - no_group_parallel: All communication groups are not parallel. 1380 optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer. 1381 It should be larger than one and less than or equal with the data parallel size. 1382 Default: -1, which means fully use parallel optimizer in data parallel dimension. 1383 optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel 1384 optimizer. Default: ``False``. 1385 sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, 1386 the strategy-configured operators will propagate the strategies to other 1387 operators with minimum redistribution cost; otherwise, the algorithm will 1388 search the desired strategies. Default: ``False``. 1389 enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to 1390 circumvent AllToAll. Default: ``False``. 1391 comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each 1392 communication fusion config has two keys: "mode" and "config". 1393 It supports following communication fusion types and configurations: 1394 1395 - openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on 1396 the communication fusion, otherwise, turn off the communication fusion. Default: `True`. 1397 1398 - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size` 1399 and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default 1400 fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size 1401 manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as 1402 `all_reduce_fusion_config`. 1403 1404 - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`. 1405 In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion 1406 threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size 1407 manually, and the fusion threshold must be larger than `0` MB. 1408 1409 - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto` 1410 and `size`. Config is same as `allgather`. 1411 1412 1413 1414 Raises: 1415 ValueError: If input key is not attribute in auto parallel context. 1416 """ 1417 for key, value in kwargs.items(): 1418 if key not in _set_auto_parallel_context_func_map: 1419 raise ValueError("Set context keyword %s is not recognized!" % key) 1420 set_func = _set_auto_parallel_context_func_map[key] 1421 set_func(value) 1422 1423 1424def _get_auto_parallel_context(attr_key): 1425 """ 1426 Get auto parallel context attribute value according to the key. 1427 1428 Args: 1429 attr_key (str): The key of the attribute. 1430 1431 Returns: 1432 Return attribute value according to the key. 1433 1434 Raises: 1435 ValueError: If input key is not attribute in auto parallel context. 1436 """ 1437 if attr_key not in _get_auto_parallel_context_func_map: 1438 raise ValueError("Get context keyword %s is not recognized!" % attr_key) 1439 get_func = _get_auto_parallel_context_func_map[attr_key] 1440 return get_func() 1441 1442 1443def _reset_auto_parallel_context(): 1444 """ 1445 Reset auto parallel context attributes to the default values: 1446 1447 - device_num: 1. 1448 - global_rank: 0. 1449 - gradients_mean: False. 1450 - gradient_fp32_sync: True. 1451 - parallel_mode: "stand_alone". 1452 - parameter_broadcast: False. 1453 - strategy_ckpt_load_file: "" 1454 - strategy_ckpt_save_file: "" 1455 - enable_parallel_optimizer: False 1456 - force_fp32_communication: False 1457 - search_mode: 'recursive_programming 1458 - auto_parallel_search_mode: 'recursive_programming 1459 - sharding_propagation: False 1460 - pipeline_stages: 0 1461 - auto_pipeline: False 1462 - pipeline_result_broadcast: False 1463 - gradient_accumulation_shard: True 1464 - fusion_threshold: 64 1465 """ 1466 auto_parallel_context().reset() 1467