• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 
45 namespace cv { namespace ml {
46 
47 static inline double
log_ratio(double val)48 log_ratio( double val )
49 {
50     const double eps = 1e-5;
51     val = std::max( val, eps );
52     val = std::min( val, 1. - eps );
53     return log( val/(1. - val) );
54 }
55 
56 
BoostTreeParams()57 BoostTreeParams::BoostTreeParams()
58 {
59     boostType = Boost::REAL;
60     weakCount = 100;
61     weightTrimRate = 0.95;
62 }
63 
BoostTreeParams(int _boostType,int _weak_count,double _weightTrimRate)64 BoostTreeParams::BoostTreeParams( int _boostType, int _weak_count,
65                                   double _weightTrimRate)
66 {
67     boostType = _boostType;
68     weakCount = _weak_count;
69     weightTrimRate = _weightTrimRate;
70 }
71 
72 class DTreesImplForBoost : public DTreesImpl
73 {
74 public:
DTreesImplForBoost()75     DTreesImplForBoost()
76     {
77         params.setCVFolds(0);
78         params.setMaxDepth(1);
79     }
~DTreesImplForBoost()80     virtual ~DTreesImplForBoost() {}
81 
isClassifier() const82     bool isClassifier() const { return true; }
83 
clear()84     void clear()
85     {
86         DTreesImpl::clear();
87     }
88 
startTraining(const Ptr<TrainData> & trainData,int flags)89     void startTraining( const Ptr<TrainData>& trainData, int flags )
90     {
91         DTreesImpl::startTraining(trainData, flags);
92         sumResult.assign(w->sidx.size(), 0.);
93 
94         if( bparams.boostType != Boost::DISCRETE )
95         {
96             _isClassifier = false;
97             int i, n = (int)w->cat_responses.size();
98             w->ord_responses.resize(n);
99 
100             double a = -1, b = 1;
101             if( bparams.boostType == Boost::LOGIT )
102             {
103                 a = -2, b = 2;
104             }
105             for( i = 0; i < n; i++ )
106                 w->ord_responses[i] = w->cat_responses[i] > 0 ? b : a;
107         }
108 
109         normalizeWeights();
110     }
111 
normalizeWeights()112     void normalizeWeights()
113     {
114         int i, n = (int)w->sidx.size();
115         double sumw = 0, a, b;
116         for( i = 0; i < n; i++ )
117             sumw += w->sample_weights[w->sidx[i]];
118         if( sumw > DBL_EPSILON )
119         {
120             a = 1./sumw;
121             b = 0;
122         }
123         else
124         {
125             a = 0;
126             b = 1;
127         }
128         for( i = 0; i < n; i++ )
129         {
130             double& wval = w->sample_weights[w->sidx[i]];
131             wval = wval*a + b;
132         }
133     }
134 
endTraining()135     void endTraining()
136     {
137         DTreesImpl::endTraining();
138         vector<double> e;
139         std::swap(sumResult, e);
140     }
141 
scaleTree(int root,double scale)142     void scaleTree( int root, double scale )
143     {
144         int nidx = root, pidx = 0;
145         Node *node = 0;
146 
147         // traverse the tree and save all the nodes in depth-first order
148         for(;;)
149         {
150             for(;;)
151             {
152                 node = &nodes[nidx];
153                 node->value *= scale;
154                 if( node->left < 0 )
155                     break;
156                 nidx = node->left;
157             }
158 
159             for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
160                  nidx = pidx, pidx = nodes[pidx].parent )
161                 ;
162 
163             if( pidx < 0 )
164                 break;
165 
166             nidx = nodes[pidx].right;
167         }
168     }
169 
calcValue(int nidx,const vector<int> & _sidx)170     void calcValue( int nidx, const vector<int>& _sidx )
171     {
172         DTreesImpl::calcValue(nidx, _sidx);
173         WNode* node = &w->wnodes[nidx];
174         if( bparams.boostType == Boost::DISCRETE )
175         {
176             node->value = node->class_idx == 0 ? -1 : 1;
177         }
178         else if( bparams.boostType == Boost::REAL )
179         {
180             double p = (node->value+1)*0.5;
181             node->value = 0.5*log_ratio(p);
182         }
183     }
184 
train(const Ptr<TrainData> & trainData,int flags)185     bool train( const Ptr<TrainData>& trainData, int flags )
186     {
187         startTraining(trainData, flags);
188         int treeidx, ntrees = bparams.weakCount >= 0 ? bparams.weakCount : 10000;
189         vector<int> sidx = w->sidx;
190 
191         for( treeidx = 0; treeidx < ntrees; treeidx++ )
192         {
193             int root = addTree( sidx );
194             if( root < 0 )
195                 return false;
196             updateWeightsAndTrim( treeidx, sidx );
197         }
198         endTraining();
199         return true;
200     }
201 
updateWeightsAndTrim(int treeidx,vector<int> & sidx)202     void updateWeightsAndTrim( int treeidx, vector<int>& sidx )
203     {
204         int i, n = (int)w->sidx.size();
205         int nvars = (int)varIdx.size();
206         double sumw = 0., C = 1.;
207         cv::AutoBuffer<double> buf(n + nvars);
208         double* result = buf;
209         float* sbuf = (float*)(result + n);
210         Mat sample(1, nvars, CV_32F, sbuf);
211         int predictFlags = bparams.boostType == Boost::DISCRETE ? (PREDICT_MAX_VOTE | RAW_OUTPUT) : PREDICT_SUM;
212         predictFlags |= COMPRESSED_INPUT;
213 
214         for( i = 0; i < n; i++ )
215         {
216             w->data->getSample(varIdx, w->sidx[i], sbuf );
217             result[i] = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
218         }
219 
220         // now update weights and other parameters for each type of boosting
221         if( bparams.boostType == Boost::DISCRETE )
222         {
223             // Discrete AdaBoost:
224             //   weak_eval[i] (=f(x_i)) is in {-1,1}
225             //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
226             //   C = log((1-err)/err)
227             //   w_i *= exp(C*(f(x_i) != y_i))
228             double err = 0.;
229 
230             for( i = 0; i < n; i++ )
231             {
232                 int si = w->sidx[i];
233                 double wval = w->sample_weights[si];
234                 sumw += wval;
235                 err += wval*(result[i] != w->cat_responses[si]);
236             }
237 
238             if( sumw != 0 )
239                 err /= sumw;
240             C = -log_ratio( err );
241             double scale = std::exp(C);
242 
243             sumw = 0;
244             for( i = 0; i < n; i++ )
245             {
246                 int si = w->sidx[i];
247                 double wval = w->sample_weights[si];
248                 if( result[i] != w->cat_responses[si] )
249                     wval *= scale;
250                 sumw += wval;
251                 w->sample_weights[si] = wval;
252             }
253 
254             scaleTree(roots[treeidx], C);
255         }
256         else if( bparams.boostType == Boost::REAL || bparams.boostType == Boost::GENTLE )
257         {
258             // Real AdaBoost:
259             //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
260             //   w_i *= exp(-y_i*f(x_i))
261 
262             // Gentle AdaBoost:
263             //   weak_eval[i] = f(x_i) in [-1,1]
264             //   w_i *= exp(-y_i*f(x_i))
265             for( i = 0; i < n; i++ )
266             {
267                 int si = w->sidx[i];
268                 CV_Assert( std::abs(w->ord_responses[si]) == 1 );
269                 double wval = w->sample_weights[si]*std::exp(-result[i]*w->ord_responses[si]);
270                 sumw += wval;
271                 w->sample_weights[si] = wval;
272             }
273         }
274         else if( bparams.boostType == Boost::LOGIT )
275         {
276             // LogitBoost:
277             //   weak_eval[i] = f(x_i) in [-z_max,z_max]
278             //   sum_response = F(x_i).
279             //   F(x_i) += 0.5*f(x_i)
280             //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
281             //   reuse weak_eval: weak_eval[i] <- p(x_i)
282             //   w_i = p(x_i)*1(1 - p(x_i))
283             //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
284             //   store z_i to the data->data_root as the new target responses
285             const double lb_weight_thresh = FLT_EPSILON;
286             const double lb_z_max = 10.;
287 
288             for( i = 0; i < n; i++ )
289             {
290                 int si = w->sidx[i];
291                 sumResult[i] += 0.5*result[i];
292                 double p = 1./(1 + std::exp(-2*sumResult[i]));
293                 double wval = std::max( p*(1 - p), lb_weight_thresh ), z;
294                 w->sample_weights[si] = wval;
295                 sumw += wval;
296                 if( w->ord_responses[si] > 0 )
297                 {
298                     z = 1./p;
299                     w->ord_responses[si] = std::min(z, lb_z_max);
300                 }
301                 else
302                 {
303                     z = 1./(1-p);
304                     w->ord_responses[si] = -std::min(z, lb_z_max);
305                 }
306             }
307         }
308         else
309             CV_Error(CV_StsNotImplemented, "Unknown boosting type");
310 
311         /*if( bparams.boostType != Boost::LOGIT )
312         {
313             double err = 0;
314             for( i = 0; i < n; i++ )
315             {
316                 sumResult[i] += result[i]*C;
317                 if( bparams.boostType != Boost::DISCRETE )
318                     err += sumResult[i]*w->ord_responses[w->sidx[i]] < 0;
319                 else
320                     err += sumResult[i]*w->cat_responses[w->sidx[i]] < 0;
321             }
322             printf("%d trees. C=%.2f, training error=%.1f%%, working set size=%d (out of %d)\n", (int)roots.size(), C, err*100./n, (int)sidx.size(), n);
323         }*/
324 
325         // renormalize weights
326         if( sumw > FLT_EPSILON )
327             normalizeWeights();
328 
329         if( bparams.weightTrimRate <= 0. || bparams.weightTrimRate >= 1. )
330             return;
331 
332         for( i = 0; i < n; i++ )
333             result[i] = w->sample_weights[w->sidx[i]];
334         std::sort(result, result + n);
335 
336         // as weight trimming occurs immediately after updating the weights,
337         // where they are renormalized, we assume that the weight sum = 1.
338         sumw = 1. - bparams.weightTrimRate;
339 
340         for( i = 0; i < n; i++ )
341         {
342             double wval = result[i];
343             if( sumw <= 0 )
344                 break;
345             sumw -= wval;
346         }
347 
348         double threshold = i < n ? result[i] : DBL_MAX;
349         sidx.clear();
350 
351         for( i = 0; i < n; i++ )
352         {
353             int si = w->sidx[i];
354             if( w->sample_weights[si] >= threshold )
355                 sidx.push_back(si);
356         }
357     }
358 
predictTrees(const Range & range,const Mat & sample,int flags0) const359     float predictTrees( const Range& range, const Mat& sample, int flags0 ) const
360     {
361         int flags = (flags0 & ~PREDICT_MASK) | PREDICT_SUM;
362         float val = DTreesImpl::predictTrees(range, sample, flags);
363         if( flags != flags0 )
364         {
365             int ival = (int)(val > 0);
366             if( !(flags0 & RAW_OUTPUT) )
367                 ival = classLabels[ival];
368             val = (float)ival;
369         }
370         return val;
371     }
372 
writeTrainingParams(FileStorage & fs) const373     void writeTrainingParams( FileStorage& fs ) const
374     {
375         fs << "boosting_type" <<
376         (bparams.boostType == Boost::DISCRETE ? "DiscreteAdaboost" :
377         bparams.boostType == Boost::REAL ? "RealAdaboost" :
378         bparams.boostType == Boost::LOGIT ? "LogitBoost" :
379         bparams.boostType == Boost::GENTLE ? "GentleAdaboost" : "Unknown");
380 
381         DTreesImpl::writeTrainingParams(fs);
382         fs << "weight_trimming_rate" << bparams.weightTrimRate;
383     }
384 
write(FileStorage & fs) const385     void write( FileStorage& fs ) const
386     {
387         if( roots.empty() )
388             CV_Error( CV_StsBadArg, "RTrees have not been trained" );
389 
390         writeParams(fs);
391 
392         int k, ntrees = (int)roots.size();
393 
394         fs << "ntrees" << ntrees
395         << "trees" << "[";
396 
397         for( k = 0; k < ntrees; k++ )
398         {
399             fs << "{";
400             writeTree(fs, roots[k]);
401             fs << "}";
402         }
403 
404         fs << "]";
405     }
406 
readParams(const FileNode & fn)407     void readParams( const FileNode& fn )
408     {
409         DTreesImpl::readParams(fn);
410 
411         FileNode tparams_node = fn["training_params"];
412         // check for old layout
413         String bts = (String)(fn["boosting_type"].empty() ?
414                          tparams_node["boosting_type"] : fn["boosting_type"]);
415         bparams.boostType = (bts == "DiscreteAdaboost" ? Boost::DISCRETE :
416                              bts == "RealAdaboost" ? Boost::REAL :
417                              bts == "LogitBoost" ? Boost::LOGIT :
418                              bts == "GentleAdaboost" ? Boost::GENTLE : -1);
419         _isClassifier = bparams.boostType == Boost::DISCRETE;
420         // check for old layout
421         bparams.weightTrimRate = (double)(fn["weight_trimming_rate"].empty() ?
422                                     tparams_node["weight_trimming_rate"] : fn["weight_trimming_rate"]);
423     }
424 
read(const FileNode & fn)425     void read( const FileNode& fn )
426     {
427         clear();
428 
429         int ntrees = (int)fn["ntrees"];
430         readParams(fn);
431 
432         FileNode trees_node = fn["trees"];
433         FileNodeIterator it = trees_node.begin();
434         CV_Assert( ntrees == (int)trees_node.size() );
435 
436         for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
437         {
438             FileNode nfn = (*it)["nodes"];
439             readTree(nfn);
440         }
441     }
442 
443     BoostTreeParams bparams;
444     vector<double> sumResult;
445 };
446 
447 
448 class BoostImpl : public Boost
449 {
450 public:
BoostImpl()451     BoostImpl() {}
~BoostImpl()452     virtual ~BoostImpl() {}
453 
454     CV_IMPL_PROPERTY(int, BoostType, impl.bparams.boostType)
455     CV_IMPL_PROPERTY(int, WeakCount, impl.bparams.weakCount)
456     CV_IMPL_PROPERTY(double, WeightTrimRate, impl.bparams.weightTrimRate)
457 
458     CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
459     CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
460     CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
461     CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
462     CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
463     CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
464     CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
465     CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
466     CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
467 
getDefaultName() const468     String getDefaultName() const { return "opencv_ml_boost"; }
469 
train(const Ptr<TrainData> & trainData,int flags)470     bool train( const Ptr<TrainData>& trainData, int flags )
471     {
472         return impl.train(trainData, flags);
473     }
474 
predict(InputArray samples,OutputArray results,int flags) const475     float predict( InputArray samples, OutputArray results, int flags ) const
476     {
477         return impl.predict(samples, results, flags);
478     }
479 
write(FileStorage & fs) const480     void write( FileStorage& fs ) const
481     {
482         impl.write(fs);
483     }
484 
read(const FileNode & fn)485     void read( const FileNode& fn )
486     {
487         impl.read(fn);
488     }
489 
getVarCount() const490     int getVarCount() const { return impl.getVarCount(); }
491 
isTrained() const492     bool isTrained() const { return impl.isTrained(); }
isClassifier() const493     bool isClassifier() const { return impl.isClassifier(); }
494 
getRoots() const495     const vector<int>& getRoots() const { return impl.getRoots(); }
getNodes() const496     const vector<Node>& getNodes() const { return impl.getNodes(); }
getSplits() const497     const vector<Split>& getSplits() const { return impl.getSplits(); }
getSubsets() const498     const vector<int>& getSubsets() const { return impl.getSubsets(); }
499 
500     DTreesImplForBoost impl;
501 };
502 
503 
create()504 Ptr<Boost> Boost::create()
505 {
506     return makePtr<BoostImpl>();
507 }
508 
509 }}
510 
511 /* End of file. */
512