• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Auto mixed precision."""
16from .. import nn
17from .._checkparam import Validator as validator
18from .._checkparam import Rel
19from ..common import dtype as mstype
20from ..nn.wrap.cell_wrapper import _VirtualDatasetCell, _TrainPipelineAccuStepCell
21from ..nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
22from ..ops import functional as F
23from ..parallel._utils import _get_parallel_mode, _get_pipeline_stages
24from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
25from ..context import ParallelMode
26from .. import boost
27from .. import context
28
29
30class _OutputTo16(nn.Cell):
31    "Wrap cell for amp. Cast network output back to float16"
32
33    def __init__(self, op):
34        super(_OutputTo16, self).__init__(auto_prefix=False)
35        self._op = op
36
37    def construct(self, x):
38        return F.cast(self._op(x), mstype.float16)
39
40
41def _do_keep_batchnorm_fp32(network):
42    """Do keep batchnorm fp32."""
43    cells = network.name_cells()
44    change = False
45    for name in cells:
46        subcell = cells[name]
47        if subcell == network:
48            continue
49        elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
50            network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
51            change = True
52        else:
53            _do_keep_batchnorm_fp32(subcell)
54    if isinstance(network, nn.SequentialCell) and change:
55        network.cell_list = list(network.cells())
56
57
58_config_level = {
59    "O0": {
60        "keep_batchnorm_fp32": False,
61        "cast_model_type": mstype.float32,
62        "loss_scale_manager": None},
63    "O2": {
64        "keep_batchnorm_fp32": True,
65        "cast_model_type": mstype.float16,
66        "loss_scale_manager": DynamicLossScaleManager()},
67    "O3": {
68        "keep_batchnorm_fp32": False,
69        "cast_model_type": mstype.float16,
70        "loss_scale_manager": None}}
71
72
73def _check_kwargs(key_words):
74    """Check kwargs."""
75    for arg in key_words:
76        if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
77            raise ValueError(f"Unsupported arg '{arg}'")
78
79    if 'cast_model_type' in key_words:
80        validator.check_type_name('cast_model_type', key_words['cast_model_type'],
81                                  [mstype.float16, mstype.float32], None)
82    if 'keep_batchnorm_fp32' in key_words:
83        validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
84    if 'loss_scale_manager' in key_words:
85        loss_scale_manager = key_words['loss_scale_manager']
86        if loss_scale_manager:
87            validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager)
88
89
90def _add_loss_network(network, loss_fn, cast_model_type):
91    """Add loss network."""
92
93    class WithLossCell(nn.Cell):
94        "Wrap loss for amp. Cast network output back to float32"
95
96        def __init__(self, backbone, loss_fn):
97            super(WithLossCell, self).__init__(auto_prefix=False)
98            self._backbone = backbone
99            self._loss_fn = loss_fn
100
101        def construct(self, data, label):
102            out = self._backbone(data)
103            label = F.mixed_precision_cast(mstype.float32, label)
104            return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
105
106    validator.check_value_type('loss_fn', loss_fn, nn.Cell)
107    if cast_model_type == mstype.float16:
108        network = WithLossCell(network, loss_fn)
109    else:
110        network = nn.WithLossCell(network, loss_fn)
111    return network
112
113
114def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
115    """
116    Build the mixed precision training cell automatically.
117
118    Args:
119        network (Cell): Definition of the network.
120        loss_fn (Union[None, Cell]): Definition of the loss_fn. If None, the `network` should have the loss inside.
121            Default: None.
122        optimizer (Optimizer): Optimizer to update the Parameter.
123        level (str): Supports ["O0", "O2", "O3", "auto"]. Default: "O0".
124
125            - O0: Do not change.
126            - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
127              using dynamic loss scale.
128            - O3: Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
129            - auto: Set to level to recommended level in different devices. Set level to O2 on GPU, Set
130              level to O3 Ascend. The recommended level is choose by the export experience, cannot
131              always general. User should specify the level for special network.
132
133            O2 is recommended on GPU, O3 is recommended on Ascend.Property of `keep_batchnorm_fp32` , `cast_model_type`
134            and `loss_scale_manager` determined by `level` setting may be overwritten by settings in `kwargs` .
135
136        boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
137            training. Supports ["O0", "O1", "O2"]. Default: "O0".
138
139            - O0: Do not change.
140            - O1: Enable the boost mode, the performance is improved by about 20%, and
141              the accuracy is the same as the original accuracy.
142            - O2: Enable the boost mode, the performance is improved by about 30%, and
143              the accuracy is reduced by less than 3%.
144
145            If O1 or O2 mode is set, the boost related library will take effect automatically.
146
147        cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
148            network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
149            to the type determined by `level` setting.
150        keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If
151            set, the `level` setting will take no effect on this property.
152        loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, otherwise scale the loss by
153            `LossScaleManager` . If set, the `level` setting will take no effect on this property.
154    Raises:
155        ValueError: Auto mixed precision only supported on device GPU and Ascend. If device is CPU, a `ValueError`
156            exception will be raised.
157        ValueError: If device is CPU, property `loss_scale_manager` only can be set as `None` or `FixedLossScaleManager`
158            (with property `drop_overflow_update=False` ), or a `ValueError` exception will be raised.
159    """
160    validator.check_value_type('network', network, nn.Cell)
161    validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt))
162    if not isinstance(level, str):
163        raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
164                         but got type {}.".format(type(level)))
165    validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
166    validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)
167
168    if level == "auto":
169        device_target = context.get_context('device_target')
170        if device_target == "GPU":
171            level = "O2"
172        elif device_target == "Ascend":
173            level = "O3"
174        else:
175            raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
176
177    _check_kwargs(kwargs)
178    config = dict(_config_level[level], **kwargs)
179
180    if config["cast_model_type"] == mstype.float16:
181        network.to_float(mstype.float16)
182
183        if config["keep_batchnorm_fp32"]:
184            _do_keep_batchnorm_fp32(network)
185
186    if loss_fn:
187        network = _add_loss_network(network, loss_fn, config["cast_model_type"])
188
189    if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
190        network = _VirtualDatasetCell(network)
191
192    enable_boost = False
193    if boost_level in ["O1", "O2"]:
194        enable_boost = True
195
196    loss_scale = 1.0
197    if config["loss_scale_manager"] is not None:
198        loss_scale_manager = config["loss_scale_manager"]
199        loss_scale = loss_scale_manager.get_loss_scale()
200        update_cell = loss_scale_manager.get_update_cell()
201        if update_cell is not None:
202            # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
203            if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
204                raise ValueError("Only `loss_scale_manager=None` or "
205                                 "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
206                                 "are supported on device `CPU`. ")
207            if _get_pipeline_stages() > 1:
208                network = _TrainPipelineWithLossScaleCell(network, optimizer,
209                                                          scale_sense=update_cell).set_train()
210            elif enable_boost:
211                network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
212                                                                   scale_sense=update_cell).set_train()
213            else:
214                network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
215                                                           scale_sense=update_cell).set_train()
216            return network
217    if _get_pipeline_stages() > 1:
218        network = _TrainPipelineAccuStepCell(network, optimizer).set_train()
219    elif enable_boost:
220        network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
221    else:
222        network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
223    return network
224