• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Dynamic Learning Rate"""
16import math
17
18from mindspore._checkparam import Validator as validator
19
20
21def piecewise_constant_lr(milestone, learning_rates):
22    r"""
23    Get piecewise constant learning rate.
24
25    Calculate learning rate by given `milestone` and `learning_rates`. Let the value of `milestone` be
26    :math:`(M_1, M_2, ..., M_N)` and the value of `learning_rates` be :math:`(x_1, x_2, ..., x_N)`. N is the length of
27    `milestone`. Let the output learning rate be `y`.
28
29    .. math::
30        y[i] = x_t,\ for\ i \in [M_{t-1}, M_t)
31
32    Args:
33        milestone (Union[list[int], tuple[int]]): A list of milestone. This list is a monotone increasing list.
34            Every element is a milestone step, and must be greater than 0.
35        learning_rates (Union[list[float], tuple[float]]): A list of learning rates.
36
37    Returns:
38        list[float]. The size of list is :math:`M_N`.
39
40    Examples:
41        >>> milestone = [2, 5, 10]
42        >>> learning_rates = [0.1, 0.05, 0.01]
43        >>> output = piecewise_constant_lr(milestone, learning_rates)
44        >>> print(output)
45        [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01]
46    """
47    validator.check_value_type('milestone', milestone, (tuple, list))
48    validator.check_value_type('learning_rates', learning_rates, (tuple, list))
49    if len(milestone) != len(learning_rates):
50        raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.')
51
52    lr = []
53    last_item = 0
54    for i, item in enumerate(milestone):
55        validator.check_positive_int(item, f'milestone[{i}]')
56        validator.check_is_float(learning_rates[i], f'learning_rates[{i}]')
57        if item < last_item:
58            raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]')
59        lr += [learning_rates[i]] * (item - last_item)
60        last_item = item
61
62    return lr
63
64
65def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair):
66    validator.check_positive_int(total_step, 'total_step')
67    validator.check_positive_int(step_per_epoch, 'step_per_epoch')
68    validator.check_positive_int(decay_epoch, 'decay_epoch')
69    validator.check_positive_float(learning_rate, 'learning_rate')
70    validator.check_is_float(learning_rate, 'learning_rate')
71    validator.check_positive_float(decay_rate, 'decay_rate')
72    validator.check_is_float(decay_rate, 'decay_rate')
73    validator.check_value_type('is_stair', is_stair, [bool])
74
75
76def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
77    r"""
78    Calculates learning rate base on exponential decay function.
79
80    For the i-th step, the formula of computing decayed_learning_rate[i] is:
81
82    .. math::
83        decayed\_learning\_rate[i] = learning\_rate * decay\_rate^{\frac{current\_epoch}{decay\_epoch}}
84
85    Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
86
87    Args:
88        learning_rate (float): The initial value of learning rate.
89        decay_rate (float): The decay rate.
90        total_step (int): The total number of steps.
91        step_per_epoch (int): The number of steps in per epoch.
92        decay_epoch (int): A value used to calculate decayed learning rate.
93        is_stair (bool): If true, learning rate is decayed once every `decay_epoch` times. Default: False.
94
95    Returns:
96        list[float]. The size of list is `total_step`.
97
98    Examples:
99        >>> learning_rate = 0.1
100        >>> decay_rate = 0.9
101        >>> total_step = 6
102        >>> step_per_epoch = 2
103        >>> decay_epoch = 1
104        >>> output = exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch)
105        >>> print(output)
106        [0.1, 0.1, 0.09000000000000001, 0.09000000000000001, 0.08100000000000002, 0.08100000000000002]
107    """
108    _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
109
110    lr = []
111    for i in range(total_step):
112        if is_stair:
113            lr.append(learning_rate * decay_rate ** math.floor(math.floor(i / step_per_epoch) / decay_epoch))
114        else:
115            lr.append(learning_rate * decay_rate ** (math.floor(i / step_per_epoch) / decay_epoch))
116    return lr
117
118
119def natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
120    r"""
121    Calculates learning rate base on natural exponential decay function.
122
123    For the i-th step, the formula of computing decayed_learning_rate[i] is:
124
125    .. math::
126        decayed\_learning\_rate[i] = learning\_rate * e^{-decay\_rate * current\_epoch}
127
128    Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
129
130    Args:
131        learning_rate (float): The initial value of learning rate.
132        decay_rate (float): The decay rate.
133        total_step (int): The total number of steps.
134        step_per_epoch (int): The number of steps in per epoch.
135        decay_epoch (int): A value used to calculate decayed learning rate.
136        is_stair (bool): If true, learning rate is decayed once every `decay_epoch` times. Default: False.
137
138    Returns:
139        list[float]. The size of list is `total_step`.
140
141    Examples:
142        >>> learning_rate = 0.1
143        >>> decay_rate = 0.9
144        >>> total_step = 6
145        >>> step_per_epoch = 2
146        >>> decay_epoch = 2
147        >>> output = natural_exp_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
148        >>> print(output)
149        [0.1, 0.1, 0.1, 0.1, 0.016529888822158657, 0.016529888822158657]
150    """
151    _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
152
153    function = lambda x, y: x
154    if is_stair:
155        function = lambda x, y: math.floor(x / y) * y
156
157    lr = []
158    for i in range(total_step):
159        lr.append(learning_rate * math.e ** (-decay_rate * function(math.floor(i / step_per_epoch), decay_epoch)))
160    return lr
161
162
163def inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False):
164    r"""
165    Calculates learning rate base on inverse-time decay function.
166
167    For the i-th step, the formula of computing decayed_learning_rate[i] is:
168
169    .. math::
170        decayed\_learning\_rate[i] = learning\_rate / (1 + decay\_rate * current\_epoch / decay\_epoch)
171
172    Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
173
174    Args:
175        learning_rate (float): The initial value of learning rate.
176        decay_rate (float): The decay rate.
177        total_step (int): The total number of steps.
178        step_per_epoch (int): The number of steps in per epoch.
179        decay_epoch (int): A value used to calculate decayed learning rate.
180        is_stair (bool): If true, learning rate is decayed once every `decay_epoch` times. Default: False.
181
182    Returns:
183        list[float]. The size of list is `total_step`.
184
185    Examples:
186        >>> learning_rate = 0.1
187        >>> decay_rate = 0.5
188        >>> total_step = 6
189        >>> step_per_epoch = 1
190        >>> decay_epoch = 1
191        >>> output = inverse_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, True)
192        >>> print(output)
193        [0.1, 0.06666666666666667, 0.05, 0.04, 0.03333333333333333, 0.028571428571428574]
194    """
195    _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair)
196
197    lr = []
198    for i in range(total_step):
199        if is_stair:
200            lr.append(learning_rate / (1 + decay_rate * math.floor(math.floor(i / step_per_epoch) / decay_epoch)))
201        else:
202            lr.append(learning_rate / (1 + decay_rate * math.floor(i / step_per_epoch) / decay_epoch))
203    return lr
204
205
206def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch):
207    r"""
208    Calculates learning rate base on cosine decay function.
209
210    For the i-th step, the formula of computing decayed_learning_rate[i] is:
211
212    .. math::
213        decayed\_learning\_rate[i] = min\_learning\_rate + 0.5 * (max\_learning\_rate - min\_learning\_rate) *
214        (1 + cos(\frac{current\_epoch}{decay\_epoch}\pi))
215
216    Where :math:`current\_epoch=floor(\frac{i}{step\_per\_epoch})`.
217
218    Args:
219        min_lr (float): The minimum value of learning rate.
220        max_lr (float): The maximum value of learning rate.
221        total_step (int): The total number of steps.
222        step_per_epoch (int): The number of steps in per epoch.
223        decay_epoch (int): A value used to calculate decayed learning rate.
224
225    Returns:
226        list[float]. The size of list is `total_step`.
227
228    Examples:
229        >>> min_lr = 0.01
230        >>> max_lr = 0.1
231        >>> total_step = 6
232        >>> step_per_epoch = 2
233        >>> decay_epoch = 2
234        >>> output = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch)
235        >>> print(output)
236        [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01]
237    """
238    if not isinstance(min_lr, float):
239        raise TypeError("min_lr must be float.")
240    validator.check_non_negative_float(min_lr, "min_lr", None)
241    validator.check_positive_float(max_lr, 'max_lr')
242    validator.check_is_float(max_lr, 'max_lr')
243    validator.check_positive_int(total_step, 'total_step')
244    validator.check_positive_int(step_per_epoch, 'step_per_epoch')
245    validator.check_positive_int(decay_epoch, 'decay_epoch')
246    if min_lr >= max_lr:
247        raise ValueError('The `max_lr` should be greater than the `min_lr`.')
248
249    delta = 0.5 * (max_lr - min_lr)
250    lr = []
251    for i in range(total_step):
252        tmp_epoch = min(math.floor(i / step_per_epoch), decay_epoch)
253        lr.append(min_lr + delta * (1 + math.cos(math.pi * tmp_epoch / decay_epoch)))
254    return lr
255
256
257def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power,
258                        update_decay_epoch=False):
259    r"""
260    Calculates learning rate base on polynomial decay function.
261
262    For the i-th step, the formula of computing decayed_learning_rate[i] is:
263
264    .. math::
265        decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) *
266        (1 - tmp\_epoch / tmp\_decay\_epoch)^{power} + end\_learning\_rate
267
268    Where:
269
270    .. math::
271        tmp\_epoch = min(current\_epoch, decay\_epoch)
272
273    .. math::
274        current\_epoch=floor(\frac{i}{step\_per\_epoch})
275
276    .. math::
277        tmp\_decay\_epoch = decay\_epoch
278
279    If `update_decay_epoch` is true, update the value of `tmp_decay_epoch` every epoch. The formula is:
280
281    .. math::
282        tmp\_decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)
283
284    Args:
285        learning_rate (float): The initial value of learning rate.
286        end_learning_rate (float): The end value of learning rate.
287        total_step (int): The total number of steps.
288        step_per_epoch (int): The number of steps in per epoch.
289        decay_epoch (int): A value used to calculate decayed learning rate.
290        power (float): A value used to calculate decayed learning rate. This parameter must be greater than 0.
291        update_decay_epoch (bool): If true, update `decay_epoch`. Default: False.
292
293    Returns:
294        list[float]. The size of list is `total_step`.
295
296    Examples:
297        >>> learning_rate = 0.1
298        >>> end_learning_rate = 0.01
299        >>> total_step = 6
300        >>> step_per_epoch = 2
301        >>> decay_epoch = 2
302        >>> power = 0.5
303        >>> r = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power)
304        >>> print(r)
305        [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01]
306    """
307    validator.check_positive_float(learning_rate, 'learning_rate')
308    validator.check_is_float(learning_rate, 'learning_rate')
309    if not isinstance(end_learning_rate, float):
310        raise TypeError("end_learning_rate must be float.")
311    validator.check_non_negative_float(end_learning_rate, "end_learning_rate", None)
312    validator.check_positive_float(power, 'power')
313    validator.check_is_float(power, 'power')
314    validator.check_positive_int(total_step, 'total_step')
315    validator.check_positive_int(step_per_epoch, 'step_per_epoch')
316    validator.check_positive_int(decay_epoch, 'decay_epoch')
317    validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool])
318
319    origin_decay_epoch = decay_epoch
320    function = lambda x, y: (x, min(x, y))
321    if update_decay_epoch:
322        function = lambda x, y: (origin_decay_epoch * max(math.ceil(y / origin_decay_epoch), 1), y)
323
324    lr = []
325    delta = learning_rate - end_learning_rate
326    for i in range(total_step):
327        current_epoch = math.floor(i / step_per_epoch)
328        decay_epoch, tmp_epoch = function(decay_epoch, current_epoch)
329        lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate)
330    return lr
331
332
333def warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch):
334    r"""
335    Gets learning rate warming up.
336
337    For the i-th step, the formula of computing warmup_learning_rate[i] is:
338
339    .. math::
340        warmup\_learning\_rate[i] = learning\_rate * tmp\_epoch / tmp\_warmup\_epoch
341
342    Where :math:`tmp\_epoch=min(current\_epoch, warmup\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`
343
344    Args:
345        learning_rate (float): The initial value of learning rate.
346        total_step (int): The total number of steps.
347        step_per_epoch (int): The number of steps in per epoch.
348        warmup_epoch (int): A value that determines the epochs of the learning rate is warmed up.
349
350    Returns:
351        list[float]. The size of list is `total_step`.
352
353    Examples:
354        >>> learning_rate = 0.1
355        >>> total_step = 6
356        >>> step_per_epoch = 2
357        >>> warmup_epoch = 2
358        >>> output = warmup_lr(learning_rate, total_step, step_per_epoch, warmup_epoch)
359        >>> print(output)
360        [0.0, 0.0, 0.05, 0.05, 0.1, 0.1]
361    """
362    if not isinstance(learning_rate, float):
363        raise TypeError("learning_rate must be float.")
364    validator.check_non_negative_float(learning_rate, "learning_rate", None)
365    validator.check_positive_int(warmup_epoch, 'warmup_epoch')
366    validator.check_positive_int(total_step, 'total_step')
367    validator.check_positive_int(step_per_epoch, 'step_per_epoch')
368
369    function = lambda x, y: (x, min(x, y))
370
371    lr = []
372    for i in range(total_step):
373        current_epoch = math.floor(i / step_per_epoch)
374        warmup_epoch, tmp_epoch = function(warmup_epoch, current_epoch)
375        lr.append(learning_rate * tmp_epoch / warmup_epoch)
376    return lr
377
378
379__all__ = [
380    'piecewise_constant_lr',
381    'exponential_decay_lr',
382    'natural_exp_decay_lr',
383    'inverse_decay_lr',
384    'cosine_decay_lr',
385    'polynomial_decay_lr',
386    'warmup_lr'
387]
388