1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""LearningRateScheduler Callback class.""" 16 17import math 18import numpy as np 19 20from mindspore import log as logger 21import mindspore.common.dtype as mstype 22from mindspore.common.tensor import Tensor 23from mindspore.train.callback._callback import Callback 24from mindspore.ops import functional as F 25 26 27class LearningRateScheduler(Callback): 28 """ 29 Change the learning_rate during training. 30 31 Args: 32 learning_rate_function (Function): The function about how to change the learning rate during training. 33 34 Examples: 35 >>> from mindspore import Model 36 >>> from mindspore.train.callback import LearningRateScheduler 37 >>> import mindspore.nn as nn 38 ... 39 >>> def learning_rate_function(lr, cur_step_num): 40 ... if cur_step_num%1000 == 0: 41 ... lr = lr*0.1 42 ... return lr 43 ... 44 >>> lr = 0.1 45 >>> momentum = 0.9 46 >>> net = Net() 47 >>> loss = nn.SoftmaxCrossEntropyWithLogits() 48 >>> optim = nn.Momentum(net.trainable_params(), learning_rate=lr, momentum=momentum) 49 >>> model = Model(net, loss_fn=loss, optimizer=optim) 50 ... 51 >>> dataset = create_custom_dataset("custom_dataset_path") 52 >>> model.train(1, dataset, callbacks=[LearningRateScheduler(learning_rate_function)], 53 ... dataset_sink_mode=False) 54 """ 55 56 def __init__(self, learning_rate_function): 57 super(LearningRateScheduler, self).__init__() 58 self.learning_rate_function = learning_rate_function 59 60 def step_end(self, run_context): 61 """ 62 Change the learning_rate at the end of step. 63 64 Args: 65 run_context (RunContext): Context of the train running. 66 """ 67 cb_params = run_context.original_args() 68 arr_lr = cb_params.optimizer.learning_rate.asnumpy() 69 lr = float(np.array2string(arr_lr)) 70 new_lr = self.learning_rate_function(lr, cb_params.cur_step_num) 71 if not math.isclose(lr, new_lr, rel_tol=1e-10): 72 F.assign(cb_params.optimizer.learning_rate, Tensor(new_lr, mstype.float32)) 73 logger.info(f'At step {cb_params.cur_step_num}, learning_rate change to {new_lr}') 74