• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Implements learning rate adaptations common to most stochastic algorithms.
18 
19 #ifndef LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_
20 #define LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_
21 
22 #include <cmath>
23 #include "common_defs.h"
24 
25 namespace learning_stochastic_linear {
26 
27 class LearningRateController {
28  public:
LearningRateController()29   LearningRateController() {
30     iteration_num_ = 1;
31     lambda_ = 1.0;
32     mini_batch_size_ = 1;
33     mini_batch_counter_ = 1;
34     sample_num_ = 1;
35     mode_ = INV_LINEAR;
36     is_first_sample_ = true;
37   }
~LearningRateController()38   ~LearningRateController() {}
39   // Getters and Setters for learning rate parameter lambda_
GetLambda()40   double GetLambda() const {
41     return lambda_;
42   }
SetLambda(double lambda)43   void SetLambda(double lambda) {
44     lambda_ = lambda;
45   }
46   // Operations on current iteration number
SetIterationNumber(uint64 num)47   void SetIterationNumber(uint64 num) {
48     iteration_num_ = num;
49   }
IncrementIteration()50   void IncrementIteration() {
51     ++iteration_num_;
52   }
GetIterationNumber()53   uint64 GetIterationNumber() const {
54     return iteration_num_;
55   }
56   // Mini batch operations
GetMiniBatchSize()57   uint64 GetMiniBatchSize() const {
58     return mini_batch_size_;
59   }
SetMiniBatchSize(uint64 size)60   void SetMiniBatchSize(uint64 size) {
61     //CHECK_GT(size, 0);
62     mini_batch_size_ = size;
63   }
IncrementSample()64   void IncrementSample() {
65     // If this is the first sample we've already counted it to prevent NaNs
66     // in the learning rate computation
67     if (is_first_sample_) {
68       is_first_sample_ = false;
69       return;
70     }
71     ++sample_num_;
72     if (1 == mini_batch_size_) {
73       IncrementIteration();
74       mini_batch_counter_ = 0;
75     } else {
76       ++mini_batch_counter_;
77       if ((mini_batch_counter_ % mini_batch_size_ == 0)) {
78         IncrementIteration();
79         mini_batch_counter_ = 0;
80       }
81     }
82   }
GetMiniBatchCounter()83   uint64 GetMiniBatchCounter() const {
84     return mini_batch_counter_;
85   }
86   // Getters and setters for adaptation mode
GetAdaptationMode()87   AdaptationMode GetAdaptationMode() const {
88     return mode_;
89   }
SetAdaptationMode(AdaptationMode m)90   void SetAdaptationMode(AdaptationMode m) {
91     mode_ = m;
92   }
GetLearningRate()93   double GetLearningRate() const {
94     if (mode_ == CONST) {
95       return (1.0 / (lambda_ * mini_batch_size_));
96     } else if (mode_ == INV_LINEAR) {
97       return (1.0 / (lambda_ * iteration_num_ * mini_batch_size_));
98     } else if (mode_ == INV_QUADRATIC) {
99       return (1.0 / (lambda_ *
100                      mini_batch_size_ *
101                      (static_cast<double>(iteration_num_) * iteration_num_)));
102     } else if (mode_ == INV_SQRT) {
103       return (1.0 / (lambda_ *
104                      mini_batch_size_ *
105                      sqrt((double)iteration_num_)));
106     }
107     return 0;
108   }
CopyFrom(const LearningRateController & other)109   void CopyFrom(const LearningRateController &other) {
110     iteration_num_ = other.iteration_num_;
111     sample_num_ = other.sample_num_;
112     mini_batch_size_ = other.mini_batch_size_;
113     mini_batch_counter_ = other.mini_batch_counter_;
114     mode_ = other.mode_;
115     is_first_sample_ = other.is_first_sample_;
116   }
117  private:
118   uint64 iteration_num_;
119   uint64 sample_num_;
120   uint64 mini_batch_size_;
121   uint64 mini_batch_counter_;
122   double lambda_;
123   AdaptationMode mode_;
124   bool is_first_sample_;
125 };
126 }  // namespace learning_stochastic_linear
127 #endif  // LEARNING_STOCHASTIC_LINEAR_LEARNING_RATE_CONTROLLER_INL_H_
128