1 2 #include "test_precomp.hpp" 3 4 #if 0 5 6 #include <string> 7 #include <fstream> 8 #include <iostream> 9 10 using namespace std; 11 12 13 class CV_GBTreesTest : public cvtest::BaseTest 14 { 15 public: 16 CV_GBTreesTest(); 17 ~CV_GBTreesTest(); 18 19 protected: 20 void run(int); 21 22 int TestTrainPredict(int test_num); 23 int TestSaveLoad(); 24 25 int checkPredictError(int test_num); 26 int checkLoadSave(); 27 28 string model_file_name1; 29 string model_file_name2; 30 31 string* datasets; 32 string data_path; 33 34 CvMLData* data; 35 CvGBTrees* gtb; 36 37 vector<float> test_resps1; 38 vector<float> test_resps2; 39 40 int64 initSeed; 41 }; 42 43 44 int _get_len(const CvMat* mat) 45 { 46 return (mat->cols > mat->rows) ? mat->cols : mat->rows; 47 } 48 49 50 CV_GBTreesTest::CV_GBTreesTest() 51 { 52 int64 seeds[] = { CV_BIG_INT(0x00009fff4f9c8d52), 53 CV_BIG_INT(0x0000a17166072c7c), 54 CV_BIG_INT(0x0201b32115cd1f9a), 55 CV_BIG_INT(0x0513cb37abcd1234), 56 CV_BIG_INT(0x0001a2b3c4d5f678) 57 }; 58 59 int seedCount = sizeof(seeds)/sizeof(seeds[0]); 60 cv::RNG& rng = cv::theRNG(); 61 initSeed = rng.state; 62 rng.state = seeds[rng(seedCount)]; 63 64 datasets = 0; 65 data = 0; 66 gtb = 0; 67 } 68 69 CV_GBTreesTest::~CV_GBTreesTest() 70 { 71 if (data) 72 delete data; 73 delete[] datasets; 74 cv::theRNG().state = initSeed; 75 } 76 77 78 int CV_GBTreesTest::TestTrainPredict(int test_num) 79 { 80 int code = cvtest::TS::OK; 81 82 int weak_count = 200; 83 float shrinkage = 0.1f; 84 float subsample_portion = 0.5f; 85 int max_depth = 5; 86 bool use_surrogates = false; 87 int loss_function_type = 0; 88 switch (test_num) 89 { 90 case (1) : loss_function_type = CvGBTrees::SQUARED_LOSS; break; 91 case (2) : loss_function_type = CvGBTrees::ABSOLUTE_LOSS; break; 92 case (3) : loss_function_type = CvGBTrees::HUBER_LOSS; break; 93 case (0) : loss_function_type = CvGBTrees::DEVIANCE_LOSS; break; 94 default : 95 { 96 ts->printf( cvtest::TS::LOG, "Bad test_num value in CV_GBTreesTest::TestTrainPredict(..) function." ); 97 return cvtest::TS::FAIL_BAD_ARG_CHECK; 98 } 99 } 100 101 int dataset_num = test_num == 0 ? 0 : 1; 102 if (!data) 103 { 104 data = new CvMLData(); 105 data->set_delimiter(','); 106 107 if (data->read_csv(datasets[dataset_num].c_str())) 108 { 109 ts->printf( cvtest::TS::LOG, "File reading error." ); 110 return cvtest::TS::FAIL_INVALID_TEST_DATA; 111 } 112 113 if (test_num == 0) 114 { 115 data->set_response_idx(57); 116 data->set_var_types("ord[0-56],cat[57]"); 117 } 118 else 119 { 120 data->set_response_idx(13); 121 data->set_var_types("ord[0-2,4-13],cat[3]"); 122 subsample_portion = 0.7f; 123 } 124 125 int train_sample_count = cvFloor(_get_len(data->get_responses())*0.5f); 126 CvTrainTestSplit spl( train_sample_count ); 127 data->set_train_test_split( &spl ); 128 } 129 130 data->mix_train_and_test_idx(); 131 132 133 if (gtb) delete gtb; 134 gtb = new CvGBTrees(); 135 bool tmp_code = true; 136 tmp_code = gtb->train(data, CvGBTreesParams(loss_function_type, weak_count, 137 shrinkage, subsample_portion, 138 max_depth, use_surrogates)); 139 140 if (!tmp_code) 141 { 142 ts->printf( cvtest::TS::LOG, "Model training was failed."); 143 return cvtest::TS::FAIL_INVALID_OUTPUT; 144 } 145 146 code = checkPredictError(test_num); 147 148 return code; 149 150 } 151 152 153 int CV_GBTreesTest::checkPredictError(int test_num) 154 { 155 if (!gtb) 156 return cvtest::TS::FAIL_GENERIC; 157 158 //float mean[] = {5.430247f, 13.5654f, 12.6569f, 13.1661f}; 159 //float sigma[] = {0.4162694f, 3.21161f, 3.43297f, 3.00624f}; 160 float mean[] = {5.80226f, 12.68689f, 13.49095f, 13.19628f}; 161 float sigma[] = {0.4764534f, 3.166919f, 3.022405f, 2.868722f}; 162 163 float current_error = gtb->calc_error(data, CV_TEST_ERROR); 164 165 if ( abs( current_error - mean[test_num]) > 6*sigma[test_num] ) 166 { 167 ts->printf( cvtest::TS::LOG, "Test error is out of range:\n" 168 "abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/", 169 current_error, mean[test_num], 6*sigma[test_num] ); 170 return cvtest::TS::FAIL_BAD_ACCURACY; 171 } 172 173 return cvtest::TS::OK; 174 175 } 176 177 178 int CV_GBTreesTest::TestSaveLoad() 179 { 180 if (!gtb) 181 return cvtest::TS::FAIL_GENERIC; 182 183 model_file_name1 = cv::tempfile(); 184 model_file_name2 = cv::tempfile(); 185 186 gtb->save(model_file_name1.c_str()); 187 gtb->calc_error(data, CV_TEST_ERROR, &test_resps1); 188 gtb->load(model_file_name1.c_str()); 189 gtb->calc_error(data, CV_TEST_ERROR, &test_resps2); 190 gtb->save(model_file_name2.c_str()); 191 192 return checkLoadSave(); 193 194 } 195 196 197 198 int CV_GBTreesTest::checkLoadSave() 199 { 200 int code = cvtest::TS::OK; 201 202 // 1. compare files 203 ifstream f1( model_file_name1.c_str() ), f2( model_file_name2.c_str() ); 204 string s1, s2; 205 int lineIdx = 0; 206 CV_Assert( f1.is_open() && f2.is_open() ); 207 for( ; !f1.eof() && !f2.eof(); lineIdx++ ) 208 { 209 getline( f1, s1 ); 210 getline( f2, s2 ); 211 if( s1.compare(s2) ) 212 { 213 ts->printf( cvtest::TS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s", 214 lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() ); 215 code = cvtest::TS::FAIL_INVALID_OUTPUT; 216 } 217 } 218 if( !f1.eof() || !f2.eof() ) 219 { 220 ts->printf( cvtest::TS::LOG, "First and second saved files differ in %n-line; first %n line: %s; second %n-line: %s", 221 lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() ); 222 code = cvtest::TS::FAIL_INVALID_OUTPUT; 223 } 224 f1.close(); 225 f2.close(); 226 // delete temporary files 227 remove( model_file_name1.c_str() ); 228 remove( model_file_name2.c_str() ); 229 230 // 2. compare responses 231 CV_Assert( test_resps1.size() == test_resps2.size() ); 232 vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin(); 233 for( ; it1 != test_resps1.end(); ++it1, ++it2 ) 234 { 235 if( fabs(*it1 - *it2) > FLT_EPSILON ) 236 { 237 ts->printf( cvtest::TS::LOG, "Responses predicted before saving and after loading are different" ); 238 code = cvtest::TS::FAIL_INVALID_OUTPUT; 239 } 240 } 241 return code; 242 } 243 244 245 246 void CV_GBTreesTest::run(int) 247 { 248 249 string dataPath = string(ts->get_data_path()); 250 datasets = new string[2]; 251 datasets[0] = dataPath + string("spambase.data"); /*string("dataset_classification.csv");*/ 252 datasets[1] = dataPath + string("housing_.data"); /*string("dataset_regression.csv");*/ 253 254 int code = cvtest::TS::OK; 255 256 for (int i = 0; i < 4; i++) 257 { 258 259 int temp_code = TestTrainPredict(i); 260 if (temp_code != cvtest::TS::OK) 261 { 262 code = temp_code; 263 break; 264 } 265 266 else if (i==0) 267 { 268 temp_code = TestSaveLoad(); 269 if (temp_code != cvtest::TS::OK) 270 code = temp_code; 271 delete data; 272 data = 0; 273 } 274 275 delete gtb; 276 gtb = 0; 277 } 278 delete data; 279 data = 0; 280 281 ts->set_failed_test_info( code ); 282 } 283 284 ///////////////////////////////////////////////////////////////////////////// 285 //////////////////// test registration ///////////////////////////////////// 286 ///////////////////////////////////////////////////////////////////////////// 287 288 TEST(ML_GBTrees, regression) { CV_GBTreesTest test; test.safe_run(); } 289 290 #endif 291