1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Parallel Config for the Parallel Training 17This is an experimental interface that is subject to change and/or deletion. 18""" 19from mindspore._checkparam import Validator 20from mindspore import context 21import mindspore.communication.management as D 22from mindspore.context import ParallelMode 23from mindspore.parallel._utils import _get_parallel_mode 24from mindspore import log as logger 25 26__all__ = [ 27 "OpParallelConfig" 28] 29 30 31class _Config: 32 r""" A basic class of the configure""" 33 34 def __str__(self): 35 info = "[ParallelConfig]" + '\n' 36 for k, v in self.__dict__.items(): 37 var_info = "{}:{}\n".format(k, v) 38 info += var_info 39 return info 40 41 42class OpParallelConfig(_Config): 43 r""" 44 OpParallelConfig for the setting data parallel and model parallel. 45 46 Args: 47 data_parallel (int): The data parallel way. Default: 1 48 model_parallel (int): The model parallel way. Default: 1 49 Supported Platforms: 50 ``Ascend`` ``GPU`` 51 52 Examples: 53 >>> from mindspore.parallel.nn import OpParallelConfig 54 >>> config=OpParallelConfig(data_parallel=1, model_parallel=1) 55 """ 56 57 def __init__(self, data_parallel=1, model_parallel=1): 58 Validator.check_positive_int(data_parallel, "data_parallel") 59 Validator.check_positive_int(model_parallel, "model_parallel") 60 self.data_parallel = data_parallel 61 self.model_parallel = model_parallel 62 63 @property 64 def data_parallel(self): 65 return self._data_parallel 66 67 @data_parallel.setter 68 def data_parallel(self, value): 69 Validator.check_positive_int(value, "data_parallel") 70 self._data_parallel = value 71 72 @property 73 def model_parallel(self): 74 return self._model_parallel 75 76 @model_parallel.setter 77 def model_parallel(self, value): 78 Validator.check_positive_int(value, "model_parallel") 79 self._model_parallel = value 80 81 82class _PipeLineConfig(_Config): 83 r""" 84 PPConfig for the setting data parallel, model parallel 85 86 Args: 87 pipeline_stage (int): The number of the pipeline stages. Default: 1 88 micro_batch_num (int): The model parallel way. Default: 1 89 Supported Platforms: 90 ``Ascend`` ``GPU`` 91 92 Examples: 93 >>> config=_PipeLineConfig(pipeline_stage=1, micro_batch_num=1) 94 """ 95 96 def __init__(self, pipeline_stage=1, micro_batch_num=1): 97 Validator.check_positive_int(pipeline_stage, "pipeline_stage") 98 Validator.check_positive_int(micro_batch_num, "micro_batch_num") 99 self.pipeline_stage = pipeline_stage 100 self.micro_batch_num = micro_batch_num 101 102 @property 103 def pipeline_stage(self): 104 return self._pipeline_stage 105 106 @pipeline_stage.setter 107 def pipeline_stage(self, value): 108 Validator.check_positive_int(value, "pipeline_stage") 109 self._pipeline_stage = value 110 context.set_auto_parallel_context(pipeline_stages=value) 111 112 @property 113 def micro_batch_num(self): 114 return self._micro_batch_num 115 116 @micro_batch_num.setter 117 def micro_batch_num(self, value): 118 Validator.check_positive_int(value, "micro_batch_num") 119 self._micro_batch_num = value 120 121 122# In case the user doesn't pass a config as args. 123default_dpmp_config = OpParallelConfig() 124 125 126def _check_config(config): 127 """ 128 Check if micro_batch_num >= pipeline_stage 129 """ 130 # the config pipeline_stage is same with context.pipeline_stage 131 pipeline_stage = context.get_auto_parallel_context("pipeline_stages") 132 if hasattr(config, 'pipeline_stage') and pipeline_stage != config.pipeline_stage: 133 raise ValueError( 134 f"The pipeline stage {pipeline_stage} in auto_parallel_context is not equal to the pipeline_stage " 135 f"{config.pipeline_stage}" 136 f" in the config.") 137 138 # make sure the following is in auto parallel mode 139 is_auto_parallel = _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) 140 if not is_auto_parallel: 141 return 142 143 device_num = D.get_group_size() 144 optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer") 145 146 if config.data_parallel * config.model_parallel * pipeline_stage > device_num: 147 raise ValueError(f"The product of the data parallel {config.data_parallel}, " 148 f"model parallel {config.model_parallel} " 149 f"pipeline stages {pipeline_stage} " 150 f"should be less than device_num {device_num}.") 151 152 # the config optimizer_shard is same with context.optimizer_shard 153 if hasattr(config, "optimizer_shard") and optimizer_shard and optimizer_shard != config.optimizer_shard: 154 logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the" 155 f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the " 156 f"optimizer_shard to make them consistent.") 157