• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #include "precomp.hpp"
42 #include <ctype.h>
43 #include <algorithm>
44 #include <iterator>
45 
46 namespace cv { namespace ml {
47 
48 static const float MISSED_VAL = TrainData::missingValue();
49 static const int VAR_MISSED = VAR_ORDERED;
50 
~TrainData()51 TrainData::~TrainData() {}
52 
getSubVector(const Mat & vec,const Mat & idx)53 Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
54 {
55     if( idx.empty() )
56         return vec;
57     int i, j, n = idx.checkVector(1, CV_32S);
58     int type = vec.type();
59     CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
60     int dims = 1, m;
61 
62     if( vec.cols == 1 || vec.rows == 1 )
63     {
64         dims = 1;
65         m = vec.cols + vec.rows - 1;
66     }
67     else
68     {
69         dims = vec.cols;
70         m = vec.rows;
71     }
72 
73     Mat subvec;
74 
75     if( vec.cols == m )
76         subvec.create(dims, n, type);
77     else
78         subvec.create(n, dims, type);
79     if( type == CV_32S )
80         for( i = 0; i < n; i++ )
81         {
82             int k = idx.at<int>(i);
83             CV_Assert( 0 <= k && k < m );
84             if( dims == 1 )
85                 subvec.at<int>(i) = vec.at<int>(k);
86             else
87                 for( j = 0; j < dims; j++ )
88                     subvec.at<int>(i, j) = vec.at<int>(k, j);
89         }
90     else if( type == CV_32F )
91         for( i = 0; i < n; i++ )
92         {
93             int k = idx.at<int>(i);
94             CV_Assert( 0 <= k && k < m );
95             if( dims == 1 )
96                 subvec.at<float>(i) = vec.at<float>(k);
97             else
98                 for( j = 0; j < dims; j++ )
99                     subvec.at<float>(i, j) = vec.at<float>(k, j);
100         }
101     else
102         for( i = 0; i < n; i++ )
103         {
104             int k = idx.at<int>(i);
105             CV_Assert( 0 <= k && k < m );
106             if( dims == 1 )
107                 subvec.at<double>(i) = vec.at<double>(k);
108             else
109                 for( j = 0; j < dims; j++ )
110                     subvec.at<double>(i, j) = vec.at<double>(k, j);
111         }
112     return subvec;
113 }
114 
115 class TrainDataImpl : public TrainData
116 {
117 public:
118     typedef std::map<String, int> MapType;
119 
TrainDataImpl()120     TrainDataImpl()
121     {
122         file = 0;
123         clear();
124     }
125 
~TrainDataImpl()126     virtual ~TrainDataImpl() { closeFile(); }
127 
getLayout() const128     int getLayout() const { return layout; }
getNSamples() const129     int getNSamples() const
130     {
131         return !sampleIdx.empty() ? (int)sampleIdx.total() :
132                layout == ROW_SAMPLE ? samples.rows : samples.cols;
133     }
getNTrainSamples() const134     int getNTrainSamples() const
135     {
136         return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
137     }
getNTestSamples() const138     int getNTestSamples() const
139     {
140         return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
141     }
getNVars() const142     int getNVars() const
143     {
144         return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
145     }
getNAllVars() const146     int getNAllVars() const
147     {
148         return layout == ROW_SAMPLE ? samples.cols : samples.rows;
149     }
150 
getSamples() const151     Mat getSamples() const { return samples; }
getResponses() const152     Mat getResponses() const { return responses; }
getMissing() const153     Mat getMissing() const { return missing; }
getVarIdx() const154     Mat getVarIdx() const { return varIdx; }
getVarType() const155     Mat getVarType() const { return varType; }
getResponseType() const156     int getResponseType() const
157     {
158         return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
159     }
getTrainSampleIdx() const160     Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
getTestSampleIdx() const161     Mat getTestSampleIdx() const { return testSampleIdx; }
getSampleWeights() const162     Mat getSampleWeights() const
163     {
164         return sampleWeights;
165     }
getTrainSampleWeights() const166     Mat getTrainSampleWeights() const
167     {
168         return getSubVector(sampleWeights, getTrainSampleIdx());
169     }
getTestSampleWeights() const170     Mat getTestSampleWeights() const
171     {
172         Mat idx = getTestSampleIdx();
173         return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
174     }
getTrainResponses() const175     Mat getTrainResponses() const
176     {
177         return getSubVector(responses, getTrainSampleIdx());
178     }
getTrainNormCatResponses() const179     Mat getTrainNormCatResponses() const
180     {
181         return getSubVector(normCatResponses, getTrainSampleIdx());
182     }
getTestResponses() const183     Mat getTestResponses() const
184     {
185         Mat idx = getTestSampleIdx();
186         return idx.empty() ? Mat() : getSubVector(responses, idx);
187     }
getTestNormCatResponses() const188     Mat getTestNormCatResponses() const
189     {
190         Mat idx = getTestSampleIdx();
191         return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
192     }
getNormCatResponses() const193     Mat getNormCatResponses() const { return normCatResponses; }
getClassLabels() const194     Mat getClassLabels() const { return classLabels; }
getClassCounters() const195     Mat getClassCounters() const { return classCounters; }
getCatCount(int vi) const196     int getCatCount(int vi) const
197     {
198         int n = (int)catOfs.total();
199         CV_Assert( 0 <= vi && vi < n );
200         Vec2i ofs = catOfs.at<Vec2i>(vi);
201         return ofs[1] - ofs[0];
202     }
203 
getCatOfs() const204     Mat getCatOfs() const { return catOfs; }
getCatMap() const205     Mat getCatMap() const { return catMap; }
206 
getDefaultSubstValues() const207     Mat getDefaultSubstValues() const { return missingSubst; }
208 
closeFile()209     void closeFile() { if(file) fclose(file); file=0; }
clear()210     void clear()
211     {
212         closeFile();
213         samples.release();
214         missing.release();
215         varType.release();
216         responses.release();
217         sampleIdx.release();
218         trainSampleIdx.release();
219         testSampleIdx.release();
220         normCatResponses.release();
221         classLabels.release();
222         classCounters.release();
223         catMap.release();
224         catOfs.release();
225         nameMap = MapType();
226         layout = ROW_SAMPLE;
227     }
228 
229     typedef std::map<int, int> CatMapHash;
230 
setData(InputArray _samples,int _layout,InputArray _responses,InputArray _varIdx,InputArray _sampleIdx,InputArray _sampleWeights,InputArray _varType,InputArray _missing)231     void setData(InputArray _samples, int _layout, InputArray _responses,
232                  InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
233                  InputArray _varType, InputArray _missing)
234     {
235         clear();
236 
237         CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
238         samples = _samples.getMat();
239         layout = _layout;
240         responses = _responses.getMat();
241         varIdx = _varIdx.getMat();
242         sampleIdx = _sampleIdx.getMat();
243         sampleWeights = _sampleWeights.getMat();
244         varType = _varType.getMat();
245         missing = _missing.getMat();
246 
247         int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
248         int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
249         int i, noutputvars = 0;
250 
251         CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
252 
253         if( !sampleIdx.empty() )
254         {
255             CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
256                        checkRange(sampleIdx, true, 0, 0, nsamples-1)) ||
257                        sampleIdx.checkVector(1, CV_8U, true) == nsamples );
258             if( sampleIdx.type() == CV_8U )
259                 sampleIdx = convertMaskToIdx(sampleIdx);
260         }
261 
262         if( !sampleWeights.empty() )
263         {
264             CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
265         }
266         else
267         {
268             sampleWeights = Mat::ones(nsamples, 1, CV_32F);
269         }
270 
271         if( !varIdx.empty() )
272         {
273             CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
274                        checkRange(varIdx, true, 0, 0, ninputvars)) ||
275                        varIdx.checkVector(1, CV_8U, true) == ninputvars );
276             if( varIdx.type() == CV_8U )
277                 varIdx = convertMaskToIdx(varIdx);
278             varIdx = varIdx.clone();
279             std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
280         }
281 
282         if( !responses.empty() )
283         {
284             CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
285             if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
286                 noutputvars = 1;
287             else
288             {
289                 CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
290                            (layout == COL_SAMPLE && responses.cols == nsamples) );
291                 noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
292             }
293             if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
294             {
295                 Mat temp;
296                 transpose(responses, temp);
297                 responses = temp;
298             }
299         }
300 
301         int nvars = ninputvars + noutputvars;
302 
303         if( !varType.empty() )
304         {
305             CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
306                        checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
307         }
308         else
309         {
310             varType.create(1, nvars, CV_8U);
311             varType = Scalar::all(VAR_ORDERED);
312             if( noutputvars == 1 )
313                 varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
314         }
315 
316         if( noutputvars > 1 )
317         {
318             for( i = 0; i < noutputvars; i++ )
319                 CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
320         }
321 
322         catOfs = Mat::zeros(1, nvars, CV_32SC2);
323         missingSubst = Mat::zeros(1, nvars, CV_32F);
324 
325         vector<int> labels, counters, sortbuf, tempCatMap;
326         vector<Vec2i> tempCatOfs;
327         CatMapHash ofshash;
328 
329         AutoBuffer<uchar> buf(nsamples);
330         Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
331         bool haveMissing = !missing.empty();
332         if( haveMissing )
333         {
334             CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
335         }
336 
337         // we iterate through all the variables. For each categorical variable we build a map
338         // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
339         // often many categorical variables are similar, so we compress the map - try to re-use
340         // maps for different variables if they are identical
341         for( i = 0; i < ninputvars; i++ )
342         {
343             Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
344 
345             if( varType.at<uchar>(i) == VAR_CATEGORICAL )
346             {
347                 preprocessCategorical(values_i, 0, labels, 0, sortbuf);
348                 missingSubst.at<float>(i) = -1.f;
349                 int j, m = (int)labels.size();
350                 CV_Assert( m > 0 );
351                 int a = labels.front(), b = labels.back();
352                 const int* currmap = &labels[0];
353                 int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
354                 CatMapHash::iterator it = ofshash.find(hashval);
355                 if( it != ofshash.end() )
356                 {
357                     int vi = it->second;
358                     Vec2i ofs0 = tempCatOfs[vi];
359                     int m0 = ofs0[1] - ofs0[0];
360                     const int* map0 = &tempCatMap[ofs0[0]];
361                     if( m0 == m && map0[0] == a && map0[m0-1] == b )
362                     {
363                         for( j = 0; j < m; j++ )
364                             if( map0[j] != currmap[j] )
365                                 break;
366                         if( j == m )
367                         {
368                             // re-use the map
369                             tempCatOfs.push_back(ofs0);
370                             continue;
371                         }
372                     }
373                 }
374                 else
375                     ofshash[hashval] = i;
376                 Vec2i ofs;
377                 ofs[0] = (int)tempCatMap.size();
378                 ofs[1] = ofs[0] + m;
379                 tempCatOfs.push_back(ofs);
380                 std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
381             }
382             else
383             {
384                 tempCatOfs.push_back(Vec2i(0, 0));
385                 /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
386                 compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
387                 missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
388                 missingSubst.at<float>(i) = 0.f;
389             }
390         }
391 
392         if( !tempCatOfs.empty() )
393         {
394             Mat(tempCatOfs).copyTo(catOfs);
395             Mat(tempCatMap).copyTo(catMap);
396         }
397 
398         if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
399         {
400             preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
401             Mat(labels).copyTo(classLabels);
402             Mat(counters).copyTo(classCounters);
403         }
404     }
405 
convertMaskToIdx(const Mat & mask)406     Mat convertMaskToIdx(const Mat& mask)
407     {
408         int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
409         Mat idx(1, nz, CV_32S);
410         for( i = j = 0; i < n; i++ )
411             if( mask.at<uchar>(i) )
412                 idx.at<int>(j++) = i;
413         return idx;
414     }
415 
416     struct CmpByIdx
417     {
CmpByIdxcv::ml::TrainDataImpl::CmpByIdx418         CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
operator ()cv::ml::TrainDataImpl::CmpByIdx419         bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
420         const int* data;
421         int step;
422     };
423 
preprocessCategorical(const Mat & data,Mat * normdata,vector<int> & labels,vector<int> * counters,vector<int> & sortbuf)424     void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
425                                vector<int>* counters, vector<int>& sortbuf)
426     {
427         CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
428         int* odata = 0;
429         int ostep = 0;
430 
431         if(normdata)
432         {
433             normdata->create(data.size(), CV_32S);
434             odata = normdata->ptr<int>();
435             ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
436         }
437 
438         int i, n = data.cols + data.rows - 1;
439         sortbuf.resize(n*2);
440         int* idx = &sortbuf[0];
441         int* idata = (int*)data.ptr<int>();
442         int istep = data.isContinuous() ? 1 : (int)data.step1();
443 
444         if( data.type() == CV_32F )
445         {
446             idata = idx + n;
447             const float* fdata = data.ptr<float>();
448             for( i = 0; i < n; i++ )
449             {
450                 if( fdata[i*istep] == MISSED_VAL )
451                     idata[i] = -1;
452                 else
453                 {
454                     idata[i] = cvRound(fdata[i*istep]);
455                     CV_Assert( (float)idata[i] == fdata[i*istep] );
456                 }
457             }
458             istep = 1;
459         }
460 
461         for( i = 0; i < n; i++ )
462             idx[i] = i;
463 
464         std::sort(idx, idx + n, CmpByIdx(idata, istep));
465 
466         int clscount = 1;
467         for( i = 1; i < n; i++ )
468             clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
469 
470         int clslabel = -1;
471         int prev = ~idata[idx[0]*istep];
472         int previdx = 0;
473 
474         labels.resize(clscount);
475         if(counters)
476             counters->resize(clscount);
477 
478         for( i = 0; i < n; i++ )
479         {
480             int l = idata[idx[i]*istep];
481             if( l != prev )
482             {
483                 clslabel++;
484                 labels[clslabel] = l;
485                 int k = i - previdx;
486                 if( clslabel > 0 && counters )
487                     counters->at(clslabel-1) = k;
488                 prev = l;
489                 previdx = i;
490             }
491             if(odata)
492                 odata[idx[i]*ostep] = clslabel;
493         }
494         if(counters)
495             counters->at(clslabel) = i - previdx;
496     }
497 
loadCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)498     bool loadCSV(const String& filename, int headerLines,
499                  int responseStartIdx, int responseEndIdx,
500                  const String& varTypeSpec, char delimiter, char missch)
501     {
502         const int M = 1000000;
503         const char delimiters[3] = { ' ', delimiter, '\0' };
504         int nvars = 0;
505         bool varTypesSet = false;
506 
507         clear();
508 
509         file = fopen( filename.c_str(), "rt" );
510 
511         if( !file )
512             return false;
513 
514         std::vector<char> _buf(M);
515         std::vector<float> allresponses;
516         std::vector<float> rowvals;
517         std::vector<uchar> vtypes, rowtypes;
518         bool haveMissed = false;
519         char* buf = &_buf[0];
520 
521         int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
522         int ninputvars = 0, noutputvars = 0;
523 
524         Mat tempSamples, tempMissing, tempResponses;
525         MapType tempNameMap;
526         int catCounter = 1;
527 
528         // skip header lines
529         int lineno = 0;
530         for(;;lineno++)
531         {
532             if( !fgets(buf, M, file) )
533                 break;
534             if(lineno < headerLines )
535                 continue;
536             // trim trailing spaces
537             int idx = (int)strlen(buf)-1;
538             while( idx >= 0 && isspace(buf[idx]) )
539                 buf[idx--] = '\0';
540             // skip spaces in the beginning
541             char* ptr = buf;
542             while( *ptr != '\0' && isspace(*ptr) )
543                 ptr++;
544             // skip commented off lines
545             if(*ptr == '#')
546                 continue;
547             rowvals.clear();
548             rowtypes.clear();
549 
550             char* token = strtok(buf, delimiters);
551             if (!token)
552                 break;
553 
554             for(;;)
555             {
556                 float val=0.f; int tp = 0;
557                 decodeElem( token, val, tp, missch, tempNameMap, catCounter );
558                 if( tp == VAR_MISSED )
559                     haveMissed = true;
560                 rowvals.push_back(val);
561                 rowtypes.push_back((uchar)tp);
562                 token = strtok(NULL, delimiters);
563                 if (!token)
564                     break;
565             }
566 
567             if( nvars == 0 )
568             {
569                 if( rowvals.empty() )
570                     CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
571                 nvars = (int)rowvals.size();
572                 if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
573                 {
574                     setVarTypes(varTypeSpec, nvars, vtypes);
575                     varTypesSet = true;
576                 }
577                 else
578                     vtypes = rowtypes;
579 
580                 ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
581                 ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
582                 CV_Assert(ridx1 > ridx0);
583                 noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
584                 ninputvars = nvars - noutputvars;
585             }
586             else
587                 CV_Assert( nvars == (int)rowvals.size() );
588 
589             // check var types
590             for( i = 0; i < nvars; i++ )
591             {
592                 CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
593                            (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
594             }
595 
596             if( ridx0 >= 0 )
597             {
598                 for( i = ridx1; i < nvars; i++ )
599                     std::swap(rowvals[i], rowvals[i-noutputvars]);
600                 for( i = ninputvars; i < nvars; i++ )
601                     allresponses.push_back(rowvals[i]);
602                 rowvals.pop_back();
603             }
604             Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
605             tempSamples.push_back(rmat);
606         }
607 
608         closeFile();
609 
610         int nsamples = tempSamples.rows;
611         if( nsamples == 0 )
612             return false;
613 
614         if( haveMissed )
615             compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
616 
617         if( ridx0 >= 0 )
618         {
619             for( i = ridx1; i < nvars; i++ )
620                 std::swap(vtypes[i], vtypes[i-noutputvars]);
621             if( noutputvars > 1 )
622             {
623                 for( i = ninputvars; i < nvars; i++ )
624                     if( vtypes[i] == VAR_CATEGORICAL )
625                         CV_Error(CV_StsBadArg,
626                                  "If responses are vector values, not scalars, they must be marked as ordered responses");
627             }
628         }
629 
630         if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
631         {
632             for( i = 0; i < nsamples; i++ )
633                 if( allresponses[i] != cvRound(allresponses[i]) )
634                     break;
635             if( i == nsamples )
636                 vtypes[ninputvars] = VAR_CATEGORICAL;
637         }
638 
639         Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
640         setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
641                 noArray(), Mat(vtypes).clone(), tempMissing);
642         bool ok = !samples.empty();
643         if(ok)
644             std::swap(tempNameMap, nameMap);
645         return ok;
646     }
647 
decodeElem(const char * token,float & elem,int & type,char missch,MapType & namemap,int & counter) const648     void decodeElem( const char* token, float& elem, int& type,
649                      char missch, MapType& namemap, int& counter ) const
650     {
651         char* stopstring = NULL;
652         elem = (float)strtod( token, &stopstring );
653         if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
654         {
655             elem = MISSED_VAL;
656             type = VAR_MISSED;
657         }
658         else if( *stopstring != '\0' )
659         {
660             MapType::iterator it = namemap.find(token);
661             if( it == namemap.end() )
662             {
663                 elem = (float)counter;
664                 namemap[token] = counter++;
665             }
666             else
667                 elem = (float)it->second;
668             type = VAR_CATEGORICAL;
669         }
670         else
671             type = VAR_ORDERED;
672     }
673 
setVarTypes(const String & s,int nvars,std::vector<uchar> & vtypes) const674     void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
675     {
676         const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
677           "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
678         const char* str = s.c_str();
679         int specCounter = 0;
680 
681         vtypes.resize(nvars);
682 
683         for( int k = 0; k < 2; k++ )
684         {
685             const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
686             int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
687             if( ptr ) // parse ord/cat str
688             {
689                 char* stopstring = NULL;
690 
691                 if( ptr[3] == '\0' )
692                 {
693                     for( int i = 0; i < nvars; i++ )
694                         vtypes[i] = (uchar)tp;
695                     specCounter = nvars;
696                     break;
697                 }
698 
699                 if ( ptr[3] != '[')
700                     CV_Error( CV_StsBadArg, errmsg );
701 
702                 ptr += 4; // pass "ord["
703                 do
704                 {
705                     int b1 = (int)strtod( ptr, &stopstring );
706                     if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
707                         CV_Error( CV_StsBadArg, errmsg );
708                     ptr = stopstring + 1;
709                     if( (stopstring[0] == ',') || (stopstring[0] == ']'))
710                     {
711                         CV_Assert( 0 <= b1 && b1 < nvars );
712                         vtypes[b1] = (uchar)tp;
713                         specCounter++;
714                     }
715                     else
716                     {
717                         if( stopstring[0] == '-')
718                         {
719                             int b2 = (int)strtod( ptr, &stopstring);
720                             if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
721                                 CV_Error( CV_StsBadArg, errmsg );
722                             ptr = stopstring + 1;
723                             CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
724                             for (int i = b1; i <= b2; i++)
725                                 vtypes[i] = (uchar)tp;
726                             specCounter += b2 - b1 + 1;
727                         }
728                         else
729                             CV_Error( CV_StsBadArg, errmsg );
730 
731                     }
732                 }
733                 while(*stopstring != ']');
734 
735                 if( stopstring[1] != '\0' && stopstring[1] != ',')
736                     CV_Error( CV_StsBadArg, errmsg );
737             }
738         }
739 
740         if( specCounter != nvars )
741             CV_Error( CV_StsBadArg, "type of some variables is not specified" );
742     }
743 
setTrainTestSplitRatio(double ratio,bool shuffle)744     void setTrainTestSplitRatio(double ratio, bool shuffle)
745     {
746         CV_Assert( 0. <= ratio && ratio <= 1. );
747         setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
748     }
749 
setTrainTestSplit(int count,bool shuffle)750     void setTrainTestSplit(int count, bool shuffle)
751     {
752         int i, nsamples = getNSamples();
753         CV_Assert( 0 <= count && count < nsamples );
754 
755         trainSampleIdx.release();
756         testSampleIdx.release();
757 
758         if( count == 0 )
759             trainSampleIdx = sampleIdx;
760         else if( count == nsamples )
761             testSampleIdx = sampleIdx;
762         else
763         {
764             Mat mask(1, nsamples, CV_8U);
765             uchar* mptr = mask.ptr();
766             for( i = 0; i < nsamples; i++ )
767                 mptr[i] = (uchar)(i < count);
768             trainSampleIdx.create(1, count, CV_32S);
769             testSampleIdx.create(1, nsamples - count, CV_32S);
770             int j0 = 0, j1 = 0;
771             const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
772             int* trainptr = trainSampleIdx.ptr<int>();
773             int* testptr = testSampleIdx.ptr<int>();
774             for( i = 0; i < nsamples; i++ )
775             {
776                 int idx = sptr ? sptr[i] : i;
777                 if( mptr[i] )
778                     trainptr[j0++] = idx;
779                 else
780                     testptr[j1++] = idx;
781             }
782             if( shuffle )
783                 shuffleTrainTest();
784         }
785     }
786 
shuffleTrainTest()787     void shuffleTrainTest()
788     {
789         if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
790         {
791             int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
792             int* trainIdx = trainSampleIdx.ptr<int>();
793             int* testIdx = testSampleIdx.ptr<int>();
794             RNG& rng = theRNG();
795 
796             for( i = 0; i < nsamples; i++)
797             {
798                 int a = rng.uniform(0, nsamples);
799                 int b = rng.uniform(0, nsamples);
800                 int* ptra = trainIdx;
801                 int* ptrb = trainIdx;
802                 if( a >= ntrain )
803                 {
804                     ptra = testIdx;
805                     a -= ntrain;
806                     CV_Assert( a < ntest );
807                 }
808                 if( b >= ntrain )
809                 {
810                     ptrb = testIdx;
811                     b -= ntrain;
812                     CV_Assert( b < ntest );
813                 }
814                 std::swap(ptra[a], ptrb[b]);
815             }
816         }
817     }
818 
getTrainSamples(int _layout,bool compressSamples,bool compressVars) const819     Mat getTrainSamples(int _layout,
820                         bool compressSamples,
821                         bool compressVars) const
822     {
823         if( samples.empty() )
824             return samples;
825 
826         if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
827             (!compressVars || varIdx.empty()) &&
828             layout == _layout )
829             return samples;
830 
831         int drows = getNTrainSamples(), dcols = getNVars();
832         Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
833         const float* src0 = samples.ptr<float>();
834         const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
835         const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
836         size_t sstep0 = samples.step/samples.elemSize();
837         size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
838         size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
839 
840         if( _layout == COL_SAMPLE )
841         {
842             std::swap(drows, dcols);
843             std::swap(sptr, vptr);
844             std::swap(sstep, vstep);
845         }
846 
847         Mat dsamples(drows, dcols, CV_32F);
848 
849         for( int i = 0; i < drows; i++ )
850         {
851             const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
852             float* dst = dsamples.ptr<float>(i);
853 
854             for( int j = 0; j < dcols; j++ )
855                 dst[j] = src[(vptr ? vptr[j] : j)*vstep];
856         }
857 
858         return dsamples;
859     }
860 
getValues(int vi,InputArray _sidx,float * values) const861     void getValues( int vi, InputArray _sidx, float* values ) const
862     {
863         Mat sidx = _sidx.getMat();
864         int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
865         CV_Assert( 0 <= vi && vi < getNAllVars() );
866         CV_Assert( n >= 0 );
867         const int* s = n > 0 ? sidx.ptr<int>() : 0;
868         if( n == 0 )
869             n = nsamples;
870 
871         size_t step = samples.step/samples.elemSize();
872         size_t sstep = layout == ROW_SAMPLE ? step : 1;
873         size_t vstep = layout == ROW_SAMPLE ? 1 : step;
874 
875         const float* src = samples.ptr<float>() + vi*vstep;
876         float subst = missingSubst.at<float>(vi);
877         for( i = 0; i < n; i++ )
878         {
879             int j = i;
880             if( s )
881             {
882                 j = s[i];
883                 CV_Assert( 0 <= j && j < nsamples );
884             }
885             values[i] = src[j*sstep];
886             if( values[i] == MISSED_VAL )
887                 values[i] = subst;
888         }
889     }
890 
getNormCatValues(int vi,InputArray _sidx,int * values) const891     void getNormCatValues( int vi, InputArray _sidx, int* values ) const
892     {
893         float* fvalues = (float*)values;
894         getValues(vi, _sidx, fvalues);
895         int i, n = (int)_sidx.total();
896         Vec2i ofs = catOfs.at<Vec2i>(vi);
897         int m = ofs[1] - ofs[0];
898 
899         CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
900         const int* cmap = &catMap.at<int>(ofs[0]);
901         bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
902 
903         if( fastMap )
904         {
905             for( i = 0; i < n; i++ )
906             {
907                 int val = cvRound(fvalues[i]);
908                 int idx = val - cmap[0];
909                 CV_Assert(cmap[idx] == val);
910                 values[i] = idx;
911             }
912         }
913         else
914         {
915             for( i = 0; i < n; i++ )
916             {
917                 int val = cvRound(fvalues[i]);
918                 int a = 0, b = m, c = -1;
919 
920                 while( a < b )
921                 {
922                     c = (a + b) >> 1;
923                     if( val < cmap[c] )
924                         b = c;
925                     else if( val > cmap[c] )
926                         a = c+1;
927                     else
928                         break;
929                 }
930 
931                 CV_DbgAssert( c >= 0 && val == cmap[c] );
932                 values[i] = c;
933             }
934         }
935     }
936 
getSample(InputArray _vidx,int sidx,float * buf) const937     void getSample(InputArray _vidx, int sidx, float* buf) const
938     {
939         CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
940         Mat vidx = _vidx.getMat();
941         int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
942         CV_Assert( n >= 0 );
943         const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
944         if( n == 0 )
945             n = nvars;
946 
947         size_t step = samples.step/samples.elemSize();
948         size_t sstep = layout == ROW_SAMPLE ? step : 1;
949         size_t vstep = layout == ROW_SAMPLE ? 1 : step;
950 
951         const float* src = samples.ptr<float>() + sidx*sstep;
952         for( i = 0; i < n; i++ )
953         {
954             int j = i;
955             if( vptr )
956             {
957                 j = vptr[i];
958                 CV_Assert( 0 <= j && j < nvars );
959             }
960             buf[i] = src[j*vstep];
961         }
962     }
963 
964     FILE* file;
965     int layout;
966     Mat samples, missing, varType, varIdx, responses, missingSubst;
967     Mat sampleIdx, trainSampleIdx, testSampleIdx;
968     Mat sampleWeights, catMap, catOfs;
969     Mat normCatResponses, classLabels, classCounters;
970     MapType nameMap;
971 };
972 
loadFromCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)973 Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
974                                       int headerLines,
975                                       int responseStartIdx,
976                                       int responseEndIdx,
977                                       const String& varTypeSpec,
978                                       char delimiter, char missch)
979 {
980     Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
981     if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
982         td.release();
983     return td;
984 }
985 
create(InputArray samples,int layout,InputArray responses,InputArray varIdx,InputArray sampleIdx,InputArray sampleWeights,InputArray varType)986 Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
987                                  InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
988                                  InputArray varType)
989 {
990     Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
991     td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
992     return td;
993 }
994 
995 }}
996 
997 /* End of file. */
998