1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Auto mixed precision.""" 16from __future__ import absolute_import 17import inspect 18import types 19 20import mindspore as ms 21from mindspore import nn 22from mindspore import _checkparam as validator 23from mindspore.common import dtype as mstype 24from mindspore.nn.wrap.cell_wrapper import _TrainGradAccuStepCell 25from mindspore.nn.wrap.loss_scale import _TrainGradAccuWithLossScaleCell 26from mindspore.ops import functional as F 27from mindspore.parallel._utils import _get_pipeline_stages 28from mindspore.train.loss_scale_manager import DynamicLossScaleManager, LossScaleManager 29from mindspore import boost, context 30from mindspore.ops import operations as P 31from mindspore.ops import Primitive 32from mindspore import log as logger 33 34 35AMP_WHITE_LIST = [ 36 nn.Conv1d, 37 nn.Conv2d, 38 nn.Conv3d, 39 nn.Conv1dTranspose, 40 nn.Conv2dTranspose, 41 nn.Conv3dTranspose, 42 nn.Dense, 43 nn.LSTMCell, 44 nn.RNNCell, 45 nn.GRUCell, 46 P.Conv2D, 47 P.Conv3D, 48 P.Conv2DTranspose, 49 P.Conv3DTranspose, 50 P.Conv2DBackpropInput, 51 P.MatMul, 52 P.BatchMatMul, 53 P.PReLU, 54 P.ReLU, 55 P.Ger 56] 57 58 59AMP_BLACK_LIST = [ 60 nn.BatchNorm1d, 61 nn.BatchNorm2d, 62 nn.BatchNorm3d, 63 nn.LayerNorm 64] 65 66# Primitives in inner amp black list will not be converted in O2/O3 67_INNER_AMP_BLACK_LIST = [] 68 69MS_AMP_BY_REWRITE = False 70 71 72def amp_cast(value, dtype): 73 """This function is used to insert cast operators for tensors during auto mixed precision.""" 74 if isinstance(value, ms.Tensor) and value.dtype in mstype.float_type: 75 return P.Cast()(value, dtype) 76 return value 77 78_amp_cast_op = amp_cast 79 80 81class _OutputTo16(nn.Cell): 82 """Wrap cell for amp. Cast network output back to float16.""" 83 def __init__(self, backbone, dtype=mstype.float16): 84 super(_OutputTo16, self).__init__(auto_prefix=False) 85 self._backbone = backbone 86 self.dtype = dtype 87 self._get_attr_from_cell(backbone) 88 89 def construct(self, *args, **kwargs): 90 return F.cast(self._backbone(*args, **kwargs), self.dtype) 91 92 93class _OutputTo32(nn.Cell): 94 """Wrap loss for amp. Cast network output back to float32.""" 95 def __init__(self, backbone): 96 super(_OutputTo32, self).__init__(auto_prefix=False) 97 self._backbone = backbone 98 self._get_attr_from_cell(backbone) 99 100 def construct(self, *args, **kwargs): 101 out = self._backbone(*args, **kwargs) 102 return F.mixed_precision_cast(mstype.float32, out) 103 104 105def _operator_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool: 106 """ 107 Check whether current node is a operator that need to be casted. Follow conditions need to be satisfied: 108 1) Type of node is CallPrimitive and type of instance is Primitive 109 2) Type of instance is not P.Cast 110 3) force_cast is True, which means one of upper layer cells is under casting 111 4) white_list exist and type of node is in white_list 112 5) black_list exist and type of node is in not black_list 113 """ 114 if node.get_node_type() != ms.rewrite.NodeType.CallPrimitive: 115 return False 116 if not inspect.isclass(node.get_instance_type()): 117 return False 118 if not issubclass(node.get_instance_type(), Primitive): 119 return False 120 if issubclass(node.get_instance_type(), P.Cast): 121 return False 122 if node.get_instance_type() in _INNER_AMP_BLACK_LIST: 123 return False 124 if force_cast: 125 return True 126 if white_list is not None and node.get_instance_type() in white_list: 127 return True 128 if black_list is not None and node.get_instance_type() not in black_list: 129 return True 130 return False 131 132 133def _precision_set_by_user(cell_inst: nn.Cell) -> bool: 134 """Check whether cell precision is set by user.""" 135 for flag in ["fp32", "fp16", "bf16"]: 136 if hasattr(cell_inst, flag) and getattr(cell_inst, flag): 137 return True 138 return False 139 140 141def _net_need_cast(node, force_cast: bool, white_list=None, black_list=None) -> bool: 142 """ 143 Check whether current node is type of tree whose network needs to be casted. Follow conditions need to 144 be satisfied: 145 1) Type of node is Tree and type of instance is Cell 146 2) Cell.to_float(xxx) is not set by user 147 3) force_cast is True, which means one of upper layer networks is under casting 148 4) white_list exist and type of node is in white_list 149 5) black_list exist and type of node is in not black_list 150 """ 151 if node.get_node_type() != ms.rewrite.NodeType.Tree: 152 return False 153 if not inspect.isclass(node.get_instance_type()): 154 return False 155 if not issubclass(node.get_instance_type(), nn.Cell): 156 return False 157 if node.get_instance_type() in _INNER_AMP_BLACK_LIST: 158 return False 159 if _precision_set_by_user(node.get_instance()): 160 return False 161 if force_cast: 162 return True 163 if white_list is not None and node.get_instance_type() in white_list: 164 return True 165 if black_list is not None and node.get_instance_type() not in black_list: 166 return True 167 return False 168 169 170def _insert_cast_for_operator(node, dtype): 171 """insert cast pair for node.""" 172 dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16" 173 stree = node.get_symbol_tree() 174 # insert cast fp16/bf16 for inputs of node 175 for idx, arg in enumerate(node.get_args()): 176 if arg.type != ms.rewrite.ValueType.NamingValue: 177 continue 178 incast_args = ms.rewrite.ScopedValue.create_name_values([arg.value, dtype_str], [arg.scope, "mindspore"]) 179 arg_providers = node.get_arg_providers() 180 if not arg_providers or idx not in arg_providers or \ 181 len(arg_providers[idx][0].get_target_users(arg_providers[idx][1])) > 1: 182 # create new target names when argument is used by other node 183 incast_targets = [stree.unique_name(f"{arg.value}_var")] 184 else: 185 incast_targets = ms.rewrite.ScopedValue.create_name_values([arg.value], [arg.scope]) 186 incast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=incast_targets, args=incast_args) 187 stree.insert(stree.before(node), incast_node) 188 node.set_arg_by_node(idx, incast_node) 189 # insert cast fp32 for outputs of node 190 for _, target in enumerate(node.get_targets()): 191 if target.type != ms.rewrite.ValueType.NamingValue: 192 continue 193 outcast_args = ms.rewrite.ScopedValue.create_name_values([target.value, "float32"], 194 [target.scope, "mindspore"]) 195 outcast_targets = ms.rewrite.ScopedValue.create_name_values([target.value], [target.scope]) 196 outcast_node = ms.rewrite.Node.create_call_function(_amp_cast_op, targets=outcast_targets, args=outcast_args) 197 stree.insert(stree.after(node), outcast_node) 198 199 200def _insert_cast_for_operators(stree, dtype, force_cast, *, white_list=None, black_list=None): 201 """insert cast for operators not in black_list.""" 202 # get all nodes of stree exclude nodes in subtree. 203 all_nodes = stree.all_nodes(False) 204 for node in all_nodes: 205 if not node.get_targets(): 206 continue 207 if _operator_need_cast(node, force_cast, white_list, black_list): 208 _insert_cast_for_operator(node, dtype) 209 elif node.get_node_type() == ms.rewrite.NodeType.Tree: 210 force_cast_ = force_cast or _net_need_cast(node, force_cast, white_list, black_list) 211 if not _precision_set_by_user(node.get_instance()): 212 subtree = node.get_sub_tree() 213 _insert_cast_for_operators(subtree, dtype, force_cast_, white_list=white_list, black_list=black_list) 214 215 216def _need_removed_cast_pair(node, dtype): 217 """check whether the cast pairs should be removed.""" 218 dtype_str = "bfloat16" if dtype == mstype.bfloat16 else "float16" 219 cast_dtypes = ms.rewrite.ScopedValue.create_name_values([dtype_str, "float32"], ["mindspore", "mindspore"]) 220 cast_dtype_f16 = cast_dtypes[0] 221 cast_dtype_f32 = cast_dtypes[1] 222 # current node should be cast fp32 223 if node.get_instance_type() != _amp_cast_op: 224 return False 225 node_cast_type = node.get_args()[1] 226 if node_cast_type != cast_dtype_f32: 227 return False 228 # all user nodes should be cast fp16/bf16 229 if not node.get_users(): 230 return False 231 all_nodes = [ms.rewrite.Node(n) for n in node.get_handler().get_node_manager().nodes()] 232 for user in node.get_users(): 233 # If ControlFlow node(e.g. if, for, while) exists between current node and user node, 234 # cast pair should not be removed. 235 middle_nodes = all_nodes[all_nodes.index(node): all_nodes.index(user)] 236 if any([n.get_node_type() == ms.rewrite.NodeType.ControlFlow for n in middle_nodes]): 237 return False 238 if user.get_instance_type() != _amp_cast_op: 239 return False 240 user_cast_type = user.get_args()[1] 241 if user_cast_type != cast_dtype_f16: 242 return False 243 # cast pair detected, check next user 244 continue 245 return True 246 247 248def _remove_duplicated_cast(stree, dtype): 249 """remove the duplicated cast operators.""" 250 all_nodes = list(stree.nodes(all_nodes=True)) 251 for node in all_nodes: 252 if _need_removed_cast_pair(node, dtype): 253 incast_nodes = node.get_users() 254 # remove cast fp16/bf16 nodes 255 for incast_node in incast_nodes: 256 # get_target_users() return {target0: [(user0, arg_idx), ...], ...} 257 target_users = list(incast_node.get_target_users().values()) 258 if not target_users or not target_users[0]: 259 continue 260 for user_node, arg_idx in target_users[0]: 261 user_node.set_arg(arg_idx, incast_node.get_args()[0]) 262 stree.erase(incast_node) 263 # remove the cast fp32 node 264 stree.erase(node) 265 266 267def _auto_mixed_precision_rewrite(network, dtype, *, white_list=None, black_list=None): 268 """Implement auto mixed precision by rewrite""" 269 if (white_list is None and black_list is None) or (white_list is not None and black_list is not None): 270 raise ValueError("For _auto_mixed_precision_rewrite, one of white_list and black_list must be provided.") 271 # enable rewrite configs for amp 272 ms.rewrite.common.namespace._ms_cells_to_subtree = True 273 ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = True 274 # insert casts by rewrite 275 stree = ms.rewrite.SymbolTree.create(network) 276 _insert_cast_for_operators(stree, dtype, False, white_list=white_list, black_list=black_list) 277 _remove_duplicated_cast(stree, dtype) 278 new_net = stree.get_network() 279 # disable rewrite configs 280 ms.rewrite.parsers.assign_parser.AssignParser._share_one_implementation = False 281 ms.rewrite.common.namespace._ms_cells_to_subtree = False 282 ms.rewrite.common.config.clear_caches() 283 return new_net 284 285 286def _auto_black_list(network, black_list, dtype): 287 """process the black list of network.""" 288 network.to_float(dtype) 289 cells = network.name_cells() 290 change = False 291 for name in cells: 292 subcell = cells[name] 293 if subcell == network: 294 continue 295 if isinstance(subcell, tuple(black_list)): 296 network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32), dtype) 297 change = True 298 else: 299 _auto_black_list(subcell, black_list, dtype) 300 if isinstance(network, nn.SequentialCell) and change: 301 network.cell_list = list(network.cells()) 302 return network 303 304 305def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16): 306 """ 307 Returns a network processed with auto mixed precision. 308 309 This interface will automatically perform mixed-precision processing on the input network, and the cells 310 and operators in the processed network will add precision conversion operations to calculate with lower 311 precision: ``mstype.float16`` or ``mstype.bfloat16`` . Inputs and parameters of cells and operators are 312 converted to lower precision float, and calculation results are converted back to full precision float, 313 i.e. ``mstype.float32`` . 314 315 The framework has a set of built-in blacklists and whitelists, and the `amp_level` determines which cells and 316 operators are specifically converted. 317 318 The current built-in whitelist contents are: 319 320 [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`, 321 :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`, 322 :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`, 323 :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`, 324 :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`, 325 :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`, 326 :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`] 327 328 The current built-in blacklist contents are: 329 330 [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`, 331 :class:`mindspore.nn.LayerNorm`] 332 333 For details on automatic mixed precision, refer to 334 `Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ . 335 336 Note: 337 - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`, 338 can result in a larger network hierarchy and slower performance. 339 - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by 340 mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` 341 need to be configured to ``O0`` to avoid the duplicated accuracy conversion. 342 343 Args: 344 network (Cell): Definition of the network. 345 amp_level (str): Supports ["O0", "O1", "O2", "O3"]. Default: ``"O0"`` . 346 347 - "O0": Do not change. 348 - "O1": Convert cells and operators in whitelist to lower precision operations, and keep full 349 precision operations for the rest. 350 - "O2": Keep full precision operations for cells and operators in blacklist, and convert the rest 351 to lower precision operations. 352 - "O3": Cast network to lower precision. 353 354 dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` , 355 default: ``mstype.float16`` . 356 357 Raises: 358 TypeError: If `network` is not a Cell. 359 ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` . 360 ValueError: If `amp_level` is not within the supported range. 361 362 Examples: 363 >>> from mindspore import amp 364 >>> # Define the network structure of LeNet5. Refer to 365 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 366 >>> network = LeNet5() 367 >>> amp_level = "O1" 368 >>> net = amp.auto_mixed_precision(network, amp_level) 369 """ 370 if not isinstance(network, nn.Cell): 371 raise TypeError("The network type should be Cell.") 372 373 if dtype not in (mstype.float16, mstype.bfloat16): 374 raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.") 375 376 if amp_level == "O0": 377 return network 378 379 # Return network if the same amp level has already been configurated 380 if getattr(network, "_amp_level") in ("O1", "O2", "O3"): 381 logger.warning(f"The network's auto mixed-precision level is adjusted from {getattr(network, '_amp_level')} " 382 f"to {amp_level}, and repeated calls to mixed-precision interfaces can cause performance " 383 f"degradation.") 384 385 if amp_level == "O1": 386 network = _auto_mixed_precision_rewrite(network, dtype, white_list=AMP_WHITE_LIST) 387 elif amp_level == "O2": 388 if MS_AMP_BY_REWRITE: 389 network = _auto_mixed_precision_rewrite(network, dtype, black_list=AMP_BLACK_LIST) 390 else: 391 network = _auto_black_list(network, AMP_BLACK_LIST, dtype) 392 network = _OutputTo32(network) 393 elif amp_level == "O3": 394 if MS_AMP_BY_REWRITE: 395 network = _auto_mixed_precision_rewrite(network, dtype, black_list=[]) 396 else: 397 network.to_float(dtype) 398 network = _OutputTo32(network) 399 else: 400 raise ValueError("The amp level {} is not supported".format(amp_level)) 401 402 setattr(network, "_amp_level", amp_level) 403 404 return network 405 406 407def _do_keep_batchnorm_fp32(network): 408 """Do keep batchnorm fp32.""" 409 cells = network.name_cells() 410 change = False 411 for name in cells: 412 subcell = cells[name] 413 if subcell == network: 414 continue 415 elif isinstance(subcell, nn.Cell) and isinstance(subcell, tuple(AMP_BLACK_LIST)): 416 network._cells[name] = _OutputTo16(subcell.to_float(mstype.float32)) 417 change = True 418 else: 419 _do_keep_batchnorm_fp32(subcell) 420 if isinstance(network, nn.SequentialCell) and change: 421 network.cell_list = list(network.cells()) 422 423 424_config_level = { 425 "O0": { 426 "keep_batchnorm_fp32": False, 427 "cast_model_type": mstype.float32, 428 "loss_scale_manager": None}, 429 "O1": { 430 "keep_batchnorm_fp32": False, 431 "cast_model_type": mstype.float32, 432 "loss_scale_manager": None}, 433 "O2": { 434 "keep_batchnorm_fp32": True, 435 "cast_model_type": mstype.float16, 436 "loss_scale_manager": DynamicLossScaleManager()}, 437 "O3": { 438 "keep_batchnorm_fp32": False, 439 "cast_model_type": mstype.float16, 440 "loss_scale_manager": None}} 441 442 443def _check_kwargs(key_words): 444 """Check kwargs.""" 445 for arg in key_words: 446 if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: 447 raise ValueError(f"Unsupported arg '{arg}'") 448 449 if 'cast_model_type' in key_words: 450 validator.check_type_name('cast_model_type', key_words['cast_model_type'], 451 [mstype.float16, mstype.float32], None) 452 if 'keep_batchnorm_fp32' in key_words: 453 validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) 454 if 'loss_scale_manager' in key_words: 455 loss_scale_manager = key_words['loss_scale_manager'] 456 if loss_scale_manager: 457 validator.check_value_type('loss_scale_manager', loss_scale_manager, 458 [LossScaleManager, boost.GroupLossScaleManager]) 459 460 461def _check_level(level, boost_level): 462 """Check level.""" 463 if not isinstance(level, str): 464 raise TypeError("The argument `level` must be a string in ['O0', 'O1', 'O2', 'O3', 'auto'], \ 465 but got type {}.".format(type(level))) 466 validator.check('level', level, "", ['O0', 'O1', 'O2', 'O3', 'auto'], validator.IN) 467 validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], validator.IN) 468 469 if level == "auto": 470 device_target = context.get_context('device_target') 471 if device_target == "GPU": 472 level = "O2" 473 elif device_target == "Ascend": 474 level = "O3" 475 else: 476 raise ValueError("Level `auto` only support when `device_target` is GPU or Ascend.") 477 478 enable_boost = False 479 if boost_level in ["O1", "O2"]: 480 enable_boost = True 481 482 return level, enable_boost 483 484 485def _add_loss_network(network, loss_fn, cast_model_type): 486 """Add loss network.""" 487 488 class WithLossCell(nn.Cell): 489 """Wrap loss for amp. Cast network output back to float32.""" 490 def __init__(self, backbone, loss_fn): 491 super(WithLossCell, self).__init__(auto_prefix=False) 492 self._backbone = backbone 493 self._loss_fn = loss_fn 494 self._get_attr_from_cell(backbone) 495 496 def construct(self, data, label): 497 out = self._backbone(data) 498 label = F.mixed_precision_cast(mstype.float32, label) 499 return self._loss_fn(F.mixed_precision_cast(mstype.float32, out), label) 500 501 validator.check_value_type('loss_fn', loss_fn, nn.Cell) 502 if cast_model_type == mstype.float16: 503 network = WithLossCell(network, loss_fn) 504 else: 505 network = nn.WithLossCell(network, loss_fn) 506 return network 507 508 509def _is_grad_accumulation(mcell): 510 if mcell.cls_name == "GradAccumulationCell": 511 return True 512 for cell in mcell.cells(): 513 if _is_grad_accumulation(cell): 514 return True 515 return False 516 517 518def _auto_mixed_precision_process(network, config, level): 519 """Auto mixed precision process.""" 520 if MS_AMP_BY_REWRITE: 521 if config["cast_model_type"] == mstype.float16 or level == "O2": 522 level = "O2" if config["keep_batchnorm_fp32"] else "O3" 523 elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"): 524 # cast_model_type set by kwargs 525 level = "O0" 526 network = auto_mixed_precision(network, level) 527 else: 528 if config["cast_model_type"] == mstype.float16: 529 network.to_float(mstype.float16) 530 531 if config["keep_batchnorm_fp32"]: 532 _do_keep_batchnorm_fp32(network) 533 elif not config["keep_batchnorm_fp32"] and level == "O2": 534 network.to_float(mstype.float16) 535 elif config["cast_model_type"] == mstype.float32 and level in ("O2", "O3"): 536 pass 537 else: 538 network = auto_mixed_precision(network, level) 539 return network 540 541 542def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_level='O0', **kwargs): 543 """ 544 Build the mixed precision training cell automatically. 545 546 Note: 547 - After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported 548 to perform the precision conversion again. If `build_train_network` is used to train a converted network, 549 `level` need to be configured to ``O0`` to avoid the duplicated accuracy conversion. 550 551 Args: 552 network (Cell): Definition of the network. 553 optimizer (:class:`mindspore.nn.Optimizer`): Define the optimizer to update the Parameter. 554 loss_fn (Union[None, Cell]): Define the loss function. If None, the `network` should have the loss inside. 555 Default: ``None`` . 556 level (str): Supports ['O0', 'O1', 'O2', 'O3', 'auto']. Default: ``'O0'`` . 557 558 - 'O0': Do not change. 559 - 'O1': Cast the operators in white_list to float16, the remaining operators are kept in float32. 560 The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose, 561 Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger]. 562 - 'O2': Cast network to float16, keep `mindspore.nn.BatchNorm` series interface, 563 :class:`mindspore.nn.LayerNorm` and `loss_fn` (if set) run in float32, using dynamic loss scale. 564 - 'O3': Cast network to float16, with additional property `keep_batchnorm_fp32=False` . 565 - 'auto': Set to level to recommended level in different devices. Set level to 'O2' on GPU, Set 566 level to 'O3' Ascend. The recommended level is chosen by the export experience, not applicable to all 567 scenarios. User should specify the level for special network. 568 569 'O2' is recommended on GPU, 'O3' is recommended on Ascend. Property of `keep_batchnorm_fp32`, 570 `cast_model_type` and `loss_scale_manager` determined by `level` setting may be overwritten by settings in 571 `kwargs`. 572 573 boost_level (str): Option for argument `level` in `mindspore.boost` , level for boost mode 574 training. Supports ['O0', 'O1', 'O2']. Default: ``'O0'`` . 575 576 - 'O0': Do not change. 577 - 'O1': Enable the boost mode, the performance is improved by about 20%, and 578 the accuracy is the same as the original accuracy. 579 - 'O2': Enable the boost mode, the performance is improved by about 30%, and 580 the accuracy is reduced by less than 3%. 581 582 If 'O1' or 'O2' mode is set, the boost related library will take effect automatically. 583 584 cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32` . If set, the 585 network will be casted to `cast_model_type` ( `mstype.float16` or `mstype.float32` ), but not to be casted 586 to the type determined by `level` setting. 587 keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32` when the network is set to cast to `float16` . If 588 set, the `level` setting will take no effect on this property. 589 loss_scale_manager (Union[None, LossScaleManager]): If not None, must be subclass of 590 :class:`mindspore.amp.LossScaleManager` for scaling the loss. If set, the `level` setting will 591 take no effect on this property. 592 593 Raises: 594 ValueError: If device is CPU, property `loss_scale_manager` is not `None` or 595 :class:`mindspore.amp.FixedLossScaleManager` (with property `drop_overflow_update=False` ). 596 597 Examples: 598 >>> from mindspore import amp, nn 599 >>> # Define the network structure of LeNet5. Refer to 600 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 601 >>> network = LeNet5() 602 >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean") 603 >>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9) 604 >>> amp_level="O3" 605 >>> net = amp.build_train_network(network, net_opt, net_loss, amp_level) 606 """ 607 validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt, 608 nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)) 609 610 level, enable_boost = _check_level(level, boost_level) 611 612 _check_kwargs(kwargs) 613 config = dict(_config_level.get(level), **kwargs) 614 615 network = _auto_mixed_precision_process(network, config, level) 616 617 if loss_fn: 618 network = _add_loss_network(network, loss_fn, config["cast_model_type"]) 619 620 loss_scale = None 621 if config["loss_scale_manager"] is not None: 622 loss_scale_manager = config["loss_scale_manager"] 623 loss_scale = loss_scale_manager.get_loss_scale() 624 update_cell = loss_scale_manager.get_update_cell() 625 if update_cell is not None: 626 # only cpu not support `TrainOneStepWithLossScaleCell` for control flow. 627 if not context.get_context("enable_ge") and context.get_context("device_target") == "CPU": 628 raise ValueError("Only `loss_scale_manager=None` or " 629 "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" 630 "are supported on device `CPU`. ") 631 if _get_pipeline_stages() > 1 or _is_grad_accumulation(network): 632 network = _TrainGradAccuWithLossScaleCell(network, optimizer, 633 scale_sense=update_cell).set_train() 634 elif enable_boost: 635 network = boost.BoostTrainOneStepWithLossScaleCell(network, optimizer, 636 scale_sense=update_cell).set_train() 637 else: 638 network = nn.TrainOneStepWithLossScaleCell(network, optimizer, 639 scale_sense=update_cell).set_train() 640 return network 641 if _get_pipeline_stages() > 1 or _is_grad_accumulation(network): 642 network = _TrainGradAccuStepCell(network, optimizer).set_train() 643 elif enable_boost: 644 network = boost.BoostTrainOneStepCell(network, optimizer, loss_scale).set_train() 645 else: 646 network = nn.TrainOneStepCell(network, optimizer, loss_scale).set_train() 647 return network 648 649 650def get_white_list(): 651 """ 652 Provide a copy of internal white list used by auto mixed precision. 653 654 The current built-in whitelist contents are: 655 656 [:class:`mindspore.nn.Conv1d`, :class:`mindspore.nn.Conv2d`, :class:`mindspore.nn.Conv3d`, 657 :class:`mindspore.nn.Conv1dTranspose`, :class:`mindspore.nn.Conv2dTranspose`, 658 :class:`mindspore.nn.Conv3dTranspose`, :class:`mindspore.nn.Dense`, :class:`mindspore.nn.LSTMCell`, 659 :class:`mindspore.nn.RNNCell`, :class:`mindspore.nn.GRUCell`, :class:`mindspore.ops.Conv2D`, 660 :class:`mindspore.ops.Conv3D`, :class:`mindspore.ops.Conv2DTranspose`, 661 :class:`mindspore.ops.Conv3DTranspose`, :class:`mindspore.ops.MatMul`, :class:`mindspore.ops.BatchMatMul`, 662 :class:`mindspore.ops.PReLU`, :class:`mindspore.ops.ReLU`, :class:`mindspore.ops.Ger`] 663 664 Returns: 665 list, A copy of internal white list. 666 667 Examples: 668 >>> from mindspore import amp 669 >>> white_list = amp.get_white_list() 670 >>> print(white_list) 671 [<class 'mindspore.nn.layer.conv.Conv1d'>, <class 'mindspore.nn.layer.conv.Conv2d'>, 672 <class 'mindspore.nn.layer.conv.Conv3d'>, <class 'mindspore.nn.layer.conv.Conv1dTranspose'>, 673 <class 'mindspore.nn.layer.conv.Conv2dTranspose'>, <class 'mindspore.nn.layer.conv.Conv3dTranspose'>, 674 <class 'mindspore.nn.layer.basic.Dense'>, <class 'mindspore.nn.layer.rnn_cells.LSTMCell'>, 675 <class 'mindspore.nn.layer.rnn_cells.RNNCell'>, <class 'mindspore.nn.layer.rnn_cells.GRUCell'>, 676 <class 'mindspore.ops.operations.nn_ops.Conv2D'>, <class 'mindspore.ops.operations.nn_ops.Conv3D'>, 677 <class 'mindspore.ops.operations.nn_ops.Conv2DTranspose'>, 678 <class 'mindspore.ops.operations.nn_ops.Conv3DTranspose'>, 679 <class 'mindspore.ops.operations.nn_ops.Conv2DBackpropInput'>, 680 <class 'mindspore.ops.operations.math_ops.MatMul'>, <class 'mindspore.ops.operations.math_ops.BatchMatMul'>, 681 <class 'mindspore.ops.operations.nn_ops.PReLU'>, <class 'mindspore.ops.operations.nn_ops.ReLU'>, 682 <class 'mindspore.ops.operations.math_ops.Ger'>] 683 """ 684 white_list = AMP_WHITE_LIST.copy() 685 return white_list 686 687 688def get_black_list(): 689 """ 690 Provide a copy of internal black list used by auto mixed precision. 691 692 The current built-in blacklist contents are: 693 694 [:class:`mindspore.nn.BatchNorm1d`, :class:`mindspore.nn.BatchNorm2d`, :class:`mindspore.nn.BatchNorm3d`, 695 :class:`mindspore.nn.LayerNorm`] 696 697 Returns: 698 list, A copy of internal black list. 699 700 Examples: 701 >>> from mindspore import amp 702 >>> black_list = amp.get_black_list() 703 >>> print(black_list) 704 [<class 'mindspore.nn.layer.normalization.BatchNorm1d'>, <class 'mindspore.nn.layer.normalization.BatchNorm2d'>, 705 <class 'mindspore.nn.layer.normalization.BatchNorm3d'>, <class 'mindspore.nn.layer.normalization.LayerNorm'>] 706 """ 707 black_list = AMP_BLACK_LIST.copy() 708 return black_list 709 710 711def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=mstype.float16): 712 """ 713 Custom mixed precision by setting whitelist or blacklist. 714 When the `white_list` is provided, primitives and cells in `white_list` will perform the precision conversion. 715 When the `black_list` is provided, cells that are not in `black_list` will perform the pereision conversion. 716 Only one of `white_list` and `black_list` should be provided. 717 718 Note: 719 - Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`, 720 can result in a larger network hierarchy and slower performance. 721 - If interfaces like `Model` and `build_train_network` is used to train the network which is converted by 722 mixed-precision interfaces such as `custom_mixed_precision` and `auto_mixed_precision`, `amp_level` 723 need to be configured to ``O0`` to avoid the duplicated accuracy conversion. 724 - Primitives for blacklist is not support yet. 725 726 Args: 727 network (Cell): Definition of the network. 728 white_list (list[Primitive, Cell], optional): White list of custom mixed precision. Defaults: ``None`` , means 729 white list is not used. 730 black_list (list[Cell], optional): Black list of custom mixed precision. Defaults: ``None`` , means 731 black list is not used. 732 dtype (Type): The type used in lower precision calculations, can be ``mstype.float16`` or ``mstype.bfloat16`` , 733 default: ``mstype.float16`` . 734 735 Returns: 736 network (Cell), A network supporting mixed precision. 737 738 Raises: 739 TypeError: The network type is not Cell. 740 ValueError: Neither `white_list` nor `black_list` is provided. 741 ValueError: If `dtype` is not one of ``mstype.float16`` , ``mstype.bfloat16`` . 742 ValueError: Both `white_list` and `black_list` are provided. 743 744 Examples: 745 >>> from mindspore import amp, nn 746 >>> # Define the network structure of LeNet5. Refer to 747 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 748 >>> net = LeNet5() 749 >>> custom_white_list = amp.get_white_list() 750 >>> custom_white_list.append(nn.Flatten) 751 >>> net = amp.custom_mixed_precision(net, white_list=custom_white_list) 752 """ 753 if not isinstance(network, nn.Cell): 754 raise TypeError("The network type should be Cell.") 755 756 if white_list is None and black_list is None: 757 raise ValueError("For custom_mixed_precision, one of white_list and black_list must be provided.") 758 759 if white_list is not None and black_list is not None: 760 raise ValueError("For custom_mixed_precision, the white_list or black_list cannot be provided " 761 "at the same time, please provide one or the other.") 762 763 if dtype not in (mstype.float16, mstype.bfloat16): 764 raise ValueError(f"The dtype should be one of (mstype.float16, mstype.bfloat16), but got {dtype}.") 765 766 if white_list is not None: 767 _list_check(white_list, "white_list") 768 network = _auto_mixed_precision_rewrite(network, dtype, white_list=white_list) 769 else: 770 _list_check(black_list, "black_list") 771 if MS_AMP_BY_REWRITE: 772 network = _auto_mixed_precision_rewrite(network, dtype, black_list=black_list) 773 else: 774 network = _auto_black_list(network, black_list, dtype) 775 network = _OutputTo32(network) 776 return network 777 778 779def _list_check(custom_list: list, list_name: str): 780 """ 781 check whether custom list is valid 782 783 Raises: 784 TypeError: The type of custom_list is not list. 785 TypeError: The element in custom_list is not a class. 786 TypeError: The subclass of element in custom_list is not one of ['Cell', 'Primitive']. 787 """ 788 if not isinstance(custom_list, list): 789 raise TypeError(f"The type of {list_name} should be list, but got {type(custom_list)}") 790 791 for elem in custom_list: 792 if not isinstance(elem, type): 793 raise TypeError(f"The element in {list_name} should be a class, but got {elem}") 794 795 if list_name == "white_list" and not issubclass(elem, nn.Cell) and not issubclass(elem, Primitive): 796 raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell' and 'Primitive', " 797 f"but got {elem}") 798 799 if list_name == "black_list" and not issubclass(elem, nn.Cell): 800 raise TypeError(f"The subclass of element in {list_name} should be one of 'Cell', but got {elem}") 801 802 if list_name == 'black_list': 803 for elem in AMP_BLACK_LIST: 804 if elem not in custom_list: 805 logger.warning(f"{elem} is removed from internal black list.") 806 807 808def _config_amp(*, enable_rewrite: bool = None, cast_op: types.FunctionType = None): # pylint: disable=unused-variable 809 """Configure auto mixed precision.""" 810 global MS_AMP_BY_REWRITE 811 global _amp_cast_op 812 813 if enable_rewrite is not None: 814 MS_AMP_BY_REWRITE = enable_rewrite 815 816 if cast_op is not None: 817 _amp_cast_op = cast_op 818