• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Loss scale manager abstract class."""
16from __future__ import absolute_import
17
18from mindspore import _checkparam as validator
19from mindspore import nn
20
21
22class LossScaleManager:
23    """
24    Loss scale (Magnification factor of gradients when mix precision is used) manager abstract class when using
25    mixed precision.
26
27    Derived class needs to implement all of its methods. `get_loss_scale` is used to get current loss scale value.
28    `update_loss_scale` is used to update loss scale value, `update_loss_scale` will be called during the training.
29    `get_update_cell` is used to get the instance of :class:`mindspore.nn.Cell` that is used to update the loss scale,
30    the instance will be called during the training. Currently, the `get_update_cell` is mostly used.
31
32    For example, :class:`mindspore.amp.FixedLossScaleManager` and :class:`mindspore.amp.DynamicLossScaleManager`.
33    """
34    def get_loss_scale(self):
35        """Get the value of loss scale, which is the amplification factor of the gradients."""
36
37    def update_loss_scale(self, overflow):
38        """
39        Update the loss scale value according to the status of `overflow`.
40
41        Args:
42            overflow (bool): Whether the overflow occurs during the training.
43        """
44    def get_update_cell(self):
45        """Get the instance of :class:`mindspore.nn.Cell` that is used to update the loss scale."""
46
47
48class FixedLossScaleManager(LossScaleManager):
49    """
50    Loss scale (Magnification factor of gradients when mix precision is used) manager with a fixed loss scale value,
51    inherits from :class:`mindspore.amp.LossScaleManager`.
52
53    Args:
54        loss_scale (float): Magnification factor of gradients. Note that if `drop_overflow_update` is set to ``False`` ,
55            the value of `loss_scale` in optimizer should be set to the same as here. Default: ``128.0`` .
56        drop_overflow_update (bool): Whether to execute optimizer if there is an overflow.
57            If ``True`` , the optimizer will
58            not executed when overflow occurs. Default: ``True`` .
59
60    Examples:
61        >>> import mindspore as ms
62        >>> from mindspore import amp, nn
63        >>>
64        >>> # Define the network structure of LeNet5. Refer to
65        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
66        >>> net = LeNet5()
67        >>> loss_scale = 1024.0
68        >>> loss_scale_manager = amp.FixedLossScaleManager(loss_scale, False)
69        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
70        >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
71    """
72    def __init__(self, loss_scale=128.0, drop_overflow_update=True):
73        if loss_scale < 1:
74            raise ValueError("The argument 'loss_scale' must be >= 1, "
75                             "but got {}".format(loss_scale))
76        self._loss_scale = loss_scale
77        self._drop_overflow_update = drop_overflow_update
78
79    def get_loss_scale(self):
80        """
81        Get loss scale value.
82
83        Returns:
84            bool, `loss_scale` value.
85        """
86        return self._loss_scale
87
88    def get_drop_overflow_update(self):
89        """
90        Get `drop_overflow_update`, whether to drop optimizer update for current step when there is an overflow.
91
92        Returns:
93            bool, `drop_overflow_update` value.
94        """
95        return self._drop_overflow_update
96
97    def update_loss_scale(self, overflow):
98        """
99        Update loss scale value. The interface at :class:`mindspore.amp.FixedLossScaleManager` will do nothing.
100
101        Args:
102            overflow (bool): Whether it overflows.
103        """
104
105    def get_update_cell(self):
106        """
107        Returns the instance of :class:`mindspore.nn.Cell` that used to update the loss scale which will be called at
108        :class:`mindspore.nn.TrainOneStepWithLossScaleCell`. As the loss scale is fixed in this class, the instance
109        will do nothing.
110
111        Returns:
112            None or :class:`mindspore.nn.FixedLossScaleUpdateCell`. Instance of
113            :class:`mindspore.nn.FixedLossScaleUpdateCell` when `drop_overflow_update` is True. None when
114            `drop_overflow_update` is False.
115        """
116        if not self._drop_overflow_update:
117            return None
118        return nn.FixedLossScaleUpdateCell(self._loss_scale)
119
120
121class DynamicLossScaleManager(LossScaleManager):
122    """
123    Loss scale(Magnification factor of gradients when mix precision is used) manager with loss scale dynamically
124    adjusted, inherits from :class:`mindspore.amp.LossScaleManager`.
125
126    Args:
127        init_loss_scale (float): Initialize loss scale. Default: ``2 ** 24`` .
128        scale_factor (int): Coefficient of increase and decrease. Default: ``2`` .
129        scale_window (int): Maximum continuous normal steps when there is no overflow. Default: ``2000`` .
130
131    Supported Platforms:
132        ``Ascend`` ``GPU``
133
134    Examples:
135        >>> import mindspore as ms
136        >>> from mindspore import amp, nn
137        >>>
138        >>> # Define the network structure of LeNet5. Refer to
139        >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
140        >>> net = LeNet5()
141        >>> loss_scale_manager = amp.DynamicLossScaleManager()
142        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
143        >>> model = ms.Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
144    """
145    def __init__(self,
146                 init_loss_scale=2 ** 24,
147                 scale_factor=2,
148                 scale_window=2000):
149        if init_loss_scale < 1.0:
150            raise ValueError("The argument 'init_loss_scale' must be > 1, but got {}".format(init_loss_scale))
151        self.loss_scale = init_loss_scale
152        validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__)
153        self.scale_window = scale_window
154        if scale_factor <= 0:
155            raise ValueError("The argument 'scale_factor' must be > 0, but got {}".format(scale_factor))
156        self.scale_factor = scale_factor
157        self.increase_ratio = scale_factor
158        self.decrease_ratio = 1 / scale_factor
159        self.cur_iter = 1
160        self.last_overflow_iter = 0
161        self.bad_step_max = 1000
162        self.bad_step = 0
163
164    def get_loss_scale(self):
165        """
166        Get the current loss scale value.
167
168        Returns:
169            float, `loss_scale` value.
170        """
171        return self.loss_scale
172
173    def update_loss_scale(self, overflow):
174        """
175        Update the loss scale value according to the status of `overflow`. If overflow occurs, decrease loss scale per
176        `scale_window`, otherwise, increase the loss scale.
177
178        Args:
179            overflow (bool): Whether it overflows.
180        """
181        if overflow:
182            self.loss_scale = max(self.loss_scale * self.decrease_ratio, 1)
183            self.last_overflow_iter = self.cur_iter
184            self.bad_step += 1
185        else:
186            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
187                self.loss_scale *= self.increase_ratio
188            self.bad_step = 0
189
190        if self.bad_step > self.bad_step_max:
191            raise RuntimeError("Dynamic loss scale Continuous overflow ", self.bad_step,
192                               " times, has exceeded maximum threshold.")
193
194        self.cur_iter += 1
195
196    def get_drop_overflow_update(self):
197        """
198        Whether to drop optimizer update for current step when there is an overflow.
199
200        Returns:
201            bool, always True.
202        """
203        return True
204
205    def get_update_cell(self):
206        """
207        Returns the instance of :class:`mindspore.nn.Cell` that is used to update the loss scale which will be called at
208        :class:`mindspore.nn.TrainOneStepWithLossScaleCell`.
209
210        Returns:
211            :class:`mindspore.nn.DynamicLossScaleUpdateCell`.
212        """
213        return nn.DynamicLossScaleUpdateCell(self.loss_scale, self.scale_factor, self.scale_window)
214