• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Purpose: A container for sparse weight vectors
18 // Maintains the sparse vector as a list of (name, value) pairs alongwith
19 // a normalizer_. All operations assume that (name, value/normalizer_) is the
20 // true value in question.
21 
22 #ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
23 #define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
24 
25 #include <math.h>
26 
27 #include <iosfwd>
28 #include <sstream>
29 #include <string>
30 #include <unordered_map>
31 
32 #include "common_defs.h"
33 
34 namespace learning_stochastic_linear {
35 
36 template<class Key = std::string, class Hash = std::unordered_map<Key, double> >
37 class SparseWeightVector {
38  public:
39   typedef Hash Wmap;
40   typedef typename Wmap::iterator Witer;
41   typedef typename Wmap::const_iterator Witer_const;
SparseWeightVector()42   SparseWeightVector() {
43     normalizer_ = 1.0;
44   }
~SparseWeightVector()45   ~SparseWeightVector() {}
SparseWeightVector(const SparseWeightVector<Key,Hash> & other)46   explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
47     CopyFrom(other);
48   }
49   void operator=(const SparseWeightVector<Key, Hash> &other) {
50     CopyFrom(other);
51   }
CopyFrom(const SparseWeightVector<Key,Hash> & other)52   void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
53     w_ = other.w_;
54     wmin_ = other.wmin_;
55     wmax_ = other.wmax_;
56     normalizer_ = other.normalizer_;
57   }
58 
59   // This function implements checks to prevent unbounded vectors. It returns
60   // true if the checks succeed and false otherwise. A vector is deemed invalid
61   // if any of these conditions are met:
62   // 1. it has no values.
63   // 2. its normalizer is nan or inf or close to zero.
64   // 3. any of its values are nan or inf.
65   // 4. its L0 norm is close to zero.
66   bool IsValid() const;
67 
68   // Normalizer getters and setters.
GetNormalizer()69   double GetNormalizer() const {
70     return normalizer_;
71   }
SetNormalizer(const double norm)72   void SetNormalizer(const double norm) {
73     normalizer_ = norm;
74   }
NormalizerMultUpdate(const double mul)75   void NormalizerMultUpdate(const double mul) {
76     normalizer_ = normalizer_ * mul;
77   }
NormalizerAddUpdate(const double add)78   void NormalizerAddUpdate(const double add) {
79     normalizer_ += add;
80   }
81 
82   // Divides all the values by the normalizer, then it resets it to 1.0
83   void ResetNormalizer();
84 
85   // Bound getters and setters.
86   // True if there is a bound with val containing the bound. false otherwise.
GetElementMinBound(const Key & fname,double * val)87   bool GetElementMinBound(const Key &fname, double *val) const {
88     return GetValue(wmin_, fname, val);
89   }
GetElementMaxBound(const Key & fname,double * val)90   bool GetElementMaxBound(const Key &fname, double *val) const {
91     return GetValue(wmax_, fname, val);
92   }
SetElementMinBound(const Key & fname,const double bound)93   void SetElementMinBound(const Key &fname, const double bound) {
94     wmin_[fname] = bound;
95   }
SetElementMaxBound(const Key & fname,const double bound)96   void SetElementMaxBound(const Key &fname, const double bound) {
97     wmax_[fname] = bound;
98   }
99   // Element getters and setters.
GetElement(const Key & fname)100   double GetElement(const Key &fname) const {
101     double val = 0;
102     GetValue(w_, fname, &val);
103     return val;
104   }
SetElement(const Key & fname,const double val)105   void SetElement(const Key &fname, const double val) {
106     //DCHECK(!isnan(val));
107     w_[fname] = val;
108   }
AddUpdateElement(const Key & fname,const double val)109   void AddUpdateElement(const Key &fname, const double val) {
110     w_[fname] += val;
111   }
MultUpdateElement(const Key & fname,const double val)112   void MultUpdateElement(const Key &fname, const double val) {
113     w_[fname] *= val;
114   }
115   // Load another weight vectors. Will overwrite the current vector.
LoadWeightVector(const SparseWeightVector<Key,Hash> & vec)116   void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
117     w_.clear();
118     w_.insert(vec.w_.begin(), vec.w_.end());
119     wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
120     wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
121     normalizer_ = vec.normalizer_;
122   }
Clear()123   void Clear() {
124     w_.clear();
125     wmax_.clear();
126     wmin_.clear();
127   }
GetMap()128   const Wmap& GetMap() const {
129     return w_;
130   }
131   // Vector Operations.
132   void AdditiveWeightUpdate(const double multiplier,
133                             const SparseWeightVector<Key, Hash> &w1,
134                             const double additive_const);
135   void AdditiveSquaredWeightUpdate(const double multiplier,
136                                    const SparseWeightVector<Key, Hash> &w1,
137                                    const double additive_const);
138   void AdditiveInvSqrtWeightUpdate(const double multiplier,
139                                    const SparseWeightVector<Key, Hash> &w1,
140                                    const double additive_const);
141   void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
142   double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
143   // L-x norm. eg. L1, L2.
144   double LxNorm(const double x) const;
145   double L2Norm() const;
146   double L1Norm() const;
147   double L0Norm(const double epsilon) const;
148   // Bound preserving updates.
149   void AdditiveWeightUpdateBounded(const double multiplier,
150                                    const SparseWeightVector<Key, Hash> &w1,
151                                    const double additive_const);
152   void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
153   void ReprojectToBounds();
154   void ReprojectL0(const double l0_norm);
155   void ReprojectL1(const double l1_norm);
156   void ReprojectL2(const double l2_norm);
157   // Reproject using the given norm.
158   // Will also rescale regularizer_ if it gets too small/large.
159   int32 Reproject(const double norm, const RegularizationType r);
160   // Convert this vector to a string, simply for debugging.
DebugString()161   std::string DebugString() const {
162     std::stringstream stream;
163     stream << *this;
164     return stream.str();
165   }
166  private:
167   // The weight map.
168   Wmap w_;
169   // Constraint bounds.
170   Wmap wmin_;
171   Wmap wmax_;
172   // Normalizing constant in magnitude measurement.
173   double normalizer_;
174   // This function is necessary since by default unordered_map inserts an
175   // element if it does not find the key through [] operator. It implements a
176   // lookup without the space overhead of an add.
GetValue(const Wmap & w1,const Key & fname,double * val)177   bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
178     Witer_const iter = w1.find(fname);
179     if (iter != w1.end()) {
180       (*val) = iter->second;
181       return true;
182     } else {
183       (*val) = 0;
184       return false;
185     }
186   }
187 };
188 
189 // Outputs a SparseWeightVector, for debugging.
190 template <class Key, class Hash>
191 std::ostream& operator<<(std::ostream &stream,
192                     const SparseWeightVector<Key, Hash> &vector) {
193   typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
194   stream << "[[ ";
195   for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
196        iter != w_map.end();
197        ++iter) {
198     stream << "<" << iter->first << ", " << iter->second << "> ";
199   }
200   return stream << " ]]";
201 };
202 
203 }  // namespace learning_stochastic_linear
204 #endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
205