• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #ifndef __ML_H__
42 #define __ML_H__
43 
44 // disable deprecation warning which appears in VisualStudio 8.0
45 #if _MSC_VER >= 1400
46 #pragma warning( disable : 4996 )
47 #endif
48 
49 #ifndef SKIP_INCLUDES
50 
51   #include "cxcore.h"
52   #include <limits.h>
53 
54   #if defined WIN32 || defined WIN64
55     #include <windows.h>
56   #endif
57 
58 #else // SKIP_INCLUDES
59 
60   #if defined WIN32 || defined WIN64
61     #define CV_CDECL __cdecl
62     #define CV_STDCALL __stdcall
63   #else
64     #define CV_CDECL
65     #define CV_STDCALL
66   #endif
67 
68   #ifndef CV_EXTERN_C
69     #ifdef __cplusplus
70       #define CV_EXTERN_C extern "C"
71       #define CV_DEFAULT(val) = val
72     #else
73       #define CV_EXTERN_C
74       #define CV_DEFAULT(val)
75     #endif
76   #endif
77 
78   #ifndef CV_EXTERN_C_FUNCPTR
79     #ifdef __cplusplus
80       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
81     #else
82       #define CV_EXTERN_C_FUNCPTR(x) typedef x
83     #endif
84   #endif
85 
86   #ifndef CV_INLINE
87     #if defined __cplusplus
88       #define CV_INLINE inline
89     #elif (defined WIN32 || defined WIN64) && !defined __GNUC__
90       #define CV_INLINE __inline
91     #else
92       #define CV_INLINE static
93     #endif
94   #endif /* CV_INLINE */
95 
96   #if (defined WIN32 || defined WIN64) && defined CVAPI_EXPORTS
97     #define CV_EXPORTS __declspec(dllexport)
98   #else
99     #define CV_EXPORTS
100   #endif
101 
102   #ifndef CVAPI
103     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
104   #endif
105 
106 #endif // SKIP_INCLUDES
107 
108 
109 #ifdef __cplusplus
110 
111 // Apple defines a check() macro somewhere in the debug headers
112 // that interferes with a method definiton in this header
113 #undef check
114 
115 /****************************************************************************************\
116 *                               Main struct definitions                                  *
117 \****************************************************************************************/
118 
119 /* log(2*PI) */
120 #define CV_LOG2PI (1.8378770664093454835606594728112)
121 
122 /* columns of <trainData> matrix are training samples */
123 #define CV_COL_SAMPLE 0
124 
125 /* rows of <trainData> matrix are training samples */
126 #define CV_ROW_SAMPLE 1
127 
128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
129 
130 struct CvVectors
131 {
132     int type;
133     int dims, count;
134     CvVectors* next;
135     union
136     {
137         uchar** ptr;
138         float** fl;
139         double** db;
140     } data;
141 };
142 
143 #if 0
144 /* A structure, representing the lattice range of statmodel parameters.
145    It is used for optimizing statmodel parameters by cross-validation method.
146    The lattice is logarithmic, so <step> must be greater then 1. */
147 typedef struct CvParamLattice
148 {
149     double min_val;
150     double max_val;
151     double step;
152 }
153 CvParamLattice;
154 
155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
156                                          double log_step )
157 {
158     CvParamLattice pl;
159     pl.min_val = MIN( min_val, max_val );
160     pl.max_val = MAX( min_val, max_val );
161     pl.step = MAX( log_step, 1. );
162     return pl;
163 }
164 
165 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
166 {
167     CvParamLattice pl = {0,0,0};
168     return pl;
169 }
170 #endif
171 
172 /* Variable type */
173 #define CV_VAR_NUMERICAL    0
174 #define CV_VAR_ORDERED      0
175 #define CV_VAR_CATEGORICAL  1
176 
177 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
178 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
179 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
180 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
181 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
182 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
183 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
184 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
185 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
186 
187 class CV_EXPORTS CvStatModel
188 {
189 public:
190     CvStatModel();
191     virtual ~CvStatModel();
192 
193     virtual void clear();
194 
195     virtual void save( const char* filename, const char* name=0 );
196     virtual void load( const char* filename, const char* name=0 );
197 
198     virtual void write( CvFileStorage* storage, const char* name );
199     virtual void read( CvFileStorage* storage, CvFileNode* node );
200 
201 protected:
202     const char* default_model_name;
203 };
204 
205 
206 /****************************************************************************************\
207 *                                 Normal Bayes Classifier                                *
208 \****************************************************************************************/
209 
210 /* The structure, representing the grid range of statmodel parameters.
211    It is used for optimizing statmodel accuracy by varying model parameters,
212    the accuracy estimate being computed by cross-validation.
213    The grid is logarithmic, so <step> must be greater then 1. */
214 struct CV_EXPORTS CvParamGrid
215 {
216     // SVM params type
217     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
218 
CvParamGridCvParamGrid219     CvParamGrid()
220     {
221         min_val = max_val = step = 0;
222     }
223 
CvParamGridCvParamGrid224     CvParamGrid( double _min_val, double _max_val, double log_step )
225     {
226         min_val = _min_val;
227         max_val = _max_val;
228         step = log_step;
229     }
230     //CvParamGrid( int param_id );
231     bool check() const;
232 
233     double min_val;
234     double max_val;
235     double step;
236 };
237 
238 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
239 {
240 public:
241     CvNormalBayesClassifier();
242     virtual ~CvNormalBayesClassifier();
243 
244     CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
245         const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
246 
247     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
248         const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
249 
250     virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
251     virtual void clear();
252 
253     virtual void write( CvFileStorage* storage, const char* name );
254     virtual void read( CvFileStorage* storage, CvFileNode* node );
255 
256 protected:
257     int     var_count, var_all;
258     CvMat*  var_idx;
259     CvMat*  cls_labels;
260     CvMat** count;
261     CvMat** sum;
262     CvMat** productsum;
263     CvMat** avg;
264     CvMat** inv_eigen_values;
265     CvMat** cov_rotate_mats;
266     CvMat*  c;
267 };
268 
269 
270 /****************************************************************************************\
271 *                          K-Nearest Neighbour Classifier                                *
272 \****************************************************************************************/
273 
274 // k Nearest Neighbors
275 class CV_EXPORTS CvKNearest : public CvStatModel
276 {
277 public:
278 
279     CvKNearest();
280     virtual ~CvKNearest();
281 
282     CvKNearest( const CvMat* _train_data, const CvMat* _responses,
283                 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
284 
285     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
286                         const CvMat* _sample_idx=0, bool is_regression=false,
287                         int _max_k=32, bool _update_base=false );
288 
289     virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
290         const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
291 
292     virtual void clear();
293     int get_max_k() const;
294     int get_var_count() const;
295     int get_sample_count() const;
296     bool is_regression() const;
297 
298 protected:
299 
300     virtual float write_results( int k, int k1, int start, int end,
301         const float* neighbor_responses, const float* dist, CvMat* _results,
302         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
303 
304     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
305         float* neighbor_responses, const float** neighbors, float* dist ) const;
306 
307 
308     int max_k, var_count;
309     int total;
310     bool regression;
311     CvVectors* samples;
312 };
313 
314 /****************************************************************************************\
315 *                                   Support Vector Machines                              *
316 \****************************************************************************************/
317 
318 // SVM training parameters
319 struct CV_EXPORTS CvSVMParams
320 {
321     CvSVMParams();
322     CvSVMParams( int _svm_type, int _kernel_type,
323                  double _degree, double _gamma, double _coef0,
324                  double _C, double _nu, double _p,
325                  CvMat* _class_weights, CvTermCriteria _term_crit );
326 
327     int         svm_type;
328     int         kernel_type;
329     double      degree; // for poly
330     double      gamma;  // for poly/rbf/sigmoid
331     double      coef0;  // for poly/sigmoid
332 
333     double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
334     double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
335     double      p; // for CV_SVM_EPS_SVR
336     CvMat*      class_weights; // for CV_SVM_C_SVC
337     CvTermCriteria term_crit; // termination criteria
338 };
339 
340 
341 struct CV_EXPORTS CvSVMKernel
342 {
343     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
344                                        const float* another, float* results );
345     CvSVMKernel();
346     CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
347     virtual bool create( const CvSVMParams* _params, Calc _calc_func );
348     virtual ~CvSVMKernel();
349 
350     virtual void clear();
351     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
352 
353     const CvSVMParams* params;
354     Calc calc_func;
355 
356     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
357                                     const float* another, float* results,
358                                     double alpha, double beta );
359 
360     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
361                               const float* another, float* results );
362     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
363                            const float* another, float* results );
364     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
365                             const float* another, float* results );
366     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
367                                const float* another, float* results );
368 };
369 
370 
371 struct CvSVMKernelRow
372 {
373     CvSVMKernelRow* prev;
374     CvSVMKernelRow* next;
375     float* data;
376 };
377 
378 
379 struct CvSVMSolutionInfo
380 {
381     double obj;
382     double rho;
383     double upper_bound_p;
384     double upper_bound_n;
385     double r;   // for Solver_NU
386 };
387 
388 class CV_EXPORTS CvSVMSolver
389 {
390 public:
391     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
392     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
393     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
394 
395     CvSVMSolver();
396 
397     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
398                  int alpha_count, double* alpha, double Cp, double Cn,
399                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
400                  SelectWorkingSet select_working_set, CalcRho calc_rho );
401     virtual bool create( int count, int var_count, const float** samples, schar* y,
402                  int alpha_count, double* alpha, double Cp, double Cn,
403                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
404                  SelectWorkingSet select_working_set, CalcRho calc_rho );
405     virtual ~CvSVMSolver();
406 
407     virtual void clear();
408     virtual bool solve_generic( CvSVMSolutionInfo& si );
409 
410     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
411                               double Cp, double Cn, CvMemStorage* storage,
412                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
413     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
414                                CvMemStorage* storage, CvSVMKernel* kernel,
415                                double* alpha, CvSVMSolutionInfo& si );
416     virtual bool solve_one_class( int count, int var_count, const float** samples,
417                                   CvMemStorage* storage, CvSVMKernel* kernel,
418                                   double* alpha, CvSVMSolutionInfo& si );
419 
420     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
421                                 CvMemStorage* storage, CvSVMKernel* kernel,
422                                 double* alpha, CvSVMSolutionInfo& si );
423 
424     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
425                                CvMemStorage* storage, CvSVMKernel* kernel,
426                                double* alpha, CvSVMSolutionInfo& si );
427 
428     virtual float* get_row_base( int i, bool* _existed );
429     virtual float* get_row( int i, float* dst );
430 
431     int sample_count;
432     int var_count;
433     int cache_size;
434     int cache_line_size;
435     const float** samples;
436     const CvSVMParams* params;
437     CvMemStorage* storage;
438     CvSVMKernelRow lru_list;
439     CvSVMKernelRow* rows;
440 
441     int alpha_count;
442 
443     double* G;
444     double* alpha;
445 
446     // -1 - lower bound, 0 - free, 1 - upper bound
447     schar* alpha_status;
448 
449     schar* y;
450     double* b;
451     float* buf[2];
452     double eps;
453     int max_iter;
454     double C[2];  // C[0] == Cn, C[1] == Cp
455     CvSVMKernel* kernel;
456 
457     SelectWorkingSet select_working_set_func;
458     CalcRho calc_rho_func;
459     GetRow get_row_func;
460 
461     virtual bool select_working_set( int& i, int& j );
462     virtual bool select_working_set_nu_svm( int& i, int& j );
463     virtual void calc_rho( double& rho, double& r );
464     virtual void calc_rho_nu_svm( double& rho, double& r );
465 
466     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
467     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
468     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
469 };
470 
471 
472 struct CvSVMDecisionFunc
473 {
474     double rho;
475     int sv_count;
476     double* alpha;
477     int* sv_index;
478 };
479 
480 
481 // SVM model
482 class CV_EXPORTS CvSVM : public CvStatModel
483 {
484 public:
485     // SVM type
486     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
487 
488     // SVM kernel type
489     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
490 
491     // SVM params type
492     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
493 
494     CvSVM();
495     virtual ~CvSVM();
496 
497     CvSVM( const CvMat* _train_data, const CvMat* _responses,
498            const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
499            CvSVMParams _params=CvSVMParams() );
500 
501     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
502                         const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
503                         CvSVMParams _params=CvSVMParams() );
504     virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
505         const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
506         int k_fold = 10,
507         CvParamGrid C_grid      = get_default_grid(CvSVM::C),
508         CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
509         CvParamGrid p_grid      = get_default_grid(CvSVM::P),
510         CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
511         CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
512         CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
513 
514     virtual float predict( const CvMat* _sample ) const;
515 
516     virtual int get_support_vector_count() const;
517     virtual const float* get_support_vector(int i) const;
get_params()518     virtual CvSVMParams get_params() const { return params; };
519     virtual void clear();
520 
521     static CvParamGrid get_default_grid( int param_id );
522 
523     virtual void write( CvFileStorage* storage, const char* name );
524     virtual void read( CvFileStorage* storage, CvFileNode* node );
get_var_count()525     int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
526 
527 protected:
528 
529     virtual bool set_params( const CvSVMParams& _params );
530     virtual bool train1( int sample_count, int var_count, const float** samples,
531                     const void* _responses, double Cp, double Cn,
532                     CvMemStorage* _storage, double* alpha, double& rho );
533     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
534                     const CvMat* _responses, CvMemStorage* _storage, double* alpha );
535     virtual void create_kernel();
536     virtual void create_solver();
537 
538     virtual void write_params( CvFileStorage* fs );
539     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
540 
541     CvSVMParams params;
542     CvMat* class_labels;
543     int var_all;
544     float** sv;
545     int sv_total;
546     CvMat* var_idx;
547     CvMat* class_weights;
548     CvSVMDecisionFunc* decision_func;
549     CvMemStorage* storage;
550 
551     CvSVMSolver* solver;
552     CvSVMKernel* kernel;
553 };
554 
555 /****************************************************************************************\
556 *                              Expectation - Maximization                                *
557 \****************************************************************************************/
558 
559 struct CV_EXPORTS CvEMParams
560 {
CvEMParamsCvEMParams561     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
562         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
563     {
564         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
565     }
566 
567     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
568                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
569                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
570                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
nclustersCvEMParams571                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
572                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
573     {}
574 
575     int nclusters;
576     int cov_mat_type;
577     int start_step;
578     const CvMat* probs;
579     const CvMat* weights;
580     const CvMat* means;
581     const CvMat** covs;
582     CvTermCriteria term_crit;
583 };
584 
585 
586 class CV_EXPORTS CvEM : public CvStatModel
587 {
588 public:
589     // Type of covariation matrices
590     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
591 
592     // The initial step
593     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
594 
595     CvEM();
596     CvEM( const CvMat* samples, const CvMat* sample_idx=0,
597           CvEMParams params=CvEMParams(), CvMat* labels=0 );
598 
599     virtual ~CvEM();
600 
601     virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
602                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
603 
604     virtual float predict( const CvMat* sample, CvMat* probs ) const;
605     virtual void clear();
606 
607     int get_nclusters() const;
608     const CvMat* get_means() const;
609     const CvMat** get_covs() const;
610     const CvMat* get_weights() const;
611     const CvMat* get_probs() const;
612 
get_log_likelihood()613     inline double get_log_likelihood () const { return log_likelihood; };
614 
615 protected:
616 
617     virtual void set_params( const CvEMParams& params,
618                              const CvVectors& train_data );
619     virtual void init_em( const CvVectors& train_data );
620     virtual double run_em( const CvVectors& train_data );
621     virtual void init_auto( const CvVectors& samples );
622     virtual void kmeans( const CvVectors& train_data, int nclusters,
623                          CvMat* labels, CvTermCriteria criteria,
624                          const CvMat* means );
625     CvEMParams params;
626     double log_likelihood;
627 
628     CvMat* means;
629     CvMat** covs;
630     CvMat* weights;
631     CvMat* probs;
632 
633     CvMat* log_weight_div_det;
634     CvMat* inv_eigen_values;
635     CvMat** cov_rotate_mats;
636 };
637 
638 /****************************************************************************************\
639 *                                      Decision Tree                                     *
640 \****************************************************************************************/
641 
642 struct CvPair32s32f
643 {
644     int i;
645     float val;
646 };
647 
648 
649 #define CV_DTREE_CAT_DIR(idx,subset) \
650     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
651 
652 struct CvDTreeSplit
653 {
654     int var_idx;
655     int inversed;
656     float quality;
657     CvDTreeSplit* next;
658     union
659     {
660         int subset[2];
661         struct
662         {
663             float c;
664             int split_point;
665         }
666         ord;
667     };
668 };
669 
670 
671 struct CvDTreeNode
672 {
673     int class_idx;
674     int Tn;
675     double value;
676 
677     CvDTreeNode* parent;
678     CvDTreeNode* left;
679     CvDTreeNode* right;
680 
681     CvDTreeSplit* split;
682 
683     int sample_count;
684     int depth;
685     int* num_valid;
686     int offset;
687     int buf_idx;
688     double maxlr;
689 
690     // global pruning data
691     int complexity;
692     double alpha;
693     double node_risk, tree_risk, tree_error;
694 
695     // cross-validation pruning data
696     int* cv_Tn;
697     double* cv_node_risk;
698     double* cv_node_error;
699 
get_num_validCvDTreeNode700     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
set_num_validCvDTreeNode701     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
702 };
703 
704 
705 struct CV_EXPORTS CvDTreeParams
706 {
707     int   max_categories;
708     int   max_depth;
709     int   min_sample_count;
710     int   cv_folds;
711     bool  use_surrogates;
712     bool  use_1se_rule;
713     bool  truncate_pruned_tree;
714     float regression_accuracy;
715     const float* priors;
716 
CvDTreeParamsCvDTreeParams717     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
718         cv_folds(10), use_surrogates(true), use_1se_rule(true),
719         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
720     {}
721 
CvDTreeParamsCvDTreeParams722     CvDTreeParams( int _max_depth, int _min_sample_count,
723                    float _regression_accuracy, bool _use_surrogates,
724                    int _max_categories, int _cv_folds,
725                    bool _use_1se_rule, bool _truncate_pruned_tree,
726                    const float* _priors ) :
727         max_categories(_max_categories), max_depth(_max_depth),
728         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
729         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
730         truncate_pruned_tree(_truncate_pruned_tree),
731         regression_accuracy(_regression_accuracy),
732         priors(_priors)
733     {}
734 };
735 
736 
737 struct CV_EXPORTS CvDTreeTrainData
738 {
739     CvDTreeTrainData();
740     CvDTreeTrainData( const CvMat* _train_data, int _tflag,
741                       const CvMat* _responses, const CvMat* _var_idx=0,
742                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
743                       const CvMat* _missing_mask=0,
744                       const CvDTreeParams& _params=CvDTreeParams(),
745                       bool _shared=false, bool _add_labels=false );
746     virtual ~CvDTreeTrainData();
747 
748     virtual void set_data( const CvMat* _train_data, int _tflag,
749                           const CvMat* _responses, const CvMat* _var_idx=0,
750                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
751                           const CvMat* _missing_mask=0,
752                           const CvDTreeParams& _params=CvDTreeParams(),
753                           bool _shared=false, bool _add_labels=false,
754                           bool _update_data=false );
755 
756     virtual void get_vectors( const CvMat* _subsample_idx,
757          float* values, uchar* missing, float* responses, bool get_class_idx=false );
758 
759     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
760 
761     virtual void write_params( CvFileStorage* fs );
762     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
763 
764     // release all the data
765     virtual void clear();
766 
767     int get_num_classes() const;
768     int get_var_type(int vi) const;
769     int get_work_var_count() const;
770 
771     virtual int* get_class_labels( CvDTreeNode* n );
772     virtual float* get_ord_responses( CvDTreeNode* n );
773     virtual int* get_labels( CvDTreeNode* n );
774     virtual int* get_cat_var_data( CvDTreeNode* n, int vi );
775     virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi );
776     virtual int get_child_buf_idx( CvDTreeNode* n );
777 
778     ////////////////////////////////////
779 
780     virtual bool set_params( const CvDTreeParams& params );
781     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
782                                    int storage_idx, int offset );
783 
784     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
785                 int split_point, int inversed, float quality );
786     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
787     virtual void free_node_data( CvDTreeNode* node );
788     virtual void free_train_data();
789     virtual void free_node( CvDTreeNode* node );
790 
791     int sample_count, var_all, var_count, max_c_count;
792     int ord_var_count, cat_var_count;
793     bool have_labels, have_priors;
794     bool is_classifier;
795 
796     int buf_count, buf_size;
797     bool shared;
798 
799     CvMat* cat_count;
800     CvMat* cat_ofs;
801     CvMat* cat_map;
802 
803     CvMat* counts;
804     CvMat* buf;
805     CvMat* direction;
806     CvMat* split_buf;
807 
808     CvMat* var_idx;
809     CvMat* var_type; // i-th element =
810                      //   k<0  - ordered
811                      //   k>=0 - categorical, see k-th element of cat_* arrays
812     CvMat* priors;
813     CvMat* priors_mult;
814 
815     CvDTreeParams params;
816 
817     CvMemStorage* tree_storage;
818     CvMemStorage* temp_storage;
819 
820     CvDTreeNode* data_root;
821 
822     CvSet* node_heap;
823     CvSet* split_heap;
824     CvSet* cv_heap;
825     CvSet* nv_heap;
826 
827     CvRNG rng;
828 };
829 
830 
831 class CV_EXPORTS CvDTree : public CvStatModel
832 {
833 public:
834     CvDTree();
835     virtual ~CvDTree();
836 
837     virtual bool train( const CvMat* _train_data, int _tflag,
838                         const CvMat* _responses, const CvMat* _var_idx=0,
839                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
840                         const CvMat* _missing_mask=0,
841                         CvDTreeParams params=CvDTreeParams() );
842 
843     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
844 
845     virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
846                                   bool preprocessed_input=false ) const;
847     virtual const CvMat* get_var_importance();
848     virtual void clear();
849 
850     virtual void read( CvFileStorage* fs, CvFileNode* node );
851     virtual void write( CvFileStorage* fs, const char* name );
852 
853     // special read & write methods for trees in the tree ensembles
854     virtual void read( CvFileStorage* fs, CvFileNode* node,
855                        CvDTreeTrainData* data );
856     virtual void write( CvFileStorage* fs );
857 
858     const CvDTreeNode* get_root() const;
859     int get_pruned_tree_idx() const;
860     CvDTreeTrainData* get_data();
861 
862 protected:
863 
864     virtual bool do_train( const CvMat* _subsample_idx );
865 
866     virtual void try_split_node( CvDTreeNode* n );
867     virtual void split_node_data( CvDTreeNode* n );
868     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
869     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
870     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
871     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
872     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
873     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
874     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
875     virtual double calc_node_dir( CvDTreeNode* node );
876     virtual void complete_node_dir( CvDTreeNode* node );
877     virtual void cluster_categories( const int* vectors, int vector_count,
878         int var_count, int* sums, int k, int* cluster_labels );
879 
880     virtual void calc_node_value( CvDTreeNode* node );
881 
882     virtual void prune_cv();
883     virtual double update_tree_rnc( int T, int fold );
884     virtual int cut_tree( int T, int fold, double min_alpha );
885     virtual void free_prune_data(bool cut_tree);
886     virtual void free_tree();
887 
888     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node );
889     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
890     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
891     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
892     virtual void write_tree_nodes( CvFileStorage* fs );
893     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
894 
895     CvDTreeNode* root;
896 
897     int pruned_tree_idx;
898     CvMat* var_importance;
899 
900     CvDTreeTrainData* data;
901 };
902 
903 
904 /****************************************************************************************\
905 *                                   Random Trees Classifier                              *
906 \****************************************************************************************/
907 
908 class CvRTrees;
909 
910 class CV_EXPORTS CvForestTree: public CvDTree
911 {
912 public:
913     CvForestTree();
914     virtual ~CvForestTree();
915 
916     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
917 
get_var_count()918     virtual int get_var_count() const {return data ? data->var_count : 0;}
919     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
920 
921     /* dummy methods to avoid warnings: BEGIN */
922     virtual bool train( const CvMat* _train_data, int _tflag,
923                         const CvMat* _responses, const CvMat* _var_idx=0,
924                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
925                         const CvMat* _missing_mask=0,
926                         CvDTreeParams params=CvDTreeParams() );
927 
928     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
929     virtual void read( CvFileStorage* fs, CvFileNode* node );
930     virtual void read( CvFileStorage* fs, CvFileNode* node,
931                        CvDTreeTrainData* data );
932     /* dummy methods to avoid warnings: END */
933 
934 protected:
935     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
936     CvRTrees* forest;
937 };
938 
939 
940 struct CV_EXPORTS CvRTParams : public CvDTreeParams
941 {
942     //Parameters for the forest
943     bool calc_var_importance; // true <=> RF processes variable importance
944     int nactive_vars;
945     CvTermCriteria term_crit;
946 
CvRTParamsCvRTParams947     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
948         calc_var_importance(false), nactive_vars(0)
949     {
950         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
951     }
952 
CvRTParamsCvRTParams953     CvRTParams( int _max_depth, int _min_sample_count,
954                 float _regression_accuracy, bool _use_surrogates,
955                 int _max_categories, const float* _priors, bool _calc_var_importance,
956                 int _nactive_vars, int max_num_of_trees_in_the_forest,
957                 float forest_accuracy, int termcrit_type ) :
958         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
959                        _use_surrogates, _max_categories, 0,
960                        false, false, _priors ),
961         calc_var_importance(_calc_var_importance),
962         nactive_vars(_nactive_vars)
963     {
964         term_crit = cvTermCriteria(termcrit_type,
965             max_num_of_trees_in_the_forest, forest_accuracy);
966     }
967 };
968 
969 
970 class CV_EXPORTS CvRTrees : public CvStatModel
971 {
972 public:
973     CvRTrees();
974     virtual ~CvRTrees();
975     virtual bool train( const CvMat* _train_data, int _tflag,
976                         const CvMat* _responses, const CvMat* _var_idx=0,
977                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
978                         const CvMat* _missing_mask=0,
979                         CvRTParams params=CvRTParams() );
980     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
981     virtual void clear();
982 
983     virtual const CvMat* get_var_importance();
984     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
985         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
986 
987     virtual void read( CvFileStorage* fs, CvFileNode* node );
988     virtual void write( CvFileStorage* fs, const char* name );
989 
990     CvMat* get_active_var_mask();
991     CvRNG* get_rng();
992 
993     int get_tree_count() const;
994     CvForestTree* get_tree(int i) const;
995 
996 protected:
997 
998     bool grow_forest( const CvTermCriteria term_crit );
999 
1000     // array of the trees of the forest
1001     CvForestTree** trees;
1002     CvDTreeTrainData* data;
1003     int ntrees;
1004     int nclasses;
1005     double oob_error;
1006     CvMat* var_importance;
1007     int nsamples;
1008 
1009     CvRNG rng;
1010     CvMat* active_var_mask;
1011 };
1012 
1013 
1014 /****************************************************************************************\
1015 *                                   Boosted tree classifier                              *
1016 \****************************************************************************************/
1017 
1018 struct CV_EXPORTS CvBoostParams : public CvDTreeParams
1019 {
1020     int boost_type;
1021     int weak_count;
1022     int split_criteria;
1023     double weight_trim_rate;
1024 
1025     CvBoostParams();
1026     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
1027                    int max_depth, bool use_surrogates, const float* priors );
1028 };
1029 
1030 
1031 class CvBoost;
1032 
1033 class CV_EXPORTS CvBoostTree: public CvDTree
1034 {
1035 public:
1036     CvBoostTree();
1037     virtual ~CvBoostTree();
1038 
1039     virtual bool train( CvDTreeTrainData* _train_data,
1040                         const CvMat* subsample_idx, CvBoost* ensemble );
1041 
1042     virtual void scale( double s );
1043     virtual void read( CvFileStorage* fs, CvFileNode* node,
1044                        CvBoost* ensemble, CvDTreeTrainData* _data );
1045     virtual void clear();
1046 
1047     /* dummy methods to avoid warnings: BEGIN */
1048     virtual bool train( const CvMat* _train_data, int _tflag,
1049                         const CvMat* _responses, const CvMat* _var_idx=0,
1050                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1051                         const CvMat* _missing_mask=0,
1052                         CvDTreeParams params=CvDTreeParams() );
1053 
1054     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
1055     virtual void read( CvFileStorage* fs, CvFileNode* node );
1056     virtual void read( CvFileStorage* fs, CvFileNode* node,
1057                        CvDTreeTrainData* data );
1058     /* dummy methods to avoid warnings: END */
1059 
1060 protected:
1061 
1062     virtual void try_split_node( CvDTreeNode* n );
1063     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
1064     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
1065     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
1066     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
1067     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
1068     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
1069     virtual void calc_node_value( CvDTreeNode* n );
1070     virtual double calc_node_dir( CvDTreeNode* n );
1071 
1072     CvBoost* ensemble;
1073 };
1074 
1075 
1076 class CV_EXPORTS CvBoost : public CvStatModel
1077 {
1078 public:
1079     // Boosting type
1080     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
1081 
1082     // Splitting criteria
1083     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
1084 
1085     CvBoost();
1086     virtual ~CvBoost();
1087 
1088     CvBoost( const CvMat* _train_data, int _tflag,
1089              const CvMat* _responses, const CvMat* _var_idx=0,
1090              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1091              const CvMat* _missing_mask=0,
1092              CvBoostParams params=CvBoostParams() );
1093 
1094     virtual bool train( const CvMat* _train_data, int _tflag,
1095              const CvMat* _responses, const CvMat* _var_idx=0,
1096              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
1097              const CvMat* _missing_mask=0,
1098              CvBoostParams params=CvBoostParams(),
1099              bool update=false );
1100 
1101     virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
1102                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
1103                            bool raw_mode=false ) const;
1104 
1105     virtual void prune( CvSlice slice );
1106 
1107     virtual void clear();
1108 
1109     virtual void write( CvFileStorage* storage, const char* name );
1110     virtual void read( CvFileStorage* storage, CvFileNode* node );
1111 
1112     CvSeq* get_weak_predictors();
1113 
1114     CvMat* get_weights();
1115     CvMat* get_subtree_weights();
1116     CvMat* get_weak_response();
1117     const CvBoostParams& get_params() const;
1118 
1119 protected:
1120 
1121     virtual bool set_params( const CvBoostParams& _params );
1122     virtual void update_weights( CvBoostTree* tree );
1123     virtual void trim_weights();
1124     virtual void write_params( CvFileStorage* fs );
1125     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1126 
1127     CvDTreeTrainData* data;
1128     CvBoostParams params;
1129     CvSeq* weak;
1130 
1131     CvMat* orig_response;
1132     CvMat* sum_response;
1133     CvMat* weak_eval;
1134     CvMat* subsample_mask;
1135     CvMat* weights;
1136     CvMat* subtree_weights;
1137     bool have_subsample;
1138 };
1139 
1140 
1141 /****************************************************************************************\
1142 *                              Artificial Neural Networks (ANN)                          *
1143 \****************************************************************************************/
1144 
1145 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
1146 
1147 struct CV_EXPORTS CvANN_MLP_TrainParams
1148 {
1149     CvANN_MLP_TrainParams();
1150     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
1151                            double param1, double param2=0 );
1152     ~CvANN_MLP_TrainParams();
1153 
1154     enum { BACKPROP=0, RPROP=1 };
1155 
1156     CvTermCriteria term_crit;
1157     int train_method;
1158 
1159     // backpropagation parameters
1160     double bp_dw_scale, bp_moment_scale;
1161 
1162     // rprop parameters
1163     double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
1164 };
1165 
1166 
1167 class CV_EXPORTS CvANN_MLP : public CvStatModel
1168 {
1169 public:
1170     CvANN_MLP();
1171     CvANN_MLP( const CvMat* _layer_sizes,
1172                int _activ_func=SIGMOID_SYM,
1173                double _f_param1=0, double _f_param2=0 );
1174 
1175     virtual ~CvANN_MLP();
1176 
1177     virtual void create( const CvMat* _layer_sizes,
1178                          int _activ_func=SIGMOID_SYM,
1179                          double _f_param1=0, double _f_param2=0 );
1180 
1181     virtual int train( const CvMat* _inputs, const CvMat* _outputs,
1182                        const CvMat* _sample_weights, const CvMat* _sample_idx=0,
1183                        CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
1184                        int flags=0 );
1185     virtual float predict( const CvMat* _inputs,
1186                            CvMat* _outputs ) const;
1187 
1188     virtual void clear();
1189 
1190     // possible activation functions
1191     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
1192 
1193     // available training flags
1194     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
1195 
1196     virtual void read( CvFileStorage* fs, CvFileNode* node );
1197     virtual void write( CvFileStorage* storage, const char* name );
1198 
get_layer_count()1199     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
get_layer_sizes()1200     const CvMat* get_layer_sizes() { return layer_sizes; }
get_weights(int layer)1201     double* get_weights(int layer)
1202     {
1203         return layer_sizes && weights &&
1204             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
1205     }
1206 
1207 protected:
1208 
1209     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
1210             const CvMat* _sample_weights, const CvMat* _sample_idx,
1211             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
1212 
1213     // sequential random backpropagation
1214     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1215 
1216     // RPROP algorithm
1217     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
1218 
1219     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
1220     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
1221     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
1222                                  double _f_param1=0, double _f_param2=0 );
1223     virtual void init_weights();
1224     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
1225     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
1226     virtual void calc_input_scale( const CvVectors* vecs, int flags );
1227     virtual void calc_output_scale( const CvVectors* vecs, int flags );
1228 
1229     virtual void write_params( CvFileStorage* fs );
1230     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
1231 
1232     CvMat* layer_sizes;
1233     CvMat* wbuf;
1234     CvMat* sample_weights;
1235     double** weights;
1236     double f_param1, f_param2;
1237     double min_val, max_val, min_val1, max_val1;
1238     int activ_func;
1239     int max_count, max_buf_sz;
1240     CvANN_MLP_TrainParams params;
1241     CvRNG rng;
1242 };
1243 
1244 #if 0
1245 /****************************************************************************************\
1246 *                            Convolutional Neural Network                                *
1247 \****************************************************************************************/
1248 typedef struct CvCNNLayer CvCNNLayer;
1249 typedef struct CvCNNetwork CvCNNetwork;
1250 
1251 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
1252 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
1253 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
1254 
1255 #define CV_CNN_GRAD_ESTIM_RANDOM        0
1256 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
1257 
1258 #define ICV_CNN_LAYER                0x55550000
1259 #define ICV_CNN_CONVOLUTION_LAYER    0x00001111
1260 #define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
1261 #define ICV_CNN_FULLCONNECT_LAYER    0x00003333
1262 
1263 #define ICV_IS_CNN_LAYER( layer )                                          \
1264     ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
1265         == ICV_CNN_LAYER ))
1266 
1267 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              \
1268     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1269         & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
1270 
1271 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              \
1272     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1273         & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
1274 
1275 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              \
1276     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
1277         & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
1278 
1279 typedef void (CV_CDECL *CvCNNLayerForward)
1280     ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
1281 
1282 typedef void (CV_CDECL *CvCNNLayerBackward)
1283     ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
1284 
1285 typedef void (CV_CDECL *CvCNNLayerRelease)
1286     (CvCNNLayer** layer);
1287 
1288 typedef void (CV_CDECL *CvCNNetworkAddLayer)
1289     (CvCNNetwork* network, CvCNNLayer* layer);
1290 
1291 typedef void (CV_CDECL *CvCNNetworkRelease)
1292     (CvCNNetwork** network);
1293 
1294 #define CV_CNN_LAYER_FIELDS()           \
1295     /* Indicator of the layer's type */ \
1296     int flags;                          \
1297                                         \
1298     /* Number of input images */        \
1299     int n_input_planes;                 \
1300     /* Height of each input image */    \
1301     int input_height;                   \
1302     /* Width of each input image */     \
1303     int input_width;                    \
1304                                         \
1305     /* Number of output images */       \
1306     int n_output_planes;                \
1307     /* Height of each output image */   \
1308     int output_height;                  \
1309     /* Width of each output image */    \
1310     int output_width;                   \
1311                                         \
1312     /* Learning rate at the first iteration */                      \
1313     float init_learn_rate;                                          \
1314     /* Dynamics of learning rate decreasing */                      \
1315     int learn_rate_decrease_type;                                   \
1316     /* Trainable weights of the layer (including bias) */           \
1317     /* i-th row is a set of weights of the i-th output plane */     \
1318     CvMat* weights;                                                 \
1319                                                                     \
1320     CvCNNLayerForward  forward;                                     \
1321     CvCNNLayerBackward backward;                                    \
1322     CvCNNLayerRelease  release;                                     \
1323     /* Pointers to the previous and next layers in the network */   \
1324     CvCNNLayer* prev_layer;                                         \
1325     CvCNNLayer* next_layer
1326 
1327 typedef struct CvCNNLayer
1328 {
1329     CV_CNN_LAYER_FIELDS();
1330 }CvCNNLayer;
1331 
1332 typedef struct CvCNNConvolutionLayer
1333 {
1334     CV_CNN_LAYER_FIELDS();
1335     // Kernel size (height and width) for convolution.
1336     int K;
1337     // connections matrix, (i,j)-th element is 1 iff there is a connection between
1338     // i-th plane of the current layer and j-th plane of the previous layer;
1339     // (i,j)-th element is equal to 0 otherwise
1340     CvMat *connect_mask;
1341     // value of the learning rate for updating weights at the first iteration
1342 }CvCNNConvolutionLayer;
1343 
1344 typedef struct CvCNNSubSamplingLayer
1345 {
1346     CV_CNN_LAYER_FIELDS();
1347     // ratio between the heights (or widths - ratios are supposed to be equal)
1348     // of the input and output planes
1349     int sub_samp_scale;
1350     // amplitude of sigmoid activation function
1351     float a;
1352     // scale parameter of sigmoid activation function
1353     float s;
1354     // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
1355     // - is the vector used in computing of the activation function in backward
1356     CvMat* exp2ssumWX;
1357     // (x1+x2+x3+x4), where x1,...x4 are some elements of X
1358     // - is the vector used in computing of the activation function in backward
1359     CvMat* sumX;
1360 }CvCNNSubSamplingLayer;
1361 
1362 // Structure of the last layer.
1363 typedef struct CvCNNFullConnectLayer
1364 {
1365     CV_CNN_LAYER_FIELDS();
1366     // amplitude of sigmoid activation function
1367     float a;
1368     // scale parameter of sigmoid activation function
1369     float s;
1370     // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
1371     // activation function and it's derivative by the formulae
1372     // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
1373     // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
1374     CvMat* exp2ssumWX;
1375 }CvCNNFullConnectLayer;
1376 
1377 typedef struct CvCNNetwork
1378 {
1379     int n_layers;
1380     CvCNNLayer* layers;
1381     CvCNNetworkAddLayer add_layer;
1382     CvCNNetworkRelease release;
1383 }CvCNNetwork;
1384 
1385 typedef struct CvCNNStatModel
1386 {
1387     CV_STAT_MODEL_FIELDS();
1388     CvCNNetwork* network;
1389     // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
1390     CvMat* etalons;
1391     // classes labels
1392     CvMat* cls_labels;
1393 }CvCNNStatModel;
1394 
1395 typedef struct CvCNNStatModelParams
1396 {
1397     CV_STAT_MODEL_PARAM_FIELDS();
1398     // network must be created by the functions cvCreateCNNetwork and <add_layer>
1399     CvCNNetwork* network;
1400     CvMat* etalons;
1401     // termination criteria
1402     int max_iter;
1403     int start_iter;
1404     int grad_estim_type;
1405 }CvCNNStatModelParams;
1406 
1407 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
1408     int n_input_planes, int input_height, int input_width,
1409     int n_output_planes, int K,
1410     float init_learn_rate, int learn_rate_decrease_type,
1411     CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
1412 
1413 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
1414     int n_input_planes, int input_height, int input_width,
1415     int sub_samp_scale, float a, float s,
1416     float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
1417 
1418 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
1419     int n_inputs, int n_outputs, float a, float s,
1420     float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
1421 
1422 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
1423 
1424 CVAPI(CvStatModel*) cvTrainCNNClassifier(
1425             const CvMat* train_data, int tflag,
1426             const CvMat* responses,
1427             const CvStatModelParams* params,
1428             const CvMat* CV_DEFAULT(0),
1429             const CvMat* sample_idx CV_DEFAULT(0),
1430             const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
1431 
1432 /****************************************************************************************\
1433 *                               Estimate classifiers algorithms                          *
1434 \****************************************************************************************/
1435 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
1436                     ( const CvStatModel* estimateModel );
1437 
1438 typedef int (CV_CDECL *CvStatModelEstimateNextStep)
1439                     ( CvStatModel* estimateModel );
1440 
1441 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
1442                     ( CvStatModel* estimateModel,
1443                 const CvStatModel* model,
1444                 const CvMat*       features,
1445                       int          sample_t_flag,
1446                 const CvMat*       responses );
1447 
1448 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
1449                     ( CvStatModel* estimateModel,
1450                 const CvStatModel* model );
1451 
1452 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
1453                     ( const CvStatModel* estimateModel,
1454                             float*       correlation );
1455 
1456 typedef void (CV_CDECL *CvStatModelEstimateReset)
1457                     ( CvStatModel* estimateModel );
1458 
1459 //-------------------------------- Cross-validation --------------------------------------
1460 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    \
1461     CV_STAT_MODEL_PARAM_FIELDS();                                 \
1462     int     k_fold;                                               \
1463     int     is_regression;                                        \
1464     CvRNG*  rng
1465 
1466 typedef struct CvCrossValidationParams
1467 {
1468     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
1469 } CvCrossValidationParams;
1470 
1471 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    \
1472     CvStatModelEstimateGetMat               getTrainIdxMat; \
1473     CvStatModelEstimateGetMat               getCheckIdxMat; \
1474     CvStatModelEstimateNextStep             nextStep;       \
1475     CvStatModelEstimateCheckClassifier      check;          \
1476     CvStatModelEstimateGetCurrentResult     getResult;      \
1477     CvStatModelEstimateReset                reset;          \
1478     int     is_regression;                                  \
1479     int     folds_all;                                      \
1480     int     samples_all;                                    \
1481     int*    sampleIdxAll;                                   \
1482     int*    folds;                                          \
1483     int     max_fold_size;                                  \
1484     int         current_fold;                               \
1485     int         is_checked;                                 \
1486     CvMat*      sampleIdxTrain;                             \
1487     CvMat*      sampleIdxEval;                              \
1488     CvMat*      predict_results;                            \
1489     int     correct_results;                                \
1490     int     all_results;                                    \
1491     double  sq_error;                                       \
1492     double  sum_correct;                                    \
1493     double  sum_predict;                                    \
1494     double  sum_cc;                                         \
1495     double  sum_pp;                                         \
1496     double  sum_cp
1497 
1498 typedef struct CvCrossValidationModel
1499 {
1500     CV_STAT_MODEL_FIELDS();
1501     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
1502 } CvCrossValidationModel;
1503 
1504 CVAPI(CvStatModel*)
1505 cvCreateCrossValidationEstimateModel
1506            ( int                samples_all,
1507        const CvStatModelParams* estimateParams CV_DEFAULT(0),
1508        const CvMat*             sampleIdx CV_DEFAULT(0) );
1509 
1510 CVAPI(float)
1511 cvCrossValidation( const CvMat*             trueData,
1512                          int                tflag,
1513                    const CvMat*             trueClasses,
1514                          CvStatModel*     (*createClassifier)( const CvMat*,
1515                                                                      int,
1516                                                                const CvMat*,
1517                                                                const CvStatModelParams*,
1518                                                                const CvMat*,
1519                                                                const CvMat*,
1520                                                                const CvMat*,
1521                                                                const CvMat* ),
1522                    const CvStatModelParams* estimateParams CV_DEFAULT(0),
1523                    const CvStatModelParams* trainParams CV_DEFAULT(0),
1524                    const CvMat*             compIdx CV_DEFAULT(0),
1525                    const CvMat*             sampleIdx CV_DEFAULT(0),
1526                          CvStatModel**      pCrValModel CV_DEFAULT(0),
1527                    const CvMat*             typeMask CV_DEFAULT(0),
1528                    const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
1529 #endif
1530 
1531 /****************************************************************************************\
1532 *                           Auxilary functions declarations                              *
1533 \****************************************************************************************/
1534 
1535 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
1536    average row vector, <cov> - symmetric covariation matrix */
1537 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
1538                            CvRNG* rng CV_DEFAULT(0) );
1539 
1540 /* Generates sample from gaussian mixture distribution */
1541 CVAPI(void) cvRandGaussMixture( CvMat* means[],
1542                                CvMat* covs[],
1543                                float weights[],
1544                                int clsnum,
1545                                CvMat* sample,
1546                                CvMat* sampClasses CV_DEFAULT(0) );
1547 
1548 #define CV_TS_CONCENTRIC_SPHERES 0
1549 
1550 /* creates test set */
1551 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
1552                  int num_samples,
1553                  int num_features,
1554                  CvMat** responses,
1555                  int num_classes, ... );
1556 
1557 /* Aij <- Aji for i > j if lower_to_upper != 0
1558               for i < j if lower_to_upper = 0 */
1559 CVAPI(void) cvCompleteSymm( CvMat* matrix, int lower_to_upper );
1560 
1561 #endif
1562 
1563 #endif /*__ML_H__*/
1564 /* End of file. */
1565