1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "_ml.h"
42
43 static const float ord_nan = FLT_MAX*0.5f;
44 static const int min_block_size = 1 << 16;
45 static const int block_size_delta = 1 << 10;
46
CvDTreeTrainData()47 CvDTreeTrainData::CvDTreeTrainData()
48 {
49 var_idx = var_type = cat_count = cat_ofs = cat_map =
50 priors = priors_mult = counts = buf = direction = split_buf = 0;
51 tree_storage = temp_storage = 0;
52
53 clear();
54 }
55
56
CvDTreeTrainData(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,const CvDTreeParams & _params,bool _shared,bool _add_labels)57 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
58 const CvMat* _responses, const CvMat* _var_idx,
59 const CvMat* _sample_idx, const CvMat* _var_type,
60 const CvMat* _missing_mask, const CvDTreeParams& _params,
61 bool _shared, bool _add_labels )
62 {
63 var_idx = var_type = cat_count = cat_ofs = cat_map =
64 priors = priors_mult = counts = buf = direction = split_buf = 0;
65 tree_storage = temp_storage = 0;
66
67 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
68 _var_type, _missing_mask, _params, _shared, _add_labels );
69 }
70
71
~CvDTreeTrainData()72 CvDTreeTrainData::~CvDTreeTrainData()
73 {
74 clear();
75 }
76
77
set_params(const CvDTreeParams & _params)78 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
79 {
80 bool ok = false;
81
82 CV_FUNCNAME( "CvDTreeTrainData::set_params" );
83
84 __BEGIN__;
85
86 // set parameters
87 params = _params;
88
89 if( params.max_categories < 2 )
90 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
91 params.max_categories = MIN( params.max_categories, 15 );
92
93 if( params.max_depth < 0 )
94 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
95 params.max_depth = MIN( params.max_depth, 25 );
96
97 params.min_sample_count = MAX(params.min_sample_count,1);
98
99 if( params.cv_folds < 0 )
100 CV_ERROR( CV_StsOutOfRange,
101 "params.cv_folds should be =0 (the tree is not pruned) "
102 "or n>0 (tree is pruned using n-fold cross-validation)" );
103
104 if( params.cv_folds == 1 )
105 params.cv_folds = 0;
106
107 if( params.regression_accuracy < 0 )
108 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
109
110 ok = true;
111
112 __END__;
113
114 return ok;
115 }
116
117
118 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
CV_IMPLEMENT_QSORT_EX(icvSortIntPtr,int *,CV_CMP_NUM_PTR,int)119 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
120 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
121
122 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
123 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
124
125 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
126 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
127 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
128 bool _shared, bool _add_labels, bool _update_data )
129 {
130 CvMat* sample_idx = 0;
131 CvMat* var_type0 = 0;
132 CvMat* tmp_map = 0;
133 int** int_ptr = 0;
134 CvDTreeTrainData* data = 0;
135
136 CV_FUNCNAME( "CvDTreeTrainData::set_data" );
137
138 __BEGIN__;
139
140 int sample_all = 0, r_type = 0, cv_n;
141 int total_c_count = 0;
142 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
143 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
144 int vi, i;
145 char err[100];
146 const int *sidx = 0, *vidx = 0;
147
148 if( _update_data && data_root )
149 {
150 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
151 _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
152
153 // compare new and old train data
154 if( !(data->var_count == var_count &&
155 cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
156 cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
157 cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
158 CV_ERROR( CV_StsBadArg,
159 "The new training data must have the same types and the input and output variables "
160 "and the same categories for categorical variables" );
161
162 cvReleaseMat( &priors );
163 cvReleaseMat( &priors_mult );
164 cvReleaseMat( &buf );
165 cvReleaseMat( &direction );
166 cvReleaseMat( &split_buf );
167 cvReleaseMemStorage( &temp_storage );
168
169 priors = data->priors; data->priors = 0;
170 priors_mult = data->priors_mult; data->priors_mult = 0;
171 buf = data->buf; data->buf = 0;
172 buf_count = data->buf_count; buf_size = data->buf_size;
173 sample_count = data->sample_count;
174
175 direction = data->direction; data->direction = 0;
176 split_buf = data->split_buf; data->split_buf = 0;
177 temp_storage = data->temp_storage; data->temp_storage = 0;
178 nv_heap = data->nv_heap; cv_heap = data->cv_heap;
179
180 data_root = new_node( 0, sample_count, 0, 0 );
181 EXIT;
182 }
183
184 clear();
185
186 var_all = 0;
187 rng = cvRNG(-1);
188
189 CV_CALL( set_params( _params ));
190
191 // check parameter types and sizes
192 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
193 if( _tflag == CV_ROW_SAMPLE )
194 {
195 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
196 dv_step = 1;
197 if( _missing_mask )
198 ms_step = _missing_mask->step, mv_step = 1;
199 }
200 else
201 {
202 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
203 ds_step = 1;
204 if( _missing_mask )
205 mv_step = _missing_mask->step, ms_step = 1;
206 }
207
208 sample_count = sample_all;
209 var_count = var_all;
210
211 if( _sample_idx )
212 {
213 CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
214 sidx = sample_idx->data.i;
215 sample_count = sample_idx->rows + sample_idx->cols - 1;
216 }
217
218 if( _var_idx )
219 {
220 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
221 vidx = var_idx->data.i;
222 var_count = var_idx->rows + var_idx->cols - 1;
223 }
224
225 if( !CV_IS_MAT(_responses) ||
226 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
227 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
228 _responses->rows != 1 && _responses->cols != 1 ||
229 _responses->rows + _responses->cols - 1 != sample_all )
230 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
231 "floating-point vector containing as many elements as "
232 "the total number of samples in the training data matrix" );
233
234 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
235 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
236
237 cat_var_count = 0;
238 ord_var_count = -1;
239
240 is_classifier = r_type == CV_VAR_CATEGORICAL;
241
242 // step 0. calc the number of categorical vars
243 for( vi = 0; vi < var_count; vi++ )
244 {
245 var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
246 cat_var_count++ : ord_var_count--;
247 }
248
249 ord_var_count = ~ord_var_count;
250 cv_n = params.cv_folds;
251 // set the two last elements of var_type array to be able
252 // to locate responses and cross-validation labels using
253 // the corresponding get_* functions.
254 var_type->data.i[var_count] = cat_var_count;
255 var_type->data.i[var_count+1] = cat_var_count+1;
256
257 // in case of single ordered predictor we need dummy cv_labels
258 // for safe split_node_data() operation
259 have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels;
260
261 buf_size = (ord_var_count + get_work_var_count())*sample_count + 2;
262 shared = _shared;
263 buf_count = shared ? 3 : 2;
264 CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
265 CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
266 CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
267 CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
268
269 // now calculate the maximum size of split,
270 // create memory storage that will keep nodes and splits of the decision tree
271 // allocate root node and the buffer for the whole training data
272 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
273 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
274 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
275 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
276 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
277 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
278
279 nv_size = var_count*sizeof(int);
280 nv_size = MAX( nv_size, (int)sizeof(CvSetElem) );
281
282 temp_block_size = nv_size;
283
284 if( cv_n )
285 {
286 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
287 CV_ERROR( CV_StsOutOfRange,
288 "The many folds in cross-validation for such a small dataset" );
289
290 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
291 temp_block_size = MAX(temp_block_size, cv_size);
292 }
293
294 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
295 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
296 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
297 if( cv_size )
298 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
299
300 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
301 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
302
303 max_c_count = 1;
304
305 // transform the training data to convenient representation
306 for( vi = 0; vi <= var_count; vi++ )
307 {
308 int ci;
309 const uchar* mask = 0;
310 int m_step = 0, step;
311 const int* idata = 0;
312 const float* fdata = 0;
313 int num_valid = 0;
314
315 if( vi < var_count ) // analyze i-th input variable
316 {
317 int vi0 = vidx ? vidx[vi] : vi;
318 ci = get_var_type(vi);
319 step = ds_step; m_step = ms_step;
320 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
321 idata = _train_data->data.i + vi0*dv_step;
322 else
323 fdata = _train_data->data.fl + vi0*dv_step;
324 if( _missing_mask )
325 mask = _missing_mask->data.ptr + vi0*mv_step;
326 }
327 else // analyze _responses
328 {
329 ci = cat_var_count;
330 step = CV_IS_MAT_CONT(_responses->type) ?
331 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
332 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
333 idata = _responses->data.i;
334 else
335 fdata = _responses->data.fl;
336 }
337
338 if( vi < var_count && ci >= 0 ||
339 vi == var_count && is_classifier ) // process categorical variable or response
340 {
341 int c_count, prev_label;
342 int* c_map, *dst = get_cat_var_data( data_root, vi );
343
344 // copy data
345 for( i = 0; i < sample_count; i++ )
346 {
347 int val = INT_MAX, si = sidx ? sidx[i] : i;
348 if( !mask || !mask[si*m_step] )
349 {
350 if( idata )
351 val = idata[si*step];
352 else
353 {
354 float t = fdata[si*step];
355 val = cvRound(t);
356 if( val != t )
357 {
358 sprintf( err, "%d-th value of %d-th (categorical) "
359 "variable is not an integer", i, vi );
360 CV_ERROR( CV_StsBadArg, err );
361 }
362 }
363
364 if( val == INT_MAX )
365 {
366 sprintf( err, "%d-th value of %d-th (categorical) "
367 "variable is too large", i, vi );
368 CV_ERROR( CV_StsBadArg, err );
369 }
370 num_valid++;
371 }
372 dst[i] = val;
373 int_ptr[i] = dst + i;
374 }
375
376 // sort all the values, including the missing measurements
377 // that should all move to the end
378 icvSortIntPtr( int_ptr, sample_count, 0 );
379 //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
380
381 c_count = num_valid > 0;
382
383 // count the categories
384 for( i = 1; i < num_valid; i++ )
385 c_count += *int_ptr[i] != *int_ptr[i-1];
386
387 if( vi > 0 )
388 max_c_count = MAX( max_c_count, c_count );
389 cat_count->data.i[ci] = c_count;
390 cat_ofs->data.i[ci] = total_c_count;
391
392 // resize cat_map, if need
393 if( cat_map->cols < total_c_count + c_count )
394 {
395 tmp_map = cat_map;
396 CV_CALL( cat_map = cvCreateMat( 1,
397 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
398 for( i = 0; i < total_c_count; i++ )
399 cat_map->data.i[i] = tmp_map->data.i[i];
400 cvReleaseMat( &tmp_map );
401 }
402
403 c_map = cat_map->data.i + total_c_count;
404 total_c_count += c_count;
405
406 // compact the class indices and build the map
407 prev_label = ~*int_ptr[0];
408 c_count = -1;
409
410 for( i = 0; i < num_valid; i++ )
411 {
412 int cur_label = *int_ptr[i];
413 if( cur_label != prev_label )
414 c_map[++c_count] = prev_label = cur_label;
415 *int_ptr[i] = c_count;
416 }
417
418 // replace labels for missing values with -1
419 for( ; i < sample_count; i++ )
420 *int_ptr[i] = -1;
421 }
422 else if( ci < 0 ) // process ordered variable
423 {
424 CvPair32s32f* dst = get_ord_var_data( data_root, vi );
425
426 for( i = 0; i < sample_count; i++ )
427 {
428 float val = ord_nan;
429 int si = sidx ? sidx[i] : i;
430 if( !mask || !mask[si*m_step] )
431 {
432 if( idata )
433 val = (float)idata[si*step];
434 else
435 val = fdata[si*step];
436
437 if( fabs(val) >= ord_nan )
438 {
439 sprintf( err, "%d-th value of %d-th (ordered) "
440 "variable (=%g) is too large", i, vi, val );
441 CV_ERROR( CV_StsBadArg, err );
442 }
443 num_valid++;
444 }
445 dst[i].i = i;
446 dst[i].val = val;
447 }
448
449 icvSortPairs( dst, sample_count, 0 );
450 }
451 else // special case: process ordered response,
452 // it will be stored similarly to categorical vars (i.e. no pairs)
453 {
454 float* dst = get_ord_responses( data_root );
455
456 for( i = 0; i < sample_count; i++ )
457 {
458 float val = ord_nan;
459 int si = sidx ? sidx[i] : i;
460 if( idata )
461 val = (float)idata[si*step];
462 else
463 val = fdata[si*step];
464
465 if( fabs(val) >= ord_nan )
466 {
467 sprintf( err, "%d-th value of %d-th (ordered) "
468 "variable (=%g) is out of range", i, vi, val );
469 CV_ERROR( CV_StsBadArg, err );
470 }
471 dst[i] = val;
472 }
473
474 cat_count->data.i[cat_var_count] = 0;
475 cat_ofs->data.i[cat_var_count] = total_c_count;
476 num_valid = sample_count;
477 }
478
479 if( vi < var_count )
480 data_root->set_num_valid(vi, num_valid);
481 }
482
483 if( cv_n )
484 {
485 int* dst = get_labels(data_root);
486 CvRNG* r = &rng;
487
488 for( i = vi = 0; i < sample_count; i++ )
489 {
490 dst[i] = vi++;
491 vi &= vi < cv_n ? -1 : 0;
492 }
493
494 for( i = 0; i < sample_count; i++ )
495 {
496 int a = cvRandInt(r) % sample_count;
497 int b = cvRandInt(r) % sample_count;
498 CV_SWAP( dst[a], dst[b], vi );
499 }
500 }
501
502 cat_map->cols = MAX( total_c_count, 1 );
503
504 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
505 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
506 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
507
508 have_priors = is_classifier && params.priors;
509 if( is_classifier )
510 {
511 int m = get_num_classes();
512 double sum = 0;
513 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
514 for( i = 0; i < m; i++ )
515 {
516 double val = have_priors ? params.priors[i] : 1.;
517 if( val <= 0 )
518 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
519 priors->data.db[i] = val;
520 sum += val;
521 }
522
523 // normalize weights
524 if( have_priors )
525 cvScale( priors, priors, 1./sum );
526
527 CV_CALL( priors_mult = cvCloneMat( priors ));
528 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
529 }
530
531 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
532 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
533
534 __END__;
535
536 if( data )
537 delete data;
538
539 cvFree( &int_ptr );
540 cvReleaseMat( &sample_idx );
541 cvReleaseMat( &var_type0 );
542 cvReleaseMat( &tmp_map );
543 }
544
545
subsample_data(const CvMat * _subsample_idx)546 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
547 {
548 CvDTreeNode* root = 0;
549 CvMat* isubsample_idx = 0;
550 CvMat* subsample_co = 0;
551
552 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
553
554 __BEGIN__;
555
556 if( !data_root )
557 CV_ERROR( CV_StsError, "No training data has been set" );
558
559 if( _subsample_idx )
560 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
561
562 if( !isubsample_idx )
563 {
564 // make a copy of the root node
565 CvDTreeNode temp;
566 int i;
567 root = new_node( 0, 1, 0, 0 );
568 temp = *root;
569 *root = *data_root;
570 root->num_valid = temp.num_valid;
571 if( root->num_valid )
572 {
573 for( i = 0; i < var_count; i++ )
574 root->num_valid[i] = data_root->num_valid[i];
575 }
576 root->cv_Tn = temp.cv_Tn;
577 root->cv_node_risk = temp.cv_node_risk;
578 root->cv_node_error = temp.cv_node_error;
579 }
580 else
581 {
582 int* sidx = isubsample_idx->data.i;
583 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
584 int* co, cur_ofs = 0;
585 int vi, i, total = data_root->sample_count;
586 int count = isubsample_idx->rows + isubsample_idx->cols - 1;
587 int work_var_count = get_work_var_count();
588 root = new_node( 0, count, 1, 0 );
589
590 CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
591 cvZero( subsample_co );
592 co = subsample_co->data.i;
593 for( i = 0; i < count; i++ )
594 co[sidx[i]*2]++;
595 for( i = 0; i < total; i++ )
596 {
597 if( co[i*2] )
598 {
599 co[i*2+1] = cur_ofs;
600 cur_ofs += co[i*2];
601 }
602 else
603 co[i*2+1] = -1;
604 }
605
606 for( vi = 0; vi < work_var_count; vi++ )
607 {
608 int ci = get_var_type(vi);
609
610 if( ci >= 0 || vi >= var_count )
611 {
612 const int* src = get_cat_var_data( data_root, vi );
613 int* dst = get_cat_var_data( root, vi );
614 int num_valid = 0;
615
616 for( i = 0; i < count; i++ )
617 {
618 int val = src[sidx[i]];
619 dst[i] = val;
620 num_valid += val >= 0;
621 }
622
623 if( vi < var_count )
624 root->set_num_valid(vi, num_valid);
625 }
626 else
627 {
628 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
629 CvPair32s32f* dst = get_ord_var_data( root, vi );
630 int j = 0, idx, count_i;
631 int num_valid = data_root->get_num_valid(vi);
632
633 for( i = 0; i < num_valid; i++ )
634 {
635 idx = src[i].i;
636 count_i = co[idx*2];
637 if( count_i )
638 {
639 float val = src[i].val;
640 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
641 {
642 dst[j].val = val;
643 dst[j].i = cur_ofs;
644 }
645 }
646 }
647
648 root->set_num_valid(vi, j);
649
650 for( ; i < total; i++ )
651 {
652 idx = src[i].i;
653 count_i = co[idx*2];
654 if( count_i )
655 {
656 float val = src[i].val;
657 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
658 {
659 dst[j].val = val;
660 dst[j].i = cur_ofs;
661 }
662 }
663 }
664 }
665 }
666 }
667
668 __END__;
669
670 cvReleaseMat( &isubsample_idx );
671 cvReleaseMat( &subsample_co );
672
673 return root;
674 }
675
676
get_vectors(const CvMat * _subsample_idx,float * values,uchar * missing,float * responses,bool get_class_idx)677 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
678 float* values, uchar* missing,
679 float* responses, bool get_class_idx )
680 {
681 CvMat* subsample_idx = 0;
682 CvMat* subsample_co = 0;
683
684 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
685
686 __BEGIN__;
687
688 int i, vi, total = sample_count, count = total, cur_ofs = 0;
689 int* sidx = 0;
690 int* co = 0;
691
692 if( _subsample_idx )
693 {
694 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
695 sidx = subsample_idx->data.i;
696 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
697 co = subsample_co->data.i;
698 cvZero( subsample_co );
699 count = subsample_idx->cols + subsample_idx->rows - 1;
700 for( i = 0; i < count; i++ )
701 co[sidx[i]*2]++;
702 for( i = 0; i < total; i++ )
703 {
704 int count_i = co[i*2];
705 if( count_i )
706 {
707 co[i*2+1] = cur_ofs*var_count;
708 cur_ofs += count_i;
709 }
710 }
711 }
712
713 if( missing )
714 memset( missing, 1, count*var_count );
715
716 for( vi = 0; vi < var_count; vi++ )
717 {
718 int ci = get_var_type(vi);
719 if( ci >= 0 ) // categorical
720 {
721 float* dst = values + vi;
722 uchar* m = missing ? missing + vi : 0;
723 const int* src = get_cat_var_data(data_root, vi);
724
725 for( i = 0; i < count; i++, dst += var_count )
726 {
727 int idx = sidx ? sidx[i] : i;
728 int val = src[idx];
729 *dst = (float)val;
730 if( m )
731 {
732 *m = val < 0;
733 m += var_count;
734 }
735 }
736 }
737 else // ordered
738 {
739 float* dst = values + vi;
740 uchar* m = missing ? missing + vi : 0;
741 const CvPair32s32f* src = get_ord_var_data(data_root, vi);
742 int count1 = data_root->get_num_valid(vi);
743
744 for( i = 0; i < count1; i++ )
745 {
746 int idx = src[i].i;
747 int count_i = 1;
748 if( co )
749 {
750 count_i = co[idx*2];
751 cur_ofs = co[idx*2+1];
752 }
753 else
754 cur_ofs = idx*var_count;
755 if( count_i )
756 {
757 float val = src[i].val;
758 for( ; count_i > 0; count_i--, cur_ofs += var_count )
759 {
760 dst[cur_ofs] = val;
761 if( m )
762 m[cur_ofs] = 0;
763 }
764 }
765 }
766 }
767 }
768
769 // copy responses
770 if( responses )
771 {
772 if( is_classifier )
773 {
774 const int* src = get_class_labels(data_root);
775 for( i = 0; i < count; i++ )
776 {
777 int idx = sidx ? sidx[i] : i;
778 int val = get_class_idx ? src[idx] :
779 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
780 responses[i] = (float)val;
781 }
782 }
783 else
784 {
785 const float* src = get_ord_responses(data_root);
786 for( i = 0; i < count; i++ )
787 {
788 int idx = sidx ? sidx[i] : i;
789 responses[i] = src[idx];
790 }
791 }
792 }
793
794 __END__;
795
796 cvReleaseMat( &subsample_idx );
797 cvReleaseMat( &subsample_co );
798 }
799
800
new_node(CvDTreeNode * parent,int count,int storage_idx,int offset)801 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
802 int storage_idx, int offset )
803 {
804 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
805
806 node->sample_count = count;
807 node->depth = parent ? parent->depth + 1 : 0;
808 node->parent = parent;
809 node->left = node->right = 0;
810 node->split = 0;
811 node->value = 0;
812 node->class_idx = 0;
813 node->maxlr = 0.;
814
815 node->buf_idx = storage_idx;
816 node->offset = offset;
817 if( nv_heap )
818 node->num_valid = (int*)cvSetNew( nv_heap );
819 else
820 node->num_valid = 0;
821 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
822 node->complexity = 0;
823
824 if( params.cv_folds > 0 && cv_heap )
825 {
826 int cv_n = params.cv_folds;
827 node->Tn = INT_MAX;
828 node->cv_Tn = (int*)cvSetNew( cv_heap );
829 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
830 node->cv_node_error = node->cv_node_risk + cv_n;
831 }
832 else
833 {
834 node->Tn = 0;
835 node->cv_Tn = 0;
836 node->cv_node_risk = 0;
837 node->cv_node_error = 0;
838 }
839
840 return node;
841 }
842
843
new_split_ord(int vi,float cmp_val,int split_point,int inversed,float quality)844 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
845 int split_point, int inversed, float quality )
846 {
847 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
848 split->var_idx = vi;
849 split->ord.c = cmp_val;
850 split->ord.split_point = split_point;
851 split->inversed = inversed;
852 split->quality = quality;
853 split->next = 0;
854
855 return split;
856 }
857
858
new_split_cat(int vi,float quality)859 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
860 {
861 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
862 int i, n = (max_c_count + 31)/32;
863
864 split->var_idx = vi;
865 split->inversed = 0;
866 split->quality = quality;
867 for( i = 0; i < n; i++ )
868 split->subset[i] = 0;
869 split->next = 0;
870
871 return split;
872 }
873
874
free_node(CvDTreeNode * node)875 void CvDTreeTrainData::free_node( CvDTreeNode* node )
876 {
877 CvDTreeSplit* split = node->split;
878 free_node_data( node );
879 while( split )
880 {
881 CvDTreeSplit* next = split->next;
882 cvSetRemoveByPtr( split_heap, split );
883 split = next;
884 }
885 node->split = 0;
886 cvSetRemoveByPtr( node_heap, node );
887 }
888
889
free_node_data(CvDTreeNode * node)890 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
891 {
892 if( node->num_valid )
893 {
894 cvSetRemoveByPtr( nv_heap, node->num_valid );
895 node->num_valid = 0;
896 }
897 // do not free cv_* fields, as all the cross-validation related data is released at once.
898 }
899
900
free_train_data()901 void CvDTreeTrainData::free_train_data()
902 {
903 cvReleaseMat( &counts );
904 cvReleaseMat( &buf );
905 cvReleaseMat( &direction );
906 cvReleaseMat( &split_buf );
907 cvReleaseMemStorage( &temp_storage );
908 cv_heap = nv_heap = 0;
909 }
910
911
clear()912 void CvDTreeTrainData::clear()
913 {
914 free_train_data();
915
916 cvReleaseMemStorage( &tree_storage );
917
918 cvReleaseMat( &var_idx );
919 cvReleaseMat( &var_type );
920 cvReleaseMat( &cat_count );
921 cvReleaseMat( &cat_ofs );
922 cvReleaseMat( &cat_map );
923 cvReleaseMat( &priors );
924 cvReleaseMat( &priors_mult );
925
926 node_heap = split_heap = 0;
927
928 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
929 have_labels = have_priors = is_classifier = false;
930
931 buf_count = buf_size = 0;
932 shared = false;
933
934 data_root = 0;
935
936 rng = cvRNG(-1);
937 }
938
939
get_num_classes() const940 int CvDTreeTrainData::get_num_classes() const
941 {
942 return is_classifier ? cat_count->data.i[cat_var_count] : 0;
943 }
944
945
get_var_type(int vi) const946 int CvDTreeTrainData::get_var_type(int vi) const
947 {
948 return var_type->data.i[vi];
949 }
950
951
get_work_var_count() const952 int CvDTreeTrainData::get_work_var_count() const
953 {
954 return var_count + 1 + (have_labels ? 1 : 0);
955 }
956
get_ord_var_data(CvDTreeNode * n,int vi)957 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
958 {
959 int oi = ~get_var_type(vi);
960 assert( 0 <= oi && oi < ord_var_count );
961 return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
962 n->offset + oi*n->sample_count*2);
963 }
964
965
get_class_labels(CvDTreeNode * n)966 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
967 {
968 return get_cat_var_data( n, var_count );
969 }
970
971
get_ord_responses(CvDTreeNode * n)972 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
973 {
974 return (float*)get_cat_var_data( n, var_count );
975 }
976
977
get_labels(CvDTreeNode * n)978 int* CvDTreeTrainData::get_labels( CvDTreeNode* n )
979 {
980 return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0;
981 }
982
983
get_cat_var_data(CvDTreeNode * n,int vi)984 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
985 {
986 int ci = get_var_type(vi);
987 assert( 0 <= ci && ci <= cat_var_count + 1 );
988 return buf->data.i + n->buf_idx*buf->cols + n->offset +
989 (ord_var_count*2 + ci)*n->sample_count;
990 }
991
992
get_child_buf_idx(CvDTreeNode * n)993 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
994 {
995 int idx = n->buf_idx + 1;
996 if( idx >= buf_count )
997 idx = shared ? 1 : 0;
998 return idx;
999 }
1000
1001
write_params(CvFileStorage * fs)1002 void CvDTreeTrainData::write_params( CvFileStorage* fs )
1003 {
1004 CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1005
1006 __BEGIN__;
1007
1008 int vi, vcount = var_count;
1009
1010 cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1011 cvWriteInt( fs, "var_all", var_all );
1012 cvWriteInt( fs, "var_count", var_count );
1013 cvWriteInt( fs, "ord_var_count", ord_var_count );
1014 cvWriteInt( fs, "cat_var_count", cat_var_count );
1015
1016 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1017 cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1018
1019 if( is_classifier )
1020 {
1021 cvWriteInt( fs, "max_categories", params.max_categories );
1022 }
1023 else
1024 {
1025 cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1026 }
1027
1028 cvWriteInt( fs, "max_depth", params.max_depth );
1029 cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1030 cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1031
1032 if( params.cv_folds > 1 )
1033 {
1034 cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1035 cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1036 }
1037
1038 if( priors )
1039 cvWrite( fs, "priors", priors );
1040
1041 cvEndWriteStruct( fs );
1042
1043 if( var_idx )
1044 cvWrite( fs, "var_idx", var_idx );
1045
1046 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1047
1048 for( vi = 0; vi < vcount; vi++ )
1049 cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1050
1051 cvEndWriteStruct( fs );
1052
1053 if( cat_count && (cat_var_count > 0 || is_classifier) )
1054 {
1055 CV_ASSERT( cat_count != 0 );
1056 cvWrite( fs, "cat_count", cat_count );
1057 cvWrite( fs, "cat_map", cat_map );
1058 }
1059
1060 __END__;
1061 }
1062
1063
read_params(CvFileStorage * fs,CvFileNode * node)1064 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1065 {
1066 CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1067
1068 __BEGIN__;
1069
1070 CvFileNode *tparams_node, *vartype_node;
1071 CvSeqReader reader;
1072 int vi, max_split_size, tree_block_size;
1073
1074 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1075 var_all = cvReadIntByName( fs, node, "var_all" );
1076 var_count = cvReadIntByName( fs, node, "var_count", var_all );
1077 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1078 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1079
1080 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1081
1082 if( tparams_node ) // training parameters are not necessary
1083 {
1084 params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1085
1086 if( is_classifier )
1087 {
1088 params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1089 }
1090 else
1091 {
1092 params.regression_accuracy =
1093 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1094 }
1095
1096 params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1097 params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1098 params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1099
1100 if( params.cv_folds > 1 )
1101 {
1102 params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1103 params.truncate_pruned_tree =
1104 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1105 }
1106
1107 priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1108 if( priors )
1109 {
1110 if( !CV_IS_MAT(priors) )
1111 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1112 priors_mult = cvCloneMat( priors );
1113 }
1114 }
1115
1116 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1117 if( var_idx )
1118 {
1119 if( !CV_IS_MAT(var_idx) ||
1120 var_idx->cols != 1 && var_idx->rows != 1 ||
1121 var_idx->cols + var_idx->rows - 1 != var_count ||
1122 CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1123 CV_ERROR( CV_StsParseError,
1124 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1125
1126 for( vi = 0; vi < var_count; vi++ )
1127 if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1128 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1129 }
1130
1131 ////// read var type
1132 CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1133
1134 cat_var_count = 0;
1135 ord_var_count = -1;
1136 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1137
1138 if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1139 var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1140 else
1141 {
1142 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1143 vartype_node->data.seq->total != var_count )
1144 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1145
1146 cvStartReadSeq( vartype_node->data.seq, &reader );
1147
1148 for( vi = 0; vi < var_count; vi++ )
1149 {
1150 CvFileNode* n = (CvFileNode*)reader.ptr;
1151 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1152 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1153 var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1154 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1155 }
1156 }
1157 var_type->data.i[var_count] = cat_var_count;
1158
1159 ord_var_count = ~ord_var_count;
1160 if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
1161 CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
1162 //////
1163
1164 if( cat_var_count > 0 || is_classifier )
1165 {
1166 int ccount, total_c_count = 0;
1167 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1168 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1169
1170 if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1171 cat_count->cols != 1 && cat_count->rows != 1 ||
1172 CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1173 cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1174 cat_map->cols != 1 && cat_map->rows != 1 ||
1175 CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1176 CV_ERROR( CV_StsParseError,
1177 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1178
1179 ccount = cat_var_count + is_classifier;
1180
1181 CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1182 cat_ofs->data.i[0] = 0;
1183 max_c_count = 1;
1184
1185 for( vi = 0; vi < ccount; vi++ )
1186 {
1187 int val = cat_count->data.i[vi];
1188 if( val <= 0 )
1189 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1190 max_c_count = MAX( max_c_count, val );
1191 cat_ofs->data.i[vi+1] = total_c_count += val;
1192 }
1193
1194 if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1195 CV_ERROR( CV_StsBadSize,
1196 "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1197 }
1198
1199 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1200 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1201
1202 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1203 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1204 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1205 CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1206 sizeof(CvDTreeNode), tree_storage ));
1207 CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1208 max_split_size, tree_storage ));
1209
1210 __END__;
1211 }
1212
1213
1214 /////////////////////// Decision Tree /////////////////////////
1215
CvDTree()1216 CvDTree::CvDTree()
1217 {
1218 data = 0;
1219 var_importance = 0;
1220 default_model_name = "my_tree";
1221
1222 clear();
1223 }
1224
1225
clear()1226 void CvDTree::clear()
1227 {
1228 cvReleaseMat( &var_importance );
1229 if( data )
1230 {
1231 if( !data->shared )
1232 delete data;
1233 else
1234 free_tree();
1235 data = 0;
1236 }
1237 root = 0;
1238 pruned_tree_idx = -1;
1239 }
1240
1241
~CvDTree()1242 CvDTree::~CvDTree()
1243 {
1244 clear();
1245 }
1246
1247
get_root() const1248 const CvDTreeNode* CvDTree::get_root() const
1249 {
1250 return root;
1251 }
1252
1253
get_pruned_tree_idx() const1254 int CvDTree::get_pruned_tree_idx() const
1255 {
1256 return pruned_tree_idx;
1257 }
1258
1259
get_data()1260 CvDTreeTrainData* CvDTree::get_data()
1261 {
1262 return data;
1263 }
1264
1265
train(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,CvDTreeParams _params)1266 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1267 const CvMat* _responses, const CvMat* _var_idx,
1268 const CvMat* _sample_idx, const CvMat* _var_type,
1269 const CvMat* _missing_mask, CvDTreeParams _params )
1270 {
1271 bool result = false;
1272
1273 CV_FUNCNAME( "CvDTree::train" );
1274
1275 __BEGIN__;
1276
1277 clear();
1278 data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1279 _var_idx, _sample_idx, _var_type,
1280 _missing_mask, _params, false );
1281 CV_CALL( result = do_train(0));
1282
1283 __END__;
1284
1285 return result;
1286 }
1287
1288
train(CvDTreeTrainData * _data,const CvMat * _subsample_idx)1289 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1290 {
1291 bool result = false;
1292
1293 CV_FUNCNAME( "CvDTree::train" );
1294
1295 __BEGIN__;
1296
1297 clear();
1298 data = _data;
1299 data->shared = true;
1300 CV_CALL( result = do_train(_subsample_idx));
1301
1302 __END__;
1303
1304 return result;
1305 }
1306
1307
do_train(const CvMat * _subsample_idx)1308 bool CvDTree::do_train( const CvMat* _subsample_idx )
1309 {
1310 bool result = false;
1311
1312 CV_FUNCNAME( "CvDTree::do_train" );
1313
1314 __BEGIN__;
1315
1316 root = data->subsample_data( _subsample_idx );
1317
1318 CV_CALL( try_split_node(root));
1319
1320 if( data->params.cv_folds > 0 )
1321 CV_CALL( prune_cv());
1322
1323 if( !data->shared )
1324 data->free_train_data();
1325
1326 result = true;
1327
1328 __END__;
1329
1330 return result;
1331 }
1332
1333
try_split_node(CvDTreeNode * node)1334 void CvDTree::try_split_node( CvDTreeNode* node )
1335 {
1336 CvDTreeSplit* best_split = 0;
1337 int i, n = node->sample_count, vi;
1338 bool can_split = true;
1339 double quality_scale;
1340
1341 calc_node_value( node );
1342
1343 if( node->sample_count <= data->params.min_sample_count ||
1344 node->depth >= data->params.max_depth )
1345 can_split = false;
1346
1347 if( can_split && data->is_classifier )
1348 {
1349 // check if we have a "pure" node,
1350 // we assume that cls_count is filled by calc_node_value()
1351 int* cls_count = data->counts->data.i;
1352 int nz = 0, m = data->get_num_classes();
1353 for( i = 0; i < m; i++ )
1354 nz += cls_count[i] != 0;
1355 if( nz == 1 ) // there is only one class
1356 can_split = false;
1357 }
1358 else if( can_split )
1359 {
1360 if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1361 can_split = false;
1362 }
1363
1364 if( can_split )
1365 {
1366 best_split = find_best_split(node);
1367 // TODO: check the split quality ...
1368 node->split = best_split;
1369 }
1370
1371 if( !can_split || !best_split )
1372 {
1373 data->free_node_data(node);
1374 return;
1375 }
1376
1377 quality_scale = calc_node_dir( node );
1378
1379 if( data->params.use_surrogates )
1380 {
1381 // find all the surrogate splits
1382 // and sort them by their similarity to the primary one
1383 for( vi = 0; vi < data->var_count; vi++ )
1384 {
1385 CvDTreeSplit* split;
1386 int ci = data->get_var_type(vi);
1387
1388 if( vi == best_split->var_idx )
1389 continue;
1390
1391 if( ci >= 0 )
1392 split = find_surrogate_split_cat( node, vi );
1393 else
1394 split = find_surrogate_split_ord( node, vi );
1395
1396 if( split )
1397 {
1398 // insert the split
1399 CvDTreeSplit* prev_split = node->split;
1400 split->quality = (float)(split->quality*quality_scale);
1401
1402 while( prev_split->next &&
1403 prev_split->next->quality > split->quality )
1404 prev_split = prev_split->next;
1405 split->next = prev_split->next;
1406 prev_split->next = split;
1407 }
1408 }
1409 }
1410
1411 split_node_data( node );
1412 try_split_node( node->left );
1413 try_split_node( node->right );
1414 }
1415
1416
1417 // calculate direction (left(-1),right(1),missing(0))
1418 // for each sample using the best split
1419 // the function returns scale coefficients for surrogate split quality factors.
1420 // the scale is applied to normalize surrogate split quality relatively to the
1421 // best (primary) split quality. That is, if a surrogate split is absolutely
1422 // identical to the primary split, its quality will be set to the maximum value =
1423 // quality of the primary split; otherwise, it will be lower.
1424 // besides, the function compute node->maxlr,
1425 // minimum possible quality (w/o considering the above mentioned scale)
1426 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1427 // are not discarded.
calc_node_dir(CvDTreeNode * node)1428 double CvDTree::calc_node_dir( CvDTreeNode* node )
1429 {
1430 char* dir = (char*)data->direction->data.ptr;
1431 int i, n = node->sample_count, vi = node->split->var_idx;
1432 double L, R;
1433
1434 assert( !node->split->inversed );
1435
1436 if( data->get_var_type(vi) >= 0 ) // split on categorical var
1437 {
1438 const int* labels = data->get_cat_var_data(node,vi);
1439 const int* subset = node->split->subset;
1440
1441 if( !data->have_priors )
1442 {
1443 int sum = 0, sum_abs = 0;
1444
1445 for( i = 0; i < n; i++ )
1446 {
1447 int idx = labels[i];
1448 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1449 sum += d; sum_abs += d & 1;
1450 dir[i] = (char)d;
1451 }
1452
1453 R = (sum_abs + sum) >> 1;
1454 L = (sum_abs - sum) >> 1;
1455 }
1456 else
1457 {
1458 const int* responses = data->get_class_labels(node);
1459 const double* priors = data->priors_mult->data.db;
1460 double sum = 0, sum_abs = 0;
1461
1462 for( i = 0; i < n; i++ )
1463 {
1464 int idx = labels[i];
1465 double w = priors[responses[i]];
1466 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1467 sum += d*w; sum_abs += (d & 1)*w;
1468 dir[i] = (char)d;
1469 }
1470
1471 R = (sum_abs + sum) * 0.5;
1472 L = (sum_abs - sum) * 0.5;
1473 }
1474 }
1475 else // split on ordered var
1476 {
1477 const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
1478 int split_point = node->split->ord.split_point;
1479 int n1 = node->get_num_valid(vi);
1480
1481 assert( 0 <= split_point && split_point < n1-1 );
1482
1483 if( !data->have_priors )
1484 {
1485 for( i = 0; i <= split_point; i++ )
1486 dir[sorted[i].i] = (char)-1;
1487 for( ; i < n1; i++ )
1488 dir[sorted[i].i] = (char)1;
1489 for( ; i < n; i++ )
1490 dir[sorted[i].i] = (char)0;
1491
1492 L = split_point-1;
1493 R = n1 - split_point + 1;
1494 }
1495 else
1496 {
1497 const int* responses = data->get_class_labels(node);
1498 const double* priors = data->priors_mult->data.db;
1499 L = R = 0;
1500
1501 for( i = 0; i <= split_point; i++ )
1502 {
1503 int idx = sorted[i].i;
1504 double w = priors[responses[idx]];
1505 dir[idx] = (char)-1;
1506 L += w;
1507 }
1508
1509 for( ; i < n1; i++ )
1510 {
1511 int idx = sorted[i].i;
1512 double w = priors[responses[idx]];
1513 dir[idx] = (char)1;
1514 R += w;
1515 }
1516
1517 for( ; i < n; i++ )
1518 dir[sorted[i].i] = (char)0;
1519 }
1520 }
1521
1522 node->maxlr = MAX( L, R );
1523 return node->split->quality/(L + R);
1524 }
1525
1526
find_best_split(CvDTreeNode * node)1527 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1528 {
1529 int vi;
1530 CvDTreeSplit *best_split = 0, *split = 0, *t;
1531
1532 for( vi = 0; vi < data->var_count; vi++ )
1533 {
1534 int ci = data->get_var_type(vi);
1535 if( node->get_num_valid(vi) <= 1 )
1536 continue;
1537
1538 if( data->is_classifier )
1539 {
1540 if( ci >= 0 )
1541 split = find_split_cat_class( node, vi );
1542 else
1543 split = find_split_ord_class( node, vi );
1544 }
1545 else
1546 {
1547 if( ci >= 0 )
1548 split = find_split_cat_reg( node, vi );
1549 else
1550 split = find_split_ord_reg( node, vi );
1551 }
1552
1553 if( split )
1554 {
1555 if( !best_split || best_split->quality < split->quality )
1556 CV_SWAP( best_split, split, t );
1557 if( split )
1558 cvSetRemoveByPtr( data->split_heap, split );
1559 }
1560 }
1561
1562 return best_split;
1563 }
1564
1565
find_split_ord_class(CvDTreeNode * node,int vi)1566 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
1567 {
1568 const float epsilon = FLT_EPSILON*2;
1569 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1570 const int* responses = data->get_class_labels(node);
1571 int n = node->sample_count;
1572 int n1 = node->get_num_valid(vi);
1573 int m = data->get_num_classes();
1574 const int* rc0 = data->counts->data.i;
1575 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1576 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1577 int i, best_i = -1;
1578 double lsum2 = 0, rsum2 = 0, best_val = 0;
1579 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
1580
1581 // init arrays of class instance counters on both sides of the split
1582 for( i = 0; i < m; i++ )
1583 {
1584 lc[i] = 0;
1585 rc[i] = rc0[i];
1586 }
1587
1588 // compensate for missing values
1589 for( i = n1; i < n; i++ )
1590 rc[responses[sorted[i].i]]--;
1591
1592 if( !priors )
1593 {
1594 int L = 0, R = n1;
1595
1596 for( i = 0; i < m; i++ )
1597 rsum2 += (double)rc[i]*rc[i];
1598
1599 for( i = 0; i < n1 - 1; i++ )
1600 {
1601 int idx = responses[sorted[i].i];
1602 int lv, rv;
1603 L++; R--;
1604 lv = lc[idx]; rv = rc[idx];
1605 lsum2 += lv*2 + 1;
1606 rsum2 -= rv*2 - 1;
1607 lc[idx] = lv + 1; rc[idx] = rv - 1;
1608
1609 if( sorted[i].val + epsilon < sorted[i+1].val )
1610 {
1611 double val = (lsum2*R + rsum2*L)/((double)L*R);
1612 if( best_val < val )
1613 {
1614 best_val = val;
1615 best_i = i;
1616 }
1617 }
1618 }
1619 }
1620 else
1621 {
1622 double L = 0, R = 0;
1623 for( i = 0; i < m; i++ )
1624 {
1625 double wv = rc[i]*priors[i];
1626 R += wv;
1627 rsum2 += wv*wv;
1628 }
1629
1630 for( i = 0; i < n1 - 1; i++ )
1631 {
1632 int idx = responses[sorted[i].i];
1633 int lv, rv;
1634 double p = priors[idx], p2 = p*p;
1635 L += p; R -= p;
1636 lv = lc[idx]; rv = rc[idx];
1637 lsum2 += p2*(lv*2 + 1);
1638 rsum2 -= p2*(rv*2 - 1);
1639 lc[idx] = lv + 1; rc[idx] = rv - 1;
1640
1641 if( sorted[i].val + epsilon < sorted[i+1].val )
1642 {
1643 double val = (lsum2*R + rsum2*L)/((double)L*R);
1644 if( best_val < val )
1645 {
1646 best_val = val;
1647 best_i = i;
1648 }
1649 }
1650 }
1651 }
1652
1653 return best_i >= 0 ? data->new_split_ord( vi,
1654 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1655 0, (float)best_val ) : 0;
1656 }
1657
1658
cluster_categories(const int * vectors,int n,int m,int * csums,int k,int * labels)1659 void CvDTree::cluster_categories( const int* vectors, int n, int m,
1660 int* csums, int k, int* labels )
1661 {
1662 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
1663 int iters = 0, max_iters = 100;
1664 int i, j, idx;
1665 double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
1666 double *v_weights = buf, *c_weights = buf + k;
1667 bool modified = true;
1668 CvRNG* r = &data->rng;
1669
1670 // assign labels randomly
1671 for( i = idx = 0; i < n; i++ )
1672 {
1673 int sum = 0;
1674 const int* v = vectors + i*m;
1675 labels[i] = idx++;
1676 idx &= idx < k ? -1 : 0;
1677
1678 // compute weight of each vector
1679 for( j = 0; j < m; j++ )
1680 sum += v[j];
1681 v_weights[i] = sum ? 1./sum : 0.;
1682 }
1683
1684 for( i = 0; i < n; i++ )
1685 {
1686 int i1 = cvRandInt(r) % n;
1687 int i2 = cvRandInt(r) % n;
1688 CV_SWAP( labels[i1], labels[i2], j );
1689 }
1690
1691 for( iters = 0; iters <= max_iters; iters++ )
1692 {
1693 // calculate csums
1694 for( i = 0; i < k; i++ )
1695 {
1696 for( j = 0; j < m; j++ )
1697 csums[i*m + j] = 0;
1698 }
1699
1700 for( i = 0; i < n; i++ )
1701 {
1702 const int* v = vectors + i*m;
1703 int* s = csums + labels[i]*m;
1704 for( j = 0; j < m; j++ )
1705 s[j] += v[j];
1706 }
1707
1708 // exit the loop here, when we have up-to-date csums
1709 if( iters == max_iters || !modified )
1710 break;
1711
1712 modified = false;
1713
1714 // calculate weight of each cluster
1715 for( i = 0; i < k; i++ )
1716 {
1717 const int* s = csums + i*m;
1718 int sum = 0;
1719 for( j = 0; j < m; j++ )
1720 sum += s[j];
1721 c_weights[i] = sum ? 1./sum : 0;
1722 }
1723
1724 // now for each vector determine the closest cluster
1725 for( i = 0; i < n; i++ )
1726 {
1727 const int* v = vectors + i*m;
1728 double alpha = v_weights[i];
1729 double min_dist2 = DBL_MAX;
1730 int min_idx = -1;
1731
1732 for( idx = 0; idx < k; idx++ )
1733 {
1734 const int* s = csums + idx*m;
1735 double dist2 = 0., beta = c_weights[idx];
1736 for( j = 0; j < m; j++ )
1737 {
1738 double t = v[j]*alpha - s[j]*beta;
1739 dist2 += t*t;
1740 }
1741 if( min_dist2 > dist2 )
1742 {
1743 min_dist2 = dist2;
1744 min_idx = idx;
1745 }
1746 }
1747
1748 if( min_idx != labels[i] )
1749 modified = true;
1750 labels[i] = min_idx;
1751 }
1752 }
1753 }
1754
1755
find_split_cat_class(CvDTreeNode * node,int vi)1756 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
1757 {
1758 CvDTreeSplit* split;
1759 const int* labels = data->get_cat_var_data(node, vi);
1760 const int* responses = data->get_class_labels(node);
1761 int ci = data->get_var_type(vi);
1762 int n = node->sample_count;
1763 int m = data->get_num_classes();
1764 int _mi = data->cat_count->data.i[ci], mi = _mi;
1765 int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
1766 int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
1767 int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
1768 double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
1769 int* cluster_labels = 0;
1770 int** int_ptr = 0;
1771 int i, j, k, idx;
1772 double L = 0, R = 0;
1773 double best_val = 0;
1774 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
1775 const double* priors = data->priors_mult->data.db;
1776
1777 // init array of counters:
1778 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
1779 for( j = -1; j < mi; j++ )
1780 for( k = 0; k < m; k++ )
1781 cjk[j*m + k] = 0;
1782
1783 for( i = 0; i < n; i++ )
1784 {
1785 j = labels[i];
1786 k = responses[i];
1787 cjk[j*m + k]++;
1788 }
1789
1790 if( m > 2 )
1791 {
1792 if( mi > data->params.max_categories )
1793 {
1794 mi = MIN(data->params.max_categories, n);
1795 cjk += _mi*m;
1796 cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
1797 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
1798 }
1799 subset_i = 1;
1800 subset_n = 1 << mi;
1801 }
1802 else
1803 {
1804 assert( m == 2 );
1805 int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
1806 for( j = 0; j < mi; j++ )
1807 int_ptr[j] = cjk + j*2 + 1;
1808 icvSortIntPtr( int_ptr, mi, 0 );
1809 subset_i = 0;
1810 subset_n = mi;
1811 }
1812
1813 for( k = 0; k < m; k++ )
1814 {
1815 int sum = 0;
1816 for( j = 0; j < mi; j++ )
1817 sum += cjk[j*m + k];
1818 rc[k] = sum;
1819 lc[k] = 0;
1820 }
1821
1822 for( j = 0; j < mi; j++ )
1823 {
1824 double sum = 0;
1825 for( k = 0; k < m; k++ )
1826 sum += cjk[j*m + k]*priors[k];
1827 c_weights[j] = sum;
1828 R += c_weights[j];
1829 }
1830
1831 for( ; subset_i < subset_n; subset_i++ )
1832 {
1833 double weight;
1834 int* crow;
1835 double lsum2 = 0, rsum2 = 0;
1836
1837 if( m == 2 )
1838 idx = (int)(int_ptr[subset_i] - cjk)/2;
1839 else
1840 {
1841 int graycode = (subset_i>>1)^subset_i;
1842 int diff = graycode ^ prevcode;
1843
1844 // determine index of the changed bit.
1845 Cv32suf u;
1846 idx = diff >= (1 << 16) ? 16 : 0;
1847 u.f = (float)(((diff >> 16) | diff) & 65535);
1848 idx += (u.i >> 23) - 127;
1849 subtract = graycode < prevcode;
1850 prevcode = graycode;
1851 }
1852
1853 crow = cjk + idx*m;
1854 weight = c_weights[idx];
1855 if( weight < FLT_EPSILON )
1856 continue;
1857
1858 if( !subtract )
1859 {
1860 for( k = 0; k < m; k++ )
1861 {
1862 int t = crow[k];
1863 int lval = lc[k] + t;
1864 int rval = rc[k] - t;
1865 double p = priors[k], p2 = p*p;
1866 lsum2 += p2*lval*lval;
1867 rsum2 += p2*rval*rval;
1868 lc[k] = lval; rc[k] = rval;
1869 }
1870 L += weight;
1871 R -= weight;
1872 }
1873 else
1874 {
1875 for( k = 0; k < m; k++ )
1876 {
1877 int t = crow[k];
1878 int lval = lc[k] - t;
1879 int rval = rc[k] + t;
1880 double p = priors[k], p2 = p*p;
1881 lsum2 += p2*lval*lval;
1882 rsum2 += p2*rval*rval;
1883 lc[k] = lval; rc[k] = rval;
1884 }
1885 L -= weight;
1886 R += weight;
1887 }
1888
1889 if( L > FLT_EPSILON && R > FLT_EPSILON )
1890 {
1891 double val = (lsum2*R + rsum2*L)/((double)L*R);
1892 if( best_val < val )
1893 {
1894 best_val = val;
1895 best_subset = subset_i;
1896 }
1897 }
1898 }
1899
1900 if( best_subset < 0 )
1901 return 0;
1902
1903 split = data->new_split_cat( vi, (float)best_val );
1904
1905 if( m == 2 )
1906 {
1907 for( i = 0; i <= best_subset; i++ )
1908 {
1909 idx = (int)(int_ptr[i] - cjk) >> 1;
1910 split->subset[idx >> 5] |= 1 << (idx & 31);
1911 }
1912 }
1913 else
1914 {
1915 for( i = 0; i < _mi; i++ )
1916 {
1917 idx = cluster_labels ? cluster_labels[i] : i;
1918 if( best_subset & (1 << idx) )
1919 split->subset[i >> 5] |= 1 << (i & 31);
1920 }
1921 }
1922
1923 return split;
1924 }
1925
1926
find_split_ord_reg(CvDTreeNode * node,int vi)1927 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
1928 {
1929 const float epsilon = FLT_EPSILON*2;
1930 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
1931 const float* responses = data->get_ord_responses(node);
1932 int n = node->sample_count;
1933 int n1 = node->get_num_valid(vi);
1934 int i, best_i = -1;
1935 double best_val = 0, lsum = 0, rsum = node->value*n;
1936 int L = 0, R = n1;
1937
1938 // compensate for missing values
1939 for( i = n1; i < n; i++ )
1940 rsum -= responses[sorted[i].i];
1941
1942 // find the optimal split
1943 for( i = 0; i < n1 - 1; i++ )
1944 {
1945 float t = responses[sorted[i].i];
1946 L++; R--;
1947 lsum += t;
1948 rsum -= t;
1949
1950 if( sorted[i].val + epsilon < sorted[i+1].val )
1951 {
1952 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
1953 if( best_val < val )
1954 {
1955 best_val = val;
1956 best_i = i;
1957 }
1958 }
1959 }
1960
1961 return best_i >= 0 ? data->new_split_ord( vi,
1962 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
1963 0, (float)best_val ) : 0;
1964 }
1965
1966
find_split_cat_reg(CvDTreeNode * node,int vi)1967 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
1968 {
1969 CvDTreeSplit* split;
1970 const int* labels = data->get_cat_var_data(node, vi);
1971 const float* responses = data->get_ord_responses(node);
1972 int ci = data->get_var_type(vi);
1973 int n = node->sample_count;
1974 int mi = data->cat_count->data.i[ci];
1975 double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
1976 int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
1977 double** sum_ptr = 0;
1978 int i, L = 0, R = 0;
1979 double best_val = 0, lsum = 0, rsum = 0;
1980 int best_subset = -1, subset_i;
1981
1982 for( i = -1; i < mi; i++ )
1983 sum[i] = counts[i] = 0;
1984
1985 // calculate sum response and weight of each category of the input var
1986 for( i = 0; i < n; i++ )
1987 {
1988 int idx = labels[i];
1989 double s = sum[idx] + responses[i];
1990 int nc = counts[idx] + 1;
1991 sum[idx] = s;
1992 counts[idx] = nc;
1993 }
1994
1995 // calculate average response in each category
1996 for( i = 0; i < mi; i++ )
1997 {
1998 R += counts[i];
1999 rsum += sum[i];
2000 sum[i] /= MAX(counts[i],1);
2001 sum_ptr[i] = sum + i;
2002 }
2003
2004 icvSortDblPtr( sum_ptr, mi, 0 );
2005
2006 // revert back to unnormalized sums
2007 // (there should be a very little loss of accuracy)
2008 for( i = 0; i < mi; i++ )
2009 sum[i] *= counts[i];
2010
2011 for( subset_i = 0; subset_i < mi-1; subset_i++ )
2012 {
2013 int idx = (int)(sum_ptr[subset_i] - sum);
2014 int ni = counts[idx];
2015
2016 if( ni )
2017 {
2018 double s = sum[idx];
2019 lsum += s; L += ni;
2020 rsum -= s; R -= ni;
2021
2022 if( L && R )
2023 {
2024 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2025 if( best_val < val )
2026 {
2027 best_val = val;
2028 best_subset = subset_i;
2029 }
2030 }
2031 }
2032 }
2033
2034 if( best_subset < 0 )
2035 return 0;
2036
2037 split = data->new_split_cat( vi, (float)best_val );
2038 for( i = 0; i <= best_subset; i++ )
2039 {
2040 int idx = (int)(sum_ptr[i] - sum);
2041 split->subset[idx >> 5] |= 1 << (idx & 31);
2042 }
2043
2044 return split;
2045 }
2046
2047
find_surrogate_split_ord(CvDTreeNode * node,int vi)2048 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
2049 {
2050 const float epsilon = FLT_EPSILON*2;
2051 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2052 const char* dir = (char*)data->direction->data.ptr;
2053 int n1 = node->get_num_valid(vi);
2054 // LL - number of samples that both the primary and the surrogate splits send to the left
2055 // LR - ... primary split sends to the left and the surrogate split sends to the right
2056 // RL - ... primary split sends to the right and the surrogate split sends to the left
2057 // RR - ... both send to the right
2058 int i, best_i = -1, best_inversed = 0;
2059 double best_val;
2060
2061 if( !data->have_priors )
2062 {
2063 int LL = 0, RL = 0, LR, RR;
2064 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2065 int sum = 0, sum_abs = 0;
2066
2067 for( i = 0; i < n1; i++ )
2068 {
2069 int d = dir[sorted[i].i];
2070 sum += d; sum_abs += d & 1;
2071 }
2072
2073 // sum_abs = R + L; sum = R - L
2074 RR = (sum_abs + sum) >> 1;
2075 LR = (sum_abs - sum) >> 1;
2076
2077 // initially all the samples are sent to the right by the surrogate split,
2078 // LR of them are sent to the left by primary split, and RR - to the right.
2079 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2080 for( i = 0; i < n1 - 1; i++ )
2081 {
2082 int d = dir[sorted[i].i];
2083
2084 if( d < 0 )
2085 {
2086 LL++; LR--;
2087 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
2088 {
2089 best_val = LL + RR;
2090 best_i = i; best_inversed = 0;
2091 }
2092 }
2093 else if( d > 0 )
2094 {
2095 RL++; RR--;
2096 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
2097 {
2098 best_val = RL + LR;
2099 best_i = i; best_inversed = 1;
2100 }
2101 }
2102 }
2103 best_val = _best_val;
2104 }
2105 else
2106 {
2107 double LL = 0, RL = 0, LR, RR;
2108 double worst_val = node->maxlr;
2109 double sum = 0, sum_abs = 0;
2110 const double* priors = data->priors_mult->data.db;
2111 const int* responses = data->get_class_labels(node);
2112 best_val = worst_val;
2113
2114 for( i = 0; i < n1; i++ )
2115 {
2116 int idx = sorted[i].i;
2117 double w = priors[responses[idx]];
2118 int d = dir[idx];
2119 sum += d*w; sum_abs += (d & 1)*w;
2120 }
2121
2122 // sum_abs = R + L; sum = R - L
2123 RR = (sum_abs + sum)*0.5;
2124 LR = (sum_abs - sum)*0.5;
2125
2126 // initially all the samples are sent to the right by the surrogate split,
2127 // LR of them are sent to the left by primary split, and RR - to the right.
2128 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2129 for( i = 0; i < n1 - 1; i++ )
2130 {
2131 int idx = sorted[i].i;
2132 double w = priors[responses[idx]];
2133 int d = dir[idx];
2134
2135 if( d < 0 )
2136 {
2137 LL += w; LR -= w;
2138 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
2139 {
2140 best_val = LL + RR;
2141 best_i = i; best_inversed = 0;
2142 }
2143 }
2144 else if( d > 0 )
2145 {
2146 RL += w; RR -= w;
2147 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
2148 {
2149 best_val = RL + LR;
2150 best_i = i; best_inversed = 1;
2151 }
2152 }
2153 }
2154 }
2155
2156 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2157 (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
2158 best_inversed, (float)best_val ) : 0;
2159 }
2160
2161
find_surrogate_split_cat(CvDTreeNode * node,int vi)2162 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
2163 {
2164 const int* labels = data->get_cat_var_data(node, vi);
2165 const char* dir = (char*)data->direction->data.ptr;
2166 int n = node->sample_count;
2167 // LL - number of samples that both the primary and the surrogate splits send to the left
2168 // LR - ... primary split sends to the left and the surrogate split sends to the right
2169 // RL - ... primary split sends to the right and the surrogate split sends to the left
2170 // RR - ... both send to the right
2171 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2172 int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2173 double best_val = 0;
2174 double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
2175 double* rc = lc + mi + 1;
2176
2177 for( i = -1; i < mi; i++ )
2178 lc[i] = rc[i] = 0;
2179
2180 // for each category calculate the weight of samples
2181 // sent to the left (lc) and to the right (rc) by the primary split
2182 if( !data->have_priors )
2183 {
2184 int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
2185 int* _rc = _lc + mi + 1;
2186
2187 for( i = -1; i < mi; i++ )
2188 _lc[i] = _rc[i] = 0;
2189
2190 for( i = 0; i < n; i++ )
2191 {
2192 int idx = labels[i];
2193 int d = dir[i];
2194 int sum = _lc[idx] + d;
2195 int sum_abs = _rc[idx] + (d & 1);
2196 _lc[idx] = sum; _rc[idx] = sum_abs;
2197 }
2198
2199 for( i = 0; i < mi; i++ )
2200 {
2201 int sum = _lc[i];
2202 int sum_abs = _rc[i];
2203 lc[i] = (sum_abs - sum) >> 1;
2204 rc[i] = (sum_abs + sum) >> 1;
2205 }
2206 }
2207 else
2208 {
2209 const double* priors = data->priors_mult->data.db;
2210 const int* responses = data->get_class_labels(node);
2211
2212 for( i = 0; i < n; i++ )
2213 {
2214 int idx = labels[i];
2215 double w = priors[responses[i]];
2216 int d = dir[i];
2217 double sum = lc[idx] + d*w;
2218 double sum_abs = rc[idx] + (d & 1)*w;
2219 lc[idx] = sum; rc[idx] = sum_abs;
2220 }
2221
2222 for( i = 0; i < mi; i++ )
2223 {
2224 double sum = lc[i];
2225 double sum_abs = rc[i];
2226 lc[i] = (sum_abs - sum) * 0.5;
2227 rc[i] = (sum_abs + sum) * 0.5;
2228 }
2229 }
2230
2231 // 2. now form the split.
2232 // in each category send all the samples to the same direction as majority
2233 for( i = 0; i < mi; i++ )
2234 {
2235 double lval = lc[i], rval = rc[i];
2236 if( lval > rval )
2237 {
2238 split->subset[i >> 5] |= 1 << (i & 31);
2239 best_val += lval;
2240 l_win++;
2241 }
2242 else
2243 best_val += rval;
2244 }
2245
2246 split->quality = (float)best_val;
2247 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2248 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2249
2250 return split;
2251 }
2252
2253
calc_node_value(CvDTreeNode * node)2254 void CvDTree::calc_node_value( CvDTreeNode* node )
2255 {
2256 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2257 const int* cv_labels = data->get_labels(node);
2258
2259 if( data->is_classifier )
2260 {
2261 // in case of classification tree:
2262 // * node value is the label of the class that has the largest weight in the node.
2263 // * node risk is the weighted number of misclassified samples,
2264 // * j-th cross-validation fold value and risk are calculated as above,
2265 // but using the samples with cv_labels(*)!=j.
2266 // * j-th cross-validation fold error is calculated as the weighted number of
2267 // misclassified samples with cv_labels(*)==j.
2268
2269 // compute the number of instances of each class
2270 int* cls_count = data->counts->data.i;
2271 const int* responses = data->get_class_labels(node);
2272 int m = data->get_num_classes();
2273 int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
2274 double max_val = -1, total_weight = 0;
2275 int max_k = -1;
2276 double* priors = data->priors_mult->data.db;
2277
2278 for( k = 0; k < m; k++ )
2279 cls_count[k] = 0;
2280
2281 if( cv_n == 0 )
2282 {
2283 for( i = 0; i < n; i++ )
2284 cls_count[responses[i]]++;
2285 }
2286 else
2287 {
2288 for( j = 0; j < cv_n; j++ )
2289 for( k = 0; k < m; k++ )
2290 cv_cls_count[j*m + k] = 0;
2291
2292 for( i = 0; i < n; i++ )
2293 {
2294 j = cv_labels[i]; k = responses[i];
2295 cv_cls_count[j*m + k]++;
2296 }
2297
2298 for( j = 0; j < cv_n; j++ )
2299 for( k = 0; k < m; k++ )
2300 cls_count[k] += cv_cls_count[j*m + k];
2301 }
2302
2303 if( data->have_priors && node->parent == 0 )
2304 {
2305 // compute priors_mult from priors, take the sample ratio into account.
2306 double sum = 0;
2307 for( k = 0; k < m; k++ )
2308 {
2309 int n_k = cls_count[k];
2310 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2311 sum += priors[k];
2312 }
2313 sum = 1./sum;
2314 for( k = 0; k < m; k++ )
2315 priors[k] *= sum;
2316 }
2317
2318 for( k = 0; k < m; k++ )
2319 {
2320 double val = cls_count[k]*priors[k];
2321 total_weight += val;
2322 if( max_val < val )
2323 {
2324 max_val = val;
2325 max_k = k;
2326 }
2327 }
2328
2329 node->class_idx = max_k;
2330 node->value = data->cat_map->data.i[
2331 data->cat_ofs->data.i[data->cat_var_count] + max_k];
2332 node->node_risk = total_weight - max_val;
2333
2334 for( j = 0; j < cv_n; j++ )
2335 {
2336 double sum_k = 0, sum = 0, max_val_k = 0;
2337 max_val = -1; max_k = -1;
2338
2339 for( k = 0; k < m; k++ )
2340 {
2341 double w = priors[k];
2342 double val_k = cv_cls_count[j*m + k]*w;
2343 double val = cls_count[k]*w - val_k;
2344 sum_k += val_k;
2345 sum += val;
2346 if( max_val < val )
2347 {
2348 max_val = val;
2349 max_val_k = val_k;
2350 max_k = k;
2351 }
2352 }
2353
2354 node->cv_Tn[j] = INT_MAX;
2355 node->cv_node_risk[j] = sum - max_val;
2356 node->cv_node_error[j] = sum_k - max_val_k;
2357 }
2358 }
2359 else
2360 {
2361 // in case of regression tree:
2362 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2363 // n is the number of samples in the node.
2364 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2365 // * j-th cross-validation fold value and risk are calculated as above,
2366 // but using the samples with cv_labels(*)!=j.
2367 // * j-th cross-validation fold error is calculated
2368 // using samples with cv_labels(*)==j as the test subset:
2369 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2370 // where node_value_j is the node value calculated
2371 // as described in the previous bullet, and summation is done
2372 // over the samples with cv_labels(*)==j.
2373
2374 double sum = 0, sum2 = 0;
2375 const float* values = data->get_ord_responses(node);
2376 double *cv_sum = 0, *cv_sum2 = 0;
2377 int* cv_count = 0;
2378
2379 if( cv_n == 0 )
2380 {
2381 for( i = 0; i < n; i++ )
2382 {
2383 double t = values[i];
2384 sum += t;
2385 sum2 += t*t;
2386 }
2387 }
2388 else
2389 {
2390 cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
2391 cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
2392 cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
2393
2394 for( j = 0; j < cv_n; j++ )
2395 {
2396 cv_sum[j] = cv_sum2[j] = 0.;
2397 cv_count[j] = 0;
2398 }
2399
2400 for( i = 0; i < n; i++ )
2401 {
2402 j = cv_labels[i];
2403 double t = values[i];
2404 double s = cv_sum[j] + t;
2405 double s2 = cv_sum2[j] + t*t;
2406 int nc = cv_count[j] + 1;
2407 cv_sum[j] = s;
2408 cv_sum2[j] = s2;
2409 cv_count[j] = nc;
2410 }
2411
2412 for( j = 0; j < cv_n; j++ )
2413 {
2414 sum += cv_sum[j];
2415 sum2 += cv_sum2[j];
2416 }
2417 }
2418
2419 node->node_risk = sum2 - (sum/n)*sum;
2420 node->value = sum/n;
2421
2422 for( j = 0; j < cv_n; j++ )
2423 {
2424 double s = cv_sum[j], si = sum - s;
2425 double s2 = cv_sum2[j], s2i = sum2 - s2;
2426 int c = cv_count[j], ci = n - c;
2427 double r = si/MAX(ci,1);
2428 node->cv_node_risk[j] = s2i - r*r*ci;
2429 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2430 node->cv_Tn[j] = INT_MAX;
2431 }
2432 }
2433 }
2434
2435
complete_node_dir(CvDTreeNode * node)2436 void CvDTree::complete_node_dir( CvDTreeNode* node )
2437 {
2438 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2439 int nz = n - node->get_num_valid(node->split->var_idx);
2440 char* dir = (char*)data->direction->data.ptr;
2441
2442 // try to complete direction using surrogate splits
2443 if( nz && data->params.use_surrogates )
2444 {
2445 CvDTreeSplit* split = node->split->next;
2446 for( ; split != 0 && nz; split = split->next )
2447 {
2448 int inversed_mask = split->inversed ? -1 : 0;
2449 vi = split->var_idx;
2450
2451 if( data->get_var_type(vi) >= 0 ) // split on categorical var
2452 {
2453 const int* labels = data->get_cat_var_data(node, vi);
2454 const int* subset = split->subset;
2455
2456 for( i = 0; i < n; i++ )
2457 {
2458 int idx;
2459 if( !dir[i] && (idx = labels[i]) >= 0 )
2460 {
2461 int d = CV_DTREE_CAT_DIR(idx,subset);
2462 dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2463 if( --nz )
2464 break;
2465 }
2466 }
2467 }
2468 else // split on ordered var
2469 {
2470 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
2471 int split_point = split->ord.split_point;
2472 int n1 = node->get_num_valid(vi);
2473
2474 assert( 0 <= split_point && split_point < n-1 );
2475
2476 for( i = 0; i < n1; i++ )
2477 {
2478 int idx = sorted[i].i;
2479 if( !dir[idx] )
2480 {
2481 int d = i <= split_point ? -1 : 1;
2482 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
2483 if( --nz )
2484 break;
2485 }
2486 }
2487 }
2488 }
2489 }
2490
2491 // find the default direction for the rest
2492 if( nz )
2493 {
2494 for( i = nr = 0; i < n; i++ )
2495 nr += dir[i] > 0;
2496 nl = n - nr - nz;
2497 d0 = nl > nr ? -1 : nr > nl;
2498 }
2499
2500 // make sure that every sample is directed either to the left or to the right
2501 for( i = 0; i < n; i++ )
2502 {
2503 int d = dir[i];
2504 if( !d )
2505 {
2506 d = d0;
2507 if( !d )
2508 d = d1, d1 = -d1;
2509 }
2510 d = d > 0;
2511 dir[i] = (char)d; // remap (-1,1) to (0,1)
2512 }
2513 }
2514
2515
split_node_data(CvDTreeNode * node)2516 void CvDTree::split_node_data( CvDTreeNode* node )
2517 {
2518 int vi, i, n = node->sample_count, nl, nr;
2519 char* dir = (char*)data->direction->data.ptr;
2520 CvDTreeNode *left = 0, *right = 0;
2521 int* new_idx = data->split_buf->data.i;
2522 int new_buf_idx = data->get_child_buf_idx( node );
2523 int work_var_count = data->get_work_var_count();
2524
2525 // speedup things a little, especially for tree ensembles with a lots of small trees:
2526 // do not physically split the input data between the left and right child nodes
2527 // when we are not going to split them further,
2528 // as calc_node_value() does not requires input features anyway.
2529 bool split_input_data;
2530
2531 complete_node_dir(node);
2532
2533 for( i = nl = nr = 0; i < n; i++ )
2534 {
2535 int d = dir[i];
2536 // initialize new indices for splitting ordered variables
2537 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
2538 nr += d;
2539 nl += d^1;
2540 }
2541
2542 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
2543 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
2544 (data->ord_var_count + work_var_count)*nl );
2545
2546 split_input_data = node->depth + 1 < data->params.max_depth &&
2547 (node->left->sample_count > data->params.min_sample_count ||
2548 node->right->sample_count > data->params.min_sample_count);
2549
2550 // split ordered variables, keep both halves sorted.
2551 for( vi = 0; vi < data->var_count; vi++ )
2552 {
2553 int ci = data->get_var_type(vi);
2554 int n1 = node->get_num_valid(vi);
2555 CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
2556 CvPair32s32f tl, tr;
2557
2558 if( ci >= 0 || !split_input_data )
2559 continue;
2560
2561 src = data->get_ord_var_data(node, vi);
2562 ldst0 = ldst = data->get_ord_var_data(left, vi);
2563 rdst0 = rdst = data->get_ord_var_data(right, vi);
2564 tl = ldst0[nl]; tr = rdst0[nr];
2565
2566 // split sorted
2567 for( i = 0; i < n1; i++ )
2568 {
2569 int idx = src[i].i;
2570 float val = src[i].val;
2571 int d = dir[idx];
2572 idx = new_idx[idx];
2573 ldst->i = rdst->i = idx;
2574 ldst->val = rdst->val = val;
2575 ldst += d^1;
2576 rdst += d;
2577 }
2578
2579 left->set_num_valid(vi, (int)(ldst - ldst0));
2580 right->set_num_valid(vi, (int)(rdst - rdst0));
2581
2582 // split missing
2583 for( ; i < n; i++ )
2584 {
2585 int idx = src[i].i;
2586 int d = dir[idx];
2587 idx = new_idx[idx];
2588 ldst->i = rdst->i = idx;
2589 ldst->val = rdst->val = ord_nan;
2590 ldst += d^1;
2591 rdst += d;
2592 }
2593
2594 ldst0[nl] = tl; rdst0[nr] = tr;
2595 }
2596
2597 // split categorical vars, responses and cv_labels using new_idx relocation table
2598 for( vi = 0; vi < work_var_count; vi++ )
2599 {
2600 int ci = data->get_var_type(vi);
2601 int n1 = node->get_num_valid(vi), nr1 = 0;
2602 int *src, *ldst0, *rdst0, *ldst, *rdst;
2603 int tl, tr;
2604
2605 if( ci < 0 || (vi < data->var_count && !split_input_data) )
2606 continue;
2607
2608 src = data->get_cat_var_data(node, vi);
2609 ldst0 = ldst = data->get_cat_var_data(left, vi);
2610 rdst0 = rdst = data->get_cat_var_data(right, vi);
2611 tl = ldst0[nl]; tr = rdst0[nr];
2612
2613 for( i = 0; i < n; i++ )
2614 {
2615 int d = dir[i];
2616 int val = src[i];
2617 *ldst = *rdst = val;
2618 ldst += d^1;
2619 rdst += d;
2620 nr1 += (val >= 0)&d;
2621 }
2622
2623 if( vi < data->var_count )
2624 {
2625 left->set_num_valid(vi, n1 - nr1);
2626 right->set_num_valid(vi, nr1);
2627 }
2628
2629 ldst0[nl] = tl; rdst0[nr] = tr;
2630 }
2631
2632 // deallocate the parent node data that is not needed anymore
2633 data->free_node_data(node);
2634 }
2635
2636
prune_cv()2637 void CvDTree::prune_cv()
2638 {
2639 CvMat* ab = 0;
2640 CvMat* temp = 0;
2641 CvMat* err_jk = 0;
2642
2643 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
2644 // 2. choose the best tree index (if need, apply 1SE rule).
2645 // 3. store the best index and cut the branches.
2646
2647 CV_FUNCNAME( "CvDTree::prune_cv" );
2648
2649 __BEGIN__;
2650
2651 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
2652 // currently, 1SE for regression is not implemented
2653 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
2654 double* err;
2655 double min_err = 0, min_err_se = 0;
2656 int min_idx = -1;
2657
2658 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
2659
2660 // build the main tree sequence, calculate alpha's
2661 for(;;tree_count++)
2662 {
2663 double min_alpha = update_tree_rnc(tree_count, -1);
2664 if( cut_tree(tree_count, -1, min_alpha) )
2665 break;
2666
2667 if( ab->cols <= tree_count )
2668 {
2669 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
2670 for( ti = 0; ti < ab->cols; ti++ )
2671 temp->data.db[ti] = ab->data.db[ti];
2672 cvReleaseMat( &ab );
2673 ab = temp;
2674 temp = 0;
2675 }
2676
2677 ab->data.db[tree_count] = min_alpha;
2678 }
2679
2680 ab->data.db[0] = 0.;
2681
2682 if( tree_count > 0 )
2683 {
2684 for( ti = 1; ti < tree_count-1; ti++ )
2685 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
2686 ab->data.db[tree_count-1] = DBL_MAX*0.5;
2687
2688 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
2689 err = err_jk->data.db;
2690
2691 for( j = 0; j < cv_n; j++ )
2692 {
2693 int tj = 0, tk = 0;
2694 for( ; tk < tree_count; tj++ )
2695 {
2696 double min_alpha = update_tree_rnc(tj, j);
2697 if( cut_tree(tj, j, min_alpha) )
2698 min_alpha = DBL_MAX;
2699
2700 for( ; tk < tree_count; tk++ )
2701 {
2702 if( ab->data.db[tk] > min_alpha )
2703 break;
2704 err[j*tree_count + tk] = root->tree_error;
2705 }
2706 }
2707 }
2708
2709 for( ti = 0; ti < tree_count; ti++ )
2710 {
2711 double sum_err = 0;
2712 for( j = 0; j < cv_n; j++ )
2713 sum_err += err[j*tree_count + ti];
2714 if( ti == 0 || sum_err < min_err )
2715 {
2716 min_err = sum_err;
2717 min_idx = ti;
2718 if( use_1se )
2719 min_err_se = sqrt( sum_err*(n - sum_err) );
2720 }
2721 else if( sum_err < min_err + min_err_se )
2722 min_idx = ti;
2723 }
2724 }
2725
2726 pruned_tree_idx = min_idx;
2727 free_prune_data(data->params.truncate_pruned_tree != 0);
2728
2729 __END__;
2730
2731 cvReleaseMat( &err_jk );
2732 cvReleaseMat( &ab );
2733 cvReleaseMat( &temp );
2734 }
2735
2736
update_tree_rnc(int T,int fold)2737 double CvDTree::update_tree_rnc( int T, int fold )
2738 {
2739 CvDTreeNode* node = root;
2740 double min_alpha = DBL_MAX;
2741
2742 for(;;)
2743 {
2744 CvDTreeNode* parent;
2745 for(;;)
2746 {
2747 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2748 if( t <= T || !node->left )
2749 {
2750 node->complexity = 1;
2751 node->tree_risk = node->node_risk;
2752 node->tree_error = 0.;
2753 if( fold >= 0 )
2754 {
2755 node->tree_risk = node->cv_node_risk[fold];
2756 node->tree_error = node->cv_node_error[fold];
2757 }
2758 break;
2759 }
2760 node = node->left;
2761 }
2762
2763 for( parent = node->parent; parent && parent->right == node;
2764 node = parent, parent = parent->parent )
2765 {
2766 parent->complexity += node->complexity;
2767 parent->tree_risk += node->tree_risk;
2768 parent->tree_error += node->tree_error;
2769
2770 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
2771 - parent->tree_risk)/(parent->complexity - 1);
2772 min_alpha = MIN( min_alpha, parent->alpha );
2773 }
2774
2775 if( !parent )
2776 break;
2777
2778 parent->complexity = node->complexity;
2779 parent->tree_risk = node->tree_risk;
2780 parent->tree_error = node->tree_error;
2781 node = parent->right;
2782 }
2783
2784 return min_alpha;
2785 }
2786
2787
cut_tree(int T,int fold,double min_alpha)2788 int CvDTree::cut_tree( int T, int fold, double min_alpha )
2789 {
2790 CvDTreeNode* node = root;
2791 if( !node->left )
2792 return 1;
2793
2794 for(;;)
2795 {
2796 CvDTreeNode* parent;
2797 for(;;)
2798 {
2799 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
2800 if( t <= T || !node->left )
2801 break;
2802 if( node->alpha <= min_alpha + FLT_EPSILON )
2803 {
2804 if( fold >= 0 )
2805 node->cv_Tn[fold] = T;
2806 else
2807 node->Tn = T;
2808 if( node == root )
2809 return 1;
2810 break;
2811 }
2812 node = node->left;
2813 }
2814
2815 for( parent = node->parent; parent && parent->right == node;
2816 node = parent, parent = parent->parent )
2817 ;
2818
2819 if( !parent )
2820 break;
2821
2822 node = parent->right;
2823 }
2824
2825 return 0;
2826 }
2827
2828
free_prune_data(bool cut_tree)2829 void CvDTree::free_prune_data(bool cut_tree)
2830 {
2831 CvDTreeNode* node = root;
2832
2833 for(;;)
2834 {
2835 CvDTreeNode* parent;
2836 for(;;)
2837 {
2838 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
2839 // as we will clear the whole cross-validation heap at the end
2840 node->cv_Tn = 0;
2841 node->cv_node_error = node->cv_node_risk = 0;
2842 if( !node->left )
2843 break;
2844 node = node->left;
2845 }
2846
2847 for( parent = node->parent; parent && parent->right == node;
2848 node = parent, parent = parent->parent )
2849 {
2850 if( cut_tree && parent->Tn <= pruned_tree_idx )
2851 {
2852 data->free_node( parent->left );
2853 data->free_node( parent->right );
2854 parent->left = parent->right = 0;
2855 }
2856 }
2857
2858 if( !parent )
2859 break;
2860
2861 node = parent->right;
2862 }
2863
2864 if( data->cv_heap )
2865 cvClearSet( data->cv_heap );
2866 }
2867
2868
free_tree()2869 void CvDTree::free_tree()
2870 {
2871 if( root && data && data->shared )
2872 {
2873 pruned_tree_idx = INT_MIN;
2874 free_prune_data(true);
2875 data->free_node(root);
2876 root = 0;
2877 }
2878 }
2879
2880
predict(const CvMat * _sample,const CvMat * _missing,bool preprocessed_input) const2881 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
2882 const CvMat* _missing, bool preprocessed_input ) const
2883 {
2884 CvDTreeNode* result = 0;
2885 int* catbuf = 0;
2886
2887 CV_FUNCNAME( "CvDTree::predict" );
2888
2889 __BEGIN__;
2890
2891 int i, step, mstep = 0;
2892 const float* sample;
2893 const uchar* m = 0;
2894 CvDTreeNode* node = root;
2895 const int* vtype;
2896 const int* vidx;
2897 const int* cmap;
2898 const int* cofs;
2899
2900 if( !node )
2901 CV_ERROR( CV_StsError, "The tree has not been trained yet" );
2902
2903 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
2904 _sample->cols != 1 && _sample->rows != 1 ||
2905 _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
2906 _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
2907 CV_ERROR( CV_StsBadArg,
2908 "the input sample must be 1d floating-point vector with the same "
2909 "number of elements as the total number of variables used for training" );
2910
2911 sample = _sample->data.fl;
2912 step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
2913
2914 if( data->cat_count && !preprocessed_input ) // cache for categorical variables
2915 {
2916 int n = data->cat_count->cols;
2917 catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
2918 for( i = 0; i < n; i++ )
2919 catbuf[i] = -1;
2920 }
2921
2922 if( _missing )
2923 {
2924 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
2925 !CV_ARE_SIZES_EQ(_missing, _sample) )
2926 CV_ERROR( CV_StsBadArg,
2927 "the missing data mask must be 8-bit vector of the same size as input sample" );
2928 m = _missing->data.ptr;
2929 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
2930 }
2931
2932 vtype = data->var_type->data.i;
2933 vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
2934 cmap = data->cat_map ? data->cat_map->data.i : 0;
2935 cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
2936
2937 while( node->Tn > pruned_tree_idx && node->left )
2938 {
2939 CvDTreeSplit* split = node->split;
2940 int dir = 0;
2941 for( ; !dir && split != 0; split = split->next )
2942 {
2943 int vi = split->var_idx;
2944 int ci = vtype[vi];
2945 i = vidx ? vidx[vi] : vi;
2946 float val = sample[i*step];
2947 if( m && m[i*mstep] )
2948 continue;
2949 if( ci < 0 ) // ordered
2950 dir = val <= split->ord.c ? -1 : 1;
2951 else // categorical
2952 {
2953 int c;
2954 if( preprocessed_input )
2955 c = cvRound(val);
2956 else
2957 {
2958 c = catbuf[ci];
2959 if( c < 0 )
2960 {
2961 int a = c = cofs[ci];
2962 int b = cofs[ci+1];
2963 int ival = cvRound(val);
2964 if( ival != val )
2965 CV_ERROR( CV_StsBadArg,
2966 "one of input categorical variable is not an integer" );
2967
2968 while( a < b )
2969 {
2970 c = (a + b) >> 1;
2971 if( ival < cmap[c] )
2972 b = c;
2973 else if( ival > cmap[c] )
2974 a = c+1;
2975 else
2976 break;
2977 }
2978
2979 if( c < 0 || ival != cmap[c] )
2980 continue;
2981
2982 catbuf[ci] = c -= cofs[ci];
2983 }
2984 }
2985 dir = CV_DTREE_CAT_DIR(c, split->subset);
2986 }
2987
2988 if( split->inversed )
2989 dir = -dir;
2990 }
2991
2992 if( !dir )
2993 {
2994 double diff = node->right->sample_count - node->left->sample_count;
2995 dir = diff < 0 ? -1 : 1;
2996 }
2997 node = dir < 0 ? node->left : node->right;
2998 }
2999
3000 result = node;
3001
3002 __END__;
3003
3004 return result;
3005 }
3006
3007
get_var_importance()3008 const CvMat* CvDTree::get_var_importance()
3009 {
3010 if( !var_importance )
3011 {
3012 CvDTreeNode* node = root;
3013 double* importance;
3014 if( !node )
3015 return 0;
3016 var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3017 cvZero( var_importance );
3018 importance = var_importance->data.db;
3019
3020 for(;;)
3021 {
3022 CvDTreeNode* parent;
3023 for( ;; node = node->left )
3024 {
3025 CvDTreeSplit* split = node->split;
3026
3027 if( !node->left || node->Tn <= pruned_tree_idx )
3028 break;
3029
3030 for( ; split != 0; split = split->next )
3031 importance[split->var_idx] += split->quality;
3032 }
3033
3034 for( parent = node->parent; parent && parent->right == node;
3035 node = parent, parent = parent->parent )
3036 ;
3037
3038 if( !parent )
3039 break;
3040
3041 node = parent->right;
3042 }
3043
3044 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3045 }
3046
3047 return var_importance;
3048 }
3049
3050
write_split(CvFileStorage * fs,CvDTreeSplit * split)3051 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
3052 {
3053 int ci;
3054
3055 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3056 cvWriteInt( fs, "var", split->var_idx );
3057 cvWriteReal( fs, "quality", split->quality );
3058
3059 ci = data->get_var_type(split->var_idx);
3060 if( ci >= 0 ) // split on a categorical var
3061 {
3062 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3063 for( i = 0; i < n; i++ )
3064 to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3065
3066 // ad-hoc rule when to use inverse categorical split notation
3067 // to achieve more compact and clear representation
3068 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3069
3070 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3071 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3072
3073 for( i = 0; i < n; i++ )
3074 {
3075 int dir = CV_DTREE_CAT_DIR(i,split->subset);
3076 if( dir*default_dir < 0 )
3077 cvWriteInt( fs, 0, i );
3078 }
3079 cvEndWriteStruct( fs );
3080 }
3081 else
3082 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3083
3084 cvEndWriteStruct( fs );
3085 }
3086
3087
write_node(CvFileStorage * fs,CvDTreeNode * node)3088 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
3089 {
3090 CvDTreeSplit* split;
3091
3092 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3093
3094 cvWriteInt( fs, "depth", node->depth );
3095 cvWriteInt( fs, "sample_count", node->sample_count );
3096 cvWriteReal( fs, "value", node->value );
3097
3098 if( data->is_classifier )
3099 cvWriteInt( fs, "norm_class_idx", node->class_idx );
3100
3101 cvWriteInt( fs, "Tn", node->Tn );
3102 cvWriteInt( fs, "complexity", node->complexity );
3103 cvWriteReal( fs, "alpha", node->alpha );
3104 cvWriteReal( fs, "node_risk", node->node_risk );
3105 cvWriteReal( fs, "tree_risk", node->tree_risk );
3106 cvWriteReal( fs, "tree_error", node->tree_error );
3107
3108 if( node->left )
3109 {
3110 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3111
3112 for( split = node->split; split != 0; split = split->next )
3113 write_split( fs, split );
3114
3115 cvEndWriteStruct( fs );
3116 }
3117
3118 cvEndWriteStruct( fs );
3119 }
3120
3121
write_tree_nodes(CvFileStorage * fs)3122 void CvDTree::write_tree_nodes( CvFileStorage* fs )
3123 {
3124 //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3125
3126 __BEGIN__;
3127
3128 CvDTreeNode* node = root;
3129
3130 // traverse the tree and save all the nodes in depth-first order
3131 for(;;)
3132 {
3133 CvDTreeNode* parent;
3134 for(;;)
3135 {
3136 write_node( fs, node );
3137 if( !node->left )
3138 break;
3139 node = node->left;
3140 }
3141
3142 for( parent = node->parent; parent && parent->right == node;
3143 node = parent, parent = parent->parent )
3144 ;
3145
3146 if( !parent )
3147 break;
3148
3149 node = parent->right;
3150 }
3151
3152 __END__;
3153 }
3154
3155
write(CvFileStorage * fs,const char * name)3156 void CvDTree::write( CvFileStorage* fs, const char* name )
3157 {
3158 //CV_FUNCNAME( "CvDTree::write" );
3159
3160 __BEGIN__;
3161
3162 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3163
3164 get_var_importance();
3165 data->write_params( fs );
3166 if( var_importance )
3167 cvWrite( fs, "var_importance", var_importance );
3168 write( fs );
3169
3170 cvEndWriteStruct( fs );
3171
3172 __END__;
3173 }
3174
3175
write(CvFileStorage * fs)3176 void CvDTree::write( CvFileStorage* fs )
3177 {
3178 //CV_FUNCNAME( "CvDTree::write" );
3179
3180 __BEGIN__;
3181
3182 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3183
3184 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3185 write_tree_nodes( fs );
3186 cvEndWriteStruct( fs );
3187
3188 __END__;
3189 }
3190
3191
read_split(CvFileStorage * fs,CvFileNode * fnode)3192 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3193 {
3194 CvDTreeSplit* split = 0;
3195
3196 CV_FUNCNAME( "CvDTree::read_split" );
3197
3198 __BEGIN__;
3199
3200 int vi, ci;
3201
3202 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3203 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3204
3205 vi = cvReadIntByName( fs, fnode, "var", -1 );
3206 if( (unsigned)vi >= (unsigned)data->var_count )
3207 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3208
3209 ci = data->get_var_type(vi);
3210 if( ci >= 0 ) // split on categorical var
3211 {
3212 int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3213 CvSeqReader reader;
3214 CvFileNode* inseq;
3215 split = data->new_split_cat( vi, 0 );
3216 inseq = cvGetFileNodeByName( fs, fnode, "in" );
3217 if( !inseq )
3218 {
3219 inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3220 inversed = 1;
3221 }
3222 if( !inseq ||
3223 (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3224 CV_ERROR( CV_StsParseError,
3225 "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3226
3227 if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3228 {
3229 val = inseq->data.i;
3230 if( (unsigned)val >= (unsigned)n )
3231 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3232
3233 split->subset[val >> 5] |= 1 << (val & 31);
3234 }
3235 else
3236 {
3237 cvStartReadSeq( inseq->data.seq, &reader );
3238
3239 for( i = 0; i < reader.seq->total; i++ )
3240 {
3241 CvFileNode* inode = (CvFileNode*)reader.ptr;
3242 val = inode->data.i;
3243 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3244 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3245
3246 split->subset[val >> 5] |= 1 << (val & 31);
3247 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3248 }
3249 }
3250
3251 // for categorical splits we do not use inversed splits,
3252 // instead we inverse the variable set in the split
3253 if( inversed )
3254 for( i = 0; i < (n + 31) >> 5; i++ )
3255 split->subset[i] ^= -1;
3256 }
3257 else
3258 {
3259 CvFileNode* cmp_node;
3260 split = data->new_split_ord( vi, 0, 0, 0, 0 );
3261
3262 cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3263 if( !cmp_node )
3264 {
3265 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
3266 split->inversed = 1;
3267 }
3268
3269 split->ord.c = (float)cvReadReal( cmp_node );
3270 }
3271
3272 split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
3273
3274 __END__;
3275
3276 return split;
3277 }
3278
3279
read_node(CvFileStorage * fs,CvFileNode * fnode,CvDTreeNode * parent)3280 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
3281 {
3282 CvDTreeNode* node = 0;
3283
3284 CV_FUNCNAME( "CvDTree::read_node" );
3285
3286 __BEGIN__;
3287
3288 CvFileNode* splits;
3289 int i, depth;
3290
3291 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3292 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
3293
3294 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
3295 depth = cvReadIntByName( fs, fnode, "depth", -1 );
3296 if( depth != node->depth )
3297 CV_ERROR( CV_StsParseError, "incorrect node depth" );
3298
3299 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
3300 node->value = cvReadRealByName( fs, fnode, "value" );
3301 if( data->is_classifier )
3302 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
3303
3304 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
3305 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
3306 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
3307 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
3308 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
3309 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
3310
3311 splits = cvGetFileNodeByName( fs, fnode, "splits" );
3312 if( splits )
3313 {
3314 CvSeqReader reader;
3315 CvDTreeSplit* last_split = 0;
3316
3317 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
3318 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
3319
3320 cvStartReadSeq( splits->data.seq, &reader );
3321 for( i = 0; i < reader.seq->total; i++ )
3322 {
3323 CvDTreeSplit* split;
3324 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
3325 if( !last_split )
3326 node->split = last_split = split;
3327 else
3328 last_split = last_split->next = split;
3329
3330 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3331 }
3332 }
3333
3334 __END__;
3335
3336 return node;
3337 }
3338
3339
read_tree_nodes(CvFileStorage * fs,CvFileNode * fnode)3340 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
3341 {
3342 CV_FUNCNAME( "CvDTree::read_tree_nodes" );
3343
3344 __BEGIN__;
3345
3346 CvSeqReader reader;
3347 CvDTreeNode _root;
3348 CvDTreeNode* parent = &_root;
3349 int i;
3350 parent->left = parent->right = parent->parent = 0;
3351
3352 cvStartReadSeq( fnode->data.seq, &reader );
3353
3354 for( i = 0; i < reader.seq->total; i++ )
3355 {
3356 CvDTreeNode* node;
3357
3358 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
3359 if( !parent->left )
3360 parent->left = node;
3361 else
3362 parent->right = node;
3363 if( node->split )
3364 parent = node;
3365 else
3366 {
3367 while( parent && parent->right )
3368 parent = parent->parent;
3369 }
3370
3371 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3372 }
3373
3374 root = _root.left;
3375
3376 __END__;
3377 }
3378
3379
read(CvFileStorage * fs,CvFileNode * fnode)3380 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
3381 {
3382 CvDTreeTrainData* _data = new CvDTreeTrainData();
3383 _data->read_params( fs, fnode );
3384
3385 read( fs, fnode, _data );
3386 get_var_importance();
3387 }
3388
3389
3390 // a special entry point for reading weak decision trees from the tree ensembles
read(CvFileStorage * fs,CvFileNode * node,CvDTreeTrainData * _data)3391 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
3392 {
3393 CV_FUNCNAME( "CvDTree::read" );
3394
3395 __BEGIN__;
3396
3397 CvFileNode* tree_nodes;
3398
3399 clear();
3400 data = _data;
3401
3402 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
3403 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
3404 CV_ERROR( CV_StsParseError, "nodes tag is missing" );
3405
3406 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
3407 read_tree_nodes( fs, tree_nodes );
3408
3409 __END__;
3410 }
3411
3412 /* End of file. */
3413