1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Group Loss Scale Manager""" 16from __future__ import absolute_import 17from __future__ import division 18 19from mindspore.nn.cell import Cell 20import mindspore.common.dtype as mstype 21from mindspore.ops import operations as P 22from mindspore.common.tensor import Tensor 23from mindspore.common.parameter import Parameter, ParameterTuple 24 25 26__all__ = ["GroupLossScaleManager"] 27 28 29class GroupLossScaleManager(Cell): 30 r""" 31 Enhanced hybrid precision algorithm supports multi-layer application of different loss scales and 32 dynamic updating of loss scales. 33 34 Args: 35 init_loss_scale (Number): The initialized loss scale value. 36 loss_scale_groups (List): The loss scale groups, which are divided from the param list. 37 38 Inputs: 39 - **x** (Tensor) - The output of last operator. 40 - **layer1** (Int) - Current network layer value. 41 - **layer2** (Int) - Last network layer value. 42 43 Outputs: 44 - **out** (Tensor) - A tensor with a group of loss scale tags that marks 45 the loss scale group number of the current tensor. 46 47 Supported Platforms: 48 ``Ascend`` 49 50 Examples: 51 >>> import mindspore as ms 52 >>> from mindspore import boost, nn 53 >>> 54 >>> class Net(nn.Cell): 55 ... def __init__(self, enhanced_amp, num_class=10, num_channel=1): 56 ... super(Net, self).__init__() 57 ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') 58 ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') 59 ... self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones') 60 ... self.fc2 = nn.Dense(120, 84, weight_init='ones') 61 ... self.fc3 = nn.Dense(84, num_class, weight_init='ones') 62 ... self.relu = nn.ReLU() 63 ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) 64 ... self.flatten = nn.Flatten() 65 ... self.enhanced_amp = enhanced_amp 66 ... 67 ... def construct(self, x): 68 ... x = self.enhanced_amp(x, 0, 1) 69 ... x = self.max_pool2d(self.relu(self.conv1(x))) 70 ... x = self.max_pool2d(self.relu(self.conv2(x))) 71 ... x = self.flatten(x) 72 ... x = self.enhanced_amp(x, 1, 2) 73 ... x = self.relu(self.fc1(x)) 74 ... x = self.relu(self.fc2(x)) 75 ... x = self.fc3(x) 76 ... x = self.enhanced_amp(x, 2, 3) 77 ... return x 78 >>> 79 >>> loss_scale_manager = boost.GroupLossScaleManager(4096, []) 80 >>> net = Net(loss_scale_manager) 81 >>> param_group1 = [] 82 >>> param_group2 = [] 83 >>> for param in net.trainable_params(): 84 ... if 'conv' in param.name: 85 ... param_group1.append(param) 86 ... else: 87 ... param_group2.append(param) 88 >>> loss_scale_manager.loss_scale_groups = [param_group1, param_group2] 89 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 90 >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) 91 >>> boost_config_dict = {"boost": {"mode": "manual", "less_bn": False, "grad_freeze": False, "adasum": False, 92 ... "grad_accumulation": False, "dim_reduce": False, "loss_scale_group": True}} 93 >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None, 94 ... loss_scale_manager=loss_scale_manager, 95 ... boost_level="O1", boost_config_dict=boost_config_dict) 96 >>> # Create the dataset taking MNIST as an example. Refer to 97 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py 98 >>> dataset = create_dataset() 99 >>> model.train(2, dataset) 100 """ 101 def __init__(self, init_loss_scale, loss_scale_groups): 102 super(GroupLossScaleManager, self).__init__() 103 self._loss_scale = init_loss_scale 104 self.loss_scale_groups = loss_scale_groups 105 self.loss_scale_number = 0 106 self.layer_loss_scale = None 107 self.dynamic_loss_scale = None 108 109 def set_loss_scale_status(self, loss_scale_number, init_loss_scale): 110 """ 111 Generate dynamic loss scale tuple and set overflow status list. 112 113 Args: 114 loss_scale_number (int): The number of loss scale. 115 init_loss_scale (float): The initialized loss scale. 116 """ 117 self.loss_scale_number = loss_scale_number 118 inner_list = [P._DynamicLossScale(layer=x) for x in range(loss_scale_number + 1)] # pylint: disable=W0212 119 self.layer_loss_scale = tuple(inner_list) 120 self.dynamic_loss_scale = ParameterTuple(Parameter(Tensor(1, mstype.float32), 121 name='layer_loss_scale_{}'.format(x), requires_grad=False) 122 for x in range(loss_scale_number + 2)) 123 if isinstance(init_loss_scale, list): 124 for i, value in enumerate(init_loss_scale): 125 self.dynamic_loss_scale[i + 1].set_data(value) 126 else: 127 for i in range(self.loss_scale_number): 128 self.dynamic_loss_scale[i + 1].set_data(init_loss_scale) 129 130 def update_loss_scale_status(self, layer, update_ratio): 131 """ 132 Update dynamic loss scale. 133 134 Args: 135 layer (int): Current layer. 136 update_ratio (float): The ratio of loss scale update. 137 138 Outputs: 139 float, new loss scale. 140 """ 141 layer = layer + 1 142 new_loss_scale = self.dynamic_loss_scale[layer] * update_ratio 143 P.Assign()(self.dynamic_loss_scale[layer], new_loss_scale) 144 return new_loss_scale 145 146 def construct(self, x, layer1, layer2): 147 x = self.layer_loss_scale[layer1](x, self.dynamic_loss_scale[layer1] / self.dynamic_loss_scale[layer2]) 148 return x 149 150 def get_loss_scale(self): 151 """ 152 Get loss scale value. 153 154 Returns: 155 bool, `loss_scale` value. 156 """ 157 return self._loss_scale 158 159 def get_update_cell(self): 160 """ 161 Returns the instance of :class:`mindspore.boost.GroupLossScaleManager`. 162 163 Returns: 164 :class:`mindspore.boost.GroupLossScaleManager`. 165 """ 166 return self 167