• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 #include "test_precomp.hpp"
3 
4 #if 0
5 
6 #include <string>
7 #include <fstream>
8 #include <iostream>
9 
10 using namespace std;
11 
12 
13 class CV_GBTreesTest : public cvtest::BaseTest
14 {
15 public:
16     CV_GBTreesTest();
17     ~CV_GBTreesTest();
18 
19 protected:
20     void run(int);
21 
22     int TestTrainPredict(int test_num);
23     int TestSaveLoad();
24 
25     int checkPredictError(int test_num);
26     int checkLoadSave();
27 
28     string model_file_name1;
29     string model_file_name2;
30 
31     string* datasets;
32     string data_path;
33 
34     CvMLData* data;
35     CvGBTrees* gtb;
36 
37     vector<float> test_resps1;
38     vector<float> test_resps2;
39 
40     int64 initSeed;
41 };
42 
43 
44 int _get_len(const CvMat* mat)
45 {
46     return (mat->cols > mat->rows) ? mat->cols : mat->rows;
47 }
48 
49 
50 CV_GBTreesTest::CV_GBTreesTest()
51 {
52     int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52),
53                       CV_BIG_INT(0x0000a17166072c7c),
54                       CV_BIG_INT(0x0201b32115cd1f9a),
55                       CV_BIG_INT(0x0513cb37abcd1234),
56                       CV_BIG_INT(0x0001a2b3c4d5f678)
57                     };
58 
59     int seedCount = sizeof(seeds)/sizeof(seeds[0]);
60     cv::RNG& rng = cv::theRNG();
61     initSeed = rng.state;
62     rng.state = seeds[rng(seedCount)];
63 
64     datasets = 0;
65     data = 0;
66     gtb = 0;
67 }
68 
69 CV_GBTreesTest::~CV_GBTreesTest()
70 {
71     if (data)
72         delete data;
73     delete[] datasets;
74     cv::theRNG().state = initSeed;
75 }
76 
77 
78 int CV_GBTreesTest::TestTrainPredict(int test_num)
79 {
80     int code = cvtest::TS::OK;
81 
82     int weak_count = 200;
83     float shrinkage = 0.1f;
84     float subsample_portion = 0.5f;
85     int max_depth = 5;
86     bool use_surrogates = false;
87     int loss_function_type = 0;
88     switch (test_num)
89     {
90         case (1) : loss_function_type = CvGBTrees::SQUARED_LOSS; break;
91         case (2) : loss_function_type = CvGBTrees::ABSOLUTE_LOSS; break;
92         case (3) : loss_function_type = CvGBTrees::HUBER_LOSS; break;
93         case (0) : loss_function_type = CvGBTrees::DEVIANCE_LOSS; break;
94         default  :
95             {
96             ts->printf( cvtest::TS::LOG, "Bad test_num value in CV_GBTreesTest::TestTrainPredict(..) function." );
97             return cvtest::TS::FAIL_BAD_ARG_CHECK;
98             }
99     }
100 
101     int dataset_num = test_num == 0 ? 0 : 1;
102     if (!data)
103     {
104         data = new CvMLData();
105         data->set_delimiter(',');
106 
107         if (data->read_csv(datasets[dataset_num].c_str()))
108         {
109             ts->printf( cvtest::TS::LOG, "File reading error." );
110             return cvtest::TS::FAIL_INVALID_TEST_DATA;
111         }
112 
113         if (test_num == 0)
114         {
115             data->set_response_idx(57);
116             data->set_var_types("ord[0-56],cat[57]");
117         }
118         else
119         {
120             data->set_response_idx(13);
121             data->set_var_types("ord[0-2,4-13],cat[3]");
122             subsample_portion = 0.7f;
123         }
124 
125         int train_sample_count = cvFloor(_get_len(data->get_responses())*0.5f);
126         CvTrainTestSplit spl( train_sample_count );
127         data->set_train_test_split( &spl );
128     }
129 
130     data->mix_train_and_test_idx();
131 
132 
133     if (gtb) delete gtb;
134     gtb = new CvGBTrees();
135     bool tmp_code = true;
136     tmp_code = gtb->train(data, CvGBTreesParams(loss_function_type, weak_count,
137                           shrinkage, subsample_portion,
138                           max_depth, use_surrogates));
139 
140     if (!tmp_code)
141     {
142         ts->printf( cvtest::TS::LOG, "Model training was failed.");
143         return cvtest::TS::FAIL_INVALID_OUTPUT;
144     }
145 
146     code = checkPredictError(test_num);
147 
148     return code;
149 
150 }
151 
152 
153 int CV_GBTreesTest::checkPredictError(int test_num)
154 {
155     if (!gtb)
156         return cvtest::TS::FAIL_GENERIC;
157 
158     //float mean[] = {5.430247f, 13.5654f, 12.6569f, 13.1661f};
159     //float sigma[] = {0.4162694f, 3.21161f, 3.43297f, 3.00624f};
160     float mean[] = {5.80226f, 12.68689f, 13.49095f, 13.19628f};
161     float sigma[] = {0.4764534f, 3.166919f, 3.022405f, 2.868722f};
162 
163     float current_error = gtb->calc_error(data, CV_TEST_ERROR);
164 
165     if ( abs( current_error - mean[test_num]) > 6*sigma[test_num] )
166     {
167         ts->printf( cvtest::TS::LOG, "Test error is out of range:\n"
168                     "abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/",
169                     current_error, mean[test_num], 6*sigma[test_num] );
170         return cvtest::TS::FAIL_BAD_ACCURACY;
171     }
172 
173     return cvtest::TS::OK;
174 
175 }
176 
177 
178 int CV_GBTreesTest::TestSaveLoad()
179 {
180     if (!gtb)
181         return cvtest::TS::FAIL_GENERIC;
182 
183     model_file_name1 = cv::tempfile();
184     model_file_name2 = cv::tempfile();
185 
186     gtb->save(model_file_name1.c_str());
187     gtb->calc_error(data, CV_TEST_ERROR, &test_resps1);
188     gtb->load(model_file_name1.c_str());
189     gtb->calc_error(data, CV_TEST_ERROR, &test_resps2);
190     gtb->save(model_file_name2.c_str());
191 
192     return checkLoadSave();
193 
194 }
195 
196 
197 
198 int CV_GBTreesTest::checkLoadSave()
199 {
200     int code = cvtest::TS::OK;
201 
202     // 1. compare files
203     ifstream f1( model_file_name1.c_str() ), f2( model_file_name2.c_str() );
204     string s1, s2;
205     int lineIdx = 0;
206     CV_Assert( f1.is_open() && f2.is_open() );
207     for( ; !f1.eof() && !f2.eof(); lineIdx++ )
208     {
209         getline( f1, s1 );
210         getline( f2, s2 );
211         if( s1.compare(s2) )
212         {
213             ts->printf( cvtest::TS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
214                lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
215             code = cvtest::TS::FAIL_INVALID_OUTPUT;
216         }
217     }
218     if( !f1.eof() || !f2.eof() )
219     {
220         ts->printf( cvtest::TS::LOG, "First and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
221             lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
222         code = cvtest::TS::FAIL_INVALID_OUTPUT;
223     }
224     f1.close();
225     f2.close();
226     // delete temporary files
227     remove( model_file_name1.c_str() );
228     remove( model_file_name2.c_str() );
229 
230     // 2. compare responses
231     CV_Assert( test_resps1.size() == test_resps2.size() );
232     vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
233     for( ; it1 != test_resps1.end(); ++it1, ++it2 )
234     {
235         if( fabs(*it1 - *it2) > FLT_EPSILON )
236         {
237             ts->printf( cvtest::TS::LOG, "Responses predicted before saving and after loading are different" );
238             code = cvtest::TS::FAIL_INVALID_OUTPUT;
239         }
240     }
241     return code;
242 }
243 
244 
245 
246 void CV_GBTreesTest::run(int)
247 {
248 
249     string dataPath = string(ts->get_data_path());
250     datasets = new string[2];
251     datasets[0] = dataPath + string("spambase.data"); /*string("dataset_classification.csv");*/
252     datasets[1] = dataPath + string("housing_.data");  /*string("dataset_regression.csv");*/
253 
254     int code = cvtest::TS::OK;
255 
256     for (int i = 0; i < 4; i++)
257     {
258 
259         int temp_code = TestTrainPredict(i);
260         if (temp_code != cvtest::TS::OK)
261         {
262             code = temp_code;
263             break;
264         }
265 
266         else if (i==0)
267         {
268             temp_code = TestSaveLoad();
269             if (temp_code != cvtest::TS::OK)
270                 code = temp_code;
271             delete data;
272             data = 0;
273         }
274 
275         delete gtb;
276         gtb = 0;
277     }
278     delete data;
279     data = 0;
280 
281     ts->set_failed_test_info( code );
282 }
283 
284 /////////////////////////////////////////////////////////////////////////////
285 //////////////////// test registration  /////////////////////////////////////
286 /////////////////////////////////////////////////////////////////////////////
287 
288 TEST(ML_GBTrees, regression) { CV_GBTreesTest test; test.safe_run(); }
289 
290 #endif
291