1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "include/train/lr_scheduler.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "include/lite_session.h"
27 #include "src/common/utils.h"
28 #include "src/tensor.h"
29
30 namespace mindspore {
31 namespace lite {
MultiplicativeLRLambda(float * lr,int epoch,void * lr_cb_data)32 int MultiplicativeLRLambda(float *lr, int epoch, void *lr_cb_data) {
33 if ((lr == nullptr) || (lr_cb_data == nullptr)) {
34 MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
35 return DONT_UPDATE_LR;
36 }
37 float mult = *(static_cast<float *>(lr_cb_data));
38 *lr = *lr * mult;
39 return UPDATE_LR;
40 }
41
StepLRLambda(float * lr,int epoch,void * lr_cb_data)42 int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
43 if ((lr == nullptr) || (lr_cb_data == nullptr)) {
44 MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
45 return DONT_UPDATE_LR;
46 }
47 struct StepLRLambda *step_lr_data = (static_cast<struct StepLRLambda *>(lr_cb_data));
48 if (((epoch + 1) % step_lr_data->step_size) == 0) {
49 *lr = *lr * step_lr_data->gamma;
50 return UPDATE_LR;
51 }
52 return DONT_UPDATE_LR;
53 }
54
LRScheduler(LR_Lambda lambda_func,void * lr_cb_data,int step)55 LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step)
56 : lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {}
57
EpochEnd(const session::TrainLoopCallBackData & cb_data)58 int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
59 if (((cb_data.epoch_ + 1) % step_) == 0) {
60 float lr = cb_data.session_->GetLearningRate();
61 int update = lambda_func_(&lr, cb_data.epoch_, lr_data_);
62 if (update == UPDATE_LR) {
63 int ret = cb_data.session_->SetLearningRate(lr);
64 if (ret != RET_OK) {
65 MS_LOG(ERROR) << "Error setting Leraning rate in train session";
66 return mindspore::session::RET_EXIT;
67 }
68 }
69 }
70 return mindspore::session::RET_CONTINUE;
71 }
72 } // namespace lite
73 } // namespace mindspore
74