• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #include "old_ml_precomp.hpp"
42 #include <ctype.h>
43 
44 using namespace cv;
45 
46 static const float ord_nan = FLT_MAX*0.5f;
47 static const int min_block_size = 1 << 16;
48 static const int block_size_delta = 1 << 10;
49 
CvDTreeTrainData()50 CvDTreeTrainData::CvDTreeTrainData()
51 {
52     var_idx = var_type = cat_count = cat_ofs = cat_map =
53         priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
54     buf = 0;
55     tree_storage = temp_storage = 0;
56 
57     clear();
58 }
59 
60 
CvDTreeTrainData(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,const CvDTreeParams & _params,bool _shared,bool _add_labels)61 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
62                       const CvMat* _responses, const CvMat* _var_idx,
63                       const CvMat* _sample_idx, const CvMat* _var_type,
64                       const CvMat* _missing_mask, const CvDTreeParams& _params,
65                       bool _shared, bool _add_labels )
66 {
67     var_idx = var_type = cat_count = cat_ofs = cat_map =
68         priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
69     buf = 0;
70 
71     tree_storage = temp_storage = 0;
72 
73     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
74               _var_type, _missing_mask, _params, _shared, _add_labels );
75 }
76 
77 
~CvDTreeTrainData()78 CvDTreeTrainData::~CvDTreeTrainData()
79 {
80     clear();
81 }
82 
83 
set_params(const CvDTreeParams & _params)84 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
85 {
86     bool ok = false;
87 
88     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
89 
90     __BEGIN__;
91 
92     // set parameters
93     params = _params;
94 
95     if( params.max_categories < 2 )
96         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
97     params.max_categories = MIN( params.max_categories, 15 );
98 
99     if( params.max_depth < 0 )
100         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
101     params.max_depth = MIN( params.max_depth, 25 );
102 
103     params.min_sample_count = MAX(params.min_sample_count,1);
104 
105     if( params.cv_folds < 0 )
106         CV_ERROR( CV_StsOutOfRange,
107         "params.cv_folds should be =0 (the tree is not pruned) "
108         "or n>0 (tree is pruned using n-fold cross-validation)" );
109 
110     if( params.cv_folds == 1 )
111         params.cv_folds = 0;
112 
113     if( params.regression_accuracy < 0 )
114         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
115 
116     ok = true;
117 
118     __END__;
119 
120     return ok;
121 }
122 
123 template<typename T>
124 class LessThanPtr
125 {
126 public:
operator ()(T * a,T * b) const127     bool operator()(T* a, T* b) const { return *a < *b; }
128 };
129 
130 template<typename T, typename Idx>
131 class LessThanIdx
132 {
133 public:
LessThanIdx(const T * _arr)134     LessThanIdx( const T* _arr ) : arr(_arr) {}
operator ()(Idx a,Idx b) const135     bool operator()(Idx a, Idx b) const { return arr[a] < arr[b]; }
136     const T* arr;
137 };
138 
139 class LessThanPairs
140 {
141 public:
operator ()(const CvPair16u32s & a,const CvPair16u32s & b) const142     bool operator()(const CvPair16u32s& a, const CvPair16u32s& b) const { return *a.i < *b.i; }
143 };
144 
set_data(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,const CvDTreeParams & _params,bool _shared,bool _add_labels,bool _update_data)145 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
146     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
147     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
148     bool _shared, bool _add_labels, bool _update_data )
149 {
150     CvMat* sample_indices = 0;
151     CvMat* var_type0 = 0;
152     CvMat* tmp_map = 0;
153     int** int_ptr = 0;
154     CvPair16u32s* pair16u32s_ptr = 0;
155     CvDTreeTrainData* data = 0;
156     float *_fdst = 0;
157     int *_idst = 0;
158     unsigned short* udst = 0;
159     int* idst = 0;
160 
161     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
162 
163     __BEGIN__;
164 
165     int sample_all = 0, r_type, cv_n;
166     int total_c_count = 0;
167     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
168     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
169     int vi, i, size;
170     char err[100];
171     const int *sidx = 0, *vidx = 0;
172 
173     uint64 effective_buf_size = 0;
174     int effective_buf_height = 0, effective_buf_width = 0;
175 
176     if( _update_data && data_root )
177     {
178         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
179             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
180 
181         // compare new and old train data
182         if( !(data->var_count == var_count &&
183             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
184             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
185             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
186             CV_ERROR( CV_StsBadArg,
187             "The new training data must have the same types and the input and output variables "
188             "and the same categories for categorical variables" );
189 
190         cvReleaseMat( &priors );
191         cvReleaseMat( &priors_mult );
192         cvReleaseMat( &buf );
193         cvReleaseMat( &direction );
194         cvReleaseMat( &split_buf );
195         cvReleaseMemStorage( &temp_storage );
196 
197         priors = data->priors; data->priors = 0;
198         priors_mult = data->priors_mult; data->priors_mult = 0;
199         buf = data->buf; data->buf = 0;
200         buf_count = data->buf_count; buf_size = data->buf_size;
201         sample_count = data->sample_count;
202 
203         direction = data->direction; data->direction = 0;
204         split_buf = data->split_buf; data->split_buf = 0;
205         temp_storage = data->temp_storage; data->temp_storage = 0;
206         nv_heap = data->nv_heap; cv_heap = data->cv_heap;
207 
208         data_root = new_node( 0, sample_count, 0, 0 );
209         EXIT;
210     }
211 
212     clear();
213 
214     var_all = 0;
215     rng = &cv::theRNG();
216 
217     CV_CALL( set_params( _params ));
218 
219     // check parameter types and sizes
220     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
221 
222     train_data = _train_data;
223     responses = _responses;
224 
225     if( _tflag == CV_ROW_SAMPLE )
226     {
227         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
228         dv_step = 1;
229         if( _missing_mask )
230             ms_step = _missing_mask->step, mv_step = 1;
231     }
232     else
233     {
234         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
235         ds_step = 1;
236         if( _missing_mask )
237             mv_step = _missing_mask->step, ms_step = 1;
238     }
239     tflag = _tflag;
240 
241     sample_count = sample_all;
242     var_count = var_all;
243 
244     if( _sample_idx )
245     {
246         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
247         sidx = sample_indices->data.i;
248         sample_count = sample_indices->rows + sample_indices->cols - 1;
249     }
250 
251     if( _var_idx )
252     {
253         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
254         vidx = var_idx->data.i;
255         var_count = var_idx->rows + var_idx->cols - 1;
256     }
257 
258     is_buf_16u = false;
259     if ( sample_count < 65536 )
260         is_buf_16u = true;
261 
262     if( !CV_IS_MAT(_responses) ||
263         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
264          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
265         (_responses->rows != 1 && _responses->cols != 1) ||
266         _responses->rows + _responses->cols - 1 != sample_all )
267         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
268                   "floating-point vector containing as many elements as "
269                   "the total number of samples in the training data matrix" );
270 
271     r_type = CV_VAR_CATEGORICAL;
272     if( _var_type )
273         CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
274 
275     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
276 
277     cat_var_count = 0;
278     ord_var_count = -1;
279 
280     is_classifier = r_type == CV_VAR_CATEGORICAL;
281 
282     // step 0. calc the number of categorical vars
283     for( vi = 0; vi < var_count; vi++ )
284     {
285         char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
286         var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
287     }
288 
289     ord_var_count = ~ord_var_count;
290     cv_n = params.cv_folds;
291     // set the two last elements of var_type array to be able
292     // to locate responses and cross-validation labels using
293     // the corresponding get_* functions.
294     var_type->data.i[var_count] = cat_var_count;
295     var_type->data.i[var_count+1] = cat_var_count+1;
296 
297     // in case of single ordered predictor we need dummy cv_labels
298     // for safe split_node_data() operation
299     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
300 
301     work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
302                                + (have_labels ? 1 : 0); // for cv_labels
303 
304     shared = _shared;
305     buf_count = shared ? 2 : 1;
306 
307     buf_size = -1; // the member buf_size is obsolete
308 
309     effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
310     effective_buf_width = sample_count;
311     effective_buf_height = work_var_count+1;
312 
313     if (effective_buf_width >= effective_buf_height)
314         effective_buf_height *= buf_count;
315     else
316         effective_buf_width *= buf_count;
317 
318     if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
319     {
320         CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
321     }
322 
323 
324 
325     if ( is_buf_16u )
326     {
327         CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
328         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
329     }
330     else
331     {
332         CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
333         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
334     }
335 
336     size = is_classifier ? (cat_var_count+1) : cat_var_count;
337     size = !size ? 1 : size;
338     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
339     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
340 
341     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
342     size = !size ? 1 : size;
343     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
344 
345     // now calculate the maximum size of split,
346     // create memory storage that will keep nodes and splits of the decision tree
347     // allocate root node and the buffer for the whole training data
348     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
349         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
350     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
351     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
352     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
353     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
354 
355     nv_size = var_count*sizeof(int);
356     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
357 
358     temp_block_size = nv_size;
359 
360     if( cv_n )
361     {
362         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
363             CV_ERROR( CV_StsOutOfRange,
364                 "The many folds in cross-validation for such a small dataset" );
365 
366         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
367         temp_block_size = MAX(temp_block_size, cv_size);
368     }
369 
370     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
371     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
372     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
373     if( cv_size )
374         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
375 
376     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
377 
378     max_c_count = 1;
379 
380     _fdst = 0;
381     _idst = 0;
382     if (ord_var_count)
383         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
384     if (is_buf_16u && (cat_var_count || is_classifier))
385         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
386 
387     // transform the training data to convenient representation
388     for( vi = 0; vi <= var_count; vi++ )
389     {
390         int ci;
391         const uchar* mask = 0;
392         int64 m_step = 0, step;
393         const int* idata = 0;
394         const float* fdata = 0;
395         int num_valid = 0;
396 
397         if( vi < var_count ) // analyze i-th input variable
398         {
399             int vi0 = vidx ? vidx[vi] : vi;
400             ci = get_var_type(vi);
401             step = ds_step; m_step = ms_step;
402             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
403                 idata = _train_data->data.i + vi0*dv_step;
404             else
405                 fdata = _train_data->data.fl + vi0*dv_step;
406             if( _missing_mask )
407                 mask = _missing_mask->data.ptr + vi0*mv_step;
408         }
409         else // analyze _responses
410         {
411             ci = cat_var_count;
412             step = CV_IS_MAT_CONT(_responses->type) ?
413                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
414             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
415                 idata = _responses->data.i;
416             else
417                 fdata = _responses->data.fl;
418         }
419 
420         if( (vi < var_count && ci>=0) ||
421             (vi == var_count && is_classifier) ) // process categorical variable or response
422         {
423             int c_count, prev_label;
424             int* c_map;
425 
426             if (is_buf_16u)
427                 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
428             else
429                 idst = buf->data.i + (size_t)vi*sample_count;
430 
431             // copy data
432             for( i = 0; i < sample_count; i++ )
433             {
434                 int val = INT_MAX, si = sidx ? sidx[i] : i;
435                 if( !mask || !mask[(size_t)si*m_step] )
436                 {
437                     if( idata )
438                         val = idata[(size_t)si*step];
439                     else
440                     {
441                         float t = fdata[(size_t)si*step];
442                         val = cvRound(t);
443                         if( fabs(t - val) > FLT_EPSILON )
444                         {
445                             sprintf( err, "%d-th value of %d-th (categorical) "
446                                 "variable is not an integer", i, vi );
447                             CV_ERROR( CV_StsBadArg, err );
448                         }
449                     }
450 
451                     if( val == INT_MAX )
452                     {
453                         sprintf( err, "%d-th value of %d-th (categorical) "
454                             "variable is too large", i, vi );
455                         CV_ERROR( CV_StsBadArg, err );
456                     }
457                     num_valid++;
458                 }
459                 if (is_buf_16u)
460                 {
461                     _idst[i] = val;
462                     pair16u32s_ptr[i].u = udst + i;
463                     pair16u32s_ptr[i].i = _idst + i;
464                 }
465                 else
466                 {
467                     idst[i] = val;
468                     int_ptr[i] = idst + i;
469                 }
470             }
471 
472             c_count = num_valid > 0;
473             if (is_buf_16u)
474             {
475                 std::sort(pair16u32s_ptr, pair16u32s_ptr + sample_count, LessThanPairs());
476                 // count the categories
477                 for( i = 1; i < num_valid; i++ )
478                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
479                         c_count ++ ;
480             }
481             else
482             {
483                 std::sort(int_ptr, int_ptr + sample_count, LessThanPtr<int>());
484                 // count the categories
485                 for( i = 1; i < num_valid; i++ )
486                     c_count += *int_ptr[i] != *int_ptr[i-1];
487             }
488 
489             if( vi > 0 )
490                 max_c_count = MAX( max_c_count, c_count );
491             cat_count->data.i[ci] = c_count;
492             cat_ofs->data.i[ci] = total_c_count;
493 
494             // resize cat_map, if need
495             if( cat_map->cols < total_c_count + c_count )
496             {
497                 tmp_map = cat_map;
498                 CV_CALL( cat_map = cvCreateMat( 1,
499                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
500                 for( i = 0; i < total_c_count; i++ )
501                     cat_map->data.i[i] = tmp_map->data.i[i];
502                 cvReleaseMat( &tmp_map );
503             }
504 
505             c_map = cat_map->data.i + total_c_count;
506             total_c_count += c_count;
507 
508             c_count = -1;
509             if (is_buf_16u)
510             {
511                 // compact the class indices and build the map
512                 prev_label = ~*pair16u32s_ptr[0].i;
513                 for( i = 0; i < num_valid; i++ )
514                 {
515                     int cur_label = *pair16u32s_ptr[i].i;
516                     if( cur_label != prev_label )
517                         c_map[++c_count] = prev_label = cur_label;
518                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
519                 }
520                 // replace labels for missing values with -1
521                 for( ; i < sample_count; i++ )
522                     *pair16u32s_ptr[i].u = 65535;
523             }
524             else
525             {
526                 // compact the class indices and build the map
527                 prev_label = ~*int_ptr[0];
528                 for( i = 0; i < num_valid; i++ )
529                 {
530                     int cur_label = *int_ptr[i];
531                     if( cur_label != prev_label )
532                         c_map[++c_count] = prev_label = cur_label;
533                     *int_ptr[i] = c_count;
534                 }
535                 // replace labels for missing values with -1
536                 for( ; i < sample_count; i++ )
537                     *int_ptr[i] = -1;
538             }
539         }
540         else if( ci < 0 ) // process ordered variable
541         {
542             if (is_buf_16u)
543                 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
544             else
545                 idst = buf->data.i + (size_t)vi*sample_count;
546 
547             for( i = 0; i < sample_count; i++ )
548             {
549                 float val = ord_nan;
550                 int si = sidx ? sidx[i] : i;
551                 if( !mask || !mask[(size_t)si*m_step] )
552                 {
553                     if( idata )
554                         val = (float)idata[(size_t)si*step];
555                     else
556                         val = fdata[(size_t)si*step];
557 
558                     if( fabs(val) >= ord_nan )
559                     {
560                         sprintf( err, "%d-th value of %d-th (ordered) "
561                             "variable (=%g) is too large", i, vi, val );
562                         CV_ERROR( CV_StsBadArg, err );
563                     }
564                     num_valid++;
565                 }
566 
567                 if (is_buf_16u)
568                     udst[i] = (unsigned short)i; // TODO: memory corruption may be here
569                 else
570                     idst[i] = i;
571                 _fdst[i] = val;
572 
573             }
574             if (is_buf_16u)
575                 std::sort(udst, udst + sample_count, LessThanIdx<float, unsigned short>(_fdst));
576             else
577                 std::sort(idst, idst + sample_count, LessThanIdx<float, int>(_fdst));
578         }
579 
580         if( vi < var_count )
581             data_root->set_num_valid(vi, num_valid);
582     }
583 
584     // set sample labels
585     if (is_buf_16u)
586         udst = (unsigned short*)(buf->data.s + (size_t)work_var_count*sample_count);
587     else
588         idst = buf->data.i + (size_t)work_var_count*sample_count;
589 
590     for (i = 0; i < sample_count; i++)
591     {
592         if (udst)
593             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
594         else
595             idst[i] = sidx ? sidx[i] : i;
596     }
597 
598     if( cv_n )
599     {
600         unsigned short* usdst = 0;
601         int* idst2 = 0;
602 
603         if (is_buf_16u)
604         {
605             usdst = (unsigned short*)(buf->data.s + (size_t)(get_work_var_count()-1)*sample_count);
606             for( i = vi = 0; i < sample_count; i++ )
607             {
608                 usdst[i] = (unsigned short)vi++;
609                 vi &= vi < cv_n ? -1 : 0;
610             }
611 
612             for( i = 0; i < sample_count; i++ )
613             {
614                 int a = (*rng)(sample_count);
615                 int b = (*rng)(sample_count);
616                 unsigned short unsh = (unsigned short)vi;
617                 CV_SWAP( usdst[a], usdst[b], unsh );
618             }
619         }
620         else
621         {
622             idst2 = buf->data.i + (size_t)(get_work_var_count()-1)*sample_count;
623             for( i = vi = 0; i < sample_count; i++ )
624             {
625                 idst2[i] = vi++;
626                 vi &= vi < cv_n ? -1 : 0;
627             }
628 
629             for( i = 0; i < sample_count; i++ )
630             {
631                 int a = (*rng)(sample_count);
632                 int b = (*rng)(sample_count);
633                 CV_SWAP( idst2[a], idst2[b], vi );
634             }
635         }
636     }
637 
638     if ( cat_map )
639         cat_map->cols = MAX( total_c_count, 1 );
640 
641     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
642         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
643     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
644 
645     have_priors = is_classifier && params.priors;
646     if( is_classifier )
647     {
648         int m = get_num_classes();
649         double sum = 0;
650         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
651         for( i = 0; i < m; i++ )
652         {
653             double val = have_priors ? params.priors[i] : 1.;
654             if( val <= 0 )
655                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
656             priors->data.db[i] = val;
657             sum += val;
658         }
659 
660         // normalize weights
661         if( have_priors )
662             cvScale( priors, priors, 1./sum );
663 
664         CV_CALL( priors_mult = cvCloneMat( priors ));
665         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
666     }
667 
668 
669     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
670     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
671 
672     __END__;
673 
674     if( data )
675         delete data;
676 
677     if (_fdst)
678         cvFree( &_fdst );
679     if (_idst)
680         cvFree( &_idst );
681     cvFree( &int_ptr );
682     cvFree( &pair16u32s_ptr);
683     cvReleaseMat( &var_type0 );
684     cvReleaseMat( &sample_indices );
685     cvReleaseMat( &tmp_map );
686 }
687 
do_responses_copy()688 void CvDTreeTrainData::do_responses_copy()
689 {
690     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
691     cvCopy( responses, responses_copy);
692     responses = responses_copy;
693 }
694 
subsample_data(const CvMat * _subsample_idx)695 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
696 {
697     CvDTreeNode* root = 0;
698     CvMat* isubsample_idx = 0;
699     CvMat* subsample_co = 0;
700 
701     bool isMakeRootCopy = true;
702 
703     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
704 
705     __BEGIN__;
706 
707     if( !data_root )
708         CV_ERROR( CV_StsError, "No training data has been set" );
709 
710     if( _subsample_idx )
711     {
712         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
713 
714         if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count )
715         {
716             const int* sidx = isubsample_idx->data.i;
717             for( int i = 0; i < sample_count; i++ )
718             {
719                 if( sidx[i] != i )
720                 {
721                     isMakeRootCopy = false;
722                     break;
723                 }
724             }
725         }
726         else
727             isMakeRootCopy = false;
728     }
729 
730     if( isMakeRootCopy )
731     {
732         // make a copy of the root node
733         CvDTreeNode temp;
734         int i;
735         root = new_node( 0, 1, 0, 0 );
736         temp = *root;
737         *root = *data_root;
738         root->num_valid = temp.num_valid;
739         if( root->num_valid )
740         {
741             for( i = 0; i < var_count; i++ )
742                 root->num_valid[i] = data_root->num_valid[i];
743         }
744         root->cv_Tn = temp.cv_Tn;
745         root->cv_node_risk = temp.cv_node_risk;
746         root->cv_node_error = temp.cv_node_error;
747     }
748     else
749     {
750         int* sidx = isubsample_idx->data.i;
751         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
752         int* co, cur_ofs = 0;
753         int vi, i;
754         int workVarCount = get_work_var_count();
755         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
756 
757         root = new_node( 0, count, 1, 0 );
758 
759         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
760         cvZero( subsample_co );
761         co = subsample_co->data.i;
762         for( i = 0; i < count; i++ )
763             co[sidx[i]*2]++;
764         for( i = 0; i < sample_count; i++ )
765         {
766             if( co[i*2] )
767             {
768                 co[i*2+1] = cur_ofs;
769                 cur_ofs += co[i*2];
770             }
771             else
772                 co[i*2+1] = -1;
773         }
774 
775         cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
776         for( vi = 0; vi < workVarCount; vi++ )
777         {
778             int ci = get_var_type(vi);
779 
780             if( ci >= 0 || vi >= var_count )
781             {
782                 int num_valid = 0;
783                 const int* src = CvDTreeTrainData::get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf );
784 
785                 if (is_buf_16u)
786                 {
787                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
788                         (size_t)vi*sample_count + root->offset);
789                     for( i = 0; i < count; i++ )
790                     {
791                         int val = src[sidx[i]];
792                         udst[i] = (unsigned short)val;
793                         num_valid += val >= 0;
794                     }
795                 }
796                 else
797                 {
798                     int* idst = buf->data.i + root->buf_idx*get_length_subbuf() +
799                         (size_t)vi*sample_count + root->offset;
800                     for( i = 0; i < count; i++ )
801                     {
802                         int val = src[sidx[i]];
803                         idst[i] = val;
804                         num_valid += val >= 0;
805                     }
806                 }
807 
808                 if( vi < var_count )
809                     root->set_num_valid(vi, num_valid);
810             }
811             else
812             {
813                 int *src_idx_buf = (int*)(uchar*)inn_buf;
814                 float *src_val_buf = (float*)(src_idx_buf + sample_count);
815                 int* sample_indices_buf = (int*)(src_val_buf + sample_count);
816                 const int* src_idx = 0;
817                 const float* src_val = 0;
818                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
819                 int j = 0, idx, count_i;
820                 int num_valid = data_root->get_num_valid(vi);
821 
822                 if (is_buf_16u)
823                 {
824                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
825                         (size_t)vi*sample_count + data_root->offset);
826                     for( i = 0; i < num_valid; i++ )
827                     {
828                         idx = src_idx[i];
829                         count_i = co[idx*2];
830                         if( count_i )
831                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
832                                 udst_idx[j] = (unsigned short)cur_ofs;
833                     }
834 
835                     root->set_num_valid(vi, j);
836 
837                     for( ; i < sample_count; i++ )
838                     {
839                         idx = src_idx[i];
840                         count_i = co[idx*2];
841                         if( count_i )
842                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
843                                 udst_idx[j] = (unsigned short)cur_ofs;
844                     }
845                 }
846                 else
847                 {
848                     int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() +
849                         (size_t)vi*sample_count + root->offset;
850                     for( i = 0; i < num_valid; i++ )
851                     {
852                         idx = src_idx[i];
853                         count_i = co[idx*2];
854                         if( count_i )
855                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
856                                 idst_idx[j] = cur_ofs;
857                     }
858 
859                     root->set_num_valid(vi, j);
860 
861                     for( ; i < sample_count; i++ )
862                     {
863                         idx = src_idx[i];
864                         count_i = co[idx*2];
865                         if( count_i )
866                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
867                                 idst_idx[j] = cur_ofs;
868                     }
869                 }
870             }
871         }
872         // sample indices subsampling
873         const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
874         if (is_buf_16u)
875         {
876             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
877                 (size_t)workVarCount*sample_count + root->offset);
878             for (i = 0; i < count; i++)
879                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
880         }
881         else
882         {
883             int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() +
884                 (size_t)workVarCount*sample_count + root->offset;
885             for (i = 0; i < count; i++)
886                 sample_idx_dst[i] = sample_idx_src[sidx[i]];
887         }
888     }
889 
890     __END__;
891 
892     cvReleaseMat( &isubsample_idx );
893     cvReleaseMat( &subsample_co );
894 
895     return root;
896 }
897 
898 
get_vectors(const CvMat * _subsample_idx,float * values,uchar * missing,float * _responses,bool get_class_idx)899 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
900                                     float* values, uchar* missing,
901                                     float* _responses, bool get_class_idx )
902 {
903     CvMat* subsample_idx = 0;
904     CvMat* subsample_co = 0;
905 
906     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
907 
908     __BEGIN__;
909 
910     int i, vi, total = sample_count, count = total, cur_ofs = 0;
911     int* sidx = 0;
912     int* co = 0;
913 
914     cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
915     if( _subsample_idx )
916     {
917         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
918         sidx = subsample_idx->data.i;
919         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
920         co = subsample_co->data.i;
921         cvZero( subsample_co );
922         count = subsample_idx->cols + subsample_idx->rows - 1;
923         for( i = 0; i < count; i++ )
924             co[sidx[i]*2]++;
925         for( i = 0; i < total; i++ )
926         {
927             int count_i = co[i*2];
928             if( count_i )
929             {
930                 co[i*2+1] = cur_ofs*var_count;
931                 cur_ofs += count_i;
932             }
933         }
934     }
935 
936     if( missing )
937         memset( missing, 1, count*var_count );
938 
939     for( vi = 0; vi < var_count; vi++ )
940     {
941         int ci = get_var_type(vi);
942         if( ci >= 0 ) // categorical
943         {
944             float* dst = values + vi;
945             uchar* m = missing ? missing + vi : 0;
946             const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf);
947 
948             for( i = 0; i < count; i++, dst += var_count )
949             {
950                 int idx = sidx ? sidx[i] : i;
951                 int val = src[idx];
952                 *dst = (float)val;
953                 if( m )
954                 {
955                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
956                     m += var_count;
957                 }
958             }
959         }
960         else // ordered
961         {
962             float* dst = values + vi;
963             uchar* m = missing ? missing + vi : 0;
964             int count1 = data_root->get_num_valid(vi);
965             float *src_val_buf = (float*)(uchar*)inn_buf;
966             int* src_idx_buf = (int*)(src_val_buf + sample_count);
967             int* sample_indices_buf = src_idx_buf + sample_count;
968             const float *src_val = 0;
969             const int* src_idx = 0;
970             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
971 
972             for( i = 0; i < count1; i++ )
973             {
974                 int idx = src_idx[i];
975                 int count_i = 1;
976                 if( co )
977                 {
978                     count_i = co[idx*2];
979                     cur_ofs = co[idx*2+1];
980                 }
981                 else
982                     cur_ofs = idx*var_count;
983                 if( count_i )
984                 {
985                     float val = src_val[i];
986                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
987                     {
988                         dst[cur_ofs] = val;
989                         if( m )
990                             m[cur_ofs] = 0;
991                     }
992                 }
993             }
994         }
995     }
996 
997     // copy responses
998     if( _responses )
999     {
1000         if( is_classifier )
1001         {
1002             const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf);
1003             for( i = 0; i < count; i++ )
1004             {
1005                 int idx = sidx ? sidx[i] : i;
1006                 int val = get_class_idx ? src[idx] :
1007                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
1008                 _responses[i] = (float)val;
1009             }
1010         }
1011         else
1012         {
1013             float* val_buf = (float*)(uchar*)inn_buf;
1014             int* sample_idx_buf = (int*)(val_buf + sample_count);
1015             const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
1016             for( i = 0; i < count; i++ )
1017             {
1018                 int idx = sidx ? sidx[i] : i;
1019                 _responses[i] = _values[idx];
1020             }
1021         }
1022     }
1023 
1024     __END__;
1025 
1026     cvReleaseMat( &subsample_idx );
1027     cvReleaseMat( &subsample_co );
1028 }
1029 
1030 
new_node(CvDTreeNode * parent,int count,int storage_idx,int offset)1031 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
1032                                          int storage_idx, int offset )
1033 {
1034     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
1035 
1036     node->sample_count = count;
1037     node->depth = parent ? parent->depth + 1 : 0;
1038     node->parent = parent;
1039     node->left = node->right = 0;
1040     node->split = 0;
1041     node->value = 0;
1042     node->class_idx = 0;
1043     node->maxlr = 0.;
1044 
1045     node->buf_idx = storage_idx;
1046     node->offset = offset;
1047     if( nv_heap )
1048         node->num_valid = (int*)cvSetNew( nv_heap );
1049     else
1050         node->num_valid = 0;
1051     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
1052     node->complexity = 0;
1053 
1054     if( params.cv_folds > 0 && cv_heap )
1055     {
1056         int cv_n = params.cv_folds;
1057         node->Tn = INT_MAX;
1058         node->cv_Tn = (int*)cvSetNew( cv_heap );
1059         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
1060         node->cv_node_error = node->cv_node_risk + cv_n;
1061     }
1062     else
1063     {
1064         node->Tn = 0;
1065         node->cv_Tn = 0;
1066         node->cv_node_risk = 0;
1067         node->cv_node_error = 0;
1068     }
1069 
1070     return node;
1071 }
1072 
1073 
new_split_ord(int vi,float cmp_val,int split_point,int inversed,float quality)1074 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
1075                 int split_point, int inversed, float quality )
1076 {
1077     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1078     split->var_idx = vi;
1079     split->condensed_idx = INT_MIN;
1080     split->ord.c = cmp_val;
1081     split->ord.split_point = split_point;
1082     split->inversed = inversed;
1083     split->quality = quality;
1084     split->next = 0;
1085 
1086     return split;
1087 }
1088 
1089 
new_split_cat(int vi,float quality)1090 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
1091 {
1092     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
1093     int i, n = (max_c_count + 31)/32;
1094 
1095     split->var_idx = vi;
1096     split->condensed_idx = INT_MIN;
1097     split->inversed = 0;
1098     split->quality = quality;
1099     for( i = 0; i < n; i++ )
1100         split->subset[i] = 0;
1101     split->next = 0;
1102 
1103     return split;
1104 }
1105 
1106 
free_node(CvDTreeNode * node)1107 void CvDTreeTrainData::free_node( CvDTreeNode* node )
1108 {
1109     CvDTreeSplit* split = node->split;
1110     free_node_data( node );
1111     while( split )
1112     {
1113         CvDTreeSplit* next = split->next;
1114         cvSetRemoveByPtr( split_heap, split );
1115         split = next;
1116     }
1117     node->split = 0;
1118     cvSetRemoveByPtr( node_heap, node );
1119 }
1120 
1121 
free_node_data(CvDTreeNode * node)1122 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
1123 {
1124     if( node->num_valid )
1125     {
1126         cvSetRemoveByPtr( nv_heap, node->num_valid );
1127         node->num_valid = 0;
1128     }
1129     // do not free cv_* fields, as all the cross-validation related data is released at once.
1130 }
1131 
1132 
free_train_data()1133 void CvDTreeTrainData::free_train_data()
1134 {
1135     cvReleaseMat( &counts );
1136     cvReleaseMat( &buf );
1137     cvReleaseMat( &direction );
1138     cvReleaseMat( &split_buf );
1139     cvReleaseMemStorage( &temp_storage );
1140     cvReleaseMat( &responses_copy );
1141     cv_heap = nv_heap = 0;
1142 }
1143 
1144 
clear()1145 void CvDTreeTrainData::clear()
1146 {
1147     free_train_data();
1148 
1149     cvReleaseMemStorage( &tree_storage );
1150 
1151     cvReleaseMat( &var_idx );
1152     cvReleaseMat( &var_type );
1153     cvReleaseMat( &cat_count );
1154     cvReleaseMat( &cat_ofs );
1155     cvReleaseMat( &cat_map );
1156     cvReleaseMat( &priors );
1157     cvReleaseMat( &priors_mult );
1158 
1159     node_heap = split_heap = 0;
1160 
1161     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
1162     have_labels = have_priors = is_classifier = false;
1163 
1164     buf_count = buf_size = 0;
1165     shared = false;
1166 
1167     data_root = 0;
1168 
1169     rng = &cv::theRNG();
1170 }
1171 
1172 
get_num_classes() const1173 int CvDTreeTrainData::get_num_classes() const
1174 {
1175     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
1176 }
1177 
1178 
get_var_type(int vi) const1179 int CvDTreeTrainData::get_var_type(int vi) const
1180 {
1181     return var_type->data.i[vi];
1182 }
1183 
get_ord_var_data(CvDTreeNode * n,int vi,float * ord_values_buf,int * sorted_indices_buf,const float ** ord_values,const int ** sorted_indices,int * sample_indices_buf)1184 void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
1185                                          const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
1186 {
1187     int vidx = var_idx ? var_idx->data.i[vi] : vi;
1188     int node_sample_count = n->sample_count;
1189     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
1190 
1191     const int* sample_indices = get_sample_indices(n, sample_indices_buf);
1192 
1193     if( !is_buf_16u )
1194         *sorted_indices = buf->data.i + n->buf_idx*get_length_subbuf() +
1195         (size_t)vi*sample_count + n->offset;
1196     else {
1197         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
1198             (size_t)vi*sample_count + n->offset );
1199         for( int i = 0; i < node_sample_count; i++ )
1200             sorted_indices_buf[i] = short_indices[i];
1201         *sorted_indices = sorted_indices_buf;
1202     }
1203 
1204     if( tflag == CV_ROW_SAMPLE )
1205     {
1206         for( int i = 0; i < node_sample_count &&
1207             ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1208         {
1209             int idx = (*sorted_indices)[i];
1210             idx = sample_indices[idx];
1211             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
1212         }
1213     }
1214     else
1215         for( int i = 0; i < node_sample_count &&
1216             ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
1217         {
1218             int idx = (*sorted_indices)[i];
1219             idx = sample_indices[idx];
1220             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
1221         }
1222 
1223     *ord_values = ord_values_buf;
1224 }
1225 
1226 
get_class_labels(CvDTreeNode * n,int * labels_buf)1227 const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
1228 {
1229     if (is_classifier)
1230         return get_cat_var_data( n, var_count, labels_buf);
1231     return 0;
1232 }
1233 
get_sample_indices(CvDTreeNode * n,int * indices_buf)1234 const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
1235 {
1236     return get_cat_var_data( n, get_work_var_count(), indices_buf );
1237 }
1238 
get_ord_responses(CvDTreeNode * n,float * values_buf,int * sample_indices_buf)1239 const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
1240 {
1241     int _sample_count = n->sample_count;
1242     int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
1243     const int* indices = get_sample_indices(n, sample_indices_buf);
1244 
1245     for( int i = 0; i < _sample_count &&
1246         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
1247     {
1248         int idx = indices[i];
1249         values_buf[i] = *(responses->data.fl + idx * r_step);
1250     }
1251 
1252     return values_buf;
1253 }
1254 
1255 
get_cv_labels(CvDTreeNode * n,int * labels_buf)1256 const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
1257 {
1258     if (have_labels)
1259         return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
1260     return 0;
1261 }
1262 
1263 
get_cat_var_data(CvDTreeNode * n,int vi,int * cat_values_buf)1264 const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
1265 {
1266     const int* cat_values = 0;
1267     if( !is_buf_16u )
1268         cat_values = buf->data.i + n->buf_idx*get_length_subbuf() +
1269             (size_t)vi*sample_count + n->offset;
1270     else {
1271         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
1272             (size_t)vi*sample_count + n->offset);
1273         for( int i = 0; i < n->sample_count; i++ )
1274             cat_values_buf[i] = short_values[i];
1275         cat_values = cat_values_buf;
1276     }
1277     return cat_values;
1278 }
1279 
1280 
get_child_buf_idx(CvDTreeNode * n)1281 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
1282 {
1283     int idx = n->buf_idx + 1;
1284     if( idx >= buf_count )
1285         idx = shared ? 1 : 0;
1286     return idx;
1287 }
1288 
1289 
write_params(CvFileStorage * fs) const1290 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
1291 {
1292     CV_FUNCNAME( "CvDTreeTrainData::write_params" );
1293 
1294     __BEGIN__;
1295 
1296     int vi, vcount = var_count;
1297 
1298     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
1299     cvWriteInt( fs, "var_all", var_all );
1300     cvWriteInt( fs, "var_count", var_count );
1301     cvWriteInt( fs, "ord_var_count", ord_var_count );
1302     cvWriteInt( fs, "cat_var_count", cat_var_count );
1303 
1304     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
1305     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
1306 
1307     if( is_classifier )
1308     {
1309         cvWriteInt( fs, "max_categories", params.max_categories );
1310     }
1311     else
1312     {
1313         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
1314     }
1315 
1316     cvWriteInt( fs, "max_depth", params.max_depth );
1317     cvWriteInt( fs, "min_sample_count", params.min_sample_count );
1318     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
1319 
1320     if( params.cv_folds > 1 )
1321     {
1322         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
1323         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
1324     }
1325 
1326     if( priors )
1327         cvWrite( fs, "priors", priors );
1328 
1329     cvEndWriteStruct( fs );
1330 
1331     if( var_idx )
1332         cvWrite( fs, "var_idx", var_idx );
1333 
1334     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
1335 
1336     for( vi = 0; vi < vcount; vi++ )
1337         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
1338 
1339     cvEndWriteStruct( fs );
1340 
1341     if( cat_count && (cat_var_count > 0 || is_classifier) )
1342     {
1343         CV_ASSERT( cat_count != 0 );
1344         cvWrite( fs, "cat_count", cat_count );
1345         cvWrite( fs, "cat_map", cat_map );
1346     }
1347 
1348     __END__;
1349 }
1350 
1351 
read_params(CvFileStorage * fs,CvFileNode * node)1352 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
1353 {
1354     CV_FUNCNAME( "CvDTreeTrainData::read_params" );
1355 
1356     __BEGIN__;
1357 
1358     CvFileNode *tparams_node, *vartype_node;
1359     CvSeqReader reader;
1360     int vi, max_split_size, tree_block_size;
1361 
1362     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
1363     var_all = cvReadIntByName( fs, node, "var_all" );
1364     var_count = cvReadIntByName( fs, node, "var_count", var_all );
1365     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
1366     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
1367 
1368     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
1369 
1370     if( tparams_node ) // training parameters are not necessary
1371     {
1372         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
1373 
1374         if( is_classifier )
1375         {
1376             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
1377         }
1378         else
1379         {
1380             params.regression_accuracy =
1381                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
1382         }
1383 
1384         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
1385         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
1386         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
1387 
1388         if( params.cv_folds > 1 )
1389         {
1390             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
1391             params.truncate_pruned_tree =
1392                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
1393         }
1394 
1395         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
1396         if( priors )
1397         {
1398             if( !CV_IS_MAT(priors) )
1399                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
1400             priors_mult = cvCloneMat( priors );
1401         }
1402     }
1403 
1404     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
1405     if( var_idx )
1406     {
1407         if( !CV_IS_MAT(var_idx) ||
1408             (var_idx->cols != 1 && var_idx->rows != 1) ||
1409             var_idx->cols + var_idx->rows - 1 != var_count ||
1410             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
1411             CV_ERROR( CV_StsParseError,
1412                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
1413 
1414         for( vi = 0; vi < var_count; vi++ )
1415             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
1416                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
1417     }
1418 
1419     ////// read var type
1420     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
1421 
1422     cat_var_count = 0;
1423     ord_var_count = -1;
1424     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
1425 
1426     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
1427         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
1428     else
1429     {
1430         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
1431             vartype_node->data.seq->total != var_count )
1432             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1433 
1434         cvStartReadSeq( vartype_node->data.seq, &reader );
1435 
1436         for( vi = 0; vi < var_count; vi++ )
1437         {
1438             CvFileNode* n = (CvFileNode*)reader.ptr;
1439             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
1440                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
1441             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
1442             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
1443         }
1444     }
1445     var_type->data.i[var_count] = cat_var_count;
1446 
1447     ord_var_count = ~ord_var_count;
1448     //////
1449 
1450     if( cat_var_count > 0 || is_classifier )
1451     {
1452         int ccount, total_c_count = 0;
1453         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
1454         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
1455 
1456         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
1457             (cat_count->cols != 1 && cat_count->rows != 1) ||
1458             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
1459             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
1460             (cat_map->cols != 1 && cat_map->rows != 1) ||
1461             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
1462             CV_ERROR( CV_StsParseError,
1463             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
1464 
1465         ccount = cat_var_count + is_classifier;
1466 
1467         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
1468         cat_ofs->data.i[0] = 0;
1469         max_c_count = 1;
1470 
1471         for( vi = 0; vi < ccount; vi++ )
1472         {
1473             int val = cat_count->data.i[vi];
1474             if( val <= 0 )
1475                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
1476             max_c_count = MAX( max_c_count, val );
1477             cat_ofs->data.i[vi+1] = total_c_count += val;
1478         }
1479 
1480         if( cat_map->cols + cat_map->rows - 1 != total_c_count )
1481             CV_ERROR( CV_StsBadSize,
1482             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
1483     }
1484 
1485     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
1486         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
1487 
1488     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
1489     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
1490     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
1491     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
1492             sizeof(CvDTreeNode), tree_storage ));
1493     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
1494             max_split_size, tree_storage ));
1495 
1496     __END__;
1497 }
1498 
1499 /////////////////////// Decision Tree /////////////////////////
CvDTreeParams()1500 CvDTreeParams::CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
1501     cv_folds(10), use_surrogates(true), use_1se_rule(true),
1502     truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
1503 {}
1504 
CvDTreeParams(int _max_depth,int _min_sample_count,float _regression_accuracy,bool _use_surrogates,int _max_categories,int _cv_folds,bool _use_1se_rule,bool _truncate_pruned_tree,const float * _priors)1505 CvDTreeParams::CvDTreeParams( int _max_depth, int _min_sample_count,
1506                               float _regression_accuracy, bool _use_surrogates,
1507                               int _max_categories, int _cv_folds,
1508                               bool _use_1se_rule, bool _truncate_pruned_tree,
1509                               const float* _priors ) :
1510     max_categories(_max_categories), max_depth(_max_depth),
1511     min_sample_count(_min_sample_count), cv_folds (_cv_folds),
1512     use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
1513     truncate_pruned_tree(_truncate_pruned_tree),
1514     regression_accuracy(_regression_accuracy),
1515     priors(_priors)
1516 {}
1517 
CvDTree()1518 CvDTree::CvDTree()
1519 {
1520     data = 0;
1521     var_importance = 0;
1522     default_model_name = "my_tree";
1523 
1524     clear();
1525 }
1526 
1527 
clear()1528 void CvDTree::clear()
1529 {
1530     cvReleaseMat( &var_importance );
1531     if( data )
1532     {
1533         if( !data->shared )
1534             delete data;
1535         else
1536             free_tree();
1537         data = 0;
1538     }
1539     root = 0;
1540     pruned_tree_idx = -1;
1541 }
1542 
1543 
~CvDTree()1544 CvDTree::~CvDTree()
1545 {
1546     clear();
1547 }
1548 
1549 
get_root() const1550 const CvDTreeNode* CvDTree::get_root() const
1551 {
1552     return root;
1553 }
1554 
1555 
get_pruned_tree_idx() const1556 int CvDTree::get_pruned_tree_idx() const
1557 {
1558     return pruned_tree_idx;
1559 }
1560 
1561 
get_data()1562 CvDTreeTrainData* CvDTree::get_data()
1563 {
1564     return data;
1565 }
1566 
1567 
train(const CvMat * _train_data,int _tflag,const CvMat * _responses,const CvMat * _var_idx,const CvMat * _sample_idx,const CvMat * _var_type,const CvMat * _missing_mask,CvDTreeParams _params)1568 bool CvDTree::train( const CvMat* _train_data, int _tflag,
1569                      const CvMat* _responses, const CvMat* _var_idx,
1570                      const CvMat* _sample_idx, const CvMat* _var_type,
1571                      const CvMat* _missing_mask, CvDTreeParams _params )
1572 {
1573     bool result = false;
1574 
1575     CV_FUNCNAME( "CvDTree::train" );
1576 
1577     __BEGIN__;
1578 
1579     clear();
1580     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
1581                                  _var_idx, _sample_idx, _var_type,
1582                                  _missing_mask, _params, false );
1583     CV_CALL( result = do_train(0) );
1584 
1585     __END__;
1586 
1587     return result;
1588 }
1589 
train(const Mat & _train_data,int _tflag,const Mat & _responses,const Mat & _var_idx,const Mat & _sample_idx,const Mat & _var_type,const Mat & _missing_mask,CvDTreeParams _params)1590 bool CvDTree::train( const Mat& _train_data, int _tflag,
1591                     const Mat& _responses, const Mat& _var_idx,
1592                     const Mat& _sample_idx, const Mat& _var_type,
1593                     const Mat& _missing_mask, CvDTreeParams _params )
1594 {
1595     train_data_hdr = _train_data;
1596     train_data_mat = _train_data;
1597     responses_hdr = _responses;
1598     responses_mat = _responses;
1599 
1600     CvMat vidx=_var_idx, sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask;
1601 
1602     return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
1603                  vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
1604 }
1605 
1606 
train(CvMLData * _data,CvDTreeParams _params)1607 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
1608 {
1609    bool result = false;
1610 
1611     CV_FUNCNAME( "CvDTree::train" );
1612 
1613     __BEGIN__;
1614 
1615     const CvMat* values = _data->get_values();
1616     const CvMat* response = _data->get_responses();
1617     const CvMat* missing = _data->get_missing();
1618     const CvMat* var_types = _data->get_var_types();
1619     const CvMat* train_sidx = _data->get_train_sample_idx();
1620     const CvMat* var_idx = _data->get_var_idx();
1621 
1622     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
1623         train_sidx, var_types, missing, _params ) );
1624 
1625     __END__;
1626 
1627     return result;
1628 }
1629 
train(CvDTreeTrainData * _data,const CvMat * _subsample_idx)1630 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
1631 {
1632     bool result = false;
1633 
1634     CV_FUNCNAME( "CvDTree::train" );
1635 
1636     __BEGIN__;
1637 
1638     clear();
1639     data = _data;
1640     data->shared = true;
1641     CV_CALL( result = do_train(_subsample_idx));
1642 
1643     __END__;
1644 
1645     return result;
1646 }
1647 
1648 
do_train(const CvMat * _subsample_idx)1649 bool CvDTree::do_train( const CvMat* _subsample_idx )
1650 {
1651     bool result = false;
1652 
1653     CV_FUNCNAME( "CvDTree::do_train" );
1654 
1655     __BEGIN__;
1656 
1657     root = data->subsample_data( _subsample_idx );
1658 
1659     CV_CALL( try_split_node(root));
1660 
1661     if( root->split )
1662     {
1663         CV_Assert( root->left );
1664         CV_Assert( root->right );
1665 
1666         if( data->params.cv_folds > 0 )
1667             CV_CALL( prune_cv() );
1668 
1669         if( !data->shared )
1670             data->free_train_data();
1671 
1672         result = true;
1673     }
1674 
1675     __END__;
1676 
1677     return result;
1678 }
1679 
1680 
try_split_node(CvDTreeNode * node)1681 void CvDTree::try_split_node( CvDTreeNode* node )
1682 {
1683     CvDTreeSplit* best_split = 0;
1684     int i, n = node->sample_count, vi;
1685     bool can_split = true;
1686     double quality_scale;
1687 
1688     calc_node_value( node );
1689 
1690     if( node->sample_count <= data->params.min_sample_count ||
1691         node->depth >= data->params.max_depth )
1692         can_split = false;
1693 
1694     if( can_split && data->is_classifier )
1695     {
1696         // check if we have a "pure" node,
1697         // we assume that cls_count is filled by calc_node_value()
1698         int* cls_count = data->counts->data.i;
1699         int nz = 0, m = data->get_num_classes();
1700         for( i = 0; i < m; i++ )
1701             nz += cls_count[i] != 0;
1702         if( nz == 1 ) // there is only one class
1703             can_split = false;
1704     }
1705     else if( can_split )
1706     {
1707         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
1708             can_split = false;
1709     }
1710 
1711     if( can_split )
1712     {
1713         best_split = find_best_split(node);
1714         // TODO: check the split quality ...
1715         node->split = best_split;
1716     }
1717     if( !can_split || !best_split )
1718     {
1719         data->free_node_data(node);
1720         return;
1721     }
1722 
1723     quality_scale = calc_node_dir( node );
1724     if( data->params.use_surrogates )
1725     {
1726         // find all the surrogate splits
1727         // and sort them by their similarity to the primary one
1728         for( vi = 0; vi < data->var_count; vi++ )
1729         {
1730             CvDTreeSplit* split;
1731             int ci = data->get_var_type(vi);
1732 
1733             if( vi == best_split->var_idx )
1734                 continue;
1735 
1736             if( ci >= 0 )
1737                 split = find_surrogate_split_cat( node, vi );
1738             else
1739                 split = find_surrogate_split_ord( node, vi );
1740 
1741             if( split )
1742             {
1743                 // insert the split
1744                 CvDTreeSplit* prev_split = node->split;
1745                 split->quality = (float)(split->quality*quality_scale);
1746 
1747                 while( prev_split->next &&
1748                        prev_split->next->quality > split->quality )
1749                     prev_split = prev_split->next;
1750                 split->next = prev_split->next;
1751                 prev_split->next = split;
1752             }
1753         }
1754     }
1755     split_node_data( node );
1756     try_split_node( node->left );
1757     try_split_node( node->right );
1758 }
1759 
1760 
1761 // calculate direction (left(-1),right(1),missing(0))
1762 // for each sample using the best split
1763 // the function returns scale coefficients for surrogate split quality factors.
1764 // the scale is applied to normalize surrogate split quality relatively to the
1765 // best (primary) split quality. That is, if a surrogate split is absolutely
1766 // identical to the primary split, its quality will be set to the maximum value =
1767 // quality of the primary split; otherwise, it will be lower.
1768 // besides, the function compute node->maxlr,
1769 // minimum possible quality (w/o considering the above mentioned scale)
1770 // for a surrogate split. Surrogate splits with quality less than node->maxlr
1771 // are not discarded.
calc_node_dir(CvDTreeNode * node)1772 double CvDTree::calc_node_dir( CvDTreeNode* node )
1773 {
1774     char* dir = (char*)data->direction->data.ptr;
1775     int i, n = node->sample_count, vi = node->split->var_idx;
1776     double L, R;
1777 
1778     assert( !node->split->inversed );
1779 
1780     if( data->get_var_type(vi) >= 0 ) // split on categorical var
1781     {
1782         cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
1783         int* labels_buf = (int*)inn_buf;
1784         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
1785         const int* subset = node->split->subset;
1786         if( !data->have_priors )
1787         {
1788             int sum = 0, sum_abs = 0;
1789 
1790             for( i = 0; i < n; i++ )
1791             {
1792                 int idx = labels[i];
1793                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
1794                     CV_DTREE_CAT_DIR(idx,subset) : 0;
1795                 sum += d; sum_abs += d & 1;
1796                 dir[i] = (char)d;
1797             }
1798 
1799             R = (sum_abs + sum) >> 1;
1800             L = (sum_abs - sum) >> 1;
1801         }
1802         else
1803         {
1804             const double* priors = data->priors_mult->data.db;
1805             double sum = 0, sum_abs = 0;
1806             int* responses_buf = labels_buf + n;
1807             const int* responses = data->get_class_labels(node, responses_buf);
1808 
1809             for( i = 0; i < n; i++ )
1810             {
1811                 int idx = labels[i];
1812                 double w = priors[responses[i]];
1813                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
1814                 sum += d*w; sum_abs += (d & 1)*w;
1815                 dir[i] = (char)d;
1816             }
1817 
1818             R = (sum_abs + sum) * 0.5;
1819             L = (sum_abs - sum) * 0.5;
1820         }
1821     }
1822     else // split on ordered var
1823     {
1824         int split_point = node->split->ord.split_point;
1825         int n1 = node->get_num_valid(vi);
1826         cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
1827         float* val_buf = (float*)(uchar*)inn_buf;
1828         int* sorted_buf = (int*)(val_buf + n);
1829         int* sample_idx_buf = sorted_buf + n;
1830         const float* val = 0;
1831         const int* sorted = 0;
1832         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
1833 
1834         assert( 0 <= split_point && split_point < n1-1 );
1835 
1836         if( !data->have_priors )
1837         {
1838             for( i = 0; i <= split_point; i++ )
1839                 dir[sorted[i]] = (char)-1;
1840             for( ; i < n1; i++ )
1841                 dir[sorted[i]] = (char)1;
1842             for( ; i < n; i++ )
1843                 dir[sorted[i]] = (char)0;
1844 
1845             L = split_point-1;
1846             R = n1 - split_point + 1;
1847         }
1848         else
1849         {
1850             const double* priors = data->priors_mult->data.db;
1851             int* responses_buf = sample_idx_buf + n;
1852             const int* responses = data->get_class_labels(node, responses_buf);
1853             L = R = 0;
1854 
1855             for( i = 0; i <= split_point; i++ )
1856             {
1857                 int idx = sorted[i];
1858                 double w = priors[responses[idx]];
1859                 dir[idx] = (char)-1;
1860                 L += w;
1861             }
1862 
1863             for( ; i < n1; i++ )
1864             {
1865                 int idx = sorted[i];
1866                 double w = priors[responses[idx]];
1867                 dir[idx] = (char)1;
1868                 R += w;
1869             }
1870 
1871             for( ; i < n; i++ )
1872                 dir[sorted[i]] = (char)0;
1873         }
1874     }
1875     node->maxlr = MAX( L, R );
1876     return node->split->quality/(L + R);
1877 }
1878 
1879 
1880 namespace cv
1881 {
1882 
operator ()(CvDTreeSplit * obj) const1883 template<> CV_EXPORTS void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const
1884 {
1885     fastFree(obj);
1886 }
1887 
DTreeBestSplitFinder(CvDTree * _tree,CvDTreeNode * _node)1888 DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
1889 {
1890     tree = _tree;
1891     node = _node;
1892     splitSize = tree->get_data()->split_heap->elem_size;
1893 
1894     bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
1895     memset(bestSplit.get(), 0, splitSize);
1896     bestSplit->quality = -1;
1897     bestSplit->condensed_idx = INT_MIN;
1898     split.reset((CvDTreeSplit*)fastMalloc(splitSize));
1899     memset(split.get(), 0, splitSize);
1900     //haveSplit = false;
1901 }
1902 
DTreeBestSplitFinder(const DTreeBestSplitFinder & finder,Split)1903 DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
1904 {
1905     tree = finder.tree;
1906     node = finder.node;
1907     splitSize = tree->get_data()->split_heap->elem_size;
1908 
1909     bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
1910     memcpy(bestSplit.get(), finder.bestSplit.get(), splitSize);
1911     split.reset((CvDTreeSplit*)fastMalloc(splitSize));
1912     memset(split.get(), 0, splitSize);
1913 }
1914 
operator ()(const BlockedRange & range)1915 void DTreeBestSplitFinder::operator()(const BlockedRange& range)
1916 {
1917     int vi, vi1 = range.begin(), vi2 = range.end();
1918     int n = node->sample_count;
1919     CvDTreeTrainData* data = tree->get_data();
1920     AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
1921 
1922     for( vi = vi1; vi < vi2; vi++ )
1923     {
1924         CvDTreeSplit *res;
1925         int ci = data->get_var_type(vi);
1926         if( node->get_num_valid(vi) <= 1 )
1927             continue;
1928 
1929         if( data->is_classifier )
1930         {
1931             if( ci >= 0 )
1932                 res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1933             else
1934                 res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1935         }
1936         else
1937         {
1938             if( ci >= 0 )
1939                 res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1940             else
1941                 res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
1942         }
1943 
1944         if( res && bestSplit->quality < split->quality )
1945                 memcpy( bestSplit.get(), split.get(), splitSize );
1946     }
1947 }
1948 
join(DTreeBestSplitFinder & rhs)1949 void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
1950 {
1951     if( bestSplit->quality < rhs.bestSplit->quality )
1952         memcpy( bestSplit.get(), rhs.bestSplit.get(), splitSize );
1953 }
1954 }
1955 
1956 
find_best_split(CvDTreeNode * node)1957 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
1958 {
1959     DTreeBestSplitFinder finder( this, node );
1960 
1961     cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
1962 
1963     CvDTreeSplit *bestSplit = 0;
1964     if( finder.bestSplit->quality > 0 )
1965     {
1966         bestSplit = data->new_split_cat( 0, -1.0f );
1967         memcpy( bestSplit, finder.bestSplit, finder.splitSize );
1968     }
1969 
1970     return bestSplit;
1971 }
1972 
find_split_ord_class(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)1973 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
1974                                              float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
1975 {
1976     const float epsilon = FLT_EPSILON*2;
1977     int n = node->sample_count;
1978     int n1 = node->get_num_valid(vi);
1979     int m = data->get_num_classes();
1980 
1981     int base_size = 2*m*sizeof(int);
1982     cv::AutoBuffer<uchar> inn_buf(base_size);
1983     if( !_ext_buf )
1984       inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
1985     uchar* base_buf = (uchar*)inn_buf;
1986     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
1987     float* values_buf = (float*)ext_buf;
1988     int* sorted_indices_buf = (int*)(values_buf + n);
1989     int* sample_indices_buf = sorted_indices_buf + n;
1990     const float* values = 0;
1991     const int* sorted_indices = 0;
1992     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
1993                             &sorted_indices, sample_indices_buf );
1994     int* responses_buf =  sample_indices_buf + n;
1995     const int* responses = data->get_class_labels( node, responses_buf );
1996 
1997     const int* rc0 = data->counts->data.i;
1998     int* lc = (int*)base_buf;
1999     int* rc = lc + m;
2000     int i, best_i = -1;
2001     double lsum2 = 0, rsum2 = 0, best_val = init_quality;
2002     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
2003 
2004     // init arrays of class instance counters on both sides of the split
2005     for( i = 0; i < m; i++ )
2006     {
2007         lc[i] = 0;
2008         rc[i] = rc0[i];
2009     }
2010 
2011     // compensate for missing values
2012     for( i = n1; i < n; i++ )
2013     {
2014         rc[responses[sorted_indices[i]]]--;
2015     }
2016 
2017     if( !priors )
2018     {
2019         int L = 0, R = n1;
2020 
2021         for( i = 0; i < m; i++ )
2022             rsum2 += (double)rc[i]*rc[i];
2023 
2024         for( i = 0; i < n1 - 1; i++ )
2025         {
2026             int idx = responses[sorted_indices[i]];
2027             int lv, rv;
2028             L++; R--;
2029             lv = lc[idx]; rv = rc[idx];
2030             lsum2 += lv*2 + 1;
2031             rsum2 -= rv*2 - 1;
2032             lc[idx] = lv + 1; rc[idx] = rv - 1;
2033 
2034             if( values[i] + epsilon < values[i+1] )
2035             {
2036                 double val = (lsum2*R + rsum2*L)/((double)L*R);
2037                 if( best_val < val )
2038                 {
2039                     best_val = val;
2040                     best_i = i;
2041                 }
2042             }
2043         }
2044     }
2045     else
2046     {
2047         double L = 0, R = 0;
2048         for( i = 0; i < m; i++ )
2049         {
2050             double wv = rc[i]*priors[i];
2051             R += wv;
2052             rsum2 += wv*wv;
2053         }
2054 
2055         for( i = 0; i < n1 - 1; i++ )
2056         {
2057             int idx = responses[sorted_indices[i]];
2058             int lv, rv;
2059             double p = priors[idx], p2 = p*p;
2060             L += p; R -= p;
2061             lv = lc[idx]; rv = rc[idx];
2062             lsum2 += p2*(lv*2 + 1);
2063             rsum2 -= p2*(rv*2 - 1);
2064             lc[idx] = lv + 1; rc[idx] = rv - 1;
2065 
2066             if( values[i] + epsilon < values[i+1] )
2067             {
2068                 double val = (lsum2*R + rsum2*L)/((double)L*R);
2069                 if( best_val < val )
2070                 {
2071                     best_val = val;
2072                     best_i = i;
2073                 }
2074             }
2075         }
2076     }
2077 
2078     CvDTreeSplit* split = 0;
2079     if( best_i >= 0 )
2080     {
2081         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2082         split->var_idx = vi;
2083         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2084         split->ord.split_point = best_i;
2085         split->inversed = 0;
2086         split->quality = (float)best_val;
2087     }
2088     return split;
2089 }
2090 
2091 
cluster_categories(const int * vectors,int n,int m,int * csums,int k,int * labels)2092 void CvDTree::cluster_categories( const int* vectors, int n, int m,
2093                                 int* csums, int k, int* labels )
2094 {
2095     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
2096     int iters = 0, max_iters = 100;
2097     int i, j, idx;
2098     cv::AutoBuffer<double> buf(n + k);
2099     double *v_weights = buf, *c_weights = buf + n;
2100     bool modified = true;
2101     RNG* r = data->rng;
2102 
2103     // assign labels randomly
2104     for( i = 0; i < n; i++ )
2105     {
2106         int sum = 0;
2107         const int* v = vectors + i*m;
2108         labels[i] = i < k ? i : r->uniform(0, k);
2109 
2110         // compute weight of each vector
2111         for( j = 0; j < m; j++ )
2112             sum += v[j];
2113         v_weights[i] = sum ? 1./sum : 0.;
2114     }
2115 
2116     for( i = 0; i < n; i++ )
2117     {
2118         int i1 = (*r)(n);
2119         int i2 = (*r)(n);
2120         CV_SWAP( labels[i1], labels[i2], j );
2121     }
2122 
2123     for( iters = 0; iters <= max_iters; iters++ )
2124     {
2125         // calculate csums
2126         for( i = 0; i < k; i++ )
2127         {
2128             for( j = 0; j < m; j++ )
2129                 csums[i*m + j] = 0;
2130         }
2131 
2132         for( i = 0; i < n; i++ )
2133         {
2134             const int* v = vectors + i*m;
2135             int* s = csums + labels[i]*m;
2136             for( j = 0; j < m; j++ )
2137                 s[j] += v[j];
2138         }
2139 
2140         // exit the loop here, when we have up-to-date csums
2141         if( iters == max_iters || !modified )
2142             break;
2143 
2144         modified = false;
2145 
2146         // calculate weight of each cluster
2147         for( i = 0; i < k; i++ )
2148         {
2149             const int* s = csums + i*m;
2150             int sum = 0;
2151             for( j = 0; j < m; j++ )
2152                 sum += s[j];
2153             c_weights[i] = sum ? 1./sum : 0;
2154         }
2155 
2156         // now for each vector determine the closest cluster
2157         for( i = 0; i < n; i++ )
2158         {
2159             const int* v = vectors + i*m;
2160             double alpha = v_weights[i];
2161             double min_dist2 = DBL_MAX;
2162             int min_idx = -1;
2163 
2164             for( idx = 0; idx < k; idx++ )
2165             {
2166                 const int* s = csums + idx*m;
2167                 double dist2 = 0., beta = c_weights[idx];
2168                 for( j = 0; j < m; j++ )
2169                 {
2170                     double t = v[j]*alpha - s[j]*beta;
2171                     dist2 += t*t;
2172                 }
2173                 if( min_dist2 > dist2 )
2174                 {
2175                     min_dist2 = dist2;
2176                     min_idx = idx;
2177                 }
2178             }
2179 
2180             if( min_idx != labels[i] )
2181                 modified = true;
2182             labels[i] = min_idx;
2183         }
2184     }
2185 }
2186 
2187 
find_split_cat_class(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)2188 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
2189                                              CvDTreeSplit* _split, uchar* _ext_buf )
2190 {
2191     int ci = data->get_var_type(vi);
2192     int n = node->sample_count;
2193     int m = data->get_num_classes();
2194     int _mi = data->cat_count->data.i[ci], mi = _mi;
2195 
2196     int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
2197     if( m > 2 && mi > data->params.max_categories )
2198         base_size += (m*std::min(data->params.max_categories, n) + mi)*sizeof(int);
2199     else
2200         base_size += mi*sizeof(int*);
2201     cv::AutoBuffer<uchar> inn_buf(base_size);
2202     if( !_ext_buf )
2203         inn_buf.allocate(base_size + 2*n*sizeof(int));
2204     uchar* base_buf = (uchar*)inn_buf;
2205     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2206 
2207     int* lc = (int*)base_buf;
2208     int* rc = lc + m;
2209     int* _cjk = rc + m*2, *cjk = _cjk;
2210     double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
2211 
2212     int* labels_buf = (int*)ext_buf;
2213     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2214     int* responses_buf = labels_buf + n;
2215     const int* responses = data->get_class_labels(node, responses_buf);
2216 
2217     int* cluster_labels = 0;
2218     int** int_ptr = 0;
2219     int i, j, k, idx;
2220     double L = 0, R = 0;
2221     double best_val = init_quality;
2222     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
2223     const double* priors = data->priors_mult->data.db;
2224 
2225     // init array of counters:
2226     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
2227     for( j = -1; j < mi; j++ )
2228         for( k = 0; k < m; k++ )
2229             cjk[j*m + k] = 0;
2230 
2231     for( i = 0; i < n; i++ )
2232     {
2233        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
2234        k = responses[i];
2235        cjk[j*m + k]++;
2236     }
2237 
2238     if( m > 2 )
2239     {
2240         if( mi > data->params.max_categories )
2241         {
2242             mi = MIN(data->params.max_categories, n);
2243             cjk = (int*)(c_weights + _mi);
2244             cluster_labels = cjk + m*mi;
2245             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
2246         }
2247         subset_i = 1;
2248         subset_n = 1 << mi;
2249     }
2250     else
2251     {
2252         assert( m == 2 );
2253         int_ptr = (int**)(c_weights + _mi);
2254         for( j = 0; j < mi; j++ )
2255             int_ptr[j] = cjk + j*2 + 1;
2256         std::sort(int_ptr, int_ptr + mi, LessThanPtr<int>());
2257         subset_i = 0;
2258         subset_n = mi;
2259     }
2260 
2261     for( k = 0; k < m; k++ )
2262     {
2263         int sum = 0;
2264         for( j = 0; j < mi; j++ )
2265             sum += cjk[j*m + k];
2266         rc[k] = sum;
2267         lc[k] = 0;
2268     }
2269 
2270     for( j = 0; j < mi; j++ )
2271     {
2272         double sum = 0;
2273         for( k = 0; k < m; k++ )
2274             sum += cjk[j*m + k]*priors[k];
2275         c_weights[j] = sum;
2276         R += c_weights[j];
2277     }
2278 
2279     for( ; subset_i < subset_n; subset_i++ )
2280     {
2281         double weight;
2282         int* crow;
2283         double lsum2 = 0, rsum2 = 0;
2284 
2285         if( m == 2 )
2286             idx = (int)(int_ptr[subset_i] - cjk)/2;
2287         else
2288         {
2289             int graycode = (subset_i>>1)^subset_i;
2290             int diff = graycode ^ prevcode;
2291 
2292             // determine index of the changed bit.
2293             Cv32suf u;
2294             idx = diff >= (1 << 16) ? 16 : 0;
2295             u.f = (float)(((diff >> 16) | diff) & 65535);
2296             idx += (u.i >> 23) - 127;
2297             subtract = graycode < prevcode;
2298             prevcode = graycode;
2299         }
2300 
2301         crow = cjk + idx*m;
2302         weight = c_weights[idx];
2303         if( weight < FLT_EPSILON )
2304             continue;
2305 
2306         if( !subtract )
2307         {
2308             for( k = 0; k < m; k++ )
2309             {
2310                 int t = crow[k];
2311                 int lval = lc[k] + t;
2312                 int rval = rc[k] - t;
2313                 double p = priors[k], p2 = p*p;
2314                 lsum2 += p2*lval*lval;
2315                 rsum2 += p2*rval*rval;
2316                 lc[k] = lval; rc[k] = rval;
2317             }
2318             L += weight;
2319             R -= weight;
2320         }
2321         else
2322         {
2323             for( k = 0; k < m; k++ )
2324             {
2325                 int t = crow[k];
2326                 int lval = lc[k] - t;
2327                 int rval = rc[k] + t;
2328                 double p = priors[k], p2 = p*p;
2329                 lsum2 += p2*lval*lval;
2330                 rsum2 += p2*rval*rval;
2331                 lc[k] = lval; rc[k] = rval;
2332             }
2333             L -= weight;
2334             R += weight;
2335         }
2336 
2337         if( L > FLT_EPSILON && R > FLT_EPSILON )
2338         {
2339             double val = (lsum2*R + rsum2*L)/((double)L*R);
2340             if( best_val < val )
2341             {
2342                 best_val = val;
2343                 best_subset = subset_i;
2344             }
2345         }
2346     }
2347 
2348     CvDTreeSplit* split = 0;
2349     if( best_subset >= 0 )
2350     {
2351         split = _split ? _split : data->new_split_cat( 0, -1.0f );
2352         split->var_idx = vi;
2353         split->quality = (float)best_val;
2354         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2355         if( m == 2 )
2356         {
2357             for( i = 0; i <= best_subset; i++ )
2358             {
2359                 idx = (int)(int_ptr[i] - cjk) >> 1;
2360                 split->subset[idx >> 5] |= 1 << (idx & 31);
2361             }
2362         }
2363         else
2364         {
2365             for( i = 0; i < _mi; i++ )
2366             {
2367                 idx = cluster_labels ? cluster_labels[i] : i;
2368                 if( best_subset & (1 << idx) )
2369                     split->subset[i >> 5] |= 1 << (i & 31);
2370             }
2371         }
2372     }
2373     return split;
2374 }
2375 
2376 
find_split_ord_reg(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)2377 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2378 {
2379     const float epsilon = FLT_EPSILON*2;
2380     int n = node->sample_count;
2381     int n1 = node->get_num_valid(vi);
2382 
2383     cv::AutoBuffer<uchar> inn_buf;
2384     if( !_ext_buf )
2385         inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
2386     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2387     float* values_buf = (float*)ext_buf;
2388     int* sorted_indices_buf = (int*)(values_buf + n);
2389     int* sample_indices_buf = sorted_indices_buf + n;
2390     const float* values = 0;
2391     const int* sorted_indices = 0;
2392     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2393     float* responses_buf =  (float*)(sample_indices_buf + n);
2394     const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
2395 
2396     int i, best_i = -1;
2397     double best_val = init_quality, lsum = 0, rsum = node->value*n;
2398     int L = 0, R = n1;
2399 
2400     // compensate for missing values
2401     for( i = n1; i < n; i++ )
2402         rsum -= responses[sorted_indices[i]];
2403 
2404     // find the optimal split
2405     for( i = 0; i < n1 - 1; i++ )
2406     {
2407         float t = responses[sorted_indices[i]];
2408         L++; R--;
2409         lsum += t;
2410         rsum -= t;
2411 
2412         if( values[i] + epsilon < values[i+1] )
2413         {
2414             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2415             if( best_val < val )
2416             {
2417                 best_val = val;
2418                 best_i = i;
2419             }
2420         }
2421     }
2422 
2423     CvDTreeSplit* split = 0;
2424     if( best_i >= 0 )
2425     {
2426         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
2427         split->var_idx = vi;
2428         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
2429         split->ord.split_point = best_i;
2430         split->inversed = 0;
2431         split->quality = (float)best_val;
2432     }
2433     return split;
2434 }
2435 
find_split_cat_reg(CvDTreeNode * node,int vi,float init_quality,CvDTreeSplit * _split,uchar * _ext_buf)2436 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
2437 {
2438     int ci = data->get_var_type(vi);
2439     int n = node->sample_count;
2440     int mi = data->cat_count->data.i[ci];
2441 
2442     int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
2443     cv::AutoBuffer<uchar> inn_buf(base_size);
2444     if( !_ext_buf )
2445         inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
2446     uchar* base_buf = (uchar*)inn_buf;
2447     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2448     int* labels_buf = (int*)ext_buf;
2449     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2450     float* responses_buf = (float*)(labels_buf + n);
2451     int* sample_indices_buf = (int*)(responses_buf + n);
2452     const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
2453 
2454     double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2455     int* counts = (int*)(sum + mi) + 1;
2456     double** sum_ptr = (double**)(counts + mi);
2457     int i, L = 0, R = 0;
2458     double best_val = init_quality, lsum = 0, rsum = 0;
2459     int best_subset = -1, subset_i;
2460 
2461     for( i = -1; i < mi; i++ )
2462         sum[i] = counts[i] = 0;
2463 
2464     // calculate sum response and weight of each category of the input var
2465     for( i = 0; i < n; i++ )
2466     {
2467         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
2468         double s = sum[idx] + responses[i];
2469         int nc = counts[idx] + 1;
2470         sum[idx] = s;
2471         counts[idx] = nc;
2472     }
2473 
2474     // calculate average response in each category
2475     for( i = 0; i < mi; i++ )
2476     {
2477         R += counts[i];
2478         rsum += sum[i];
2479         sum[i] /= MAX(counts[i],1);
2480         sum_ptr[i] = sum + i;
2481     }
2482 
2483     std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());
2484 
2485     // revert back to unnormalized sums
2486     // (there should be a very little loss of accuracy)
2487     for( i = 0; i < mi; i++ )
2488         sum[i] *= counts[i];
2489 
2490     for( subset_i = 0; subset_i < mi-1; subset_i++ )
2491     {
2492         int idx = (int)(sum_ptr[subset_i] - sum);
2493         int ni = counts[idx];
2494 
2495         if( ni )
2496         {
2497             double s = sum[idx];
2498             lsum += s; L += ni;
2499             rsum -= s; R -= ni;
2500 
2501             if( L && R )
2502             {
2503                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
2504                 if( best_val < val )
2505                 {
2506                     best_val = val;
2507                     best_subset = subset_i;
2508                 }
2509             }
2510         }
2511     }
2512 
2513     CvDTreeSplit* split = 0;
2514     if( best_subset >= 0 )
2515     {
2516         split = _split ? _split : data->new_split_cat( 0, -1.0f);
2517         split->var_idx = vi;
2518         split->quality = (float)best_val;
2519         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
2520         for( i = 0; i <= best_subset; i++ )
2521         {
2522             int idx = (int)(sum_ptr[i] - sum);
2523             split->subset[idx >> 5] |= 1 << (idx & 31);
2524         }
2525     }
2526     return split;
2527 }
2528 
find_surrogate_split_ord(CvDTreeNode * node,int vi,uchar * _ext_buf)2529 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
2530 {
2531     const float epsilon = FLT_EPSILON*2;
2532     const char* dir = (char*)data->direction->data.ptr;
2533     int n = node->sample_count, n1 = node->get_num_valid(vi);
2534     cv::AutoBuffer<uchar> inn_buf;
2535     if( !_ext_buf )
2536         inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
2537     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
2538     float* values_buf = (float*)ext_buf;
2539     int* sorted_indices_buf = (int*)(values_buf + n);
2540     int* sample_indices_buf = sorted_indices_buf + n;
2541     const float* values = 0;
2542     const int* sorted_indices = 0;
2543     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2544     // LL - number of samples that both the primary and the surrogate splits send to the left
2545     // LR - ... primary split sends to the left and the surrogate split sends to the right
2546     // RL - ... primary split sends to the right and the surrogate split sends to the left
2547     // RR - ... both send to the right
2548     int i, best_i = -1, best_inversed = 0;
2549     double best_val;
2550 
2551     if( !data->have_priors )
2552     {
2553         int LL = 0, RL = 0, LR, RR;
2554         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
2555         int sum = 0, sum_abs = 0;
2556 
2557         for( i = 0; i < n1; i++ )
2558         {
2559             int d = dir[sorted_indices[i]];
2560             sum += d; sum_abs += d & 1;
2561         }
2562 
2563         // sum_abs = R + L; sum = R - L
2564         RR = (sum_abs + sum) >> 1;
2565         LR = (sum_abs - sum) >> 1;
2566 
2567         // initially all the samples are sent to the right by the surrogate split,
2568         // LR of them are sent to the left by primary split, and RR - to the right.
2569         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2570         for( i = 0; i < n1 - 1; i++ )
2571         {
2572             int d = dir[sorted_indices[i]];
2573 
2574             if( d < 0 )
2575             {
2576                 LL++; LR--;
2577                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
2578                 {
2579                     best_val = LL + RR;
2580                     best_i = i; best_inversed = 0;
2581                 }
2582             }
2583             else if( d > 0 )
2584             {
2585                 RL++; RR--;
2586                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
2587                 {
2588                     best_val = RL + LR;
2589                     best_i = i; best_inversed = 1;
2590                 }
2591             }
2592         }
2593         best_val = _best_val;
2594     }
2595     else
2596     {
2597         double LL = 0, RL = 0, LR, RR;
2598         double worst_val = node->maxlr;
2599         double sum = 0, sum_abs = 0;
2600         const double* priors = data->priors_mult->data.db;
2601         int* responses_buf = sample_indices_buf + n;
2602         const int* responses = data->get_class_labels(node, responses_buf);
2603         best_val = worst_val;
2604 
2605         for( i = 0; i < n1; i++ )
2606         {
2607             int idx = sorted_indices[i];
2608             double w = priors[responses[idx]];
2609             int d = dir[idx];
2610             sum += d*w; sum_abs += (d & 1)*w;
2611         }
2612 
2613         // sum_abs = R + L; sum = R - L
2614         RR = (sum_abs + sum)*0.5;
2615         LR = (sum_abs - sum)*0.5;
2616 
2617         // initially all the samples are sent to the right by the surrogate split,
2618         // LR of them are sent to the left by primary split, and RR - to the right.
2619         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
2620         for( i = 0; i < n1 - 1; i++ )
2621         {
2622             int idx = sorted_indices[i];
2623             double w = priors[responses[idx]];
2624             int d = dir[idx];
2625 
2626             if( d < 0 )
2627             {
2628                 LL += w; LR -= w;
2629                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
2630                 {
2631                     best_val = LL + RR;
2632                     best_i = i; best_inversed = 0;
2633                 }
2634             }
2635             else if( d > 0 )
2636             {
2637                 RL += w; RR -= w;
2638                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
2639                 {
2640                     best_val = RL + LR;
2641                     best_i = i; best_inversed = 1;
2642                 }
2643             }
2644         }
2645     }
2646     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
2647         (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
2648 }
2649 
2650 
find_surrogate_split_cat(CvDTreeNode * node,int vi,uchar * _ext_buf)2651 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
2652 {
2653     const char* dir = (char*)data->direction->data.ptr;
2654     int n = node->sample_count;
2655     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
2656 
2657     int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
2658     cv::AutoBuffer<uchar> inn_buf(base_size);
2659     if( !_ext_buf )
2660         inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
2661     uchar* base_buf = (uchar*)inn_buf;
2662     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
2663 
2664     int* labels_buf = (int*)ext_buf;
2665     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2666     // LL - number of samples that both the primary and the surrogate splits send to the left
2667     // LR - ... primary split sends to the left and the surrogate split sends to the right
2668     // RL - ... primary split sends to the right and the surrogate split sends to the left
2669     // RR - ... both send to the right
2670     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
2671     double best_val = 0;
2672     double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
2673     double* rc = lc + mi + 1;
2674 
2675     for( i = -1; i < mi; i++ )
2676         lc[i] = rc[i] = 0;
2677 
2678     // for each category calculate the weight of samples
2679     // sent to the left (lc) and to the right (rc) by the primary split
2680     if( !data->have_priors )
2681     {
2682         int* _lc = (int*)rc + 1;
2683         int* _rc = _lc + mi + 1;
2684 
2685         for( i = -1; i < mi; i++ )
2686             _lc[i] = _rc[i] = 0;
2687 
2688         for( i = 0; i < n; i++ )
2689         {
2690             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2691             int d = dir[i];
2692             int sum = _lc[idx] + d;
2693             int sum_abs = _rc[idx] + (d & 1);
2694             _lc[idx] = sum; _rc[idx] = sum_abs;
2695         }
2696 
2697         for( i = 0; i < mi; i++ )
2698         {
2699             int sum = _lc[i];
2700             int sum_abs = _rc[i];
2701             lc[i] = (sum_abs - sum) >> 1;
2702             rc[i] = (sum_abs + sum) >> 1;
2703         }
2704     }
2705     else
2706     {
2707         const double* priors = data->priors_mult->data.db;
2708         int* responses_buf = labels_buf + n;
2709         const int* responses = data->get_class_labels(node, responses_buf);
2710 
2711         for( i = 0; i < n; i++ )
2712         {
2713             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
2714             double w = priors[responses[i]];
2715             int d = dir[i];
2716             double sum = lc[idx] + d*w;
2717             double sum_abs = rc[idx] + (d & 1)*w;
2718             lc[idx] = sum; rc[idx] = sum_abs;
2719         }
2720 
2721         for( i = 0; i < mi; i++ )
2722         {
2723             double sum = lc[i];
2724             double sum_abs = rc[i];
2725             lc[i] = (sum_abs - sum) * 0.5;
2726             rc[i] = (sum_abs + sum) * 0.5;
2727         }
2728     }
2729 
2730     // 2. now form the split.
2731     // in each category send all the samples to the same direction as majority
2732     for( i = 0; i < mi; i++ )
2733     {
2734         double lval = lc[i], rval = rc[i];
2735         if( lval > rval )
2736         {
2737             split->subset[i >> 5] |= 1 << (i & 31);
2738             best_val += lval;
2739             l_win++;
2740         }
2741         else
2742             best_val += rval;
2743     }
2744 
2745     split->quality = (float)best_val;
2746     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
2747         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
2748 
2749     return split;
2750 }
2751 
2752 
calc_node_value(CvDTreeNode * node)2753 void CvDTree::calc_node_value( CvDTreeNode* node )
2754 {
2755     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
2756     int m = data->get_num_classes();
2757 
2758     int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
2759     int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
2760     cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
2761     uchar* base_buf = (uchar*)inn_buf;
2762     uchar* ext_buf = base_buf + base_size;
2763 
2764     int* cv_labels_buf = (int*)ext_buf;
2765     const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
2766 
2767     if( data->is_classifier )
2768     {
2769         // in case of classification tree:
2770         //  * node value is the label of the class that has the largest weight in the node.
2771         //  * node risk is the weighted number of misclassified samples,
2772         //  * j-th cross-validation fold value and risk are calculated as above,
2773         //    but using the samples with cv_labels(*)!=j.
2774         //  * j-th cross-validation fold error is calculated as the weighted number of
2775         //    misclassified samples with cv_labels(*)==j.
2776 
2777         // compute the number of instances of each class
2778         int* cls_count = data->counts->data.i;
2779         int* responses_buf = cv_labels_buf + n;
2780         const int* responses = data->get_class_labels(node, responses_buf);
2781         int* cv_cls_count = (int*)base_buf;
2782         double max_val = -1, total_weight = 0;
2783         int max_k = -1;
2784         double* priors = data->priors_mult->data.db;
2785 
2786         for( k = 0; k < m; k++ )
2787             cls_count[k] = 0;
2788 
2789         if( cv_n == 0 )
2790         {
2791             for( i = 0; i < n; i++ )
2792                 cls_count[responses[i]]++;
2793         }
2794         else
2795         {
2796             for( j = 0; j < cv_n; j++ )
2797                 for( k = 0; k < m; k++ )
2798                     cv_cls_count[j*m + k] = 0;
2799 
2800             for( i = 0; i < n; i++ )
2801             {
2802                 j = cv_labels[i]; k = responses[i];
2803                 cv_cls_count[j*m + k]++;
2804             }
2805 
2806             for( j = 0; j < cv_n; j++ )
2807                 for( k = 0; k < m; k++ )
2808                     cls_count[k] += cv_cls_count[j*m + k];
2809         }
2810 
2811         if( data->have_priors && node->parent == 0 )
2812         {
2813             // compute priors_mult from priors, take the sample ratio into account.
2814             double sum = 0;
2815             for( k = 0; k < m; k++ )
2816             {
2817                 int n_k = cls_count[k];
2818                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
2819                 sum += priors[k];
2820             }
2821             sum = 1./sum;
2822             for( k = 0; k < m; k++ )
2823                 priors[k] *= sum;
2824         }
2825 
2826         for( k = 0; k < m; k++ )
2827         {
2828             double val = cls_count[k]*priors[k];
2829             total_weight += val;
2830             if( max_val < val )
2831             {
2832                 max_val = val;
2833                 max_k = k;
2834             }
2835         }
2836 
2837         node->class_idx = max_k;
2838         node->value = data->cat_map->data.i[
2839             data->cat_ofs->data.i[data->cat_var_count] + max_k];
2840         node->node_risk = total_weight - max_val;
2841 
2842         for( j = 0; j < cv_n; j++ )
2843         {
2844             double sum_k = 0, sum = 0, max_val_k = 0;
2845             max_val = -1; max_k = -1;
2846 
2847             for( k = 0; k < m; k++ )
2848             {
2849                 double w = priors[k];
2850                 double val_k = cv_cls_count[j*m + k]*w;
2851                 double val = cls_count[k]*w - val_k;
2852                 sum_k += val_k;
2853                 sum += val;
2854                 if( max_val < val )
2855                 {
2856                     max_val = val;
2857                     max_val_k = val_k;
2858                     max_k = k;
2859                 }
2860             }
2861 
2862             node->cv_Tn[j] = INT_MAX;
2863             node->cv_node_risk[j] = sum - max_val;
2864             node->cv_node_error[j] = sum_k - max_val_k;
2865         }
2866     }
2867     else
2868     {
2869         // in case of regression tree:
2870         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
2871         //    n is the number of samples in the node.
2872         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
2873         //  * j-th cross-validation fold value and risk are calculated as above,
2874         //    but using the samples with cv_labels(*)!=j.
2875         //  * j-th cross-validation fold error is calculated
2876         //    using samples with cv_labels(*)==j as the test subset:
2877         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
2878         //    where node_value_j is the node value calculated
2879         //    as described in the previous bullet, and summation is done
2880         //    over the samples with cv_labels(*)==j.
2881 
2882         double sum = 0, sum2 = 0;
2883         float* values_buf = (float*)(cv_labels_buf + n);
2884         int* sample_indices_buf = (int*)(values_buf + n);
2885         const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
2886         double *cv_sum = 0, *cv_sum2 = 0;
2887         int* cv_count = 0;
2888 
2889         if( cv_n == 0 )
2890         {
2891             for( i = 0; i < n; i++ )
2892             {
2893                 double t = values[i];
2894                 sum += t;
2895                 sum2 += t*t;
2896             }
2897         }
2898         else
2899         {
2900             cv_sum = (double*)base_buf;
2901             cv_sum2 = cv_sum + cv_n;
2902             cv_count = (int*)(cv_sum2 + cv_n);
2903 
2904             for( j = 0; j < cv_n; j++ )
2905             {
2906                 cv_sum[j] = cv_sum2[j] = 0.;
2907                 cv_count[j] = 0;
2908             }
2909 
2910             for( i = 0; i < n; i++ )
2911             {
2912                 j = cv_labels[i];
2913                 double t = values[i];
2914                 double s = cv_sum[j] + t;
2915                 double s2 = cv_sum2[j] + t*t;
2916                 int nc = cv_count[j] + 1;
2917                 cv_sum[j] = s;
2918                 cv_sum2[j] = s2;
2919                 cv_count[j] = nc;
2920             }
2921 
2922             for( j = 0; j < cv_n; j++ )
2923             {
2924                 sum += cv_sum[j];
2925                 sum2 += cv_sum2[j];
2926             }
2927         }
2928 
2929         node->node_risk = sum2 - (sum/n)*sum;
2930         node->value = sum/n;
2931 
2932         for( j = 0; j < cv_n; j++ )
2933         {
2934             double s = cv_sum[j], si = sum - s;
2935             double s2 = cv_sum2[j], s2i = sum2 - s2;
2936             int c = cv_count[j], ci = n - c;
2937             double r = si/MAX(ci,1);
2938             node->cv_node_risk[j] = s2i - r*r*ci;
2939             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
2940             node->cv_Tn[j] = INT_MAX;
2941         }
2942     }
2943 }
2944 
2945 
complete_node_dir(CvDTreeNode * node)2946 void CvDTree::complete_node_dir( CvDTreeNode* node )
2947 {
2948     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
2949     int nz = n - node->get_num_valid(node->split->var_idx);
2950     char* dir = (char*)data->direction->data.ptr;
2951 
2952     // try to complete direction using surrogate splits
2953     if( nz && data->params.use_surrogates )
2954     {
2955         cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
2956         CvDTreeSplit* split = node->split->next;
2957         for( ; split != 0 && nz; split = split->next )
2958         {
2959             int inversed_mask = split->inversed ? -1 : 0;
2960             vi = split->var_idx;
2961 
2962             if( data->get_var_type(vi) >= 0 ) // split on categorical var
2963             {
2964                 int* labels_buf = (int*)(uchar*)inn_buf;
2965                 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
2966                 const int* subset = split->subset;
2967 
2968                 for( i = 0; i < n; i++ )
2969                 {
2970                     int idx = labels[i];
2971                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
2972 
2973                     {
2974                         int d = CV_DTREE_CAT_DIR(idx,subset);
2975                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
2976                         if( --nz )
2977                             break;
2978                     }
2979                 }
2980             }
2981             else // split on ordered var
2982             {
2983                 float* values_buf = (float*)(uchar*)inn_buf;
2984                 int* sorted_indices_buf = (int*)(values_buf + n);
2985                 int* sample_indices_buf = sorted_indices_buf + n;
2986                 const float* values = 0;
2987                 const int* sorted_indices = 0;
2988                 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
2989                 int split_point = split->ord.split_point;
2990                 int n1 = node->get_num_valid(vi);
2991 
2992                 assert( 0 <= split_point && split_point < n-1 );
2993 
2994                 for( i = 0; i < n1; i++ )
2995                 {
2996                     int idx = sorted_indices[i];
2997                     if( !dir[idx] )
2998                     {
2999                         int d = i <= split_point ? -1 : 1;
3000                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
3001                         if( --nz )
3002                             break;
3003                     }
3004                 }
3005             }
3006         }
3007     }
3008 
3009     // find the default direction for the rest
3010     if( nz )
3011     {
3012         for( i = nr = 0; i < n; i++ )
3013             nr += dir[i] > 0;
3014         nl = n - nr - nz;
3015         d0 = nl > nr ? -1 : nr > nl;
3016     }
3017 
3018     // make sure that every sample is directed either to the left or to the right
3019     for( i = 0; i < n; i++ )
3020     {
3021         int d = dir[i];
3022         if( !d )
3023         {
3024             d = d0;
3025             if( !d )
3026                 d = d1, d1 = -d1;
3027         }
3028         d = d > 0;
3029         dir[i] = (char)d; // remap (-1,1) to (0,1)
3030     }
3031 }
3032 
3033 
split_node_data(CvDTreeNode * node)3034 void CvDTree::split_node_data( CvDTreeNode* node )
3035 {
3036     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
3037     char* dir = (char*)data->direction->data.ptr;
3038     CvDTreeNode *left = 0, *right = 0;
3039     int* new_idx = data->split_buf->data.i;
3040     int new_buf_idx = data->get_child_buf_idx( node );
3041     int work_var_count = data->get_work_var_count();
3042     CvMat* buf = data->buf;
3043     size_t length_buf_row = data->get_length_subbuf();
3044     cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
3045     int* temp_buf = (int*)(uchar*)inn_buf;
3046 
3047     complete_node_dir(node);
3048 
3049     for( i = nl = nr = 0; i < n; i++ )
3050     {
3051         int d = dir[i];
3052         // initialize new indices for splitting ordered variables
3053         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
3054         nr += d;
3055         nl += d^1;
3056     }
3057 
3058     bool split_input_data;
3059     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
3060     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
3061 
3062     split_input_data = node->depth + 1 < data->params.max_depth &&
3063         (node->left->sample_count > data->params.min_sample_count ||
3064         node->right->sample_count > data->params.min_sample_count);
3065 
3066     // split ordered variables, keep both halves sorted.
3067     for( vi = 0; vi < data->var_count; vi++ )
3068     {
3069         int ci = data->get_var_type(vi);
3070 
3071         if( ci >= 0 || !split_input_data )
3072             continue;
3073 
3074         int n1 = node->get_num_valid(vi);
3075         float* src_val_buf = (float*)(uchar*)(temp_buf + n);
3076         int* src_sorted_idx_buf = (int*)(src_val_buf + n);
3077         int* src_sample_idx_buf = src_sorted_idx_buf + n;
3078         const float* src_val = 0;
3079         const int* src_sorted_idx = 0;
3080         data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
3081 
3082         for(i = 0; i < n; i++)
3083             temp_buf[i] = src_sorted_idx[i];
3084 
3085         if (data->is_buf_16u)
3086         {
3087             unsigned short *ldst, *rdst, *ldst0, *rdst0;
3088             //unsigned short tl, tr;
3089             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
3090                 vi*scount + left->offset);
3091             rdst0 = rdst = (unsigned short*)(ldst + nl);
3092 
3093             // split sorted
3094             for( i = 0; i < n1; i++ )
3095             {
3096                 int idx = temp_buf[i];
3097                 int d = dir[idx];
3098                 idx = new_idx[idx];
3099                 if (d)
3100                 {
3101                     *rdst = (unsigned short)idx;
3102                     rdst++;
3103                 }
3104                 else
3105                 {
3106                     *ldst = (unsigned short)idx;
3107                     ldst++;
3108                 }
3109             }
3110 
3111             left->set_num_valid(vi, (int)(ldst - ldst0));
3112             right->set_num_valid(vi, (int)(rdst - rdst0));
3113 
3114             // split missing
3115             for( ; i < n; i++ )
3116             {
3117                 int idx = temp_buf[i];
3118                 int d = dir[idx];
3119                 idx = new_idx[idx];
3120                 if (d)
3121                 {
3122                     *rdst = (unsigned short)idx;
3123                     rdst++;
3124                 }
3125                 else
3126                 {
3127                     *ldst = (unsigned short)idx;
3128                     ldst++;
3129                 }
3130             }
3131         }
3132         else
3133         {
3134             int *ldst0, *ldst, *rdst0, *rdst;
3135             ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row +
3136                 vi*scount + left->offset;
3137             rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row +
3138                 vi*scount + right->offset;
3139 
3140             // split sorted
3141             for( i = 0; i < n1; i++ )
3142             {
3143                 int idx = temp_buf[i];
3144                 int d = dir[idx];
3145                 idx = new_idx[idx];
3146                 if (d)
3147                 {
3148                     *rdst = idx;
3149                     rdst++;
3150                 }
3151                 else
3152                 {
3153                     *ldst = idx;
3154                     ldst++;
3155                 }
3156             }
3157 
3158             left->set_num_valid(vi, (int)(ldst - ldst0));
3159             right->set_num_valid(vi, (int)(rdst - rdst0));
3160 
3161             // split missing
3162             for( ; i < n; i++ )
3163             {
3164                 int idx = temp_buf[i];
3165                 int d = dir[idx];
3166                 idx = new_idx[idx];
3167                 if (d)
3168                 {
3169                     *rdst = idx;
3170                     rdst++;
3171                 }
3172                 else
3173                 {
3174                     *ldst = idx;
3175                     ldst++;
3176                 }
3177             }
3178         }
3179     }
3180 
3181     // split categorical vars, responses and cv_labels using new_idx relocation table
3182     for( vi = 0; vi < work_var_count; vi++ )
3183     {
3184         int ci = data->get_var_type(vi);
3185         int n1 = node->get_num_valid(vi), nr1 = 0;
3186 
3187         if( ci < 0 || (vi < data->var_count && !split_input_data) )
3188             continue;
3189 
3190         int *src_lbls_buf = temp_buf + n;
3191         const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
3192 
3193         for(i = 0; i < n; i++)
3194             temp_buf[i] = src_lbls[i];
3195 
3196         if (data->is_buf_16u)
3197         {
3198             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
3199                 vi*scount + left->offset);
3200             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
3201                 vi*scount + right->offset);
3202 
3203             for( i = 0; i < n; i++ )
3204             {
3205                 int d = dir[i];
3206                 int idx = temp_buf[i];
3207                 if (d)
3208                 {
3209                     *rdst = (unsigned short)idx;
3210                     rdst++;
3211                     nr1 += (idx != 65535 )&d;
3212                 }
3213                 else
3214                 {
3215                     *ldst = (unsigned short)idx;
3216                     ldst++;
3217                 }
3218             }
3219 
3220             if( vi < data->var_count )
3221             {
3222                 left->set_num_valid(vi, n1 - nr1);
3223                 right->set_num_valid(vi, nr1);
3224             }
3225         }
3226         else
3227         {
3228             int *ldst = buf->data.i + left->buf_idx*length_buf_row +
3229                 vi*scount + left->offset;
3230             int *rdst = buf->data.i + right->buf_idx*length_buf_row +
3231                 vi*scount + right->offset;
3232 
3233             for( i = 0; i < n; i++ )
3234             {
3235                 int d = dir[i];
3236                 int idx = temp_buf[i];
3237                 if (d)
3238                 {
3239                     *rdst = idx;
3240                     rdst++;
3241                     nr1 += (idx >= 0)&d;
3242                 }
3243                 else
3244                 {
3245                     *ldst = idx;
3246                     ldst++;
3247                 }
3248 
3249             }
3250 
3251             if( vi < data->var_count )
3252             {
3253                 left->set_num_valid(vi, n1 - nr1);
3254                 right->set_num_valid(vi, nr1);
3255             }
3256         }
3257     }
3258 
3259 
3260     // split sample indices
3261     int *sample_idx_src_buf = temp_buf + n;
3262     const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
3263 
3264     for(i = 0; i < n; i++)
3265         temp_buf[i] = sample_idx_src[i];
3266 
3267     int pos = data->get_work_var_count();
3268     if (data->is_buf_16u)
3269     {
3270         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
3271             pos*scount + left->offset);
3272         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
3273             pos*scount + right->offset);
3274         for (i = 0; i < n; i++)
3275         {
3276             int d = dir[i];
3277             unsigned short idx = (unsigned short)temp_buf[i];
3278             if (d)
3279             {
3280                 *rdst = idx;
3281                 rdst++;
3282             }
3283             else
3284             {
3285                 *ldst = idx;
3286                 ldst++;
3287             }
3288         }
3289     }
3290     else
3291     {
3292         int* ldst = buf->data.i + left->buf_idx*length_buf_row +
3293             pos*scount + left->offset;
3294         int* rdst = buf->data.i + right->buf_idx*length_buf_row +
3295             pos*scount + right->offset;
3296         for (i = 0; i < n; i++)
3297         {
3298             int d = dir[i];
3299             int idx = temp_buf[i];
3300             if (d)
3301             {
3302                 *rdst = idx;
3303                 rdst++;
3304             }
3305             else
3306             {
3307                 *ldst = idx;
3308                 ldst++;
3309             }
3310         }
3311     }
3312 
3313     // deallocate the parent node data that is not needed anymore
3314     data->free_node_data(node);
3315 }
3316 
calc_error(CvMLData * _data,int type,std::vector<float> * resp)3317 float CvDTree::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
3318 {
3319     float err = 0;
3320     const CvMat* values = _data->get_values();
3321     const CvMat* response = _data->get_responses();
3322     const CvMat* missing = _data->get_missing();
3323     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
3324     const CvMat* var_types = _data->get_var_types();
3325     int* sidx = sample_idx ? sample_idx->data.i : 0;
3326     int r_step = CV_IS_MAT_CONT(response->type) ?
3327                 1 : response->step / CV_ELEM_SIZE(response->type);
3328     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
3329     int sample_count = sample_idx ? sample_idx->cols : 0;
3330     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
3331     float* pred_resp = 0;
3332     if( resp && (sample_count > 0) )
3333     {
3334         resp->resize( sample_count );
3335         pred_resp = &((*resp)[0]);
3336     }
3337 
3338     if ( is_classifier )
3339     {
3340         for( int i = 0; i < sample_count; i++ )
3341         {
3342             CvMat sample, miss;
3343             int si = sidx ? sidx[i] : i;
3344             cvGetRow( values, &sample, si );
3345             if( missing )
3346                 cvGetRow( missing, &miss, si );
3347             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3348             if( pred_resp )
3349                 pred_resp[i] = r;
3350             int d = fabs((double)r - response->data.fl[(size_t)si*r_step]) <= FLT_EPSILON ? 0 : 1;
3351             err += d;
3352         }
3353         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
3354     }
3355     else
3356     {
3357         for( int i = 0; i < sample_count; i++ )
3358         {
3359             CvMat sample, miss;
3360             int si = sidx ? sidx[i] : i;
3361             cvGetRow( values, &sample, si );
3362             if( missing )
3363                 cvGetRow( missing, &miss, si );
3364             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
3365             if( pred_resp )
3366                 pred_resp[i] = r;
3367             float d = r - response->data.fl[(size_t)si*r_step];
3368             err += d*d;
3369         }
3370         err = sample_count ? err / (float)sample_count : -FLT_MAX;
3371     }
3372     return err;
3373 }
3374 
prune_cv()3375 void CvDTree::prune_cv()
3376 {
3377     CvMat* ab = 0;
3378     CvMat* temp = 0;
3379     CvMat* err_jk = 0;
3380 
3381     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
3382     // 2. choose the best tree index (if need, apply 1SE rule).
3383     // 3. store the best index and cut the branches.
3384 
3385     CV_FUNCNAME( "CvDTree::prune_cv" );
3386 
3387     __BEGIN__;
3388 
3389     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
3390     // currently, 1SE for regression is not implemented
3391     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
3392     double* err;
3393     double min_err = 0, min_err_se = 0;
3394     int min_idx = -1;
3395 
3396     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
3397 
3398     // build the main tree sequence, calculate alpha's
3399     for(;;tree_count++)
3400     {
3401         double min_alpha = update_tree_rnc(tree_count, -1);
3402         if( cut_tree(tree_count, -1, min_alpha) )
3403             break;
3404 
3405         if( ab->cols <= tree_count )
3406         {
3407             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
3408             for( ti = 0; ti < ab->cols; ti++ )
3409                 temp->data.db[ti] = ab->data.db[ti];
3410             cvReleaseMat( &ab );
3411             ab = temp;
3412             temp = 0;
3413         }
3414 
3415         ab->data.db[tree_count] = min_alpha;
3416     }
3417 
3418     ab->data.db[0] = 0.;
3419 
3420     if( tree_count > 0 )
3421     {
3422         for( ti = 1; ti < tree_count-1; ti++ )
3423             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
3424         ab->data.db[tree_count-1] = DBL_MAX*0.5;
3425 
3426         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
3427         err = err_jk->data.db;
3428 
3429         for( j = 0; j < cv_n; j++ )
3430         {
3431             int tj = 0, tk = 0;
3432             for( ; tk < tree_count; tj++ )
3433             {
3434                 double min_alpha = update_tree_rnc(tj, j);
3435                 if( cut_tree(tj, j, min_alpha) )
3436                     min_alpha = DBL_MAX;
3437 
3438                 for( ; tk < tree_count; tk++ )
3439                 {
3440                     if( ab->data.db[tk] > min_alpha )
3441                         break;
3442                     err[j*tree_count + tk] = root->tree_error;
3443                 }
3444             }
3445         }
3446 
3447         for( ti = 0; ti < tree_count; ti++ )
3448         {
3449             double sum_err = 0;
3450             for( j = 0; j < cv_n; j++ )
3451                 sum_err += err[j*tree_count + ti];
3452             if( ti == 0 || sum_err < min_err )
3453             {
3454                 min_err = sum_err;
3455                 min_idx = ti;
3456                 if( use_1se )
3457                     min_err_se = sqrt( sum_err*(n - sum_err) );
3458             }
3459             else if( sum_err < min_err + min_err_se )
3460                 min_idx = ti;
3461         }
3462     }
3463 
3464     pruned_tree_idx = min_idx;
3465     free_prune_data(data->params.truncate_pruned_tree != 0);
3466 
3467     __END__;
3468 
3469     cvReleaseMat( &err_jk );
3470     cvReleaseMat( &ab );
3471     cvReleaseMat( &temp );
3472 }
3473 
3474 
update_tree_rnc(int T,int fold)3475 double CvDTree::update_tree_rnc( int T, int fold )
3476 {
3477     CvDTreeNode* node = root;
3478     double min_alpha = DBL_MAX;
3479 
3480     for(;;)
3481     {
3482         CvDTreeNode* parent;
3483         for(;;)
3484         {
3485             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3486             if( t <= T || !node->left )
3487             {
3488                 node->complexity = 1;
3489                 node->tree_risk = node->node_risk;
3490                 node->tree_error = 0.;
3491                 if( fold >= 0 )
3492                 {
3493                     node->tree_risk = node->cv_node_risk[fold];
3494                     node->tree_error = node->cv_node_error[fold];
3495                 }
3496                 break;
3497             }
3498             node = node->left;
3499         }
3500 
3501         for( parent = node->parent; parent && parent->right == node;
3502             node = parent, parent = parent->parent )
3503         {
3504             parent->complexity += node->complexity;
3505             parent->tree_risk += node->tree_risk;
3506             parent->tree_error += node->tree_error;
3507 
3508             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
3509                 - parent->tree_risk)/(parent->complexity - 1);
3510             min_alpha = MIN( min_alpha, parent->alpha );
3511         }
3512 
3513         if( !parent )
3514             break;
3515 
3516         parent->complexity = node->complexity;
3517         parent->tree_risk = node->tree_risk;
3518         parent->tree_error = node->tree_error;
3519         node = parent->right;
3520     }
3521 
3522     return min_alpha;
3523 }
3524 
3525 
cut_tree(int T,int fold,double min_alpha)3526 int CvDTree::cut_tree( int T, int fold, double min_alpha )
3527 {
3528     CvDTreeNode* node = root;
3529     if( !node->left )
3530         return 1;
3531 
3532     for(;;)
3533     {
3534         CvDTreeNode* parent;
3535         for(;;)
3536         {
3537             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
3538             if( t <= T || !node->left )
3539                 break;
3540             if( node->alpha <= min_alpha + FLT_EPSILON )
3541             {
3542                 if( fold >= 0 )
3543                     node->cv_Tn[fold] = T;
3544                 else
3545                     node->Tn = T;
3546                 if( node == root )
3547                     return 1;
3548                 break;
3549             }
3550             node = node->left;
3551         }
3552 
3553         for( parent = node->parent; parent && parent->right == node;
3554             node = parent, parent = parent->parent )
3555             ;
3556 
3557         if( !parent )
3558             break;
3559 
3560         node = parent->right;
3561     }
3562 
3563     return 0;
3564 }
3565 
3566 
free_prune_data(bool _cut_tree)3567 void CvDTree::free_prune_data(bool _cut_tree)
3568 {
3569     CvDTreeNode* node = root;
3570 
3571     for(;;)
3572     {
3573         CvDTreeNode* parent;
3574         for(;;)
3575         {
3576             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
3577             // as we will clear the whole cross-validation heap at the end
3578             node->cv_Tn = 0;
3579             node->cv_node_error = node->cv_node_risk = 0;
3580             if( !node->left )
3581                 break;
3582             node = node->left;
3583         }
3584 
3585         for( parent = node->parent; parent && parent->right == node;
3586             node = parent, parent = parent->parent )
3587         {
3588             if( _cut_tree && parent->Tn <= pruned_tree_idx )
3589             {
3590                 data->free_node( parent->left );
3591                 data->free_node( parent->right );
3592                 parent->left = parent->right = 0;
3593             }
3594         }
3595 
3596         if( !parent )
3597             break;
3598 
3599         node = parent->right;
3600     }
3601 
3602     if( data->cv_heap )
3603         cvClearSet( data->cv_heap );
3604 }
3605 
3606 
free_tree()3607 void CvDTree::free_tree()
3608 {
3609     if( root && data && data->shared )
3610     {
3611         pruned_tree_idx = INT_MIN;
3612         free_prune_data(true);
3613         data->free_node(root);
3614         root = 0;
3615     }
3616 }
3617 
predict(const CvMat * _sample,const CvMat * _missing,bool preprocessed_input) const3618 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
3619     const CvMat* _missing, bool preprocessed_input ) const
3620 {
3621     cv::AutoBuffer<int> catbuf;
3622 
3623     int i, mstep = 0;
3624     const uchar* m = 0;
3625     CvDTreeNode* node = root;
3626 
3627     if( !node )
3628         CV_Error( CV_StsError, "The tree has not been trained yet" );
3629 
3630     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
3631         (_sample->cols != 1 && _sample->rows != 1) ||
3632         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
3633         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
3634             CV_Error( CV_StsBadArg,
3635         "the input sample must be 1d floating-point vector with the same "
3636         "number of elements as the total number of variables used for training" );
3637 
3638     const float* sample = _sample->data.fl;
3639     int step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
3640 
3641     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
3642     {
3643         int n = data->cat_count->cols;
3644         catbuf.allocate(n);
3645         for( i = 0; i < n; i++ )
3646             catbuf[i] = -1;
3647     }
3648 
3649     if( _missing )
3650     {
3651         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
3652             !CV_ARE_SIZES_EQ(_missing, _sample) )
3653             CV_Error( CV_StsBadArg,
3654         "the missing data mask must be 8-bit vector of the same size as input sample" );
3655         m = _missing->data.ptr;
3656         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
3657     }
3658 
3659     const int* vtype = data->var_type->data.i;
3660     const int* vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
3661     const int* cmap = data->cat_map ? data->cat_map->data.i : 0;
3662     const int* cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
3663 
3664     while( node->Tn > pruned_tree_idx && node->left )
3665     {
3666         CvDTreeSplit* split = node->split;
3667         int dir = 0;
3668         for( ; !dir && split != 0; split = split->next )
3669         {
3670             int vi = split->var_idx;
3671             int ci = vtype[vi];
3672             i = vidx ? vidx[vi] : vi;
3673             float val = sample[(size_t)i*step];
3674             if( m && m[(size_t)i*mstep] )
3675                 continue;
3676             if( ci < 0 ) // ordered
3677                 dir = val <= split->ord.c ? -1 : 1;
3678             else // categorical
3679             {
3680                 int c;
3681                 if( preprocessed_input )
3682                     c = cvRound(val);
3683                 else
3684                 {
3685                     c = catbuf[ci];
3686                     if( c < 0 )
3687                     {
3688                         int a = c = cofs[ci];
3689                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
3690 
3691                         int ival = cvRound(val);
3692                         if( ival != val )
3693                             CV_Error( CV_StsBadArg,
3694                             "one of input categorical variable is not an integer" );
3695 
3696                         int sh = 0;
3697                         while( a < b )
3698                         {
3699                             sh++;
3700                             c = (a + b) >> 1;
3701                             if( ival < cmap[c] )
3702                                 b = c;
3703                             else if( ival > cmap[c] )
3704                                 a = c+1;
3705                             else
3706                                 break;
3707                         }
3708 
3709                         if( c < 0 || ival != cmap[c] )
3710                             continue;
3711 
3712                         catbuf[ci] = c -= cofs[ci];
3713                     }
3714                 }
3715                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
3716                 dir = CV_DTREE_CAT_DIR(c, split->subset);
3717             }
3718 
3719             if( split->inversed )
3720                 dir = -dir;
3721         }
3722 
3723         if( !dir )
3724         {
3725             double diff = node->right->sample_count - node->left->sample_count;
3726             dir = diff < 0 ? -1 : 1;
3727         }
3728         node = dir < 0 ? node->left : node->right;
3729     }
3730 
3731     return node;
3732 }
3733 
3734 
predict(const Mat & _sample,const Mat & _missing,bool preprocessed_input) const3735 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
3736 {
3737     CvMat sample = _sample, mmask = _missing;
3738     return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
3739 }
3740 
3741 
get_var_importance()3742 const CvMat* CvDTree::get_var_importance()
3743 {
3744     if( !var_importance )
3745     {
3746         CvDTreeNode* node = root;
3747         double* importance;
3748         if( !node )
3749             return 0;
3750         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
3751         cvZero( var_importance );
3752         importance = var_importance->data.db;
3753 
3754         for(;;)
3755         {
3756             CvDTreeNode* parent;
3757             for( ;; node = node->left )
3758             {
3759                 CvDTreeSplit* split = node->split;
3760 
3761                 if( !node->left || node->Tn <= pruned_tree_idx )
3762                     break;
3763 
3764                 for( ; split != 0; split = split->next )
3765                     importance[split->var_idx] += split->quality;
3766             }
3767 
3768             for( parent = node->parent; parent && parent->right == node;
3769                 node = parent, parent = parent->parent )
3770                 ;
3771 
3772             if( !parent )
3773                 break;
3774 
3775             node = parent->right;
3776         }
3777 
3778         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
3779     }
3780 
3781     return var_importance;
3782 }
3783 
3784 
write_split(CvFileStorage * fs,CvDTreeSplit * split) const3785 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
3786 {
3787     int ci;
3788 
3789     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
3790     cvWriteInt( fs, "var", split->var_idx );
3791     cvWriteReal( fs, "quality", split->quality );
3792 
3793     ci = data->get_var_type(split->var_idx);
3794     if( ci >= 0 ) // split on a categorical var
3795     {
3796         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
3797         for( i = 0; i < n; i++ )
3798             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
3799 
3800         // ad-hoc rule when to use inverse categorical split notation
3801         // to achieve more compact and clear representation
3802         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
3803 
3804         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
3805                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
3806 
3807         for( i = 0; i < n; i++ )
3808         {
3809             int dir = CV_DTREE_CAT_DIR(i,split->subset);
3810             if( dir*default_dir < 0 )
3811                 cvWriteInt( fs, 0, i );
3812         }
3813         cvEndWriteStruct( fs );
3814     }
3815     else
3816         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
3817 
3818     cvEndWriteStruct( fs );
3819 }
3820 
3821 
write_node(CvFileStorage * fs,CvDTreeNode * node) const3822 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
3823 {
3824     CvDTreeSplit* split;
3825 
3826     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
3827 
3828     cvWriteInt( fs, "depth", node->depth );
3829     cvWriteInt( fs, "sample_count", node->sample_count );
3830     cvWriteReal( fs, "value", node->value );
3831 
3832     if( data->is_classifier )
3833         cvWriteInt( fs, "norm_class_idx", node->class_idx );
3834 
3835     cvWriteInt( fs, "Tn", node->Tn );
3836     cvWriteInt( fs, "complexity", node->complexity );
3837     cvWriteReal( fs, "alpha", node->alpha );
3838     cvWriteReal( fs, "node_risk", node->node_risk );
3839     cvWriteReal( fs, "tree_risk", node->tree_risk );
3840     cvWriteReal( fs, "tree_error", node->tree_error );
3841 
3842     if( node->left )
3843     {
3844         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
3845 
3846         for( split = node->split; split != 0; split = split->next )
3847             write_split( fs, split );
3848 
3849         cvEndWriteStruct( fs );
3850     }
3851 
3852     cvEndWriteStruct( fs );
3853 }
3854 
3855 
write_tree_nodes(CvFileStorage * fs) const3856 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
3857 {
3858     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
3859 
3860     __BEGIN__;
3861 
3862     CvDTreeNode* node = root;
3863 
3864     // traverse the tree and save all the nodes in depth-first order
3865     for(;;)
3866     {
3867         CvDTreeNode* parent;
3868         for(;;)
3869         {
3870             write_node( fs, node );
3871             if( !node->left )
3872                 break;
3873             node = node->left;
3874         }
3875 
3876         for( parent = node->parent; parent && parent->right == node;
3877             node = parent, parent = parent->parent )
3878             ;
3879 
3880         if( !parent )
3881             break;
3882 
3883         node = parent->right;
3884     }
3885 
3886     __END__;
3887 }
3888 
3889 
write(CvFileStorage * fs,const char * name) const3890 void CvDTree::write( CvFileStorage* fs, const char* name ) const
3891 {
3892     //CV_FUNCNAME( "CvDTree::write" );
3893 
3894     __BEGIN__;
3895 
3896     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
3897 
3898     //get_var_importance();
3899     data->write_params( fs );
3900     //if( var_importance )
3901     //cvWrite( fs, "var_importance", var_importance );
3902     write( fs );
3903 
3904     cvEndWriteStruct( fs );
3905 
3906     __END__;
3907 }
3908 
3909 
write(CvFileStorage * fs) const3910 void CvDTree::write( CvFileStorage* fs ) const
3911 {
3912     //CV_FUNCNAME( "CvDTree::write" );
3913 
3914     __BEGIN__;
3915 
3916     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
3917 
3918     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
3919     write_tree_nodes( fs );
3920     cvEndWriteStruct( fs );
3921 
3922     __END__;
3923 }
3924 
3925 
read_split(CvFileStorage * fs,CvFileNode * fnode)3926 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
3927 {
3928     CvDTreeSplit* split = 0;
3929 
3930     CV_FUNCNAME( "CvDTree::read_split" );
3931 
3932     __BEGIN__;
3933 
3934     int vi, ci;
3935 
3936     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
3937         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
3938 
3939     vi = cvReadIntByName( fs, fnode, "var", -1 );
3940     if( (unsigned)vi >= (unsigned)data->var_count )
3941         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
3942 
3943     ci = data->get_var_type(vi);
3944     if( ci >= 0 ) // split on categorical var
3945     {
3946         int i, n = data->cat_count->data.i[ci], inversed = 0, val;
3947         CvSeqReader reader;
3948         CvFileNode* inseq;
3949         split = data->new_split_cat( vi, 0 );
3950         inseq = cvGetFileNodeByName( fs, fnode, "in" );
3951         if( !inseq )
3952         {
3953             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
3954             inversed = 1;
3955         }
3956         if( !inseq ||
3957             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
3958             CV_ERROR( CV_StsParseError,
3959             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
3960 
3961         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
3962         {
3963             val = inseq->data.i;
3964             if( (unsigned)val >= (unsigned)n )
3965                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3966 
3967             split->subset[val >> 5] |= 1 << (val & 31);
3968         }
3969         else
3970         {
3971             cvStartReadSeq( inseq->data.seq, &reader );
3972 
3973             for( i = 0; i < reader.seq->total; i++ )
3974             {
3975                 CvFileNode* inode = (CvFileNode*)reader.ptr;
3976                 val = inode->data.i;
3977                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
3978                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
3979 
3980                 split->subset[val >> 5] |= 1 << (val & 31);
3981                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
3982             }
3983         }
3984 
3985         // for categorical splits we do not use inversed splits,
3986         // instead we inverse the variable set in the split
3987         if( inversed )
3988             for( i = 0; i < (n + 31) >> 5; i++ )
3989                 split->subset[i] ^= -1;
3990     }
3991     else
3992     {
3993         CvFileNode* cmp_node;
3994         split = data->new_split_ord( vi, 0, 0, 0, 0 );
3995 
3996         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
3997         if( !cmp_node )
3998         {
3999             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
4000             split->inversed = 1;
4001         }
4002 
4003         split->ord.c = (float)cvReadReal( cmp_node );
4004     }
4005 
4006     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
4007 
4008     __END__;
4009 
4010     return split;
4011 }
4012 
4013 
read_node(CvFileStorage * fs,CvFileNode * fnode,CvDTreeNode * parent)4014 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
4015 {
4016     CvDTreeNode* node = 0;
4017 
4018     CV_FUNCNAME( "CvDTree::read_node" );
4019 
4020     __BEGIN__;
4021 
4022     CvFileNode* splits;
4023     int i, depth;
4024 
4025     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
4026         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
4027 
4028     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
4029     depth = cvReadIntByName( fs, fnode, "depth", -1 );
4030     if( depth != node->depth )
4031         CV_ERROR( CV_StsParseError, "incorrect node depth" );
4032 
4033     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
4034     node->value = cvReadRealByName( fs, fnode, "value" );
4035     if( data->is_classifier )
4036         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
4037 
4038     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
4039     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
4040     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
4041     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
4042     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
4043     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
4044 
4045     splits = cvGetFileNodeByName( fs, fnode, "splits" );
4046     if( splits )
4047     {
4048         CvSeqReader reader;
4049         CvDTreeSplit* last_split = 0;
4050 
4051         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
4052             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
4053 
4054         cvStartReadSeq( splits->data.seq, &reader );
4055         for( i = 0; i < reader.seq->total; i++ )
4056         {
4057             CvDTreeSplit* split;
4058             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
4059             if( !last_split )
4060                 node->split = last_split = split;
4061             else
4062                 last_split = last_split->next = split;
4063 
4064             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4065         }
4066     }
4067 
4068     __END__;
4069 
4070     return node;
4071 }
4072 
4073 
read_tree_nodes(CvFileStorage * fs,CvFileNode * fnode)4074 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
4075 {
4076     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
4077 
4078     __BEGIN__;
4079 
4080     CvSeqReader reader;
4081     CvDTreeNode _root;
4082     CvDTreeNode* parent = &_root;
4083     int i;
4084     parent->left = parent->right = parent->parent = 0;
4085 
4086     cvStartReadSeq( fnode->data.seq, &reader );
4087 
4088     for( i = 0; i < reader.seq->total; i++ )
4089     {
4090         CvDTreeNode* node;
4091 
4092         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
4093         if( !parent->left )
4094             parent->left = node;
4095         else
4096             parent->right = node;
4097         if( node->split )
4098             parent = node;
4099         else
4100         {
4101             while( parent && parent->right )
4102                 parent = parent->parent;
4103         }
4104 
4105         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
4106     }
4107 
4108     root = _root.left;
4109 
4110     __END__;
4111 }
4112 
4113 
read(CvFileStorage * fs,CvFileNode * fnode)4114 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
4115 {
4116     CvDTreeTrainData* _data = new CvDTreeTrainData();
4117     _data->read_params( fs, fnode );
4118 
4119     read( fs, fnode, _data );
4120     get_var_importance();
4121 }
4122 
4123 
4124 // a special entry point for reading weak decision trees from the tree ensembles
read(CvFileStorage * fs,CvFileNode * node,CvDTreeTrainData * _data)4125 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
4126 {
4127     CV_FUNCNAME( "CvDTree::read" );
4128 
4129     __BEGIN__;
4130 
4131     CvFileNode* tree_nodes;
4132 
4133     clear();
4134     data = _data;
4135 
4136     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
4137     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
4138         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
4139 
4140     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
4141     read_tree_nodes( fs, tree_nodes );
4142 
4143     __END__;
4144 }
4145 
getVarImportance()4146 Mat CvDTree::getVarImportance()
4147 {
4148     return cvarrToMat(get_var_importance());
4149 }
4150 
4151 /* End of file. */
4152