• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2012 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // This file contains the MulticlassPA class which implements a simple
18 // linear multi-class classifier based on the multi-prototype version of
19 // passive aggressive.
20 
21 #ifndef LEARNINGFW_MULTICLASS_PA_H_
22 #define LEARNINGFW_MULTICLASS_PA_H_
23 
24 #include <vector>
25 #include <cmath>
26 
27 const float kEpsilon = 1.0e-4;
28 
29 namespace learningfw {
30 
31 class MulticlassPA {
32  public:
33   MulticlassPA(int num_classes,
34                int num_dimensions,
35                float aggressiveness);
36   virtual ~MulticlassPA();
37 
38   // Initialize all parameters to 0.0.
39   void InitializeParameters();
40 
41   // Returns a random class that is different from the target class.
42   int PickAClassExcept(int target);
43 
44   // Returns a random example.
45   int PickAnExample(int num_examples);
46 
47   // Computes the score of a given input vector for a given parameter
48   // vector, by computing the dot product between the two.
49   float Score(const std::vector<float>& inputs,
50               const std::vector<float>& parameters) const;
51   float SparseScore(const std::vector<std::pair<int, float> >& inputs,
52                     const std::vector<float>& parameters) const;
53 
54   // Returns the square of the L2 norm.
55   float L2NormSquare(const std::vector<float>& inputs) const;
56   float SparseL2NormSquare(const std::vector<std::pair<int, float> >& inputs) const;
57 
58   // Verify if the given example is correctly classified with margin with
59   // respect to a random class.  If not, then modifies the corresponding
60   // parameters using passive-aggressive.
61   virtual float TrainOneExample(const std::vector<float>& inputs, int target);
62   virtual float SparseTrainOneExample(
63       const std::vector<std::pair<int, float> >& inputs, int target);
64 
65   // Iteratively train the model for num_iterations on the given dataset.
66   float Train(const std::vector<std::pair<std::vector<float>, int> >& data,
67               int num_iterations);
68   float SparseTrain(
69       const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data,
70       int num_iterations);
71 
72   // Returns the best class for a given input vector.
73   virtual int GetClass(const std::vector<float>& inputs);
74   virtual int SparseGetClass(const std::vector<std::pair<int, float> >& inputs);
75 
76   // Computes the test error of a given test set on the current model.
77   float Test(const std::vector<std::pair<std::vector<float>, int> >& data);
78   float SparseTest(
79       const std::vector<std::pair<std::vector<std::pair<int, float> >, int> >& data);
80 
81   // A few accessors used by the sub-classes.
aggressiveness()82   inline float aggressiveness() const {
83     return aggressiveness_;
84   }
85 
parameters()86   inline std::vector<std::vector<float> >& parameters() {
87     return parameters_;
88   }
89 
mutable_parameters()90   inline std::vector<std::vector<float> >* mutable_parameters() {
91     return &parameters_;
92   }
93 
num_classes()94   inline int num_classes() const {
95     return num_classes_;
96   }
97 
num_dimensions()98   inline int num_dimensions() const {
99     return num_dimensions_;
100   }
101 
102  private:
103   // Keeps the current parameter vector.
104   std::vector<std::vector<float> > parameters_;
105 
106   // The number of classes of the problem.
107   int num_classes_;
108 
109   // The number of dimensions of the input vectors.
110   int num_dimensions_;
111 
112   // Controls how "aggressive" training should be.
113   float aggressiveness_;
114 
115 };
116 }  // namespace learningfw
117 #endif  // LEARNINGFW_MULTICLASS_PA_H_
118