1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "old_ml_precomp.hpp"
42 #include <ctype.h>
43
44 using namespace cv;
45
46 static const float ord_nan = FLT_MAX*0.5f;
47 static const int min_block_size = 1 << 16;
48 static const int block_size_delta = 1 << 10;
49
CvDTreeTrainData()50 CvDTreeTrainData::CvDTreeTrainData()
51 {
52 var_idx = var_type = cat_count = cat_ofs = cat_map =
53 priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
54 buf = 0;
55 tree_storage = temp_storage = 0;
56
57 clear();
58 }
59
60
CvDTreeTrainData(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,const CvDTreeParams & _params,bool _shared,bool _add_labels)61 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
62 const CvMat* _responses, const CvMat* _var_idx,
63 const CvMat* _sample_idx, const CvMat* _var_type,
64 const CvMat* _missing_mask, const CvDTreeParams& _params,
65 bool _shared, bool _add_labels )
66 {
67 var_idx = var_type = cat_count = cat_ofs = cat_map =
68 priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
69 buf = 0;
70
71 tree_storage = temp_storage = 0;
72
73 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
74 _var_type, _missing_mask, _params, _shared, _add_labels );
75 }
76
77
~CvDTreeTrainData()78 CvDTreeTrainData::~CvDTreeTrainData()
79 {
80 clear();
81 }
82
83
set_params(const CvDTreeParams & _params)84 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
85 {
86 bool ok = false;
87
88 CV_FUNCNAME( "CvDTreeTrainData::set_params" );
89
90 __BEGIN__;
91
92 // set parameters
93 params = _params;
94
95 if( params.max_categories < 2 )
96 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
97 params.max_categories = MIN( params.max_categories, 15 );
98
99 if( params.max_depth < 0 )
100 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
101 params.max_depth = MIN( params.max_depth, 25 );
102
103 params.min_sample_count = MAX(params.min_sample_count,1);
104
105 if( params.cv_folds < 0 )
106 CV_ERROR( CV_StsOutOfRange,
107 "params.cv_folds should be =0 (the tree is not pruned) "
108 "or n>0 (tree is pruned using n-fold cross-validation)" );
109
110 if( params.cv_folds == 1 )
111 params.cv_folds = 0;
112
113 if( params.regression_accuracy < 0 )
114 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
115
116 ok = true;
117
118 __END__;
119
120 return ok;
121 }
122
123 template<typename T>
124 class LessThanPtr
125 {
126 public:
operator ()(T * a,T * b) const127 bool operator()(T* a, T* b) const { return *a < *b; }
128 };
129
130 template<typename T, typename Idx>
131 class LessThanIdx
132 {
133 public:
LessThanIdx(const T * _arr)134 LessThanIdx( const T* _arr ) : arr(_arr) {}
operator ()(Idx a,Idx b) const135 bool operator()(Idx a, Idx b) const { return arr[a] < arr[b]; }
136 const T* arr;
137 };
138
139 class LessThanPairs
140 {
141 public:
operator ()(const CvPair16u32s & a,const CvPair16u32s & b) const142 bool operator()(const CvPair16u32s& a, const CvPair16u32s& b) const { return *a.i < *b.i; }
143 };
144
set_data(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,const CvDTreeParams & _params,bool _shared,bool _add_labels,bool _update_data)145 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
146 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
147 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
148 bool _shared, bool _add_labels, bool _update_data )
149 {
150 CvMat* sample_indices = 0;
151 CvMat* var_type0 = 0;
152 CvMat* tmp_map = 0;
153 int** int_ptr = 0;
154 CvPair16u32s* pair16u32s_ptr = 0;
155 CvDTreeTrainData* data = 0;
156 float *_fdst = 0;
157 int *_idst = 0;
158 unsigned short* udst = 0;
159 int* idst = 0;
160
161 CV_FUNCNAME( "CvDTreeTrainData::set_data" );
162
163 __BEGIN__;
164
165 int sample_all = 0, r_type, cv_n;
166 int total_c_count = 0;
167 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
168 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
169 int vi, i, size;
170 char err[100];
171 const int *sidx = 0, *vidx = 0;
172
173 uint64 effective_buf_size = 0;
174 int effective_buf_height = 0, effective_buf_width = 0;
175
176 if( _update_data && data_root )
177 {
178 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
179 _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
180
181 // compare new and old train data
182 if( !(data->var_count == var_count &&
183 cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
184 cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
185 cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
186 CV_ERROR( CV_StsBadArg,
187 "The new training data must have the same types and the input and output variables "
188 "and the same categories for categorical variables" );
189
190 cvReleaseMat( &priors );
191 cvReleaseMat( &priors_mult );
192 cvReleaseMat( &buf );
193 cvReleaseMat( &direction );
194 cvReleaseMat( &split_buf );
195 cvReleaseMemStorage( &temp_storage );
196
197 priors = data->priors; data->priors = 0;
198 priors_mult = data->priors_mult; data->priors_mult = 0;
199 buf = data->buf; data->buf = 0;
200 buf_count = data->buf_count; buf_size = data->buf_size;
201 sample_count = data->sample_count;
202
203 direction = data->direction; data->direction = 0;
204 split_buf = data->split_buf; data->split_buf = 0;
205 temp_storage = data->temp_storage; data->temp_storage = 0;
206 nv_heap = data->nv_heap; cv_heap = data->cv_heap;
207
208 data_root = new_node( 0, sample_count, 0, 0 );
209 EXIT;
210 }
211
212 clear();
213
214 var_all = 0;
215 rng = &cv::theRNG();
216
217 CV_CALL( set_params( _params ));
218
219 // check parameter types and sizes
220 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
221
222 train_data = _train_data;
223 responses = _responses;
224
225 if( _tflag == CV_ROW_SAMPLE )
226 {
227 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
228 dv_step = 1;
229 if( _missing_mask )
230 ms_step = _missing_mask->step, mv_step = 1;
231 }
232 else
233 {
234 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
235 ds_step = 1;
236 if( _missing_mask )
237 mv_step = _missing_mask->step, ms_step = 1;
238 }
239 tflag = _tflag;
240
241 sample_count = sample_all;
242 var_count = var_all;
243
244 if( _sample_idx )
245 {
246 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
247 sidx = sample_indices->data.i;
248 sample_count = sample_indices->rows + sample_indices->cols - 1;
249 }
250
251 if( _var_idx )
252 {
253 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
254 vidx = var_idx->data.i;
255 var_count = var_idx->rows + var_idx->cols - 1;
256 }
257
258 is_buf_16u = false;
259 if ( sample_count < 65536 )
260 is_buf_16u = true;
261
262 if( !CV_IS_MAT(_responses) ||
263 (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
264 CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
265 (_responses->rows != 1 && _responses->cols != 1) ||
266 _responses->rows + _responses->cols - 1 != sample_all )
267 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
268 "floating-point vector containing as many elements as "
269 "the total number of samples in the training data matrix" );
270
271 r_type = CV_VAR_CATEGORICAL;
272 if( _var_type )
273 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
274
275 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
276
277 cat_var_count = 0;
278 ord_var_count = -1;
279
280 is_classifier = r_type == CV_VAR_CATEGORICAL;
281
282 // step 0. calc the number of categorical vars
283 for( vi = 0; vi < var_count; vi++ )
284 {
285 char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
286 var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
287 }
288
289 ord_var_count = ~ord_var_count;
290 cv_n = params.cv_folds;
291 // set the two last elements of var_type array to be able
292 // to locate responses and cross-validation labels using
293 // the corresponding get_* functions.
294 var_type->data.i[var_count] = cat_var_count;
295 var_type->data.i[var_count+1] = cat_var_count+1;
296
297 // in case of single ordered predictor we need dummy cv_labels
298 // for safe split_node_data() operation
299 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
300
301 work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
302 + (have_labels ? 1 : 0); // for cv_labels
303
304 shared = _shared;
305 buf_count = shared ? 2 : 1;
306
307 buf_size = -1; // the member buf_size is obsolete
308
309 effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
310 effective_buf_width = sample_count;
311 effective_buf_height = work_var_count+1;
312
313 if (effective_buf_width >= effective_buf_height)
314 effective_buf_height *= buf_count;
315 else
316 effective_buf_width *= buf_count;
317
318 if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
319 {
320 CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
321 }
322
323
324
325 if ( is_buf_16u )
326 {
327 CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
328 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
329 }
330 else
331 {
332 CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
333 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
334 }
335
336 size = is_classifier ? (cat_var_count+1) : cat_var_count;
337 size = !size ? 1 : size;
338 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
339 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
340
341 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
342 size = !size ? 1 : size;
343 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
344
345 // now calculate the maximum size of split,
346 // create memory storage that will keep nodes and splits of the decision tree
347 // allocate root node and the buffer for the whole training data
348 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
349 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
350 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
351 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
352 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
353 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
354
355 nv_size = var_count*sizeof(int);
356 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
357
358 temp_block_size = nv_size;
359
360 if( cv_n )
361 {
362 if( sample_count < cv_n*MAX(params.min_sample_count,10) )
363 CV_ERROR( CV_StsOutOfRange,
364 "The many folds in cross-validation for such a small dataset" );
365
366 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
367 temp_block_size = MAX(temp_block_size, cv_size);
368 }
369
370 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
371 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
372 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
373 if( cv_size )
374 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
375
376 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
377
378 max_c_count = 1;
379
380 _fdst = 0;
381 _idst = 0;
382 if (ord_var_count)
383 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
384 if (is_buf_16u && (cat_var_count || is_classifier))
385 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
386
387 // transform the training data to convenient representation
388 for( vi = 0; vi <= var_count; vi++ )
389 {
390 int ci;
391 const uchar* mask = 0;
392 int64 m_step = 0, step;
393 const int* idata = 0;
394 const float* fdata = 0;
395 int num_valid = 0;
396
397 if( vi < var_count ) // analyze i-th input variable
398 {
399 int vi0 = vidx ? vidx[vi] : vi;
400 ci = get_var_type(vi);
401 step = ds_step; m_step = ms_step;
402 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
403 idata = _train_data->data.i + vi0*dv_step;
404 else
405 fdata = _train_data->data.fl + vi0*dv_step;
406 if( _missing_mask )
407 mask = _missing_mask->data.ptr + vi0*mv_step;
408 }
409 else // analyze _responses
410 {
411 ci = cat_var_count;
412 step = CV_IS_MAT_CONT(_responses->type) ?
413 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
414 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
415 idata = _responses->data.i;
416 else
417 fdata = _responses->data.fl;
418 }
419
420 if( (vi < var_count && ci>=0) ||
421 (vi == var_count && is_classifier) ) // process categorical variable or response
422 {
423 int c_count, prev_label;
424 int* c_map;
425
426 if (is_buf_16u)
427 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
428 else
429 idst = buf->data.i + (size_t)vi*sample_count;
430
431 // copy data
432 for( i = 0; i < sample_count; i++ )
433 {
434 int val = INT_MAX, si = sidx ? sidx[i] : i;
435 if( !mask || !mask[(size_t)si*m_step] )
436 {
437 if( idata )
438 val = idata[(size_t)si*step];
439 else
440 {
441 float t = fdata[(size_t)si*step];
442 val = cvRound(t);
443 if( fabs(t - val) > FLT_EPSILON )
444 {
445 sprintf( err, "%d-th value of %d-th (categorical) "
446 "variable is not an integer", i, vi );
447 CV_ERROR( CV_StsBadArg, err );
448 }
449 }
450
451 if( val == INT_MAX )
452 {
453 sprintf( err, "%d-th value of %d-th (categorical) "
454 "variable is too large", i, vi );
455 CV_ERROR( CV_StsBadArg, err );
456 }
457 num_valid++;
458 }
459 if (is_buf_16u)
460 {
461 _idst[i] = val;
462 pair16u32s_ptr[i].u = udst + i;
463 pair16u32s_ptr[i].i = _idst + i;
464 }
465 else
466 {
467 idst[i] = val;
468 int_ptr[i] = idst + i;
469 }
470 }
471
472 c_count = num_valid > 0;
473 if (is_buf_16u)
474 {
475 std::sort(pair16u32s_ptr, pair16u32s_ptr + sample_count, LessThanPairs());
476 // count the categories
477 for( i = 1; i < num_valid; i++ )
478 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
479 c_count ++ ;
480 }
481 else
482 {
483 std::sort(int_ptr, int_ptr + sample_count, LessThanPtr<int>());
484 // count the categories
485 for( i = 1; i < num_valid; i++ )
486 c_count += *int_ptr[i] != *int_ptr[i-1];
487 }
488
489 if( vi > 0 )
490 max_c_count = MAX( max_c_count, c_count );
491 cat_count->data.i[ci] = c_count;
492 cat_ofs->data.i[ci] = total_c_count;
493
494 // resize cat_map, if need
495 if( cat_map->cols < total_c_count + c_count )
496 {
497 tmp_map = cat_map;
498 CV_CALL( cat_map = cvCreateMat( 1,
499 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
500 for( i = 0; i < total_c_count; i++ )
501 cat_map->data.i[i] = tmp_map->data.i[i];
502 cvReleaseMat( &tmp_map );
503 }
504
505 c_map = cat_map->data.i + total_c_count;
506 total_c_count += c_count;
507
508 c_count = -1;
509 if (is_buf_16u)
510 {
511 // compact the class indices and build the map
512 prev_label = ~*pair16u32s_ptr[0].i;
513 for( i = 0; i < num_valid; i++ )
514 {
515 int cur_label = *pair16u32s_ptr[i].i;
516 if( cur_label != prev_label )
517 c_map[++c_count] = prev_label = cur_label;
518 *pair16u32s_ptr[i].u = (unsigned short)c_count;
519 }
520 // replace labels for missing values with -1
521 for( ; i < sample_count; i++ )
522 *pair16u32s_ptr[i].u = 65535;
523 }
524 else
525 {
526 // compact the class indices and build the map
527 prev_label = ~*int_ptr[0];
528 for( i = 0; i < num_valid; i++ )
529 {
530 int cur_label = *int_ptr[i];
531 if( cur_label != prev_label )
532 c_map[++c_count] = prev_label = cur_label;
533 *int_ptr[i] = c_count;
534 }
535 // replace labels for missing values with -1
536 for( ; i < sample_count; i++ )
537 *int_ptr[i] = -1;
538 }
539 }
540 else if( ci < 0 ) // process ordered variable
541 {
542 if (is_buf_16u)
543 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
544 else
545 idst = buf->data.i + (size_t)vi*sample_count;
546
547 for( i = 0; i < sample_count; i++ )
548 {
549 float val = ord_nan;
550 int si = sidx ? sidx[i] : i;
551 if( !mask || !mask[(size_t)si*m_step] )
552 {
553 if( idata )
554 val = (float)idata[(size_t)si*step];
555 else
556 val = fdata[(size_t)si*step];
557
558 if( fabs(val) >= ord_nan )
559 {
560 sprintf( err, "%d-th value of %d-th (ordered) "
561 "variable (=%g) is too large", i, vi, val );
562 CV_ERROR( CV_StsBadArg, err );
563 }
564 num_valid++;
565 }
566
567 if (is_buf_16u)
568 udst[i] = (unsigned short)i; // TODO: memory corruption may be here
569 else
570 idst[i] = i;
571 _fdst[i] = val;
572
573 }
574 if (is_buf_16u)
575 std::sort(udst, udst + sample_count, LessThanIdx<float, unsigned short>(_fdst));
576 else
577 std::sort(idst, idst + sample_count, LessThanIdx<float, int>(_fdst));
578 }
579
580 if( vi < var_count )
581 data_root->set_num_valid(vi, num_valid);
582 }
583
584 // set sample labels
585 if (is_buf_16u)
586 udst = (unsigned short*)(buf->data.s + (size_t)work_var_count*sample_count);
587 else
588 idst = buf->data.i + (size_t)work_var_count*sample_count;
589
590 for (i = 0; i < sample_count; i++)
591 {
592 if (udst)
593 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
594 else
595 idst[i] = sidx ? sidx[i] : i;
596 }
597
598 if( cv_n )
599 {
600 unsigned short* usdst = 0;
601 int* idst2 = 0;
602
603 if (is_buf_16u)
604 {
605 usdst = (unsigned short*)(buf->data.s + (size_t)(get_work_var_count()-1)*sample_count);
606 for( i = vi = 0; i < sample_count; i++ )
607 {
608 usdst[i] = (unsigned short)vi++;
609 vi &= vi < cv_n ? -1 : 0;
610 }
611
612 for( i = 0; i < sample_count; i++ )
613 {
614 int a = (*rng)(sample_count);
615 int b = (*rng)(sample_count);
616 unsigned short unsh = (unsigned short)vi;
617 CV_SWAP( usdst[a], usdst[b], unsh );
618 }
619 }
620 else
621 {
622 idst2 = buf->data.i + (size_t)(get_work_var_count()-1)*sample_count;
623 for( i = vi = 0; i < sample_count; i++ )
624 {
625 idst2[i] = vi++;
626 vi &= vi < cv_n ? -1 : 0;
627 }
628
629 for( i = 0; i < sample_count; i++ )
630 {
631 int a = (*rng)(sample_count);
632 int b = (*rng)(sample_count);
633 CV_SWAP( idst2[a], idst2[b], vi );
634 }
635 }
636 }
637
638 if ( cat_map )
639 cat_map->cols = MAX( total_c_count, 1 );
640
641 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
642 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
643 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
644
645 have_priors = is_classifier && params.priors;
646 if( is_classifier )
647 {
648 int m = get_num_classes();
649 double sum = 0;
650 CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
651 for( i = 0; i < m; i++ )
652 {
653 double val = have_priors ? params.priors[i] : 1.;
654 if( val <= 0 )
655 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
656 priors->data.db[i] = val;
657 sum += val;
658 }
659
660 // normalize weights
661 if( have_priors )
662 cvScale( priors, priors, 1./sum );
663
664 CV_CALL( priors_mult = cvCloneMat( priors ));
665 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
666 }
667
668
669 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
670 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
671
672 __END__;
673
674 if( data )
675 delete data;
676
677 if (_fdst)
678 cvFree( &_fdst );
679 if (_idst)
680 cvFree( &_idst );
681 cvFree( &int_ptr );
682 cvFree( &pair16u32s_ptr);
683 cvReleaseMat( &var_type0 );
684 cvReleaseMat( &sample_indices );
685 cvReleaseMat( &tmp_map );
686 }
687
do_responses_copy()688 void CvDTreeTrainData::do_responses_copy()
689 {
690 responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
691 cvCopy( responses, responses_copy);
692 responses = responses_copy;
693 }
694
subsample_data(const CvMat * _subsample_idx)695 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
696 {
697 CvDTreeNode* root = 0;
698 CvMat* isubsample_idx = 0;
699 CvMat* subsample_co = 0;
700
701 bool isMakeRootCopy = true;
702
703 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
704
705 __BEGIN__;
706
707 if( !data_root )
708 CV_ERROR( CV_StsError, "No training data has been set" );
709
710 if( _subsample_idx )
711 {
712 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
713
714 if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count )
715 {
716 const int* sidx = isubsample_idx->data.i;
717 for( int i = 0; i < sample_count; i++ )
718 {
719 if( sidx[i] != i )
720 {
721 isMakeRootCopy = false;
722 break;
723 }
724 }
725 }
726 else
727 isMakeRootCopy = false;
728 }
729
730 if( isMakeRootCopy )
731 {
732 // make a copy of the root node
733 CvDTreeNode temp;
734 int i;
735 root = new_node( 0, 1, 0, 0 );
736 temp = *root;
737 *root = *data_root;
738 root->num_valid = temp.num_valid;
739 if( root->num_valid )
740 {
741 for( i = 0; i < var_count; i++ )
742 root->num_valid[i] = data_root->num_valid[i];
743 }
744 root->cv_Tn = temp.cv_Tn;
745 root->cv_node_risk = temp.cv_node_risk;
746 root->cv_node_error = temp.cv_node_error;
747 }
748 else
749 {
750 int* sidx = isubsample_idx->data.i;
751 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
752 int* co, cur_ofs = 0;
753 int vi, i;
754 int workVarCount = get_work_var_count();
755 int count = isubsample_idx->rows + isubsample_idx->cols - 1;
756
757 root = new_node( 0, count, 1, 0 );
758
759 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
760 cvZero( subsample_co );
761 co = subsample_co->data.i;
762 for( i = 0; i < count; i++ )
763 co[sidx[i]*2]++;
764 for( i = 0; i < sample_count; i++ )
765 {
766 if( co[i*2] )
767 {
768 co[i*2+1] = cur_ofs;
769 cur_ofs += co[i*2];
770 }
771 else
772 co[i*2+1] = -1;
773 }
774
775 cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
776 for( vi = 0; vi < workVarCount; vi++ )
777 {
778 int ci = get_var_type(vi);
779
780 if( ci >= 0 || vi >= var_count )
781 {
782 int num_valid = 0;
783 const int* src = CvDTreeTrainData::get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf );
784
785 if (is_buf_16u)
786 {
787 unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
788 (size_t)vi*sample_count + root->offset);
789 for( i = 0; i < count; i++ )
790 {
791 int val = src[sidx[i]];
792 udst[i] = (unsigned short)val;
793 num_valid += val >= 0;
794 }
795 }
796 else
797 {
798 int* idst = buf->data.i + root->buf_idx*get_length_subbuf() +
799 (size_t)vi*sample_count + root->offset;
800 for( i = 0; i < count; i++ )
801 {
802 int val = src[sidx[i]];
803 idst[i] = val;
804 num_valid += val >= 0;
805 }
806 }
807
808 if( vi < var_count )
809 root->set_num_valid(vi, num_valid);
810 }
811 else
812 {
813 int *src_idx_buf = (int*)(uchar*)inn_buf;
814 float *src_val_buf = (float*)(src_idx_buf + sample_count);
815 int* sample_indices_buf = (int*)(src_val_buf + sample_count);
816 const int* src_idx = 0;
817 const float* src_val = 0;
818 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
819 int j = 0, idx, count_i;
820 int num_valid = data_root->get_num_valid(vi);
821
822 if (is_buf_16u)
823 {
824 unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
825 (size_t)vi*sample_count + data_root->offset);
826 for( i = 0; i < num_valid; i++ )
827 {
828 idx = src_idx[i];
829 count_i = co[idx*2];
830 if( count_i )
831 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
832 udst_idx[j] = (unsigned short)cur_ofs;
833 }
834
835 root->set_num_valid(vi, j);
836
837 for( ; i < sample_count; i++ )
838 {
839 idx = src_idx[i];
840 count_i = co[idx*2];
841 if( count_i )
842 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
843 udst_idx[j] = (unsigned short)cur_ofs;
844 }
845 }
846 else
847 {
848 int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() +
849 (size_t)vi*sample_count + root->offset;
850 for( i = 0; i < num_valid; i++ )
851 {
852 idx = src_idx[i];
853 count_i = co[idx*2];
854 if( count_i )
855 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
856 idst_idx[j] = cur_ofs;
857 }
858
859 root->set_num_valid(vi, j);
860
861 for( ; i < sample_count; i++ )
862 {
863 idx = src_idx[i];
864 count_i = co[idx*2];
865 if( count_i )
866 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
867 idst_idx[j] = cur_ofs;
868 }
869 }
870 }
871 }
872 // sample indices subsampling
873 const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
874 if (is_buf_16u)
875 {
876 unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
877 (size_t)workVarCount*sample_count + root->offset);
878 for (i = 0; i < count; i++)
879 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
880 }
881 else
882 {
883 int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() +
884 (size_t)workVarCount*sample_count + root->offset;
885 for (i = 0; i < count; i++)
886 sample_idx_dst[i] = sample_idx_src[sidx[i]];
887 }
888 }
889
890 __END__;
891
892 cvReleaseMat( &isubsample_idx );
893 cvReleaseMat( &subsample_co );
894
895 return root;
896 }
897
898
get_vectors(const CvMat * _subsample_idx,float * values,uchar * missing,float * _responses,bool get_class_idx)899 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
900 float* values, uchar* missing,
901 float* _responses, bool get_class_idx )
902 {
903 CvMat* subsample_idx = 0;
904 CvMat* subsample_co = 0;
905
906 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
907
908 __BEGIN__;
909
910 int i, vi, total = sample_count, count = total, cur_ofs = 0;
911 int* sidx = 0;
912 int* co = 0;
913
914 cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
915 if( _subsample_idx )
916 {
917 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
918 sidx = subsample_idx->data.i;
919 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
920 co = subsample_co->data.i;
921 cvZero( subsample_co );
922 count = subsample_idx->cols + subsample_idx->rows - 1;
923 for( i = 0; i < count; i++ )
924 co[sidx[i]*2]++;
925 for( i = 0; i < total; i++ )
926 {
927 int count_i = co[i*2];
928 if( count_i )
929 {
930 co[i*2+1] = cur_ofs*var_count;
931 cur_ofs += count_i;
932 }
933 }
934 }
935
936 if( missing )
937 memset( missing, 1, count*var_count );
938
939 for( vi = 0; vi < var_count; vi++ )
940 {
941 int ci = get_var_type(vi);
942 if( ci >= 0 ) // categorical
943 {
944 float* dst = values + vi;
945 uchar* m = missing ? missing + vi : 0;
946 const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf);
947
948 for( i = 0; i < count; i++, dst += var_count )
949 {
950 int idx = sidx ? sidx[i] : i;
951 int val = src[idx];
952 *dst = (float)val;
953 if( m )
954 {
955 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
956 m += var_count;
957 }
958 }
959 }
960 else // ordered
961 {
962 float* dst = values + vi;
963 uchar* m = missing ? missing + vi : 0;
964 int count1 = data_root->get_num_valid(vi);
965 float *src_val_buf = (float*)(uchar*)inn_buf;
966 int* src_idx_buf = (int*)(src_val_buf + sample_count);
967 int* sample_indices_buf = src_idx_buf + sample_count;
968 const float *src_val = 0;
969 const int* src_idx = 0;
970 get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
971
972 for( i = 0; i < count1; i++ )
973 {
974 int idx = src_idx[i];
975 int count_i = 1;
976 if( co )
977 {
978 count_i = co[idx*2];
979 cur_ofs = co[idx*2+1];
980 }
981 else
982 cur_ofs = idx*var_count;
983 if( count_i )
984 {
985 float val = src_val[i];
986 for( ; count_i > 0; count_i--, cur_ofs += var_count )
987 {
988 dst[cur_ofs] = val;
989 if( m )
990 m[cur_ofs] = 0;
991 }
992 }
993 }
994 }
995 }
996
997 // copy responses
998 if( _responses )
999 {
1000 if( is_classifier )
1001 {
1002 const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf);
1003 for( i = 0; i < count; i++ )
1004 {
1005 int idx = sidx ? sidx[i] : i;
1006 int val = get_class_idx ? src[idx] :
1007 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
1008 _responses[i] = (float)val;
1009 }
1010 }
1011 else
1012 {
1013 float* val_buf = (float*)(uchar*)inn_buf;
1014 int* sample_idx_buf = (int*)(val_buf + sample_count);
1015 const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
1016 for( i = 0; i < count; i++ )
1017 {
1018 int idx = sidx ? sidx[i] : i;
1019 _responses[i] = _values[idx];
1020 }
1021 }
1022 }
1023
1024 __END__;
1025
1026 cvReleaseMat( &subsample_idx );
1027 cvReleaseMat( &subsample_co );
1028 }
1029
1030
new_node(CvDTreeNode * parent,int count,int storage_idx,int offset)1031 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
1032 int storage_idx, int offset )
1033 {
1034 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
1035
1036 node->sample_count = count;
1037 node->depth = parent ? parent->depth + 1 : 0;
1038 node->parent = parent;
1039 node->left = node->right = 0;
1040 node->split = 0;
1041 node->value = 0;
1042 node->class_idx = 0;
1043 node->maxlr = 0.;
1044
1045 node->buf_idx = storage_idx;
1046 node->offset = offset;
1047 if( nv_heap )
1048 node->num_valid = (int*)cvSetNew( nv_heap );
1049 else
1050 node->num_valid = 0;
1051 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
1052 node->complexity = 0;
1053
1054 if( params.cv_folds > 0 && cv_heap )
1055 {
1056 int cv_n = params.cv_folds;
1057 node->Tn = INT_MAX;
1058 node->cv_Tn = (int*)cvSetNew( cv_heap );
1059 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
1060 node->cv_node_error = node->cv_node_risk + cv_n;
1061 }
1062 else
1063 {
1064 node->Tn = 0;
1065 node->cv_Tn = 0;
1066 node->cv_node_risk = 0;
1067 node->cv_node_error = 0;
1068 }
1069
1070 return node;
1071 }
1072
1073
new_split_ord(int vi,float cmp_val,int split_point,int inversed,float quality)1074 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
1075 int split_point, int inversed, float quality )
1076 {
1077 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1078 split->var_idx = vi;
1079 split->condensed_idx = INT_MIN;
1080 split->ord.c = cmp_val;
1081 split->ord.split_point = split_point;
1082 split->inversed = inversed;
1083 split->quality = quality;
1084 split->next = 0;
1085
1086 return split;
1087 }
1088
1089
new_split_cat(int vi,float quality)1090 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
1091 {
1092 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1093 int i, n = (max_c_count + 31)/32;
1094
1095 split->var_idx = vi;
1096 split->condensed_idx = INT_MIN;
1097 split->inversed = 0;
1098 split->quality = quality;
1099 for( i = 0; i < n; i++ )
1100 split->subset[i] = 0;
1101 split->next = 0;
1102
1103 return split;
1104 }
1105
1106
free_node(CvDTreeNode * node)1107 void CvDTreeTrainData::free_node( CvDTreeNode* node )
1108 {
1109 CvDTreeSplit* split = node->split;
1110 free_node_data( node );
1111 while( split )
1112 {
1113 CvDTreeSplit* next = split->next;
1114 cvSetRemoveByPtr( split_heap, split );
1115 split = next;
1116 }
1117 node->split = 0;
1118 cvSetRemoveByPtr( node_heap, node );
1119 }
1120
1121
free_node_data(CvDTreeNode * node)1122 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
1123 {
1124 if( node->num_valid )
1125 {
1126 cvSetRemoveByPtr( nv_heap, node->num_valid );
1127 node->num_valid = 0;
1128 }
1129 // do not free cv_* fields, as all the cross-validation related data is released at once.
1130 }
1131
1132
free_train_data()1133 void CvDTreeTrainData::free_train_data()
1134 {
1135 cvReleaseMat( &counts );
1136 cvReleaseMat( &buf );
1137 cvReleaseMat( &direction );
1138 cvReleaseMat( &split_buf );
1139 cvReleaseMemStorage( &temp_storage );
1140 cvReleaseMat( &responses_copy );
1141 cv_heap = nv_heap = 0;
1142 }
1143
1144
clear()1145 void CvDTreeTrainData::clear()
1146 {
1147 free_train_data();
1148
1149 cvReleaseMemStorage( &tree_storage );
1150
1151 cvReleaseMat( &var_idx );
1152 cvReleaseMat( &var_type );
1153 cvReleaseMat( &cat_count );
1154 cvReleaseMat( &cat_ofs );
1155 cvReleaseMat( &cat_map );
1156 cvReleaseMat( &priors );
1157 cvReleaseMat( &priors_mult );
1158
1159 node_heap = split_heap = 0;
1160
1161 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
1162 have_labels = have_priors = is_classifier = false;
1163
1164 buf_count = buf_size = 0;
1165 shared = false;
1166
1167 data_root = 0;
1168
1169 rng = &cv::theRNG();
1170 }
1171
1172
get_num_classes() const1173 int CvDTreeTrainData::get_num_classes() const
1174 {
1175 return is_classifier ? cat_count->data.i[cat_var_count] : 0;
1176 }
1177
1178
get_var_type(int vi) const1179 int CvDTreeTrainData::get_var_type(int vi) const
1180 {
1181 return var_type->data.i[vi];
1182 }
1183
get_ord_var_data(CvDTreeNode * n,int vi,float * ord_values_buf,int * sorted_indices_buf,const float ** ord_values,const int ** sorted_indices,int * sample_indices_buf)1184 void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
1185 const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
1186 {
1187 int vidx = var_idx ? var_idx->data.i[vi] : vi;
1188 int node_sample_count = n->sample_count;
1189 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
1190
1191 const int* sample_indices = get_sample_indices(n, sample_indices_buf);
1192
1193 if( !is_buf_16u )
1194 *sorted_indices = buf->data.i + n->buf_idx*get_length_subbuf() +
1195 (size_t)vi*sample_count + n->offset;
1196 else {
1197 const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
1198 (size_t)vi*sample_count + n->offset );
1199 for( int i = 0; i < node_sample_count; i++ )
1200 sorted_indices_buf[i] = short_indices[i];
1201 *sorted_indices = sorted_indices_buf;
1202 }
1203
1204 if( tflag == CV_ROW_SAMPLE )
1205 {
1206 for( int i = 0; i < node_sample_count &&
1207 ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1208 {
1209 int idx = (*sorted_indices)[i];
1210 idx = sample_indices[idx];
1211 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
1212 }
1213 }
1214 else
1215 for( int i = 0; i < node_sample_count &&
1216 ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1217 {
1218 int idx = (*sorted_indices)[i];
1219 idx = sample_indices[idx];
1220 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
1221 }
1222
1223 *ord_values = ord_values_buf;
1224 }
1225
1226
get_class_labels(CvDTreeNode * n,int * labels_buf)1227 const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
1228 {
1229 if (is_classifier)
1230 return get_cat_var_data( n, var_count, labels_buf);
1231 return 0;
1232 }
1233
get_sample_indices(CvDTreeNode * n,int * indices_buf)1234 const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
1235 {
1236 return get_cat_var_data( n, get_work_var_count(), indices_buf );
1237 }
1238
get_ord_responses(CvDTreeNode * n,float * values_buf,int * sample_indices_buf)1239 const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
1240 {
1241 int _sample_count = n->sample_count;
1242 int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
1243 const int* indices = get_sample_indices(n, sample_indices_buf);
1244
1245 for( int i = 0; i < _sample_count &&
1246 (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
1247 {
1248 int idx = indices[i];
1249 values_buf[i] = *(responses->data.fl + idx * r_step);
1250 }
1251
1252 return values_buf;
1253 }
1254
1255
get_cv_labels(CvDTreeNode * n,int * labels_buf)1256 const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
1257 {
1258 if (have_labels)
1259 return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
1260 return 0;
1261 }
1262
1263
get_cat_var_data(CvDTreeNode * n,int vi,int * cat_values_buf)1264 const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
1265 {
1266 const int* cat_values = 0;
1267 if( !is_buf_16u )
1268 cat_values = buf->data.i + n->buf_idx*get_length_subbuf() +
1269 (size_t)vi*sample_count + n->offset;
1270 else {
1271 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
1272 (size_t)vi*sample_count + n->offset);
1273 for( int i = 0; i < n->sample_count; i++ )
1274 cat_values_buf[i] = short_values[i];
1275 cat_values = cat_values_buf;
1276 }
1277 return cat_values;
1278 }
1279
1280
get_child_buf_idx(CvDTreeNode * n)1281 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
1282 {
1283 int idx = n->buf_idx + 1;
1284 if( idx >= buf_count )
1285 idx = shared ? 1 : 0;
1286 return idx;
1287 }
1288
1289
write_params(CvFileStorage * fs) const1290 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
1291 {
1292 CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1293
1294 __BEGIN__;
1295
1296 int vi, vcount = var_count;
1297
1298 cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1299 cvWriteInt( fs, "var_all", var_all );
1300 cvWriteInt( fs, "var_count", var_count );
1301 cvWriteInt( fs, "ord_var_count", ord_var_count );
1302 cvWriteInt( fs, "cat_var_count", cat_var_count );
1303
1304 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1305 cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1306
1307 if( is_classifier )
1308 {
1309 cvWriteInt( fs, "max_categories", params.max_categories );
1310 }
1311 else
1312 {
1313 cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1314 }
1315
1316 cvWriteInt( fs, "max_depth", params.max_depth );
1317 cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1318 cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1319
1320 if( params.cv_folds > 1 )
1321 {
1322 cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1323 cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1324 }
1325
1326 if( priors )
1327 cvWrite( fs, "priors", priors );
1328
1329 cvEndWriteStruct( fs );
1330
1331 if( var_idx )
1332 cvWrite( fs, "var_idx", var_idx );
1333
1334 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1335
1336 for( vi = 0; vi < vcount; vi++ )
1337 cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1338
1339 cvEndWriteStruct( fs );
1340
1341 if( cat_count && (cat_var_count > 0 || is_classifier) )
1342 {
1343 CV_ASSERT( cat_count != 0 );
1344 cvWrite( fs, "cat_count", cat_count );
1345 cvWrite( fs, "cat_map", cat_map );
1346 }
1347
1348 __END__;
1349 }
1350
1351
read_params(CvFileStorage * fs,CvFileNode * node)1352 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1353 {
1354 CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1355
1356 __BEGIN__;
1357
1358 CvFileNode *tparams_node, *vartype_node;
1359 CvSeqReader reader;
1360 int vi, max_split_size, tree_block_size;
1361
1362 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1363 var_all = cvReadIntByName( fs, node, "var_all" );
1364 var_count = cvReadIntByName( fs, node, "var_count", var_all );
1365 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1366 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1367
1368 tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1369
1370 if( tparams_node ) // training parameters are not necessary
1371 {
1372 params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1373
1374 if( is_classifier )
1375 {
1376 params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1377 }
1378 else
1379 {
1380 params.regression_accuracy =
1381 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1382 }
1383
1384 params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1385 params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1386 params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1387
1388 if( params.cv_folds > 1 )
1389 {
1390 params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1391 params.truncate_pruned_tree =
1392 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1393 }
1394
1395 priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1396 if( priors )
1397 {
1398 if( !CV_IS_MAT(priors) )
1399 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1400 priors_mult = cvCloneMat( priors );
1401 }
1402 }
1403
1404 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1405 if( var_idx )
1406 {
1407 if( !CV_IS_MAT(var_idx) ||
1408 (var_idx->cols != 1 && var_idx->rows != 1) ||
1409 var_idx->cols + var_idx->rows - 1 != var_count ||
1410 CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1411 CV_ERROR( CV_StsParseError,
1412 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1413
1414 for( vi = 0; vi < var_count; vi++ )
1415 if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1416 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1417 }
1418
1419 ////// read var type
1420 CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1421
1422 cat_var_count = 0;
1423 ord_var_count = -1;
1424 vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1425
1426 if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1427 var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1428 else
1429 {
1430 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1431 vartype_node->data.seq->total != var_count )
1432 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1433
1434 cvStartReadSeq( vartype_node->data.seq, &reader );
1435
1436 for( vi = 0; vi < var_count; vi++ )
1437 {
1438 CvFileNode* n = (CvFileNode*)reader.ptr;
1439 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1440 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1441 var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1442 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1443 }
1444 }
1445 var_type->data.i[var_count] = cat_var_count;
1446
1447 ord_var_count = ~ord_var_count;
1448 //////
1449
1450 if( cat_var_count > 0 || is_classifier )
1451 {
1452 int ccount, total_c_count = 0;
1453 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1454 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1455
1456 if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1457 (cat_count->cols != 1 && cat_count->rows != 1) ||
1458 CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1459 cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1460 (cat_map->cols != 1 && cat_map->rows != 1) ||
1461 CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1462 CV_ERROR( CV_StsParseError,
1463 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1464
1465 ccount = cat_var_count + is_classifier;
1466
1467 CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1468 cat_ofs->data.i[0] = 0;
1469 max_c_count = 1;
1470
1471 for( vi = 0; vi < ccount; vi++ )
1472 {
1473 int val = cat_count->data.i[vi];
1474 if( val <= 0 )
1475 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1476 max_c_count = MAX( max_c_count, val );
1477 cat_ofs->data.i[vi+1] = total_c_count += val;
1478 }
1479
1480 if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1481 CV_ERROR( CV_StsBadSize,
1482 "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1483 }
1484
1485 max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1486 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1487
1488 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1489 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1490 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1491 CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1492 sizeof(CvDTreeNode), tree_storage ));
1493 CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1494 max_split_size, tree_storage ));
1495
1496 __END__;
1497 }
1498
1499 /////////////////////// Decision Tree /////////////////////////
CvDTreeParams()1500 CvDTreeParams::CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
1501 cv_folds(10), use_surrogates(true), use_1se_rule(true),
1502 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
1503 {}
1504
CvDTreeParams(int _max_depth,int _min_sample_count,float _regression_accuracy,bool _use_surrogates,int _max_categories,int _cv_folds,bool _use_1se_rule,bool _truncate_pruned_tree,const float * _priors)1505 CvDTreeParams::CvDTreeParams( int _max_depth, int _min_sample_count,
1506 float _regression_accuracy, bool _use_surrogates,
1507 int _max_categories, int _cv_folds,
1508 bool _use_1se_rule, bool _truncate_pruned_tree,
1509 const float* _priors ) :
1510 max_categories(_max_categories), max_depth(_max_depth),
1511 min_sample_count(_min_sample_count), cv_folds (_cv_folds),
1512 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
1513 truncate_pruned_tree(_truncate_pruned_tree),
1514 regression_accuracy(_regression_accuracy),
1515 priors(_priors)
1516 {}
1517
CvDTree()1518 CvDTree::CvDTree()
1519 {
1520 data = 0;
1521 var_importance = 0;
1522 default_model_name = "my_tree";
1523
1524 clear();
1525 }
1526
1527
clear()1528 void CvDTree::clear()
1529 {
1530 cvReleaseMat( &var_importance );
1531 if( data )
1532 {
1533 if( !data->shared )
1534 delete data;
1535 else
1536 free_tree();
1537 data = 0;
1538 }
1539 root = 0;
1540 pruned_tree_idx = -1;
1541 }
1542
1543
~CvDTree()1544 CvDTree::~CvDTree()
1545 {
1546 clear();
1547 }
1548
1549
get_root() const1550 const CvDTreeNode* CvDTree::get_root() const
1551 {
1552 return root;
1553 }
1554
1555
get_pruned_tree_idx() const1556 int CvDTree::get_pruned_tree_idx() const
1557 {
1558 return pruned_tree_idx;
1559 }
1560
1561
get_data()1562 CvDTreeTrainData* CvDTree::get_data()
1563 {
1564 return data;
1565 }
1566
1567
train(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,CvDTreeParams _params)1568 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1569 const CvMat* _responses, const CvMat* _var_idx,
1570 const CvMat* _sample_idx, const CvMat* _var_type,
1571 const CvMat* _missing_mask, CvDTreeParams _params )
1572 {
1573 bool result = false;
1574
1575 CV_FUNCNAME( "CvDTree::train" );
1576
1577 __BEGIN__;
1578
1579 clear();
1580 data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1581 _var_idx, _sample_idx, _var_type,
1582 _missing_mask, _params, false );
1583 CV_CALL( result = do_train(0) );
1584
1585 __END__;
1586
1587 return result;
1588 }
1589
train(const Mat & _train_data,int _tflag,const Mat & _responses,const Mat & _var_idx,const Mat & _sample_idx,const Mat & _var_type,const Mat & _missing_mask,CvDTreeParams _params)1590 bool CvDTree::train( const Mat& _train_data, int _tflag,
1591 const Mat& _responses, const Mat& _var_idx,
1592 const Mat& _sample_idx, const Mat& _var_type,
1593 const Mat& _missing_mask, CvDTreeParams _params )
1594 {
1595 train_data_hdr = _train_data;
1596 train_data_mat = _train_data;
1597 responses_hdr = _responses;
1598 responses_mat = _responses;
1599
1600 CvMat vidx=_var_idx, sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask;
1601
1602 return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
1603 vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
1604 }
1605
1606
train(CvMLData * _data,CvDTreeParams _params)1607 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
1608 {
1609 bool result = false;
1610
1611 CV_FUNCNAME( "CvDTree::train" );
1612
1613 __BEGIN__;
1614
1615 const CvMat* values = _data->get_values();
1616 const CvMat* response = _data->get_responses();
1617 const CvMat* missing = _data->get_missing();
1618 const CvMat* var_types = _data->get_var_types();
1619 const CvMat* train_sidx = _data->get_train_sample_idx();
1620 const CvMat* var_idx = _data->get_var_idx();
1621
1622 CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1623 train_sidx, var_types, missing, _params ) );
1624
1625 __END__;
1626
1627 return result;
1628 }
1629
train(CvDTreeTrainData * _data,const CvMat * _subsample_idx)1630 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1631 {
1632 bool result = false;
1633
1634 CV_FUNCNAME( "CvDTree::train" );
1635
1636 __BEGIN__;
1637
1638 clear();
1639 data = _data;
1640 data->shared = true;
1641 CV_CALL( result = do_train(_subsample_idx));
1642
1643 __END__;
1644
1645 return result;
1646 }
1647
1648
do_train(const CvMat * _subsample_idx)1649 bool CvDTree::do_train( const CvMat* _subsample_idx )
1650 {
1651 bool result = false;
1652
1653 CV_FUNCNAME( "CvDTree::do_train" );
1654
1655 __BEGIN__;
1656
1657 root = data->subsample_data( _subsample_idx );
1658
1659 CV_CALL( try_split_node(root));
1660
1661 if( root->split )
1662 {
1663 CV_Assert( root->left );
1664 CV_Assert( root->right );
1665
1666 if( data->params.cv_folds > 0 )
1667 CV_CALL( prune_cv() );
1668
1669 if( !data->shared )
1670 data->free_train_data();
1671
1672 result = true;
1673 }
1674
1675 __END__;
1676
1677 return result;
1678 }
1679
1680
try_split_node(CvDTreeNode * node)1681 void CvDTree::try_split_node( CvDTreeNode* node )
1682 {
1683 CvDTreeSplit* best_split = 0;
1684 int i, n = node->sample_count, vi;
1685 bool can_split = true;
1686 double quality_scale;
1687
1688 calc_node_value( node );
1689
1690 if( node->sample_count <= data->params.min_sample_count ||
1691 node->depth >= data->params.max_depth )
1692 can_split = false;
1693
1694 if( can_split && data->is_classifier )
1695 {
1696 // check if we have a "pure" node,
1697 // we assume that cls_count is filled by calc_node_value()
1698 int* cls_count = data->counts->data.i;
1699 int nz = 0, m = data->get_num_classes();
1700 for( i = 0; i < m; i++ )
1701 nz += cls_count[i] != 0;
1702 if( nz == 1 ) // there is only one class
1703 can_split = false;
1704 }
1705 else if( can_split )
1706 {
1707 if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1708 can_split = false;
1709 }
1710
1711 if( can_split )
1712 {
1713 best_split = find_best_split(node);
1714 // TODO: check the split quality ...
1715 node->split = best_split;
1716 }
1717 if( !can_split || !best_split )
1718 {
1719 data->free_node_data(node);
1720 return;
1721 }
1722
1723 quality_scale = calc_node_dir( node );
1724 if( data->params.use_surrogates )
1725 {
1726 // find all the surrogate splits
1727 // and sort them by their similarity to the primary one
1728 for( vi = 0; vi < data->var_count; vi++ )
1729 {
1730 CvDTreeSplit* split;
1731 int ci = data->get_var_type(vi);
1732
1733 if( vi == best_split->var_idx )
1734 continue;
1735
1736 if( ci >= 0 )
1737 split = find_surrogate_split_cat( node, vi );
1738 else
1739 split = find_surrogate_split_ord( node, vi );
1740
1741 if( split )
1742 {
1743 // insert the split
1744 CvDTreeSplit* prev_split = node->split;
1745 split->quality = (float)(split->quality*quality_scale);
1746
1747 while( prev_split->next &&
1748 prev_split->next->quality > split->quality )
1749 prev_split = prev_split->next;
1750 split->next = prev_split->next;
1751 prev_split->next = split;
1752 }
1753 }
1754 }
1755 split_node_data( node );
1756 try_split_node( node->left );
1757 try_split_node( node->right );
1758 }
1759
1760
1761 // calculate direction (left(-1),right(1),missing(0))
1762 // for each sample using the best split
1763 // the function returns scale coefficients for surrogate split quality factors.
1764 // the scale is applied to normalize surrogate split quality relatively to the
1765 // best (primary) split quality. That is, if a surrogate split is absolutely
1766 // identical to the primary split, its quality will be set to the maximum value =
1767 // quality of the primary split; otherwise, it will be lower.
1768 // besides, the function compute node->maxlr,
1769 // minimum possible quality (w/o considering the above mentioned scale)
1770 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1771 // are not discarded.
calc_node_dir(CvDTreeNode * node)1772 double CvDTree::calc_node_dir( CvDTreeNode* node )
1773 {
1774 char* dir = (char*)data->direction->data.ptr;
1775 int i, n = node->sample_count, vi = node->split->var_idx;
1776 double L, R;
1777
1778 assert( !node->split->inversed );
1779
1780 if( data->get_var_type(vi) >= 0 ) // split on categorical var
1781 {
1782 cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
1783 int* labels_buf = (int*)inn_buf;
1784 const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1785 const int* subset = node->split->subset;
1786 if( !data->have_priors )
1787 {
1788 int sum = 0, sum_abs = 0;
1789
1790 for( i = 0; i < n; i++ )
1791 {
1792 int idx = labels[i];
1793 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
1794 CV_DTREE_CAT_DIR(idx,subset) : 0;
1795 sum += d; sum_abs += d & 1;
1796 dir[i] = (char)d;
1797 }
1798
1799 R = (sum_abs + sum) >> 1;
1800 L = (sum_abs - sum) >> 1;
1801 }
1802 else
1803 {
1804 const double* priors = data->priors_mult->data.db;
1805 double sum = 0, sum_abs = 0;
1806 int* responses_buf = labels_buf + n;
1807 const int* responses = data->get_class_labels(node, responses_buf);
1808
1809 for( i = 0; i < n; i++ )
1810 {
1811 int idx = labels[i];
1812 double w = priors[responses[i]];
1813 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1814 sum += d*w; sum_abs += (d & 1)*w;
1815 dir[i] = (char)d;
1816 }
1817
1818 R = (sum_abs + sum) * 0.5;
1819 L = (sum_abs - sum) * 0.5;
1820 }
1821 }
1822 else // split on ordered var
1823 {
1824 int split_point = node->split->ord.split_point;
1825 int n1 = node->get_num_valid(vi);
1826 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
1827 float* val_buf = (float*)(uchar*)inn_buf;
1828 int* sorted_buf = (int*)(val_buf + n);
1829 int* sample_idx_buf = sorted_buf + n;
1830 const float* val = 0;
1831 const int* sorted = 0;
1832 data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
1833
1834 assert( 0 <= split_point && split_point < n1-1 );
1835
1836 if( !data->have_priors )
1837 {
1838 for( i = 0; i <= split_point; i++ )
1839 dir[sorted[i]] = (char)-1;
1840 for( ; i < n1; i++ )
1841 dir[sorted[i]] = (char)1;
1842 for( ; i < n; i++ )
1843 dir[sorted[i]] = (char)0;
1844
1845 L = split_point-1;
1846 R = n1 - split_point + 1;
1847 }
1848 else
1849 {
1850 const double* priors = data->priors_mult->data.db;
1851 int* responses_buf = sample_idx_buf + n;
1852 const int* responses = data->get_class_labels(node, responses_buf);
1853 L = R = 0;
1854
1855 for( i = 0; i <= split_point; i++ )
1856 {
1857 int idx = sorted[i];
1858 double w = priors[responses[idx]];
1859 dir[idx] = (char)-1;
1860 L += w;
1861 }
1862
1863 for( ; i < n1; i++ )
1864 {
1865 int idx = sorted[i];
1866 double w = priors[responses[idx]];
1867 dir[idx] = (char)1;
1868 R += w;
1869 }
1870
1871 for( ; i < n; i++ )
1872 dir[sorted[i]] = (char)0;
1873 }
1874 }
1875 node->maxlr = MAX( L, R );
1876 return node->split->quality/(L + R);
1877 }
1878
1879
1880 namespace cv
1881 {
1882
operator ()(CvDTreeSplit * obj) const1883 template<> CV_EXPORTS void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const
1884 {
1885 fastFree(obj);
1886 }
1887
DTreeBestSplitFinder(CvDTree * _tree,CvDTreeNode * _node)1888 DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
1889 {
1890 tree = _tree;
1891 node = _node;
1892 splitSize = tree->get_data()->split_heap->elem_size;
1893
1894 bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
1895 memset(bestSplit.get(), 0, splitSize);
1896 bestSplit->quality = -1;
1897 bestSplit->condensed_idx = INT_MIN;
1898 split.reset((CvDTreeSplit*)fastMalloc(splitSize));
1899 memset(split.get(), 0, splitSize);
1900 //haveSplit = false;
1901 }
1902
DTreeBestSplitFinder(const DTreeBestSplitFinder & finder,Split)1903 DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
1904 {
1905 tree = finder.tree;
1906 node = finder.node;
1907 splitSize = tree->get_data()->split_heap->elem_size;
1908
1909 bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
1910 memcpy(bestSplit.get(), finder.bestSplit.get(), splitSize);
1911 split.reset((CvDTreeSplit*)fastMalloc(splitSize));
1912 memset(split.get(), 0, splitSize);
1913 }
1914
operator ()(const BlockedRange & range)1915 void DTreeBestSplitFinder::operator()(const BlockedRange& range)
1916 {
1917 int vi, vi1 = range.begin(), vi2 = range.end();
1918 int n = node->sample_count;
1919 CvDTreeTrainData* data = tree->get_data();
1920 AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
1921
1922 for( vi = vi1; vi < vi2; vi++ )
1923 {
1924 CvDTreeSplit *res;
1925 int ci = data->get_var_type(vi);
1926 if( node->get_num_valid(vi) <= 1 )
1927 continue;
1928
1929 if( data->is_classifier )
1930 {
1931 if( ci >= 0 )
1932 res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1933 else
1934 res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1935 }
1936 else
1937 {
1938 if( ci >= 0 )
1939 res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1940 else
1941 res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1942 }
1943
1944 if( res && bestSplit->quality < split->quality )
1945 memcpy( bestSplit.get(), split.get(), splitSize );
1946 }
1947 }
1948
join(DTreeBestSplitFinder & rhs)1949 void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
1950 {
1951 if( bestSplit->quality < rhs.bestSplit->quality )
1952 memcpy( bestSplit.get(), rhs.bestSplit.get(), splitSize );
1953 }
1954 }
1955
1956
find_best_split(CvDTreeNode * node)1957 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1958 {
1959 DTreeBestSplitFinder finder( this, node );
1960
1961 cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
1962
1963 CvDTreeSplit *bestSplit = 0;
1964 if( finder.bestSplit->quality > 0 )
1965 {
1966 bestSplit = data->new_split_cat( 0, -1.0f );
1967 memcpy( bestSplit, finder.bestSplit, finder.splitSize );
1968 }
1969
1970 return bestSplit;
1971 }
1972
find_split_ord_class(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)1973 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
1974 float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
1975 {
1976 const float epsilon = FLT_EPSILON*2;
1977 int n = node->sample_count;
1978 int n1 = node->get_num_valid(vi);
1979 int m = data->get_num_classes();
1980
1981 int base_size = 2*m*sizeof(int);
1982 cv::AutoBuffer<uchar> inn_buf(base_size);
1983 if( !_ext_buf )
1984 inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
1985 uchar* base_buf = (uchar*)inn_buf;
1986 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1987 float* values_buf = (float*)ext_buf;
1988 int* sorted_indices_buf = (int*)(values_buf + n);
1989 int* sample_indices_buf = sorted_indices_buf + n;
1990 const float* values = 0;
1991 const int* sorted_indices = 0;
1992 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
1993 &sorted_indices, sample_indices_buf );
1994 int* responses_buf = sample_indices_buf + n;
1995 const int* responses = data->get_class_labels( node, responses_buf );
1996
1997 const int* rc0 = data->counts->data.i;
1998 int* lc = (int*)base_buf;
1999 int* rc = lc + m;
2000 int i, best_i = -1;
2001 double lsum2 = 0, rsum2 = 0, best_val = init_quality;
2002 const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
2003
2004 // init arrays of class instance counters on both sides of the split
2005 for( i = 0; i < m; i++ )
2006 {
2007 lc[i] = 0;
2008 rc[i] = rc0[i];
2009 }
2010
2011 // compensate for missing values
2012 for( i = n1; i < n; i++ )
2013 {
2014 rc[responses[sorted_indices[i]]]--;
2015 }
2016
2017 if( !priors )
2018 {
2019 int L = 0, R = n1;
2020
2021 for( i = 0; i < m; i++ )
2022 rsum2 += (double)rc[i]*rc[i];
2023
2024 for( i = 0; i < n1 - 1; i++ )
2025 {
2026 int idx = responses[sorted_indices[i]];
2027 int lv, rv;
2028 L++; R--;
2029 lv = lc[idx]; rv = rc[idx];
2030 lsum2 += lv*2 + 1;
2031 rsum2 -= rv*2 - 1;
2032 lc[idx] = lv + 1; rc[idx] = rv - 1;
2033
2034 if( values[i] + epsilon < values[i+1] )
2035 {
2036 double val = (lsum2*R + rsum2*L)/((double)L*R);
2037 if( best_val < val )
2038 {
2039 best_val = val;
2040 best_i = i;
2041 }
2042 }
2043 }
2044 }
2045 else
2046 {
2047 double L = 0, R = 0;
2048 for( i = 0; i < m; i++ )
2049 {
2050 double wv = rc[i]*priors[i];
2051 R += wv;
2052 rsum2 += wv*wv;
2053 }
2054
2055 for( i = 0; i < n1 - 1; i++ )
2056 {
2057 int idx = responses[sorted_indices[i]];
2058 int lv, rv;
2059 double p = priors[idx], p2 = p*p;
2060 L += p; R -= p;
2061 lv = lc[idx]; rv = rc[idx];
2062 lsum2 += p2*(lv*2 + 1);
2063 rsum2 -= p2*(rv*2 - 1);
2064 lc[idx] = lv + 1; rc[idx] = rv - 1;
2065
2066 if( values[i] + epsilon < values[i+1] )
2067 {
2068 double val = (lsum2*R + rsum2*L)/((double)L*R);
2069 if( best_val < val )
2070 {
2071 best_val = val;
2072 best_i = i;
2073 }
2074 }
2075 }
2076 }
2077
2078 CvDTreeSplit* split = 0;
2079 if( best_i >= 0 )
2080 {
2081 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2082 split->var_idx = vi;
2083 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2084 split->ord.split_point = best_i;
2085 split->inversed = 0;
2086 split->quality = (float)best_val;
2087 }
2088 return split;
2089 }
2090
2091
cluster_categories(const int * vectors,int n,int m,int * csums,int k,int * labels)2092 void CvDTree::cluster_categories( const int* vectors, int n, int m,
2093 int* csums, int k, int* labels )
2094 {
2095 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
2096 int iters = 0, max_iters = 100;
2097 int i, j, idx;
2098 cv::AutoBuffer<double> buf(n + k);
2099 double *v_weights = buf, *c_weights = buf + n;
2100 bool modified = true;
2101 RNG* r = data->rng;
2102
2103 // assign labels randomly
2104 for( i = 0; i < n; i++ )
2105 {
2106 int sum = 0;
2107 const int* v = vectors + i*m;
2108 labels[i] = i < k ? i : r->uniform(0, k);
2109
2110 // compute weight of each vector
2111 for( j = 0; j < m; j++ )
2112 sum += v[j];
2113 v_weights[i] = sum ? 1./sum : 0.;
2114 }
2115
2116 for( i = 0; i < n; i++ )
2117 {
2118 int i1 = (*r)(n);
2119 int i2 = (*r)(n);
2120 CV_SWAP( labels[i1], labels[i2], j );
2121 }
2122
2123 for( iters = 0; iters <= max_iters; iters++ )
2124 {
2125 // calculate csums
2126 for( i = 0; i < k; i++ )
2127 {
2128 for( j = 0; j < m; j++ )
2129 csums[i*m + j] = 0;
2130 }
2131
2132 for( i = 0; i < n; i++ )
2133 {
2134 const int* v = vectors + i*m;
2135 int* s = csums + labels[i]*m;
2136 for( j = 0; j < m; j++ )
2137 s[j] += v[j];
2138 }
2139
2140 // exit the loop here, when we have up-to-date csums
2141 if( iters == max_iters || !modified )
2142 break;
2143
2144 modified = false;
2145
2146 // calculate weight of each cluster
2147 for( i = 0; i < k; i++ )
2148 {
2149 const int* s = csums + i*m;
2150 int sum = 0;
2151 for( j = 0; j < m; j++ )
2152 sum += s[j];
2153 c_weights[i] = sum ? 1./sum : 0;
2154 }
2155
2156 // now for each vector determine the closest cluster
2157 for( i = 0; i < n; i++ )
2158 {
2159 const int* v = vectors + i*m;
2160 double alpha = v_weights[i];
2161 double min_dist2 = DBL_MAX;
2162 int min_idx = -1;
2163
2164 for( idx = 0; idx < k; idx++ )
2165 {
2166 const int* s = csums + idx*m;
2167 double dist2 = 0., beta = c_weights[idx];
2168 for( j = 0; j < m; j++ )
2169 {
2170 double t = v[j]*alpha - s[j]*beta;
2171 dist2 += t*t;
2172 }
2173 if( min_dist2 > dist2 )
2174 {
2175 min_dist2 = dist2;
2176 min_idx = idx;
2177 }
2178 }
2179
2180 if( min_idx != labels[i] )
2181 modified = true;
2182 labels[i] = min_idx;
2183 }
2184 }
2185 }
2186
2187
find_split_cat_class(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)2188 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
2189 CvDTreeSplit* _split, uchar* _ext_buf )
2190 {
2191 int ci = data->get_var_type(vi);
2192 int n = node->sample_count;
2193 int m = data->get_num_classes();
2194 int _mi = data->cat_count->data.i[ci], mi = _mi;
2195
2196 int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
2197 if( m > 2 && mi > data->params.max_categories )
2198 base_size += (m*std::min(data->params.max_categories, n) + mi)*sizeof(int);
2199 else
2200 base_size += mi*sizeof(int*);
2201 cv::AutoBuffer<uchar> inn_buf(base_size);
2202 if( !_ext_buf )
2203 inn_buf.allocate(base_size + 2*n*sizeof(int));
2204 uchar* base_buf = (uchar*)inn_buf;
2205 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2206
2207 int* lc = (int*)base_buf;
2208 int* rc = lc + m;
2209 int* _cjk = rc + m*2, *cjk = _cjk;
2210 double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
2211
2212 int* labels_buf = (int*)ext_buf;
2213 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2214 int* responses_buf = labels_buf + n;
2215 const int* responses = data->get_class_labels(node, responses_buf);
2216
2217 int* cluster_labels = 0;
2218 int** int_ptr = 0;
2219 int i, j, k, idx;
2220 double L = 0, R = 0;
2221 double best_val = init_quality;
2222 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
2223 const double* priors = data->priors_mult->data.db;
2224
2225 // init array of counters:
2226 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
2227 for( j = -1; j < mi; j++ )
2228 for( k = 0; k < m; k++ )
2229 cjk[j*m + k] = 0;
2230
2231 for( i = 0; i < n; i++ )
2232 {
2233 j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
2234 k = responses[i];
2235 cjk[j*m + k]++;
2236 }
2237
2238 if( m > 2 )
2239 {
2240 if( mi > data->params.max_categories )
2241 {
2242 mi = MIN(data->params.max_categories, n);
2243 cjk = (int*)(c_weights + _mi);
2244 cluster_labels = cjk + m*mi;
2245 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
2246 }
2247 subset_i = 1;
2248 subset_n = 1 << mi;
2249 }
2250 else
2251 {
2252 assert( m == 2 );
2253 int_ptr = (int**)(c_weights + _mi);
2254 for( j = 0; j < mi; j++ )
2255 int_ptr[j] = cjk + j*2 + 1;
2256 std::sort(int_ptr, int_ptr + mi, LessThanPtr<int>());
2257 subset_i = 0;
2258 subset_n = mi;
2259 }
2260
2261 for( k = 0; k < m; k++ )
2262 {
2263 int sum = 0;
2264 for( j = 0; j < mi; j++ )
2265 sum += cjk[j*m + k];
2266 rc[k] = sum;
2267 lc[k] = 0;
2268 }
2269
2270 for( j = 0; j < mi; j++ )
2271 {
2272 double sum = 0;
2273 for( k = 0; k < m; k++ )
2274 sum += cjk[j*m + k]*priors[k];
2275 c_weights[j] = sum;
2276 R += c_weights[j];
2277 }
2278
2279 for( ; subset_i < subset_n; subset_i++ )
2280 {
2281 double weight;
2282 int* crow;
2283 double lsum2 = 0, rsum2 = 0;
2284
2285 if( m == 2 )
2286 idx = (int)(int_ptr[subset_i] - cjk)/2;
2287 else
2288 {
2289 int graycode = (subset_i>>1)^subset_i;
2290 int diff = graycode ^ prevcode;
2291
2292 // determine index of the changed bit.
2293 Cv32suf u;
2294 idx = diff >= (1 << 16) ? 16 : 0;
2295 u.f = (float)(((diff >> 16) | diff) & 65535);
2296 idx += (u.i >> 23) - 127;
2297 subtract = graycode < prevcode;
2298 prevcode = graycode;
2299 }
2300
2301 crow = cjk + idx*m;
2302 weight = c_weights[idx];
2303 if( weight < FLT_EPSILON )
2304 continue;
2305
2306 if( !subtract )
2307 {
2308 for( k = 0; k < m; k++ )
2309 {
2310 int t = crow[k];
2311 int lval = lc[k] + t;
2312 int rval = rc[k] - t;
2313 double p = priors[k], p2 = p*p;
2314 lsum2 += p2*lval*lval;
2315 rsum2 += p2*rval*rval;
2316 lc[k] = lval; rc[k] = rval;
2317 }
2318 L += weight;
2319 R -= weight;
2320 }
2321 else
2322 {
2323 for( k = 0; k < m; k++ )
2324 {
2325 int t = crow[k];
2326 int lval = lc[k] - t;
2327 int rval = rc[k] + t;
2328 double p = priors[k], p2 = p*p;
2329 lsum2 += p2*lval*lval;
2330 rsum2 += p2*rval*rval;
2331 lc[k] = lval; rc[k] = rval;
2332 }
2333 L -= weight;
2334 R += weight;
2335 }
2336
2337 if( L > FLT_EPSILON && R > FLT_EPSILON )
2338 {
2339 double val = (lsum2*R + rsum2*L)/((double)L*R);
2340 if( best_val < val )
2341 {
2342 best_val = val;
2343 best_subset = subset_i;
2344 }
2345 }
2346 }
2347
2348 CvDTreeSplit* split = 0;
2349 if( best_subset >= 0 )
2350 {
2351 split = _split ? _split : data->new_split_cat( 0, -1.0f );
2352 split->var_idx = vi;
2353 split->quality = (float)best_val;
2354 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2355 if( m == 2 )
2356 {
2357 for( i = 0; i <= best_subset; i++ )
2358 {
2359 idx = (int)(int_ptr[i] - cjk) >> 1;
2360 split->subset[idx >> 5] |= 1 << (idx & 31);
2361 }
2362 }
2363 else
2364 {
2365 for( i = 0; i < _mi; i++ )
2366 {
2367 idx = cluster_labels ? cluster_labels[i] : i;
2368 if( best_subset & (1 << idx) )
2369 split->subset[i >> 5] |= 1 << (i & 31);
2370 }
2371 }
2372 }
2373 return split;
2374 }
2375
2376
find_split_ord_reg(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)2377 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2378 {
2379 const float epsilon = FLT_EPSILON*2;
2380 int n = node->sample_count;
2381 int n1 = node->get_num_valid(vi);
2382
2383 cv::AutoBuffer<uchar> inn_buf;
2384 if( !_ext_buf )
2385 inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
2386 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2387 float* values_buf = (float*)ext_buf;
2388 int* sorted_indices_buf = (int*)(values_buf + n);
2389 int* sample_indices_buf = sorted_indices_buf + n;
2390 const float* values = 0;
2391 const int* sorted_indices = 0;
2392 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2393 float* responses_buf = (float*)(sample_indices_buf + n);
2394 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
2395
2396 int i, best_i = -1;
2397 double best_val = init_quality, lsum = 0, rsum = node->value*n;
2398 int L = 0, R = n1;
2399
2400 // compensate for missing values
2401 for( i = n1; i < n; i++ )
2402 rsum -= responses[sorted_indices[i]];
2403
2404 // find the optimal split
2405 for( i = 0; i < n1 - 1; i++ )
2406 {
2407 float t = responses[sorted_indices[i]];
2408 L++; R--;
2409 lsum += t;
2410 rsum -= t;
2411
2412 if( values[i] + epsilon < values[i+1] )
2413 {
2414 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2415 if( best_val < val )
2416 {
2417 best_val = val;
2418 best_i = i;
2419 }
2420 }
2421 }
2422
2423 CvDTreeSplit* split = 0;
2424 if( best_i >= 0 )
2425 {
2426 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2427 split->var_idx = vi;
2428 split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2429 split->ord.split_point = best_i;
2430 split->inversed = 0;
2431 split->quality = (float)best_val;
2432 }
2433 return split;
2434 }
2435
find_split_cat_reg(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)2436 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2437 {
2438 int ci = data->get_var_type(vi);
2439 int n = node->sample_count;
2440 int mi = data->cat_count->data.i[ci];
2441
2442 int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
2443 cv::AutoBuffer<uchar> inn_buf(base_size);
2444 if( !_ext_buf )
2445 inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
2446 uchar* base_buf = (uchar*)inn_buf;
2447 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2448 int* labels_buf = (int*)ext_buf;
2449 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2450 float* responses_buf = (float*)(labels_buf + n);
2451 int* sample_indices_buf = (int*)(responses_buf + n);
2452 const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
2453
2454 double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2455 int* counts = (int*)(sum + mi) + 1;
2456 double** sum_ptr = (double**)(counts + mi);
2457 int i, L = 0, R = 0;
2458 double best_val = init_quality, lsum = 0, rsum = 0;
2459 int best_subset = -1, subset_i;
2460
2461 for( i = -1; i < mi; i++ )
2462 sum[i] = counts[i] = 0;
2463
2464 // calculate sum response and weight of each category of the input var
2465 for( i = 0; i < n; i++ )
2466 {
2467 int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
2468 double s = sum[idx] + responses[i];
2469 int nc = counts[idx] + 1;
2470 sum[idx] = s;
2471 counts[idx] = nc;
2472 }
2473
2474 // calculate average response in each category
2475 for( i = 0; i < mi; i++ )
2476 {
2477 R += counts[i];
2478 rsum += sum[i];
2479 sum[i] /= MAX(counts[i],1);
2480 sum_ptr[i] = sum + i;
2481 }
2482
2483 std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());
2484
2485 // revert back to unnormalized sums
2486 // (there should be a very little loss of accuracy)
2487 for( i = 0; i < mi; i++ )
2488 sum[i] *= counts[i];
2489
2490 for( subset_i = 0; subset_i < mi-1; subset_i++ )
2491 {
2492 int idx = (int)(sum_ptr[subset_i] - sum);
2493 int ni = counts[idx];
2494
2495 if( ni )
2496 {
2497 double s = sum[idx];
2498 lsum += s; L += ni;
2499 rsum -= s; R -= ni;
2500
2501 if( L && R )
2502 {
2503 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2504 if( best_val < val )
2505 {
2506 best_val = val;
2507 best_subset = subset_i;
2508 }
2509 }
2510 }
2511 }
2512
2513 CvDTreeSplit* split = 0;
2514 if( best_subset >= 0 )
2515 {
2516 split = _split ? _split : data->new_split_cat( 0, -1.0f);
2517 split->var_idx = vi;
2518 split->quality = (float)best_val;
2519 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2520 for( i = 0; i <= best_subset; i++ )
2521 {
2522 int idx = (int)(sum_ptr[i] - sum);
2523 split->subset[idx >> 5] |= 1 << (idx & 31);
2524 }
2525 }
2526 return split;
2527 }
2528
find_surrogate_split_ord(CvDTreeNode * node,int vi,uchar * _ext_buf)2529 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
2530 {
2531 const float epsilon = FLT_EPSILON*2;
2532 const char* dir = (char*)data->direction->data.ptr;
2533 int n = node->sample_count, n1 = node->get_num_valid(vi);
2534 cv::AutoBuffer<uchar> inn_buf;
2535 if( !_ext_buf )
2536 inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
2537 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2538 float* values_buf = (float*)ext_buf;
2539 int* sorted_indices_buf = (int*)(values_buf + n);
2540 int* sample_indices_buf = sorted_indices_buf + n;
2541 const float* values = 0;
2542 const int* sorted_indices = 0;
2543 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2544 // LL - number of samples that both the primary and the surrogate splits send to the left
2545 // LR - ... primary split sends to the left and the surrogate split sends to the right
2546 // RL - ... primary split sends to the right and the surrogate split sends to the left
2547 // RR - ... both send to the right
2548 int i, best_i = -1, best_inversed = 0;
2549 double best_val;
2550
2551 if( !data->have_priors )
2552 {
2553 int LL = 0, RL = 0, LR, RR;
2554 int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2555 int sum = 0, sum_abs = 0;
2556
2557 for( i = 0; i < n1; i++ )
2558 {
2559 int d = dir[sorted_indices[i]];
2560 sum += d; sum_abs += d & 1;
2561 }
2562
2563 // sum_abs = R + L; sum = R - L
2564 RR = (sum_abs + sum) >> 1;
2565 LR = (sum_abs - sum) >> 1;
2566
2567 // initially all the samples are sent to the right by the surrogate split,
2568 // LR of them are sent to the left by primary split, and RR - to the right.
2569 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2570 for( i = 0; i < n1 - 1; i++ )
2571 {
2572 int d = dir[sorted_indices[i]];
2573
2574 if( d < 0 )
2575 {
2576 LL++; LR--;
2577 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
2578 {
2579 best_val = LL + RR;
2580 best_i = i; best_inversed = 0;
2581 }
2582 }
2583 else if( d > 0 )
2584 {
2585 RL++; RR--;
2586 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
2587 {
2588 best_val = RL + LR;
2589 best_i = i; best_inversed = 1;
2590 }
2591 }
2592 }
2593 best_val = _best_val;
2594 }
2595 else
2596 {
2597 double LL = 0, RL = 0, LR, RR;
2598 double worst_val = node->maxlr;
2599 double sum = 0, sum_abs = 0;
2600 const double* priors = data->priors_mult->data.db;
2601 int* responses_buf = sample_indices_buf + n;
2602 const int* responses = data->get_class_labels(node, responses_buf);
2603 best_val = worst_val;
2604
2605 for( i = 0; i < n1; i++ )
2606 {
2607 int idx = sorted_indices[i];
2608 double w = priors[responses[idx]];
2609 int d = dir[idx];
2610 sum += d*w; sum_abs += (d & 1)*w;
2611 }
2612
2613 // sum_abs = R + L; sum = R - L
2614 RR = (sum_abs + sum)*0.5;
2615 LR = (sum_abs - sum)*0.5;
2616
2617 // initially all the samples are sent to the right by the surrogate split,
2618 // LR of them are sent to the left by primary split, and RR - to the right.
2619 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2620 for( i = 0; i < n1 - 1; i++ )
2621 {
2622 int idx = sorted_indices[i];
2623 double w = priors[responses[idx]];
2624 int d = dir[idx];
2625
2626 if( d < 0 )
2627 {
2628 LL += w; LR -= w;
2629 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
2630 {
2631 best_val = LL + RR;
2632 best_i = i; best_inversed = 0;
2633 }
2634 }
2635 else if( d > 0 )
2636 {
2637 RL += w; RR -= w;
2638 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
2639 {
2640 best_val = RL + LR;
2641 best_i = i; best_inversed = 1;
2642 }
2643 }
2644 }
2645 }
2646 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2647 (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
2648 }
2649
2650
find_surrogate_split_cat(CvDTreeNode * node,int vi,uchar * _ext_buf)2651 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
2652 {
2653 const char* dir = (char*)data->direction->data.ptr;
2654 int n = node->sample_count;
2655 int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2656
2657 int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
2658 cv::AutoBuffer<uchar> inn_buf(base_size);
2659 if( !_ext_buf )
2660 inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
2661 uchar* base_buf = (uchar*)inn_buf;
2662 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2663
2664 int* labels_buf = (int*)ext_buf;
2665 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2666 // LL - number of samples that both the primary and the surrogate splits send to the left
2667 // LR - ... primary split sends to the left and the surrogate split sends to the right
2668 // RL - ... primary split sends to the right and the surrogate split sends to the left
2669 // RR - ... both send to the right
2670 CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2671 double best_val = 0;
2672 double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2673 double* rc = lc + mi + 1;
2674
2675 for( i = -1; i < mi; i++ )
2676 lc[i] = rc[i] = 0;
2677
2678 // for each category calculate the weight of samples
2679 // sent to the left (lc) and to the right (rc) by the primary split
2680 if( !data->have_priors )
2681 {
2682 int* _lc = (int*)rc + 1;
2683 int* _rc = _lc + mi + 1;
2684
2685 for( i = -1; i < mi; i++ )
2686 _lc[i] = _rc[i] = 0;
2687
2688 for( i = 0; i < n; i++ )
2689 {
2690 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2691 int d = dir[i];
2692 int sum = _lc[idx] + d;
2693 int sum_abs = _rc[idx] + (d & 1);
2694 _lc[idx] = sum; _rc[idx] = sum_abs;
2695 }
2696
2697 for( i = 0; i < mi; i++ )
2698 {
2699 int sum = _lc[i];
2700 int sum_abs = _rc[i];
2701 lc[i] = (sum_abs - sum) >> 1;
2702 rc[i] = (sum_abs + sum) >> 1;
2703 }
2704 }
2705 else
2706 {
2707 const double* priors = data->priors_mult->data.db;
2708 int* responses_buf = labels_buf + n;
2709 const int* responses = data->get_class_labels(node, responses_buf);
2710
2711 for( i = 0; i < n; i++ )
2712 {
2713 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2714 double w = priors[responses[i]];
2715 int d = dir[i];
2716 double sum = lc[idx] + d*w;
2717 double sum_abs = rc[idx] + (d & 1)*w;
2718 lc[idx] = sum; rc[idx] = sum_abs;
2719 }
2720
2721 for( i = 0; i < mi; i++ )
2722 {
2723 double sum = lc[i];
2724 double sum_abs = rc[i];
2725 lc[i] = (sum_abs - sum) * 0.5;
2726 rc[i] = (sum_abs + sum) * 0.5;
2727 }
2728 }
2729
2730 // 2. now form the split.
2731 // in each category send all the samples to the same direction as majority
2732 for( i = 0; i < mi; i++ )
2733 {
2734 double lval = lc[i], rval = rc[i];
2735 if( lval > rval )
2736 {
2737 split->subset[i >> 5] |= 1 << (i & 31);
2738 best_val += lval;
2739 l_win++;
2740 }
2741 else
2742 best_val += rval;
2743 }
2744
2745 split->quality = (float)best_val;
2746 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2747 cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2748
2749 return split;
2750 }
2751
2752
calc_node_value(CvDTreeNode * node)2753 void CvDTree::calc_node_value( CvDTreeNode* node )
2754 {
2755 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2756 int m = data->get_num_classes();
2757
2758 int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
2759 int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
2760 cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
2761 uchar* base_buf = (uchar*)inn_buf;
2762 uchar* ext_buf = base_buf + base_size;
2763
2764 int* cv_labels_buf = (int*)ext_buf;
2765 const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
2766
2767 if( data->is_classifier )
2768 {
2769 // in case of classification tree:
2770 // * node value is the label of the class that has the largest weight in the node.
2771 // * node risk is the weighted number of misclassified samples,
2772 // * j-th cross-validation fold value and risk are calculated as above,
2773 // but using the samples with cv_labels(*)!=j.
2774 // * j-th cross-validation fold error is calculated as the weighted number of
2775 // misclassified samples with cv_labels(*)==j.
2776
2777 // compute the number of instances of each class
2778 int* cls_count = data->counts->data.i;
2779 int* responses_buf = cv_labels_buf + n;
2780 const int* responses = data->get_class_labels(node, responses_buf);
2781 int* cv_cls_count = (int*)base_buf;
2782 double max_val = -1, total_weight = 0;
2783 int max_k = -1;
2784 double* priors = data->priors_mult->data.db;
2785
2786 for( k = 0; k < m; k++ )
2787 cls_count[k] = 0;
2788
2789 if( cv_n == 0 )
2790 {
2791 for( i = 0; i < n; i++ )
2792 cls_count[responses[i]]++;
2793 }
2794 else
2795 {
2796 for( j = 0; j < cv_n; j++ )
2797 for( k = 0; k < m; k++ )
2798 cv_cls_count[j*m + k] = 0;
2799
2800 for( i = 0; i < n; i++ )
2801 {
2802 j = cv_labels[i]; k = responses[i];
2803 cv_cls_count[j*m + k]++;
2804 }
2805
2806 for( j = 0; j < cv_n; j++ )
2807 for( k = 0; k < m; k++ )
2808 cls_count[k] += cv_cls_count[j*m + k];
2809 }
2810
2811 if( data->have_priors && node->parent == 0 )
2812 {
2813 // compute priors_mult from priors, take the sample ratio into account.
2814 double sum = 0;
2815 for( k = 0; k < m; k++ )
2816 {
2817 int n_k = cls_count[k];
2818 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2819 sum += priors[k];
2820 }
2821 sum = 1./sum;
2822 for( k = 0; k < m; k++ )
2823 priors[k] *= sum;
2824 }
2825
2826 for( k = 0; k < m; k++ )
2827 {
2828 double val = cls_count[k]*priors[k];
2829 total_weight += val;
2830 if( max_val < val )
2831 {
2832 max_val = val;
2833 max_k = k;
2834 }
2835 }
2836
2837 node->class_idx = max_k;
2838 node->value = data->cat_map->data.i[
2839 data->cat_ofs->data.i[data->cat_var_count] + max_k];
2840 node->node_risk = total_weight - max_val;
2841
2842 for( j = 0; j < cv_n; j++ )
2843 {
2844 double sum_k = 0, sum = 0, max_val_k = 0;
2845 max_val = -1; max_k = -1;
2846
2847 for( k = 0; k < m; k++ )
2848 {
2849 double w = priors[k];
2850 double val_k = cv_cls_count[j*m + k]*w;
2851 double val = cls_count[k]*w - val_k;
2852 sum_k += val_k;
2853 sum += val;
2854 if( max_val < val )
2855 {
2856 max_val = val;
2857 max_val_k = val_k;
2858 max_k = k;
2859 }
2860 }
2861
2862 node->cv_Tn[j] = INT_MAX;
2863 node->cv_node_risk[j] = sum - max_val;
2864 node->cv_node_error[j] = sum_k - max_val_k;
2865 }
2866 }
2867 else
2868 {
2869 // in case of regression tree:
2870 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2871 // n is the number of samples in the node.
2872 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2873 // * j-th cross-validation fold value and risk are calculated as above,
2874 // but using the samples with cv_labels(*)!=j.
2875 // * j-th cross-validation fold error is calculated
2876 // using samples with cv_labels(*)==j as the test subset:
2877 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2878 // where node_value_j is the node value calculated
2879 // as described in the previous bullet, and summation is done
2880 // over the samples with cv_labels(*)==j.
2881
2882 double sum = 0, sum2 = 0;
2883 float* values_buf = (float*)(cv_labels_buf + n);
2884 int* sample_indices_buf = (int*)(values_buf + n);
2885 const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
2886 double *cv_sum = 0, *cv_sum2 = 0;
2887 int* cv_count = 0;
2888
2889 if( cv_n == 0 )
2890 {
2891 for( i = 0; i < n; i++ )
2892 {
2893 double t = values[i];
2894 sum += t;
2895 sum2 += t*t;
2896 }
2897 }
2898 else
2899 {
2900 cv_sum = (double*)base_buf;
2901 cv_sum2 = cv_sum + cv_n;
2902 cv_count = (int*)(cv_sum2 + cv_n);
2903
2904 for( j = 0; j < cv_n; j++ )
2905 {
2906 cv_sum[j] = cv_sum2[j] = 0.;
2907 cv_count[j] = 0;
2908 }
2909
2910 for( i = 0; i < n; i++ )
2911 {
2912 j = cv_labels[i];
2913 double t = values[i];
2914 double s = cv_sum[j] + t;
2915 double s2 = cv_sum2[j] + t*t;
2916 int nc = cv_count[j] + 1;
2917 cv_sum[j] = s;
2918 cv_sum2[j] = s2;
2919 cv_count[j] = nc;
2920 }
2921
2922 for( j = 0; j < cv_n; j++ )
2923 {
2924 sum += cv_sum[j];
2925 sum2 += cv_sum2[j];
2926 }
2927 }
2928
2929 node->node_risk = sum2 - (sum/n)*sum;
2930 node->value = sum/n;
2931
2932 for( j = 0; j < cv_n; j++ )
2933 {
2934 double s = cv_sum[j], si = sum - s;
2935 double s2 = cv_sum2[j], s2i = sum2 - s2;
2936 int c = cv_count[j], ci = n - c;
2937 double r = si/MAX(ci,1);
2938 node->cv_node_risk[j] = s2i - r*r*ci;
2939 node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2940 node->cv_Tn[j] = INT_MAX;
2941 }
2942 }
2943 }
2944
2945
complete_node_dir(CvDTreeNode * node)2946 void CvDTree::complete_node_dir( CvDTreeNode* node )
2947 {
2948 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2949 int nz = n - node->get_num_valid(node->split->var_idx);
2950 char* dir = (char*)data->direction->data.ptr;
2951
2952 // try to complete direction using surrogate splits
2953 if( nz && data->params.use_surrogates )
2954 {
2955 cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
2956 CvDTreeSplit* split = node->split->next;
2957 for( ; split != 0 && nz; split = split->next )
2958 {
2959 int inversed_mask = split->inversed ? -1 : 0;
2960 vi = split->var_idx;
2961
2962 if( data->get_var_type(vi) >= 0 ) // split on categorical var
2963 {
2964 int* labels_buf = (int*)(uchar*)inn_buf;
2965 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2966 const int* subset = split->subset;
2967
2968 for( i = 0; i < n; i++ )
2969 {
2970 int idx = labels[i];
2971 if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
2972
2973 {
2974 int d = CV_DTREE_CAT_DIR(idx,subset);
2975 dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2976 if( --nz )
2977 break;
2978 }
2979 }
2980 }
2981 else // split on ordered var
2982 {
2983 float* values_buf = (float*)(uchar*)inn_buf;
2984 int* sorted_indices_buf = (int*)(values_buf + n);
2985 int* sample_indices_buf = sorted_indices_buf + n;
2986 const float* values = 0;
2987 const int* sorted_indices = 0;
2988 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2989 int split_point = split->ord.split_point;
2990 int n1 = node->get_num_valid(vi);
2991
2992 assert( 0 <= split_point && split_point < n-1 );
2993
2994 for( i = 0; i < n1; i++ )
2995 {
2996 int idx = sorted_indices[i];
2997 if( !dir[idx] )
2998 {
2999 int d = i <= split_point ? -1 : 1;
3000 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
3001 if( --nz )
3002 break;
3003 }
3004 }
3005 }
3006 }
3007 }
3008
3009 // find the default direction for the rest
3010 if( nz )
3011 {
3012 for( i = nr = 0; i < n; i++ )
3013 nr += dir[i] > 0;
3014 nl = n - nr - nz;
3015 d0 = nl > nr ? -1 : nr > nl;
3016 }
3017
3018 // make sure that every sample is directed either to the left or to the right
3019 for( i = 0; i < n; i++ )
3020 {
3021 int d = dir[i];
3022 if( !d )
3023 {
3024 d = d0;
3025 if( !d )
3026 d = d1, d1 = -d1;
3027 }
3028 d = d > 0;
3029 dir[i] = (char)d; // remap (-1,1) to (0,1)
3030 }
3031 }
3032
3033
split_node_data(CvDTreeNode * node)3034 void CvDTree::split_node_data( CvDTreeNode* node )
3035 {
3036 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
3037 char* dir = (char*)data->direction->data.ptr;
3038 CvDTreeNode *left = 0, *right = 0;
3039 int* new_idx = data->split_buf->data.i;
3040 int new_buf_idx = data->get_child_buf_idx( node );
3041 int work_var_count = data->get_work_var_count();
3042 CvMat* buf = data->buf;
3043 size_t length_buf_row = data->get_length_subbuf();
3044 cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
3045 int* temp_buf = (int*)(uchar*)inn_buf;
3046
3047 complete_node_dir(node);
3048
3049 for( i = nl = nr = 0; i < n; i++ )
3050 {
3051 int d = dir[i];
3052 // initialize new indices for splitting ordered variables
3053 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
3054 nr += d;
3055 nl += d^1;
3056 }
3057
3058 bool split_input_data;
3059 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
3060 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
3061
3062 split_input_data = node->depth + 1 < data->params.max_depth &&
3063 (node->left->sample_count > data->params.min_sample_count ||
3064 node->right->sample_count > data->params.min_sample_count);
3065
3066 // split ordered variables, keep both halves sorted.
3067 for( vi = 0; vi < data->var_count; vi++ )
3068 {
3069 int ci = data->get_var_type(vi);
3070
3071 if( ci >= 0 || !split_input_data )
3072 continue;
3073
3074 int n1 = node->get_num_valid(vi);
3075 float* src_val_buf = (float*)(uchar*)(temp_buf + n);
3076 int* src_sorted_idx_buf = (int*)(src_val_buf + n);
3077 int* src_sample_idx_buf = src_sorted_idx_buf + n;
3078 const float* src_val = 0;
3079 const int* src_sorted_idx = 0;
3080 data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
3081
3082 for(i = 0; i < n; i++)
3083 temp_buf[i] = src_sorted_idx[i];
3084
3085 if (data->is_buf_16u)
3086 {
3087 unsigned short *ldst, *rdst, *ldst0, *rdst0;
3088 //unsigned short tl, tr;
3089 ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
3090 vi*scount + left->offset);
3091 rdst0 = rdst = (unsigned short*)(ldst + nl);
3092
3093 // split sorted
3094 for( i = 0; i < n1; i++ )
3095 {
3096 int idx = temp_buf[i];
3097 int d = dir[idx];
3098 idx = new_idx[idx];
3099 if (d)
3100 {
3101 *rdst = (unsigned short)idx;
3102 rdst++;
3103 }
3104 else
3105 {
3106 *ldst = (unsigned short)idx;
3107 ldst++;
3108 }
3109 }
3110
3111 left->set_num_valid(vi, (int)(ldst - ldst0));
3112 right->set_num_valid(vi, (int)(rdst - rdst0));
3113
3114 // split missing
3115 for( ; i < n; i++ )
3116 {
3117 int idx = temp_buf[i];
3118 int d = dir[idx];
3119 idx = new_idx[idx];
3120 if (d)
3121 {
3122 *rdst = (unsigned short)idx;
3123 rdst++;
3124 }
3125 else
3126 {
3127 *ldst = (unsigned short)idx;
3128 ldst++;
3129 }
3130 }
3131 }
3132 else
3133 {
3134 int *ldst0, *ldst, *rdst0, *rdst;
3135 ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row +
3136 vi*scount + left->offset;
3137 rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row +
3138 vi*scount + right->offset;
3139
3140 // split sorted
3141 for( i = 0; i < n1; i++ )
3142 {
3143 int idx = temp_buf[i];
3144 int d = dir[idx];
3145 idx = new_idx[idx];
3146 if (d)
3147 {
3148 *rdst = idx;
3149 rdst++;
3150 }
3151 else
3152 {
3153 *ldst = idx;
3154 ldst++;
3155 }
3156 }
3157
3158 left->set_num_valid(vi, (int)(ldst - ldst0));
3159 right->set_num_valid(vi, (int)(rdst - rdst0));
3160
3161 // split missing
3162 for( ; i < n; i++ )
3163 {
3164 int idx = temp_buf[i];
3165 int d = dir[idx];
3166 idx = new_idx[idx];
3167 if (d)
3168 {
3169 *rdst = idx;
3170 rdst++;
3171 }
3172 else
3173 {
3174 *ldst = idx;
3175 ldst++;
3176 }
3177 }
3178 }
3179 }
3180
3181 // split categorical vars, responses and cv_labels using new_idx relocation table
3182 for( vi = 0; vi < work_var_count; vi++ )
3183 {
3184 int ci = data->get_var_type(vi);
3185 int n1 = node->get_num_valid(vi), nr1 = 0;
3186
3187 if( ci < 0 || (vi < data->var_count && !split_input_data) )
3188 continue;
3189
3190 int *src_lbls_buf = temp_buf + n;
3191 const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
3192
3193 for(i = 0; i < n; i++)
3194 temp_buf[i] = src_lbls[i];
3195
3196 if (data->is_buf_16u)
3197 {
3198 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
3199 vi*scount + left->offset);
3200 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
3201 vi*scount + right->offset);
3202
3203 for( i = 0; i < n; i++ )
3204 {
3205 int d = dir[i];
3206 int idx = temp_buf[i];
3207 if (d)
3208 {
3209 *rdst = (unsigned short)idx;
3210 rdst++;
3211 nr1 += (idx != 65535 )&d;
3212 }
3213 else
3214 {
3215 *ldst = (unsigned short)idx;
3216 ldst++;
3217 }
3218 }
3219
3220 if( vi < data->var_count )
3221 {
3222 left->set_num_valid(vi, n1 - nr1);
3223 right->set_num_valid(vi, nr1);
3224 }
3225 }
3226 else
3227 {
3228 int *ldst = buf->data.i + left->buf_idx*length_buf_row +
3229 vi*scount + left->offset;
3230 int *rdst = buf->data.i + right->buf_idx*length_buf_row +
3231 vi*scount + right->offset;
3232
3233 for( i = 0; i < n; i++ )
3234 {
3235 int d = dir[i];
3236 int idx = temp_buf[i];
3237 if (d)
3238 {
3239 *rdst = idx;
3240 rdst++;
3241 nr1 += (idx >= 0)&d;
3242 }
3243 else
3244 {
3245 *ldst = idx;
3246 ldst++;
3247 }
3248
3249 }
3250
3251 if( vi < data->var_count )
3252 {
3253 left->set_num_valid(vi, n1 - nr1);
3254 right->set_num_valid(vi, nr1);
3255 }
3256 }
3257 }
3258
3259
3260 // split sample indices
3261 int *sample_idx_src_buf = temp_buf + n;
3262 const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
3263
3264 for(i = 0; i < n; i++)
3265 temp_buf[i] = sample_idx_src[i];
3266
3267 int pos = data->get_work_var_count();
3268 if (data->is_buf_16u)
3269 {
3270 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
3271 pos*scount + left->offset);
3272 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
3273 pos*scount + right->offset);
3274 for (i = 0; i < n; i++)
3275 {
3276 int d = dir[i];
3277 unsigned short idx = (unsigned short)temp_buf[i];
3278 if (d)
3279 {
3280 *rdst = idx;
3281 rdst++;
3282 }
3283 else
3284 {
3285 *ldst = idx;
3286 ldst++;
3287 }
3288 }
3289 }
3290 else
3291 {
3292 int* ldst = buf->data.i + left->buf_idx*length_buf_row +
3293 pos*scount + left->offset;
3294 int* rdst = buf->data.i + right->buf_idx*length_buf_row +
3295 pos*scount + right->offset;
3296 for (i = 0; i < n; i++)
3297 {
3298 int d = dir[i];
3299 int idx = temp_buf[i];
3300 if (d)
3301 {
3302 *rdst = idx;
3303 rdst++;
3304 }
3305 else
3306 {
3307 *ldst = idx;
3308 ldst++;
3309 }
3310 }
3311 }
3312
3313 // deallocate the parent node data that is not needed anymore
3314 data->free_node_data(node);
3315 }
3316
calc_error(CvMLData * _data,int type,std::vector<float> * resp)3317 float CvDTree::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
3318 {
3319 float err = 0;
3320 const CvMat* values = _data->get_values();
3321 const CvMat* response = _data->get_responses();
3322 const CvMat* missing = _data->get_missing();
3323 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
3324 const CvMat* var_types = _data->get_var_types();
3325 int* sidx = sample_idx ? sample_idx->data.i : 0;
3326 int r_step = CV_IS_MAT_CONT(response->type) ?
3327 1 : response->step / CV_ELEM_SIZE(response->type);
3328 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
3329 int sample_count = sample_idx ? sample_idx->cols : 0;
3330 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
3331 float* pred_resp = 0;
3332 if( resp && (sample_count > 0) )
3333 {
3334 resp->resize( sample_count );
3335 pred_resp = &((*resp)[0]);
3336 }
3337
3338 if ( is_classifier )
3339 {
3340 for( int i = 0; i < sample_count; i++ )
3341 {
3342 CvMat sample, miss;
3343 int si = sidx ? sidx[i] : i;
3344 cvGetRow( values, &sample, si );
3345 if( missing )
3346 cvGetRow( missing, &miss, si );
3347 float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3348 if( pred_resp )
3349 pred_resp[i] = r;
3350 int d = fabs((double)r - response->data.fl[(size_t)si*r_step]) <= FLT_EPSILON ? 0 : 1;
3351 err += d;
3352 }
3353 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
3354 }
3355 else
3356 {
3357 for( int i = 0; i < sample_count; i++ )
3358 {
3359 CvMat sample, miss;
3360 int si = sidx ? sidx[i] : i;
3361 cvGetRow( values, &sample, si );
3362 if( missing )
3363 cvGetRow( missing, &miss, si );
3364 float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3365 if( pred_resp )
3366 pred_resp[i] = r;
3367 float d = r - response->data.fl[(size_t)si*r_step];
3368 err += d*d;
3369 }
3370 err = sample_count ? err / (float)sample_count : -FLT_MAX;
3371 }
3372 return err;
3373 }
3374
prune_cv()3375 void CvDTree::prune_cv()
3376 {
3377 CvMat* ab = 0;
3378 CvMat* temp = 0;
3379 CvMat* err_jk = 0;
3380
3381 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
3382 // 2. choose the best tree index (if need, apply 1SE rule).
3383 // 3. store the best index and cut the branches.
3384
3385 CV_FUNCNAME( "CvDTree::prune_cv" );
3386
3387 __BEGIN__;
3388
3389 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
3390 // currently, 1SE for regression is not implemented
3391 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
3392 double* err;
3393 double min_err = 0, min_err_se = 0;
3394 int min_idx = -1;
3395
3396 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
3397
3398 // build the main tree sequence, calculate alpha's
3399 for(;;tree_count++)
3400 {
3401 double min_alpha = update_tree_rnc(tree_count, -1);
3402 if( cut_tree(tree_count, -1, min_alpha) )
3403 break;
3404
3405 if( ab->cols <= tree_count )
3406 {
3407 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
3408 for( ti = 0; ti < ab->cols; ti++ )
3409 temp->data.db[ti] = ab->data.db[ti];
3410 cvReleaseMat( &ab );
3411 ab = temp;
3412 temp = 0;
3413 }
3414
3415 ab->data.db[tree_count] = min_alpha;
3416 }
3417
3418 ab->data.db[0] = 0.;
3419
3420 if( tree_count > 0 )
3421 {
3422 for( ti = 1; ti < tree_count-1; ti++ )
3423 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
3424 ab->data.db[tree_count-1] = DBL_MAX*0.5;
3425
3426 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
3427 err = err_jk->data.db;
3428
3429 for( j = 0; j < cv_n; j++ )
3430 {
3431 int tj = 0, tk = 0;
3432 for( ; tk < tree_count; tj++ )
3433 {
3434 double min_alpha = update_tree_rnc(tj, j);
3435 if( cut_tree(tj, j, min_alpha) )
3436 min_alpha = DBL_MAX;
3437
3438 for( ; tk < tree_count; tk++ )
3439 {
3440 if( ab->data.db[tk] > min_alpha )
3441 break;
3442 err[j*tree_count + tk] = root->tree_error;
3443 }
3444 }
3445 }
3446
3447 for( ti = 0; ti < tree_count; ti++ )
3448 {
3449 double sum_err = 0;
3450 for( j = 0; j < cv_n; j++ )
3451 sum_err += err[j*tree_count + ti];
3452 if( ti == 0 || sum_err < min_err )
3453 {
3454 min_err = sum_err;
3455 min_idx = ti;
3456 if( use_1se )
3457 min_err_se = sqrt( sum_err*(n - sum_err) );
3458 }
3459 else if( sum_err < min_err + min_err_se )
3460 min_idx = ti;
3461 }
3462 }
3463
3464 pruned_tree_idx = min_idx;
3465 free_prune_data(data->params.truncate_pruned_tree != 0);
3466
3467 __END__;
3468
3469 cvReleaseMat( &err_jk );
3470 cvReleaseMat( &ab );
3471 cvReleaseMat( &temp );
3472 }
3473
3474
update_tree_rnc(int T,int fold)3475 double CvDTree::update_tree_rnc( int T, int fold )
3476 {
3477 CvDTreeNode* node = root;
3478 double min_alpha = DBL_MAX;
3479
3480 for(;;)
3481 {
3482 CvDTreeNode* parent;
3483 for(;;)
3484 {
3485 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3486 if( t <= T || !node->left )
3487 {
3488 node->complexity = 1;
3489 node->tree_risk = node->node_risk;
3490 node->tree_error = 0.;
3491 if( fold >= 0 )
3492 {
3493 node->tree_risk = node->cv_node_risk[fold];
3494 node->tree_error = node->cv_node_error[fold];
3495 }
3496 break;
3497 }
3498 node = node->left;
3499 }
3500
3501 for( parent = node->parent; parent && parent->right == node;
3502 node = parent, parent = parent->parent )
3503 {
3504 parent->complexity += node->complexity;
3505 parent->tree_risk += node->tree_risk;
3506 parent->tree_error += node->tree_error;
3507
3508 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
3509 - parent->tree_risk)/(parent->complexity - 1);
3510 min_alpha = MIN( min_alpha, parent->alpha );
3511 }
3512
3513 if( !parent )
3514 break;
3515
3516 parent->complexity = node->complexity;
3517 parent->tree_risk = node->tree_risk;
3518 parent->tree_error = node->tree_error;
3519 node = parent->right;
3520 }
3521
3522 return min_alpha;
3523 }
3524
3525
cut_tree(int T,int fold,double min_alpha)3526 int CvDTree::cut_tree( int T, int fold, double min_alpha )
3527 {
3528 CvDTreeNode* node = root;
3529 if( !node->left )
3530 return 1;
3531
3532 for(;;)
3533 {
3534 CvDTreeNode* parent;
3535 for(;;)
3536 {
3537 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3538 if( t <= T || !node->left )
3539 break;
3540 if( node->alpha <= min_alpha + FLT_EPSILON )
3541 {
3542 if( fold >= 0 )
3543 node->cv_Tn[fold] = T;
3544 else
3545 node->Tn = T;
3546 if( node == root )
3547 return 1;
3548 break;
3549 }
3550 node = node->left;
3551 }
3552
3553 for( parent = node->parent; parent && parent->right == node;
3554 node = parent, parent = parent->parent )
3555 ;
3556
3557 if( !parent )
3558 break;
3559
3560 node = parent->right;
3561 }
3562
3563 return 0;
3564 }
3565
3566
free_prune_data(bool _cut_tree)3567 void CvDTree::free_prune_data(bool _cut_tree)
3568 {
3569 CvDTreeNode* node = root;
3570
3571 for(;;)
3572 {
3573 CvDTreeNode* parent;
3574 for(;;)
3575 {
3576 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
3577 // as we will clear the whole cross-validation heap at the end
3578 node->cv_Tn = 0;
3579 node->cv_node_error = node->cv_node_risk = 0;
3580 if( !node->left )
3581 break;
3582 node = node->left;
3583 }
3584
3585 for( parent = node->parent; parent && parent->right == node;
3586 node = parent, parent = parent->parent )
3587 {
3588 if( _cut_tree && parent->Tn <= pruned_tree_idx )
3589 {
3590 data->free_node( parent->left );
3591 data->free_node( parent->right );
3592 parent->left = parent->right = 0;
3593 }
3594 }
3595
3596 if( !parent )
3597 break;
3598
3599 node = parent->right;
3600 }
3601
3602 if( data->cv_heap )
3603 cvClearSet( data->cv_heap );
3604 }
3605
3606
free_tree()3607 void CvDTree::free_tree()
3608 {
3609 if( root && data && data->shared )
3610 {
3611 pruned_tree_idx = INT_MIN;
3612 free_prune_data(true);
3613 data->free_node(root);
3614 root = 0;
3615 }
3616 }
3617
predict(const CvMat * _sample,const CvMat * _missing,bool preprocessed_input) const3618 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
3619 const CvMat* _missing, bool preprocessed_input ) const
3620 {
3621 cv::AutoBuffer<int> catbuf;
3622
3623 int i, mstep = 0;
3624 const uchar* m = 0;
3625 CvDTreeNode* node = root;
3626
3627 if( !node )
3628 CV_Error( CV_StsError, "The tree has not been trained yet" );
3629
3630 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
3631 (_sample->cols != 1 && _sample->rows != 1) ||
3632 (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
3633 (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
3634 CV_Error( CV_StsBadArg,
3635 "the input sample must be 1d floating-point vector with the same "
3636 "number of elements as the total number of variables used for training" );
3637
3638 const float* sample = _sample->data.fl;
3639 int step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
3640
3641 if( data->cat_count && !preprocessed_input ) // cache for categorical variables
3642 {
3643 int n = data->cat_count->cols;
3644 catbuf.allocate(n);
3645 for( i = 0; i < n; i++ )
3646 catbuf[i] = -1;
3647 }
3648
3649 if( _missing )
3650 {
3651 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
3652 !CV_ARE_SIZES_EQ(_missing, _sample) )
3653 CV_Error( CV_StsBadArg,
3654 "the missing data mask must be 8-bit vector of the same size as input sample" );
3655 m = _missing->data.ptr;
3656 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
3657 }
3658
3659 const int* vtype = data->var_type->data.i;
3660 const int* vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
3661 const int* cmap = data->cat_map ? data->cat_map->data.i : 0;
3662 const int* cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
3663
3664 while( node->Tn > pruned_tree_idx && node->left )
3665 {
3666 CvDTreeSplit* split = node->split;
3667 int dir = 0;
3668 for( ; !dir && split != 0; split = split->next )
3669 {
3670 int vi = split->var_idx;
3671 int ci = vtype[vi];
3672 i = vidx ? vidx[vi] : vi;
3673 float val = sample[(size_t)i*step];
3674 if( m && m[(size_t)i*mstep] )
3675 continue;
3676 if( ci < 0 ) // ordered
3677 dir = val <= split->ord.c ? -1 : 1;
3678 else // categorical
3679 {
3680 int c;
3681 if( preprocessed_input )
3682 c = cvRound(val);
3683 else
3684 {
3685 c = catbuf[ci];
3686 if( c < 0 )
3687 {
3688 int a = c = cofs[ci];
3689 int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
3690
3691 int ival = cvRound(val);
3692 if( ival != val )
3693 CV_Error( CV_StsBadArg,
3694 "one of input categorical variable is not an integer" );
3695
3696 int sh = 0;
3697 while( a < b )
3698 {
3699 sh++;
3700 c = (a + b) >> 1;
3701 if( ival < cmap[c] )
3702 b = c;
3703 else if( ival > cmap[c] )
3704 a = c+1;
3705 else
3706 break;
3707 }
3708
3709 if( c < 0 || ival != cmap[c] )
3710 continue;
3711
3712 catbuf[ci] = c -= cofs[ci];
3713 }
3714 }
3715 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
3716 dir = CV_DTREE_CAT_DIR(c, split->subset);
3717 }
3718
3719 if( split->inversed )
3720 dir = -dir;
3721 }
3722
3723 if( !dir )
3724 {
3725 double diff = node->right->sample_count - node->left->sample_count;
3726 dir = diff < 0 ? -1 : 1;
3727 }
3728 node = dir < 0 ? node->left : node->right;
3729 }
3730
3731 return node;
3732 }
3733
3734
predict(const Mat & _sample,const Mat & _missing,bool preprocessed_input) const3735 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
3736 {
3737 CvMat sample = _sample, mmask = _missing;
3738 return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
3739 }
3740
3741
get_var_importance()3742 const CvMat* CvDTree::get_var_importance()
3743 {
3744 if( !var_importance )
3745 {
3746 CvDTreeNode* node = root;
3747 double* importance;
3748 if( !node )
3749 return 0;
3750 var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3751 cvZero( var_importance );
3752 importance = var_importance->data.db;
3753
3754 for(;;)
3755 {
3756 CvDTreeNode* parent;
3757 for( ;; node = node->left )
3758 {
3759 CvDTreeSplit* split = node->split;
3760
3761 if( !node->left || node->Tn <= pruned_tree_idx )
3762 break;
3763
3764 for( ; split != 0; split = split->next )
3765 importance[split->var_idx] += split->quality;
3766 }
3767
3768 for( parent = node->parent; parent && parent->right == node;
3769 node = parent, parent = parent->parent )
3770 ;
3771
3772 if( !parent )
3773 break;
3774
3775 node = parent->right;
3776 }
3777
3778 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3779 }
3780
3781 return var_importance;
3782 }
3783
3784
write_split(CvFileStorage * fs,CvDTreeSplit * split) const3785 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
3786 {
3787 int ci;
3788
3789 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3790 cvWriteInt( fs, "var", split->var_idx );
3791 cvWriteReal( fs, "quality", split->quality );
3792
3793 ci = data->get_var_type(split->var_idx);
3794 if( ci >= 0 ) // split on a categorical var
3795 {
3796 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3797 for( i = 0; i < n; i++ )
3798 to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3799
3800 // ad-hoc rule when to use inverse categorical split notation
3801 // to achieve more compact and clear representation
3802 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3803
3804 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3805 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3806
3807 for( i = 0; i < n; i++ )
3808 {
3809 int dir = CV_DTREE_CAT_DIR(i,split->subset);
3810 if( dir*default_dir < 0 )
3811 cvWriteInt( fs, 0, i );
3812 }
3813 cvEndWriteStruct( fs );
3814 }
3815 else
3816 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3817
3818 cvEndWriteStruct( fs );
3819 }
3820
3821
write_node(CvFileStorage * fs,CvDTreeNode * node) const3822 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
3823 {
3824 CvDTreeSplit* split;
3825
3826 cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3827
3828 cvWriteInt( fs, "depth", node->depth );
3829 cvWriteInt( fs, "sample_count", node->sample_count );
3830 cvWriteReal( fs, "value", node->value );
3831
3832 if( data->is_classifier )
3833 cvWriteInt( fs, "norm_class_idx", node->class_idx );
3834
3835 cvWriteInt( fs, "Tn", node->Tn );
3836 cvWriteInt( fs, "complexity", node->complexity );
3837 cvWriteReal( fs, "alpha", node->alpha );
3838 cvWriteReal( fs, "node_risk", node->node_risk );
3839 cvWriteReal( fs, "tree_risk", node->tree_risk );
3840 cvWriteReal( fs, "tree_error", node->tree_error );
3841
3842 if( node->left )
3843 {
3844 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3845
3846 for( split = node->split; split != 0; split = split->next )
3847 write_split( fs, split );
3848
3849 cvEndWriteStruct( fs );
3850 }
3851
3852 cvEndWriteStruct( fs );
3853 }
3854
3855
write_tree_nodes(CvFileStorage * fs) const3856 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
3857 {
3858 //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3859
3860 __BEGIN__;
3861
3862 CvDTreeNode* node = root;
3863
3864 // traverse the tree and save all the nodes in depth-first order
3865 for(;;)
3866 {
3867 CvDTreeNode* parent;
3868 for(;;)
3869 {
3870 write_node( fs, node );
3871 if( !node->left )
3872 break;
3873 node = node->left;
3874 }
3875
3876 for( parent = node->parent; parent && parent->right == node;
3877 node = parent, parent = parent->parent )
3878 ;
3879
3880 if( !parent )
3881 break;
3882
3883 node = parent->right;
3884 }
3885
3886 __END__;
3887 }
3888
3889
write(CvFileStorage * fs,const char * name) const3890 void CvDTree::write( CvFileStorage* fs, const char* name ) const
3891 {
3892 //CV_FUNCNAME( "CvDTree::write" );
3893
3894 __BEGIN__;
3895
3896 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3897
3898 //get_var_importance();
3899 data->write_params( fs );
3900 //if( var_importance )
3901 //cvWrite( fs, "var_importance", var_importance );
3902 write( fs );
3903
3904 cvEndWriteStruct( fs );
3905
3906 __END__;
3907 }
3908
3909
write(CvFileStorage * fs) const3910 void CvDTree::write( CvFileStorage* fs ) const
3911 {
3912 //CV_FUNCNAME( "CvDTree::write" );
3913
3914 __BEGIN__;
3915
3916 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3917
3918 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3919 write_tree_nodes( fs );
3920 cvEndWriteStruct( fs );
3921
3922 __END__;
3923 }
3924
3925
read_split(CvFileStorage * fs,CvFileNode * fnode)3926 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3927 {
3928 CvDTreeSplit* split = 0;
3929
3930 CV_FUNCNAME( "CvDTree::read_split" );
3931
3932 __BEGIN__;
3933
3934 int vi, ci;
3935
3936 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3937 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3938
3939 vi = cvReadIntByName( fs, fnode, "var", -1 );
3940 if( (unsigned)vi >= (unsigned)data->var_count )
3941 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3942
3943 ci = data->get_var_type(vi);
3944 if( ci >= 0 ) // split on categorical var
3945 {
3946 int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3947 CvSeqReader reader;
3948 CvFileNode* inseq;
3949 split = data->new_split_cat( vi, 0 );
3950 inseq = cvGetFileNodeByName( fs, fnode, "in" );
3951 if( !inseq )
3952 {
3953 inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3954 inversed = 1;
3955 }
3956 if( !inseq ||
3957 (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3958 CV_ERROR( CV_StsParseError,
3959 "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3960
3961 if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3962 {
3963 val = inseq->data.i;
3964 if( (unsigned)val >= (unsigned)n )
3965 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3966
3967 split->subset[val >> 5] |= 1 << (val & 31);
3968 }
3969 else
3970 {
3971 cvStartReadSeq( inseq->data.seq, &reader );
3972
3973 for( i = 0; i < reader.seq->total; i++ )
3974 {
3975 CvFileNode* inode = (CvFileNode*)reader.ptr;
3976 val = inode->data.i;
3977 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3978 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3979
3980 split->subset[val >> 5] |= 1 << (val & 31);
3981 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3982 }
3983 }
3984
3985 // for categorical splits we do not use inversed splits,
3986 // instead we inverse the variable set in the split
3987 if( inversed )
3988 for( i = 0; i < (n + 31) >> 5; i++ )
3989 split->subset[i] ^= -1;
3990 }
3991 else
3992 {
3993 CvFileNode* cmp_node;
3994 split = data->new_split_ord( vi, 0, 0, 0, 0 );
3995
3996 cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3997 if( !cmp_node )
3998 {
3999 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
4000 split->inversed = 1;
4001 }
4002
4003 split->ord.c = (float)cvReadReal( cmp_node );
4004 }
4005
4006 split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
4007
4008 __END__;
4009
4010 return split;
4011 }
4012
4013
read_node(CvFileStorage * fs,CvFileNode * fnode,CvDTreeNode * parent)4014 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
4015 {
4016 CvDTreeNode* node = 0;
4017
4018 CV_FUNCNAME( "CvDTree::read_node" );
4019
4020 __BEGIN__;
4021
4022 CvFileNode* splits;
4023 int i, depth;
4024
4025 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
4026 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
4027
4028 CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
4029 depth = cvReadIntByName( fs, fnode, "depth", -1 );
4030 if( depth != node->depth )
4031 CV_ERROR( CV_StsParseError, "incorrect node depth" );
4032
4033 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
4034 node->value = cvReadRealByName( fs, fnode, "value" );
4035 if( data->is_classifier )
4036 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
4037
4038 node->Tn = cvReadIntByName( fs, fnode, "Tn" );
4039 node->complexity = cvReadIntByName( fs, fnode, "complexity" );
4040 node->alpha = cvReadRealByName( fs, fnode, "alpha" );
4041 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
4042 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
4043 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
4044
4045 splits = cvGetFileNodeByName( fs, fnode, "splits" );
4046 if( splits )
4047 {
4048 CvSeqReader reader;
4049 CvDTreeSplit* last_split = 0;
4050
4051 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
4052 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
4053
4054 cvStartReadSeq( splits->data.seq, &reader );
4055 for( i = 0; i < reader.seq->total; i++ )
4056 {
4057 CvDTreeSplit* split;
4058 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
4059 if( !last_split )
4060 node->split = last_split = split;
4061 else
4062 last_split = last_split->next = split;
4063
4064 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4065 }
4066 }
4067
4068 __END__;
4069
4070 return node;
4071 }
4072
4073
read_tree_nodes(CvFileStorage * fs,CvFileNode * fnode)4074 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
4075 {
4076 CV_FUNCNAME( "CvDTree::read_tree_nodes" );
4077
4078 __BEGIN__;
4079
4080 CvSeqReader reader;
4081 CvDTreeNode _root;
4082 CvDTreeNode* parent = &_root;
4083 int i;
4084 parent->left = parent->right = parent->parent = 0;
4085
4086 cvStartReadSeq( fnode->data.seq, &reader );
4087
4088 for( i = 0; i < reader.seq->total; i++ )
4089 {
4090 CvDTreeNode* node;
4091
4092 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
4093 if( !parent->left )
4094 parent->left = node;
4095 else
4096 parent->right = node;
4097 if( node->split )
4098 parent = node;
4099 else
4100 {
4101 while( parent && parent->right )
4102 parent = parent->parent;
4103 }
4104
4105 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4106 }
4107
4108 root = _root.left;
4109
4110 __END__;
4111 }
4112
4113
read(CvFileStorage * fs,CvFileNode * fnode)4114 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
4115 {
4116 CvDTreeTrainData* _data = new CvDTreeTrainData();
4117 _data->read_params( fs, fnode );
4118
4119 read( fs, fnode, _data );
4120 get_var_importance();
4121 }
4122
4123
4124 // a special entry point for reading weak decision trees from the tree ensembles
read(CvFileStorage * fs,CvFileNode * node,CvDTreeTrainData * _data)4125 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
4126 {
4127 CV_FUNCNAME( "CvDTree::read" );
4128
4129 __BEGIN__;
4130
4131 CvFileNode* tree_nodes;
4132
4133 clear();
4134 data = _data;
4135
4136 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
4137 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
4138 CV_ERROR( CV_StsParseError, "nodes tag is missing" );
4139
4140 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
4141 read_tree_nodes( fs, tree_nodes );
4142
4143 __END__;
4144 }
4145
getVarImportance()4146 Mat CvDTree::getVarImportance()
4147 {
4148 return cvarrToMat(get_var_importance());
4149 }
4150
4151 /* End of file. */
4152