• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""LearningRateScheduler Callback class."""
16
17import math
18import numpy as np
19
20from mindspore import log as logger
21import mindspore.common.dtype as mstype
22from mindspore.common.tensor import Tensor
23from mindspore.train.callback._callback import Callback
24from mindspore.ops import functional as F
25
26
27class LearningRateScheduler(Callback):
28    """
29    Change the learning_rate during training.
30
31    Args:
32        learning_rate_function (Function): The function about how to change the learning rate during training.
33
34    Examples:
35        >>> from mindspore import Model
36        >>> from mindspore.train.callback import LearningRateScheduler
37        >>> import mindspore.nn as nn
38        ...
39        >>> def learning_rate_function(lr, cur_step_num):
40        ...     if cur_step_num%1000 == 0:
41        ...         lr = lr*0.1
42        ...     return lr
43        ...
44        >>> lr = 0.1
45        >>> momentum = 0.9
46        >>> net = Net()
47        >>> loss = nn.SoftmaxCrossEntropyWithLogits()
48        >>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum)
49        >>> model = Model(net, loss_fn=loss, optimizer=optim)
50        ...
51        >>> dataset = create_custom_dataset("custom_dataset_path")
52        >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)],
53        ...             dataset_sink_mode=False)
54    """
55
56    def __init__(self, learning_rate_function):
57        super(LearningRateScheduler, self).__init__()
58        self.learning_rate_function = learning_rate_function
59
60    def step_end(self, run_context):
61        """
62        Change the learning_rate at the end of step.
63
64        Args:
65            run_context (RunContext): Context of the train running.
66        """
67        cb_params = run_context.original_args()
68        arr_lr = cb_params.optimizer.learning_rate.asnumpy()
69        lr = float(np.array2string(arr_lr))
70        new_lr = self.learning_rate_function(lr, cb_params.cur_step_num)
71        if not math.isclose(lr, new_lr, rel_tol=1e-10):
72            F.assign(cb_params.optimizer.learning_rate, Tensor(new_lr, mstype.float32))
73            logger.info(f'At step {cb_params.cur_step_num}, learning_rate change to {new_lr}')
74