• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Loss scale manager abstract class."""
16
17from .._checkparam import Validator as validator
18from .. import nn
19
20
21class LossScaleManager:
22    """
23    Loss scale manager abstract class.
24
25    Derive FixedLossScaleManager and DynamicLossScaleManager that override all LossScaleManager's method.
26    """
27    def get_loss_scale(self):
28        """Get loss scale value."""
29
30    def update_loss_scale(self, overflow):
31        """
32        Update loss scale value.
33
34        Args:
35            overflow (bool): Whether it overflows.
36        """
37    def get_update_cell(self):
38        """Get the loss scaling update logic cell."""
39
40
41class FixedLossScaleManager(LossScaleManager):
42    """
43    Loss scale with a fixed value, inherits from LossScaleManager.
44
45    Args:
46        loss_scale (float): Loss scale. Note that if `drop_overflow_update` is set to False, the value of `loss_scale`
47            in optimizer that you used need to be set to the same value as here. Default: 128.0.
48        drop_overflow_update (bool): Whether to execute optimizer if there is an overflow. If True, the optimizer will
49            not executed when overflow occurs. Default: True.
50
51    Examples:
52        >>> from mindspore import Model, nn, FixedLossScaleManager
53        >>>
54        >>> net = Net()
55        >>> #1) Drop the parameter update if there is an overflow
56        >>> loss_scale_manager = FixedLossScaleManager()
57        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
58        >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
59        >>>
60        >>> #2) Execute parameter update even if overflow occurs
61        >>> loss_scale = 1024.0
62        >>> loss_scale_manager = FixedLossScaleManager(loss_scale, False)
63        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=loss_scale)
64        >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
65    """
66    def __init__(self, loss_scale=128.0, drop_overflow_update=True):
67        if loss_scale < 1:
68            raise ValueError("The argument 'loss_scale' must be >= 1, "
69                             "but got {}".format(loss_scale))
70        self._loss_scale = loss_scale
71        self._drop_overflow_update = drop_overflow_update
72
73    def get_loss_scale(self):
74        """
75        Get loss scale value.
76
77        Returns:
78            bool, `loss_scale` value.
79        """
80        return self._loss_scale
81
82    def get_drop_overflow_update(self):
83        """
84        Get the flag whether to drop optimizer update when there is an overflow.
85
86        Returns:
87            bool, `drop_overflow_update` value.
88        """
89        return self._drop_overflow_update
90
91    def update_loss_scale(self, overflow):
92        """
93        Update loss scale value. The interface at `FixedLossScaleManager` will do nothing.
94
95        Args:
96            overflow (bool): Whether it overflows.
97        """
98
99    def get_update_cell(self):
100        """
101        Returns the update cell for `TrainOneStepWithLossScaleCell`.
102
103        Returns:
104            None or Cell. Cell object, used to update `loss_scale`, when `drop_overflow_update` is True. None when
105            `drop_overflow_update` is False.
106        """
107        if not self._drop_overflow_update:
108            return None
109        return nn.FixedLossScaleUpdateCell(self._loss_scale)
110
111
112class DynamicLossScaleManager(LossScaleManager):
113    """
114    Loss scale that dynamically adjusts itself, inherits from LossScaleManager.
115
116    Args:
117        init_loss_scale (float): Initialize loss scale. Default: 2**24.
118        scale_factor (int): Coefficient of increase and decrease. Default: 2.
119        scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000.
120
121    Examples:
122        >>> from mindspore import Model, nn, DynamicLossScaleManager
123        >>>
124        >>> net = Net()
125        >>> loss_scale_manager = DynamicLossScaleManager()
126        >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
127        >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
128    """
129    def __init__(self,
130                 init_loss_scale=2 ** 24,
131                 scale_factor=2,
132                 scale_window=2000):
133        if init_loss_scale < 1.0:
134            raise ValueError("The argument 'init_loss_scale' must be > 1, but got {}".format(init_loss_scale))
135        self.loss_scale = init_loss_scale
136        validator.check_positive_int(scale_window, "scale_window", self.__class__.__name__)
137        self.scale_window = scale_window
138        if scale_factor <= 0:
139            raise ValueError("The argument 'scale_factor' should be > 0, but got {}".format(scale_factor))
140        self.scale_factor = scale_factor
141        self.increase_ratio = scale_factor
142        self.decrease_ratio = 1 / scale_factor
143        self.cur_iter = 1
144        self.last_overflow_iter = 0
145        self.bad_step_max = 1000
146        self.bad_step = 0
147
148    def get_loss_scale(self):
149        """
150        Get loss scale value.
151
152        Returns:
153            bool, `loss_scale` value.
154        """
155        return self.loss_scale
156
157    def update_loss_scale(self, overflow):
158        """
159        Update loss scale value.
160
161        Args:
162            overflow (bool): Whether it overflows.
163        """
164        if overflow:
165            self.loss_scale = max(self.loss_scale * self.decrease_ratio, 1)
166            self.last_overflow_iter = self.cur_iter
167            self.bad_step += 1
168        else:
169            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
170                self.loss_scale *= self.increase_ratio
171            self.bad_step = 0
172
173        if self.bad_step > self.bad_step_max:
174            raise RuntimeError("Dynamic loss scale Continuous overflow ", self.bad_step,
175                               " times, has exceeded maximum threshold.")
176
177        self.cur_iter += 1
178
179    def get_drop_overflow_update(self):
180        """
181        Get the flag whether to drop optimizer update when there is an overflow.
182
183        Returns:
184            bool, always return True at `DynamicLossScaleManager`.
185        """
186        return True
187
188    def get_update_cell(self):
189        """
190        Returns the update cell for `TrainOneStepWithLossScaleCell`.
191
192        Returns:
193            Cell, cell object used to update `loss_scale`.
194        """
195        return nn.DynamicLossScaleUpdateCell(self.loss_scale, self.scale_factor, self.scale_window)
196