1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // Intel License Agreement 11 // 12 // Copyright (C) 2000, Intel Corporation, all rights reserved. 13 // Third party copyrights are property of their respective owners. 14 // 15 // Redistribution and use in source and binary forms, with or without modification, 16 // are permitted provided that the following conditions are met: 17 // 18 // * Redistribution's of source code must retain the above copyright notice, 19 // this list of conditions and the following disclaimer. 20 // 21 // * Redistribution's in binary form must reproduce the above copyright notice, 22 // this list of conditions and the following disclaimer in the documentation 23 // and/or other materials provided with the distribution. 24 // 25 // * The name of Intel Corporation may not be used to endorse or promote products 26 // derived from this software without specific prior written permission. 27 // 28 // This software is provided by the copyright holders and contributors "as is" and 29 // any express or implied warranties, including, but not limited to, the implied 30 // warranties of merchantability and fitness for a particular purpose are disclaimed. 31 // In no event shall the Intel Corporation or contributors be liable for any direct, 32 // indirect, incidental, special, exemplary, or consequential damages 33 // (including, but not limited to, procurement of substitute goods or services; 34 // loss of use, data, or profits; or business interruption) however caused 35 // and on any theory of liability, whether in contract, strict liability, 36 // or tort (including negligence or otherwise) arising in any way out of 37 // the use of this software, even if advised of the possibility of such damage. 38 // 39 //M*/ 40 41 #ifndef __OPENCV_ML_PRECOMP_HPP__ 42 #define __OPENCV_ML_PRECOMP_HPP__ 43 44 #include "opencv2/core.hpp" 45 #include "opencv2/ml.hpp" 46 #include "opencv2/core/core_c.h" 47 #include "opencv2/core/utility.hpp" 48 49 #include "opencv2/core/private.hpp" 50 51 #include <assert.h> 52 #include <float.h> 53 #include <limits.h> 54 #include <math.h> 55 #include <stdlib.h> 56 #include <stdio.h> 57 #include <string.h> 58 #include <time.h> 59 #include <vector> 60 61 /****************************************************************************************\ 62 * Main struct definitions * 63 \****************************************************************************************/ 64 65 /* log(2*PI) */ 66 #define CV_LOG2PI (1.8378770664093454835606594728112) 67 68 namespace cv 69 { 70 namespace ml 71 { 72 using std::vector; 73 74 #define CV_DTREE_CAT_DIR(idx,subset) \ 75 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1) 76 77 template<typename _Tp> struct cmp_lt_idx 78 { cmp_lt_idxcv::ml::cmp_lt_idx79 cmp_lt_idx(const _Tp* _arr) : arr(_arr) {} operator ()cv::ml::cmp_lt_idx80 bool operator ()(int a, int b) const { return arr[a] < arr[b]; } 81 const _Tp* arr; 82 }; 83 84 template<typename _Tp> struct cmp_lt_ptr 85 { cmp_lt_ptrcv::ml::cmp_lt_ptr86 cmp_lt_ptr() {} operator ()cv::ml::cmp_lt_ptr87 bool operator ()(const _Tp* a, const _Tp* b) const { return *a < *b; } 88 }; 89 setRangeVector(std::vector<int> & vec,int n)90 static inline void setRangeVector(std::vector<int>& vec, int n) 91 { 92 vec.resize(n); 93 for( int i = 0; i < n; i++ ) 94 vec[i] = i; 95 } 96 writeTermCrit(FileStorage & fs,const TermCriteria & termCrit)97 static inline void writeTermCrit(FileStorage& fs, const TermCriteria& termCrit) 98 { 99 if( (termCrit.type & TermCriteria::EPS) != 0 ) 100 fs << "epsilon" << termCrit.epsilon; 101 if( (termCrit.type & TermCriteria::COUNT) != 0 ) 102 fs << "iterations" << termCrit.maxCount; 103 } 104 readTermCrit(const FileNode & fn)105 static inline TermCriteria readTermCrit(const FileNode& fn) 106 { 107 TermCriteria termCrit; 108 double epsilon = (double)fn["epsilon"]; 109 if( epsilon > 0 ) 110 { 111 termCrit.type |= TermCriteria::EPS; 112 termCrit.epsilon = epsilon; 113 } 114 int iters = (int)fn["iterations"]; 115 if( iters > 0 ) 116 { 117 termCrit.type |= TermCriteria::COUNT; 118 termCrit.maxCount = iters; 119 } 120 return termCrit; 121 } 122 123 struct TreeParams 124 { 125 TreeParams(); 126 TreeParams( int maxDepth, int minSampleCount, 127 double regressionAccuracy, bool useSurrogates, 128 int maxCategories, int CVFolds, 129 bool use1SERule, bool truncatePrunedTree, 130 const Mat& priors ); 131 setMaxCategoriescv::ml::TreeParams132 inline void setMaxCategories(int val) 133 { 134 if( val < 2 ) 135 CV_Error( CV_StsOutOfRange, "max_categories should be >= 2" ); 136 maxCategories = std::min(val, 15 ); 137 } setMaxDepthcv::ml::TreeParams138 inline void setMaxDepth(int val) 139 { 140 if( val < 0 ) 141 CV_Error( CV_StsOutOfRange, "max_depth should be >= 0" ); 142 maxDepth = std::min( val, 25 ); 143 } setMinSampleCountcv::ml::TreeParams144 inline void setMinSampleCount(int val) 145 { 146 minSampleCount = std::max(val, 1); 147 } setCVFoldscv::ml::TreeParams148 inline void setCVFolds(int val) 149 { 150 if( val < 0 ) 151 CV_Error( CV_StsOutOfRange, 152 "params.CVFolds should be =0 (the tree is not pruned) " 153 "or n>0 (tree is pruned using n-fold cross-validation)" ); 154 if( val == 1 ) 155 val = 0; 156 CVFolds = val; 157 } setRegressionAccuracycv::ml::TreeParams158 inline void setRegressionAccuracy(float val) 159 { 160 if( val < 0 ) 161 CV_Error( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" ); 162 regressionAccuracy = val; 163 } 164 getMaxCategoriescv::ml::TreeParams165 inline int getMaxCategories() const { return maxCategories; } getMaxDepthcv::ml::TreeParams166 inline int getMaxDepth() const { return maxDepth; } getMinSampleCountcv::ml::TreeParams167 inline int getMinSampleCount() const { return minSampleCount; } getCVFoldscv::ml::TreeParams168 inline int getCVFolds() const { return CVFolds; } getRegressionAccuracycv::ml::TreeParams169 inline float getRegressionAccuracy() const { return regressionAccuracy; } 170 171 CV_IMPL_PROPERTY(bool, UseSurrogates, useSurrogates) 172 CV_IMPL_PROPERTY(bool, Use1SERule, use1SERule) 173 CV_IMPL_PROPERTY(bool, TruncatePrunedTree, truncatePrunedTree) 174 CV_IMPL_PROPERTY_S(cv::Mat, Priors, priors) 175 176 public: 177 bool useSurrogates; 178 bool use1SERule; 179 bool truncatePrunedTree; 180 Mat priors; 181 182 protected: 183 int maxCategories; 184 int maxDepth; 185 int minSampleCount; 186 int CVFolds; 187 float regressionAccuracy; 188 }; 189 190 struct RTreeParams 191 { 192 RTreeParams(); 193 RTreeParams(bool calcVarImportance, int nactiveVars, TermCriteria termCrit ); 194 bool calcVarImportance; 195 int nactiveVars; 196 TermCriteria termCrit; 197 }; 198 199 struct BoostTreeParams 200 { 201 BoostTreeParams(); 202 BoostTreeParams(int boostType, int weakCount, double weightTrimRate); 203 int boostType; 204 int weakCount; 205 double weightTrimRate; 206 }; 207 208 class DTreesImpl : public DTrees 209 { 210 public: 211 struct WNode 212 { WNodecv::ml::DTreesImpl::WNode213 WNode() 214 { 215 class_idx = sample_count = depth = complexity = 0; 216 parent = left = right = split = defaultDir = -1; 217 Tn = INT_MAX; 218 value = maxlr = alpha = node_risk = tree_risk = tree_error = 0.; 219 } 220 221 int class_idx; 222 double Tn; 223 double value; 224 225 int parent; 226 int left; 227 int right; 228 int defaultDir; 229 230 int split; 231 232 int sample_count; 233 int depth; 234 double maxlr; 235 236 // global pruning data 237 int complexity; 238 double alpha; 239 double node_risk, tree_risk, tree_error; 240 }; 241 242 struct WSplit 243 { WSplitcv::ml::DTreesImpl::WSplit244 WSplit() 245 { 246 varIdx = next = 0; 247 inversed = false; 248 quality = c = 0.f; 249 subsetOfs = -1; 250 } 251 252 int varIdx; 253 bool inversed; 254 float quality; 255 int next; 256 float c; 257 int subsetOfs; 258 }; 259 260 struct WorkData 261 { 262 WorkData(const Ptr<TrainData>& _data); 263 264 Ptr<TrainData> data; 265 vector<WNode> wnodes; 266 vector<WSplit> wsplits; 267 vector<int> wsubsets; 268 vector<double> cv_Tn; 269 vector<double> cv_node_risk; 270 vector<double> cv_node_error; 271 vector<int> cv_labels; 272 vector<double> sample_weights; 273 vector<int> cat_responses; 274 vector<double> ord_responses; 275 vector<int> sidx; 276 int maxSubsetSize; 277 }; 278 279 CV_WRAP_SAME_PROPERTY(int, MaxCategories, params) 280 CV_WRAP_SAME_PROPERTY(int, MaxDepth, params) 281 CV_WRAP_SAME_PROPERTY(int, MinSampleCount, params) 282 CV_WRAP_SAME_PROPERTY(int, CVFolds, params) 283 CV_WRAP_SAME_PROPERTY(bool, UseSurrogates, params) 284 CV_WRAP_SAME_PROPERTY(bool, Use1SERule, params) 285 CV_WRAP_SAME_PROPERTY(bool, TruncatePrunedTree, params) 286 CV_WRAP_SAME_PROPERTY(float, RegressionAccuracy, params) 287 CV_WRAP_SAME_PROPERTY_S(cv::Mat, Priors, params) 288 289 DTreesImpl(); 290 virtual ~DTreesImpl(); 291 virtual void clear(); 292 getDefaultName() const293 String getDefaultName() const { return "opencv_ml_dtree"; } isTrained() const294 bool isTrained() const { return !roots.empty(); } isClassifier() const295 bool isClassifier() const { return _isClassifier; } getVarCount() const296 int getVarCount() const { return varType.empty() ? 0 : (int)(varType.size() - 1); } getCatCount(int vi) const297 int getCatCount(int vi) const { return catOfs[vi][1] - catOfs[vi][0]; } getSubsetSize(int vi) const298 int getSubsetSize(int vi) const { return (getCatCount(vi) + 31)/32; } 299 300 virtual void setDParams(const TreeParams& _params); 301 virtual void startTraining( const Ptr<TrainData>& trainData, int flags ); 302 virtual void endTraining(); 303 virtual void initCompVarIdx(); 304 virtual bool train( const Ptr<TrainData>& trainData, int flags ); 305 306 virtual int addTree( const vector<int>& sidx ); 307 virtual int addNodeAndTrySplit( int parent, const vector<int>& sidx ); 308 virtual const vector<int>& getActiveVars(); 309 virtual int findBestSplit( const vector<int>& _sidx ); 310 virtual void calcValue( int nidx, const vector<int>& _sidx ); 311 312 virtual WSplit findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality ); 313 314 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector. 315 virtual void clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels ); 316 virtual WSplit findSplitCatClass( int vi, const vector<int>& _sidx, double initQuality, int* subset ); 317 318 virtual WSplit findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality ); 319 virtual WSplit findSplitCatReg( int vi, const vector<int>& _sidx, double initQuality, int* subset ); 320 321 virtual int calcDir( int splitidx, const vector<int>& _sidx, vector<int>& _sleft, vector<int>& _sright ); 322 virtual int pruneCV( int root ); 323 324 virtual double updateTreeRNC( int root, double T, int fold ); 325 virtual bool cutTree( int root, double T, int fold, double min_alpha ); 326 virtual float predictTrees( const Range& range, const Mat& sample, int flags ) const; 327 virtual float predict( InputArray inputs, OutputArray outputs, int flags ) const; 328 329 virtual void writeTrainingParams( FileStorage& fs ) const; 330 virtual void writeParams( FileStorage& fs ) const; 331 virtual void writeSplit( FileStorage& fs, int splitidx ) const; 332 virtual void writeNode( FileStorage& fs, int nidx, int depth ) const; 333 virtual void writeTree( FileStorage& fs, int root ) const; 334 virtual void write( FileStorage& fs ) const; 335 336 virtual void readParams( const FileNode& fn ); 337 virtual int readSplit( const FileNode& fn ); 338 virtual int readNode( const FileNode& fn ); 339 virtual int readTree( const FileNode& fn ); 340 virtual void read( const FileNode& fn ); 341 getRoots() const342 virtual const std::vector<int>& getRoots() const { return roots; } getNodes() const343 virtual const std::vector<Node>& getNodes() const { return nodes; } getSplits() const344 virtual const std::vector<Split>& getSplits() const { return splits; } getSubsets() const345 virtual const std::vector<int>& getSubsets() const { return subsets; } 346 347 TreeParams params; 348 349 vector<int> varIdx; 350 vector<int> compVarIdx; 351 vector<uchar> varType; 352 vector<Vec2i> catOfs; 353 vector<int> catMap; 354 vector<int> roots; 355 vector<Node> nodes; 356 vector<Split> splits; 357 vector<int> subsets; 358 vector<int> classLabels; 359 vector<float> missingSubst; 360 vector<int> varMapping; 361 bool _isClassifier; 362 363 Ptr<WorkData> w; 364 }; 365 366 template <typename T> readVectorOrMat(const FileNode & node,std::vector<T> & v)367 static inline void readVectorOrMat(const FileNode & node, std::vector<T> & v) 368 { 369 if (node.type() == FileNode::MAP) 370 { 371 Mat m; 372 node >> m; 373 m.copyTo(v); 374 } 375 else if (node.type() == FileNode::SEQ) 376 { 377 node >> v; 378 } 379 } 380 381 }} 382 383 #endif /* __OPENCV_ML_PRECOMP_HPP__ */ 384