• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41 
42 #include "test_precomp.hpp"
43 
44 #include <iostream>
45 #include <fstream>
46 
47 using namespace cv;
48 using namespace std;
49 
CV_SLMLTest(const char * _modelName)50 CV_SLMLTest::CV_SLMLTest( const char* _modelName ) : CV_MLBaseTest( _modelName )
51 {
52     validationFN = "slvalidation.xml";
53 }
54 
run_test_case(int testCaseIdx)55 int CV_SLMLTest::run_test_case( int testCaseIdx )
56 {
57     int code = cvtest::TS::OK;
58     code = prepare_test_case( testCaseIdx );
59 
60     if( code == cvtest::TS::OK )
61     {
62         data->setTrainTestSplit(data->getNTrainSamples(), true);
63         code = train( testCaseIdx );
64         if( code == cvtest::TS::OK )
65         {
66             get_test_error( testCaseIdx, &test_resps1 );
67             fname1 = tempfile(".yml.gz");
68             save( fname1.c_str() );
69             load( fname1.c_str() );
70             get_test_error( testCaseIdx, &test_resps2 );
71             fname2 = tempfile(".yml.gz");
72             save( fname2.c_str() );
73         }
74         else
75             ts->printf( cvtest::TS::LOG, "model can not be trained" );
76     }
77     return code;
78 }
79 
validate_test_results(int testCaseIdx)80 int CV_SLMLTest::validate_test_results( int testCaseIdx )
81 {
82     int code = cvtest::TS::OK;
83 
84     // 1. compare files
85     FILE *fs1 = fopen(fname1.c_str(), "rb"), *fs2 = fopen(fname2.c_str(), "rb");
86     size_t sz1 = 0, sz2 = 0;
87     if( !fs1 || !fs2 )
88         code = cvtest::TS::FAIL_MISSING_TEST_DATA;
89     if( code >= 0 )
90     {
91         fseek(fs1, 0, SEEK_END); fseek(fs2, 0, SEEK_END);
92         sz1 = ftell(fs1);
93         sz2 = ftell(fs2);
94         fseek(fs1, 0, SEEK_SET); fseek(fs2, 0, SEEK_SET);
95     }
96 
97     if( sz1 != sz2 )
98         code = cvtest::TS::FAIL_INVALID_OUTPUT;
99 
100     if( code >= 0 )
101     {
102         const int BUFSZ = 1024;
103         uchar buf1[BUFSZ], buf2[BUFSZ];
104         for( size_t pos = 0; pos < sz1;  )
105         {
106             size_t r1 = fread(buf1, 1, BUFSZ, fs1);
107             size_t r2 = fread(buf2, 1, BUFSZ, fs2);
108             if( r1 != r2 || memcmp(buf1, buf2, r1) != 0 )
109             {
110                 ts->printf( cvtest::TS::LOG,
111                            "in test case %d first (%s) and second (%s) saved files differ in %d-th kb\n",
112                            testCaseIdx, fname1.c_str(), fname2.c_str(),
113                            (int)pos );
114                 code = cvtest::TS::FAIL_INVALID_OUTPUT;
115                 break;
116             }
117             pos += r1;
118         }
119     }
120 
121     if(fs1)
122         fclose(fs1);
123     if(fs2)
124         fclose(fs2);
125 
126     // delete temporary files
127     if( code >= 0 )
128     {
129         remove( fname1.c_str() );
130         remove( fname2.c_str() );
131     }
132 
133     if( code >= 0 )
134     {
135         // 2. compare responses
136         CV_Assert( test_resps1.size() == test_resps2.size() );
137         vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
138         for( ; it1 != test_resps1.end(); ++it1, ++it2 )
139         {
140             if( fabs(*it1 - *it2) > FLT_EPSILON )
141             {
142                 ts->printf( cvtest::TS::LOG, "in test case %d responses predicted before saving and after loading is different", testCaseIdx );
143                 code = cvtest::TS::FAIL_INVALID_OUTPUT;
144                 break;
145             }
146         }
147     }
148     return code;
149 }
150 
TEST(ML_NaiveBayes,save_load)151 TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); }
TEST(ML_KNearest,save_load)152 TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); }
TEST(ML_SVM,save_load)153 TEST(ML_SVM, save_load) { CV_SLMLTest test( CV_SVM ); test.safe_run(); }
TEST(ML_ANN,save_load)154 TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); }
TEST(ML_DTree,save_load)155 TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); }
TEST(ML_Boost,save_load)156 TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
TEST(ML_RTrees,save_load)157 TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
TEST(DISABLED_ML_ERTrees,save_load)158 TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
159 
160 class CV_LegacyTest : public cvtest::BaseTest
161 {
162 public:
CV_LegacyTest(const std::string & _modelName,const std::string & _suffixes=std::string ())163     CV_LegacyTest(const std::string &_modelName, const std::string &_suffixes = std::string())
164         : cvtest::BaseTest(), modelName(_modelName), suffixes(_suffixes)
165     {
166     }
~CV_LegacyTest()167     virtual ~CV_LegacyTest() {}
168 protected:
run(int)169     void run(int)
170     {
171         unsigned int idx = 0;
172         for (;;)
173         {
174             if (idx >= suffixes.size())
175                 break;
176             int found = (int)suffixes.find(';', idx);
177             string piece = suffixes.substr(idx, found - idx);
178             if (piece.empty())
179                 break;
180             oneTest(piece);
181             idx += (unsigned int)piece.size() + 1;
182         }
183     }
oneTest(const string & suffix)184     void oneTest(const string & suffix)
185     {
186         using namespace cv::ml;
187 
188         int code = cvtest::TS::OK;
189         string filename = ts->get_data_path() + "legacy/" + modelName + suffix;
190         bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
191         Ptr<StatModel> model;
192         if (modelName == CV_BOOST)
193             model = Algorithm::load<Boost>(filename);
194         else if (modelName == CV_ANN)
195             model = Algorithm::load<ANN_MLP>(filename);
196         else if (modelName == CV_DTREE)
197             model = Algorithm::load<DTrees>(filename);
198         else if (modelName == CV_NBAYES)
199             model = Algorithm::load<NormalBayesClassifier>(filename);
200         else if (modelName == CV_SVM)
201             model = Algorithm::load<SVM>(filename);
202         else if (modelName == CV_RTREES)
203             model = Algorithm::load<RTrees>(filename);
204         if (!model)
205         {
206             code = cvtest::TS::FAIL_INVALID_TEST_DATA;
207         }
208         else
209         {
210             Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
211             ts->get_rng().fill(input, RNG::UNIFORM, 0, 40);
212 
213             if (isTree)
214                 randomFillCategories(filename, input);
215 
216             Mat output;
217             model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0));
218             // just check if no internal assertions or errors thrown
219         }
220         ts->set_failed_test_info(code);
221     }
randomFillCategories(const string & filename,Mat & input)222     void randomFillCategories(const string & filename, Mat & input)
223     {
224         Mat catMap;
225         Mat catCount;
226         std::vector<uchar> varTypes;
227 
228         FileStorage fs(filename, FileStorage::READ);
229         FileNode root = fs.getFirstTopLevelNode();
230         root["cat_map"] >> catMap;
231         root["cat_count"] >> catCount;
232         root["var_type"] >> varTypes;
233 
234         int offset = 0;
235         int countOffset = 0;
236         uint var = 0, varCount = (uint)varTypes.size();
237         for (; var < varCount; ++var)
238         {
239             if (varTypes[var] == ml::VAR_CATEGORICAL)
240             {
241                 int size = catCount.at<int>(0, countOffset);
242                 for (int row = 0; row < input.rows; ++row)
243                 {
244                     int randomChosenIndex = offset + ((uint)ts->get_rng()) % size;
245                     int value = catMap.at<int>(0, randomChosenIndex);
246                     input.at<float>(row, var) = (float)value;
247                 }
248                 offset += size;
249                 ++countOffset;
250             }
251         }
252     }
253     string modelName;
254     string suffixes;
255 };
256 
TEST(ML_ANN,legacy_load)257 TEST(ML_ANN, legacy_load) { CV_LegacyTest test(CV_ANN, "_waveform.xml"); test.safe_run(); }
TEST(ML_Boost,legacy_load)258 TEST(ML_Boost, legacy_load) { CV_LegacyTest test(CV_BOOST, "_adult.xml;_1.xml;_2.xml;_3.xml"); test.safe_run(); }
TEST(ML_DTree,legacy_load)259 TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushroom.xml"); test.safe_run(); }
TEST(ML_NBayes,legacy_load)260 TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
TEST(ML_SVM,legacy_load)261 TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
TEST(ML_RTrees,legacy_load)262 TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
263 
264 /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
265 {
266     Ptr<cv::ml::SVM> svm;
267     string filename = tempfile("svm.xml");
268     ASSERT_THROW(svm.save(filename.c_str()), Exception);
269     remove(filename.c_str());
270 }*/
271 
TEST(DISABLED_ML_SVM,linear_save_load)272 TEST(DISABLED_ML_SVM, linear_save_load)
273 {
274     Ptr<cv::ml::SVM> svm1, svm2, svm3;
275 
276     svm1 = Algorithm::load<SVM>("SVM45_X_38-1.xml");
277     svm2 = Algorithm::load<SVM>("SVM45_X_38-2.xml");
278     string tname = tempfile("a.xml");
279     svm2->save(tname);
280     svm3 = Algorithm::load<SVM>(tname);
281 
282     ASSERT_EQ(svm1->getVarCount(), svm2->getVarCount());
283     ASSERT_EQ(svm1->getVarCount(), svm3->getVarCount());
284 
285     int m = 10000, n = svm1->getVarCount();
286     Mat samples(m, n, CV_32F), r1, r2, r3;
287     randu(samples, 0., 1.);
288 
289     svm1->predict(samples, r1);
290     svm2->predict(samples, r2);
291     svm3->predict(samples, r3);
292 
293     double eps = 1e-4;
294     EXPECT_LE(norm(r1, r2, NORM_INF), eps);
295     EXPECT_LE(norm(r1, r3, NORM_INF), eps);
296 
297     remove(tname.c_str());
298 }
299 
300 /* End of file. */
301