• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #include "_ml.h"
42 
43 /****************************************************************************************\
44                                 COPYRIGHT NOTICE
45                                 ----------------
46 
47   The code has been derived from libsvm library (version 2.6)
48   (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
49 
50   Here is the orignal copyright:
51 ------------------------------------------------------------------------------------------
52     Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
53     All rights reserved.
54 
55     Redistribution and use in source and binary forms, with or without
56     modification, are permitted provided that the following conditions
57     are met:
58 
59     1. Redistributions of source code must retain the above copyright
60     notice, this list of conditions and the following disclaimer.
61 
62     2. Redistributions in binary form must reproduce the above copyright
63     notice, this list of conditions and the following disclaimer in the
64     documentation and/or other materials provided with the distribution.
65 
66     3. Neither name of copyright holders nor the names of its contributors
67     may be used to endorse or promote products derived from this software
68     without specific prior written permission.
69 
70 
71     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
72     ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
73     LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
74     A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
75     CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
76     EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
77     PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
78     PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
79     LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
80     NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
81     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
82 \****************************************************************************************/
83 
84 #define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
85 
86 #include <stdarg.h>
87 #include <ctype.h>
88 
89 #if _MSC_VER >= 1200
90 #pragma warning( disable: 4514 ) /* unreferenced inline functions */
91 #endif
92 
93 #if 1
94 typedef float Qfloat;
95 #define QFLOAT_TYPE CV_32F
96 #else
97 typedef double Qfloat;
98 #define QFLOAT_TYPE CV_64F
99 #endif
100 
101 // Param Grid
check() const102 bool CvParamGrid::check() const
103 {
104     bool ok = false;
105 
106     CV_FUNCNAME( "CvParamGrid::check" );
107     __BEGIN__;
108 
109     if( min_val > max_val )
110         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
111     if( min_val < DBL_EPSILON )
112         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
113     if( step < 1. + FLT_EPSILON )
114         CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
115 
116     ok = true;
117 
118     __END__;
119 
120     return ok;
121 }
122 
get_default_grid(int param_id)123 CvParamGrid CvSVM::get_default_grid( int param_id )
124 {
125     CvParamGrid grid;
126     if( param_id == CvSVM::C )
127     {
128         grid.min_val = 0.1;
129         grid.max_val = 500;
130         grid.step = 5; // total iterations = 5
131     }
132     else if( param_id == CvSVM::GAMMA )
133     {
134         grid.min_val = 1e-5;
135         grid.max_val = 0.6;
136         grid.step = 15; // total iterations = 4
137     }
138     else if( param_id == CvSVM::P )
139     {
140         grid.min_val = 0.01;
141         grid.max_val = 100;
142         grid.step = 7; // total iterations = 4
143     }
144     else if( param_id == CvSVM::NU )
145     {
146         grid.min_val = 0.01;
147         grid.max_val = 0.2;
148         grid.step = 3; // total iterations = 3
149     }
150     else if( param_id == CvSVM::COEF )
151     {
152         grid.min_val = 0.1;
153         grid.max_val = 300;
154         grid.step = 14; // total iterations = 3
155     }
156     else if( param_id == CvSVM::DEGREE )
157     {
158         grid.min_val = 0.01;
159         grid.max_val = 4;
160         grid.step = 7; // total iterations = 3
161     }
162     else
163         cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
164             "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
165     return grid;
166 }
167 
168 // SVM training parameters
CvSVMParams()169 CvSVMParams::CvSVMParams() :
170     svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
171     gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
172 {
173     term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
174 }
175 
176 
CvSVMParams(int _svm_type,int _kernel_type,double _degree,double _gamma,double _coef0,double _Con,double _nu,double _p,CvMat * _class_weights,CvTermCriteria _term_crit)177 CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
178     double _degree, double _gamma, double _coef0,
179     double _Con, double _nu, double _p,
180     CvMat* _class_weights, CvTermCriteria _term_crit ) :
181     svm_type(_svm_type), kernel_type(_kernel_type),
182     degree(_degree), gamma(_gamma), coef0(_coef0),
183     C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
184 {
185 }
186 
187 
188 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
189 
CvSVMKernel()190 CvSVMKernel::CvSVMKernel()
191 {
192     clear();
193 }
194 
195 
clear()196 void CvSVMKernel::clear()
197 {
198     params = 0;
199     calc_func = 0;
200 }
201 
202 
~CvSVMKernel()203 CvSVMKernel::~CvSVMKernel()
204 {
205 }
206 
207 
CvSVMKernel(const CvSVMParams * _params,Calc _calc_func)208 CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
209 {
210     clear();
211     create( _params, _calc_func );
212 }
213 
214 
create(const CvSVMParams * _params,Calc _calc_func)215 bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
216 {
217     clear();
218     params = _params;
219     calc_func = _calc_func;
220 
221     if( !calc_func )
222         calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
223                     params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
224                     params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
225                     &CvSVMKernel::calc_linear;
226 
227     return true;
228 }
229 
230 
calc_non_rbf_base(int vcount,int var_count,const float ** vecs,const float * another,Qfloat * results,double alpha,double beta)231 void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
232                                      const float* another, Qfloat* results,
233                                      double alpha, double beta )
234 {
235     int j, k;
236     for( j = 0; j < vcount; j++ )
237     {
238         const float* sample = vecs[j];
239         double s = 0;
240         for( k = 0; k <= var_count - 4; k += 4 )
241             s += sample[k]*another[k] + sample[k+1]*another[k+1] +
242                  sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
243         for( ; k < var_count; k++ )
244             s += sample[k]*another[k];
245         results[j] = (Qfloat)(s*alpha + beta);
246     }
247 }
248 
249 
calc_linear(int vcount,int var_count,const float ** vecs,const float * another,Qfloat * results)250 void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
251                                const float* another, Qfloat* results )
252 {
253     calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
254 }
255 
256 
calc_poly(int vcount,int var_count,const float ** vecs,const float * another,Qfloat * results)257 void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
258                              const float* another, Qfloat* results )
259 {
260     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
261     calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
262     cvPow( &R, &R, params->degree );
263 }
264 
265 
calc_sigmoid(int vcount,int var_count,const float ** vecs,const float * another,Qfloat * results)266 void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
267                                 const float* another, Qfloat* results )
268 {
269     int j;
270     calc_non_rbf_base( vcount, var_count, vecs, another, results,
271                        -2*params->gamma, -2*params->coef0 );
272     // TODO: speedup this
273     for( j = 0; j < vcount; j++ )
274     {
275         Qfloat t = results[j];
276         double e = exp(-fabs(t));
277         if( t > 0 )
278             results[j] = (Qfloat)((1. - e)/(1. + e));
279         else
280             results[j] = (Qfloat)((e - 1.)/(e + 1.));
281     }
282 }
283 
284 
calc_rbf(int vcount,int var_count,const float ** vecs,const float * another,Qfloat * results)285 void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
286                             const float* another, Qfloat* results )
287 {
288     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
289     double gamma = -params->gamma;
290     int j, k;
291 
292     for( j = 0; j < vcount; j++ )
293     {
294         const float* sample = vecs[j];
295         double s = 0;
296 
297         for( k = 0; k <= var_count - 4; k += 4 )
298         {
299             double t0 = sample[k] - another[k];
300             double t1 = sample[k+1] - another[k+1];
301 
302             s += t0*t0 + t1*t1;
303 
304             t0 = sample[k+2] - another[k+2];
305             t1 = sample[k+3] - another[k+3];
306 
307             s += t0*t0 + t1*t1;
308         }
309 
310         for( ; k < var_count; k++ )
311         {
312             double t0 = sample[k] - another[k];
313             s += t0*t0;
314         }
315         results[j] = (Qfloat)(s*gamma);
316     }
317 
318     cvExp( &R, &R );
319 }
320 
321 
calc(int vcount,int var_count,const float ** vecs,const float * another,Qfloat * results)322 void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
323                         const float* another, Qfloat* results )
324 {
325     const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
326     int j;
327     (this->*calc_func)( vcount, var_count, vecs, another, results );
328     for( j = 0; j < vcount; j++ )
329     {
330         if( results[j] > max_val )
331             results[j] = max_val;
332     }
333 }
334 
335 
336 // Generalized SMO+SVMlight algorithm
337 // Solves:
338 //
339 //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
340 //
341 //      y^T \alpha = \delta
342 //      y_i = +1 or -1
343 //      0 <= alpha_i <= Cp for y_i = 1
344 //      0 <= alpha_i <= Cn for y_i = -1
345 //
346 // Given:
347 //
348 //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
349 //  l is the size of vectors and matrices
350 //  eps is the stopping criterion
351 //
352 // solution will be put in \alpha, objective value will be put in obj
353 //
354 
clear()355 void CvSVMSolver::clear()
356 {
357     G = 0;
358     alpha = 0;
359     y = 0;
360     b = 0;
361     buf[0] = buf[1] = 0;
362     cvReleaseMemStorage( &storage );
363     kernel = 0;
364     select_working_set_func = 0;
365     calc_rho_func = 0;
366 
367     rows = 0;
368     samples = 0;
369     get_row_func = 0;
370 }
371 
372 
CvSVMSolver()373 CvSVMSolver::CvSVMSolver()
374 {
375     storage = 0;
376     clear();
377 }
378 
379 
~CvSVMSolver()380 CvSVMSolver::~CvSVMSolver()
381 {
382     clear();
383 }
384 
385 
CvSVMSolver(int _sample_count,int _var_count,const float ** _samples,schar * _y,int _alpha_count,double * _alpha,double _Cp,double _Cn,CvMemStorage * _storage,CvSVMKernel * _kernel,GetRow _get_row,SelectWorkingSet _select_working_set,CalcRho _calc_rho)386 CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
387                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
388                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
389                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
390 {
391     storage = 0;
392     create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
393             _storage, _kernel, _get_row, _select_working_set, _calc_rho );
394 }
395 
396 
create(int _sample_count,int _var_count,const float ** _samples,schar * _y,int _alpha_count,double * _alpha,double _Cp,double _Cn,CvMemStorage * _storage,CvSVMKernel * _kernel,GetRow _get_row,SelectWorkingSet _select_working_set,CalcRho _calc_rho)397 bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
398                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
399                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
400                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
401 {
402     bool ok = false;
403     int i, svm_type;
404 
405     CV_FUNCNAME( "CvSVMSolver::create" );
406 
407     __BEGIN__;
408 
409     int rows_hdr_size;
410 
411     clear();
412 
413     sample_count = _sample_count;
414     var_count = _var_count;
415     samples = _samples;
416     y = _y;
417     alpha_count = _alpha_count;
418     alpha = _alpha;
419     kernel = _kernel;
420 
421     C[0] = _Cn;
422     C[1] = _Cp;
423     eps = kernel->params->term_crit.epsilon;
424     max_iter = kernel->params->term_crit.max_iter;
425     storage = cvCreateChildMemStorage( _storage );
426 
427     b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
428     alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
429     G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
430     for( i = 0; i < 2; i++ )
431         buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
432     svm_type = kernel->params->svm_type;
433 
434     select_working_set_func = _select_working_set;
435     if( !select_working_set_func )
436         select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
437         &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
438 
439     calc_rho_func = _calc_rho;
440     if( !calc_rho_func )
441         calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
442             &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
443 
444     get_row_func = _get_row;
445     if( !get_row_func )
446         get_row_func = params->svm_type == CvSVM::EPS_SVR ||
447                        params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
448                        params->svm_type == CvSVM::C_SVC ||
449                        params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
450                        &CvSVMSolver::get_row_one_class;
451 
452     cache_line_size = sample_count*sizeof(Qfloat);
453     // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
454     // (assuming that for large training sets ~25% of Q matrix is used)
455     cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
456 
457     // the size of Q matrix row headers
458     rows_hdr_size = sample_count*sizeof(rows[0]);
459     if( rows_hdr_size > storage->block_size )
460         CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
461 
462     lru_list.prev = lru_list.next = &lru_list;
463     rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
464     memset( rows, 0, rows_hdr_size );
465 
466     ok = true;
467 
468     __END__;
469 
470     return ok;
471 }
472 
473 
get_row_base(int i,bool * _existed)474 float* CvSVMSolver::get_row_base( int i, bool* _existed )
475 {
476     int i1 = i < sample_count ? i : i - sample_count;
477     CvSVMKernelRow* row = rows + i1;
478     bool existed = row->data != 0;
479     Qfloat* data;
480 
481     if( existed || cache_size <= 0 )
482     {
483         CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
484         data = del_row->data;
485         assert( data != 0 );
486 
487         // delete row from the LRU list
488         del_row->data = 0;
489         del_row->prev->next = del_row->next;
490         del_row->next->prev = del_row->prev;
491     }
492     else
493     {
494         data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
495         cache_size -= cache_line_size;
496     }
497 
498     // insert row into the LRU list
499     row->data = data;
500     row->prev = &lru_list;
501     row->next = lru_list.next;
502     row->prev->next = row->next->prev = row;
503 
504     if( !existed )
505     {
506         kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
507     }
508 
509     if( _existed )
510         *_existed = existed;
511 
512     return row->data;
513 }
514 
515 
get_row_svc(int i,float * row,float *,bool existed)516 float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
517 {
518     if( !existed )
519     {
520         const schar* _y = y;
521         int j, len = sample_count;
522         assert( _y && i < sample_count );
523 
524         if( _y[i] > 0 )
525         {
526             for( j = 0; j < len; j++ )
527                 row[j] = _y[j]*row[j];
528         }
529         else
530         {
531             for( j = 0; j < len; j++ )
532                 row[j] = -_y[j]*row[j];
533         }
534     }
535     return row;
536 }
537 
538 
get_row_one_class(int,float * row,float *,bool)539 float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
540 {
541     return row;
542 }
543 
544 
get_row_svr(int i,float * row,float * dst,bool)545 float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
546 {
547     int j, len = sample_count;
548     Qfloat* dst_pos = dst;
549     Qfloat* dst_neg = dst + len;
550     if( i >= len )
551     {
552         Qfloat* temp;
553         CV_SWAP( dst_pos, dst_neg, temp );
554     }
555 
556     for( j = 0; j < len; j++ )
557     {
558         Qfloat t = row[j];
559         dst_pos[j] = t;
560         dst_neg[j] = -t;
561     }
562     return dst;
563 }
564 
565 
566 
get_row(int i,float * dst)567 float* CvSVMSolver::get_row( int i, float* dst )
568 {
569     bool existed = false;
570     float* row = get_row_base( i, &existed );
571     return (this->*get_row_func)( i, row, dst, existed );
572 }
573 
574 
575 #undef is_upper_bound
576 #define is_upper_bound(i) (alpha_status[i] > 0)
577 
578 #undef is_lower_bound
579 #define is_lower_bound(i) (alpha_status[i] < 0)
580 
581 #undef is_free
582 #define is_free(i) (alpha_status[i] == 0)
583 
584 #undef get_C
585 #define get_C(i) (C[y[i]>0])
586 
587 #undef update_alpha_status
588 #define update_alpha_status(i) \
589     alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
590 
591 #undef reconstruct_gradient
592 #define reconstruct_gradient() /* empty for now */
593 
594 
solve_generic(CvSVMSolutionInfo & si)595 bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
596 {
597     int iter = 0;
598     int i, j, k;
599 
600     // 1. initialize gradient and alpha status
601     for( i = 0; i < alpha_count; i++ )
602     {
603         update_alpha_status(i);
604         G[i] = b[i];
605         if( fabs(G[i]) > 1e200 )
606             return false;
607     }
608 
609     for( i = 0; i < alpha_count; i++ )
610     {
611         if( !is_lower_bound(i) )
612         {
613             const Qfloat *Q_i = get_row( i, buf[0] );
614             double alpha_i = alpha[i];
615 
616             for( j = 0; j < alpha_count; j++ )
617                 G[j] += alpha_i*Q_i[j];
618         }
619     }
620 
621     // 2. optimization loop
622     for(;;)
623     {
624         const Qfloat *Q_i, *Q_j;
625         double C_i, C_j;
626         double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
627         double delta_alpha_i, delta_alpha_j;
628 
629 #ifdef _DEBUG
630         for( i = 0; i < alpha_count; i++ )
631         {
632             if( fabs(G[i]) > 1e+300 )
633                 return false;
634 
635             if( fabs(alpha[i]) > 1e16 )
636                 return false;
637         }
638 #endif
639 
640         if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
641             break;
642 
643         Q_i = get_row( i, buf[0] );
644         Q_j = get_row( j, buf[1] );
645 
646         C_i = get_C(i);
647         C_j = get_C(j);
648 
649         alpha_i = old_alpha_i = alpha[i];
650         alpha_j = old_alpha_j = alpha[j];
651 
652         if( y[i] != y[j] )
653         {
654             double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
655             double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
656             double diff = alpha_i - alpha_j;
657             alpha_i += delta;
658             alpha_j += delta;
659 
660             if( diff > 0 && alpha_j < 0 )
661             {
662                 alpha_j = 0;
663                 alpha_i = diff;
664             }
665             else if( diff <= 0 && alpha_i < 0 )
666             {
667                 alpha_i = 0;
668                 alpha_j = -diff;
669             }
670 
671             if( diff > C_i - C_j && alpha_i > C_i )
672             {
673                 alpha_i = C_i;
674                 alpha_j = C_i - diff;
675             }
676             else if( diff <= C_i - C_j && alpha_j > C_j )
677             {
678                 alpha_j = C_j;
679                 alpha_i = C_j + diff;
680             }
681         }
682         else
683         {
684             double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
685             double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
686             double sum = alpha_i + alpha_j;
687             alpha_i -= delta;
688             alpha_j += delta;
689 
690             if( sum > C_i && alpha_i > C_i )
691             {
692                 alpha_i = C_i;
693                 alpha_j = sum - C_i;
694             }
695             else if( sum <= C_i && alpha_j < 0)
696             {
697                 alpha_j = 0;
698                 alpha_i = sum;
699             }
700 
701             if( sum > C_j && alpha_j > C_j )
702             {
703                 alpha_j = C_j;
704                 alpha_i = sum - C_j;
705             }
706             else if( sum <= C_j && alpha_i < 0 )
707             {
708                 alpha_i = 0;
709                 alpha_j = sum;
710             }
711         }
712 
713         // update alpha
714         alpha[i] = alpha_i;
715         alpha[j] = alpha_j;
716         update_alpha_status(i);
717         update_alpha_status(j);
718 
719         // update G
720         delta_alpha_i = alpha_i - old_alpha_i;
721         delta_alpha_j = alpha_j - old_alpha_j;
722 
723         for( k = 0; k < alpha_count; k++ )
724             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
725     }
726 
727     // calculate rho
728     (this->*calc_rho_func)( si.rho, si.r );
729 
730     // calculate objective value
731     for( i = 0, si.obj = 0; i < alpha_count; i++ )
732         si.obj += alpha[i] * (G[i] + b[i]);
733 
734     si.obj *= 0.5;
735 
736     si.upper_bound_p = C[1];
737     si.upper_bound_n = C[0];
738 
739     return true;
740 }
741 
742 
743 // return 1 if already optimal, return 0 otherwise
744 bool
select_working_set(int & out_i,int & out_j)745 CvSVMSolver::select_working_set( int& out_i, int& out_j )
746 {
747     // return i,j which maximize -grad(f)^T d , under constraint
748     // if alpha_i == C, d != +1
749     // if alpha_i == 0, d != -1
750     double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
751     int Gmax1_idx = -1;
752 
753     double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
754     int Gmax2_idx = -1;
755 
756     int i;
757 
758     for( i = 0; i < alpha_count; i++ )
759     {
760         double t;
761 
762         if( y[i] > 0 )    // y = +1
763         {
764             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
765             {
766                 Gmax1 = t;
767                 Gmax1_idx = i;
768             }
769             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
770             {
771                 Gmax2 = t;
772                 Gmax2_idx = i;
773             }
774         }
775         else        // y = -1
776         {
777             if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
778             {
779                 Gmax2 = t;
780                 Gmax2_idx = i;
781             }
782             if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
783             {
784                 Gmax1 = t;
785                 Gmax1_idx = i;
786             }
787         }
788     }
789 
790     out_i = Gmax1_idx;
791     out_j = Gmax2_idx;
792 
793     return Gmax1 + Gmax2 < eps;
794 }
795 
796 
797 void
calc_rho(double & rho,double & r)798 CvSVMSolver::calc_rho( double& rho, double& r )
799 {
800     int i, nr_free = 0;
801     double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
802 
803     for( i = 0; i < alpha_count; i++ )
804     {
805         double yG = y[i]*G[i];
806 
807         if( is_lower_bound(i) )
808         {
809             if( y[i] > 0 )
810                 ub = MIN(ub,yG);
811             else
812                 lb = MAX(lb,yG);
813         }
814         else if( is_upper_bound(i) )
815         {
816             if( y[i] < 0)
817                 ub = MIN(ub,yG);
818             else
819                 lb = MAX(lb,yG);
820         }
821         else
822         {
823             ++nr_free;
824             sum_free += yG;
825         }
826     }
827 
828     rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
829     r = 0;
830 }
831 
832 
833 bool
select_working_set_nu_svm(int & out_i,int & out_j)834 CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
835 {
836     // return i,j which maximize -grad(f)^T d , under constraint
837     // if alpha_i == C, d != +1
838     // if alpha_i == 0, d != -1
839     double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
840     int Gmax1_idx = -1;
841 
842     double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
843     int Gmax2_idx = -1;
844 
845     double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
846     int Gmax3_idx = -1;
847 
848     double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
849     int Gmax4_idx = -1;
850 
851     int i;
852 
853     for( i = 0; i < alpha_count; i++ )
854     {
855         double t;
856 
857         if( y[i] > 0 )    // y == +1
858         {
859             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
860             {
861                 Gmax1 = t;
862                 Gmax1_idx = i;
863             }
864             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
865             {
866                 Gmax2 = t;
867                 Gmax2_idx = i;
868             }
869         }
870         else        // y == -1
871         {
872             if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
873             {
874                 Gmax3 = t;
875                 Gmax3_idx = i;
876             }
877             if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
878             {
879                 Gmax4 = t;
880                 Gmax4_idx = i;
881             }
882         }
883     }
884 
885     if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
886         return 1;
887 
888     if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
889     {
890         out_i = Gmax1_idx;
891         out_j = Gmax2_idx;
892     }
893     else
894     {
895         out_i = Gmax3_idx;
896         out_j = Gmax4_idx;
897     }
898     return 0;
899 }
900 
901 
902 void
calc_rho_nu_svm(double & rho,double & r)903 CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
904 {
905     int nr_free1 = 0, nr_free2 = 0;
906     double ub1 = DBL_MAX, ub2 = DBL_MAX;
907     double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
908     double sum_free1 = 0, sum_free2 = 0;
909     double r1, r2;
910 
911     int i;
912 
913     for( i = 0; i < alpha_count; i++ )
914     {
915         double G_i = G[i];
916         if( y[i] > 0 )
917         {
918             if( is_lower_bound(i) )
919                 ub1 = MIN( ub1, G_i );
920             else if( is_upper_bound(i) )
921                 lb1 = MAX( lb1, G_i );
922             else
923             {
924                 ++nr_free1;
925                 sum_free1 += G_i;
926             }
927         }
928         else
929         {
930             if( is_lower_bound(i) )
931                 ub2 = MIN( ub2, G_i );
932             else if( is_upper_bound(i) )
933                 lb2 = MAX( lb2, G_i );
934             else
935             {
936                 ++nr_free2;
937                 sum_free2 += G_i;
938             }
939         }
940     }
941 
942     r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
943     r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
944 
945     rho = (r1 - r2)*0.5;
946     r = (r1 + r2)*0.5;
947 }
948 
949 
950 /*
951 ///////////////////////// construct and solve various formulations ///////////////////////
952 */
953 
solve_c_svc(int _sample_count,int _var_count,const float ** _samples,schar * _y,double _Cp,double _Cn,CvMemStorage * _storage,CvSVMKernel * _kernel,double * _alpha,CvSVMSolutionInfo & _si)954 bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
955                                double _Cp, double _Cn, CvMemStorage* _storage,
956                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
957 {
958     int i;
959 
960     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
961                  _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
962                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
963         return false;
964 
965     for( i = 0; i < sample_count; i++ )
966     {
967         alpha[i] = 0;
968         b[i] = -1;
969     }
970 
971     if( !solve_generic( _si ))
972         return false;
973 
974     for( i = 0; i < sample_count; i++ )
975         alpha[i] *= y[i];
976 
977     return true;
978 }
979 
980 
solve_nu_svc(int _sample_count,int _var_count,const float ** _samples,schar * _y,CvMemStorage * _storage,CvSVMKernel * _kernel,double * _alpha,CvSVMSolutionInfo & _si)981 bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
982                                 CvMemStorage* _storage, CvSVMKernel* _kernel,
983                                 double* _alpha, CvSVMSolutionInfo& _si )
984 {
985     int i;
986     double sum_pos, sum_neg, inv_r;
987 
988     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
989                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
990                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
991         return false;
992 
993     sum_pos = kernel->params->nu * sample_count * 0.5;
994     sum_neg = kernel->params->nu * sample_count * 0.5;
995 
996     for( i = 0; i < sample_count; i++ )
997     {
998         if( y[i] > 0 )
999         {
1000             alpha[i] = MIN(1.0, sum_pos);
1001             sum_pos -= alpha[i];
1002         }
1003         else
1004         {
1005             alpha[i] = MIN(1.0, sum_neg);
1006             sum_neg -= alpha[i];
1007         }
1008         b[i] = 0;
1009     }
1010 
1011     if( !solve_generic( _si ))
1012         return false;
1013 
1014     inv_r = 1./_si.r;
1015 
1016     for( i = 0; i < sample_count; i++ )
1017         alpha[i] *= y[i]*inv_r;
1018 
1019     _si.rho *= inv_r;
1020     _si.obj *= (inv_r*inv_r);
1021     _si.upper_bound_p = inv_r;
1022     _si.upper_bound_n = inv_r;
1023 
1024     return true;
1025 }
1026 
1027 
solve_one_class(int _sample_count,int _var_count,const float ** _samples,CvMemStorage * _storage,CvSVMKernel * _kernel,double * _alpha,CvSVMSolutionInfo & _si)1028 bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
1029                                    CvMemStorage* _storage, CvSVMKernel* _kernel,
1030                                    double* _alpha, CvSVMSolutionInfo& _si )
1031 {
1032     int i, n;
1033     double nu = _kernel->params->nu;
1034 
1035     if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
1036                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
1037                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1038         return false;
1039 
1040     y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
1041     n = cvRound( nu*sample_count );
1042 
1043     for( i = 0; i < sample_count; i++ )
1044     {
1045         y[i] = 1;
1046         b[i] = 0;
1047         alpha[i] = i < n ? 1 : 0;
1048     }
1049 
1050     if( n < sample_count )
1051         alpha[n] = nu * sample_count - n;
1052     else
1053         alpha[n-1] = nu * sample_count - (n-1);
1054 
1055     return solve_generic(_si);
1056 }
1057 
1058 
solve_eps_svr(int _sample_count,int _var_count,const float ** _samples,const float * _y,CvMemStorage * _storage,CvSVMKernel * _kernel,double * _alpha,CvSVMSolutionInfo & _si)1059 bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
1060                                  const float* _y, CvMemStorage* _storage,
1061                                  CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1062 {
1063     int i;
1064     double p = _kernel->params->p, C = _kernel->params->C;
1065 
1066     if( !create( _sample_count, _var_count, _samples, 0,
1067                  _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
1068                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
1069         return false;
1070 
1071     y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1072     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1073 
1074     for( i = 0; i < sample_count; i++ )
1075     {
1076         alpha[i] = 0;
1077         b[i] = p - _y[i];
1078         y[i] = 1;
1079 
1080         alpha[i+sample_count] = 0;
1081         b[i+sample_count] = p + _y[i];
1082         y[i+sample_count] = -1;
1083     }
1084 
1085     if( !solve_generic( _si ))
1086         return false;
1087 
1088     for( i = 0; i < sample_count; i++ )
1089         _alpha[i] = alpha[i] - alpha[i+sample_count];
1090 
1091     return true;
1092 }
1093 
1094 
solve_nu_svr(int _sample_count,int _var_count,const float ** _samples,const float * _y,CvMemStorage * _storage,CvSVMKernel * _kernel,double * _alpha,CvSVMSolutionInfo & _si)1095 bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
1096                                 const float* _y, CvMemStorage* _storage,
1097                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
1098 {
1099     int i;
1100     double C = _kernel->params->C, sum;
1101 
1102     if( !create( _sample_count, _var_count, _samples, 0,
1103                  _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
1104                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
1105         return false;
1106 
1107     y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
1108     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
1109     sum = C * _kernel->params->nu * sample_count * 0.5;
1110 
1111     for( i = 0; i < sample_count; i++ )
1112     {
1113         alpha[i] = alpha[i + sample_count] = MIN(sum, C);
1114         sum -= alpha[i];
1115 
1116         b[i] = -_y[i];
1117         y[i] = 1;
1118 
1119         b[i + sample_count] = _y[i];
1120         y[i + sample_count] = -1;
1121     }
1122 
1123     if( !solve_generic( _si ))
1124         return false;
1125 
1126     for( i = 0; i < sample_count; i++ )
1127         _alpha[i] = alpha[i] - alpha[i+sample_count];
1128 
1129     return true;
1130 }
1131 
1132 
1133 //////////////////////////////////////////////////////////////////////////////////////////
1134 
CvSVM()1135 CvSVM::CvSVM()
1136 {
1137     decision_func = 0;
1138     class_labels = 0;
1139     class_weights = 0;
1140     storage = 0;
1141     var_idx = 0;
1142     kernel = 0;
1143     solver = 0;
1144     default_model_name = "my_svm";
1145 
1146     clear();
1147 }
1148 
1149 
~CvSVM()1150 CvSVM::~CvSVM()
1151 {
1152     clear();
1153 }
1154 
1155 
clear()1156 void CvSVM::clear()
1157 {
1158     cvFree( &decision_func );
1159     cvReleaseMat( &class_labels );
1160     cvReleaseMat( &class_weights );
1161     cvReleaseMemStorage( &storage );
1162     cvReleaseMat( &var_idx );
1163     delete kernel;
1164     delete solver;
1165     kernel = 0;
1166     solver = 0;
1167     var_all = 0;
1168     sv = 0;
1169     sv_total = 0;
1170 }
1171 
1172 
CvSVM(const CvMat * _train_data,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,CvSVMParams _params)1173 CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
1174     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1175 {
1176     decision_func = 0;
1177     class_labels = 0;
1178     class_weights = 0;
1179     storage = 0;
1180     var_idx = 0;
1181     kernel = 0;
1182     solver = 0;
1183     default_model_name = "my_svm";
1184 
1185     train( _train_data, _responses, _var_idx, _sample_idx, _params );
1186 }
1187 
1188 
get_support_vector_count() const1189 int CvSVM::get_support_vector_count() const
1190 {
1191     return sv_total;
1192 }
1193 
1194 
get_support_vector(int i) const1195 const float* CvSVM::get_support_vector(int i) const
1196 {
1197     return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
1198 }
1199 
1200 
set_params(const CvSVMParams & _params)1201 bool CvSVM::set_params( const CvSVMParams& _params )
1202 {
1203     bool ok = false;
1204 
1205     CV_FUNCNAME( "CvSVM::set_params" );
1206 
1207     __BEGIN__;
1208 
1209     int kernel_type, svm_type;
1210 
1211     params = _params;
1212 
1213     kernel_type = params.kernel_type;
1214     svm_type = params.svm_type;
1215 
1216     if( kernel_type != LINEAR && kernel_type != POLY &&
1217         kernel_type != SIGMOID && kernel_type != RBF )
1218         CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
1219 
1220     if( kernel_type == LINEAR )
1221         params.gamma = 1;
1222     else if( params.gamma <= 0 )
1223         CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
1224 
1225     if( kernel_type != SIGMOID && kernel_type != POLY )
1226         params.coef0 = 0;
1227     else if( params.coef0 < 0 )
1228         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
1229 
1230     if( kernel_type != POLY )
1231         params.degree = 0;
1232     else if( params.degree <= 0 )
1233         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
1234 
1235     if( svm_type != C_SVC && svm_type != NU_SVC &&
1236         svm_type != ONE_CLASS && svm_type != EPS_SVR &&
1237         svm_type != NU_SVR )
1238         CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
1239 
1240     if( svm_type == ONE_CLASS || svm_type == NU_SVC )
1241         params.C = 0;
1242     else if( params.C <= 0 )
1243         CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
1244 
1245     if( svm_type == C_SVC || svm_type == EPS_SVR )
1246         params.nu = 0;
1247     else if( params.nu <= 0 || params.nu >= 1 )
1248         CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
1249 
1250     if( svm_type != EPS_SVR )
1251         params.p = 0;
1252     else if( params.p <= 0 )
1253         CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
1254 
1255     if( svm_type != C_SVC )
1256         params.class_weights = 0;
1257 
1258     params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
1259     params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
1260     ok = true;
1261 
1262     __END__;
1263 
1264     return ok;
1265 }
1266 
1267 
1268 
create_kernel()1269 void CvSVM::create_kernel()
1270 {
1271     kernel = new CvSVMKernel(&params,0);
1272 }
1273 
1274 
create_solver()1275 void CvSVM::create_solver( )
1276 {
1277     solver = new CvSVMSolver;
1278 }
1279 
1280 
1281 // switching function
train1(int sample_count,int var_count,const float ** samples,const void * _responses,double Cp,double Cn,CvMemStorage * _storage,double * alpha,double & rho)1282 bool CvSVM::train1( int sample_count, int var_count, const float** samples,
1283                     const void* _responses, double Cp, double Cn,
1284                     CvMemStorage* _storage, double* alpha, double& rho )
1285 {
1286     bool ok = false;
1287 
1288     //CV_FUNCNAME( "CvSVM::train1" );
1289 
1290     __BEGIN__;
1291 
1292     CvSVMSolutionInfo si;
1293     int svm_type = params.svm_type;
1294 
1295     si.rho = 0;
1296 
1297     ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
1298                                                   Cp, Cn, _storage, kernel, alpha, si ) :
1299          svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
1300                                                     _storage, kernel, alpha, si ) :
1301          svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
1302                                                           _storage, kernel, alpha, si ) :
1303          svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
1304                                                       _storage, kernel, alpha, si ) :
1305          svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
1306                                                     _storage, kernel, alpha, si ) : false;
1307 
1308     rho = si.rho;
1309 
1310     __END__;
1311 
1312     return ok;
1313 }
1314 
1315 
do_train(int svm_type,int sample_count,int var_count,const float ** samples,const CvMat * responses,CvMemStorage * temp_storage,double * alpha)1316 bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
1317                     const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
1318 {
1319     bool ok = false;
1320 
1321     CV_FUNCNAME( "CvSVM::do_train" );
1322 
1323     __BEGIN__;
1324 
1325     CvSVMDecisionFunc* df = 0;
1326     const int sample_size = var_count*sizeof(samples[0][0]);
1327     int i, j, k;
1328 
1329     if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
1330     {
1331         int sv_count = 0;
1332 
1333         CV_CALL( decision_func = df =
1334             (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
1335 
1336         df->rho = 0;
1337         if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
1338             responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
1339             EXIT;
1340 
1341         for( i = 0; i < sample_count; i++ )
1342             sv_count += fabs(alpha[i]) > 0;
1343 
1344         sv_total = df->sv_count = sv_count;
1345         CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
1346         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
1347 
1348         for( i = k = 0; i < sample_count; i++ )
1349         {
1350             if( fabs(alpha[i]) > 0 )
1351             {
1352                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1353                 memcpy( sv[k], samples[i], sample_size );
1354                 df->alpha[k++] = alpha[i];
1355             }
1356         }
1357     }
1358     else
1359     {
1360         int class_count = class_labels->cols;
1361         int* sv_tab = 0;
1362         const float** temp_samples = 0;
1363         int* class_ranges = 0;
1364         schar* temp_y = 0;
1365         assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
1366 
1367         if( svm_type == CvSVM::C_SVC && params.class_weights )
1368         {
1369             const CvMat* cw = params.class_weights;
1370 
1371             if( !CV_IS_MAT(cw) || cw->cols != 1 && cw->rows != 1 ||
1372                 cw->rows + cw->cols - 1 != class_count ||
1373                 CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1 )
1374                 CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
1375                     "containing as many elements as the number of classes" );
1376 
1377             CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
1378             CV_CALL( cvConvert( cw, class_weights ));
1379             CV_CALL( cvScale( class_weights, class_weights, params.C ));
1380         }
1381 
1382         CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
1383             (class_count*(class_count-1)/2)*sizeof(df[0])));
1384 
1385         CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
1386         memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
1387         CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
1388                             (class_count + 1)*sizeof(class_ranges[0])));
1389         CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
1390                             sample_count*sizeof(temp_samples[0])));
1391         CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
1392 
1393         class_ranges[class_count] = 0;
1394         cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
1395         //check that while cross-validation there were the samples from all the classes
1396         if( class_ranges[class_count] <= 0 )
1397             CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
1398             "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
1399 
1400         if( svm_type == NU_SVC )
1401         {
1402             // check if nu is feasible
1403             for(i = 0; i < class_count; i++ )
1404             {
1405                 int ci = class_ranges[i+1] - class_ranges[i];
1406                 for( j = i+1; j< class_count; j++ )
1407                 {
1408                     int cj = class_ranges[j+1] - class_ranges[j];
1409                     if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
1410                     {
1411                         // !!!TODO!!! add some diagnostic
1412                         EXIT; // exit immediately; will release the model and return NULL pointer
1413                     }
1414                 }
1415             }
1416         }
1417 
1418         // train n*(n-1)/2 classifiers
1419         for( i = 0; i < class_count; i++ )
1420         {
1421             for( j = i+1; j < class_count; j++, df++ )
1422             {
1423                 int si = class_ranges[i], ci = class_ranges[i+1] - si;
1424                 int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
1425                 double Cp = params.C, Cn = Cp;
1426                 int k1 = 0, sv_count = 0;
1427 
1428                 for( k = 0; k < ci; k++ )
1429                 {
1430                     temp_samples[k] = samples[si + k];
1431                     temp_y[k] = 1;
1432                 }
1433 
1434                 for( k = 0; k < cj; k++ )
1435                 {
1436                     temp_samples[ci + k] = samples[sj + k];
1437                     temp_y[ci + k] = -1;
1438                 }
1439 
1440                 if( class_weights )
1441                 {
1442                     Cp = class_weights->data.db[i];
1443                     Cn = class_weights->data.db[j];
1444                 }
1445 
1446                 if( !train1( ci + cj, var_count, temp_samples, temp_y,
1447                              Cp, Cn, temp_storage, alpha, df->rho ))
1448                     EXIT;
1449 
1450                 for( k = 0; k < ci + cj; k++ )
1451                     sv_count += fabs(alpha[k]) > 0;
1452 
1453                 df->sv_count = sv_count;
1454 
1455                 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
1456                                                 sv_count*sizeof(df->alpha[0])));
1457                 CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
1458                                                 sv_count*sizeof(df->sv_index[0])));
1459 
1460                 for( k = 0; k < ci; k++ )
1461                 {
1462                     if( fabs(alpha[k]) > 0 )
1463                     {
1464                         sv_tab[si + k] = 1;
1465                         df->sv_index[k1] = si + k;
1466                         df->alpha[k1++] = alpha[k];
1467                     }
1468                 }
1469 
1470                 for( k = 0; k < cj; k++ )
1471                 {
1472                     if( fabs(alpha[ci + k]) > 0 )
1473                     {
1474                         sv_tab[sj + k] = 1;
1475                         df->sv_index[k1] = sj + k;
1476                         df->alpha[k1++] = alpha[ci + k];
1477                     }
1478                 }
1479             }
1480         }
1481 
1482         // allocate support vectors and initialize sv_tab
1483         for( i = 0, k = 0; i < sample_count; i++ )
1484         {
1485             if( sv_tab[i] )
1486                 sv_tab[i] = ++k;
1487         }
1488 
1489         sv_total = k;
1490         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
1491 
1492         for( i = 0, k = 0; i < sample_count; i++ )
1493         {
1494             if( sv_tab[i] )
1495             {
1496                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
1497                 memcpy( sv[k], samples[i], sample_size );
1498                 k++;
1499             }
1500         }
1501 
1502         df = (CvSVMDecisionFunc*)decision_func;
1503 
1504         // set sv pointers
1505         for( i = 0; i < class_count; i++ )
1506         {
1507             for( j = i+1; j < class_count; j++, df++ )
1508             {
1509                 for( k = 0; k < df->sv_count; k++ )
1510                 {
1511                     df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
1512                     assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
1513                 }
1514             }
1515         }
1516     }
1517 
1518     ok = true;
1519 
1520     __END__;
1521 
1522     return ok;
1523 }
1524 
train(const CvMat * _train_data,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,CvSVMParams _params)1525 bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
1526     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
1527 {
1528     bool ok = false;
1529     CvMat* responses = 0;
1530     CvMemStorage* temp_storage = 0;
1531     const float** samples = 0;
1532 
1533     CV_FUNCNAME( "CvSVM::train" );
1534 
1535     __BEGIN__;
1536 
1537     int svm_type, sample_count, var_count, sample_size;
1538     int block_size = 1 << 16;
1539     double* alpha;
1540 
1541     clear();
1542     CV_CALL( set_params( _params ));
1543 
1544     svm_type = _params.svm_type;
1545 
1546     /* Prepare training data and related parameters */
1547     CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
1548                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1549                                  svm_type == CvSVM::C_SVC ||
1550                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1551                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
1552                                  false, &samples, &sample_count, &var_count, &var_all,
1553                                  &responses, &class_labels, &var_idx ));
1554 
1555 
1556     sample_size = var_count*sizeof(samples[0][0]);
1557 
1558     // make the storage block size large enough to fit all
1559     // the temporary vectors and output support vectors.
1560     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1561     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1562     block_size = MAX( block_size, sample_size*2 + 1024 );
1563 
1564     CV_CALL( storage = cvCreateMemStorage(block_size));
1565     CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
1566     CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1567 
1568     create_kernel();
1569     create_solver();
1570 
1571     if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
1572         EXIT;
1573 
1574     ok = true; // model has been trained succesfully
1575 
1576     __END__;
1577 
1578     delete solver;
1579     solver = 0;
1580     cvReleaseMemStorage( &temp_storage );
1581     cvReleaseMat( &responses );
1582     cvFree( &samples );
1583 
1584     if( cvGetErrStatus() < 0 || !ok )
1585         clear();
1586 
1587     return ok;
1588 }
1589 
train_auto(const CvMat * _train_data,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,CvSVMParams _params,int k_fold,CvParamGrid C_grid,CvParamGrid gamma_grid,CvParamGrid p_grid,CvParamGrid nu_grid,CvParamGrid coef_grid,CvParamGrid degree_grid)1590 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
1591     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
1592     CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
1593     CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
1594 {
1595     bool ok = false;
1596     CvMat* responses = 0;
1597     CvMat* responses_local = 0;
1598     CvMemStorage* temp_storage = 0;
1599     const float** samples = 0;
1600     const float** samples_local = 0;
1601 
1602     CV_FUNCNAME( "CvSVM::train_auto" );
1603     __BEGIN__;
1604 
1605     int svm_type, sample_count, var_count, sample_size;
1606     int block_size = 1 << 16;
1607     double* alpha;
1608     int i, k;
1609     CvRNG rng = cvRNG(-1);
1610 
1611     // all steps are logarithmic and must be > 1
1612     double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
1613     double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
1614     double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
1615     float min_error = FLT_MAX, error;
1616 
1617     if( _params.svm_type == CvSVM::ONE_CLASS )
1618     {
1619         if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
1620             EXIT;
1621         return true;
1622     }
1623 
1624     clear();
1625 
1626     if( k_fold < 2 )
1627         CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
1628 
1629     CV_CALL(set_params( _params ));
1630     svm_type = _params.svm_type;
1631 
1632     // All the parameters except, possibly, <coef0> are positive.
1633     // <coef0> is nonnegative
1634     if( C_grid.step <= 1 )
1635     {
1636         C_grid.min_val = C_grid.max_val = params.C;
1637         C_grid.step = 10;
1638     }
1639     else
1640         CV_CALL(C_grid.check());
1641 
1642     if( gamma_grid.step <= 1 )
1643     {
1644         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1645         gamma_grid.step = 10;
1646     }
1647     else
1648         CV_CALL(gamma_grid.check());
1649 
1650     if( p_grid.step <= 1 )
1651     {
1652         p_grid.min_val = p_grid.max_val = params.p;
1653         p_grid.step = 10;
1654     }
1655     else
1656         CV_CALL(p_grid.check());
1657 
1658     if( nu_grid.step <= 1 )
1659     {
1660         nu_grid.min_val = nu_grid.max_val = params.nu;
1661         nu_grid.step = 10;
1662     }
1663     else
1664         CV_CALL(nu_grid.check());
1665 
1666     if( coef_grid.step <= 1 )
1667     {
1668         coef_grid.min_val = coef_grid.max_val = params.coef0;
1669         coef_grid.step = 10;
1670     }
1671     else
1672         CV_CALL(coef_grid.check());
1673 
1674     if( degree_grid.step <= 1 )
1675     {
1676         degree_grid.min_val = degree_grid.max_val = params.degree;
1677         degree_grid.step = 10;
1678     }
1679     else
1680         CV_CALL(degree_grid.check());
1681 
1682     // these parameters are not used:
1683     if( params.kernel_type != CvSVM::POLY )
1684         degree_grid.min_val = degree_grid.max_val = params.degree;
1685     if( params.kernel_type == CvSVM::LINEAR )
1686         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
1687     if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
1688         coef_grid.min_val = coef_grid.max_val = params.coef0;
1689     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
1690         C_grid.min_val = C_grid.max_val = params.C;
1691     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
1692         nu_grid.min_val = nu_grid.max_val = params.nu;
1693     if( svm_type != CvSVM::EPS_SVR )
1694         p_grid.min_val = p_grid.max_val = params.p;
1695 
1696     CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
1697     CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
1698 
1699     /* Prepare training data and related parameters */
1700     CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
1701                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
1702                                  svm_type == CvSVM::C_SVC ||
1703                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
1704                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
1705                                  false, &samples, &sample_count, &var_count, &var_all,
1706                                  &responses, &class_labels, &var_idx ));
1707 
1708     sample_size = var_count*sizeof(samples[0][0]);
1709 
1710     // make the storage block size large enough to fit all
1711     // the temporary vectors and output support vectors.
1712     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
1713     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
1714     block_size = MAX( block_size, sample_size*2 + 1024 );
1715 
1716     CV_CALL(storage = cvCreateMemStorage(block_size));
1717     CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
1718     CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
1719 
1720     create_kernel();
1721     create_solver();
1722 
1723     {
1724     const int testset_size = sample_count/k_fold;
1725     const int trainset_size = sample_count - testset_size;
1726     const int last_testset_size = sample_count - testset_size*(k_fold-1);
1727     const int last_trainset_size = sample_count - last_testset_size;
1728     const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
1729 
1730     size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
1731     size_t size = 2*last_trainset_size*sizeof(samples[0]);
1732 
1733     samples_local = (const float**) cvAlloc( size );
1734     memset( samples_local, 0, size );
1735 
1736     responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
1737     cvZero( responses_local );
1738 
1739     // randomly permute samples and responses
1740     for( i = 0; i < sample_count; i++ )
1741     {
1742         int i1 = cvRandInt( &rng ) % sample_count;
1743         int i2 = cvRandInt( &rng ) % sample_count;
1744         const float* temp;
1745         float t;
1746         int y;
1747 
1748         CV_SWAP( samples[i1], samples[i2], temp );
1749         if( is_regression )
1750             CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
1751         else
1752             CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
1753     }
1754 
1755     C = C_grid.min_val;
1756     do
1757     {
1758       params.C = C;
1759       gamma = gamma_grid.min_val;
1760       do
1761       {
1762         params.gamma = gamma;
1763         p = p_grid.min_val;
1764         do
1765         {
1766           params.p = p;
1767           nu = nu_grid.min_val;
1768           do
1769           {
1770             params.nu = nu;
1771             coef = coef_grid.min_val;
1772             do
1773             {
1774               params.coef0 = coef;
1775               degree = degree_grid.min_val;
1776               do
1777               {
1778                 params.degree = degree;
1779 
1780                 float** test_samples_ptr = (float**)samples;
1781                 uchar* true_resp = responses->data.ptr;
1782                 int test_size = testset_size;
1783                 int train_size = trainset_size;
1784 
1785                 error = 0;
1786                 for( k = 0; k < k_fold; k++ )
1787                 {
1788                     memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
1789                     memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
1790                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1791 
1792                     memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
1793                     memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
1794                         true_resp + resp_elem_size*test_size,
1795                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
1796 
1797                     if( k == k_fold - 1 )
1798                     {
1799                         test_size = last_testset_size;
1800                         train_size = last_trainset_size;
1801                         responses_local->cols = last_trainset_size;
1802                     }
1803 
1804                     // Train SVM on <train_size> samples
1805                     if( !do_train( svm_type, train_size, var_count,
1806                         (const float**)samples_local, responses_local, temp_storage, alpha ) )
1807                         EXIT;
1808 
1809                     // Compute test set error on <test_size> samples
1810                     CvMat s = cvMat( 1, var_count, CV_32FC1 );
1811                     for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
1812                     {
1813                         float resp;
1814                         s.data.fl = *test_samples_ptr;
1815                         resp = predict( &s );
1816                         error += is_regression ? powf( resp - *(float*)true_resp, 2 )
1817                             : ((int)resp != *(int*)true_resp);
1818                     }
1819                 }
1820                 if( min_error > error )
1821                 {
1822                     min_error   = error;
1823                     best_degree = degree;
1824                     best_gamma  = gamma;
1825                     best_coef   = coef;
1826                     best_C      = C;
1827                     best_nu     = nu;
1828                     best_p      = p;
1829                 }
1830                 degree *= degree_grid.step;
1831               }
1832               while( degree < degree_grid.max_val );
1833               coef *= coef_grid.step;
1834             }
1835             while( coef < coef_grid.max_val );
1836             nu *= nu_grid.step;
1837           }
1838           while( nu < nu_grid.max_val );
1839           p *= p_grid.step;
1840         }
1841         while( p < p_grid.max_val );
1842         gamma *= gamma_grid.step;
1843       }
1844       while( gamma < gamma_grid.max_val );
1845       C *= C_grid.step;
1846     }
1847     while( C < C_grid.max_val );
1848     }
1849 
1850     min_error /= (float) sample_count;
1851 
1852     params.C      = best_C;
1853     params.nu     = best_nu;
1854     params.p      = best_p;
1855     params.gamma  = best_gamma;
1856     params.degree = best_degree;
1857     params.coef0  = best_coef;
1858 
1859     CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
1860 
1861     __END__;
1862 
1863     delete solver;
1864     solver = 0;
1865     cvReleaseMemStorage( &temp_storage );
1866     cvReleaseMat( &responses );
1867     cvReleaseMat( &responses_local );
1868     cvFree( &samples );
1869     cvFree( &samples_local );
1870 
1871     if( cvGetErrStatus() < 0 || !ok )
1872         clear();
1873 
1874     return ok;
1875 }
1876 
predict(const CvMat * sample) const1877 float CvSVM::predict( const CvMat* sample ) const
1878 {
1879     bool local_alloc = 0;
1880     float result = 0;
1881     float* row_sample = 0;
1882     Qfloat* buffer = 0;
1883 
1884     CV_FUNCNAME( "CvSVM::predict" );
1885 
1886     __BEGIN__;
1887 
1888     int class_count;
1889     int var_count, buf_sz;
1890 
1891     if( !kernel )
1892         CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
1893 
1894     class_count = class_labels ? class_labels->cols :
1895                   params.svm_type == ONE_CLASS ? 1 : 0;
1896 
1897     CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
1898                                    class_count, 0, &row_sample ));
1899 
1900     var_count = get_var_count();
1901 
1902     buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
1903     if( buf_sz <= CV_MAX_LOCAL_SIZE )
1904     {
1905         CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
1906         local_alloc = 1;
1907     }
1908     else
1909         CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
1910 
1911     if( params.svm_type == EPS_SVR ||
1912         params.svm_type == NU_SVR ||
1913         params.svm_type == ONE_CLASS )
1914     {
1915         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1916         int i, sv_count = df->sv_count;
1917         double sum = -df->rho;
1918 
1919         kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
1920         for( i = 0; i < sv_count; i++ )
1921             sum += buffer[i]*df->alpha[i];
1922 
1923         result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
1924     }
1925     else if( params.svm_type == C_SVC ||
1926              params.svm_type == NU_SVC )
1927     {
1928         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
1929         int* vote = (int*)(buffer + sv_total);
1930         int i, j, k;
1931 
1932         memset( vote, 0, class_count*sizeof(vote[0]));
1933         kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
1934 
1935         for( i = 0; i < class_count; i++ )
1936         {
1937             for( j = i+1; j < class_count; j++, df++ )
1938             {
1939                 double sum = -df->rho;
1940                 int sv_count = df->sv_count;
1941                 for( k = 0; k < sv_count; k++ )
1942                     sum += df->alpha[k]*buffer[df->sv_index[k]];
1943 
1944                 vote[sum > 0 ? i : j]++;
1945             }
1946         }
1947 
1948         for( i = 1, k = 0; i < class_count; i++ )
1949         {
1950             if( vote[i] > vote[k] )
1951                 k = i;
1952         }
1953 
1954         result = (float)(class_labels->data.i[k]);
1955     }
1956     else
1957         CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
1958                                 "the SVM structure is probably corrupted" );
1959 
1960     __END__;
1961 
1962     if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
1963         cvFree( &row_sample );
1964 
1965     if( !local_alloc )
1966         cvFree( &buffer );
1967 
1968     return result;
1969 }
1970 
1971 
write_params(CvFileStorage * fs)1972 void CvSVM::write_params( CvFileStorage* fs )
1973 {
1974     //CV_FUNCNAME( "CvSVM::write_params" );
1975 
1976     __BEGIN__;
1977 
1978     int svm_type = params.svm_type;
1979     int kernel_type = params.kernel_type;
1980 
1981     const char* svm_type_str =
1982         svm_type == CvSVM::C_SVC ? "C_SVC" :
1983         svm_type == CvSVM::NU_SVC ? "NU_SVC" :
1984         svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
1985         svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
1986         svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
1987     const char* kernel_type_str =
1988         kernel_type == CvSVM::LINEAR ? "LINEAR" :
1989         kernel_type == CvSVM::POLY ? "POLY" :
1990         kernel_type == CvSVM::RBF ? "RBF" :
1991         kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
1992 
1993     if( svm_type_str )
1994         cvWriteString( fs, "svm_type", svm_type_str );
1995     else
1996         cvWriteInt( fs, "svm_type", svm_type );
1997 
1998     // save kernel
1999     cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
2000 
2001     if( kernel_type_str )
2002         cvWriteString( fs, "type", kernel_type_str );
2003     else
2004         cvWriteInt( fs, "type", kernel_type );
2005 
2006     if( kernel_type == CvSVM::POLY || !kernel_type_str )
2007         cvWriteReal( fs, "degree", params.degree );
2008 
2009     if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
2010         cvWriteReal( fs, "gamma", params.gamma );
2011 
2012     if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
2013         cvWriteReal( fs, "coef0", params.coef0 );
2014 
2015     cvEndWriteStruct(fs);
2016 
2017     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
2018         svm_type == CvSVM::NU_SVR || !svm_type_str )
2019         cvWriteReal( fs, "C", params.C );
2020 
2021     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
2022         svm_type == CvSVM::NU_SVR || !svm_type_str )
2023         cvWriteReal( fs, "nu", params.nu );
2024 
2025     if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
2026         cvWriteReal( fs, "p", params.p );
2027 
2028     cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
2029     if( params.term_crit.type & CV_TERMCRIT_EPS )
2030         cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
2031     if( params.term_crit.type & CV_TERMCRIT_ITER )
2032         cvWriteInt( fs, "iterations", params.term_crit.max_iter );
2033     cvEndWriteStruct( fs );
2034 
2035     __END__;
2036 }
2037 
2038 
write(CvFileStorage * fs,const char * name)2039 void CvSVM::write( CvFileStorage* fs, const char* name )
2040 {
2041     CV_FUNCNAME( "CvSVM::write" );
2042 
2043     __BEGIN__;
2044 
2045     int i, var_count = get_var_count(), df_count, class_count;
2046     const CvSVMDecisionFunc* df = decision_func;
2047 
2048     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
2049 
2050     write_params( fs );
2051 
2052     cvWriteInt( fs, "var_all", var_all );
2053     cvWriteInt( fs, "var_count", var_count );
2054 
2055     class_count = class_labels ? class_labels->cols :
2056                   params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2057 
2058     if( class_count )
2059     {
2060         cvWriteInt( fs, "class_count", class_count );
2061 
2062         if( class_labels )
2063             cvWrite( fs, "class_labels", class_labels );
2064 
2065         if( class_weights )
2066             cvWrite( fs, "class_weights", class_weights );
2067     }
2068 
2069     if( var_idx )
2070         cvWrite( fs, "var_idx", var_idx );
2071 
2072     // write the joint collection of support vectors
2073     cvWriteInt( fs, "sv_total", sv_total );
2074     cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
2075     for( i = 0; i < sv_total; i++ )
2076     {
2077         cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
2078         cvWriteRawData( fs, sv[i], var_count, "f" );
2079         cvEndWriteStruct( fs );
2080     }
2081 
2082     cvEndWriteStruct( fs );
2083 
2084     // write decision functions
2085     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2086     df = decision_func;
2087 
2088     cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
2089     for( i = 0; i < df_count; i++ )
2090     {
2091         int sv_count = df[i].sv_count;
2092         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
2093         cvWriteInt( fs, "sv_count", sv_count );
2094         cvWriteReal( fs, "rho", df[i].rho );
2095         cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
2096         cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
2097         cvEndWriteStruct( fs );
2098         if( class_count > 1 )
2099         {
2100             cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
2101             cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
2102             cvEndWriteStruct( fs );
2103         }
2104         else
2105             CV_ASSERT( sv_count == sv_total );
2106         cvEndWriteStruct( fs );
2107     }
2108     cvEndWriteStruct( fs );
2109     cvEndWriteStruct( fs );
2110 
2111     __END__;
2112 }
2113 
2114 
read_params(CvFileStorage * fs,CvFileNode * svm_node)2115 void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
2116 {
2117     CV_FUNCNAME( "CvSVM::read_params" );
2118 
2119     __BEGIN__;
2120 
2121     int svm_type, kernel_type;
2122     CvSVMParams _params;
2123 
2124     CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
2125     CvFileNode* kernel_node;
2126     if( !tmp_node )
2127         CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
2128 
2129     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2130         svm_type = cvReadInt( tmp_node, -1 );
2131     else
2132     {
2133         const char* svm_type_str = cvReadString( tmp_node, "" );
2134         svm_type =
2135             strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
2136             strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
2137             strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
2138             strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
2139             strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
2140 
2141         if( svm_type < 0 )
2142             CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
2143     }
2144 
2145     kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
2146     if( !kernel_node )
2147         CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
2148 
2149     tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
2150     if( !tmp_node )
2151         CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
2152 
2153     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
2154         kernel_type = cvReadInt( tmp_node, -1 );
2155     else
2156     {
2157         const char* kernel_type_str = cvReadString( tmp_node, "" );
2158         kernel_type =
2159             strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
2160             strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
2161             strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
2162             strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
2163 
2164         if( kernel_type < 0 )
2165             CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
2166     }
2167 
2168     _params.svm_type = svm_type;
2169     _params.kernel_type = kernel_type;
2170     _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
2171     _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
2172     _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
2173 
2174     _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
2175     _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
2176     _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
2177     _params.class_weights = 0;
2178 
2179     tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
2180     if( tmp_node )
2181     {
2182         _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
2183         _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
2184         _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
2185                                (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
2186     }
2187     else
2188         _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
2189 
2190     set_params( _params );
2191 
2192     __END__;
2193 }
2194 
2195 
read(CvFileStorage * fs,CvFileNode * svm_node)2196 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
2197 {
2198     const double not_found_dbl = DBL_MAX;
2199 
2200     CV_FUNCNAME( "CvSVM::read" );
2201 
2202     __BEGIN__;
2203 
2204     int i, var_count, df_count, class_count;
2205     int block_size = 1 << 16, sv_size;
2206     CvFileNode *sv_node, *df_node;
2207     CvSVMDecisionFunc* df;
2208     CvSeqReader reader;
2209 
2210     if( !svm_node )
2211         CV_ERROR( CV_StsParseError, "The requested element is not found" );
2212 
2213     clear();
2214 
2215     // read SVM parameters
2216     read_params( fs, svm_node );
2217 
2218     // and top-level data
2219     sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
2220     var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
2221     var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
2222     class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
2223 
2224     if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
2225         CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
2226 
2227     CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
2228     CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
2229     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "comp_idx" ));
2230 
2231     if( class_count > 1 && (!class_labels ||
2232         !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
2233         CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
2234 
2235     if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
2236         CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
2237 
2238     // read support vectors
2239     sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
2240     if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
2241         CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
2242 
2243     block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
2244     block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
2245     block_size = MAX( block_size, var_all*(int)sizeof(double));
2246     CV_CALL( storage = cvCreateMemStorage( block_size ));
2247     CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
2248                                 sv_total*sizeof(sv[0]) ));
2249 
2250     CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
2251     sv_size = var_count*sizeof(sv[0][0]);
2252 
2253     for( i = 0; i < sv_total; i++ )
2254     {
2255         CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
2256         CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
2257                    sv_elem->data.seq->total == var_count) );
2258 
2259         CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
2260         CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
2261         CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
2262     }
2263 
2264     // read decision functions
2265     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2266     df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
2267     if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
2268         df_node->data.seq->total != df_count )
2269         CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
2270                   "or has a wrong number of elements" );
2271 
2272     CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
2273     cvStartReadSeq( df_node->data.seq, &reader, 0 );
2274 
2275     for( i = 0; i < df_count; i++ )
2276     {
2277         CvFileNode* df_elem = (CvFileNode*)reader.ptr;
2278         CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
2279 
2280         int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
2281         if( sv_count <= 0 )
2282             CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
2283         df[i].sv_count = sv_count;
2284 
2285         df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
2286         if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
2287             CV_ERROR( CV_StsParseError, "rho is missing" );
2288 
2289         if( !alpha_node )
2290             CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
2291 
2292         CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
2293                                         sv_count*sizeof(df[i].alpha[0])));
2294         CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(alpha_node->tag) &&
2295                    alpha_node->data.seq->total == sv_count );
2296         CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
2297 
2298         if( class_count > 1 )
2299         {
2300             CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
2301             if( !index_node )
2302                 CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
2303             CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
2304                                             sv_count*sizeof(df[i].sv_index[0])));
2305             CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(index_node->tag) &&
2306                    index_node->data.seq->total == sv_count );
2307             CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
2308         }
2309         else
2310             df[i].sv_index = 0;
2311 
2312         CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
2313     }
2314 
2315     create_kernel();
2316 
2317     __END__;
2318 }
2319 
2320 #if 0
2321 
2322 static void*
2323 icvCloneSVM( const void* _src )
2324 {
2325     CvSVMModel* dst = 0;
2326 
2327     CV_FUNCNAME( "icvCloneSVM" );
2328 
2329     __BEGIN__;
2330 
2331     const CvSVMModel* src = (const CvSVMModel*)_src;
2332     int var_count, class_count;
2333     int i, sv_total, df_count;
2334     int sv_size;
2335 
2336     if( !CV_IS_SVM(src) )
2337         CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
2338 
2339     // 0. create initial CvSVMModel structure
2340     CV_CALL( dst = icvCreateSVM() );
2341     dst->params = src->params;
2342     dst->params.weight_labels = 0;
2343     dst->params.weights = 0;
2344 
2345     dst->var_all = src->var_all;
2346     if( src->class_labels )
2347         dst->class_labels = cvCloneMat( src->class_labels );
2348     if( src->class_weights )
2349         dst->class_weights = cvCloneMat( src->class_weights );
2350     if( src->comp_idx )
2351         dst->comp_idx = cvCloneMat( src->comp_idx );
2352 
2353     var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
2354     class_count = src->class_labels ? src->class_labels->cols :
2355                   src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
2356     sv_total = dst->sv_total = src->sv_total;
2357     CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
2358     CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
2359                                     sv_total*sizeof(dst->sv[0]) ));
2360 
2361     sv_size = var_count*sizeof(dst->sv[0][0]);
2362 
2363     for( i = 0; i < sv_total; i++ )
2364     {
2365         CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
2366         memcpy( dst->sv[i], src->sv[i], sv_size );
2367     }
2368 
2369     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
2370 
2371     CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
2372 
2373     for( i = 0; i < df_count; i++ )
2374     {
2375         const CvSVMDecisionFunc *sdf =
2376             (const CvSVMDecisionFunc*)src->decision_func+i;
2377         CvSVMDecisionFunc *ddf =
2378             (CvSVMDecisionFunc*)dst->decision_func+i;
2379         int sv_count = sdf->sv_count;
2380         ddf->sv_count = sv_count;
2381         ddf->rho = sdf->rho;
2382         CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
2383                                         sv_count*sizeof(ddf->alpha[0])));
2384         memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
2385 
2386         if( class_count > 1 )
2387         {
2388             CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
2389                                                 sv_count*sizeof(ddf->sv_index[0])));
2390             memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
2391         }
2392         else
2393             ddf->sv_index = 0;
2394     }
2395 
2396     __END__;
2397 
2398     if( cvGetErrStatus() < 0 && dst )
2399         icvReleaseSVM( &dst );
2400 
2401     return dst;
2402 }
2403 
2404 static int icvRegisterSVMType()
2405 {
2406     CvTypeInfo info;
2407     memset( &info, 0, sizeof(info) );
2408 
2409     info.flags = 0;
2410     info.header_size = sizeof( info );
2411     info.is_instance = icvIsSVM;
2412     info.release = (CvReleaseFunc)icvReleaseSVM;
2413     info.read = icvReadSVM;
2414     info.write = icvWriteSVM;
2415     info.clone = icvCloneSVM;
2416     info.type_name = CV_TYPE_NAME_ML_SVM;
2417     cvRegisterType( &info );
2418 
2419     return 1;
2420 }
2421 
2422 
2423 static int svm = icvRegisterSVMType();
2424 
2425 /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
2426 The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
2427 The optimal parameters are saved in <model_params> */
2428 CV_IMPL CvStatModel*
2429 cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
2430             const CvMat* responses,
2431             CvStatModelParams* model_params,
2432             const CvStatModelParams* cross_valid_params,
2433             const CvMat* comp_idx,
2434             const CvMat* sample_idx,
2435             const CvParamGrid* degree_grid,
2436             const CvParamGrid* gamma_grid,
2437             const CvParamGrid* coef_grid,
2438             const CvParamGrid* C_grid,
2439             const CvParamGrid* nu_grid,
2440             const CvParamGrid* p_grid )
2441 {
2442     CvStatModel* svm = 0;
2443 
2444     CV_FUNCNAME("cvTainSVMCrossValidation");
2445     __BEGIN__;
2446 
2447     double degree_step = 7,
2448 	       g_step      = 15,
2449 		   coef_step   = 14,
2450 		   C_step      = 20,
2451 		   nu_step     = 5,
2452 		   p_step      = 7; // all steps must be > 1
2453     double degree_begin = 0.01, degree_end = 2;
2454     double g_begin      = 1e-5, g_end      = 0.5;
2455     double coef_begin   = 0.1,  coef_end   = 300;
2456     double C_begin      = 0.1,  C_end      = 6000;
2457     double nu_begin     = 0.01,  nu_end    = 0.4;
2458     double p_begin      = 0.01, p_end      = 100;
2459 
2460     double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
2461 
2462 	double best_rate    = 0;
2463     double best_degree  = degree_begin;
2464     double best_gamma   = g_begin;
2465     double best_coef    = coef_begin;
2466 	double best_C       = C_begin;
2467 	double best_nu      = nu_begin;
2468     double best_p       = p_begin;
2469 
2470     CvSVMModelParams svm_params, *psvm_params;
2471     CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
2472     int svm_type, kernel;
2473     int is_regression;
2474 
2475     if( !model_params )
2476         CV_ERROR( CV_StsBadArg, "" );
2477     if( !cv_params )
2478         CV_ERROR( CV_StsBadArg, "" );
2479 
2480     svm_params = *(CvSVMModelParams*)model_params;
2481     psvm_params = (CvSVMModelParams*)model_params;
2482     svm_type = svm_params.svm_type;
2483     kernel = svm_params.kernel_type;
2484 
2485     svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
2486     svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
2487     svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
2488     svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
2489     svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
2490     svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
2491 
2492     if( degree_grid )
2493     {
2494         if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
2495               degree_grid->step == 0) )
2496         {
2497             if( degree_grid->min_val > degree_grid->max_val )
2498                 CV_ERROR( CV_StsBadArg,
2499                 "low bound of grid should be less then the upper one");
2500             if( degree_grid->step <= 1 )
2501                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2502             degree_begin = degree_grid->min_val;
2503             degree_end   = degree_grid->max_val;
2504             degree_step  = degree_grid->step;
2505         }
2506     }
2507     else
2508         degree_begin = degree_end = svm_params.degree;
2509 
2510     if( gamma_grid )
2511     {
2512         if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
2513               gamma_grid->step == 0) )
2514         {
2515             if( gamma_grid->min_val > gamma_grid->max_val )
2516                 CV_ERROR( CV_StsBadArg,
2517                 "low bound of grid should be less then the upper one");
2518             if( gamma_grid->step <= 1 )
2519                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2520             g_begin = gamma_grid->min_val;
2521             g_end   = gamma_grid->max_val;
2522             g_step  = gamma_grid->step;
2523         }
2524     }
2525     else
2526         g_begin = g_end = svm_params.gamma;
2527 
2528     if( coef_grid )
2529     {
2530         if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
2531               coef_grid->step == 0) )
2532         {
2533             if( coef_grid->min_val > coef_grid->max_val )
2534                 CV_ERROR( CV_StsBadArg,
2535                 "low bound of grid should be less then the upper one");
2536             if( coef_grid->step <= 1 )
2537                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2538             coef_begin = coef_grid->min_val;
2539             coef_end   = coef_grid->max_val;
2540             coef_step  = coef_grid->step;
2541         }
2542     }
2543     else
2544         coef_begin = coef_end = svm_params.coef0;
2545 
2546     if( C_grid )
2547     {
2548         if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
2549         {
2550             if( C_grid->min_val > C_grid->max_val )
2551                 CV_ERROR( CV_StsBadArg,
2552                 "low bound of grid should be less then the upper one");
2553             if( C_grid->step <= 1 )
2554                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2555             C_begin = C_grid->min_val;
2556             C_end   = C_grid->max_val;
2557             C_step  = C_grid->step;
2558         }
2559     }
2560     else
2561         C_begin = C_end = svm_params.C;
2562 
2563     if( nu_grid )
2564     {
2565         if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
2566         {
2567             if( nu_grid->min_val > nu_grid->max_val )
2568                 CV_ERROR( CV_StsBadArg,
2569                 "low bound of grid should be less then the upper one");
2570             if( nu_grid->step <= 1 )
2571                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2572             nu_begin = nu_grid->min_val;
2573             nu_end   = nu_grid->max_val;
2574             nu_step  = nu_grid->step;
2575         }
2576     }
2577     else
2578         nu_begin = nu_end = svm_params.nu;
2579 
2580     if( p_grid )
2581     {
2582         if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
2583         {
2584             if( p_grid->min_val > p_grid->max_val )
2585                 CV_ERROR( CV_StsBadArg,
2586                 "low bound of grid should be less then the upper one");
2587             if( p_grid->step <= 1 )
2588                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
2589             p_begin = p_grid->min_val;
2590             p_end   = p_grid->max_val;
2591             p_step  = p_grid->step;
2592         }
2593     }
2594     else
2595         p_begin = p_end = svm_params.p;
2596 
2597     // these parameters are not used:
2598     if( kernel != CvSVM::POLY )
2599         degree_begin = degree_end = svm_params.degree;
2600 
2601    if( kernel == CvSVM::LINEAR )
2602         g_begin = g_end = svm_params.gamma;
2603 
2604     if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
2605         coef_begin = coef_end = svm_params.coef0;
2606 
2607     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
2608         C_begin = C_end = svm_params.C;
2609 
2610     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
2611         nu_begin = nu_end = svm_params.nu;
2612 
2613     if( svm_type != CvSVM::EPS_SVR )
2614         p_begin = p_end = svm_params.p;
2615 
2616     is_regression = cv_params->is_regression;
2617     best_rate = is_regression ? FLT_MAX : 0;
2618 
2619     assert( g_step > 1 && degree_step > 1 && coef_step > 1);
2620     assert( p_step > 1 && C_step > 1 && nu_step > 1 );
2621 
2622     for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
2623     {
2624       svm_params.degree = degree;
2625       //printf("degree = %.3f\n", degree );
2626       for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
2627       {
2628         svm_params.gamma = gamma;
2629         //printf("   gamma = %.3f\n", gamma );
2630         for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
2631         {
2632           svm_params.coef0 = coef;
2633           //printf("      coef = %.3f\n", coef );
2634           for( C = C_begin; C <= C_end; C *= C_step )
2635           {
2636             svm_params.C = C;
2637             //printf("         C = %.3f\n", C );
2638             for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
2639             {
2640               svm_params.nu = nu;
2641               //printf("            nu = %.3f\n", nu );
2642               for( p = p_begin; p <= p_end; p *= p_step )
2643               {
2644                 int well;
2645                 svm_params.p = p;
2646                 //printf("               p = %.3f\n", p );
2647 
2648                 CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
2649                     cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
2650 
2651                 well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
2652                 if( well || (rate == best_rate && C < best_C) )
2653                 {
2654                     best_rate   = rate;
2655                     best_degree = degree;
2656                     best_gamma  = gamma;
2657                     best_coef   = coef;
2658                     best_C      = C;
2659                     best_nu     = nu;
2660                     best_p      = p;
2661                 }
2662                 //printf("                  rate = %.2f\n", rate );
2663               }
2664             }
2665           }
2666         }
2667       }
2668     }
2669     //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
2670       //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
2671 
2672     psvm_params->C      = best_C;
2673     psvm_params->nu     = best_nu;
2674     psvm_params->p      = best_p;
2675     psvm_params->gamma  = best_gamma;
2676     psvm_params->degree = best_degree;
2677     psvm_params->coef0  = best_coef;
2678 
2679     CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
2680 
2681     __END__;
2682 
2683     return svm;
2684 }
2685 
2686 #endif
2687 
2688 /* End of file. */
2689 
2690