• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 
45 namespace cv {
46 namespace ml {
47 
48 //////////////////////////////////////////////////////////////////////////////////////////
49 //                                  Random trees                                        //
50 //////////////////////////////////////////////////////////////////////////////////////////
RTreeParams()51 RTreeParams::RTreeParams()
52 {
53     calcVarImportance = false;
54     nactiveVars = 0;
55     termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1);
56 }
57 
RTreeParams(bool _calcVarImportance,int _nactiveVars,TermCriteria _termCrit)58 RTreeParams::RTreeParams(bool _calcVarImportance,
59                          int _nactiveVars,
60                          TermCriteria _termCrit )
61 {
62     calcVarImportance = _calcVarImportance;
63     nactiveVars = _nactiveVars;
64     termCrit = _termCrit;
65 }
66 
67 
68 class DTreesImplForRTrees : public DTreesImpl
69 {
70 public:
DTreesImplForRTrees()71     DTreesImplForRTrees()
72     {
73         params.setMaxDepth(5);
74         params.setMinSampleCount(10);
75         params.setRegressionAccuracy(0.f);
76         params.useSurrogates = false;
77         params.setMaxCategories(10);
78         params.setCVFolds(0);
79         params.use1SERule = false;
80         params.truncatePrunedTree = false;
81         params.priors = Mat();
82     }
~DTreesImplForRTrees()83     virtual ~DTreesImplForRTrees() {}
84 
clear()85     void clear()
86     {
87         DTreesImpl::clear();
88         oobError = 0.;
89         rng = RNG((uint64)-1);
90     }
91 
getActiveVars()92     const vector<int>& getActiveVars()
93     {
94         int i, nvars = (int)allVars.size(), m = (int)activeVars.size();
95         for( i = 0; i < nvars; i++ )
96         {
97             int i1 = rng.uniform(0, nvars);
98             int i2 = rng.uniform(0, nvars);
99             std::swap(allVars[i1], allVars[i2]);
100         }
101         for( i = 0; i < m; i++ )
102             activeVars[i] = allVars[i];
103         return activeVars;
104     }
105 
startTraining(const Ptr<TrainData> & trainData,int flags)106     void startTraining( const Ptr<TrainData>& trainData, int flags )
107     {
108         DTreesImpl::startTraining(trainData, flags);
109         int nvars = w->data->getNVars();
110         int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
111         m = std::min(std::max(m, 1), nvars);
112         allVars.resize(nvars);
113         activeVars.resize(m);
114         for( i = 0; i < nvars; i++ )
115             allVars[i] = varIdx[i];
116     }
117 
endTraining()118     void endTraining()
119     {
120         DTreesImpl::endTraining();
121         vector<int> a, b;
122         std::swap(allVars, a);
123         std::swap(activeVars, b);
124     }
125 
train(const Ptr<TrainData> & trainData,int flags)126     bool train( const Ptr<TrainData>& trainData, int flags )
127     {
128         startTraining(trainData, flags);
129         int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
130             rparams.termCrit.maxCount : 10000;
131         int i, j, k, vi, vi_, n = (int)w->sidx.size();
132         int nclasses = (int)classLabels.size();
133         double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 &&
134             rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.;
135         vector<int> sidx(n);
136         vector<uchar> oobmask(n);
137         vector<int> oobidx;
138         vector<int> oobperm;
139         vector<double> oobres(n, 0.);
140         vector<int> oobcount(n, 0);
141         vector<int> oobvotes(n*nclasses, 0);
142         int nvars = w->data->getNVars();
143         int nallvars = w->data->getNAllVars();
144         const int* vidx = !varIdx.empty() ? &varIdx[0] : 0;
145         vector<float> samplebuf(nallvars);
146         Mat samples = w->data->getSamples();
147         float* psamples = samples.ptr<float>();
148         size_t sstep0 = samples.step1(), sstep1 = 1;
149         Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]);
150         int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM;
151 
152         bool calcOOBError = eps > 0 || rparams.calcVarImportance;
153         double max_response = 0.;
154 
155         if( w->data->getLayout() == COL_SAMPLE )
156             std::swap(sstep0, sstep1);
157 
158         if( !_isClassifier )
159         {
160             for( i = 0; i < n; i++ )
161             {
162                 double val = std::abs(w->ord_responses[w->sidx[i]]);
163                 max_response = std::max(max_response, val);
164             }
165         }
166 
167         if( rparams.calcVarImportance )
168             varImportance.resize(nallvars, 0.f);
169 
170         for( treeidx = 0; treeidx < ntrees; treeidx++ )
171         {
172             for( i = 0; i < n; i++ )
173                 oobmask[i] = (uchar)1;
174 
175             for( i = 0; i < n; i++ )
176             {
177                 j = rng.uniform(0, n);
178                 sidx[i] = w->sidx[j];
179                 oobmask[j] = (uchar)0;
180             }
181             int root = addTree( sidx );
182             if( root < 0 )
183                 return false;
184 
185             if( calcOOBError )
186             {
187                 oobidx.clear();
188                 for( i = 0; i < n; i++ )
189                 {
190                     if( !oobmask[i] )
191                         oobidx.push_back(i);
192                 }
193                 int n_oob = (int)oobidx.size();
194                 // if there is no out-of-bag samples, we can not compute OOB error
195                 // nor update the variable importance vector; so we proceed to the next tree
196                 if( n_oob == 0 )
197                     continue;
198                 double ncorrect_responses = 0.;
199 
200                 oobError = 0.;
201                 for( i = 0; i < n_oob; i++ )
202                 {
203                     j = oobidx[i];
204                     sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
205 
206                     double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
207                     if( !_isClassifier )
208                     {
209                         oobres[j] += val;
210                         oobcount[j]++;
211                         double true_val = w->ord_responses[w->sidx[j]];
212                         double a = oobres[j]/oobcount[j] - true_val;
213                         oobError += a*a;
214                         val = (val - true_val)/max_response;
215                         ncorrect_responses += std::exp( -val*val );
216                     }
217                     else
218                     {
219                         int ival = cvRound(val);
220                         int* votes = &oobvotes[j*nclasses];
221                         votes[ival]++;
222                         int best_class = 0;
223                         for( k = 1; k < nclasses; k++ )
224                             if( votes[best_class] < votes[k] )
225                                 best_class = k;
226                         int diff = best_class != w->cat_responses[w->sidx[j]];
227                         oobError += diff;
228                         ncorrect_responses += diff == 0;
229                     }
230                 }
231 
232                 oobError /= n_oob;
233                 if( rparams.calcVarImportance && n_oob > 1 )
234                 {
235                     oobperm.resize(n_oob);
236                     for( i = 0; i < n_oob; i++ )
237                         oobperm[i] = oobidx[i];
238 
239                     for( vi_ = 0; vi_ < nvars; vi_++ )
240                     {
241                         vi = vidx ? vidx[vi_] : vi_;
242                         double ncorrect_responses_permuted = 0;
243                         for( i = 0; i < n_oob; i++ )
244                         {
245                             int i1 = rng.uniform(0, n_oob);
246                             int i2 = rng.uniform(0, n_oob);
247                             std::swap(i1, i2);
248                         }
249 
250                         for( i = 0; i < n_oob; i++ )
251                         {
252                             j = oobidx[i];
253                             int vj = oobperm[i];
254                             sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
255                             for( k = 0; k < nallvars; k++ )
256                                 sample.at<float>(k) = sample0.at<float>(k);
257                             sample.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
258 
259                             double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
260                             if( !_isClassifier )
261                             {
262                                 val = (val - w->ord_responses[w->sidx[j]])/max_response;
263                                 ncorrect_responses_permuted += exp( -val*val );
264                             }
265                             else
266                                 ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
267                         }
268                         varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
269                     }
270                 }
271             }
272             if( calcOOBError && oobError < eps )
273                 break;
274         }
275 
276         if( rparams.calcVarImportance )
277         {
278             for( vi_ = 0; vi_ < nallvars; vi_++ )
279                 varImportance[vi_] = std::max(varImportance[vi_], 0.f);
280             normalize(varImportance, varImportance, 1., 0, NORM_L1);
281         }
282         endTraining();
283         return true;
284     }
285 
writeTrainingParams(FileStorage & fs) const286     void writeTrainingParams( FileStorage& fs ) const
287     {
288         DTreesImpl::writeTrainingParams(fs);
289         fs << "nactive_vars" << rparams.nactiveVars;
290     }
291 
write(FileStorage & fs) const292     void write( FileStorage& fs ) const
293     {
294         if( roots.empty() )
295             CV_Error( CV_StsBadArg, "RTrees have not been trained" );
296 
297         writeParams(fs);
298 
299         fs << "oob_error" << oobError;
300         if( !varImportance.empty() )
301             fs << "var_importance" << varImportance;
302 
303         int k, ntrees = (int)roots.size();
304 
305         fs << "ntrees" << ntrees
306            << "trees" << "[";
307 
308         for( k = 0; k < ntrees; k++ )
309         {
310             fs << "{";
311             writeTree(fs, roots[k]);
312             fs << "}";
313         }
314 
315         fs << "]";
316     }
317 
readParams(const FileNode & fn)318     void readParams( const FileNode& fn )
319     {
320         DTreesImpl::readParams(fn);
321 
322         FileNode tparams_node = fn["training_params"];
323         rparams.nactiveVars = (int)tparams_node["nactive_vars"];
324     }
325 
read(const FileNode & fn)326     void read( const FileNode& fn )
327     {
328         clear();
329 
330         //int nclasses = (int)fn["nclasses"];
331         //int nsamples = (int)fn["nsamples"];
332         oobError = (double)fn["oob_error"];
333         int ntrees = (int)fn["ntrees"];
334 
335         readVectorOrMat(fn["var_importance"], varImportance);
336 
337         readParams(fn);
338 
339         FileNode trees_node = fn["trees"];
340         FileNodeIterator it = trees_node.begin();
341         CV_Assert( ntrees == (int)trees_node.size() );
342 
343         for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
344         {
345             FileNode nfn = (*it)["nodes"];
346             readTree(nfn);
347         }
348     }
349 
350     RTreeParams rparams;
351     double oobError;
352     vector<float> varImportance;
353     vector<int> allVars, activeVars;
354     RNG rng;
355 };
356 
357 
358 class RTreesImpl : public RTrees
359 {
360 public:
361     CV_IMPL_PROPERTY(bool, CalculateVarImportance, impl.rparams.calcVarImportance)
362     CV_IMPL_PROPERTY(int, ActiveVarCount, impl.rparams.nactiveVars)
363     CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, impl.rparams.termCrit)
364 
365     CV_WRAP_SAME_PROPERTY(int, MaxCategories, impl.params)
366     CV_WRAP_SAME_PROPERTY(int, MaxDepth, impl.params)
367     CV_WRAP_SAME_PROPERTY(int, MinSampleCount, impl.params)
368     CV_WRAP_SAME_PROPERTY(int, CVFolds, impl.params)
369     CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, impl.params)
370     CV_WRAP_SAME_PROPERTY(bool, Use1SERule, impl.params)
371     CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, impl.params)
372     CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, impl.params)
373     CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, impl.params)
374 
RTreesImpl()375     RTreesImpl() {}
~RTreesImpl()376     virtual ~RTreesImpl() {}
377 
getDefaultName() const378     String getDefaultName() const { return "opencv_ml_rtrees"; }
379 
train(const Ptr<TrainData> & trainData,int flags)380     bool train( const Ptr<TrainData>& trainData, int flags )
381     {
382         return impl.train(trainData, flags);
383     }
384 
predict(InputArray samples,OutputArray results,int flags) const385     float predict( InputArray samples, OutputArray results, int flags ) const
386     {
387         return impl.predict(samples, results, flags);
388     }
389 
write(FileStorage & fs) const390     void write( FileStorage& fs ) const
391     {
392         impl.write(fs);
393     }
394 
read(const FileNode & fn)395     void read( const FileNode& fn )
396     {
397         impl.read(fn);
398     }
399 
getVarImportance() const400     Mat getVarImportance() const { return Mat_<float>(impl.varImportance, true); }
getVarCount() const401     int getVarCount() const { return impl.getVarCount(); }
402 
isTrained() const403     bool isTrained() const { return impl.isTrained(); }
isClassifier() const404     bool isClassifier() const { return impl.isClassifier(); }
405 
getRoots() const406     const vector<int>& getRoots() const { return impl.getRoots(); }
getNodes() const407     const vector<Node>& getNodes() const { return impl.getNodes(); }
getSplits() const408     const vector<Split>& getSplits() const { return impl.getSplits(); }
getSubsets() const409     const vector<int>& getSubsets() const { return impl.getSubsets(); }
410 
411     DTreesImplForRTrees impl;
412 };
413 
414 
create()415 Ptr<RTrees> RTrees::create()
416 {
417     return makePtr<RTreesImpl>();
418 }
419 
420 }}
421 
422 // End of file.
423