• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #ifndef __OPENCV_ML_HPP__
42 #define __OPENCV_ML_HPP__
43 
44 #ifdef __cplusplus
45 #  include "opencv2/core.hpp"
46 #endif
47 
48 #include "opencv2/core/core_c.h"
49 #include <limits.h>
50 
51 #ifdef __cplusplus
52 
53 #include <map>
54 #include <iostream>
55 
56 // Apple defines a check() macro somewhere in the debug headers
57 // that interferes with a method definiton in this header
58 #undef check
59 
60 /****************************************************************************************\
61 *                               Main struct definitions                                  *
62 \****************************************************************************************/
63 
64 /* log(2*PI) */
65 #define CV_LOG2PI (1.8378770664093454835606594728112)
66 
67 /* columns of <trainData> matrix are training samples */
68 #define CV_COL_SAMPLE 0
69 
70 /* rows of <trainData> matrix are training samples */
71 #define CV_ROW_SAMPLE 1
72 
73 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
74 
75 struct CvVectors
76 {
77     int type;
78     int dims, count;
79     CvVectors* next;
80     union
81     {
82         uchar** ptr;
83         float** fl;
84         double** db;
85     } data;
86 };
87 
88 #if 0
89 /* A structure, representing the lattice range of statmodel parameters.
90    It is used for optimizing statmodel parameters by cross-validation method.
91    The lattice is logarithmic, so <step> must be greater then 1. */
92 typedef struct CvParamLattice
93 {
94     double min_val;
95     double max_val;
96     double step;
97 }
98 CvParamLattice;
99 
100 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
101                                          double log_step )
102 {
103     CvParamLattice pl;
104     pl.min_val = MIN( min_val, max_val );
105     pl.max_val = MAX( min_val, max_val );
106     pl.step = MAX( log_step, 1. );
107     return pl;
108 }
109 
110 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
111 {
112     CvParamLattice pl = {0,0,0};
113     return pl;
114 }
115 #endif
116 
117 /* Variable type */
118 #define CV_VAR_NUMERICAL    0
119 #define CV_VAR_ORDERED      0
120 #define CV_VAR_CATEGORICAL  1
121 
122 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
123 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
124 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
125 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
126 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
127 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
128 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
129 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
130 #define CV_TYPE_NAME_ML_ERTREES     "opencv-ml-extremely-randomized-trees"
131 #define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
132 
133 #define CV_TRAIN_ERROR  0
134 #define CV_TEST_ERROR   1
135 
136 class CvStatModel
137 {
138 public:
139     CvStatModel();
140     virtual ~CvStatModel();
141 
142     virtual void clear();
143 
144     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
145     CV_WRAP virtual void load( const char* filename, const char* name=0 );
146 
147     virtual void write( CvFileStorage* storage, const char* name ) const;
148     virtual void read( CvFileStorage* storage, CvFileNode* node );
149 
150 protected:
151     const char* default_model_name;
152 };
153 
154 /****************************************************************************************\
155 *                                 Normal Bayes Classifier                                *
156 \****************************************************************************************/
157 
158 /* The structure, representing the grid range of statmodel parameters.
159    It is used for optimizing statmodel accuracy by varying model parameters,
160    the accuracy estimate being computed by cross-validation.
161    The grid is logarithmic, so <step> must be greater then 1. */
162 
163 class CvMLData;
164 
165 struct CvParamGrid
166 {
167     // SVM params type
168     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
169 
CvParamGridCvParamGrid170     CvParamGrid()
171     {
172         min_val = max_val = step = 0;
173     }
174 
175     CvParamGrid( double min_val, double max_val, double log_step );
176     //CvParamGrid( int param_id );
177     bool check() const;
178 
179     CV_PROP_RW double min_val;
180     CV_PROP_RW double max_val;
181     CV_PROP_RW double step;
182 };
183 
CvParamGrid(double _min_val,double _max_val,double _log_step)184 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
185 {
186     min_val = _min_val;
187     max_val = _max_val;
188     step = _log_step;
189 }
190 
191 class CvNormalBayesClassifier : public CvStatModel
192 {
193 public:
194     CV_WRAP CvNormalBayesClassifier();
195     virtual ~CvNormalBayesClassifier();
196 
197     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
198         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
199 
200     virtual bool train( const CvMat* trainData, const CvMat* responses,
201         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
202 
203     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
204     CV_WRAP virtual void clear();
205 
206     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
207                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
208     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
209                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
210                        bool update=false );
211     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
212 
213     virtual void write( CvFileStorage* storage, const char* name ) const;
214     virtual void read( CvFileStorage* storage, CvFileNode* node );
215 
216 protected:
217     int     var_count, var_all;
218     CvMat*  var_idx;
219     CvMat*  cls_labels;
220     CvMat** count;
221     CvMat** sum;
222     CvMat** productsum;
223     CvMat** avg;
224     CvMat** inv_eigen_values;
225     CvMat** cov_rotate_mats;
226     CvMat*  c;
227 };
228 
229 
230 /****************************************************************************************\
231 *                          K-Nearest Neighbour Classifier                                *
232 \****************************************************************************************/
233 
234 // k Nearest Neighbors
235 class CvKNearest : public CvStatModel
236 {
237 public:
238 
239     CV_WRAP CvKNearest();
240     virtual ~CvKNearest();
241 
242     CvKNearest( const CvMat* trainData, const CvMat* responses,
243                 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
244 
245     virtual bool train( const CvMat* trainData, const CvMat* responses,
246                         const CvMat* sampleIdx=0, bool is_regression=false,
247                         int maxK=32, bool updateBase=false );
248 
249     virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
250         const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
251 
252     CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
253                const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
254 
255     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
256                        const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
257                        int maxK=32, bool updateBase=false );
258 
259     virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
260                                 const float** neighbors=0, cv::Mat* neighborResponses=0,
261                                 cv::Mat* dist=0 ) const;
262     CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
263                                         CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
264 
265     virtual void clear();
266     int get_max_k() const;
267     int get_var_count() const;
268     int get_sample_count() const;
269     bool is_regression() const;
270 
271     virtual float write_results( int k, int k1, int start, int end,
272         const float* neighbor_responses, const float* dist, CvMat* _results,
273         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
274 
275     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
276         float* neighbor_responses, const float** neighbors, float* dist ) const;
277 
278 protected:
279 
280     int max_k, var_count;
281     int total;
282     bool regression;
283     CvVectors* samples;
284 };
285 
286 /****************************************************************************************\
287 *                                   Support Vector Machines                              *
288 \****************************************************************************************/
289 
290 // SVM training parameters
291 struct CvSVMParams
292 {
293     CvSVMParams();
294     CvSVMParams( int svm_type, int kernel_type,
295                  double degree, double gamma, double coef0,
296                  double Cvalue, double nu, double p,
297                  CvMat* class_weights, CvTermCriteria term_crit );
298 
299     CV_PROP_RW int         svm_type;
300     CV_PROP_RW int         kernel_type;
301     CV_PROP_RW double      degree; // for poly
302     CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid/chi2
303     CV_PROP_RW double      coef0;  // for poly/sigmoid
304 
305     CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
306     CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
307     CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
308     CvMat*      class_weights; // for CV_SVM_C_SVC
309     CV_PROP_RW CvTermCriteria term_crit; // termination criteria
310 };
311 
312 
313 struct CvSVMKernel
314 {
315     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
316                                        const float* another, float* results );
317     CvSVMKernel();
318     CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
319     virtual bool create( const CvSVMParams* params, Calc _calc_func );
320     virtual ~CvSVMKernel();
321 
322     virtual void clear();
323     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
324 
325     const CvSVMParams* params;
326     Calc calc_func;
327 
328     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
329                                     const float* another, float* results,
330                                     double alpha, double beta );
331     virtual void calc_intersec( int vcount, int var_count, const float** vecs,
332                             const float* another, float* results );
333     virtual void calc_chi2( int vec_count, int vec_size, const float** vecs,
334                               const float* another, float* results );
335     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
336                               const float* another, float* results );
337     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
338                            const float* another, float* results );
339     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
340                             const float* another, float* results );
341     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
342                                const float* another, float* results );
343 };
344 
345 
346 struct CvSVMKernelRow
347 {
348     CvSVMKernelRow* prev;
349     CvSVMKernelRow* next;
350     float* data;
351 };
352 
353 
354 struct CvSVMSolutionInfo
355 {
356     double obj;
357     double rho;
358     double upper_bound_p;
359     double upper_bound_n;
360     double r;   // for Solver_NU
361 };
362 
363 class CvSVMSolver
364 {
365 public:
366     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
367     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
368     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
369 
370     CvSVMSolver();
371 
372     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
373                  int alpha_count, double* alpha, double Cp, double Cn,
374                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
375                  SelectWorkingSet select_working_set, CalcRho calc_rho );
376     virtual bool create( int count, int var_count, const float** samples, schar* y,
377                  int alpha_count, double* alpha, double Cp, double Cn,
378                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
379                  SelectWorkingSet select_working_set, CalcRho calc_rho );
380     virtual ~CvSVMSolver();
381 
382     virtual void clear();
383     virtual bool solve_generic( CvSVMSolutionInfo& si );
384 
385     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
386                               double Cp, double Cn, CvMemStorage* storage,
387                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
388     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
389                                CvMemStorage* storage, CvSVMKernel* kernel,
390                                double* alpha, CvSVMSolutionInfo& si );
391     virtual bool solve_one_class( int count, int var_count, const float** samples,
392                                   CvMemStorage* storage, CvSVMKernel* kernel,
393                                   double* alpha, CvSVMSolutionInfo& si );
394 
395     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
396                                 CvMemStorage* storage, CvSVMKernel* kernel,
397                                 double* alpha, CvSVMSolutionInfo& si );
398 
399     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
400                                CvMemStorage* storage, CvSVMKernel* kernel,
401                                double* alpha, CvSVMSolutionInfo& si );
402 
403     virtual float* get_row_base( int i, bool* _existed );
404     virtual float* get_row( int i, float* dst );
405 
406     int sample_count;
407     int var_count;
408     int cache_size;
409     int cache_line_size;
410     const float** samples;
411     const CvSVMParams* params;
412     CvMemStorage* storage;
413     CvSVMKernelRow lru_list;
414     CvSVMKernelRow* rows;
415 
416     int alpha_count;
417 
418     double* G;
419     double* alpha;
420 
421     // -1 - lower bound, 0 - free, 1 - upper bound
422     schar* alpha_status;
423 
424     schar* y;
425     double* b;
426     float* buf[2];
427     double eps;
428     int max_iter;
429     double C[2];  // C[0] == Cn, C[1] == Cp
430     CvSVMKernel* kernel;
431 
432     SelectWorkingSet select_working_set_func;
433     CalcRho calc_rho_func;
434     GetRow get_row_func;
435 
436     virtual bool select_working_set( int& i, int& j );
437     virtual bool select_working_set_nu_svm( int& i, int& j );
438     virtual void calc_rho( double& rho, double& r );
439     virtual void calc_rho_nu_svm( double& rho, double& r );
440 
441     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
442     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
443     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
444 };
445 
446 
447 struct CvSVMDecisionFunc
448 {
449     double rho;
450     int sv_count;
451     double* alpha;
452     int* sv_index;
453 };
454 
455 
456 // SVM model
457 class CvSVM : public CvStatModel
458 {
459 public:
460     // SVM type
461     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
462 
463     // SVM kernel type
464     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 };
465 
466     // SVM params type
467     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
468 
469     CV_WRAP CvSVM();
470     virtual ~CvSVM();
471 
472     CvSVM( const CvMat* trainData, const CvMat* responses,
473            const CvMat* varIdx=0, const CvMat* sampleIdx=0,
474            CvSVMParams params=CvSVMParams() );
475 
476     virtual bool train( const CvMat* trainData, const CvMat* responses,
477                         const CvMat* varIdx=0, const CvMat* sampleIdx=0,
478                         CvSVMParams params=CvSVMParams() );
479 
480     virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
481         const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
482         int kfold = 10,
483         CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
484         CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
485         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
486         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
487         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
488         CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
489         bool balanced=false );
490 
491     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
492     virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const;
493 
494     CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
495           const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
496           CvSVMParams params=CvSVMParams() );
497 
498     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
499                        const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
500                        CvSVMParams params=CvSVMParams() );
501 
502     CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
503                             const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
504                             int k_fold = 10,
505                             CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
506                             CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
507                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
508                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
509                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
510                             CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
511                             bool balanced=false);
512     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
513     CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
514 
515     CV_WRAP virtual int get_support_vector_count() const;
516     virtual const float* get_support_vector(int i) const;
get_params() const517     virtual CvSVMParams get_params() const { return params; }
518     CV_WRAP virtual void clear();
519 
get_decision_function() const520     virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; }
521 
522     static CvParamGrid get_default_grid( int param_id );
523 
524     virtual void write( CvFileStorage* storage, const char* name ) const;
525     virtual void read( CvFileStorage* storage, CvFileNode* node );
get_var_count() const526     CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
527 
528 protected:
529 
530     virtual bool set_params( const CvSVMParams& params );
531     virtual bool train1( int sample_count, int var_count, const float** samples,
532                     const void* responses, double Cp, double Cn,
533                     CvMemStorage* _storage, double* alpha, double& rho );
534     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
535                     const CvMat* responses, CvMemStorage* _storage, double* alpha );
536     virtual void create_kernel();
537     virtual void create_solver();
538 
539     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
540 
541     virtual void write_params( CvFileStorage* fs ) const;
542     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
543 
544     void optimize_linear_svm();
545 
546     CvSVMParams params;
547     CvMat* class_labels;
548     int var_all;
549     float** sv;
550     int sv_total;
551     CvMat* var_idx;
552     CvMat* class_weights;
553     CvSVMDecisionFunc* decision_func;
554     CvMemStorage* storage;
555 
556     CvSVMSolver* solver;
557     CvSVMKernel* kernel;
558 
559 private:
560     CvSVM(const CvSVM&);
561     CvSVM& operator = (const CvSVM&);
562 };
563 
564 /****************************************************************************************\
565 *                                      Decision Tree                                     *
566 \****************************************************************************************/\
567 struct CvPair16u32s
568 {
569     unsigned short* u;
570     int* i;
571 };
572 
573 
574 #define CV_DTREE_CAT_DIR(idx,subset) \
575     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
576 
577 struct CvDTreeSplit
578 {
579     int var_idx;
580     int condensed_idx;
581     int inversed;
582     float quality;
583     CvDTreeSplit* next;
584     union
585     {
586         int subset[2];
587         struct
588         {
589             float c;
590             int split_point;
591         }
592         ord;
593     };
594 };
595 
596 struct CvDTreeNode
597 {
598     int class_idx;
599     int Tn;
600     double value;
601 
602     CvDTreeNode* parent;
603     CvDTreeNode* left;
604     CvDTreeNode* right;
605 
606     CvDTreeSplit* split;
607 
608     int sample_count;
609     int depth;
610     int* num_valid;
611     int offset;
612     int buf_idx;
613     double maxlr;
614 
615     // global pruning data
616     int complexity;
617     double alpha;
618     double node_risk, tree_risk, tree_error;
619 
620     // cross-validation pruning data
621     int* cv_Tn;
622     double* cv_node_risk;
623     double* cv_node_error;
624 
get_num_validCvDTreeNode625     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
set_num_validCvDTreeNode626     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
627 };
628 
629 
630 struct CvDTreeParams
631 {
632     CV_PROP_RW int   max_categories;
633     CV_PROP_RW int   max_depth;
634     CV_PROP_RW int   min_sample_count;
635     CV_PROP_RW int   cv_folds;
636     CV_PROP_RW bool  use_surrogates;
637     CV_PROP_RW bool  use_1se_rule;
638     CV_PROP_RW bool  truncate_pruned_tree;
639     CV_PROP_RW float regression_accuracy;
640     const float* priors;
641 
642     CvDTreeParams();
643     CvDTreeParams( int max_depth, int min_sample_count,
644                    float regression_accuracy, bool use_surrogates,
645                    int max_categories, int cv_folds,
646                    bool use_1se_rule, bool truncate_pruned_tree,
647                    const float* priors );
648 };
649 
650 
651 struct CvDTreeTrainData
652 {
653     CvDTreeTrainData();
654     CvDTreeTrainData( const CvMat* trainData, int tflag,
655                       const CvMat* responses, const CvMat* varIdx=0,
656                       const CvMat* sampleIdx=0, const CvMat* varType=0,
657                       const CvMat* missingDataMask=0,
658                       const CvDTreeParams& params=CvDTreeParams(),
659                       bool _shared=false, bool _add_labels=false );
660     virtual ~CvDTreeTrainData();
661 
662     virtual void set_data( const CvMat* trainData, int tflag,
663                           const CvMat* responses, const CvMat* varIdx=0,
664                           const CvMat* sampleIdx=0, const CvMat* varType=0,
665                           const CvMat* missingDataMask=0,
666                           const CvDTreeParams& params=CvDTreeParams(),
667                           bool _shared=false, bool _add_labels=false,
668                           bool _update_data=false );
669     virtual void do_responses_copy();
670 
671     virtual void get_vectors( const CvMat* _subsample_idx,
672          float* values, uchar* missing, float* responses, bool get_class_idx=false );
673 
674     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
675 
676     virtual void write_params( CvFileStorage* fs ) const;
677     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
678 
679     // release all the data
680     virtual void clear();
681 
682     int get_num_classes() const;
683     int get_var_type(int vi) const;
get_work_var_countCvDTreeTrainData684     int get_work_var_count() const {return work_var_count;}
685 
686     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
687     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
688     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
689     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
690     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
691     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
692                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
693     virtual int get_child_buf_idx( CvDTreeNode* n );
694 
695     ////////////////////////////////////
696 
697     virtual bool set_params( const CvDTreeParams& params );
698     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
699                                    int storage_idx, int offset );
700 
701     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
702                 int split_point, int inversed, float quality );
703     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
704     virtual void free_node_data( CvDTreeNode* node );
705     virtual void free_train_data();
706     virtual void free_node( CvDTreeNode* node );
707 
708     int sample_count, var_all, var_count, max_c_count;
709     int ord_var_count, cat_var_count, work_var_count;
710     bool have_labels, have_priors;
711     bool is_classifier;
712     int tflag;
713 
714     const CvMat* train_data;
715     const CvMat* responses;
716     CvMat* responses_copy; // used in Boosting
717 
718     int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
719     bool shared;
720     int is_buf_16u;
721 
722     CvMat* cat_count;
723     CvMat* cat_ofs;
724     CvMat* cat_map;
725 
726     CvMat* counts;
727     CvMat* buf;
get_length_subbufCvDTreeTrainData728     inline size_t get_length_subbuf() const
729     {
730         size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
731         return res;
732     }
733 
734     CvMat* direction;
735     CvMat* split_buf;
736 
737     CvMat* var_idx;
738     CvMat* var_type; // i-th element =
739                      //   k<0  - ordered
740                      //   k>=0 - categorical, see k-th element of cat_* arrays
741     CvMat* priors;
742     CvMat* priors_mult;
743 
744     CvDTreeParams params;
745 
746     CvMemStorage* tree_storage;
747     CvMemStorage* temp_storage;
748 
749     CvDTreeNode* data_root;
750 
751     CvSet* node_heap;
752     CvSet* split_heap;
753     CvSet* cv_heap;
754     CvSet* nv_heap;
755 
756     cv::RNG* rng;
757 };
758 
759 class CvDTree;
760 class CvForestTree;
761 
762 namespace cv
763 {
764     struct DTreeBestSplitFinder;
765     struct ForestTreeBestSplitFinder;
766 }
767 
768 class CvDTree : public CvStatModel
769 {
770 public:
771     CV_WRAP CvDTree();
772     virtual ~CvDTree();
773 
774     virtual bool train( const CvMat* trainData, int tflag,
775                         const CvMat* responses, const CvMat* varIdx=0,
776                         const CvMat* sampleIdx=0, const CvMat* varType=0,
777                         const CvMat* missingDataMask=0,
778                         CvDTreeParams params=CvDTreeParams() );
779 
780     virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
781 
782     // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
783     virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
784 
785     virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
786 
787     virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
788                                   bool preprocessedInput=false ) const;
789 
790     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
791                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
792                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
793                        const cv::Mat& missingDataMask=cv::Mat(),
794                        CvDTreeParams params=CvDTreeParams() );
795 
796     CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
797                                   bool preprocessedInput=false ) const;
798     CV_WRAP virtual cv::Mat getVarImportance();
799 
800     virtual const CvMat* get_var_importance();
801     CV_WRAP virtual void clear();
802 
803     virtual void read( CvFileStorage* fs, CvFileNode* node );
804     virtual void write( CvFileStorage* fs, const char* name ) const;
805 
806     // special read & write methods for trees in the tree ensembles
807     virtual void read( CvFileStorage* fs, CvFileNode* node,
808                        CvDTreeTrainData* data );
809     virtual void write( CvFileStorage* fs ) const;
810 
811     const CvDTreeNode* get_root() const;
812     int get_pruned_tree_idx() const;
813     CvDTreeTrainData* get_data();
814 
815 protected:
816     friend struct cv::DTreeBestSplitFinder;
817 
818     virtual bool do_train( const CvMat* _subsample_idx );
819 
820     virtual void try_split_node( CvDTreeNode* n );
821     virtual void split_node_data( CvDTreeNode* n );
822     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
823     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
824                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
825     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
826                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
827     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
828                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
829     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
830                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
831     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
832     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
833     virtual double calc_node_dir( CvDTreeNode* node );
834     virtual void complete_node_dir( CvDTreeNode* node );
835     virtual void cluster_categories( const int* vectors, int vector_count,
836         int var_count, int* sums, int k, int* cluster_labels );
837 
838     virtual void calc_node_value( CvDTreeNode* node );
839 
840     virtual void prune_cv();
841     virtual double update_tree_rnc( int T, int fold );
842     virtual int cut_tree( int T, int fold, double min_alpha );
843     virtual void free_prune_data(bool cut_tree);
844     virtual void free_tree();
845 
846     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
847     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
848     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
849     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
850     virtual void write_tree_nodes( CvFileStorage* fs ) const;
851     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
852 
853     CvDTreeNode* root;
854     CvMat* var_importance;
855     CvDTreeTrainData* data;
856     CvMat train_data_hdr, responses_hdr;
857     cv::Mat train_data_mat, responses_mat;
858 
859 public:
860     int pruned_tree_idx;
861 };
862 
863 
864 /****************************************************************************************\
865 *                                   Random Trees Classifier                              *
866 \****************************************************************************************/
867 
868 class CvRTrees;
869 
870 class CvForestTree: public CvDTree
871 {
872 public:
873     CvForestTree();
874     virtual ~CvForestTree();
875 
876     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
877 
get_var_count() const878     virtual int get_var_count() const {return data ? data->var_count : 0;}
879     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
880 
881     /* dummy methods to avoid warnings: BEGIN */
882     virtual bool train( const CvMat* trainData, int tflag,
883                         const CvMat* responses, const CvMat* varIdx=0,
884                         const CvMat* sampleIdx=0, const CvMat* varType=0,
885                         const CvMat* missingDataMask=0,
886                         CvDTreeParams params=CvDTreeParams() );
887 
888     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
889     virtual void read( CvFileStorage* fs, CvFileNode* node );
890     virtual void read( CvFileStorage* fs, CvFileNode* node,
891                        CvDTreeTrainData* data );
892     /* dummy methods to avoid warnings: END */
893 
894 protected:
895     friend struct cv::ForestTreeBestSplitFinder;
896 
897     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
898     CvRTrees* forest;
899 };
900 
901 
902 struct CvRTParams : public CvDTreeParams
903 {
904     //Parameters for the forest
905     CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
906     CV_PROP_RW int nactive_vars;
907     CV_PROP_RW CvTermCriteria term_crit;
908 
909     CvRTParams();
910     CvRTParams( int max_depth, int min_sample_count,
911                 float regression_accuracy, bool use_surrogates,
912                 int max_categories, const float* priors, bool calc_var_importance,
913                 int nactive_vars, int max_num_of_trees_in_the_forest,
914                 float forest_accuracy, int termcrit_type );
915 };
916 
917 
918 class CvRTrees : public CvStatModel
919 {
920 public:
921     CV_WRAP CvRTrees();
922     virtual ~CvRTrees();
923     virtual bool train( const CvMat* trainData, int tflag,
924                         const CvMat* responses, const CvMat* varIdx=0,
925                         const CvMat* sampleIdx=0, const CvMat* varType=0,
926                         const CvMat* missingDataMask=0,
927                         CvRTParams params=CvRTParams() );
928 
929     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
930     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
931     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
932 
933     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
934                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
935                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
936                        const cv::Mat& missingDataMask=cv::Mat(),
937                        CvRTParams params=CvRTParams() );
938     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
939     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
940     CV_WRAP virtual cv::Mat getVarImportance();
941 
942     CV_WRAP virtual void clear();
943 
944     virtual const CvMat* get_var_importance();
945     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
946         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
947 
948     virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
949 
950     virtual float get_train_error();
951 
952     virtual void read( CvFileStorage* fs, CvFileNode* node );
953     virtual void write( CvFileStorage* fs, const char* name ) const;
954 
955     CvMat* get_active_var_mask();
956     CvRNG* get_rng();
957 
958     int get_tree_count() const;
959     CvForestTree* get_tree(int i) const;
960 
961 protected:
962     virtual cv::String getName() const;
963 
964     virtual bool grow_forest( const CvTermCriteria term_crit );
965 
966     // array of the trees of the forest
967     CvForestTree** trees;
968     CvDTreeTrainData* data;
969     CvMat train_data_hdr, responses_hdr;
970     cv::Mat train_data_mat, responses_mat;
971     int ntrees;
972     int nclasses;
973     double oob_error;
974     CvMat* var_importance;
975     int nsamples;
976 
977     cv::RNG* rng;
978     CvMat* active_var_mask;
979 };
980 
981 /****************************************************************************************\
982 *                           Extremely randomized trees Classifier                        *
983 \****************************************************************************************/
984 struct CvERTreeTrainData : public CvDTreeTrainData
985 {
986     virtual void set_data( const CvMat* trainData, int tflag,
987                           const CvMat* responses, const CvMat* varIdx=0,
988                           const CvMat* sampleIdx=0, const CvMat* varType=0,
989                           const CvMat* missingDataMask=0,
990                           const CvDTreeParams& params=CvDTreeParams(),
991                           bool _shared=false, bool _add_labels=false,
992                           bool _update_data=false );
993     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
994                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
995     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
996     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
997     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
998     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
999                               float* responses, bool get_class_idx=false );
1000     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
1001     const CvMat* missing_mask;
1002 };
1003 
1004 class CvForestERTree : public CvForestTree
1005 {
1006 protected:
1007     virtual double calc_node_dir( CvDTreeNode* node );
1008     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1009         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1010     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1011         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1012     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1013         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1014     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1015         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1016     virtual void split_node_data( CvDTreeNode* n );
1017 };
1018 
1019 class CvERTrees : public CvRTrees
1020 {
1021 public:
1022     CV_WRAP CvERTrees();
1023     virtual ~CvERTrees();
1024     virtual bool train( const CvMat* trainData, int tflag,
1025                         const CvMat* responses, const CvMat* varIdx=0,
1026                         const CvMat* sampleIdx=0, const CvMat* varType=0,
1027                         const CvMat* missingDataMask=0,
1028                         CvRTParams params=CvRTParams());
1029     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1030                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1031                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1032                        const cv::Mat& missingDataMask=cv::Mat(),
1033                        CvRTParams params=CvRTParams());
1034     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
1035 protected:
1036     virtual cv::String getName() const;
1037     virtual bool grow_forest( const CvTermCriteria term_crit );
1038 };
1039 
1040 
1041 /****************************************************************************************\
1042 *                                   Boosted tree classifier                              *
1043 \****************************************************************************************/
1044 
1045 struct CvBoostParams : public CvDTreeParams
1046 {
1047     CV_PROP_RW int boost_type;
1048     CV_PROP_RW int weak_count;
1049     CV_PROP_RW int split_criteria;
1050     CV_PROP_RW double weight_trim_rate;
1051 
1052     CvBoostParams();
1053     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1054                    int max_depth, bool use_surrogates, const float* priors );
1055 };
1056 
1057 
1058 class CvBoost;
1059 
1060 class CvBoostTree: public CvDTree
1061 {
1062 public:
1063     CvBoostTree();
1064     virtual ~CvBoostTree();
1065 
1066     virtual bool train( CvDTreeTrainData* trainData,
1067                         const CvMat* subsample_idx, CvBoost* ensemble );
1068 
1069     virtual void scale( double s );
1070     virtual void read( CvFileStorage* fs, CvFileNode* node,
1071                        CvBoost* ensemble, CvDTreeTrainData* _data );
1072     virtual void clear();
1073 
1074     /* dummy methods to avoid warnings: BEGIN */
1075     virtual bool train( const CvMat* trainData, int tflag,
1076                         const CvMat* responses, const CvMat* varIdx=0,
1077                         const CvMat* sampleIdx=0, const CvMat* varType=0,
1078                         const CvMat* missingDataMask=0,
1079                         CvDTreeParams params=CvDTreeParams() );
1080     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
1081 
1082     virtual void read( CvFileStorage* fs, CvFileNode* node );
1083     virtual void read( CvFileStorage* fs, CvFileNode* node,
1084                        CvDTreeTrainData* data );
1085     /* dummy methods to avoid warnings: END */
1086 
1087 protected:
1088 
1089     virtual void try_split_node( CvDTreeNode* n );
1090     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1091     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
1092     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
1093         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1094     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
1095         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1096     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
1097         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1098     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
1099         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
1100     virtual void calc_node_value( CvDTreeNode* n );
1101     virtual double calc_node_dir( CvDTreeNode* n );
1102 
1103     CvBoost* ensemble;
1104 };
1105 
1106 
1107 class CvBoost : public CvStatModel
1108 {
1109 public:
1110     // Boosting type
1111     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1112 
1113     // Splitting criteria
1114     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1115 
1116     CV_WRAP CvBoost();
1117     virtual ~CvBoost();
1118 
1119     CvBoost( const CvMat* trainData, int tflag,
1120              const CvMat* responses, const CvMat* varIdx=0,
1121              const CvMat* sampleIdx=0, const CvMat* varType=0,
1122              const CvMat* missingDataMask=0,
1123              CvBoostParams params=CvBoostParams() );
1124 
1125     virtual bool train( const CvMat* trainData, int tflag,
1126              const CvMat* responses, const CvMat* varIdx=0,
1127              const CvMat* sampleIdx=0, const CvMat* varType=0,
1128              const CvMat* missingDataMask=0,
1129              CvBoostParams params=CvBoostParams(),
1130              bool update=false );
1131 
1132     virtual bool train( CvMLData* data,
1133              CvBoostParams params=CvBoostParams(),
1134              bool update=false );
1135 
1136     virtual float predict( const CvMat* sample, const CvMat* missing=0,
1137                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1138                            bool raw_mode=false, bool return_sum=false ) const;
1139 
1140     CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
1141             const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1142             const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1143             const cv::Mat& missingDataMask=cv::Mat(),
1144             CvBoostParams params=CvBoostParams() );
1145 
1146     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1147                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1148                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1149                        const cv::Mat& missingDataMask=cv::Mat(),
1150                        CvBoostParams params=CvBoostParams(),
1151                        bool update=false );
1152 
1153     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1154                                    const cv::Range& slice=cv::Range::all(), bool rawMode=false,
1155                                    bool returnSum=false ) const;
1156 
1157     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
1158 
1159     CV_WRAP virtual void prune( CvSlice slice );
1160 
1161     CV_WRAP virtual void clear();
1162 
1163     virtual void write( CvFileStorage* storage, const char* name ) const;
1164     virtual void read( CvFileStorage* storage, CvFileNode* node );
1165     virtual const CvMat* get_active_vars(bool absolute_idx=true);
1166 
1167     CvSeq* get_weak_predictors();
1168 
1169     CvMat* get_weights();
1170     CvMat* get_subtree_weights();
1171     CvMat* get_weak_response();
1172     const CvBoostParams& get_params() const;
1173     const CvDTreeTrainData* get_data() const;
1174 
1175 protected:
1176 
1177     virtual bool set_params( const CvBoostParams& params );
1178     virtual void update_weights( CvBoostTree* tree );
1179     virtual void trim_weights();
1180     virtual void write_params( CvFileStorage* fs ) const;
1181     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1182 
1183     virtual void initialize_weights(double (&p)[2]);
1184 
1185     CvDTreeTrainData* data;
1186     CvMat train_data_hdr, responses_hdr;
1187     cv::Mat train_data_mat, responses_mat;
1188     CvBoostParams params;
1189     CvSeq* weak;
1190 
1191     CvMat* active_vars;
1192     CvMat* active_vars_abs;
1193     bool have_active_cat_vars;
1194 
1195     CvMat* orig_response;
1196     CvMat* sum_response;
1197     CvMat* weak_eval;
1198     CvMat* subsample_mask;
1199     CvMat* weights;
1200     CvMat* subtree_weights;
1201     bool have_subsample;
1202 };
1203 
1204 
1205 /****************************************************************************************\
1206 *                                   Gradient Boosted Trees                               *
1207 \****************************************************************************************/
1208 
1209 // DataType: STRUCT CvGBTreesParams
1210 // Parameters of GBT (Gradient Boosted trees model), including single
1211 // tree settings and ensemble parameters.
1212 //
1213 // weak_count          - count of trees in the ensemble
1214 // loss_function_type  - loss function used for ensemble training
1215 // subsample_portion   - portion of whole training set used for
1216 //                       every single tree training.
1217 //                       subsample_portion value is in (0.0, 1.0].
1218 //                       subsample_portion == 1.0 when whole dataset is
1219 //                       used on each step. Count of sample used on each
1220 //                       step is computed as
1221 //                       int(total_samples_count * subsample_portion).
1222 // shrinkage           - regularization parameter.
1223 //                       Each tree prediction is multiplied on shrinkage value.
1224 
1225 
1226 struct CvGBTreesParams : public CvDTreeParams
1227 {
1228     CV_PROP_RW int weak_count;
1229     CV_PROP_RW int loss_function_type;
1230     CV_PROP_RW float subsample_portion;
1231     CV_PROP_RW float shrinkage;
1232 
1233     CvGBTreesParams();
1234     CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
1235         float subsample_portion, int max_depth, bool use_surrogates );
1236 };
1237 
1238 // DataType: CLASS CvGBTrees
1239 // Gradient Boosting Trees (GBT) algorithm implementation.
1240 //
1241 // data             - training dataset
1242 // params           - parameters of the CvGBTrees
1243 // weak             - array[0..(class_count-1)] of CvSeq
1244 //                    for storing tree ensembles
1245 // orig_response    - original responses of the training set samples
1246 // sum_response     - predicitons of the current model on the training dataset.
1247 //                    this matrix is updated on every iteration.
1248 // sum_response_tmp - predicitons of the model on the training set on the next
1249 //                    step. On every iteration values of sum_responses_tmp are
1250 //                    computed via sum_responses values. When the current
1251 //                    step is complete sum_response values become equal to
1252 //                    sum_responses_tmp.
1253 // sampleIdx       - indices of samples used for training the ensemble.
1254 //                    CvGBTrees training procedure takes a set of samples
1255 //                    (train_data) and a set of responses (responses).
1256 //                    Only pairs (train_data[i], responses[i]), where i is
1257 //                    in sample_idx are used for training the ensemble.
1258 // subsample_train  - indices of samples used for training a single decision
1259 //                    tree on the current step. This indices are countered
1260 //                    relatively to the sample_idx, so that pairs
1261 //                    (train_data[sample_idx[i]], responses[sample_idx[i]])
1262 //                    are used for training a decision tree.
1263 //                    Training set is randomly splited
1264 //                    in two parts (subsample_train and subsample_test)
1265 //                    on every iteration accordingly to the portion parameter.
1266 // subsample_test   - relative indices of samples from the training set,
1267 //                    which are not used for training a tree on the current
1268 //                    step.
1269 // missing          - mask of the missing values in the training set. This
1270 //                    matrix has the same size as train_data. 1 - missing
1271 //                    value, 0 - not a missing value.
1272 // class_labels     - output class labels map.
1273 // rng              - random number generator. Used for spliting the
1274 //                    training set.
1275 // class_count      - count of output classes.
1276 //                    class_count == 1 in the case of regression,
1277 //                    and > 1 in the case of classification.
1278 // delta            - Huber loss function parameter.
1279 // base_value       - start point of the gradient descent procedure.
1280 //                    model prediction is
1281 //                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
1282 //                    f_0 is the base value.
1283 
1284 
1285 
1286 class CvGBTrees : public CvStatModel
1287 {
1288 public:
1289 
1290     /*
1291     // DataType: ENUM
1292     // Loss functions implemented in CvGBTrees.
1293     //
1294     // SQUARED_LOSS
1295     // problem: regression
1296     // loss = (x - x')^2
1297     //
1298     // ABSOLUTE_LOSS
1299     // problem: regression
1300     // loss = abs(x - x')
1301     //
1302     // HUBER_LOSS
1303     // problem: regression
1304     // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
1305     //           1/2*(x - x')^2, if abs(x - x') <= delta,
1306     //           where delta is the alpha-quantile of pseudo responses from
1307     //           the training set.
1308     //
1309     // DEVIANCE_LOSS
1310     // problem: classification
1311     //
1312     */
1313     enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
1314 
1315 
1316     /*
1317     // Default constructor. Creates a model only (without training).
1318     // Should be followed by one form of the train(...) function.
1319     //
1320     // API
1321     // CvGBTrees();
1322 
1323     // INPUT
1324     // OUTPUT
1325     // RESULT
1326     */
1327     CV_WRAP CvGBTrees();
1328 
1329 
1330     /*
1331     // Full form constructor. Creates a gradient boosting model and does the
1332     // train.
1333     //
1334     // API
1335     // CvGBTrees( const CvMat* trainData, int tflag,
1336              const CvMat* responses, const CvMat* varIdx=0,
1337              const CvMat* sampleIdx=0, const CvMat* varType=0,
1338              const CvMat* missingDataMask=0,
1339              CvGBTreesParams params=CvGBTreesParams() );
1340 
1341     // INPUT
1342     // trainData    - a set of input feature vectors.
1343     //                  size of matrix is
1344     //                  <count of samples> x <variables count>
1345     //                  or <variables count> x <count of samples>
1346     //                  depending on the tflag parameter.
1347     //                  matrix values are float.
1348     // tflag         - a flag showing how do samples stored in the
1349     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1350     //                  or column by column (tflag=CV_COL_SAMPLE).
1351     // responses     - a vector of responses corresponding to the samples
1352     //                  in trainData.
1353     // varIdx       - indices of used variables. zero value means that all
1354     //                  variables are active.
1355     // sampleIdx    - indices of used samples. zero value means that all
1356     //                  samples from trainData are in the training set.
1357     // varType      - vector of <variables count> length. gives every
1358     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1359     //                  varType = 0 means all variables are numerical.
1360     // missingDataMask  - a mask of misiing values in trainData.
1361     //                  missingDataMask = 0 means that there are no missing
1362     //                  values.
1363     // params         - parameters of GTB algorithm.
1364     // OUTPUT
1365     // RESULT
1366     */
1367     CvGBTrees( const CvMat* trainData, int tflag,
1368              const CvMat* responses, const CvMat* varIdx=0,
1369              const CvMat* sampleIdx=0, const CvMat* varType=0,
1370              const CvMat* missingDataMask=0,
1371              CvGBTreesParams params=CvGBTreesParams() );
1372 
1373 
1374     /*
1375     // Destructor.
1376     */
1377     virtual ~CvGBTrees();
1378 
1379 
1380     /*
1381     // Gradient tree boosting model training
1382     //
1383     // API
1384     // virtual bool train( const CvMat* trainData, int tflag,
1385              const CvMat* responses, const CvMat* varIdx=0,
1386              const CvMat* sampleIdx=0, const CvMat* varType=0,
1387              const CvMat* missingDataMask=0,
1388              CvGBTreesParams params=CvGBTreesParams(),
1389              bool update=false );
1390 
1391     // INPUT
1392     // trainData    - a set of input feature vectors.
1393     //                  size of matrix is
1394     //                  <count of samples> x <variables count>
1395     //                  or <variables count> x <count of samples>
1396     //                  depending on the tflag parameter.
1397     //                  matrix values are float.
1398     // tflag         - a flag showing how do samples stored in the
1399     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
1400     //                  or column by column (tflag=CV_COL_SAMPLE).
1401     // responses     - a vector of responses corresponding to the samples
1402     //                  in trainData.
1403     // varIdx       - indices of used variables. zero value means that all
1404     //                  variables are active.
1405     // sampleIdx    - indices of used samples. zero value means that all
1406     //                  samples from trainData are in the training set.
1407     // varType      - vector of <variables count> length. gives every
1408     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
1409     //                  varType = 0 means all variables are numerical.
1410     // missingDataMask  - a mask of misiing values in trainData.
1411     //                  missingDataMask = 0 means that there are no missing
1412     //                  values.
1413     // params         - parameters of GTB algorithm.
1414     // update         - is not supported now. (!)
1415     // OUTPUT
1416     // RESULT
1417     // Error state.
1418     */
1419     virtual bool train( const CvMat* trainData, int tflag,
1420              const CvMat* responses, const CvMat* varIdx=0,
1421              const CvMat* sampleIdx=0, const CvMat* varType=0,
1422              const CvMat* missingDataMask=0,
1423              CvGBTreesParams params=CvGBTreesParams(),
1424              bool update=false );
1425 
1426 
1427     /*
1428     // Gradient tree boosting model training
1429     //
1430     // API
1431     // virtual bool train( CvMLData* data,
1432              CvGBTreesParams params=CvGBTreesParams(),
1433              bool update=false ) {return false;}
1434 
1435     // INPUT
1436     // data          - training set.
1437     // params        - parameters of GTB algorithm.
1438     // update        - is not supported now. (!)
1439     // OUTPUT
1440     // RESULT
1441     // Error state.
1442     */
1443     virtual bool train( CvMLData* data,
1444              CvGBTreesParams params=CvGBTreesParams(),
1445              bool update=false );
1446 
1447 
1448     /*
1449     // Response value prediction
1450     //
1451     // API
1452     // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1453              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1454              int k=-1 ) const;
1455 
1456     // INPUT
1457     // sample         - input sample of the same type as in the training set.
1458     // missing        - missing values mask. missing=0 if there are no
1459     //                   missing values in sample vector.
1460     // weak_responses  - predictions of all of the trees.
1461     //                   not implemented (!)
1462     // slice           - part of the ensemble used for prediction.
1463     //                   slice = CV_WHOLE_SEQ when all trees are used.
1464     // k               - number of ensemble used.
1465     //                   k is in {-1,0,1,..,<count of output classes-1>}.
1466     //                   in the case of classification problem
1467     //                   <count of output classes-1> ensembles are built.
1468     //                   If k = -1 ordinary prediction is the result,
1469     //                   otherwise function gives the prediction of the
1470     //                   k-th ensemble only.
1471     // OUTPUT
1472     // RESULT
1473     // Predicted value.
1474     */
1475     virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
1476             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1477             int k=-1 ) const;
1478 
1479     /*
1480     // Response value prediction.
1481     // Parallel version (in the case of TBB existence)
1482     //
1483     // API
1484     // virtual float predict( const CvMat* sample, const CvMat* missing=0,
1485              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
1486              int k=-1 ) const;
1487 
1488     // INPUT
1489     // sample         - input sample of the same type as in the training set.
1490     // missing        - missing values mask. missing=0 if there are no
1491     //                   missing values in sample vector.
1492     // weak_responses  - predictions of all of the trees.
1493     //                   not implemented (!)
1494     // slice           - part of the ensemble used for prediction.
1495     //                   slice = CV_WHOLE_SEQ when all trees are used.
1496     // k               - number of ensemble used.
1497     //                   k is in {-1,0,1,..,<count of output classes-1>}.
1498     //                   in the case of classification problem
1499     //                   <count of output classes-1> ensembles are built.
1500     //                   If k = -1 ordinary prediction is the result,
1501     //                   otherwise function gives the prediction of the
1502     //                   k-th ensemble only.
1503     // OUTPUT
1504     // RESULT
1505     // Predicted value.
1506     */
1507     virtual float predict( const CvMat* sample, const CvMat* missing=0,
1508             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
1509             int k=-1 ) const;
1510 
1511     /*
1512     // Deletes all the data.
1513     //
1514     // API
1515     // virtual void clear();
1516 
1517     // INPUT
1518     // OUTPUT
1519     // delete data, weak, orig_response, sum_response,
1520     //        weak_eval, subsample_train, subsample_test,
1521     //        sample_idx, missing, lass_labels
1522     // delta = 0.0
1523     // RESULT
1524     */
1525     CV_WRAP virtual void clear();
1526 
1527     /*
1528     // Compute error on the train/test set.
1529     //
1530     // API
1531     // virtual float calc_error( CvMLData* _data, int type,
1532     //        std::vector<float> *resp = 0 );
1533     //
1534     // INPUT
1535     // data  - dataset
1536     // type  - defines which error is to compute: train (CV_TRAIN_ERROR) or
1537     //         test (CV_TEST_ERROR).
1538     // OUTPUT
1539     // resp  - vector of predicitons
1540     // RESULT
1541     // Error value.
1542     */
1543     virtual float calc_error( CvMLData* _data, int type,
1544             std::vector<float> *resp = 0 );
1545 
1546     /*
1547     //
1548     // Write parameters of the gtb model and data. Write learned model.
1549     //
1550     // API
1551     // virtual void write( CvFileStorage* fs, const char* name ) const;
1552     //
1553     // INPUT
1554     // fs     - file storage to read parameters from.
1555     // name   - model name.
1556     // OUTPUT
1557     // RESULT
1558     */
1559     virtual void write( CvFileStorage* fs, const char* name ) const;
1560 
1561 
1562     /*
1563     //
1564     // Read parameters of the gtb model and data. Read learned model.
1565     //
1566     // API
1567     // virtual void read( CvFileStorage* fs, CvFileNode* node );
1568     //
1569     // INPUT
1570     // fs     - file storage to read parameters from.
1571     // node   - file node.
1572     // OUTPUT
1573     // RESULT
1574     */
1575     virtual void read( CvFileStorage* fs, CvFileNode* node );
1576 
1577 
1578     // new-style C++ interface
1579     CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
1580               const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1581               const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1582               const cv::Mat& missingDataMask=cv::Mat(),
1583               CvGBTreesParams params=CvGBTreesParams() );
1584 
1585     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
1586                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
1587                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
1588                        const cv::Mat& missingDataMask=cv::Mat(),
1589                        CvGBTreesParams params=CvGBTreesParams(),
1590                        bool update=false );
1591 
1592     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
1593                            const cv::Range& slice = cv::Range::all(),
1594                            int k=-1 ) const;
1595 
1596 protected:
1597 
1598     /*
1599     // Compute the gradient vector components.
1600     //
1601     // API
1602     // virtual void find_gradient( const int k = 0);
1603 
1604     // INPUT
1605     // k        - used for classification problem, determining current
1606     //            tree ensemble.
1607     // OUTPUT
1608     // changes components of data->responses
1609     // which correspond to samples used for training
1610     // on the current step.
1611     // RESULT
1612     */
1613     virtual void find_gradient( const int k = 0);
1614 
1615 
1616     /*
1617     //
1618     // Change values in tree leaves according to the used loss function.
1619     //
1620     // API
1621     // virtual void change_values(CvDTree* tree, const int k = 0);
1622     //
1623     // INPUT
1624     // tree      - decision tree to change.
1625     // k         - used for classification problem, determining current
1626     //             tree ensemble.
1627     // OUTPUT
1628     // changes 'value' fields of the trees' leaves.
1629     // changes sum_response_tmp.
1630     // RESULT
1631     */
1632     virtual void change_values(CvDTree* tree, const int k = 0);
1633 
1634 
1635     /*
1636     //
1637     // Find optimal constant prediction value according to the used loss
1638     // function.
1639     // The goal is to find a constant which gives the minimal summary loss
1640     // on the _Idx samples.
1641     //
1642     // API
1643     // virtual float find_optimal_value( const CvMat* _Idx );
1644     //
1645     // INPUT
1646     // _Idx        - indices of the samples from the training set.
1647     // OUTPUT
1648     // RESULT
1649     // optimal constant value.
1650     */
1651     virtual float find_optimal_value( const CvMat* _Idx );
1652 
1653 
1654     /*
1655     //
1656     // Randomly split the whole training set in two parts according
1657     // to params.portion.
1658     //
1659     // API
1660     // virtual void do_subsample();
1661     //
1662     // INPUT
1663     // OUTPUT
1664     // subsample_train - indices of samples used for training
1665     // subsample_test  - indices of samples used for test
1666     // RESULT
1667     */
1668     virtual void do_subsample();
1669 
1670 
1671     /*
1672     //
1673     // Internal recursive function giving an array of subtree tree leaves.
1674     //
1675     // API
1676     // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1677     //
1678     // INPUT
1679     // node         - current leaf.
1680     // OUTPUT
1681     // count        - count of leaves in the subtree.
1682     // leaves       - array of pointers to leaves.
1683     // RESULT
1684     */
1685     void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
1686 
1687 
1688     /*
1689     //
1690     // Get leaves of the tree.
1691     //
1692     // API
1693     // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1694     //
1695     // INPUT
1696     // dtree            - decision tree.
1697     // OUTPUT
1698     // len              - count of the leaves.
1699     // RESULT
1700     // CvDTreeNode**    - array of pointers to leaves.
1701     */
1702     CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
1703 
1704 
1705     /*
1706     //
1707     // Is it a regression or a classification.
1708     //
1709     // API
1710     // bool problem_type();
1711     //
1712     // INPUT
1713     // OUTPUT
1714     // RESULT
1715     // false if it is a classification problem,
1716     // true - if regression.
1717     */
1718     virtual bool problem_type() const;
1719 
1720 
1721     /*
1722     //
1723     // Write parameters of the gtb model.
1724     //
1725     // API
1726     // virtual void write_params( CvFileStorage* fs ) const;
1727     //
1728     // INPUT
1729     // fs           - file storage to write parameters to.
1730     // OUTPUT
1731     // RESULT
1732     */
1733     virtual void write_params( CvFileStorage* fs ) const;
1734 
1735 
1736     /*
1737     //
1738     // Read parameters of the gtb model and data.
1739     //
1740     // API
1741     // virtual void read_params( CvFileStorage* fs );
1742     //
1743     // INPUT
1744     // fs           - file storage to read parameters from.
1745     // OUTPUT
1746     // params       - parameters of the gtb model.
1747     // data         - contains information about the structure
1748     //                of the data set (count of variables,
1749     //                their types, etc.).
1750     // class_labels - output class labels map.
1751     // RESULT
1752     */
1753     virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
1754     int get_len(const CvMat* mat) const;
1755 
1756 
1757     CvDTreeTrainData* data;
1758     CvGBTreesParams params;
1759 
1760     CvSeq** weak;
1761     CvMat* orig_response;
1762     CvMat* sum_response;
1763     CvMat* sum_response_tmp;
1764     CvMat* sample_idx;
1765     CvMat* subsample_train;
1766     CvMat* subsample_test;
1767     CvMat* missing;
1768     CvMat* class_labels;
1769 
1770     cv::RNG* rng;
1771 
1772     int class_count;
1773     float delta;
1774     float base_value;
1775 
1776 };
1777 
1778 
1779 
1780 /****************************************************************************************\
1781 *                              Artificial Neural Networks (ANN)                          *
1782 \****************************************************************************************/
1783 
1784 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1785 
1786 struct CvANN_MLP_TrainParams
1787 {
1788     CvANN_MLP_TrainParams();
1789     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1790                            double param1, double param2=0 );
1791     ~CvANN_MLP_TrainParams();
1792 
1793     enum { BACKPROP=0, RPROP=1 };
1794 
1795     CV_PROP_RW CvTermCriteria term_crit;
1796     CV_PROP_RW int train_method;
1797 
1798     // backpropagation parameters
1799     CV_PROP_RW double bp_dw_scale, bp_moment_scale;
1800 
1801     // rprop parameters
1802     CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1803 };
1804 
1805 
1806 class CvANN_MLP : public CvStatModel
1807 {
1808 public:
1809     CV_WRAP CvANN_MLP();
1810     CvANN_MLP( const CvMat* layerSizes,
1811                int activateFunc=CvANN_MLP::SIGMOID_SYM,
1812                double fparam1=0, double fparam2=0 );
1813 
1814     virtual ~CvANN_MLP();
1815 
1816     virtual void create( const CvMat* layerSizes,
1817                          int activateFunc=CvANN_MLP::SIGMOID_SYM,
1818                          double fparam1=0, double fparam2=0 );
1819 
1820     virtual int train( const CvMat* inputs, const CvMat* outputs,
1821                        const CvMat* sampleWeights, const CvMat* sampleIdx=0,
1822                        CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1823                        int flags=0 );
1824     virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
1825 
1826     CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
1827               int activateFunc=CvANN_MLP::SIGMOID_SYM,
1828               double fparam1=0, double fparam2=0 );
1829 
1830     CV_WRAP virtual void create( const cv::Mat& layerSizes,
1831                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
1832                         double fparam1=0, double fparam2=0 );
1833 
1834     CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
1835                       const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
1836                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
1837                       int flags=0 );
1838 
1839     CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
1840 
1841     CV_WRAP virtual void clear();
1842 
1843     // possible activation functions
1844     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1845 
1846     // available training flags
1847     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1848 
1849     virtual void read( CvFileStorage* fs, CvFileNode* node );
1850     virtual void write( CvFileStorage* storage, const char* name ) const;
1851 
get_layer_count()1852     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
get_layer_sizes()1853     const CvMat* get_layer_sizes() { return layer_sizes; }
get_weights(int layer)1854     double* get_weights(int layer)
1855     {
1856         return layer_sizes && weights &&
1857             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1858     }
1859 
1860     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1861 
1862 protected:
1863 
1864     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1865             const CvMat* _sample_weights, const CvMat* sampleIdx,
1866             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1867 
1868     // sequential random backpropagation
1869     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1870 
1871     // RPROP algorithm
1872     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1873 
1874     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1875     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1876                                  double _f_param1=0, double _f_param2=0 );
1877     virtual void init_weights();
1878     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1879     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1880     virtual void calc_input_scale( const CvVectors* vecs, int flags );
1881     virtual void calc_output_scale( const CvVectors* vecs, int flags );
1882 
1883     virtual void write_params( CvFileStorage* fs ) const;
1884     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1885 
1886     CvMat* layer_sizes;
1887     CvMat* wbuf;
1888     CvMat* sample_weights;
1889     double** weights;
1890     double f_param1, f_param2;
1891     double min_val, max_val, min_val1, max_val1;
1892     int activ_func;
1893     int max_count, max_buf_sz;
1894     CvANN_MLP_TrainParams params;
1895     cv::RNG* rng;
1896 };
1897 
1898 /****************************************************************************************\
1899 *                           Auxilary functions declarations                              *
1900 \****************************************************************************************/
1901 
1902 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1903    average row vector, <cov> - symmetric covariation matrix */
1904 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1905                            CvRNG* rng CV_DEFAULT(0) );
1906 
1907 /* Generates sample from gaussian mixture distribution */
1908 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1909                                CvMat* covs[],
1910                                float weights[],
1911                                int clsnum,
1912                                CvMat* sample,
1913                                CvMat* sampClasses CV_DEFAULT(0) );
1914 
1915 #define CV_TS_CONCENTRIC_SPHERES 0
1916 
1917 /* creates test set */
1918 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1919                  int num_samples,
1920                  int num_features,
1921                  CvMat** responses,
1922                  int num_classes, ... );
1923 
1924 /****************************************************************************************\
1925 *                                      Data                                             *
1926 \****************************************************************************************/
1927 
1928 #define CV_COUNT     0
1929 #define CV_PORTION   1
1930 
1931 struct CvTrainTestSplit
1932 {
1933     CvTrainTestSplit();
1934     CvTrainTestSplit( int train_sample_count, bool mix = true);
1935     CvTrainTestSplit( float train_sample_portion, bool mix = true);
1936 
1937     union
1938     {
1939         int count;
1940         float portion;
1941     } train_sample_part;
1942     int train_sample_part_mode;
1943 
1944     bool mix;
1945 };
1946 
1947 class CvMLData
1948 {
1949 public:
1950     CvMLData();
1951     virtual ~CvMLData();
1952 
1953     // returns:
1954     // 0 - OK
1955     // -1 - file can not be opened or is not correct
1956     int read_csv( const char* filename );
1957 
1958     const CvMat* get_values() const;
1959     const CvMat* get_responses();
1960     const CvMat* get_missing() const;
1961 
1962     void set_header_lines_number( int n );
1963     int get_header_lines_number() const;
1964 
1965     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
1966                                       // if idx < 0 there will be no response
1967     int get_response_idx() const;
1968 
1969     void set_train_test_split( const CvTrainTestSplit * spl );
1970     const CvMat* get_train_sample_idx() const;
1971     const CvMat* get_test_sample_idx() const;
1972     void mix_train_and_test_idx();
1973 
1974     const CvMat* get_var_idx();
1975     void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
1976                                                // use change_var_idx
1977     void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
1978 
1979     const CvMat* get_var_types();
1980     int get_var_type( int var_idx ) const;
1981     // following 2 methods enable to change vars type
1982     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
1983     // with numerical labels; in the other cases var types are correctly determined automatically
1984     void set_var_types( const char* str );  // str examples:
1985                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
1986                                             // "cat", "ord" (all vars are categorical/ordered)
1987     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
1988 
1989     void set_delimiter( char ch );
1990     char get_delimiter() const;
1991 
1992     void set_miss_ch( char ch );
1993     char get_miss_ch() const;
1994 
1995     const std::map<cv::String, int>& get_class_labels_map() const;
1996 
1997 protected:
1998     virtual void clear();
1999 
2000     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
2001     void free_train_test_idx();
2002 
2003     char delimiter;
2004     char miss_ch;
2005     //char flt_separator;
2006 
2007     CvMat* values;
2008     CvMat* missing;
2009     CvMat* var_types;
2010     CvMat* var_idx_mask;
2011 
2012     CvMat* response_out; // header
2013     CvMat* var_idx_out; // mat
2014     CvMat* var_types_out; // mat
2015 
2016     int header_lines_number;
2017 
2018     int response_idx;
2019 
2020     int train_sample_count;
2021     bool mix;
2022 
2023     int total_class_count;
2024     std::map<cv::String, int> class_map;
2025 
2026     CvMat* train_sample_idx;
2027     CvMat* test_sample_idx;
2028     int* sample_idx; // data of train_sample_idx and test_sample_idx
2029 
2030     cv::RNG* rng;
2031 };
2032 
2033 
2034 namespace cv
2035 {
2036 
2037 typedef CvStatModel StatModel;
2038 typedef CvParamGrid ParamGrid;
2039 typedef CvNormalBayesClassifier NormalBayesClassifier;
2040 typedef CvKNearest KNearest;
2041 typedef CvSVMParams SVMParams;
2042 typedef CvSVMKernel SVMKernel;
2043 typedef CvSVMSolver SVMSolver;
2044 typedef CvSVM SVM;
2045 typedef CvDTreeParams DTreeParams;
2046 typedef CvMLData TrainData;
2047 typedef CvDTree DecisionTree;
2048 typedef CvForestTree ForestTree;
2049 typedef CvRTParams RandomTreeParams;
2050 typedef CvRTrees RandomTrees;
2051 typedef CvERTreeTrainData ERTreeTRainData;
2052 typedef CvForestERTree ERTree;
2053 typedef CvERTrees ERTrees;
2054 typedef CvBoostParams BoostParams;
2055 typedef CvBoostTree BoostTree;
2056 typedef CvBoost Boost;
2057 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
2058 typedef CvANN_MLP NeuralNet_MLP;
2059 typedef CvGBTreesParams GradientBoostingTreeParams;
2060 typedef CvGBTrees GradientBoostingTrees;
2061 
2062 template<> void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const;
2063 }
2064 
2065 #endif // __cplusplus
2066 #endif // __OPENCV_ML_HPP__
2067 
2068 /* End of file. */
2069