• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #ifndef __OPENCV_ML_PRECOMP_HPP__
42 #define __OPENCV_ML_PRECOMP_HPP__
43 
44 #include "opencv2/core.hpp"
45 #include "opencv2/ml.hpp"
46 #include "opencv2/core/core_c.h"
47 #include "opencv2/core/utility.hpp"
48 
49 #include "opencv2/core/private.hpp"
50 
51 #include <assert.h>
52 #include <float.h>
53 #include <limits.h>
54 #include <math.h>
55 #include <stdlib.h>
56 #include <stdio.h>
57 #include <string.h>
58 #include <time.h>
59 #include <vector>
60 
61 /****************************************************************************************\
62  *                               Main struct definitions                                  *
63  \****************************************************************************************/
64 
65 /* log(2*PI) */
66 #define CV_LOG2PI (1.8378770664093454835606594728112)
67 
68 namespace cv
69 {
70 namespace ml
71 {
72     using std::vector;
73 
74     #define CV_DTREE_CAT_DIR(idx,subset) \
75         (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
76 
77     template<typename _Tp> struct cmp_lt_idx
78     {
cmp_lt_idxcv::ml::cmp_lt_idx79         cmp_lt_idx(const _Tp* _arr) : arr(_arr) {}
operator ()cv::ml::cmp_lt_idx80         bool operator ()(int a, int b) const { return arr[a] < arr[b]; }
81         const _Tp* arr;
82     };
83 
84     template<typename _Tp> struct cmp_lt_ptr
85     {
cmp_lt_ptrcv::ml::cmp_lt_ptr86         cmp_lt_ptr() {}
operator ()cv::ml::cmp_lt_ptr87         bool operator ()(const _Tp* a, const _Tp* b) const { return *a < *b; }
88     };
89 
setRangeVector(std::vector<int> & vec,int n)90     static inline void setRangeVector(std::vector<int>& vec, int n)
91     {
92         vec.resize(n);
93         for( int i = 0; i < n; i++ )
94             vec[i] = i;
95     }
96 
writeTermCrit(FileStorage & fs,const TermCriteria & termCrit)97     static inline void writeTermCrit(FileStorage& fs, const TermCriteria& termCrit)
98     {
99         if( (termCrit.type & TermCriteria::EPS) != 0 )
100             fs << "epsilon" << termCrit.epsilon;
101         if( (termCrit.type & TermCriteria::COUNT) != 0 )
102             fs << "iterations" << termCrit.maxCount;
103     }
104 
readTermCrit(const FileNode & fn)105     static inline TermCriteria readTermCrit(const FileNode& fn)
106     {
107         TermCriteria termCrit;
108         double epsilon = (double)fn["epsilon"];
109         if( epsilon > 0 )
110         {
111             termCrit.type |= TermCriteria::EPS;
112             termCrit.epsilon = epsilon;
113         }
114         int iters = (int)fn["iterations"];
115         if( iters > 0 )
116         {
117             termCrit.type |= TermCriteria::COUNT;
118             termCrit.maxCount = iters;
119         }
120         return termCrit;
121     }
122 
123     struct TreeParams
124     {
125         TreeParams();
126         TreeParams( int maxDepth, int minSampleCount,
127                     double regressionAccuracy, bool useSurrogates,
128                     int maxCategories, int CVFolds,
129                     bool use1SERule, bool truncatePrunedTree,
130                     const Mat& priors );
131 
setMaxCategoriescv::ml::TreeParams132         inline void setMaxCategories(int val)
133         {
134             if( val < 2 )
135                 CV_Error( CV_StsOutOfRange, "max_categories should be >= 2" );
136             maxCategories = std::min(val, 15 );
137         }
setMaxDepthcv::ml::TreeParams138         inline void setMaxDepth(int val)
139         {
140             if( val < 0 )
141                 CV_Error( CV_StsOutOfRange, "max_depth should be >= 0" );
142             maxDepth = std::min( val, 25 );
143         }
setMinSampleCountcv::ml::TreeParams144         inline void setMinSampleCount(int val)
145         {
146             minSampleCount = std::max(val, 1);
147         }
setCVFoldscv::ml::TreeParams148         inline void setCVFolds(int val)
149         {
150             if( val < 0 )
151                 CV_Error( CV_StsOutOfRange,
152                           "params.CVFolds should be =0 (the tree is not pruned) "
153                           "or n>0 (tree is pruned using n-fold cross-validation)" );
154             if( val == 1 )
155                 val = 0;
156             CVFolds = val;
157         }
setRegressionAccuracycv::ml::TreeParams158         inline void setRegressionAccuracy(float val)
159         {
160             if( val < 0 )
161                 CV_Error( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
162             regressionAccuracy = val;
163         }
164 
getMaxCategoriescv::ml::TreeParams165         inline int getMaxCategories() const { return maxCategories; }
getMaxDepthcv::ml::TreeParams166         inline int getMaxDepth() const { return maxDepth; }
getMinSampleCountcv::ml::TreeParams167         inline int getMinSampleCount() const { return minSampleCount; }
getCVFoldscv::ml::TreeParams168         inline int getCVFolds() const { return CVFolds; }
getRegressionAccuracycv::ml::TreeParams169         inline float getRegressionAccuracy() const { return regressionAccuracy; }
170 
171         CV_IMPL_PROPERTY(bool, UseSurrogates, useSurrogates)
172         CV_IMPL_PROPERTY(bool, Use1SERule, use1SERule)
173         CV_IMPL_PROPERTY(bool, TruncatePrunedTree, truncatePrunedTree)
174         CV_IMPL_PROPERTY_S(cv::Mat, Priors, priors)
175 
176         public:
177             bool  useSurrogates;
178         bool  use1SERule;
179         bool  truncatePrunedTree;
180         Mat priors;
181 
182     protected:
183         int   maxCategories;
184         int   maxDepth;
185         int   minSampleCount;
186         int   CVFolds;
187         float regressionAccuracy;
188     };
189 
190     struct RTreeParams
191     {
192         RTreeParams();
193         RTreeParams(bool calcVarImportance, int nactiveVars, TermCriteria termCrit );
194         bool calcVarImportance;
195         int nactiveVars;
196         TermCriteria termCrit;
197     };
198 
199     struct BoostTreeParams
200     {
201         BoostTreeParams();
202         BoostTreeParams(int boostType, int weakCount, double weightTrimRate);
203         int boostType;
204         int weakCount;
205         double weightTrimRate;
206     };
207 
208     class DTreesImpl : public DTrees
209     {
210     public:
211         struct WNode
212         {
WNodecv::ml::DTreesImpl::WNode213             WNode()
214             {
215                 class_idx = sample_count = depth = complexity = 0;
216                 parent = left = right = split = defaultDir = -1;
217                 Tn = INT_MAX;
218                 value = maxlr = alpha = node_risk = tree_risk = tree_error = 0.;
219             }
220 
221             int class_idx;
222             double Tn;
223             double value;
224 
225             int parent;
226             int left;
227             int right;
228             int defaultDir;
229 
230             int split;
231 
232             int sample_count;
233             int depth;
234             double maxlr;
235 
236             // global pruning data
237             int complexity;
238             double alpha;
239             double node_risk, tree_risk, tree_error;
240         };
241 
242         struct WSplit
243         {
WSplitcv::ml::DTreesImpl::WSplit244             WSplit()
245             {
246                 varIdx = next = 0;
247                 inversed = false;
248                 quality = c = 0.f;
249                 subsetOfs = -1;
250             }
251 
252             int varIdx;
253             bool inversed;
254             float quality;
255             int next;
256             float c;
257             int subsetOfs;
258         };
259 
260         struct WorkData
261         {
262             WorkData(const Ptr<TrainData>& _data);
263 
264             Ptr<TrainData> data;
265             vector<WNode> wnodes;
266             vector<WSplit> wsplits;
267             vector<int> wsubsets;
268             vector<double> cv_Tn;
269             vector<double> cv_node_risk;
270             vector<double> cv_node_error;
271             vector<int> cv_labels;
272             vector<double> sample_weights;
273             vector<int> cat_responses;
274             vector<double> ord_responses;
275             vector<int> sidx;
276             int maxSubsetSize;
277         };
278 
279         CV_WRAP_SAME_PROPERTY(int, MaxCategories, params)
280         CV_WRAP_SAME_PROPERTY(int, MaxDepth, params)
281         CV_WRAP_SAME_PROPERTY(int, MinSampleCount, params)
282         CV_WRAP_SAME_PROPERTY(int, CVFolds, params)
283         CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, params)
284         CV_WRAP_SAME_PROPERTY(bool, Use1SERule, params)
285         CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, params)
286         CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, params)
287         CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, params)
288 
289         DTreesImpl();
290         virtual ~DTreesImpl();
291         virtual void clear();
292 
getDefaultName() const293         String getDefaultName() const { return "opencv_ml_dtree"; }
isTrained() const294         bool isTrained() const { return !roots.empty(); }
isClassifier() const295         bool isClassifier() const { return _isClassifier; }
getVarCount() const296         int getVarCount() const { return varType.empty() ? 0 : (int)(varType.size() - 1); }
getCatCount(int vi) const297         int getCatCount(int vi) const { return catOfs[vi][1] - catOfs[vi][0]; }
getSubsetSize(int vi) const298         int getSubsetSize(int vi) const { return (getCatCount(vi) + 31)/32; }
299 
300         virtual void setDParams(const TreeParams& _params);
301         virtual void startTraining( const Ptr<TrainData>& trainData, int flags );
302         virtual void endTraining();
303         virtual void initCompVarIdx();
304         virtual bool train( const Ptr<TrainData>& trainData, int flags );
305 
306         virtual int addTree( const vector<int>& sidx );
307         virtual int addNodeAndTrySplit( int parent, const vector<int>& sidx );
308         virtual const vector<int>& getActiveVars();
309         virtual int findBestSplit( const vector<int>& _sidx );
310         virtual void calcValue( int nidx, const vector<int>& _sidx );
311 
312         virtual WSplit findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality );
313 
314         // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
315         virtual void clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels );
316         virtual WSplit findSplitCatClass( int vi, const vector<int>& _sidx, double initQuality, int* subset );
317 
318         virtual WSplit findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality );
319         virtual WSplit findSplitCatReg( int vi, const vector<int>& _sidx, double initQuality, int* subset );
320 
321         virtual int calcDir( int splitidx, const vector<int>& _sidx, vector<int>& _sleft, vector<int>& _sright );
322         virtual int pruneCV( int root );
323 
324         virtual double updateTreeRNC( int root, double T, int fold );
325         virtual bool cutTree( int root, double T, int fold, double min_alpha );
326         virtual float predictTrees( const Range& range, const Mat& sample, int flags ) const;
327         virtual float predict( InputArray inputs, OutputArray outputs, int flags ) const;
328 
329         virtual void writeTrainingParams( FileStorage& fs ) const;
330         virtual void writeParams( FileStorage& fs ) const;
331         virtual void writeSplit( FileStorage& fs, int splitidx ) const;
332         virtual void writeNode( FileStorage& fs, int nidx, int depth ) const;
333         virtual void writeTree( FileStorage& fs, int root ) const;
334         virtual void write( FileStorage& fs ) const;
335 
336         virtual void readParams( const FileNode& fn );
337         virtual int readSplit( const FileNode& fn );
338         virtual int readNode( const FileNode& fn );
339         virtual int readTree( const FileNode& fn );
340         virtual void read( const FileNode& fn );
341 
getRoots() const342         virtual const std::vector<int>& getRoots() const { return roots; }
getNodes() const343         virtual const std::vector<Node>& getNodes() const { return nodes; }
getSplits() const344         virtual const std::vector<Split>& getSplits() const { return splits; }
getSubsets() const345         virtual const std::vector<int>& getSubsets() const { return subsets; }
346 
347         TreeParams params;
348 
349         vector<int> varIdx;
350         vector<int> compVarIdx;
351         vector<uchar> varType;
352         vector<Vec2i> catOfs;
353         vector<int> catMap;
354         vector<int> roots;
355         vector<Node> nodes;
356         vector<Split> splits;
357         vector<int> subsets;
358         vector<int> classLabels;
359         vector<float> missingSubst;
360         vector<int> varMapping;
361         bool _isClassifier;
362 
363         Ptr<WorkData> w;
364     };
365 
366     template <typename T>
readVectorOrMat(const FileNode & node,std::vector<T> & v)367     static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v)
368     {
369         if (node.type() == FileNode::MAP)
370         {
371             Mat m;
372             node >> m;
373             m.copyTo(v);
374         }
375         else if (node.type() == FileNode::SEQ)
376         {
377             node >> v;
378         }
379     }
380 
381 }}
382 
383 #endif /* __OPENCV_ML_PRECOMP_HPP__ */
384