• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Context of auto parallel"""
16from __future__ import absolute_import
17import os
18import copy
19import threading
20from mindspore import context
21import mindspore.log as logger
22from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
23from mindspore.parallel._ps_context import _is_role_pserver
24from mindspore._c_expression import AutoParallelContext
25from mindspore._checkparam import args_type_check
26from mindspore import _checkparam as Validator
27
28_MAX_GROUP_NAME_LEN = 127
29_DEFAULT_HCCL_FUSION_GROUP_NAME = "hccl_world_groupsum1"
30_DEFAULT_NCCL_FUSION_GROUP_NAME = "nccl_world_groupsum1"
31
32
33class _ParallelFusionConfig:
34    """
35    The key of the Parallel fusion method configuration.
36    """
37    ALLREDUCE = "allreduce"
38    ALLGATHER = "allgather"
39    REDUCESCATTER = "reducescatter"
40    MODE = "mode"
41    FUSION_CONFIG = "config"
42    AUTO = "auto"
43    INDEX = "index"
44    SIZE = "size"
45    OPENSTATE = "openstate"
46    CONFIG = {"openstate": True,
47              "allreduce": {"mode": "auto", "config": None},
48              "allgather": {"mode": "auto", "config": None},
49              "reducescatter": {"mode": "auto", "config": None}}
50
51    @classmethod
52    def reset(cls):
53        cls.CONFIG = {"openstate": True,
54                      "allreduce": {"mode": "auto", "config": None},
55                      "allgather": {"mode": "auto", "config": None},
56                      "reducescatter": {"mode": "auto", "config": None}}
57
58
59class _ParallelOptimizerConfig:
60    """
61    The key of the Parallel Optimizer. There are three
62    """
63    GRADIENT_ACCUMULATION_SHARD = "gradient_accumulation_shard"
64    PARALLEL_OPTIMIZER_THRESHOLD = "parallel_optimizer_threshold"
65    OPTIMIZER_WEIGHT_SHARD_SIZE = "optimizer_weight_shard_size"
66
67
68class _PipelineConfig:
69    """
70    The key of the Pipeline parallelism.
71    """
72    PIPELINE_INTERLEAVE = "pipeline_interleave"
73    PIPELINE_SCHEDULER = "pipeline_scheduler"
74
75
76class _PipelineScheduler:
77    PIPELINE_1F1B = "1f1b"
78    PIPELINE_GPIPE = "gpipe"
79
80
81class _AutoParallelContext:
82    """
83    _AutoParallelContext is the environment in which operations are executed
84
85    Note:
86        Create a context through instantiating Context object is not recommended.
87        Should use auto_parallel_context() to get the context since Context is singleton.
88    """
89    _instance = None
90    _instance_lock = threading.Lock()
91
92    def __new__(cls):
93        if cls._instance is None:
94            cls._instance_lock.acquire()
95            cls._instance = object.__new__(cls)
96            cls._instance_lock.release()
97        return cls._instance
98
99    def __init__(self):
100        self._context_handle = AutoParallelContext.get_instance()
101        self._dataset_strategy_using_str = True
102
103    def check_context_handle(self):
104        """
105        Check context handle.
106
107        Raises:
108            ValueError: If the context handle is none.
109        """
110        if self._context_handle is None:
111            raise ValueError("Context handle is none in context!!!")
112
113    def set_device_num(self, device_num):
114        """
115        Set device num for auto parallel.
116
117        Args:
118            device_num (int): The device number.
119
120        Raises:
121            ValueError: If the device num is not a positive integer.
122        """
123        self.check_context_handle()
124        if device_num < 1:
125            raise ValueError("The context configuration parameter 'device_num' must be a positive integer, "
126                             "but got the value of device_num : {}.".format(device_num))
127        from mindspore.communication._comm_helper import _HCCL_TEST_AVAILABLE
128        self._context_handle.set_hccl_test_avaible(_HCCL_TEST_AVAILABLE)
129        self._context_handle.set_device_num(device_num)
130
131    def get_device_num(self):
132        """Get device num."""
133        self.check_context_handle()
134        return self._context_handle.get_device_num()
135
136    def set_comm_fusion(self, config):
137        """
138        Set fusion method for auto parallel.
139
140        Args:
141            config (dict): A dict contains the methods and values for setting the communication fusion. Currently it
142            supports: `allreduce`.
143
144        Raises:
145            KeyError: When key of comm_fusion is not 'allreduce'.
146        """
147        self.check_context_handle()
148        config = copy.deepcopy(config)
149        if _ParallelFusionConfig.OPENSTATE not in config.keys():
150            config[_ParallelFusionConfig.OPENSTATE] = True
151        for key in list(config.keys()):
152            if key == _ParallelFusionConfig.ALLREDUCE:
153                self._set_allreduce_comm_fusion(config[key])
154            elif key == _ParallelFusionConfig.ALLGATHER:
155                self._set_allgather_comm_fusion(config[key], key)
156            elif key == _ParallelFusionConfig.REDUCESCATTER:
157                self._set_allgather_comm_fusion(config[key], key)
158            elif key == _ParallelFusionConfig.OPENSTATE:
159                self._set_openstate_comm_fusion(config[key])
160            else:
161                raise KeyError("comm fusion type must be openstate,"
162                               "allreduce, allgather or reducescatter, but got {}".format(key))
163            if key in _ParallelFusionConfig.CONFIG:
164                _ParallelFusionConfig.CONFIG[key] = config[key]
165
166    def get_comm_fusion(self):
167        """Get comm fusion config."""
168        self.check_context_handle()
169        return _ParallelFusionConfig.CONFIG
170
171    def set_fusion_threshold_mb(self, fusion_threshold=64, comm_type="allreduce"):
172        """
173        Set fusion threshold (MB) for auto parallel.
174
175        Args:
176            fusion_threshold (int): The fusion threshold (unit: MB). Default: 64.
177            comm_type (str): The name of the communication operator, `allreduce`, `allgather` or `reducescatter`.
178
179        Raises:
180            ValueError: If the fusion threshold is not in [0, +inf].
181        """
182        self.check_context_handle()
183        if fusion_threshold < 0:
184            raise ValueError("fusion threshold must be larger than 0, but got {}".format(fusion_threshold))
185
186        if comm_type == _ParallelFusionConfig.ALLREDUCE:
187            self._context_handle.set_fusion_threshold_mb(fusion_threshold)
188        if comm_type == _ParallelFusionConfig.ALLGATHER:
189            self._context_handle.set_allgather_fusion_threshold_mb(fusion_threshold)
190        if comm_type == _ParallelFusionConfig.REDUCESCATTER:
191            self._context_handle.set_reducescatter_fusion_threshold_mb(fusion_threshold)
192
193    def fusion_threshold_mb(self):
194        """Get all reduce threshold."""
195        self.check_context_handle()
196        return self._context_handle.fusion_threshold_mb()
197
198    def allgather_fusion_threshold_mb(self):
199        """Get allgather threshold."""
200        self.check_context_handle()
201        return self._context_handle.allgather_fusion_threshold_mb()
202
203    def reducescatter_fusion_threshold_mb(self):
204        """Get reducescatter threshold."""
205        self.check_context_handle()
206        return self._context_handle.reducescatter_fusion_threshold_mb()
207
208    def set_global_rank(self, global_rank):
209        """
210        Set global rank for auto parallel.
211
212        Args:
213            global_rank (int): The rank id of current rank.
214
215        Raises:
216            ValueError: If the global rank is not in [1, 4096].
217        """
218        self.check_context_handle()
219        if global_rank < 0 or global_rank > 4095:
220            raise ValueError("The context configuration parameter 'global_rank' must be in [0, 4095], "
221                             "but got the value of global_rank : {}.".format(global_rank))
222        self._context_handle.set_global_rank(global_rank)
223
224    def get_global_rank(self):
225        """Get current rank id."""
226        self.check_context_handle()
227        return self._context_handle.get_global_rank()
228
229    def set_pipeline_stages(self, stages):
230        """Set the stages of the pipeline"""
231        if isinstance(stages, bool) or not isinstance(stages, int):
232            raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
233                            "must be int, but got the type : {}.".format(type(stages)))
234        if stages < 1:
235            raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_stages' "
236                             "should be greater or equal 1, but got the value of stages : {}.".format(stages))
237        self.check_context_handle()
238        self._context_handle.set_pipeline_stage_split_num(stages)
239
240    def get_pipeline_stages(self):
241        """Get the stages of the pipeline"""
242        self.check_context_handle()
243        return self._context_handle.get_pipeline_stage_split_num()
244
245    def set_auto_pipeline(self, auto_pipeline):
246        """Set the pipeline stage number to automatic"""
247        if not isinstance(auto_pipeline, bool):
248            raise TypeError("For 'set_auto_parallel_context', the argument 'auto_pipeline' "
249                            "must be bool, but got the type : {}.".format(type(auto_pipeline)))
250        self.check_context_handle()
251        self._context_handle.set_auto_pipeline(auto_pipeline)
252
253    def get_auto_pipeline(self):
254        """Get whether the pipeline stage number is automatic"""
255        self.check_context_handle()
256        return self._context_handle.get_auto_pipeline()
257
258    def set_pipeline_result_broadcast(self, pipeline_result_broadcast):
259        """
260        Set the value of enabling pipeline result broadcast. Default: ``False``.
261
262        Args:
263            pipeline_result_broadcast (bool): Enable/disable broadcast the last stage result to all other stages.
264        """
265        self.check_context_handle()
266        if not isinstance(pipeline_result_broadcast, bool):
267            raise TypeError("For 'set_auto_parallel_context().set_pipeline_result_broadcast', the argument "
268                            "'pipeline_result_broadcast' must be bool, but got the type : {}."
269                            .format(type(pipeline_result_broadcast)))
270        self._context_handle.set_pipeline_result_broadcast(pipeline_result_broadcast)
271
272    def get_pipeline_result_broadcast(self):
273        """Get the value of enabling pipeline result broadcast"""
274        self.check_context_handle()
275        return self._context_handle.get_pipeline_result_broadcast()
276
277    def get_pipeline_interleave(self):
278        """Get pipeline interleave flag"""
279        self.check_context_handle()
280        return self._context_handle.get_pipeline_interleave()
281
282    def get_pipeline_scheduler(self):
283        """Get pipeline scheduler"""
284        self.check_context_handle()
285        return self._context_handle.get_pipeline_scheduler()
286
287    def set_pipeline_segments(self, segments):
288        """Set the segments of the pipeline"""
289        if isinstance(segments, bool) or not isinstance(segments, int):
290            raise TypeError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
291                            "must be int, but got the type : {}.".format(type(segments)))
292        if segments < 1:
293            raise ValueError("For 'set_auto_parallel_context', the argument 'pipeline_segments' "
294                             "should be greater or equal 1, but got the value of segments : {}.".format(segments))
295        self.check_context_handle()
296        self._context_handle.set_pipeline_segment_split_num(segments)
297
298    def get_pipeline_segments(self):
299        """Get the stages of the pipeline"""
300        self.check_context_handle()
301        return self._context_handle.get_pipeline_segment_split_num()
302
303    def set_gradients_mean(self, gradients_mean):
304        """
305        Set gradients_mean flag.
306
307        Note:
308            If gradients_mean is true, it will insert a div operator after parameter gradients allreduce.
309
310        Args:
311            gradients_mean (bool): The gradients_mean flag.
312        """
313        self.check_context_handle()
314        self._context_handle.set_gradients_mean(gradients_mean)
315
316    def get_gradients_mean(self):
317        """Get gradients_mean flag."""
318        self.check_context_handle()
319        return self._context_handle.get_gradients_mean()
320
321    def set_gradient_fp32_sync(self, gradient_fp32_sync):
322        """
323        Set gradient_fp32_sync.
324
325        Note:
326            If gradient_fp32_sync is true,
327            it will convert tensor type from fp16 to fp32 before parameter gradients allreduce.
328
329        Args:
330            gradient_fp32_sync (bool): The gradient_fp32_sync flag.
331        """
332        self.check_context_handle()
333        self._context_handle.set_gradient_fp32_sync(gradient_fp32_sync)
334
335    def get_gradient_fp32_sync(self):
336        """Get gradient_fp32_sync flag."""
337        self.check_context_handle()
338        return self._context_handle.get_gradient_fp32_sync()
339
340    def set_loss_repeated_mean(self, loss_repeated_mean):
341        """
342        Set loss_repeated_mean flag.
343
344        Note:
345            If loss_repeated_mean is true,
346            Distributed automatic differentiation will perform a mean operator
347            in backward in the case of repeated calculations.
348
349        Args:
350            loss_repeated_mean (bool): The loss_repeated_mean flag.
351        """
352        if not isinstance(loss_repeated_mean, bool):
353            raise TypeError("For 'set_auto_parallel_context', the argument 'loss_repeated_mean' "
354                            "must be bool, but got the type : {}.".format(type(loss_repeated_mean)))
355        self.check_context_handle()
356        self._context_handle.set_loss_repeated_mean(loss_repeated_mean)
357
358    def get_loss_repeated_mean(self):
359        """Get loss_repeated_mean flag."""
360        self.check_context_handle()
361        return self._context_handle.get_loss_repeated_mean()
362
363    def set_parallel_mode(self, parallel_mode):
364        """
365        Set parallel mode for auto parallel.
366
367        Args:
368            parallel_mode (str): The parallel mode of auto parallel.
369
370        Raises:
371            ValueError: If parallel mode is not supported.
372        """
373        self.check_context_handle()
374        run_mode = context.get_context("mode")
375        if run_mode == context.PYNATIVE_MODE and parallel_mode not in (
376                context.ParallelMode.DATA_PARALLEL, context.ParallelMode.STAND_ALONE,
377                context.ParallelMode.AUTO_PARALLEL):
378            raise ValueError(f"Pynative only supports STAND_ALONE, DATA_PARALLEL and AUTO_PARALLEL using"
379                             f" sharding_propagation under shard function"
380                             f" for ParallelMode, "
381                             f"but got {parallel_mode.upper()}.")
382        ret = self._context_handle.set_parallel_mode(parallel_mode)
383        if ret is False:
384            raise ValueError("The context configuration parameter 'parallel_mode' only support 'stand_alone', "
385                             "'data_parallel', 'hybrid_parallel', 'semi_auto_parallel' and 'auto_parallel', "
386                             "but got the value : {}.".format(parallel_mode))
387
388    def get_parallel_mode(self):
389        """Get parallel mode."""
390        self.check_context_handle()
391        return self._context_handle.get_parallel_mode()
392
393    def set_strategy_search_mode(self, search_mode):
394        """
395        Set search mode of strategy.
396
397        Args:
398            search_mode (str): The search mode of strategy.
399        """
400        self.check_context_handle()
401        ret = self._context_handle.set_strategy_search_mode(search_mode)
402        if ret is False:
403            raise ValueError("The context configuration parameter 'auto_parallel_search_mode' only support "
404                             "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', "
405                             "but got the value: {}."
406                             .format(search_mode))
407
408    def get_strategy_search_mode(self):
409        """Get search mode of strategy."""
410        self.check_context_handle()
411        return self._context_handle.get_strategy_search_mode()
412
413    def set_auto_parallel_search_mode(self, search_mode):
414        """
415        Set search mode of strategy searching. This is the old version of 'search_mode', and will be deleted in a future
416        MindSpore version.
417
418        Args:
419            search_mode (str): The search mode of strategy.
420        """
421        logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
422                       "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
423        self.check_context_handle()
424        ret = self._context_handle.set_strategy_search_mode(search_mode)
425        if ret is False:
426            raise ValueError("The context configuration parameter 'search_mode' only support "
427                             "'recursive_programming', 'dynamic_programming' and 'sharding_propagation', "
428                             "but got the value: {}."
429                             .format(search_mode))
430
431    def get_auto_parallel_search_mode(self):
432        """Get search mode of strategy. This is the old version of 'search_mode', and will be deleted in a future
433        MindSpore version.
434        """
435        logger.warning("The attribute 'auto_parallel_search_mode' is currently replaced by 'search_mode'. "
436                       "The attribute 'auto_parallel_search_mode' will be deleted in a future MindSpore version.")
437        self.check_context_handle()
438        return self._context_handle.get_strategy_search_mode()
439
440    def set_sharding_propagation(self, sharding_propagation):
441        """
442        Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True, the strategy-configured operators
443        will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm
444        will search the desired strategies. Default: ``False``.
445        This attribute is replaced by context.set_auto_parallel_context(search_mode="sharding_propagation").
446
447        Args:
448            sharding_propagation (bool): Enable/disable strategy propagation.
449        """
450        logger.warning("This attribute is replaced by "
451                       "context.set_auto_parallel_context(search_mode='sharding_propagation'), and this attribute will"
452                       " be deleted in a future MindSpore version.")
453        self.check_context_handle()
454        if not isinstance(sharding_propagation, bool):
455            raise TypeError("For 'set_auto_parallel_context().set_sharding_propagation', "
456                            "the argument 'sharding_propagation' must be bool, but got the type : {}."
457                            .format(type(sharding_propagation)))
458        self._context_handle.set_sharding_propagation(sharding_propagation)
459
460    def get_sharding_propagation(self):
461        """Get the value of sharding strategy propagation."""
462        self.check_context_handle()
463        return self._context_handle.get_sharding_propagation()
464
465    def set_parameter_broadcast(self, parameter_broadcast):
466        """
467        Set parameter broadcast.
468
469        Args:
470            parameter_broadcast (bool): Parameter broadcast or not.
471        """
472        self.check_context_handle()
473        self._context_handle.set_parameter_broadcast(parameter_broadcast)
474
475    def get_parameter_broadcast(self):
476        """Get parameter broadcast flag."""
477        self.check_context_handle()
478        return self._context_handle.get_parameter_broadcast()
479
480    def set_strategy_ckpt_load_file(self, strategy_ckpt_load_file):
481        """
482        Set strategy checkpoint load path.
483
484        Args:
485            strategy_ckpt_load_file (str): Path to load parallel strategy checkpoint.
486        """
487        self.check_context_handle()
488        self._context_handle.set_strategy_ckpt_load_file(strategy_ckpt_load_file)
489
490    def get_strategy_ckpt_load_file(self):
491        """Get strategy checkpoint load path."""
492        self.check_context_handle()
493        return self._context_handle.get_strategy_ckpt_load_file()
494
495    def set_full_batch(self, full_batch):
496        """
497        Set whether load full batch on each device.
498
499        Args:
500            full_batch (bool): True if load full batch on each device.
501        """
502        self.check_context_handle()
503        self._context_handle.set_full_batch(full_batch)
504
505    def get_full_batch(self):
506        """Get whether load full batch on each device."""
507        self.check_context_handle()
508        if _is_role_pserver():
509            return False
510        return self._context_handle.get_full_batch()
511
512    def set_dataset_strategy(self, dataset_strategy):
513        """
514        Set dataset sharding strategy.
515
516        Args:
517            dataset_strategy (str or tuple(tuple)): The dataset sharding strategy.
518        """
519        self.check_context_handle()
520        if isinstance(dataset_strategy, str):
521            if dataset_strategy not in ("full_batch", "data_parallel"):
522                raise ValueError("For 'set_auto_parallel_context', the argument "
523                                 "'dataset_strategy' must be 'full_batch' or 'data_parallel', but got the value : {}."
524                                 .format(dataset_strategy))
525            self._context_handle.set_full_batch(dataset_strategy == "full_batch")
526            self._dataset_strategy_using_str = True
527            return
528        if not isinstance(dataset_strategy, tuple):
529            raise TypeError("For 'set_auto_parallel_context', the argument 'dataset_strategy' "
530                            "must be str or tuple type, but got the type : {}.".format(type(dataset_strategy)))
531        for ele in dataset_strategy:
532            if not isinstance(ele, tuple):
533                raise TypeError("For 'set_auto_parallel_context', the element of argument "
534                                "'dataset_strategy' must be tuple, but got the type : {} .".format(type(ele)))
535            for dim in ele:
536                if not isinstance(dim, int):
537                    raise TypeError("For 'set_auto_parallel_context', the element of argument "
538                                    "'dataset_strategy' must be int type, but got the type : {} .".format(type(dim)))
539        if context.get_context('mode') == context.PYNATIVE_MODE:
540            raise ValueError("In PyNative mode, the setting value of 'dataset_strategy' must be either 'full_batch' "
541                             f"or 'data_parallel', but got {dataset_strategy}.")
542        self._dataset_strategy_using_str = False
543        self._context_handle.set_dataset_strategy(dataset_strategy)
544
545    def get_dataset_strategy(self):
546        """Get dataset sharding strategy."""
547        self.check_context_handle()
548        if self._dataset_strategy_using_str:
549            if self._context_handle.get_full_batch():
550                return "full_batch"
551            return "data_parallel"
552        dataset_strategy = self._context_handle.get_dataset_strategy()
553        if context.get_context('mode') == context.PYNATIVE_MODE:
554            raise ValueError("In PyNative mode, the value of 'dataset_strategy' must be either 'full_batch' "
555                             f"or 'data_parallel', but got the setting value is {dataset_strategy}.")
556        return dataset_strategy
557
558    def set_grad_accumulation_step(self, grad_accumulation_step):
559        """
560        Set grad accumulation step.
561
562        Args:
563            grad_accumulation_step (int): The grad accumulation step.
564        """
565        if grad_accumulation_step > 1:
566            raise ValueError("The interface is deprecated. To use gradient accumulation, "
567                             "please use GradAccumulationCell in mindspore.nn.wrap.cell_wrapper.")
568        self.check_context_handle()
569        Validator.check_positive_int(grad_accumulation_step)
570        self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
571
572    def get_grad_accumulation_step(self):
573        """Get grad accumulation step."""
574        self.check_context_handle()
575        return self._context_handle.get_grad_accumulation_step()
576
577    def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
578        """
579        Set strategy checkpoint save path.
580
581        Args:
582            strategy_ckpt_save_file (bool): Path to save parallel strategy checkpoint.
583        """
584        self.check_context_handle()
585        dir_path = os.path.dirname(strategy_ckpt_save_file)
586        if dir_path and not os.path.exists(dir_path):
587            os.makedirs(dir_path, exist_ok=True)
588        self._context_handle.set_strategy_ckpt_save_file(strategy_ckpt_save_file)
589
590    def get_strategy_ckpt_save_file(self):
591        """Get strategy checkpoint save path."""
592        self.check_context_handle()
593        return self._context_handle.get_strategy_ckpt_save_file()
594
595    def set_strategy_ckpt_config(self, strategy_ckpt_config):
596        """
597        Set strategy checkpoint config.
598
599        Args:
600            strategy_ckpt_config (dict): The strategy checkpoint config.
601        """
602        self.check_context_handle()
603        if not isinstance(strategy_ckpt_config, dict):
604            raise TypeError("For 'set_auto_parallel_context', the argument 'strategy_ckpt_config' "
605                            "must be dict, but got the type : {}.".format(type(strategy_ckpt_config)))
606        for config_name in strategy_ckpt_config:
607            unknown_config = []
608            if config_name not in ["load_file", "save_file", "only_trainable_params"]:
609                unknown_config.append(config_name)
610
611            if unknown_config:
612                raise ValueError("Unknown config: {}".format(unknown_config))
613        if "load_file" in strategy_ckpt_config:
614            load_file = strategy_ckpt_config.get("load_file")
615            if not isinstance(load_file, str):
616                raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
617                                "the argument 'load_file' must be str, but got the type : {} .".format(type(load_file)))
618            self._context_handle.set_strategy_ckpt_load_file(load_file)
619        if "save_file" in strategy_ckpt_config:
620            save_file = strategy_ckpt_config.get("save_file")
621            if not isinstance(save_file, str):
622                raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
623                                "the argument 'save_file' must be str, but got the type : {} .".format(type(save_file)))
624            self._context_handle.set_strategy_ckpt_save_file(save_file)
625        if "only_trainable_params" in strategy_ckpt_config:
626            only_trainable_params = strategy_ckpt_config.get("only_trainable_params")
627            if not isinstance(only_trainable_params, bool):
628                raise TypeError("For 'set_auto_parallel_context().set_strategy_ckpt_config', "
629                                "the argument 'only_trainable_params' must be bool,"
630                                " but got the type : {} .".format(type(only_trainable_params)))
631            self._context_handle.set_stra_file_only_trainable_params(only_trainable_params)
632
633    def get_strategy_ckpt_config(self):
634        """Get strategy checkpoint config."""
635        self.check_context_handle()
636        load_file = self._context_handle.get_strategy_ckpt_load_file()
637        save_file = self._context_handle.get_strategy_ckpt_save_file()
638        only_trainable_param = self._context_handle.get_stra_file_only_trainable_params()
639        return {"load_file": load_file, "save_file": save_file, "only_trainable_params": only_trainable_param}
640
641    def set_group_ckpt_save_file(self, group_ckpt_save_file):
642        """Set group checkpoint save path."""
643        self.check_context_handle()
644        dir_path = os.path.dirname(group_ckpt_save_file)
645        if dir_path and not os.path.exists(dir_path):
646            os.makedirs(dir_path)
647        self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
648
649    def get_parameter_broadcast_is_set(self):
650        """Get parameter broadcast is set or not."""
651        self.check_context_handle()
652        return self._context_handle.get_parameter_broadcast_is_set()
653
654    def set_all_reduce_fusion_split_indices(self, indices, group=""):
655        """
656        Set allreduce fusion strategy by parameters indices.
657
658        Args:
659            indices (list): Indices list.
660            group (str): The communication group of hccl/nccl.
661
662        Raises:
663            TypeError: If type of indices item is not int.
664            TypeError: If group is not a python str.
665        """
666        self.check_context_handle()
667        if not indices:
668            raise ValueError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
669                             "the argument 'indices' can not be empty")
670
671        if isinstance(indices, (list)):
672            for index in indices:
673                if not isinstance(index, int) or isinstance(index, bool):
674                    raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
675                                    "the argument 'index' must be int, but got the type : {} .".format(type(index)))
676        else:
677            raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
678                            "the argument 'indices' must be list, but got the type : {} .".format(type(indices)))
679
680        if len(set(indices)) != len(indices):
681            raise ValueError("The indices has duplicate elements")
682
683        if sorted(indices) != indices:
684            raise ValueError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_indices', "
685                             "the elements in argument 'indices' must be sorted in ascending order")
686
687        new_group = self._check_and_default_group(group)
688
689        self._context_handle.set_all_reduce_fusion_split_indices(indices, new_group)
690        if context.get_context("device_target") == "Ascend" and context.get_context("enable_ge"):
691            _set_fusion_strategy_by_idx(indices)
692
693    def get_all_reduce_fusion_split_indices(self, group=""):
694        """
695        Get allreduce fusion split indices.
696
697        Args:
698            group (str): The communication group of hccl/nccl.
699
700        Returns:
701            Return split sizes list according to the group.
702
703        Raises:
704            TypeError: If group is not a python str.
705        """
706        self.check_context_handle()
707        new_group = self._check_and_default_group(group)
708        return self._context_handle.get_all_reduce_fusion_split_indices(new_group)
709
710    def set_all_reduce_fusion_split_sizes(self, sizes, group=""):
711        """
712        Set allreduce fusion strategy by parameters data sizes.
713
714        Args:
715            sizes (list): Sizes list.
716            group (str): The communication group of hccl/nccl.
717
718        Raises:
719            TypeError: If type of sizes item is not int.
720            TypeError: If group is not a python str.
721        """
722        self.check_context_handle()
723        if isinstance(sizes, (list)):
724            for size in sizes:
725                if not isinstance(size, int) or isinstance(size, bool):
726                    raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_sizes', "
727                                    "the argument 'sizes' must be int, but got the type : {}.".format(type(size)))
728        else:
729            raise TypeError("For 'set_auto_parallel_context().set_all_reduce_fusion_split_sizes', "
730                            "the argument 'sizes' must be list, but got the type : {}.".format(type(sizes)))
731
732        new_group = self._check_and_default_group(group)
733        self._context_handle.set_all_reduce_fusion_split_sizes(sizes, new_group)
734        if context.get_context("device_target") == "Ascend":
735            _set_fusion_strategy_by_size(sizes)
736
737    def get_all_reduce_fusion_split_sizes(self, group=""):
738        """
739        Get allreduce fusion split sizes.
740
741        Args:
742            group (str): The communication group of hccl/nccl.
743
744        Returns:
745            Return split sizes list according to the group.
746
747        Raises:
748            TypeError: If group is not a python str.
749        """
750        self.check_context_handle()
751        new_group = self._check_and_default_group(group)
752        return self._context_handle.get_all_reduce_fusion_split_sizes(new_group)
753
754    def set_enable_all_reduce_fusion(self, enable_all_reduce_fusion):
755        """
756        Set enable/disable all reduce fusion.
757
758        Args:
759            enable_all_reduce_fusion (bool): Enable/disable all reduce fusion.
760        """
761        self.check_context_handle()
762        if not isinstance(enable_all_reduce_fusion, bool):
763            raise TypeError("For 'set_auto_parallel_context().set_enable_all_reduce_fusion', "
764                            "the argument 'enable_fusion' must be bool, but got the type : {}."
765                            .format(type(enable_all_reduce_fusion)))
766        self._context_handle.set_enable_all_reduce_fusion(enable_all_reduce_fusion)
767
768    def set_enable_all_gather_fusion(self, enable_all_gather_fusion):
769        """
770        Set enable/disable all gather fusion.
771
772        Args:
773            enable_all_gather_fusion (bool): Enable/disable all gather fusion.
774        """
775        self.check_context_handle()
776        if not isinstance(enable_all_gather_fusion, bool):
777            raise TypeError("For 'set_auto_parallel_context().set_enable_all_gather_fusion', "
778                            "the argument 'enable_fusion' must be bool, but got the type : {}."
779                            .format(type(enable_all_gather_fusion)))
780        self._context_handle.set_enable_all_gather_fusion(enable_all_gather_fusion)
781
782    def set_enable_reduce_scatter_fusion(self, enable_reduce_scatter_fusion):
783        """
784        Set enable/disable reduce scatter fusion.
785
786        Args:
787            enable_reduce_scatter_fusion (bool): Enable/disable reduce scatter fusion.
788        """
789        self.check_context_handle()
790        if not isinstance(enable_reduce_scatter_fusion, bool):
791            raise TypeError("For 'set_auto_parallel_context().set_enable_reduce_scatter_fusion', "
792                            "the argument 'enable_fusion' must be bool, but got the type : {}."
793                            .format(type(enable_reduce_scatter_fusion)))
794        self._context_handle.set_enable_reduce_scatter_fusion(enable_reduce_scatter_fusion)
795
796    def get_enable_all_reduce_fusion(self):
797        """Get all reduce fusion flag."""
798        self.check_context_handle()
799        return self._context_handle.get_enable_all_reduce_fusion()
800
801    def get_enable_all_gather_fusion(self):
802        """Get all gather fusion flag."""
803        self.check_context_handle()
804        return self._context_handle.get_enable_all_gather_fusion()
805
806    def get_enable_reduce_scatter_fusion(self):
807        """Get reduce scatter flag."""
808        self.check_context_handle()
809        return self._context_handle.get_enable_reduce_scatter_fusion()
810
811    def get_device_num_is_set(self):
812        """Get device number is set or not."""
813        self.check_context_handle()
814        return self._context_handle.get_device_num_is_set()
815
816    def get_global_rank_is_set(self):
817        """Get global rank is set or not."""
818        self.check_context_handle()
819        return self._context_handle.get_global_rank_is_set()
820
821    def set_enable_parallel_optimizer(self, enable_parallel_optimizer):
822        """
823        Set enable/disable parallel optimizer.
824
825        Args:
826            set_enable_parallel_optimizer (bool): Enable/disable parallel optimizer.
827        """
828        self.check_context_handle()
829        if not isinstance(enable_parallel_optimizer, bool):
830            raise TypeError("For 'set_auto_parallel_context', "
831                            "the argument 'enable_parallel_optimizer' must be bool, but got the type : {}."
832                            .format(type(enable_parallel_optimizer)))
833        self._context_handle.set_enable_parallel_optimizer(enable_parallel_optimizer)
834
835    def set_force_fp32_communication(self, force_fp32_communication):
836        """
837        Set enable/disable force fp32 communication.
838
839        Args:
840            set_force_fp32_communication (bool): Enable/disable force fp32 communication.
841        """
842        self.check_context_handle()
843        if not isinstance(force_fp32_communication, bool):
844            raise TypeError("For 'set_auto_parallel_context', "
845                            "the argument 'force_fp32_communication' must be bool, but got the type : {}."
846                            .format(type(force_fp32_communication)))
847        self._context_handle.set_force_fp32_communication(force_fp32_communication)
848
849    def get_enable_fold_pipeline(self):
850        """Get parallel optimizer flag."""
851        self.check_context_handle()
852        return self._context_handle.get_enable_fold_pipeline()
853
854    def set_pipeline_config(self, pipeline_config):
855        r"""
856        Set the configuration for pipeline parallelism. The configuration provides more detailed behavior control about
857        parallel training when pipeline parallelism is enabled.
858
859        Args:
860            pipeline_config (dict): The configuration for pipeline parallelism. It supports following keys:
861
862            - pipeline_interleave(bool): Setting true enable interleave scheduler for pipeline parallelism. This
863                                         scheduler requires more memory but less bubble.
864            - pipeline_scheduler(string): There are two choices, "1f1b" and "gpipe". default is "1f1b"
865
866              - 1f1b: It requires less memory and bubble ratio, for it run backward pass when corresponding forward pass
867                      finished.
868              - gpipe: It requires more memory and bubble ratio, for it run backward pass after all forward pass
869                       finished.
870
871        Raises:
872            TypeError: If the type of `pipeline_config` is not `dict`.
873            ValueError: If the key in `pipeline_config` not  in ["pipeline_interleave", "pipeline_scheduler"].
874            ValueError: If pipeline interleave is False, pipeline scheduler is not `1f1b`.
875        """
876        self.check_context_handle()
877
878        if not isinstance(pipeline_config, dict):
879            raise TypeError("For 'set_pipeline_config', the argument 'pipeine_config' "
880                            "must be dict, but got the type : {}.".format(type(pipeline_config)))
881
882        pp_interleave = _PipelineConfig.PIPELINE_INTERLEAVE
883        pp_scheduler = _PipelineConfig.PIPELINE_SCHEDULER
884
885        for config_name in pipeline_config:
886            unknown_config = []
887            if config_name not in [pp_interleave, pp_scheduler]:
888                unknown_config.append(config_name)
889
890            if unknown_config:
891                raise ValueError("Unknown config: {}".format(unknown_config))
892
893        Validator.check_bool(
894            pipeline_config[pp_interleave], pp_interleave, pp_interleave)
895        self._context_handle.set_pipeline_interleave(
896            pipeline_config[pp_interleave])
897
898        Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
899                                                               _PipelineScheduler.PIPELINE_GPIPE])
900        if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
901            raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
902
903        self._context_handle.set_pipeline_scheduler(pipeline_config[pp_scheduler])
904
905    def get_enable_parallel_optimizer(self):
906        """Get parallel optimizer flag."""
907        self.check_context_handle()
908        return self._context_handle.get_enable_parallel_optimizer()
909
910    def get_force_fp32_communication(self):
911        """Get force fp32 communication flag."""
912        self.check_context_handle()
913        return self._context_handle.get_force_fp32_communication()
914
915
916    def set_parallel_optimizer_config(self, parallel_optimizer_config):
917        r"""
918        Set the configure for parallel optimizer. The configure provides more detailed behavior control about parallel
919        training when parallel optimizer is enabled.
920
921        Args:
922            parallel_optimizer_config(dict): A dict contains the keys and values for setting the parallel optimizer
923            configure. It supports the following keys:
924
925            - gradient_accumulation_shard(bool): If true, the accumulation gradient parameters will be sharded
926                                                 across the data parallel devices. This will introduce additional
927                                                 communication cost(ReduceScatter) at each step when accumulate the
928                                                 gradients, but saves a lot of device memories,
929                                                 thus can make model be trained with larger batch size.
930                                                 This configuration is effective only when the model runs on pipeline
931                                                 training or gradient accumulation with data parallel.
932
933            - parallel_optimizer_threshold(int): Set the threshold of parallel optimizer. When parallel optimizer is
934                                                 enabled, parameters with size smaller than this threshold will not be
935                                                 sharded across the devices. Parameter size = shape[0] \* ... \*
936                                                 shape[n] \* size(dtype). Non-negative. Unit: KB. Default: 64.
937            - optimizer_weight_shard_size(int): Set the optimizer weight shard group size if you want to specific the
938                                                maximum group size across devices when the parallel optimizer is
939                                                enabled. The numerical range can be (0, device_num]. Default value
940                                                is -1, which means the optimizer weight shard group size will
941                                                the data parallel group of each parameter. Default -1.
942
943        """
944        self.check_context_handle()
945        grad_shard_name = _ParallelOptimizerConfig.GRADIENT_ACCUMULATION_SHARD
946        threshold_name = _ParallelOptimizerConfig.PARALLEL_OPTIMIZER_THRESHOLD
947        optimizer_weight_shard_size_name = _ParallelOptimizerConfig.OPTIMIZER_WEIGHT_SHARD_SIZE
948
949        for config_name in parallel_optimizer_config:
950            unknown_config = []
951            if config_name not in [grad_shard_name, threshold_name, optimizer_weight_shard_size_name]:
952                unknown_config.append(config_name)
953
954            if unknown_config:
955                raise ValueError("Unknown config: {}".format(unknown_config))
956
957        if grad_shard_name in parallel_optimizer_config:
958            Validator.check_bool(
959                parallel_optimizer_config[grad_shard_name], grad_shard_name, grad_shard_name)
960            self._context_handle.set_grad_accumulation_shard(
961                parallel_optimizer_config[grad_shard_name])
962
963        if threshold_name in parallel_optimizer_config:
964            Validator.check_non_negative_int(
965                parallel_optimizer_config[threshold_name])
966            self._context_handle.set_parallel_optimizer_threshold(
967                parallel_optimizer_config[threshold_name])
968
969        if optimizer_weight_shard_size_name in parallel_optimizer_config:
970            value = parallel_optimizer_config[optimizer_weight_shard_size_name]
971            Validator.check_positive_int(value)
972            self.set_optimizer_weight_shard_size(value)
973
974    def get_grad_accumulation_shard(self):
975        """Get grad accumulation shard."""
976        self.check_context_handle()
977        return self._context_handle.get_grad_accumulation_shard()
978
979    def get_parallel_optimizer_threshold(self):
980        """Get parallel optimizer threshold."""
981        self.check_context_handle()
982        return self._context_handle.get_parallel_optimizer_threshold()
983
984    def set_enable_alltoall(self, enable_a2a):
985        """
986        Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
987        Default: ``False``.
988
989        Args:
990            enable_a2a (bool): Enable/disable AllToAll.
991        """
992        self.check_context_handle()
993        if not isinstance(enable_a2a, bool):
994            raise TypeError("For 'set_auto_parallel_context().set_enable_alltoall', the argument 'enable_a2a' "
995                            "must be bool, but got the type : {}.".format(type(enable_a2a)))
996        self._context_handle.set_enable_alltoall(enable_a2a)
997
998    def get_enable_alltoall(self):
999        """Get the value of enabling AllToAll."""
1000        self.check_context_handle()
1001        return self._context_handle.get_enable_alltoall()
1002
1003    def set_communi_parallel_mode(self, communi_parallel_mode):
1004        """
1005        Set communication parallel mode.
1006
1007        Args:
1008            communi_parallel_mode (str): The communication parallel mode.
1009
1010        Raises:
1011            ValueError: If parallel mode is not supported.
1012        """
1013        if not isinstance(communi_parallel_mode, str):
1014            raise TypeError("For 'set_auto_parallel_context().set_communi_parallel_mode', "
1015                            "the argument 'communi_parallel_mode' must be str, but got the type : {}."
1016                            .format(type(communi_parallel_mode)))
1017        self.check_context_handle()
1018        ret = self._context_handle.set_communi_parallel_mode(communi_parallel_mode)
1019        if ret is False:
1020            raise ValueError("For 'set_auto_parallel_context().set_communi_parallel_mode', "
1021                             "the argument 'communi_parallel_mode' only support 'ALL_GROUP_PARALLEL', "
1022                             "'SAME_SEVER_GROUP_PARALLEL' and 'NO_GROUP_PARALLEL', "
1023                             "but got the value : {}.".format(communi_parallel_mode))
1024
1025    def get_communi_parallel_mode(self):
1026        """Get communication parallel mode."""
1027        self.check_context_handle()
1028        return self._context_handle.get_communi_parallel_mode()
1029
1030    def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size):
1031        """
1032        Set optimizer_weight_shard_size.
1033
1034        Args:
1035            optimizer_weight_shard_size (int): Opt shard group size when not globally use parallel
1036                                               optimizer across devices.
1037        """
1038        self.check_context_handle()
1039        if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, bool):
1040            raise TypeError(f"The type of optimizer_weight_shard_size must be int, \
1041                but got {type(optimizer_weight_shard_size)}.")
1042        if optimizer_weight_shard_size <= 1:
1043            logger.warning("The setting 'optimizer_weight_shard_size' is invalid. "
1044                           "Please use the integer larger than 1.")
1045            return
1046        self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)
1047
1048    def get_optimizer_weight_shard_size(self):
1049        """Get optimizer_weight_shard_size."""
1050        self.check_context_handle()
1051        return self._context_handle.get_optimizer_weight_shard_size()
1052
1053    def set_ops_strategy_json_config(self, type, path, mode):
1054        """
1055         Set configuration of saving ops strategy in file .json.
1056        """
1057        self.check_context_handle()
1058        self._context_handle.set_ops_strategy_json_config(type, path, mode)
1059
1060    def set_optimizer_weight_shard_aggregated_save(self, optimizer_weight_shard_aggregated_save):
1061        """
1062        Set optimizer_weight_shard_aggregated_save.
1063
1064        Args:
1065            optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when
1066                                                           enable parallel optimizer.
1067        """
1068        self.check_context_handle()
1069        if not isinstance(optimizer_weight_shard_aggregated_save, bool):
1070            raise TypeError('optimizer_weight_shard_aggregated_save is invalid type')
1071        self._context_handle.set_optimizer_weight_shard_aggregated_save(optimizer_weight_shard_aggregated_save)
1072
1073    def get_optimizer_weight_shard_aggregated_save(self):
1074        """Get optimizer_weight_shard_size."""
1075        self.check_context_handle()
1076        return self._context_handle.get_optimizer_weight_shard_aggregated_save()
1077
1078    def get_full_batch_is_set(self):
1079        """Get full batch attr"""
1080        self.check_context_handle()
1081        return self._context_handle.get_full_batch_is_set()
1082
1083    def reset(self):
1084        """Reset all settings."""
1085        self.check_context_handle()
1086        self._context_handle.reset()
1087        _ParallelFusionConfig.reset()
1088
1089    def _check_and_default_group(self, group):
1090        """Validate the given group, if group is empty, returns a default fusion group"""
1091        if isinstance(group, (str)):
1092            group_len = len(group)
1093            if group_len > _MAX_GROUP_NAME_LEN:
1094                raise ValueError('Group name len is out of range {_MAX_GROUP_NAME_LEN}')
1095        else:
1096            raise TypeError('Group must be a python str')
1097
1098        if group == "":
1099            if context.get_context("device_target") == "Ascend":
1100                group = _DEFAULT_HCCL_FUSION_GROUP_NAME
1101            else:
1102                group = _DEFAULT_NCCL_FUSION_GROUP_NAME
1103        return group
1104
1105    def _set_allgather_comm_fusion(self, comm_fusion, comm_type="allgather"):
1106        """
1107        Set allgather and reducescatter fusion method for auto parallel.
1108
1109        Args:
1110            comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
1111                                  supports four fusion methods: `auto` and `size`.
1112            comm_type (str): The name of the communication operator, `allgather` or `reducescatter`.
1113
1114        Raises:
1115            KeyError: When key of comm_fusion is not 'mode' or 'config'.
1116            KeyError: When `mode` is not 'auto', 'size'.
1117        """
1118        self.check_context_handle()
1119        if comm_type == "allgather" and not self.get_enable_all_gather_fusion():
1120            return
1121        if comm_type == "reducescatter" and not self.get_enable_reduce_scatter_fusion():
1122            return
1123        if not isinstance(comm_fusion, dict):
1124            raise TypeError("For 'comm_fusion', {} config must be dict, but got the type : {}.".format(
1125                comm_type, type(comm_fusion)))
1126        if _ParallelFusionConfig.MODE not in comm_fusion:
1127            raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
1128        if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
1129            raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
1130        check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.SIZE]
1131        if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
1132            self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
1133        else:
1134            raise KeyError("fusion method mode must be auto or size, but got {}".format(
1135                comm_fusion[_ParallelFusionConfig.MODE]))
1136
1137        fusion_threshold = 64
1138        if comm_fusion[_ParallelFusionConfig.MODE] != _ParallelFusionConfig.AUTO:
1139            fusion_threshold = comm_fusion[_ParallelFusionConfig.FUSION_CONFIG]
1140        self.set_fusion_threshold_mb(fusion_threshold, comm_type)
1141
1142    def _set_allreduce_comm_fusion(self, comm_fusion):
1143        """
1144        Set fusion method for auto parallel.
1145
1146        Args:
1147            comm_fusion (dict): A dict contains the methods and values for setting the fusion method. Currently it
1148                                  supports four fusion methods: `auto`, `size` and `index`.
1149
1150        Raises:
1151            KeyError: When key of comm_fusion is not 'mode' or 'config'.
1152            KeyError: When `mode` is not 'auto', 'size' or 'index'.
1153        """
1154        self.check_context_handle()
1155        if not self.get_enable_all_reduce_fusion():
1156            return
1157        if not isinstance(comm_fusion, dict):
1158            raise TypeError("For 'comm_fusion', the 'allreduce' config must be dict, but got the type : {}.".format(
1159                type(comm_fusion)))
1160        if _ParallelFusionConfig.MODE not in comm_fusion:
1161            raise KeyError("For 'comm_fusion', the key 'mode' should be contained.")
1162        if _ParallelFusionConfig.FUSION_CONFIG not in comm_fusion:
1163            raise KeyError("For 'comm_fusion', the key 'config' should be contained.")
1164        check_mode = [_ParallelFusionConfig.AUTO, _ParallelFusionConfig.INDEX, _ParallelFusionConfig.SIZE]
1165        if comm_fusion[_ParallelFusionConfig.MODE] in check_mode:
1166            self._context_handle.set_fusion_mode(comm_fusion[_ParallelFusionConfig.MODE])
1167        else:
1168            raise KeyError("fusion method mode must be auto, index or size, but got {}".format(
1169                comm_fusion[_ParallelFusionConfig.MODE]))
1170        if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.AUTO:
1171            self.set_fusion_threshold_mb(fusion_threshold=64)
1172        if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.SIZE:
1173            self.set_fusion_threshold_mb(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
1174        if comm_fusion[_ParallelFusionConfig.MODE] == _ParallelFusionConfig.INDEX:
1175            self.set_all_reduce_fusion_split_indices(comm_fusion[_ParallelFusionConfig.FUSION_CONFIG])
1176
1177    def _set_openstate_comm_fusion(self, openstate):
1178        """
1179        Set open state for comm fusion.
1180
1181        Args:
1182            openstate (bool): The open state value to set the fusion method whether or not. Currently it
1183                                  supports two states: `True`, or `Flase`.
1184
1185        Raises:
1186            TypeError: When the value is not bool.
1187        """
1188        self.check_context_handle()
1189        if not isinstance(openstate, bool):
1190            raise TypeError("For 'comm_fusion', the 'openstate' must be bool, but got the type : {}.".format(
1191                type(openstate)))
1192        if not openstate:
1193            self.set_enable_all_reduce_fusion(openstate)
1194            self.set_enable_all_gather_fusion(openstate)
1195            self.set_enable_reduce_scatter_fusion(openstate)
1196
1197
1198def _set_ops_strategy_json_config(type="SAVE", path="", mode="all"):
1199    """
1200    Set strategy json configuration.
1201
1202    Args:
1203        type (str): The parameter for choosing save or load .json file.
1204        path (str): Path to save or load parallel strategy json.
1205        mode (str): The parameter for choosing save all or important operators.
1206
1207    Raises:
1208        KeyError: When type is not 'SAVE' or 'LOAD'.
1209        KeyError: When mode is not 'all' or 'principal'.
1210    """
1211    dir_path = os.path.dirname(path)
1212    if dir_path and not os.path.exists(dir_path):
1213        os.makedirs(dir_path)
1214    check_type = ["SAVE", "LOAD"]
1215    check_mode = ["all", "principal"]
1216    if type in check_type and mode in check_mode:
1217        auto_parallel_context().set_ops_strategy_json_config(type, path, mode)
1218    else:
1219        raise KeyError("Type must be 'SAVE' or 'LOAD' and mode must be 'all' or 'principal'")
1220
1221
1222_AUTO_PARALLEL_CONTEXT = None
1223
1224
1225def auto_parallel_context():
1226    """
1227    Get the global _AUTO_PARALLEL_CONTEXT, if it is not created, create a new one.
1228
1229    Returns:
1230        _AutoParallelContext, the global auto parallel context.
1231    """
1232    global _AUTO_PARALLEL_CONTEXT
1233    if _AUTO_PARALLEL_CONTEXT is None:
1234        _AUTO_PARALLEL_CONTEXT = _AutoParallelContext()
1235    return _AUTO_PARALLEL_CONTEXT
1236
1237
1238_set_auto_parallel_context_func_map = {
1239    "device_num": auto_parallel_context().set_device_num,
1240    "global_rank": auto_parallel_context().set_global_rank,
1241    "gradients_mean": auto_parallel_context().set_gradients_mean,
1242    "gradient_fp32_sync": auto_parallel_context().set_gradient_fp32_sync,
1243    "loss_repeated_mean": auto_parallel_context().set_loss_repeated_mean,
1244    "pipeline_stages": auto_parallel_context().set_pipeline_stages,
1245    "auto_pipeline": auto_parallel_context().set_auto_pipeline,
1246    "pipeline_result_broadcast": auto_parallel_context().set_pipeline_result_broadcast,
1247    "pipeline_segments": auto_parallel_context().set_pipeline_segments,
1248    "parallel_mode": auto_parallel_context().set_parallel_mode,
1249    "search_mode": auto_parallel_context().set_strategy_search_mode,
1250    "auto_parallel_search_mode": auto_parallel_context().set_auto_parallel_search_mode,
1251    "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
1252    "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
1253    "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
1254    "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
1255    "full_batch": auto_parallel_context().set_full_batch,
1256    "dataset_strategy": auto_parallel_context().set_dataset_strategy,
1257    "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
1258    "force_fp32_communication": auto_parallel_context().set_force_fp32_communication,
1259    "parallel_optimizer_config": auto_parallel_context().set_parallel_optimizer_config,
1260    "pipeline_config": auto_parallel_context().set_pipeline_config,
1261    "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
1262    "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices,
1263    "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
1264    "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
1265    "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
1266    "sharding_propagation": auto_parallel_context().set_sharding_propagation,
1267    "enable_alltoall": auto_parallel_context().set_enable_alltoall,
1268    "strategy_ckpt_config": auto_parallel_context().set_strategy_ckpt_config,
1269    "comm_fusion": auto_parallel_context().set_comm_fusion}
1270
1271_get_auto_parallel_context_func_map = {
1272    "device_num": auto_parallel_context().get_device_num,
1273    "global_rank": auto_parallel_context().get_global_rank,
1274    "gradients_mean": auto_parallel_context().get_gradients_mean,
1275    "gradient_fp32_sync": auto_parallel_context().get_gradient_fp32_sync,
1276    "loss_repeated_mean": auto_parallel_context().get_loss_repeated_mean,
1277    "pipeline_stages": auto_parallel_context().get_pipeline_stages,
1278    "auto_pipeline": auto_parallel_context().get_auto_pipeline,
1279    "pipeline_result_broadcast": auto_parallel_context().get_pipeline_result_broadcast,
1280    "pipeline_interleave": auto_parallel_context().get_pipeline_interleave,
1281    "pipeline_scheduler": auto_parallel_context().get_pipeline_scheduler,
1282    "parallel_mode": auto_parallel_context().get_parallel_mode,
1283    "search_mode": auto_parallel_context().get_strategy_search_mode,
1284    "auto_parallel_search_mode": auto_parallel_context().get_auto_parallel_search_mode,
1285    "parameter_broadcast": auto_parallel_context().get_parameter_broadcast,
1286    "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file,
1287    "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
1288    "full_batch": auto_parallel_context().get_full_batch,
1289    "dataset_strategy": auto_parallel_context().get_dataset_strategy,
1290    "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
1291    "force_fp32_communication": auto_parallel_context().get_force_fp32_communication,
1292    "grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
1293    "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices,
1294    "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
1295    "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
1296    "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
1297    "sharding_propagation": auto_parallel_context().get_sharding_propagation,
1298    "enable_alltoall": auto_parallel_context().get_enable_alltoall,
1299    "comm_fusion": auto_parallel_context().get_comm_fusion,
1300    "strategy_ckpt_config": auto_parallel_context().get_strategy_ckpt_config,
1301    "full_batch_is_set": auto_parallel_context().get_full_batch_is_set}
1302
1303
1304@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
1305                 loss_repeated_mean=bool, parallel_mode=str, search_mode=str, auto_parallel_search_mode=str,
1306                 parameter_broadcast=bool, strategy_ckpt_load_file=str,
1307                 strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
1308                 grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
1309                 communi_parallel_mode=str, optimizer_weight_shard_size=int, sharding_propagation=bool,
1310                 optimizer_weight_shard_aggregated_save=bool, enable_alltoall=bool, comm_fusion=dict,
1311                 strategy_ckpt_config=dict, force_fp32_communication=bool)
1312def _set_auto_parallel_context(**kwargs):
1313    """
1314    Set auto parallel context.
1315
1316    Note:
1317        Attribute name is required for setting attributes.
1318
1319    Args:
1320        device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
1321        global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
1322        gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror. Default: ``False``.
1323        loss_repeated_mean (bool): Whether to perform mean operator in backward in the case of repeated
1324                        calculations. Default: ``True``.
1325        gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True.
1326                        Default: ``True``.
1327        parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
1328                     "hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
1329
1330                     - stand_alone: Only one processor working.
1331
1332                     - data_parallel: Distributing the data across different processors.
1333
1334                     - hybrid_parallel: Achieving data parallelism and model parallelism manually.
1335
1336                     - semi_auto_parallel: Achieving data parallelism and model parallelism by
1337                       setting parallel strategies.
1338
1339                     - auto_parallel: Achieving parallelism automatically.
1340        search_mode (str): There are two kinds of search modes: "recursive_programming", "dynamic_programming"
1341                     and "sharding_propagation". Default: "dynamic_programming".
1342
1343                     - recursive_programming: Recursive programming search mode.
1344
1345                     - dynamic_programming: Dynamic programming search mode.
1346
1347                     - sharding_propagation: Propagate shardings from configured ops to non-configured ops.
1348        auto_parallel_search_mode (str): This is the old version of 'search_mode'. Here, remaining this attribute is
1349                     for forward compatibility, and this attribute will be deleted in a future MindSpore version.
1350        parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
1351                       "stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
1352                       broadcast. Default: ``False``.
1353        strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
1354        strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
1355        group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
1356        full_batch (bool): Whether to load the whole batch on each device. Default: ``False``.
1357        dataset_strategy Union[str, tuple]: Dataset sharding strategy. Default: "data_parallel".
1358        enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: ``False``.
1359        force_fp32_communication (bool): A switch that determines whether reduce operators (AllReduce, ReduceScatter)
1360                        are forced to use the fp32 data type for communication during communication. True is the enable
1361                        switch. Default: ``False`` .
1362        all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.
1363        pipeline_stages (int): Set the stage information for pipeline parallel. This indicates how
1364                        the devices are distributed alone the pipeline. The total devices will be divided into
1365                        'pipeline_stags' stages. This currently could only be used when
1366                        parallel mode semi_auto_parallel is enabled. Default: 0
1367        auto_pipeline (bool): Set the pipeline stage number to automatic. Its value will be selected between 1 and the
1368                        parameter `pipeline_stages`. This option requires the `parallel_mode` to be ``auto_parallel``
1369                        and the `search_mode` to be ``recursive_programming``. Default: ``False`` .
1370        pipeline_result_broadcast (bool): A switch that broadcast the last stage result to all other stage in pipeline
1371                        parallel inference. Default: ``False`` .
1372        communi_parallel_mode (str): There are tree kinds of communication parallel modes, "all_group_parallel",
1373                     "same_server_group_parallel" and "no_group_parallel". Default: "all_group_parallel".
1374
1375                     - all_group_parallel: All communication groups are in parallel.
1376
1377                     - same_server_group_parallel: Only the communication groups within the same server are parallel.
1378
1379                     - no_group_parallel: All communication groups are not parallel.
1380        optimizer_weight_shard_size (int): Set optimizer shard group size when not fully use parallel optimizer.
1381                                    It should be larger than one and less than or equal with the data parallel size.
1382                                    Default: -1, which means fully use parallel optimizer in data parallel dimension.
1383        optimizer_weight_shard_aggregated_save (bool): Whether to integrated save weight shard when enable parallel
1384                                                       optimizer. Default: ``False``.
1385        sharding_propagation (bool): Set the value of sharding strategy propagation in AUTO_PARALLEL mode. If True,
1386                                    the strategy-configured operators will propagate the strategies to other
1387                                    operators with minimum redistribution cost; otherwise, the algorithm will
1388                                    search the desired strategies. Default: ``False``.
1389        enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
1390                                circumvent AllToAll. Default: ``False``.
1391        comm_fusion (dict): A dict contains the types and configurations for setting the communication fusion. each
1392                    communication fusion config has two keys: "mode" and "config".
1393                    It supports following communication fusion types and configurations:
1394
1395                    - openstate: Whether turn on the communication fusion or not. If `openstate` is `True`, turn on
1396                        the communication fusion, otherwise, turn off the communication fusion. Default: `True`.
1397
1398                    - allreduce: if communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
1399                        and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
1400                        fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
1401                        manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
1402                        `all_reduce_fusion_config`.
1403
1404                    - allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
1405                        In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
1406                        threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
1407                        manually, and the fusion threshold must be larger than `0` MB.
1408
1409                    - reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
1410                        and `size`. Config is same as `allgather`.
1411
1412
1413
1414    Raises:
1415        ValueError: If input key is not attribute in auto parallel context.
1416    """
1417    for key, value in kwargs.items():
1418        if key not in _set_auto_parallel_context_func_map:
1419            raise ValueError("Set context keyword %s is not recognized!" % key)
1420        set_func = _set_auto_parallel_context_func_map[key]
1421        set_func(value)
1422
1423
1424def _get_auto_parallel_context(attr_key):
1425    """
1426    Get auto parallel context attribute value according to the key.
1427
1428    Args:
1429        attr_key (str): The key of the attribute.
1430
1431    Returns:
1432        Return attribute value according to the key.
1433
1434    Raises:
1435        ValueError: If input key is not attribute in auto parallel context.
1436    """
1437    if attr_key not in _get_auto_parallel_context_func_map:
1438        raise ValueError("Get context keyword %s is not recognized!" % attr_key)
1439    get_func = _get_auto_parallel_context_func_map[attr_key]
1440    return get_func()
1441
1442
1443def _reset_auto_parallel_context():
1444    """
1445    Reset auto parallel context attributes to the default values:
1446
1447    - device_num: 1.
1448    - global_rank: 0.
1449    - gradients_mean: False.
1450    - gradient_fp32_sync: True.
1451    - parallel_mode: "stand_alone".
1452    - parameter_broadcast: False.
1453    - strategy_ckpt_load_file: ""
1454    - strategy_ckpt_save_file: ""
1455    - enable_parallel_optimizer: False
1456    - force_fp32_communication: False
1457    - search_mode: 'recursive_programming
1458    - auto_parallel_search_mode: 'recursive_programming
1459    - sharding_propagation: False
1460    - pipeline_stages: 0
1461    - auto_pipeline: False
1462    - pipeline_result_broadcast: False
1463    - gradient_accumulation_shard: True
1464    - fusion_threshold: 64
1465    """
1466    auto_parallel_context().reset()
1467