1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Conversion interface for second-order optimizer thor.""" 16from __future__ import absolute_import 17 18import mindspore.nn as nn 19import mindspore.common.dtype as mstype 20from mindspore import context 21 22 23class ConvertNetUtils: 24 """ 25 Convert net to thor layer net, used to compute and store second-order information matrix. 26 27 Examples: 28 >>> import mindspore as ms 29 >>> convert_net_utils = ms.train.ConvertNetUtils() 30 """ 31 def __init__(self): 32 self._convert_method_map = {nn.Dense: ConvertNetUtils._convert_dense, 33 nn.Embedding: ConvertNetUtils._convert_embedding, 34 nn.Conv2d: ConvertNetUtils._convert_conv2d, 35 nn.EmbeddingLookup: ConvertNetUtils._convert_embeddinglookup} 36 37 38 @staticmethod 39 def _convert_dense(subcell): 40 """ 41 Convert dense cell to second-order cell 42 """ 43 weight = subcell.weight 44 act_name = None 45 if subcell.activation_flag: 46 act_class = subcell.activation.__class__.__name__ 47 act_name = act_class.lower() 48 if act_name == "fastgelu": 49 act_name = "fast_gelu" 50 if subcell.out_channels == 1001: 51 new_subcell = nn.DenseThor(in_channels=subcell.in_channels, 52 out_channels=subcell.out_channels, 53 weight_init=weight, 54 has_bias=subcell.has_bias, 55 bias_init='zeros', 56 activation=act_name) 57 else: 58 compute_type = mstype.float16 59 if context.get_context("device_target") == "GPU": 60 compute_type = mstype.float32 61 new_subcell = nn.DenseThor(in_channels=subcell.in_channels, 62 out_channels=subcell.out_channels, 63 weight_init=weight, 64 has_bias=subcell.has_bias, 65 bias_init='zeros', 66 activation=act_name).to_float(compute_type) 67 68 if subcell.has_bias: 69 new_subcell.bias = subcell.bias 70 return new_subcell 71 72 73 @staticmethod 74 def _convert_embedding(subcell): 75 """ 76 Convert embedding cell to second-order cell 77 """ 78 new_subcell = nn.EmbeddingThor(vocab_size=subcell.vocab_size, 79 embedding_size=subcell.embedding_size, 80 use_one_hot=False) 81 new_subcell.embedding_table = subcell.embedding_table 82 return new_subcell 83 84 85 @staticmethod 86 def _convert_embeddinglookup(subcell): 87 """ 88 convert embedding cell to second_order cell 89 """ 90 new_subcell = nn.EmbeddingLookupThor(vocab_size=subcell.vocab_size, 91 embedding_size=subcell.embedding_size, 92 target=subcell.target, sparse=subcell.sparse, 93 vocab_cache_size=subcell.vocab_cache_size) 94 new_subcell.embedding_table = subcell.embedding_table 95 return new_subcell 96 97 98 @staticmethod 99 def _convert_conv2d(subcell): 100 """ 101 Convert conv2d cell to second-order cell 102 """ 103 out_channel = subcell.out_channels 104 in_channel = subcell.in_channels 105 kernel_size = subcell.kernel_size[0] 106 stride = subcell.stride 107 padding = subcell.padding 108 pad_mode = subcell.pad_mode 109 has_bias = subcell.has_bias 110 weight = subcell.weight 111 112 new_subcell = nn.Conv2dThor(in_channel, out_channel, 113 kernel_size=kernel_size, stride=stride, padding=padding, pad_mode=pad_mode, 114 has_bias=has_bias, weight_init=weight) 115 return new_subcell 116 117 @staticmethod 118 def _need_change(subcell, prefix): 119 """for thor layers, need to change""" 120 if isinstance(subcell, (nn.Dense, nn.Conv2d)) and subcell.weight.requires_grad: 121 if "rpn_with_loss.rpn_convs_list." in prefix.lower() or "wide" in prefix.lower(): 122 return False 123 return True 124 if isinstance(subcell, (nn.Embedding, nn.EmbeddingLookup)) and subcell.embedding_table.requires_grad: 125 return True 126 return False 127 128 def _convert_to_thor_net(self, net): 129 """ 130 Convert net to thor net 131 """ 132 cells = net.name_cells() 133 change = False 134 for name in cells: 135 subcell = cells[name] 136 if subcell == net: 137 continue 138 elif isinstance(subcell, (nn.DenseThor, nn.Conv2dThor, nn.EmbeddingThor, nn.EmbeddingLookupThor)): 139 continue 140 elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm, 141 nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)): 142 continue 143 elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d, nn.EmbeddingLookup)): 144 prefix = subcell.param_prefix 145 if self._need_change(subcell, prefix): 146 new_subcell = self._convert_method_map.get(type(subcell))(subcell) 147 new_subcell.update_parameters_name(prefix + '.') 148 net.insert_child_to_cell(name, new_subcell) 149 change = True 150 else: 151 self._convert_to_thor_net(subcell) 152 153 if isinstance(net, nn.SequentialCell) and change: 154 net.cell_list = list(net.cells()) 155 156 157 def convert_to_thor_net(self, net): 158 """ 159 This interface is used to convert a network to thor layer network, in order to calculate and store the 160 second-order information matrix. 161 162 Note: 163 This interface is automatically called by the second-order optimizer thor. 164 165 Args: 166 net (Cell): Network to be trained by the second-order optimizer thor. 167 168 Supported Platforms: 169 ``Ascend`` ``GPU`` 170 171 Examples: 172 >>> import mindspore as ms 173 >>> from mindspore import nn 174 >>> from mindspore import Tensor 175 >>> from mindspore.nn import thor 176 >>> 177 >>> # Define the network structure of LeNet5. Refer to 178 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 179 >>> net = LeNet5() 180 >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], ms.float32) 181 >>> opt = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4) 182 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 183 >>> loss_scale = ms.FixedLossScaleManager(128, drop_overflow_update=False) 184 >>> model = ms.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, 185 ... amp_level="O2", keep_batchnorm_fp32=False) 186 >>> model = ms.train.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, 187 ... optimizer=opt,loss_scale_manager=loss_scale, metrics={'acc'}, 188 ... amp_level="O2", keep_batchnorm_fp32=False) 189 """ 190 191 net.update_cell_prefix() 192 self._convert_to_thor_net(net) 193 net.update_cell_type("second-order") 194 195 196class ConvertModelUtils: 197 """ 198 Convert model to thor model. 199 """ 200 @staticmethod 201 def convert_to_thor_model(model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0", 202 loss_scale_manager=None, keep_batchnorm_fp32=False): 203 """ 204 This interface is used to convert model to thor model. 205 206 Args: 207 model (Object): High-Level API for Training. 208 network (Cell): A training network. 209 loss_fn (Cell): Objective function. Default: ``None`` . 210 optimizer (Cell): Optimizer used to updating the weights. Default: ``None`` . 211 metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during 212 training. eg: {'accuracy', 'recall'}. Default: ``None`` . 213 amp_level (str): Level for mixed precision training. Supports ["O0", "O2", "O3", "auto"]. 214 Default: ``"O0"`` . 215 216 - O0: Do not change. 217 - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. 218 - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. 219 - auto: Set level to recommended level in different devices. O2 is recommended on GPU, O3 is 220 recommended on Ascend. The recommended level is based on the expert experience, cannot 221 always generalize. User should specify the level for special network. 222 223 loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. 224 Otherwise, scale the loss by LossScaleManager and optimizer can not be None. It is a key argument. 225 e.g. Use `loss_scale_manager=None` to set the value. Default: ``None`` . 226 keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If True, the level setting before 227 will be overwritten. Default: ``False`` . 228 229 Returns: 230 model (Object), High-Level API for Training. 231 232 Supported Platforms: 233 ``Ascend`` ``GPU`` 234 235 Examples: 236 >>> import mindspore as ms 237 >>> from mindspore import nn 238 >>> from mindspore import Tensor 239 >>> from mindspore.nn import thor 240 >>> 241 >>> # Define the network structure of LeNet5. Refer to 242 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 243 >>> net = LeNet5() 244 >>> # Create the dataset taking MNIST as an example. Refer to 245 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py 246 >>> dataset = create_dataset() 247 >>> temp = Tensor([4e-4, 1e-4, 1e-5, 1e-5], ms.float32) 248 >>> opt = thor(net, learning_rate=temp, damping=temp, momentum=0.9, loss_scale=128, frequency=4) 249 >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') 250 >>> loss_scale = ms.amp.FixedLossScaleManager(128, drop_overflow_update=False) 251 >>> model = ms.train.Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, 252 ... amp_level="O2", keep_batchnorm_fp32=False) 253 >>> model = ms.train.ConvertModelUtils.convert_to_thor_model(model=model, network=net, loss_fn=loss, 254 ... optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, 255 ... amp_level="O2", keep_batchnorm_fp32=False) 256 """ 257 258 optim_name = type(optimizer).__name__ 259 if optim_name in ("ThorAscend", "ThorGpu"): 260 from mindspore.train.train_thor.model_thor import ModelThor 261 if isinstance(network, nn.TrainOneStepCell): 262 model = ModelThor(network=network) 263 else: 264 model = ModelThor(network=network, loss_fn=loss_fn, optimizer=optimizer, amp_level=amp_level, 265 loss_scale_manager=loss_scale_manager, 266 keep_batchnorm_fp32=keep_batchnorm_fp32, metrics=metrics) 267 268 return model 269