• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Group Loss Scale Manager"""
16from __future__ import absolute_import
17from __future__ import division
18
19from mindspore.nn.cell import Cell
20import mindspore.common.dtype as mstype
21from mindspore.ops import operations as P
22from mindspore.common.tensor import Tensor
23from mindspore.common.parameter import Parameter, ParameterTuple
24
25
26__all__ = ["GroupLossScaleManager"]
27
28
29class GroupLossScaleManager(Cell):
30    r"""
31    Enhanced hybrid precision algorithm supports multi-layer application of different loss scales and
32    dynamic updating of loss scales.
33
34    Args:
35        init_loss_scale (Number): The initialized loss scale value.
36        loss_scale_groups (List): The loss scale groups, which are divided from the param list.
37
38    Inputs:
39        - **x** (Tensor) - The output of last operator.
40        - **layer1** (Int) - Current network layer value.
41        - **layer2** (Int) - Last network layer value.
42
43    Outputs:
44        - **out** (Tensor) - A tensor with a group of loss scale tags that marks
45          the loss scale group number of the current tensor.
46
47    Supported Platforms:
48        ``Ascend``
49
50    Examples:
51        >>> import mindspore as ms
52        >>> from mindspore import boost, nn
53        >>>
54        >>> class Net(nn.Cell):
55        ...     def __init__(self, enhanced_amp, num_class=10, num_channel=1):
56        ...         super(Net, self).__init__()
57        ...         self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
58        ...         self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
59        ...         self.fc1 = nn.Dense(16*5*5, 120, weight_init='ones')
60        ...         self.fc2 = nn.Dense(120, 84, weight_init='ones')
61        ...         self.fc3 = nn.Dense(84, num_class, weight_init='ones')
62        ...         self.relu = nn.ReLU()
63        ...         self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
64        ...         self.flatten = nn.Flatten()
65        ...         self.enhanced_amp = enhanced_amp
66        ...
67        ...     def construct(self, x):
68        ...         x = self.enhanced_amp(x, 0, 1)
69        ...         x = self.max_pool2d(self.relu(self.conv1(x)))
70        ...         x = self.max_pool2d(self.relu(self.conv2(x)))
71        ...         x = self.flatten(x)
72        ...         x = self.enhanced_amp(x, 1, 2)
73        ...         x = self.relu(self.fc1(x))
74        ...         x = self.relu(self.fc2(x))
75        ...         x = self.fc3(x)
76        ...         x = self.enhanced_amp(x, 2, 3)
77        ...         return x
78        >>>
79        >>> loss_scale_manager = boost.GroupLossScaleManager(4096, [])
80        >>> net = Net(loss_scale_manager)
81        >>> param_group1 = []
82        >>> param_group2 = []
83        >>> for param in net.trainable_params():
84        ...     if 'conv' in param.name:
85        ...         param_group1.append(param)
86        ...     else:
87        ...         param_group2.append(param)
88        >>> loss_scale_manager.loss_scale_groups = [param_group1, param_group2]
89        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
90        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
91        >>> boost_config_dict = {"boost": {"mode": "manual", "less_bn": False, "grad_freeze": False, "adasum": False,
92        ...                      "grad_accumulation": False, "dim_reduce": False, "loss_scale_group": True}}
93        >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim, metrics=None,
94        ...                        loss_scale_manager=loss_scale_manager,
95        ...                        boost_level="O1", boost_config_dict=boost_config_dict)
96        >>> # Create the dataset taking MNIST as an example. Refer to
97        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
98        >>> dataset = create_dataset()
99        >>> model.train(2, dataset)
100    """
101    def __init__(self, init_loss_scale, loss_scale_groups):
102        super(GroupLossScaleManager, self).__init__()
103        self._loss_scale = init_loss_scale
104        self.loss_scale_groups = loss_scale_groups
105        self.loss_scale_number = 0
106        self.layer_loss_scale = None
107        self.dynamic_loss_scale = None
108
109    def set_loss_scale_status(self, loss_scale_number, init_loss_scale):
110        """
111        Generate dynamic loss scale tuple and set overflow status list.
112
113        Args:
114            loss_scale_number (int): The number of loss scale.
115            init_loss_scale (float): The initialized loss scale.
116        """
117        self.loss_scale_number = loss_scale_number
118        inner_list = [P._DynamicLossScale(layer=x) for x in range(loss_scale_number + 1)] # pylint: disable=W0212
119        self.layer_loss_scale = tuple(inner_list)
120        self.dynamic_loss_scale = ParameterTuple(Parameter(Tensor(1, mstype.float32),
121                                                           name='layer_loss_scale_{}'.format(x), requires_grad=False)
122                                                 for x in range(loss_scale_number + 2))
123        if isinstance(init_loss_scale, list):
124            for i, value in enumerate(init_loss_scale):
125                self.dynamic_loss_scale[i + 1].set_data(value)
126        else:
127            for i in range(self.loss_scale_number):
128                self.dynamic_loss_scale[i + 1].set_data(init_loss_scale)
129
130    def update_loss_scale_status(self, layer, update_ratio):
131        """
132        Update dynamic loss scale.
133
134        Args:
135            layer (int): Current layer.
136            update_ratio (float): The ratio of loss scale update.
137
138        Outputs:
139            float, new loss scale.
140        """
141        layer = layer + 1
142        new_loss_scale = self.dynamic_loss_scale[layer] * update_ratio
143        P.Assign()(self.dynamic_loss_scale[layer], new_loss_scale)
144        return new_loss_scale
145
146    def construct(self, x, layer1, layer2):
147        x = self.layer_loss_scale[layer1](x, self.dynamic_loss_scale[layer1] / self.dynamic_loss_scale[layer2])
148        return x
149
150    def get_loss_scale(self):
151        """
152        Get loss scale value.
153
154        Returns:
155            bool, `loss_scale` value.
156        """
157        return self._loss_scale
158
159    def get_update_cell(self):
160        """
161        Returns the instance of :class:`mindspore.boost.GroupLossScaleManager`.
162
163        Returns:
164            :class:`mindspore.boost.GroupLossScaleManager`.
165        """
166        return self
167