• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "opencv2/ml/ml.hpp"
2 #include "opencv2/core/core.hpp"
3 #include "opencv2/core/utility.hpp"
4 #include <stdio.h>
5 #include <string>
6 #include <map>
7 
8 using namespace cv;
9 using namespace cv::ml;
10 
help()11 static void help()
12 {
13     printf(
14         "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
15         "Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
16         "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
17         "-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
18         "<csv filename> is the name of training data file in comma-separated value format\n\n");
19 }
20 
train_and_print_errs(Ptr<StatModel> model,const Ptr<TrainData> & data)21 static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
22 {
23     bool ok = model->train(data);
24     if( !ok )
25     {
26         printf("Training failed\n");
27     }
28     else
29     {
30         printf( "train error: %f\n", model->calcError(data, false, noArray()) );
31         printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
32     }
33 }
34 
main(int argc,char ** argv)35 int main(int argc, char** argv)
36 {
37     if(argc < 2)
38     {
39         help();
40         return 0;
41     }
42     const char* filename = 0;
43     int response_idx = 0;
44     std::string typespec;
45 
46     for(int i = 1; i < argc; i++)
47     {
48         if(strcmp(argv[i], "-r") == 0)
49             sscanf(argv[++i], "%d", &response_idx);
50         else if(strcmp(argv[i], "-ts") == 0)
51             typespec = argv[++i];
52         else if(argv[i][0] != '-' )
53             filename = argv[i];
54         else
55         {
56             printf("Error. Invalid option %s\n", argv[i]);
57             help();
58             return -1;
59         }
60     }
61 
62     printf("\nReading in %s...\n\n",filename);
63     const double train_test_split_ratio = 0.5;
64 
65     Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
66 
67     if( data.empty() )
68     {
69         printf("ERROR: File %s can not be read\n", filename);
70         return 0;
71     }
72 
73     data->setTrainTestSplitRatio(train_test_split_ratio);
74 
75     printf("======DTREE=====\n");
76     Ptr<DTrees> dtree = DTrees::create();
77     dtree->setMaxDepth(10);
78     dtree->setMinSampleCount(2);
79     dtree->setRegressionAccuracy(0);
80     dtree->setUseSurrogates(false);
81     dtree->setMaxCategories(16);
82     dtree->setCVFolds(0);
83     dtree->setUse1SERule(false);
84     dtree->setTruncatePrunedTree(false);
85     dtree->setPriors(Mat());
86     train_and_print_errs(dtree, data);
87 
88     if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
89     {
90         printf("======BOOST=====\n");
91         Ptr<Boost> boost = Boost::create();
92         boost->setBoostType(Boost::GENTLE);
93         boost->setWeakCount(100);
94         boost->setWeightTrimRate(0.95);
95         boost->setMaxDepth(2);
96         boost->setUseSurrogates(false);
97         boost->setPriors(Mat());
98         train_and_print_errs(boost, data);
99     }
100 
101     printf("======RTREES=====\n");
102     Ptr<RTrees> rtrees = RTrees::create();
103     rtrees->setMaxDepth(10);
104     rtrees->setMinSampleCount(2);
105     rtrees->setRegressionAccuracy(0);
106     rtrees->setUseSurrogates(false);
107     rtrees->setMaxCategories(16);
108     rtrees->setPriors(Mat());
109     rtrees->setCalculateVarImportance(false);
110     rtrees->setActiveVarCount(0);
111     rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
112     train_and_print_errs(rtrees, data);
113 
114     return 0;
115 }
116