• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Purpose: A container for sparse weight vectors
18 // Maintains the sparse vector as a list of (name, value) pairs alongwith
19 // a normalizer_. All operations assume that (name, value/normalizer_) is the
20 // true value in question.
21 
22 #ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
23 #define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
24 
25 #include <hash_map>
26 #include <iosfwd>
27 #include <math.h>
28 #include <sstream>
29 #include <string>
30 
31 #include "common_defs.h"
32 
33 namespace learning_stochastic_linear {
34 
35 template<class Key = std::string, class Hash = std::hash_map<Key, double> >
36 class SparseWeightVector {
37  public:
38   typedef Hash Wmap;
39   typedef typename Wmap::iterator Witer;
40   typedef typename Wmap::const_iterator Witer_const;
SparseWeightVector()41   SparseWeightVector() {
42     normalizer_ = 1.0;
43   }
~SparseWeightVector()44   ~SparseWeightVector() {}
SparseWeightVector(const SparseWeightVector<Key,Hash> & other)45   explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
46     CopyFrom(other);
47   }
48   void operator=(const SparseWeightVector<Key, Hash> &other) {
49     CopyFrom(other);
50   }
CopyFrom(const SparseWeightVector<Key,Hash> & other)51   void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
52     w_ = other.w_;
53     wmin_ = other.wmin_;
54     wmax_ = other.wmax_;
55     normalizer_ = other.normalizer_;
56   }
57 
58   // This function implements checks to prevent unbounded vectors. It returns
59   // true if the checks succeed and false otherwise. A vector is deemed invalid
60   // if any of these conditions are met:
61   // 1. it has no values.
62   // 2. its normalizer is nan or inf or close to zero.
63   // 3. any of its values are nan or inf.
64   // 4. its L0 norm is close to zero.
65   bool IsValid() const;
66 
67   // Normalizer getters and setters.
GetNormalizer()68   double GetNormalizer() const {
69     return normalizer_;
70   }
SetNormalizer(const double norm)71   void SetNormalizer(const double norm) {
72     normalizer_ = norm;
73   }
NormalizerMultUpdate(const double mul)74   void NormalizerMultUpdate(const double mul) {
75     normalizer_ = normalizer_ * mul;
76   }
NormalizerAddUpdate(const double add)77   void NormalizerAddUpdate(const double add) {
78     normalizer_ += add;
79   }
80 
81   // Divides all the values by the normalizer, then it resets it to 1.0
82   void ResetNormalizer();
83 
84   // Bound getters and setters.
85   // True if there is a bound with val containing the bound. false otherwise.
GetElementMinBound(const Key & fname,double * val)86   bool GetElementMinBound(const Key &fname, double *val) const {
87     return GetValue(wmin_, fname, val);
88   }
GetElementMaxBound(const Key & fname,double * val)89   bool GetElementMaxBound(const Key &fname, double *val) const {
90     return GetValue(wmax_, fname, val);
91   }
SetElementMinBound(const Key & fname,const double bound)92   void SetElementMinBound(const Key &fname, const double bound) {
93     wmin_[fname] = bound;
94   }
SetElementMaxBound(const Key & fname,const double bound)95   void SetElementMaxBound(const Key &fname, const double bound) {
96     wmax_[fname] = bound;
97   }
98   // Element getters and setters.
GetElement(const Key & fname)99   double GetElement(const Key &fname) const {
100     double val = 0;
101     GetValue(w_, fname, &val);
102     return val;
103   }
SetElement(const Key & fname,const double val)104   void SetElement(const Key &fname, const double val) {
105     //DCHECK(!isnan(val));
106     w_[fname] = val;
107   }
AddUpdateElement(const Key & fname,const double val)108   void AddUpdateElement(const Key &fname, const double val) {
109     w_[fname] += val;
110   }
MultUpdateElement(const Key & fname,const double val)111   void MultUpdateElement(const Key &fname, const double val) {
112     w_[fname] *= val;
113   }
114   // Load another weight vectors. Will overwrite the current vector.
LoadWeightVector(const SparseWeightVector<Key,Hash> & vec)115   void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
116     w_.clear();
117     w_.insert(vec.w_.begin(), vec.w_.end());
118     wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
119     wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
120     normalizer_ = vec.normalizer_;
121   }
Clear()122   void Clear() {
123     w_.clear();
124     wmax_.clear();
125     wmin_.clear();
126   }
GetMap()127   const Wmap& GetMap() const {
128     return w_;
129   }
130   // Vector Operations.
131   void AdditiveWeightUpdate(const double multiplier,
132                             const SparseWeightVector<Key, Hash> &w1,
133                             const double additive_const);
134   void AdditiveSquaredWeightUpdate(const double multiplier,
135                                    const SparseWeightVector<Key, Hash> &w1,
136                                    const double additive_const);
137   void AdditiveInvSqrtWeightUpdate(const double multiplier,
138                                    const SparseWeightVector<Key, Hash> &w1,
139                                    const double additive_const);
140   void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
141   double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
142   // L-x norm. eg. L1, L2.
143   double LxNorm(const double x) const;
144   double L2Norm() const;
145   double L1Norm() const;
146   double L0Norm(const double epsilon) const;
147   // Bound preserving updates.
148   void AdditiveWeightUpdateBounded(const double multiplier,
149                                    const SparseWeightVector<Key, Hash> &w1,
150                                    const double additive_const);
151   void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
152   void ReprojectToBounds();
153   void ReprojectL0(const double l0_norm);
154   void ReprojectL1(const double l1_norm);
155   void ReprojectL2(const double l2_norm);
156   // Reproject using the given norm.
157   // Will also rescale regularizer_ if it gets too small/large.
158   int32 Reproject(const double norm, const RegularizationType r);
159   // Convert this vector to a string, simply for debugging.
DebugString()160   std::string DebugString() const {
161     std::stringstream stream;
162     stream << *this;
163     return stream.str();
164   }
165  private:
166   // The weight map.
167   Wmap w_;
168   // Constraint bounds.
169   Wmap wmin_;
170   Wmap wmax_;
171   // Normalizing constant in magnitude measurement.
172   double normalizer_;
173   // This function in necessary since by default hash_map inserts an element
174   // if it does not find the key through [] operator. It implements a lookup
175   // without the space overhead of an add.
GetValue(const Wmap & w1,const Key & fname,double * val)176   bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
177     Witer_const iter = w1.find(fname);
178     if (iter != w1.end()) {
179       (*val) = iter->second;
180       return true;
181     } else {
182       (*val) = 0;
183       return false;
184     }
185   }
186 };
187 
188 // Outputs a SparseWeightVector, for debugging.
189 template <class Key, class Hash>
190 std::ostream& operator<<(std::ostream &stream,
191                     const SparseWeightVector<Key, Hash> &vector) {
192   typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
193   stream << "[[ ";
194   for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
195        iter != w_map.end();
196        ++iter) {
197     stream << "<" << iter->first << ", " << iter->second << "> ";
198   }
199   return stream << " ]]";
200 };
201 
202 }  // namespace learning_stochastic_linear
203 #endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
204