• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Auto mixed precision."""
16from __future__ import absolute_import
17import inspect
18import types
19
20import mindspore as ms
21from mindspore import nn
22from mindspore import _checkparam as validator
23from mindspore.common import dtype as mstype
24from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell
25from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell
26from mindspore.ops import functional as F
27from mindspore.parallel._utils import _get_pipeline_stages
28from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager
29from mindspore import boost, context
30from mindspore.ops import operations as P
31from mindspore.ops import Primitive
32from mindspore import log as logger
33
34
35AMP_WHITE_LIST = [
36    nn.Conv1d,
37    nn.Conv2d,
38    nn.Conv3d,
39    nn.Conv1dTranspose,
40    nn.Conv2dTranspose,
41    nn.Conv3dTranspose,
42    nn.Dense,
43    nn.LSTMCell,
44    nn.RNNCell,
45    nn.GRUCell,
46    P.Conv2D,
47    P.Conv3D,
48    P.Conv2DTranspose,
49    P.Conv3DTranspose,
50    P.Conv2DBackpropInput,
51    P.MatMul,
52    P.BatchMatMul,
53    P.PReLU,
54    P.ReLU,
55    P.Ger
56]
57
58
59AMP_BLACK_LIST = [
60    nn.BatchNorm1d,
61    nn.BatchNorm2d,
62    nn.BatchNorm3d,
63    nn.LayerNorm
64]
65
66# Primitives in inner amp black list will not be converted in O2/O3
67_INNER_AMP_BLACK_LIST = []
68
69MS_AMP_BY_REWRITE = False
70
71
72def amp_cast(value, dtype):
73    """This function is used to insert cast operators for tensors during auto mixed precision."""
74    if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type:
75        return P.Cast()(value, dtype)
76    return value
77
78_amp_cast_op = amp_cast
79
80
81class _OutputTo16(nn.Cell):
82    """Wrap cell for amp. Cast network output back to float16."""
83    def __init__(self, backbone, dtype=mstype.float16):
84        super(_OutputTo16, self).__init__(auto_prefix=False)
85        self._backbone = backbone
86        self.dtype = dtype
87        self._get_attr_from_cell(backbone)
88
89    def construct(self, *args, **kwargs):
90        return F.cast(self._backbone(*args, **kwargs), self.dtype)
91
92
93class _OutputTo32(nn.Cell):
94    """Wrap loss for amp. Cast network output back to float32."""
95    def __init__(self, backbone):
96        super(_OutputTo32, self).__init__(auto_prefix=False)
97        self._backbone = backbone
98        self._get_attr_from_cell(backbone)
99
100    def construct(self, *args, **kwargs):
101        out = self._backbone(*args, **kwargs)
102        return F.mixed_precision_cast(mstype.float32, out)
103
104
105def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
106    """
107    Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied:
108        1) Type of node is CallPrimitive and type of instance is Primitive
109        2) Type of instance is not P.Cast
110        3) force_cast is True, which means one of upper layer cells is under casting
111        4) white_list exist and type of node is in white_list
112        5) black_list exist and type of node is in not black_list
113    """
114    if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive:
115        return False
116    if not inspect.isclass(node.get_instance_type()):
117        return False
118    if not issubclass(node.get_instance_type(), Primitive):
119        return False
120    if issubclass(node.get_instance_type(), P.Cast):
121        return False
122    if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
123        return False
124    if force_cast:
125        return True
126    if white_list is not None and node.get_instance_type() in white_list:
127        return True
128    if black_list is not None and node.get_instance_type() not in black_list:
129        return True
130    return False
131
132
133def _precision_set_by_user(cell_inst: nn.Cell) -> bool:
134    """Check whether cell precision is set by user."""
135    for flag in ["fp32", "fp16", "bf16"]:
136        if hasattr(cell_inst, flag) and getattr(cell_inst, flag):
137            return True
138    return False
139
140
141def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool:
142    """
143    Check whether current node is type of tree whose network needs to be casted. Follow conditions need to
144    be satisfied:
145        1) Type of node is Tree and type of instance is Cell
146        2) Cell.to_float(xxx) is not set by user
147        3) force_cast is True, which means one of upper layer networks is under casting
148        4) white_list exist and type of node is in white_list
149        5) black_list exist and type of node is in not black_list
150    """
151    if node.get_node_type() != ms.rewrite.NodeType.Tree:
152        return False
153    if not inspect.isclass(node.get_instance_type()):
154        return False
155    if not issubclass(node.get_instance_type(), nn.Cell):
156        return False
157    if node.get_instance_type() in _INNER_AMP_BLACK_LIST:
158        return False
159    if _precision_set_by_user(node.get_instance()):
160        return False
161    if force_cast:
162        return True
163    if white_list is not None and node.get_instance_type() in white_list:
164        return True
165    if black_list is not None and node.get_instance_type() not in black_list:
166        return True
167    return False
168
169
170def _insert_cast_for_operator(node, dtype):
171    """insert cast pair for node."""
172    dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
173    stree = node.get_symbol_tree()
174    # insert cast fp16/bf16 for inputs of node
175    for idx, arg in enumerate(node.get_args()):
176        if arg.type != ms.rewrite.ValueType.NamingValue:
177            continue
178        incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"])
179        arg_providers = node.get_arg_providers()
180        if not arg_providers or idx not in arg_providers or \
181            len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1:
182            # create new target names when argument is used by other node
183            incast_targets = [stree.unique_name(f"{arg.value}_var")]
184        else:
185            incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope])
186        incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args)
187        stree.insert(stree.before(node), incast_node)
188        node.set_arg_by_node(idx, incast_node)
189    # insert cast fp32 for outputs of node
190    for _, target in enumerate(node.get_targets()):
191        if target.type != ms.rewrite.ValueType.NamingValue:
192            continue
193        outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"],
194                                                                 [target.scope, "mindspore"])
195        outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope])
196        outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args)
197        stree.insert(stree.after(node), outcast_node)
198
199
200def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None):
201    """insert cast for operators not in black_list."""
202    # get all nodes of stree exclude nodes in subtree.
203    all_nodes = stree.all_nodes(False)
204    for node in all_nodes:
205        if not node.get_targets():
206            continue
207        if _operator_need_cast(node, force_cast, white_list, black_list):
208            _insert_cast_for_operator(node, dtype)
209        elif node.get_node_type() == ms.rewrite.NodeType.Tree:
210            force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list)
211            if not _precision_set_by_user(node.get_instance()):
212                subtree = node.get_sub_tree()
213                _insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list)
214
215
216def _need_removed_cast_pair(node, dtype):
217    """check whether the cast pairs should be removed."""
218    dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16"
219    cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"])
220    cast_dtype_f16 = cast_dtypes[0]
221    cast_dtype_f32 = cast_dtypes[1]
222    # current node should be cast fp32
223    if node.get_instance_type() != _amp_cast_op:
224        return False
225    node_cast_type = node.get_args()[1]
226    if node_cast_type != cast_dtype_f32:
227        return False
228    # all user nodes should be cast fp16/bf16
229    if not node.get_users():
230        return False
231    all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()]
232    for user in node.get_users():
233        # If ControlFlow node(e.g. if, for, while) exists between current node and user node,
234        # cast pair should not be removed.
235        middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)]
236        if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]):
237            return False
238        if user.get_instance_type() != _amp_cast_op:
239            return False
240        user_cast_type = user.get_args()[1]
241        if user_cast_type != cast_dtype_f16:
242            return False
243        # cast pair detected, check next user
244        continue
245    return True
246
247
248def _remove_duplicated_cast(stree, dtype):
249    """remove the duplicated cast operators."""
250    all_nodes = list(stree.nodes(all_nodes=True))
251    for node in all_nodes:
252        if _need_removed_cast_pair(node, dtype):
253            incast_nodes = node.get_users()
254            # remove cast fp16/bf16 nodes
255            for incast_node in incast_nodes:
256                # get_target_users() return {target0: [(user0, arg_idx), ...], ...}
257                target_users = list(incast_node.get_target_users().values())
258                if not target_users or not target_users[0]:
259                    continue
260                for user_node, arg_idx in target_users[0]:
261                    user_node.set_arg(arg_idx, incast_node.get_args()[0])
262                stree.erase(incast_node)
263            # remove the cast fp32 node
264            stree.erase(node)
265
266
267def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None):
268    """Implement auto mixed precision by rewrite"""
269    if (white_list is None and black_list is None) or (white_list is not None and black_list is not None):
270        raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.")
271    # enable rewrite configs for amp
272    ms.rewrite.common.namespace._ms_cells_to_subtree = True
273    ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True
274    # insert casts by rewrite
275    stree = ms.rewrite.SymbolTree.create(network)
276    _insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list)
277    _remove_duplicated_cast(stree, dtype)
278    new_net = stree.get_network()
279    # disable rewrite configs
280    ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False
281    ms.rewrite.common.namespace._ms_cells_to_subtree = False
282    ms.rewrite.common.config.clear_caches()
283    return new_net
284
285
286def _auto_black_list(network, black_list, dtype):
287    """process the black list of network."""
288    network.to_float(dtype)
289    cells = network.name_cells()
290    change = False
291    for name in cells:
292        subcell = cells[name]
293        if subcell == network:
294            continue
295        if isinstance(subcell, tuple(black_list)):
296            network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype)
297            change = True
298        else:
299            _auto_black_list(subcell, black_list, dtype)
300    if isinstance(network, nn.SequentialCell) and change:
301        network.cell_list = list(network.cells())
302    return network
303
304
305def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
306    """
307    Returns a network processed with auto mixed precision.
308
309    This interface will automatically perform mixed-precision processing on the input network, and the cells
310    and operators in the processed network will add precision conversion operations to calculate with lower
311    precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are
312    converted to lower precision float, and calculation results are converted back to full precision float,
313    i.e. ``mstype.float32`` .
314
315    The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and
316    operators are specifically converted.
317
318    The current built-in whitelist contents are:
319
320    [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
321    :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
322    :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
323    :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
324    :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
325    :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
326    :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
327
328    The current built-in blacklist contents are:
329
330    [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
331    :class:`mindspore.nn.LayerNorm`]
332
333    For details on automatic mixed precision, refer to
334    `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ .
335
336    Note:
337        - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
338          can result in a larger network hierarchy and slower performance.
339        - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
340          mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
341          need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
342
343    Args:
344        network (Cell): Definition of the network.
345        amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` .
346
347            - "O0": Do not change.
348            - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full
349              precision operations for the rest.
350            - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest
351              to lower precision operations.
352            - "O3": Cast network to lower precision.
353
354        dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
355            default: ``mstype.float16`` .
356
357    Raises:
358        TypeError: If `network` is not a Cell.
359        ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
360        ValueError: If `amp_level` is not within the supported range.
361
362    Examples:
363        >>> from mindspore import amp
364        >>> # Define the network structure of LeNet5. Refer to
365        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
366        >>> network = LeNet5()
367        >>> amp_level = "O1"
368        >>> net = amp.auto_mixed_precision(network, amp_level)
369    """
370    if not isinstance(network, nn.Cell):
371        raise TypeError("The network type should be Cell.")
372
373    if dtype not in (mstype.float16, mstype.bfloat16):
374        raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
375
376    if amp_level == "O0":
377        return network
378
379    # Return network if the same amp level has already been configurated
380    if getattr(network, "_amp_level") in ("O1", "O2", "O3"):
381        logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} "
382                       f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance "
383                       f"degradation.")
384
385    if amp_level == "O1":
386        network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST)
387    elif amp_level == "O2":
388        if MS_AMP_BY_REWRITE:
389            network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST)
390        else:
391            network = _auto_black_list(network, AMP_BLACK_LIST, dtype)
392            network = _OutputTo32(network)
393    elif amp_level == "O3":
394        if MS_AMP_BY_REWRITE:
395            network = _auto_mixed_precision_rewrite(network, dtype, black_list=[])
396        else:
397            network.to_float(dtype)
398            network = _OutputTo32(network)
399    else:
400        raise ValueError("The amp level {} is not supported".format(amp_level))
401
402    setattr(network, "_amp_level", amp_level)
403
404    return network
405
406
407def _do_keep_batchnorm_fp32(network):
408    """Do keep batchnorm fp32."""
409    cells = network.name_cells()
410    change = False
411    for name in cells:
412        subcell = cells[name]
413        if subcell == network:
414            continue
415        elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)):
416            network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32))
417            change = True
418        else:
419            _do_keep_batchnorm_fp32(subcell)
420    if isinstance(network, nn.SequentialCell) and change:
421        network.cell_list = list(network.cells())
422
423
424_config_level = {
425    "O0": {
426        "keep_batchnorm_fp32": False,
427        "cast_model_type": mstype.float32,
428        "loss_scale_manager": None},
429    "O1": {
430        "keep_batchnorm_fp32": False,
431        "cast_model_type": mstype.float32,
432        "loss_scale_manager": None},
433    "O2": {
434        "keep_batchnorm_fp32": True,
435        "cast_model_type": mstype.float16,
436        "loss_scale_manager": DynamicLossScaleManager()},
437    "O3": {
438        "keep_batchnorm_fp32": False,
439        "cast_model_type": mstype.float16,
440        "loss_scale_manager": None}}
441
442
443def _check_kwargs(key_words):
444    """Check kwargs."""
445    for arg in key_words:
446        if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
447            raise ValueError(f"Unsupported arg '{arg}'")
448
449    if 'cast_model_type' in key_words:
450        validator.check_type_name('cast_model_type', key_words['cast_model_type'],
451                                  [mstype.float16, mstype.float32], None)
452    if 'keep_batchnorm_fp32' in key_words:
453        validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool)
454    if 'loss_scale_manager' in key_words:
455        loss_scale_manager = key_words['loss_scale_manager']
456        if loss_scale_manager:
457            validator.check_value_type('loss_scale_manager', loss_scale_manager,
458                                       [LossScaleManager, boost.GroupLossScaleManager])
459
460
461def _check_level(level, boost_level):
462    """Check level."""
463    if not isinstance(level, str):
464        raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \
465                         but got type {}.".format(type(level)))
466    validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN)
467    validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN)
468
469    if level == "auto":
470        device_target = context.get_context('device_target')
471        if device_target == "GPU":
472            level = "O2"
473        elif device_target == "Ascend":
474            level = "O3"
475        else:
476            raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.")
477
478    enable_boost = False
479    if boost_level in ["O1", "O2"]:
480        enable_boost = True
481
482    return level, enable_boost
483
484
485def _add_loss_network(network, loss_fn, cast_model_type):
486    """Add loss network."""
487
488    class WithLossCell(nn.Cell):
489        """Wrap loss for amp. Cast network output back to float32."""
490        def __init__(self, backbone, loss_fn):
491            super(WithLossCell, self).__init__(auto_prefix=False)
492            self._backbone = backbone
493            self._loss_fn = loss_fn
494            self._get_attr_from_cell(backbone)
495
496        def construct(self, data, label):
497            out = self._backbone(data)
498            label = F.mixed_precision_cast(mstype.float32, label)
499            return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label)
500
501    validator.check_value_type('loss_fn', loss_fn, nn.Cell)
502    if cast_model_type == mstype.float16:
503        network = WithLossCell(network, loss_fn)
504    else:
505        network = nn.WithLossCell(network, loss_fn)
506    return network
507
508
509def _is_grad_accumulation(mcell):
510    if mcell.cls_name == "GradAccumulationCell":
511        return True
512    for cell in mcell.cells():
513        if _is_grad_accumulation(cell):
514            return True
515    return False
516
517
518def _auto_mixed_precision_process(network, config, level):
519    """Auto mixed precision process."""
520    if MS_AMP_BY_REWRITE:
521        if config["cast_model_type"] == mstype.float16 or level == "O2":
522            level = "O2" if config["keep_batchnorm_fp32"] else "O3"
523        elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
524            # cast_model_type set by kwargs
525            level = "O0"
526        network = auto_mixed_precision(network, level)
527    else:
528        if config["cast_model_type"] == mstype.float16:
529            network.to_float(mstype.float16)
530
531            if config["keep_batchnorm_fp32"]:
532                _do_keep_batchnorm_fp32(network)
533        elif not config["keep_batchnorm_fp32"] and level == "O2":
534            network.to_float(mstype.float16)
535        elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"):
536            pass
537        else:
538            network = auto_mixed_precision(network, level)
539    return network
540
541
542def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs):
543    """
544    Build the mixed precision training cell automatically.
545
546    Note:
547        - After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
548          to perform the precision conversion again. If  `build_train_network` is used to train a converted network,
549          `level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
550
551    Args:
552        network (Cell): Definition of the network.
553        optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter.
554        loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside.
555            Default: ``None`` .
556        level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` .
557
558            - 'O0': Do not change.
559            - 'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32.
560              The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
561              Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
562            - 'O2': Cast network to float16, keep `mindspore.nn.BatchNorm` series interface,
563              :class:`mindspore.nn.LayerNorm` and `loss_fn` (if set) run in float32, using dynamic loss scale.
564            - 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` .
565            - 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set
566              level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all
567              scenarios. User should specify the level for special network.
568
569            'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`,
570            `cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in
571            `kwargs`.
572
573        boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode
574            training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` .
575
576            - 'O0': Do not change.
577            - 'O1': Enable the boost mode, the performance is improved by about 20%, and
578              the accuracy is the same as the original accuracy.
579            - 'O2': Enable the boost mode, the performance is improved by about 30%, and
580              the accuracy is reduced by less than 3%.
581
582            If 'O1' or 'O2' mode is set, the boost related library will take effect automatically.
583
584        cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the
585            network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted
586            to the type determined by `level` setting.
587        keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If
588            set, the `level` setting will take no effect on this property.
589        loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of
590            :class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will
591            take no effect on this property.
592
593    Raises:
594        ValueError: If device is CPU, property `loss_scale_manager` is not `None` or
595            :class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ).
596
597    Examples:
598        >>> from mindspore import amp, nn
599        >>> # Define the network structure of LeNet5. Refer to
600        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
601        >>> network = LeNet5()
602        >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
603        >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
604        >>> amp_level="O3"
605        >>> net = amp.build_train_network(network, net_opt, net_loss, amp_level)
606    """
607    validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt,
608                                                        nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell))
609
610    level, enable_boost = _check_level(level, boost_level)
611
612    _check_kwargs(kwargs)
613    config = dict(_config_level.get(level), **kwargs)
614
615    network = _auto_mixed_precision_process(network, config, level)
616
617    if loss_fn:
618        network = _add_loss_network(network, loss_fn, config["cast_model_type"])
619
620    loss_scale = None
621    if config["loss_scale_manager"] is not None:
622        loss_scale_manager = config["loss_scale_manager"]
623        loss_scale = loss_scale_manager.get_loss_scale()
624        update_cell = loss_scale_manager.get_update_cell()
625        if update_cell is not None:
626            # only cpu not support `TrainOneStepWithLossScaleCell` for control flow.
627            if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU":
628                raise ValueError("Only `loss_scale_manager=None` or "
629                                 "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`"
630                                 "are supported on device `CPU`. ")
631            if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
632                network = _TrainGradAccuWithLossScaleCell(network, optimizer,
633                                                          scale_sense=update_cell).set_train()
634            elif enable_boost:
635                network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer,
636                                                                   scale_sense=update_cell).set_train()
637            else:
638                network = nn.TrainOneStepWithLossScaleCell(network, optimizer,
639                                                           scale_sense=update_cell).set_train()
640            return network
641    if _get_pipeline_stages() > 1 or _is_grad_accumulation(network):
642        network = _TrainGradAccuStepCell(network, optimizer).set_train()
643    elif enable_boost:
644        network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train()
645    else:
646        network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train()
647    return network
648
649
650def get_white_list():
651    """
652    Provide a copy of internal white list used by auto mixed precision.
653
654    The current built-in whitelist contents are:
655
656    [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`,
657    :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`,
658    :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`,
659    :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`,
660    :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`,
661    :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`,
662    :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`]
663
664    Returns:
665        list, A copy of internal white list.
666
667    Examples:
668        >>> from mindspore import amp
669        >>> white_list = amp.get_white_list()
670        >>> print(white_list)
671        [<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>,
672         <class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>,
673         <class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>,
674         <class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>,
675         <class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>,
676         <class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>,
677         <class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>,
678         <class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>,
679         <class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>,
680         <class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>,
681         <class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>,
682         <class 'mindspore.ops.operations.math_ops.Ger'>]
683    """
684    white_list = AMP_WHITE_LIST.copy()
685    return white_list
686
687
688def get_black_list():
689    """
690    Provide a copy of internal black list used by auto mixed precision.
691
692    The current built-in blacklist contents are:
693
694    [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`,
695    :class:`mindspore.nn.LayerNorm`]
696
697    Returns:
698        list, A copy of internal black list.
699
700    Examples:
701        >>> from mindspore import amp
702        >>> black_list = amp.get_black_list()
703        >>> print(black_list)
704        [<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>,
705         <class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>]
706    """
707    black_list = AMP_BLACK_LIST.copy()
708    return black_list
709
710
711def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16):
712    """
713    Custom mixed precision by setting whitelist or blacklist.
714    When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion.
715    When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion.
716    Only one of `white_list` and `black_list` should be provided.
717
718    Note:
719        - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
720          can result in a larger network hierarchy and slower performance.
721        - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by
722          mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level`
723          need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
724        - Primitives for blacklist is not support yet.
725
726    Args:
727        network (Cell): Definition of the network.
728        white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means
729            white list is not used.
730        black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means
731            black list is not used.
732        dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` ,
733            default: ``mstype.float16`` .
734
735    Returns:
736        network (Cell), A network supporting mixed precision.
737
738    Raises:
739        TypeError: The network type is not Cell.
740        ValueError: Neither `white_list` nor `black_list` is provided.
741        ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` .
742        ValueError: Both `white_list` and `black_list` are provided.
743
744    Examples:
745        >>> from mindspore import amp, nn
746        >>> # Define the network structure of LeNet5. Refer to
747        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
748        >>> net = LeNet5()
749        >>> custom_white_list = amp.get_white_list()
750        >>> custom_white_list.append(nn.Flatten)
751        >>> net = amp.custom_mixed_precision(net, white_list=custom_white_list)
752    """
753    if not isinstance(network, nn.Cell):
754        raise TypeError("The network type should be Cell.")
755
756    if white_list is None and black_list is None:
757        raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.")
758
759    if white_list is not None and black_list is not None:
760        raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided "
761                         "at the same time, please provide one or the other.")
762
763    if dtype not in (mstype.float16, mstype.bfloat16):
764        raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.")
765
766    if white_list is not None:
767        _list_check(white_list, "white_list")
768        network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list)
769    else:
770        _list_check(black_list, "black_list")
771        if MS_AMP_BY_REWRITE:
772            network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list)
773        else:
774            network = _auto_black_list(network, black_list, dtype)
775            network = _OutputTo32(network)
776    return network
777
778
779def _list_check(custom_list: list, list_name: str):
780    """
781    check whether custom list is valid
782
783    Raises:
784        TypeError: The type of custom_list is not list.
785        TypeError: The element in custom_list is not a class.
786        TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive'].
787    """
788    if not isinstance(custom_list, list):
789        raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}")
790
791    for elem in custom_list:
792        if not isinstance(elem, type):
793            raise TypeError(f"The element in {list_name} should be a class, but got {elem}")
794
795        if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive):
796            raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', "
797                            f"but got {elem}")
798
799        if list_name == "black_list" and not issubclass(elem, nn.Cell):
800            raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}")
801
802    if list_name == 'black_list':
803        for elem in AMP_BLACK_LIST:
804            if elem not in custom_list:
805                logger.warning(f"{elem} is removed from internal black list.")
806
807
808def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable
809    """Configure auto mixed precision."""
810    global MS_AMP_BY_REWRITE
811    global _amp_cast_op
812
813    if enable_rewrite is not None:
814        MS_AMP_BY_REWRITE = enable_rewrite
815
816    if cast_op is not None:
817        _amp_cast_op = cast_op
818