• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Parallel Config for the Parallel Training
17This is an experimental interface that is subject to change and/or deletion.
18"""
19from mindspore._checkparam import Validator
20from mindspore import context
21import mindspore.communication.management as D
22from mindspore.context import ParallelMode
23from mindspore.parallel._utils import _get_parallel_mode
24from mindspore import log as logger
25
26__all__ = [
27    "OpParallelConfig"
28]
29
30
31class _Config:
32    r""" A basic class of the configure"""
33
34    def __str__(self):
35        info = "[ParallelConfig]" + '\n'
36        for k, v in self.__dict__.items():
37            var_info = "{}:{}\n".format(k, v)
38            info += var_info
39        return info
40
41
42class OpParallelConfig(_Config):
43    r"""
44        OpParallelConfig for the setting data parallel and model parallel.
45
46        Args:
47            data_parallel (int): The data parallel way. Default: 1
48            model_parallel (int): The model parallel way. Default: 1
49        Supported Platforms:
50            ``Ascend`` ``GPU``
51
52        Examples:
53            >>> from mindspore.parallel.nn import OpParallelConfig
54            >>> config=OpParallelConfig(data_parallel=1, model_parallel=1)
55    """
56
57    def __init__(self, data_parallel=1, model_parallel=1):
58        Validator.check_positive_int(data_parallel, "data_parallel")
59        Validator.check_positive_int(model_parallel, "model_parallel")
60        self.data_parallel = data_parallel
61        self.model_parallel = model_parallel
62
63    @property
64    def data_parallel(self):
65        return self._data_parallel
66
67    @data_parallel.setter
68    def data_parallel(self, value):
69        Validator.check_positive_int(value, "data_parallel")
70        self._data_parallel = value
71
72    @property
73    def model_parallel(self):
74        return self._model_parallel
75
76    @model_parallel.setter
77    def model_parallel(self, value):
78        Validator.check_positive_int(value, "model_parallel")
79        self._model_parallel = value
80
81
82class _PipeLineConfig(_Config):
83    r"""
84        PPConfig for the setting data parallel, model parallel
85
86        Args:
87            pipeline_stage (int): The number of the pipeline stages. Default: 1
88            micro_batch_num (int): The model parallel way. Default: 1
89        Supported Platforms:
90            ``Ascend`` ``GPU``
91
92        Examples:
93            >>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1)
94    """
95
96    def __init__(self, pipeline_stage=1, micro_batch_num=1):
97        Validator.check_positive_int(pipeline_stage, "pipeline_stage")
98        Validator.check_positive_int(micro_batch_num, "micro_batch_num")
99        self.pipeline_stage = pipeline_stage
100        self.micro_batch_num = micro_batch_num
101
102    @property
103    def pipeline_stage(self):
104        return self._pipeline_stage
105
106    @pipeline_stage.setter
107    def pipeline_stage(self, value):
108        Validator.check_positive_int(value, "pipeline_stage")
109        self._pipeline_stage = value
110        context.set_auto_parallel_context(pipeline_stages=value)
111
112    @property
113    def micro_batch_num(self):
114        return self._micro_batch_num
115
116    @micro_batch_num.setter
117    def micro_batch_num(self, value):
118        Validator.check_positive_int(value, "micro_batch_num")
119        self._micro_batch_num = value
120
121
122# In case the user doesn't pass a config as args.
123default_dpmp_config = OpParallelConfig()
124
125
126def _check_config(config):
127    """
128       Check if micro_batch_num >= pipeline_stage
129    """
130    # the config pipeline_stage is same with context.pipeline_stage
131    pipeline_stage = context.get_auto_parallel_context("pipeline_stages")
132    if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage:
133        raise ValueError(
134            f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage "
135            f"{config.pipeline_stage}"
136            f" in the config.")
137
138    # make sure the following is in auto parallel mode
139    is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
140    if not is_auto_parallel:
141        return
142
143    device_num = D.get_group_size()
144    optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
145
146    if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
147        raise ValueError(f"The product of the data parallel {config.data_parallel}, "
148                         f"model parallel {config.model_parallel} "
149                         f"pipeline stages {pipeline_stage} "
150                         f"should be less than device_num {device_num}.")
151
152    # the config optimizer_shard is same with context.optimizer_shard
153    if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard:
154        logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
155                       f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
156                       f"optimizer_shard to make them consistent.")
157