1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Auto mixed precision.""" 16from .. import nn 17from .._checkparam import Validator as validator 18from .._checkparam import Rel 19from ..common import dtype as mstype 20from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell 21from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell 22from ..ops import functional as F 23from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages 24from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager 25from ..context import ParallelMode 26from .. import boost 27from .. import context 28 29 30class _OutputTo16(nn.Cell): 31 "Wrap cell for amp. Cast network output back to float16" 32 33 def __init__(self, op): 34 super(_OutputTo16, self).__init__(auto_prefix=False) 35 self._op = op 36 37 def construct(self, x): 38 return F.cast(self._op(x), mstype.float16) 39 40 41def _do_keep_batchnorm_fp32(network): 42 """Do keep batchnorm fp32.""" 43 cells = network.name_cells() 44 change = False 45 for name in cells: 46 subcell = cells[name] 47 if subcell == network: 48 continue 49 elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)): 50 network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32)) 51 change = True 52 else: 53 _do_keep_batchnorm_fp32(subcell) 54 if isinstance(network, nn.SequentialCell) and change: 55 network.cell_list = list(network.cells()) 56 57 58_config_level = { 59 "O0": { 60 "keep_batchnorm_fp32": False, 61 "cast_model_type": mstype.float32, 62 "loss_scale_manager": None}, 63 "O2": { 64 "keep_batchnorm_fp32": True, 65 "cast_model_type": mstype.float16, 66 "loss_scale_manager": DynamicLossScaleManager()}, 67 "O3": { 68 "keep_batchnorm_fp32": False, 69 "cast_model_type": mstype.float16, 70 "loss_scale_manager": None}} 71 72 73def _check_kwargs(key_words): 74 """Check kwargs.""" 75 for arg in key_words: 76 if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: 77 raise ValueError(f"Unsupported arg '{arg}'") 78 79 if 'cast_model_type' in key_words: 80 validator.check_type_name('cast_model_type', key_words['cast_model_type'], 81 [mstype.float16, mstype.float32], None) 82 if 'keep_batchnorm_fp32' in key_words: 83 validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) 84 if 'loss_scale_manager' in key_words: 85 loss_scale_manager = key_words['loss_scale_manager'] 86 if loss_scale_manager: 87 validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager) 88 89 90def _add_loss_network(network, loss_fn, cast_model_type): 91 """Add loss network.""" 92 93 class WithLossCell(nn.Cell): 94 "Wrap loss for amp. Cast network output back to float32" 95 96 def __init__(self, backbone, loss_fn): 97 super(WithLossCell, self).__init__(auto_prefix=False) 98 self._backbone = backbone 99 self._loss_fn = loss_fn 100 101 def construct(self, data, label): 102 out = self._backbone(data) 103 label = F.mixed_precision_cast(mstype.float32, label) 104 return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) 105 106 validator.check_value_type('loss_fn', loss_fn, nn.Cell) 107 if cast_model_type == mstype.float16: 108 network = WithLossCell(network, loss_fn) 109 else: 110 network = nn.WithLossCell(network, loss_fn) 111 return network 112 113 114def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs): 115 """ 116 Build the mixed precision training cell automatically. 117 118 Args: 119 network (Cell): Definition of the network. 120 loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside. 121 Default: None. 122 optimizer (Optimizer): Optimizer to update the Parameter. 123 level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0". 124 125 - O0: Do not change. 126 - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, 127 using dynamic loss scale. 128 - O3: Cast network to float16, with additional property `keep_batchnorm_fp32=False` . 129 - auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set 130 level to O3 Ascend. The recommended level is choose by the export experience, cannot 131 always general. User should specify the level for special network. 132 133 O2 is recommended on GPU, O3 is recommended on Ascend.Property of `keep_batchnorm_fp32` , `cast_model_type` 134 and `loss_scale_manager` determined by `level` setting may be overwritten by settings in `kwargs` . 135 136 boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode 137 training. Supports ["O0", "O1", "O2"]. Default: "O0". 138 139 - O0: Do not change. 140 - O1: Enable the boost mode, the performance is improved by about 20%, and 141 the accuracy is the same as the original accuracy. 142 - O2: Enable the boost mode, the performance is improved by about 30%, and 143 the accuracy is reduced by less than 3%. 144 145 If O1 or O2 mode is set, the boost related library will take effect automatically. 146 147 cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the 148 network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted 149 to the type determined by `level` setting. 150 keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If 151 set, the `level` setting will take no effect on this property. 152 loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, otherwise scale the loss by 153 `LossScaleManager` . If set, the `level` setting will take no effect on this property. 154 Raises: 155 ValueError: Auto mixed precision only supported on device GPU and Ascend. If device is CPU, a `ValueError` 156 exception will be raised. 157 ValueError: If device is CPU, property `loss_scale_manager` only can be set as `None` or `FixedLossScaleManager` 158 (with property `drop_overflow_update=False` ), or a `ValueError` exception will be raised. 159 """ 160 validator.check_value_type('network', network, nn.Cell) 161 validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt)) 162 if not isinstance(level, str): 163 raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \ 164 but got type {}.".format(type(level))) 165 validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN) 166 validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN) 167 168 if level == "auto": 169 device_target = context.get_context('device_target') 170 if device_target == "GPU": 171 level = "O2" 172 elif device_target == "Ascend": 173 level = "O3" 174 else: 175 raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.") 176 177 _check_kwargs(kwargs) 178 config = dict(_config_level[level], **kwargs) 179 180 if config["cast_model_type"] == mstype.float16: 181 network.to_float(mstype.float16) 182 183 if config["keep_batchnorm_fp32"]: 184 _do_keep_batchnorm_fp32(network) 185 186 if loss_fn: 187 network = _add_loss_network(network, loss_fn, config["cast_model_type"]) 188 189 if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): 190 network = _VirtualDatasetCell(network) 191 192 enable_boost = False 193 if boost_level in ["O1", "O2"]: 194 enable_boost = True 195 196 loss_scale = 1.0 197 if config["loss_scale_manager"] is not None: 198 loss_scale_manager = config["loss_scale_manager"] 199 loss_scale = loss_scale_manager.get_loss_scale() 200 update_cell = loss_scale_manager.get_update_cell() 201 if update_cell is not None: 202 # only cpu not support `TrainOneStepWithLossScaleCell` for control flow. 203 if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": 204 raise ValueError("Only `loss_scale_manager=None` or " 205 "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" 206 "are supported on device `CPU`. ") 207 if _get_pipeline_stages() > 1: 208 network = _TrainPipelineWithLossScaleCell(network, optimizer, 209 scale_sense=update_cell).set_train() 210 elif enable_boost: 211 network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer, 212 scale_sense=update_cell).set_train() 213 else: 214 network = nn.TrainOneStepWithLossScaleCell(network, optimizer, 215 scale_sense=update_cell).set_train() 216 return network 217 if _get_pipeline_stages() > 1: 218 network = _TrainPipelineAccuStepCell(network, optimizer).set_train() 219 elif enable_boost: 220 network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train() 221 else: 222 network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() 223 return network 224