• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Context of auto parallel"""
16import threading
17
18import mindspore.context as context
19import mindspore.log as logger
20from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
21from mindspore.parallel._ps_context import _is_role_pserver
22from mindspore._c_expression import AutoParallelContext
23from mindspore._checkparam import args_type_check, Validator
24
25_MAX_GROUP_NAME_LEN = 127
26_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
27_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
28
29
30class _AutoParallelContext:
31    """
32    _AutoParallelContext is the environment in which operations are executed
33
34    Note:
35        Create a context through instantiating Context object is not recommended.
36        Should use auto_parallel_context() to get the context since Context is singleton.
37    """
38    _instance = None
39    _instance_lock = threading.Lock()
40
41    def __init__(self):
42        self._context_handle = AutoParallelContext.get_instance()
43        self._dataset_strategy_using_str = True
44
45    def __new__(cls):
46        if cls._instance is None:
47            cls._instance_lock.acquire()
48            cls._instance = object.__new__(cls)
49            cls._instance_lock.release()
50        return cls._instance
51
52    def check_context_handle(self):
53        """
54        Check context handle.
55
56        Raises:
57            ValueError: If the context handle is none.
58        """
59        if self._context_handle is None:
60            raise ValueError("Context handle is none in context!!!")
61
62    def set_device_num(self, device_num):
63        """
64        Set device num for auto parallel.
65
66        Args:
67            device_num (int): The device number.
68
69        Raises:
70            ValueError: If the device num is not in [1, 4096].
71        """
72        self.check_context_handle()
73        if device_num < 1 or device_num > 4096:
74            raise ValueError("Device num must be in [1, 4096], but got {}".format(device_num))
75        self._context_handle.set_device_num(device_num)
76
77    def get_device_num(self):
78        """Get device num."""
79        self.check_context_handle()
80        return self._context_handle.get_device_num()
81
82    def set_global_rank(self, global_rank):
83        """
84        Set global rank for auto parallel.
85
86        Args:
87            global_rank (int): The rank id of current rank.
88
89        Raises:
90            ValueError: If the global rank is not in [1, 4096].
91        """
92        self.check_context_handle()
93        if global_rank < 0 or global_rank > 4095:
94            raise ValueError("Global rank must be in [0, 4095], but got {}".format(global_rank))
95        self._context_handle.set_global_rank(global_rank)
96
97    def get_global_rank(self):
98        """Get current rank id."""
99        self.check_context_handle()
100        return self._context_handle.get_global_rank()
101
102    def set_pipeline_stages(self, stages):
103        """Set the stages of the pipeline"""
104        if isinstance(stages, bool):
105            raise TypeError("The type of pipeline_stage_num must be int, but got bool.")
106        if not isinstance(stages, int):
107            raise TypeError("The type of pipeline_stage_num must be int.")
108        if stages < 1:
109            raise ValueError("pipeline_stage_num can't be less than 1.")
110        backend = context.get_context("device_target")
111        if backend == "GPU" and stages > 1:
112            raise RuntimeError("Now GPU don't support pipeline parallel.")
113        self.check_context_handle()
114        self._context_handle.set_pipeline_stage_split_num(stages)
115
116    def get_pipeline_stages(self):
117        """Get the stages of the pipeline"""
118        self.check_context_handle()
119        return self._context_handle.get_pipeline_stage_split_num()
120
121    def set_gradients_mean(self, gradients_mean):
122        """
123        Set gradients_mean flag.
124
125        Note:
126            If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
127
128        Args:
129            gradients_mean (bool): The gradients_mean flag.
130        """
131        self.check_context_handle()
132        self._context_handle.set_gradients_mean(gradients_mean)
133
134    def get_gradients_mean(self):
135        """Get gradients_mean flag."""
136        self.check_context_handle()
137        return self._context_handle.get_gradients_mean()
138
139    def set_gradient_fp32_sync(self, gradient_fp32_sync):
140        """
141        Set gradient_fp32_sync.
142
143        Note:
144            If gradient_fp32_sync is true,
145            it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
146
147        Args:
148            gradient_fp32_sync (bool): The gradient_fp32_sync flag.
149        """
150        self.check_context_handle()
151        self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
152
153    def get_gradient_fp32_sync(self):
154        """Get gradient_fp32_sync flag."""
155        self.check_context_handle()
156        return self._context_handle.get_gradient_fp32_sync()
157
158    def set_loss_repeated_mean(self, loss_repeated_mean):
159        """
160        Set loss_repeated_mean flag.
161
162        Note:
163            If loss_repeated_mean is true,
164            Distributed automatic differentiation will perform a mean operator
165            in backward in the case of repeated calculations.
166
167        Args:
168            loss_repeated_mean (bool): The loss_repeated_mean flag.
169        """
170        if not isinstance(loss_repeated_mean, bool):
171            raise TypeError(f"The type of loss_repeated_mean must be bool, but got {type(loss_repeated_mean)}.")
172        self.check_context_handle()
173        self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
174
175    def get_loss_repeated_mean(self):
176        """Get loss_repeated_mean flag."""
177        self.check_context_handle()
178        return self._context_handle.get_loss_repeated_mean()
179
180    def set_parallel_mode(self, parallel_mode):
181        """
182        Set parallel mode for auto parallel.
183
184        Args:
185            parallel_mode (str): The parallel mode of auto parallel.
186
187        Raises:
188            ValueError: If parallel mode is not supported.
189        """
190        self.check_context_handle()
191        ret = self._context_handle.set_parallel_mode(parallel_mode)
192        if ret is False:
193            raise ValueError("Parallel mode does not support {}".format(parallel_mode))
194
195    def get_parallel_mode(self):
196        """Get parallel mode."""
197        self.check_context_handle()
198        if _is_role_pserver():
199            return context.ParallelMode.STAND_ALONE
200        return self._context_handle.get_parallel_mode()
201
202    def set_strategy_search_mode(self, auto_parallel_search_mode):
203        """
204        Set search mode of strategy.
205
206        Args:
207            auto_parallel_search_mode (str): The search mode of strategy.
208        """
209        self.check_context_handle()
210        ret = self._context_handle.set_strategy_search_mode(auto_parallel_search_mode)
211        if ret is False:
212            raise ValueError("Strategy search mode does not support {}".format(auto_parallel_search_mode))
213
214    def get_strategy_search_mode(self):
215        """Get search mode of strategy."""
216        self.check_context_handle()
217        return self._context_handle.get_strategy_search_mode()
218
219    def set_parameter_broadcast(self, parameter_broadcast):
220        """
221        Set parameter broadcast.
222
223        Args:
224            parameter_broadcast (bool): Parameter broadcast or not.
225        """
226        self.check_context_handle()
227        self._context_handle.set_parameter_broadcast(parameter_broadcast)
228
229    def get_parameter_broadcast(self):
230        """Get parameter broadcast flag."""
231        self.check_context_handle()
232        return self._context_handle.get_parameter_broadcast()
233
234    def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
235        """
236        Set strategy checkpoint load path.
237
238        Args:
239            strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint.
240        """
241        self.check_context_handle()
242        self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
243
244    def get_strategy_ckpt_load_file(self):
245        """Get strategy checkpoint load path."""
246        self.check_context_handle()
247        return self._context_handle.get_strategy_ckpt_load_file()
248
249    def set_full_batch(self, full_batch):
250        """
251        Set whether load full batch on each device.
252
253        Args:
254            full_batch (bool): True if load full batch on each device.
255        """
256        self.check_context_handle()
257        self._context_handle.set_full_batch(full_batch)
258
259    def get_full_batch(self):
260        """Get whether load full batch on each device."""
261        self.check_context_handle()
262        if _is_role_pserver():
263            return False
264        return self._context_handle.get_full_batch()
265
266    def set_dataset_strategy(self, dataset_strategy):
267        """
268        Set dataset sharding strategy.
269
270        Args:
271            dataset_strategy (str or tuple(tuple)): The dataset sharding strategy.
272        """
273        self.check_context_handle()
274        if isinstance(dataset_strategy, str):
275            if dataset_strategy not in ("full_batch", "data_parallel"):
276                raise ValueError("The dataset_strategy string should be 'full_batch' or 'data_parallel', "
277                                 "otherwise, incoming tuple(tuple) type strategy")
278            self._context_handle.set_full_batch(dataset_strategy == "full_batch")
279            self._dataset_strategy_using_str = True
280            return
281        if not isinstance(dataset_strategy, tuple):
282            raise TypeError(f'strategy must be str or tuple type, but got:{type(dataset_strategy)}')
283        for ele in dataset_strategy:
284            if not isinstance(ele, tuple):
285                raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
286            for dim in ele:
287                if not isinstance(dim, int):
288                    raise TypeError(f'The dim of each strategy value must be int type, but got:{type(dim)}')
289        self._dataset_strategy_using_str = False
290        self._context_handle.set_dataset_strategy(dataset_strategy)
291
292    def get_dataset_strategy(self):
293        """Get dataset sharding strategy."""
294        self.check_context_handle()
295        if self._dataset_strategy_using_str:
296            if self._context_handle.get_full_batch():
297                return "full_batch"
298            return "data_parallel"
299        return self._context_handle.get_dataset_strategy()
300
301    def set_grad_accumulation_step(self, grad_accumulation_step):
302        """
303        Set grad accumulation step.
304
305        Args:
306            grad_accumulation_step (int): The grad accumulation step.
307        """
308        self.check_context_handle()
309        Validator.check_positive_int(grad_accumulation_step)
310        self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
311
312    def get_grad_accumulation_step(self):
313        """Get grad accumulation step."""
314        self.check_context_handle()
315        return self._context_handle.get_grad_accumulation_step()
316
317    def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
318        """
319        Set strategy checkpoint save path.
320
321        Args:
322            strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
323        """
324        self.check_context_handle()
325        import os
326        dir_path = os.path.dirname(strategy_ckpt_save_file)
327        if dir_path and not os.path.exists(dir_path):
328            os.makedirs(dir_path)
329        self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
330
331    def get_strategy_ckpt_save_file(self):
332        """Get strategy checkpoint save path."""
333        self.check_context_handle()
334        return self._context_handle.get_strategy_ckpt_save_file()
335
336    def set_group_ckpt_save_file(self, group_ckpt_save_file):
337        """Set group checkpoint save path."""
338        self.check_context_handle()
339        import os
340        dir_path = os.path.dirname(group_ckpt_save_file)
341        if dir_path and not os.path.exists(dir_path):
342            os.makedirs(dir_path)
343        self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
344
345    def get_parameter_broadcast_is_set(self):
346        """Get parameter broadcast is set or not."""
347        self.check_context_handle()
348        return self._context_handle.get_parameter_broadcast_is_set()
349
350    def set_all_reduce_fusion_split_indices(self, indices, group=""):
351        """
352        Set allreduce fusion strategy by parameters indices.
353
354        Args:
355            indices (list): Indices list.
356            group (str): The communication group of hccl/nccl.
357
358        Raises:
359            TypeError: If type of indices item is not int.
360            TypeError: If group is not a python str.
361        """
362        self.check_context_handle()
363        if not indices:
364            raise ValueError('indices can not be empty')
365
366        if isinstance(indices, (list)):
367            for index in indices:
368                if not isinstance(index, int) or isinstance(index, bool):
369                    raise TypeError(f"The type of index must be int, but got {type(index)}.")
370        else:
371            raise TypeError('indices must be a python list')
372
373        if len(set(indices)) != len(indices):
374            raise ValueError('indices has duplicate elements')
375
376        if sorted(indices) != indices:
377            raise ValueError('elements in indices must be sorted in ascending order')
378
379        new_group = self._check_and_default_group(group)
380
381        self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
382        if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
383            _set_fusion_strategy_by_idx(indices)
384
385    def get_all_reduce_fusion_split_indices(self, group=""):
386        """
387        Get allreduce fusion split indices.
388
389        Args:
390            group (str): The communication group of hccl/nccl.
391
392        Returns:
393            Return split sizes list according to the group.
394
395        Raises:
396            TypeError: If group is not a python str.
397        """
398        self.check_context_handle()
399        new_group = self._check_and_default_group(group)
400        return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
401
402    def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
403        """
404        Set allreduce fusion strategy by parameters data sizes.
405
406        Args:
407            sizes (list): Sizes list.
408            group (str): The communication group of hccl/nccl.
409
410        Raises:
411            TypeError: If type of sizes item is not int.
412            TypeError: If group is not a python str.
413        """
414        self.check_context_handle()
415        if isinstance(sizes, (list)):
416            for size in sizes:
417                if not isinstance(size, int) or isinstance(size, bool):
418                    raise TypeError(f"The type of size must be int, but got {type(size)}.")
419        else:
420            raise TypeError('sizes must be a python list')
421
422        new_group = self._check_and_default_group(group)
423        self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
424        if context.get_context("device_target") == "Ascend":
425            _set_fusion_strategy_by_size(sizes)
426
427    def get_all_reduce_fusion_split_sizes(self, group=""):
428        """
429        Get allreduce fusion split sizes.
430
431        Args:
432            group (str): The communication group of hccl/nccl.
433
434        Returns:
435            Return split sizes list according to the group.
436
437        Raises:
438            TypeError: If group is not a python str.
439        """
440        self.check_context_handle()
441        new_group = self._check_and_default_group(group)
442        return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
443
444    def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
445        """
446        Set enable/disable all reduce fusion.
447
448        Args:
449            enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
450        """
451        self.check_context_handle()
452        if not isinstance(enable_all_reduce_fusion, bool):
453            raise TypeError('enable_all_reduce_fusion is invalid type')
454        self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
455
456    def get_enable_all_reduce_fusion(self):
457        """Get all reduce fusion flag."""
458        self.check_context_handle()
459        return self._context_handle.get_enable_all_reduce_fusion()
460
461    def get_device_num_is_set(self):
462        """Get device number is set or not."""
463        self.check_context_handle()
464        return self._context_handle.get_device_num_is_set()
465
466    def get_global_rank_is_set(self):
467        """Get global rank is set or not."""
468        self.check_context_handle()
469        return self._context_handle.get_global_rank_is_set()
470
471    def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
472        """
473        Set enable/disable parallel optimizer.
474
475        Args:
476            set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
477        """
478        self.check_context_handle()
479        if not isinstance(enable_parallel_optimizer, bool):
480            raise TypeError('enable_parallel_optimizer is invalid type')
481        self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
482
483    def get_enable_parallel_optimizer(self):
484        """Get parallel optimizer flag."""
485        self.check_context_handle()
486        return self._context_handle.get_enable_parallel_optimizer()
487
488    def set_sharding_propagation(self, sharding_propagation):
489        """
490        Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
491        will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
492        will search the desired strategies.
493        Default: False.
494
495        Args:
496            sharding_propagation (bool): Enable/disable strategy propagation.
497        """
498        self.check_context_handle()
499        if not isinstance(sharding_propagation, bool):
500            raise TypeError("'sharding_propagation' is an invalid type.")
501        self._context_handle.set_sharding_propagation(sharding_propagation)
502
503    def get_sharding_propagation(self):
504        """Get the value of sharding strategy propagation."""
505        self.check_context_handle()
506        return self._context_handle.get_sharding_propagation()
507
508    def set_enable_alltoall(self, enable_a2a):
509        """
510        Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
511        Default: False.
512
513        Args:
514            enable_a2a (bool): Enable/disable AllToAll.
515        """
516        self.check_context_handle()
517        if not isinstance(enable_a2a, bool):
518            raise TypeError("'enable_a2a' is an invalid type.")
519        self._context_handle.set_enable_alltoall(enable_a2a)
520
521    def get_enable_alltoall(self):
522        """Get the value of enabling AllToAll."""
523        self.check_context_handle()
524        return self._context_handle.get_enable_alltoall()
525
526    def set_communi_parallel_mode(self, communi_parallel_mode):
527        """
528        Set communication parallel mode.
529
530        Args:
531            communi_parallel_mode (str): The communication parallel mode.
532
533        Raises:
534            ValueError: If parallel mode is not supported.
535        """
536        if not isinstance(communi_parallel_mode, str):
537            raise TypeError(f"The type of communi_parallel_mode must be str, \
538                but got {type(communi_parallel_mode)}.")
539        self.check_context_handle()
540        ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
541        if ret is False:
542            raise ValueError("Communication parallel mode does not support {}".format(communi_parallel_mode))
543
544    def get_communi_parallel_mode(self):
545        """Get communication parallel mode."""
546        self.check_context_handle()
547        return self._context_handle.get_communi_parallel_mode()
548
549    def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
550        """
551        Set optimizer_weight_shard_size.
552
553        Args:
554            optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
555                                               optimizer across devices.
556        """
557        self.check_context_handle()
558        if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool):
559            raise TypeError(f"The type of optimizer_weight_shard_size must be int, \
560                but got {type(optimizer_weight_shard_size)}.")
561        if optimizer_weight_shard_size <= 1:
562            logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
563                           "Please use the integer larger than 1.")
564            return
565        self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
566
567    def get_optimizer_weight_shard_size(self):
568        """Get optimizer_weight_shard_size."""
569        self.check_context_handle()
570        return self._context_handle.get_optimizer_weight_shard_size()
571
572    def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
573        """
574        Set optimizer_weight_shard_aggregated_save.
575
576        Args:
577            optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
578                                                           enable parallel optimizer.
579        """
580        self.check_context_handle()
581        if not isinstance(optimizer_weight_shard_aggregated_save, bool):
582            raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
583        self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)
584
585
586    def get_optimizer_weight_shard_aggregated_save(self):
587        """Get optimizer_weight_shard_size."""
588        self.check_context_handle()
589        return self._context_handle.get_optimizer_weight_shard_aggregated_save()
590
591
592    def reset(self):
593        """Reset all settings."""
594        self.check_context_handle()
595        self._context_handle.reset()
596
597
598    def _check_and_default_group(self, group):
599        """Validate the given group, if group is empty, returns a default fusion group"""
600        if isinstance(group, (str)):
601            group_len = len(group)
602            if group_len > _MAX_GROUP_NAME_LEN:
603                raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
604        else:
605            raise TypeError('Group must be a python str')
606
607        if group == "":
608            if context.get_context("device_target") == "Ascend":
609                group = _DEFAULT_HCCL_FUSION_GROUP_NAME
610            else:
611                group = _DEFAULT_NCCL_FUSION_GROUP_NAME
612        return group
613
614
615_auto_parallel_context = None
616
617
618def auto_parallel_context():
619    """
620    Get the global _auto_parallel_context, if it is not created, create a new one.
621
622    Returns:
623        _AutoParallelContext, the global auto parallel context.
624    """
625    global _auto_parallel_context
626    if _auto_parallel_context is None:
627        _auto_parallel_context = _AutoParallelContext()
628    return _auto_parallel_context
629
630
631_set_auto_parallel_context_func_map = {
632    "device_num": auto_parallel_context().set_device_num,
633    "global_rank": auto_parallel_context().set_global_rank,
634    "gradients_mean": auto_parallel_context().set_gradients_mean,
635    "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
636    "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
637    "pipeline_stages": auto_parallel_context().set_pipeline_stages,
638    "parallel_mode": auto_parallel_context().set_parallel_mode,
639    "auto_parallel_search_mode": auto_parallel_context().set_strategy_search_mode,
640    "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
641    "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
642    "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
643    "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
644    "full_batch": auto_parallel_context().set_full_batch,
645    "dataset_strategy": auto_parallel_context().set_dataset_strategy,
646    "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
647    "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
648    "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
649    "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
650    "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
651    "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
652    "sharding_propagation": auto_parallel_context().set_sharding_propagation,
653    "enable_alltoall": auto_parallel_context().set_enable_alltoall}
654
655
656_get_auto_parallel_context_func_map = {
657    "device_num": auto_parallel_context().get_device_num,
658    "global_rank": auto_parallel_context().get_global_rank,
659    "gradients_mean": auto_parallel_context().get_gradients_mean,
660    "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
661    "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
662    "pipeline_stages": auto_parallel_context().get_pipeline_stages,
663    "parallel_mode": auto_parallel_context().get_parallel_mode,
664    "auto_parallel_search_mode": auto_parallel_context().get_strategy_search_mode,
665    "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
666    "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
667    "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
668    "full_batch": auto_parallel_context().get_full_batch,
669    "dataset_strategy": auto_parallel_context().get_dataset_strategy,
670    "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
671    "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
672    "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
673    "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
674    "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
675    "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
676    "sharding_propagation": auto_parallel_context().get_sharding_propagation,
677    "enable_alltoall": auto_parallel_context().get_enable_alltoall}
678
679
680@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
681                 loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
682                 parameter_broadcast=bool, strategy_ckpt_load_file=str,
683                 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
684                 grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
685                 communi_parallel_mode=str, optimizer_weight_shard_size=int,
686                 optimizer_weight_shard_aggregated_save=bool,
687                 sharding_propagation=bool, enable_alltoall=bool)
688
689def _set_auto_parallel_context(**kwargs):
690    """
691    Set auto parallel context.
692
693    Note:
694        Attribute name is required for setting attributes.
695
696    Args:
697        device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
698        global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
699        gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: False.
700        loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
701                        calculations. Default: True.
702        gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
703                        Default: True.
704        parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
705                     "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
706
707                     - stand_alone: Only one processor working.
708
709                     - data_parallel: Distributing the data across different processors.
710
711                     - hybrid_parallel: Achieving data parallelism and model parallelism manually.
712
713                     - semi_auto_parallel: Achieving data parallelism and model parallelism by
714                       setting parallel strategies.
715
716                     - auto_parallel: Achieving parallelism automatically.
717        auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
718                     and "dynamic_programming". Default: "dynamic_programming".
719
720                     - recursive_programming: Recursive programming search mode.
721
722                     - dynamic_programming: Dynamic programming search mode.
723        parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
724                       "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
725                       broadcast. Default: False.
726        strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
727        strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
728        group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
729        full_batch (bool): Whether to load the whole batch on each device. Default: False.
730        dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel".
731        enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
732        all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
733        pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
734                        the devices are distributed alone the pipeline. The total devices will be divided into
735                        'pipeline_stags' stages. This currently could only be used when
736                        parallel mode semi_auto_parallel is enabled. Default: 0
737        communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
738                     "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
739
740                     - all_group_parallel: All communication groups are in parallel.
741
742                     - same_server_group_parallel: Only the communication groups within the same server are parallel.
743
744                     - no_group_parallel: All communication groups are not parallel.
745        optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
746                                    It should be larger than one and less than or equal with the data parallel size.
747                                    Default: -1, which means fully use parallel optimizer in data parallel dimension.
748        optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
749                                                       optimizer. Default: False.
750        sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
751                                    the strategy-configured operators will propagate the strategies to other
752                                    operators with minimum redistribution cost; otherwise, the algorithm will
753                                    search the desired strategies. Default: False.
754        enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
755                                circumvent AllToAll. Default: False.
756
757    Raises:
758        ValueError: If input key is not attribute in auto parallel context.
759    """
760    for key, value in kwargs.items():
761        if key not in _set_auto_parallel_context_func_map:
762            raise ValueError("Set context keyword %s is not recognized!" % key)
763        set_func = _set_auto_parallel_context_func_map[key]
764        set_func(value)
765
766
767def _get_auto_parallel_context(attr_key):
768    """
769    Get auto parallel context attribute value according to the key.
770
771    Args:
772        attr_key (str): The key of the attribute.
773
774    Returns:
775        Return attribute value according to the key.
776
777    Raises:
778        ValueError: If input key is not attribute in auto parallel context.
779    """
780    if attr_key not in _get_auto_parallel_context_func_map:
781        raise ValueError("Get context keyword %s is not recognized!" % attr_key)
782    get_func = _get_auto_parallel_context_func_map[attr_key]
783    return get_func()
784
785
786def _reset_auto_parallel_context():
787    """
788    Reset auto parallel context attributes to the default values:
789
790    - device_num: 1.
791    - global_rank: 0.
792    - gradients_mean: False.
793    - gradient_fp32_sync: True.
794    - parallel_mode: "stand_alone".
795    - parameter_broadcast: False.
796    - strategy_ckpt_load_file: ""
797    - strategy_ckpt_save_file: ""
798    - enable_parallel_optimizer: False
799    - auto_parallel_search_mode: dynamic_programming
800    - pipeline_stages: 0
801    """
802    auto_parallel_context().reset()
803