• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/train/lr_scheduler.h"
18 #include <sys/stat.h>
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 #include <iostream>
23 #include <fstream>
24 #include <memory>
25 #include "include/errorcode.h"
26 #include "include/lite_session.h"
27 #include "src/common/utils.h"
28 #include "src/tensor.h"
29 
30 namespace mindspore {
31 namespace lite {
MultiplicativeLRLambda(float * lr,int epoch,void * lr_cb_data)32 int MultiplicativeLRLambda(float *lr, int epoch, void *lr_cb_data) {
33   if ((lr == nullptr) || (lr_cb_data == nullptr)) {
34     MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
35     return DONT_UPDATE_LR;
36   }
37   float mult = *(static_cast<float *>(lr_cb_data));
38   *lr = *lr * mult;
39   return UPDATE_LR;
40 }
41 
StepLRLambda(float * lr,int epoch,void * lr_cb_data)42 int StepLRLambda(float *lr, int epoch, void *lr_cb_data) {
43   if ((lr == nullptr) || (lr_cb_data == nullptr)) {
44     MS_LOG(ERROR) << "nullptr passed as input to MultiplicativeLRLambda";
45     return DONT_UPDATE_LR;
46   }
47   struct StepLRLambda *step_lr_data = (static_cast<struct StepLRLambda *>(lr_cb_data));
48   if (((epoch + 1) % step_lr_data->step_size) == 0) {
49     *lr = *lr * step_lr_data->gamma;
50     return UPDATE_LR;
51   }
52   return DONT_UPDATE_LR;
53 }
54 
LRScheduler(LR_Lambda lambda_func,void * lr_cb_data,int step)55 LRScheduler::LRScheduler(LR_Lambda lambda_func, void *lr_cb_data, int step)
56     : lambda_func_(lambda_func), lr_data_(lr_cb_data), step_(step) {}
57 
EpochEnd(const session::TrainLoopCallBackData & cb_data)58 int LRScheduler::EpochEnd(const session::TrainLoopCallBackData &cb_data) {
59   if (((cb_data.epoch_ + 1) % step_) == 0) {
60     float lr = cb_data.session_->GetLearningRate();
61     int update = lambda_func_(&lr, cb_data.epoch_, lr_data_);
62     if (update == UPDATE_LR) {
63       int ret = cb_data.session_->SetLearningRate(lr);
64       if (ret != RET_OK) {
65         MS_LOG(ERROR) << "Error setting Leraning rate in train session";
66         return mindspore::session::RET_EXIT;
67       }
68     }
69   }
70   return mindspore::session::RET_CONTINUE;
71 }
72 }  // namespace lite
73 }  // namespace mindspore
74