1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "precomp.hpp"
42 #include <ctype.h>
43 #include <algorithm>
44 #include <iterator>
45
46 namespace cv { namespace ml {
47
48 static const float MISSED_VAL = TrainData::missingValue();
49 static const int VAR_MISSED = VAR_ORDERED;
50
~TrainData()51 TrainData::~TrainData() {}
52
getSubVector(const Mat & vec,const Mat & idx)53 Mat TrainData::getSubVector(const Mat& vec, const Mat& idx)
54 {
55 if( idx.empty() )
56 return vec;
57 int i, j, n = idx.checkVector(1, CV_32S);
58 int type = vec.type();
59 CV_Assert( type == CV_32S || type == CV_32F || type == CV_64F );
60 int dims = 1, m;
61
62 if( vec.cols == 1 || vec.rows == 1 )
63 {
64 dims = 1;
65 m = vec.cols + vec.rows - 1;
66 }
67 else
68 {
69 dims = vec.cols;
70 m = vec.rows;
71 }
72
73 Mat subvec;
74
75 if( vec.cols == m )
76 subvec.create(dims, n, type);
77 else
78 subvec.create(n, dims, type);
79 if( type == CV_32S )
80 for( i = 0; i < n; i++ )
81 {
82 int k = idx.at<int>(i);
83 CV_Assert( 0 <= k && k < m );
84 if( dims == 1 )
85 subvec.at<int>(i) = vec.at<int>(k);
86 else
87 for( j = 0; j < dims; j++ )
88 subvec.at<int>(i, j) = vec.at<int>(k, j);
89 }
90 else if( type == CV_32F )
91 for( i = 0; i < n; i++ )
92 {
93 int k = idx.at<int>(i);
94 CV_Assert( 0 <= k && k < m );
95 if( dims == 1 )
96 subvec.at<float>(i) = vec.at<float>(k);
97 else
98 for( j = 0; j < dims; j++ )
99 subvec.at<float>(i, j) = vec.at<float>(k, j);
100 }
101 else
102 for( i = 0; i < n; i++ )
103 {
104 int k = idx.at<int>(i);
105 CV_Assert( 0 <= k && k < m );
106 if( dims == 1 )
107 subvec.at<double>(i) = vec.at<double>(k);
108 else
109 for( j = 0; j < dims; j++ )
110 subvec.at<double>(i, j) = vec.at<double>(k, j);
111 }
112 return subvec;
113 }
114
115 class TrainDataImpl : public TrainData
116 {
117 public:
118 typedef std::map<String, int> MapType;
119
TrainDataImpl()120 TrainDataImpl()
121 {
122 file = 0;
123 clear();
124 }
125
~TrainDataImpl()126 virtual ~TrainDataImpl() { closeFile(); }
127
getLayout() const128 int getLayout() const { return layout; }
getNSamples() const129 int getNSamples() const
130 {
131 return !sampleIdx.empty() ? (int)sampleIdx.total() :
132 layout == ROW_SAMPLE ? samples.rows : samples.cols;
133 }
getNTrainSamples() const134 int getNTrainSamples() const
135 {
136 return !trainSampleIdx.empty() ? (int)trainSampleIdx.total() : getNSamples();
137 }
getNTestSamples() const138 int getNTestSamples() const
139 {
140 return !testSampleIdx.empty() ? (int)testSampleIdx.total() : 0;
141 }
getNVars() const142 int getNVars() const
143 {
144 return !varIdx.empty() ? (int)varIdx.total() : getNAllVars();
145 }
getNAllVars() const146 int getNAllVars() const
147 {
148 return layout == ROW_SAMPLE ? samples.cols : samples.rows;
149 }
150
getSamples() const151 Mat getSamples() const { return samples; }
getResponses() const152 Mat getResponses() const { return responses; }
getMissing() const153 Mat getMissing() const { return missing; }
getVarIdx() const154 Mat getVarIdx() const { return varIdx; }
getVarType() const155 Mat getVarType() const { return varType; }
getResponseType() const156 int getResponseType() const
157 {
158 return classLabels.empty() ? VAR_ORDERED : VAR_CATEGORICAL;
159 }
getTrainSampleIdx() const160 Mat getTrainSampleIdx() const { return !trainSampleIdx.empty() ? trainSampleIdx : sampleIdx; }
getTestSampleIdx() const161 Mat getTestSampleIdx() const { return testSampleIdx; }
getSampleWeights() const162 Mat getSampleWeights() const
163 {
164 return sampleWeights;
165 }
getTrainSampleWeights() const166 Mat getTrainSampleWeights() const
167 {
168 return getSubVector(sampleWeights, getTrainSampleIdx());
169 }
getTestSampleWeights() const170 Mat getTestSampleWeights() const
171 {
172 Mat idx = getTestSampleIdx();
173 return idx.empty() ? Mat() : getSubVector(sampleWeights, idx);
174 }
getTrainResponses() const175 Mat getTrainResponses() const
176 {
177 return getSubVector(responses, getTrainSampleIdx());
178 }
getTrainNormCatResponses() const179 Mat getTrainNormCatResponses() const
180 {
181 return getSubVector(normCatResponses, getTrainSampleIdx());
182 }
getTestResponses() const183 Mat getTestResponses() const
184 {
185 Mat idx = getTestSampleIdx();
186 return idx.empty() ? Mat() : getSubVector(responses, idx);
187 }
getTestNormCatResponses() const188 Mat getTestNormCatResponses() const
189 {
190 Mat idx = getTestSampleIdx();
191 return idx.empty() ? Mat() : getSubVector(normCatResponses, idx);
192 }
getNormCatResponses() const193 Mat getNormCatResponses() const { return normCatResponses; }
getClassLabels() const194 Mat getClassLabels() const { return classLabels; }
getClassCounters() const195 Mat getClassCounters() const { return classCounters; }
getCatCount(int vi) const196 int getCatCount(int vi) const
197 {
198 int n = (int)catOfs.total();
199 CV_Assert( 0 <= vi && vi < n );
200 Vec2i ofs = catOfs.at<Vec2i>(vi);
201 return ofs[1] - ofs[0];
202 }
203
getCatOfs() const204 Mat getCatOfs() const { return catOfs; }
getCatMap() const205 Mat getCatMap() const { return catMap; }
206
getDefaultSubstValues() const207 Mat getDefaultSubstValues() const { return missingSubst; }
208
closeFile()209 void closeFile() { if(file) fclose(file); file=0; }
clear()210 void clear()
211 {
212 closeFile();
213 samples.release();
214 missing.release();
215 varType.release();
216 responses.release();
217 sampleIdx.release();
218 trainSampleIdx.release();
219 testSampleIdx.release();
220 normCatResponses.release();
221 classLabels.release();
222 classCounters.release();
223 catMap.release();
224 catOfs.release();
225 nameMap = MapType();
226 layout = ROW_SAMPLE;
227 }
228
229 typedef std::map<int, int> CatMapHash;
230
setData(InputArray _samples,int _layout,InputArray _responses,InputArray _varIdx,InputArray _sampleIdx,InputArray _sampleWeights,InputArray _varType,InputArray _missing)231 void setData(InputArray _samples, int _layout, InputArray _responses,
232 InputArray _varIdx, InputArray _sampleIdx, InputArray _sampleWeights,
233 InputArray _varType, InputArray _missing)
234 {
235 clear();
236
237 CV_Assert(_layout == ROW_SAMPLE || _layout == COL_SAMPLE );
238 samples = _samples.getMat();
239 layout = _layout;
240 responses = _responses.getMat();
241 varIdx = _varIdx.getMat();
242 sampleIdx = _sampleIdx.getMat();
243 sampleWeights = _sampleWeights.getMat();
244 varType = _varType.getMat();
245 missing = _missing.getMat();
246
247 int nsamples = layout == ROW_SAMPLE ? samples.rows : samples.cols;
248 int ninputvars = layout == ROW_SAMPLE ? samples.cols : samples.rows;
249 int i, noutputvars = 0;
250
251 CV_Assert( samples.type() == CV_32F || samples.type() == CV_32S );
252
253 if( !sampleIdx.empty() )
254 {
255 CV_Assert( (sampleIdx.checkVector(1, CV_32S, true) > 0 &&
256 checkRange(sampleIdx, true, 0, 0, nsamples-1)) ||
257 sampleIdx.checkVector(1, CV_8U, true) == nsamples );
258 if( sampleIdx.type() == CV_8U )
259 sampleIdx = convertMaskToIdx(sampleIdx);
260 }
261
262 if( !sampleWeights.empty() )
263 {
264 CV_Assert( sampleWeights.checkVector(1, CV_32F, true) == nsamples );
265 }
266 else
267 {
268 sampleWeights = Mat::ones(nsamples, 1, CV_32F);
269 }
270
271 if( !varIdx.empty() )
272 {
273 CV_Assert( (varIdx.checkVector(1, CV_32S, true) > 0 &&
274 checkRange(varIdx, true, 0, 0, ninputvars)) ||
275 varIdx.checkVector(1, CV_8U, true) == ninputvars );
276 if( varIdx.type() == CV_8U )
277 varIdx = convertMaskToIdx(varIdx);
278 varIdx = varIdx.clone();
279 std::sort(varIdx.ptr<int>(), varIdx.ptr<int>() + varIdx.total());
280 }
281
282 if( !responses.empty() )
283 {
284 CV_Assert( responses.type() == CV_32F || responses.type() == CV_32S );
285 if( (responses.cols == 1 || responses.rows == 1) && (int)responses.total() == nsamples )
286 noutputvars = 1;
287 else
288 {
289 CV_Assert( (layout == ROW_SAMPLE && responses.rows == nsamples) ||
290 (layout == COL_SAMPLE && responses.cols == nsamples) );
291 noutputvars = layout == ROW_SAMPLE ? responses.cols : responses.rows;
292 }
293 if( !responses.isContinuous() || (layout == COL_SAMPLE && noutputvars > 1) )
294 {
295 Mat temp;
296 transpose(responses, temp);
297 responses = temp;
298 }
299 }
300
301 int nvars = ninputvars + noutputvars;
302
303 if( !varType.empty() )
304 {
305 CV_Assert( varType.checkVector(1, CV_8U, true) == nvars &&
306 checkRange(varType, true, 0, VAR_ORDERED, VAR_CATEGORICAL+1) );
307 }
308 else
309 {
310 varType.create(1, nvars, CV_8U);
311 varType = Scalar::all(VAR_ORDERED);
312 if( noutputvars == 1 )
313 varType.at<uchar>(ninputvars) = (uchar)(responses.type() < CV_32F ? VAR_CATEGORICAL : VAR_ORDERED);
314 }
315
316 if( noutputvars > 1 )
317 {
318 for( i = 0; i < noutputvars; i++ )
319 CV_Assert( varType.at<uchar>(ninputvars + i) == VAR_ORDERED );
320 }
321
322 catOfs = Mat::zeros(1, nvars, CV_32SC2);
323 missingSubst = Mat::zeros(1, nvars, CV_32F);
324
325 vector<int> labels, counters, sortbuf, tempCatMap;
326 vector<Vec2i> tempCatOfs;
327 CatMapHash ofshash;
328
329 AutoBuffer<uchar> buf(nsamples);
330 Mat non_missing(layout == ROW_SAMPLE ? Size(1, nsamples) : Size(nsamples, 1), CV_8U, (uchar*)buf);
331 bool haveMissing = !missing.empty();
332 if( haveMissing )
333 {
334 CV_Assert( missing.size() == samples.size() && missing.type() == CV_8U );
335 }
336
337 // we iterate through all the variables. For each categorical variable we build a map
338 // in order to convert input values of the variable into normalized values (0..catcount_vi-1)
339 // often many categorical variables are similar, so we compress the map - try to re-use
340 // maps for different variables if they are identical
341 for( i = 0; i < ninputvars; i++ )
342 {
343 Mat values_i = layout == ROW_SAMPLE ? samples.col(i) : samples.row(i);
344
345 if( varType.at<uchar>(i) == VAR_CATEGORICAL )
346 {
347 preprocessCategorical(values_i, 0, labels, 0, sortbuf);
348 missingSubst.at<float>(i) = -1.f;
349 int j, m = (int)labels.size();
350 CV_Assert( m > 0 );
351 int a = labels.front(), b = labels.back();
352 const int* currmap = &labels[0];
353 int hashval = ((unsigned)a*127 + (unsigned)b)*127 + m;
354 CatMapHash::iterator it = ofshash.find(hashval);
355 if( it != ofshash.end() )
356 {
357 int vi = it->second;
358 Vec2i ofs0 = tempCatOfs[vi];
359 int m0 = ofs0[1] - ofs0[0];
360 const int* map0 = &tempCatMap[ofs0[0]];
361 if( m0 == m && map0[0] == a && map0[m0-1] == b )
362 {
363 for( j = 0; j < m; j++ )
364 if( map0[j] != currmap[j] )
365 break;
366 if( j == m )
367 {
368 // re-use the map
369 tempCatOfs.push_back(ofs0);
370 continue;
371 }
372 }
373 }
374 else
375 ofshash[hashval] = i;
376 Vec2i ofs;
377 ofs[0] = (int)tempCatMap.size();
378 ofs[1] = ofs[0] + m;
379 tempCatOfs.push_back(ofs);
380 std::copy(labels.begin(), labels.end(), std::back_inserter(tempCatMap));
381 }
382 else
383 {
384 tempCatOfs.push_back(Vec2i(0, 0));
385 /*Mat missing_i = layout == ROW_SAMPLE ? missing.col(i) : missing.row(i);
386 compare(missing_i, Scalar::all(0), non_missing, CMP_EQ);
387 missingSubst.at<float>(i) = (float)(mean(values_i, non_missing)[0]);*/
388 missingSubst.at<float>(i) = 0.f;
389 }
390 }
391
392 if( !tempCatOfs.empty() )
393 {
394 Mat(tempCatOfs).copyTo(catOfs);
395 Mat(tempCatMap).copyTo(catMap);
396 }
397
398 if( varType.at<uchar>(ninputvars) == VAR_CATEGORICAL )
399 {
400 preprocessCategorical(responses, &normCatResponses, labels, &counters, sortbuf);
401 Mat(labels).copyTo(classLabels);
402 Mat(counters).copyTo(classCounters);
403 }
404 }
405
convertMaskToIdx(const Mat & mask)406 Mat convertMaskToIdx(const Mat& mask)
407 {
408 int i, j, nz = countNonZero(mask), n = mask.cols + mask.rows - 1;
409 Mat idx(1, nz, CV_32S);
410 for( i = j = 0; i < n; i++ )
411 if( mask.at<uchar>(i) )
412 idx.at<int>(j++) = i;
413 return idx;
414 }
415
416 struct CmpByIdx
417 {
CmpByIdxcv::ml::TrainDataImpl::CmpByIdx418 CmpByIdx(const int* _data, int _step) : data(_data), step(_step) {}
operator ()cv::ml::TrainDataImpl::CmpByIdx419 bool operator ()(int i, int j) const { return data[i*step] < data[j*step]; }
420 const int* data;
421 int step;
422 };
423
preprocessCategorical(const Mat & data,Mat * normdata,vector<int> & labels,vector<int> * counters,vector<int> & sortbuf)424 void preprocessCategorical(const Mat& data, Mat* normdata, vector<int>& labels,
425 vector<int>* counters, vector<int>& sortbuf)
426 {
427 CV_Assert((data.cols == 1 || data.rows == 1) && (data.type() == CV_32S || data.type() == CV_32F));
428 int* odata = 0;
429 int ostep = 0;
430
431 if(normdata)
432 {
433 normdata->create(data.size(), CV_32S);
434 odata = normdata->ptr<int>();
435 ostep = normdata->isContinuous() ? 1 : (int)normdata->step1();
436 }
437
438 int i, n = data.cols + data.rows - 1;
439 sortbuf.resize(n*2);
440 int* idx = &sortbuf[0];
441 int* idata = (int*)data.ptr<int>();
442 int istep = data.isContinuous() ? 1 : (int)data.step1();
443
444 if( data.type() == CV_32F )
445 {
446 idata = idx + n;
447 const float* fdata = data.ptr<float>();
448 for( i = 0; i < n; i++ )
449 {
450 if( fdata[i*istep] == MISSED_VAL )
451 idata[i] = -1;
452 else
453 {
454 idata[i] = cvRound(fdata[i*istep]);
455 CV_Assert( (float)idata[i] == fdata[i*istep] );
456 }
457 }
458 istep = 1;
459 }
460
461 for( i = 0; i < n; i++ )
462 idx[i] = i;
463
464 std::sort(idx, idx + n, CmpByIdx(idata, istep));
465
466 int clscount = 1;
467 for( i = 1; i < n; i++ )
468 clscount += idata[idx[i]*istep] != idata[idx[i-1]*istep];
469
470 int clslabel = -1;
471 int prev = ~idata[idx[0]*istep];
472 int previdx = 0;
473
474 labels.resize(clscount);
475 if(counters)
476 counters->resize(clscount);
477
478 for( i = 0; i < n; i++ )
479 {
480 int l = idata[idx[i]*istep];
481 if( l != prev )
482 {
483 clslabel++;
484 labels[clslabel] = l;
485 int k = i - previdx;
486 if( clslabel > 0 && counters )
487 counters->at(clslabel-1) = k;
488 prev = l;
489 previdx = i;
490 }
491 if(odata)
492 odata[idx[i]*ostep] = clslabel;
493 }
494 if(counters)
495 counters->at(clslabel) = i - previdx;
496 }
497
loadCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)498 bool loadCSV(const String& filename, int headerLines,
499 int responseStartIdx, int responseEndIdx,
500 const String& varTypeSpec, char delimiter, char missch)
501 {
502 const int M = 1000000;
503 const char delimiters[3] = { ' ', delimiter, '\0' };
504 int nvars = 0;
505 bool varTypesSet = false;
506
507 clear();
508
509 file = fopen( filename.c_str(), "rt" );
510
511 if( !file )
512 return false;
513
514 std::vector<char> _buf(M);
515 std::vector<float> allresponses;
516 std::vector<float> rowvals;
517 std::vector<uchar> vtypes, rowtypes;
518 bool haveMissed = false;
519 char* buf = &_buf[0];
520
521 int i, ridx0 = responseStartIdx, ridx1 = responseEndIdx;
522 int ninputvars = 0, noutputvars = 0;
523
524 Mat tempSamples, tempMissing, tempResponses;
525 MapType tempNameMap;
526 int catCounter = 1;
527
528 // skip header lines
529 int lineno = 0;
530 for(;;lineno++)
531 {
532 if( !fgets(buf, M, file) )
533 break;
534 if(lineno < headerLines )
535 continue;
536 // trim trailing spaces
537 int idx = (int)strlen(buf)-1;
538 while( idx >= 0 && isspace(buf[idx]) )
539 buf[idx--] = '\0';
540 // skip spaces in the beginning
541 char* ptr = buf;
542 while( *ptr != '\0' && isspace(*ptr) )
543 ptr++;
544 // skip commented off lines
545 if(*ptr == '#')
546 continue;
547 rowvals.clear();
548 rowtypes.clear();
549
550 char* token = strtok(buf, delimiters);
551 if (!token)
552 break;
553
554 for(;;)
555 {
556 float val=0.f; int tp = 0;
557 decodeElem( token, val, tp, missch, tempNameMap, catCounter );
558 if( tp == VAR_MISSED )
559 haveMissed = true;
560 rowvals.push_back(val);
561 rowtypes.push_back((uchar)tp);
562 token = strtok(NULL, delimiters);
563 if (!token)
564 break;
565 }
566
567 if( nvars == 0 )
568 {
569 if( rowvals.empty() )
570 CV_Error(CV_StsBadArg, "invalid CSV format; no data found");
571 nvars = (int)rowvals.size();
572 if( !varTypeSpec.empty() && varTypeSpec.size() > 0 )
573 {
574 setVarTypes(varTypeSpec, nvars, vtypes);
575 varTypesSet = true;
576 }
577 else
578 vtypes = rowtypes;
579
580 ridx0 = ridx0 >= 0 ? ridx0 : ridx0 == -1 ? nvars - 1 : -1;
581 ridx1 = ridx1 >= 0 ? ridx1 : ridx0 >= 0 ? ridx0+1 : -1;
582 CV_Assert(ridx1 > ridx0);
583 noutputvars = ridx0 >= 0 ? ridx1 - ridx0 : 0;
584 ninputvars = nvars - noutputvars;
585 }
586 else
587 CV_Assert( nvars == (int)rowvals.size() );
588
589 // check var types
590 for( i = 0; i < nvars; i++ )
591 {
592 CV_Assert( (!varTypesSet && vtypes[i] == rowtypes[i]) ||
593 (varTypesSet && (vtypes[i] == rowtypes[i] || rowtypes[i] == VAR_ORDERED)) );
594 }
595
596 if( ridx0 >= 0 )
597 {
598 for( i = ridx1; i < nvars; i++ )
599 std::swap(rowvals[i], rowvals[i-noutputvars]);
600 for( i = ninputvars; i < nvars; i++ )
601 allresponses.push_back(rowvals[i]);
602 rowvals.pop_back();
603 }
604 Mat rmat(1, ninputvars, CV_32F, &rowvals[0]);
605 tempSamples.push_back(rmat);
606 }
607
608 closeFile();
609
610 int nsamples = tempSamples.rows;
611 if( nsamples == 0 )
612 return false;
613
614 if( haveMissed )
615 compare(tempSamples, MISSED_VAL, tempMissing, CMP_EQ);
616
617 if( ridx0 >= 0 )
618 {
619 for( i = ridx1; i < nvars; i++ )
620 std::swap(vtypes[i], vtypes[i-noutputvars]);
621 if( noutputvars > 1 )
622 {
623 for( i = ninputvars; i < nvars; i++ )
624 if( vtypes[i] == VAR_CATEGORICAL )
625 CV_Error(CV_StsBadArg,
626 "If responses are vector values, not scalars, they must be marked as ordered responses");
627 }
628 }
629
630 if( !varTypesSet && noutputvars == 1 && vtypes[ninputvars] == VAR_ORDERED )
631 {
632 for( i = 0; i < nsamples; i++ )
633 if( allresponses[i] != cvRound(allresponses[i]) )
634 break;
635 if( i == nsamples )
636 vtypes[ninputvars] = VAR_CATEGORICAL;
637 }
638
639 Mat(nsamples, noutputvars, CV_32F, &allresponses[0]).copyTo(tempResponses);
640 setData(tempSamples, ROW_SAMPLE, tempResponses, noArray(), noArray(),
641 noArray(), Mat(vtypes).clone(), tempMissing);
642 bool ok = !samples.empty();
643 if(ok)
644 std::swap(tempNameMap, nameMap);
645 return ok;
646 }
647
decodeElem(const char * token,float & elem,int & type,char missch,MapType & namemap,int & counter) const648 void decodeElem( const char* token, float& elem, int& type,
649 char missch, MapType& namemap, int& counter ) const
650 {
651 char* stopstring = NULL;
652 elem = (float)strtod( token, &stopstring );
653 if( *stopstring == missch && strlen(stopstring) == 1 ) // missed value
654 {
655 elem = MISSED_VAL;
656 type = VAR_MISSED;
657 }
658 else if( *stopstring != '\0' )
659 {
660 MapType::iterator it = namemap.find(token);
661 if( it == namemap.end() )
662 {
663 elem = (float)counter;
664 namemap[token] = counter++;
665 }
666 else
667 elem = (float)it->second;
668 type = VAR_CATEGORICAL;
669 }
670 else
671 type = VAR_ORDERED;
672 }
673
setVarTypes(const String & s,int nvars,std::vector<uchar> & vtypes) const674 void setVarTypes( const String& s, int nvars, std::vector<uchar>& vtypes ) const
675 {
676 const char* errmsg = "type spec is not correct; it should have format \"cat\", \"ord\" or "
677 "\"ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\", where n's and m's are 0-based variable indices";
678 const char* str = s.c_str();
679 int specCounter = 0;
680
681 vtypes.resize(nvars);
682
683 for( int k = 0; k < 2; k++ )
684 {
685 const char* ptr = strstr(str, k == 0 ? "ord" : "cat");
686 int tp = k == 0 ? VAR_ORDERED : VAR_CATEGORICAL;
687 if( ptr ) // parse ord/cat str
688 {
689 char* stopstring = NULL;
690
691 if( ptr[3] == '\0' )
692 {
693 for( int i = 0; i < nvars; i++ )
694 vtypes[i] = (uchar)tp;
695 specCounter = nvars;
696 break;
697 }
698
699 if ( ptr[3] != '[')
700 CV_Error( CV_StsBadArg, errmsg );
701
702 ptr += 4; // pass "ord["
703 do
704 {
705 int b1 = (int)strtod( ptr, &stopstring );
706 if( *stopstring == 0 || (*stopstring != ',' && *stopstring != ']' && *stopstring != '-') )
707 CV_Error( CV_StsBadArg, errmsg );
708 ptr = stopstring + 1;
709 if( (stopstring[0] == ',') || (stopstring[0] == ']'))
710 {
711 CV_Assert( 0 <= b1 && b1 < nvars );
712 vtypes[b1] = (uchar)tp;
713 specCounter++;
714 }
715 else
716 {
717 if( stopstring[0] == '-')
718 {
719 int b2 = (int)strtod( ptr, &stopstring);
720 if ( (*stopstring == 0) || (*stopstring != ',' && *stopstring != ']') )
721 CV_Error( CV_StsBadArg, errmsg );
722 ptr = stopstring + 1;
723 CV_Assert( 0 <= b1 && b1 <= b2 && b2 < nvars );
724 for (int i = b1; i <= b2; i++)
725 vtypes[i] = (uchar)tp;
726 specCounter += b2 - b1 + 1;
727 }
728 else
729 CV_Error( CV_StsBadArg, errmsg );
730
731 }
732 }
733 while(*stopstring != ']');
734
735 if( stopstring[1] != '\0' && stopstring[1] != ',')
736 CV_Error( CV_StsBadArg, errmsg );
737 }
738 }
739
740 if( specCounter != nvars )
741 CV_Error( CV_StsBadArg, "type of some variables is not specified" );
742 }
743
setTrainTestSplitRatio(double ratio,bool shuffle)744 void setTrainTestSplitRatio(double ratio, bool shuffle)
745 {
746 CV_Assert( 0. <= ratio && ratio <= 1. );
747 setTrainTestSplit(cvRound(getNSamples()*ratio), shuffle);
748 }
749
setTrainTestSplit(int count,bool shuffle)750 void setTrainTestSplit(int count, bool shuffle)
751 {
752 int i, nsamples = getNSamples();
753 CV_Assert( 0 <= count && count < nsamples );
754
755 trainSampleIdx.release();
756 testSampleIdx.release();
757
758 if( count == 0 )
759 trainSampleIdx = sampleIdx;
760 else if( count == nsamples )
761 testSampleIdx = sampleIdx;
762 else
763 {
764 Mat mask(1, nsamples, CV_8U);
765 uchar* mptr = mask.ptr();
766 for( i = 0; i < nsamples; i++ )
767 mptr[i] = (uchar)(i < count);
768 trainSampleIdx.create(1, count, CV_32S);
769 testSampleIdx.create(1, nsamples - count, CV_32S);
770 int j0 = 0, j1 = 0;
771 const int* sptr = !sampleIdx.empty() ? sampleIdx.ptr<int>() : 0;
772 int* trainptr = trainSampleIdx.ptr<int>();
773 int* testptr = testSampleIdx.ptr<int>();
774 for( i = 0; i < nsamples; i++ )
775 {
776 int idx = sptr ? sptr[i] : i;
777 if( mptr[i] )
778 trainptr[j0++] = idx;
779 else
780 testptr[j1++] = idx;
781 }
782 if( shuffle )
783 shuffleTrainTest();
784 }
785 }
786
shuffleTrainTest()787 void shuffleTrainTest()
788 {
789 if( !trainSampleIdx.empty() && !testSampleIdx.empty() )
790 {
791 int i, nsamples = getNSamples(), ntrain = getNTrainSamples(), ntest = getNTestSamples();
792 int* trainIdx = trainSampleIdx.ptr<int>();
793 int* testIdx = testSampleIdx.ptr<int>();
794 RNG& rng = theRNG();
795
796 for( i = 0; i < nsamples; i++)
797 {
798 int a = rng.uniform(0, nsamples);
799 int b = rng.uniform(0, nsamples);
800 int* ptra = trainIdx;
801 int* ptrb = trainIdx;
802 if( a >= ntrain )
803 {
804 ptra = testIdx;
805 a -= ntrain;
806 CV_Assert( a < ntest );
807 }
808 if( b >= ntrain )
809 {
810 ptrb = testIdx;
811 b -= ntrain;
812 CV_Assert( b < ntest );
813 }
814 std::swap(ptra[a], ptrb[b]);
815 }
816 }
817 }
818
getTrainSamples(int _layout,bool compressSamples,bool compressVars) const819 Mat getTrainSamples(int _layout,
820 bool compressSamples,
821 bool compressVars) const
822 {
823 if( samples.empty() )
824 return samples;
825
826 if( (!compressSamples || (trainSampleIdx.empty() && sampleIdx.empty())) &&
827 (!compressVars || varIdx.empty()) &&
828 layout == _layout )
829 return samples;
830
831 int drows = getNTrainSamples(), dcols = getNVars();
832 Mat sidx = getTrainSampleIdx(), vidx = getVarIdx();
833 const float* src0 = samples.ptr<float>();
834 const int* sptr = !sidx.empty() ? sidx.ptr<int>() : 0;
835 const int* vptr = !vidx.empty() ? vidx.ptr<int>() : 0;
836 size_t sstep0 = samples.step/samples.elemSize();
837 size_t sstep = layout == ROW_SAMPLE ? sstep0 : 1;
838 size_t vstep = layout == ROW_SAMPLE ? 1 : sstep0;
839
840 if( _layout == COL_SAMPLE )
841 {
842 std::swap(drows, dcols);
843 std::swap(sptr, vptr);
844 std::swap(sstep, vstep);
845 }
846
847 Mat dsamples(drows, dcols, CV_32F);
848
849 for( int i = 0; i < drows; i++ )
850 {
851 const float* src = src0 + (sptr ? sptr[i] : i)*sstep;
852 float* dst = dsamples.ptr<float>(i);
853
854 for( int j = 0; j < dcols; j++ )
855 dst[j] = src[(vptr ? vptr[j] : j)*vstep];
856 }
857
858 return dsamples;
859 }
860
getValues(int vi,InputArray _sidx,float * values) const861 void getValues( int vi, InputArray _sidx, float* values ) const
862 {
863 Mat sidx = _sidx.getMat();
864 int i, n = sidx.checkVector(1, CV_32S), nsamples = getNSamples();
865 CV_Assert( 0 <= vi && vi < getNAllVars() );
866 CV_Assert( n >= 0 );
867 const int* s = n > 0 ? sidx.ptr<int>() : 0;
868 if( n == 0 )
869 n = nsamples;
870
871 size_t step = samples.step/samples.elemSize();
872 size_t sstep = layout == ROW_SAMPLE ? step : 1;
873 size_t vstep = layout == ROW_SAMPLE ? 1 : step;
874
875 const float* src = samples.ptr<float>() + vi*vstep;
876 float subst = missingSubst.at<float>(vi);
877 for( i = 0; i < n; i++ )
878 {
879 int j = i;
880 if( s )
881 {
882 j = s[i];
883 CV_Assert( 0 <= j && j < nsamples );
884 }
885 values[i] = src[j*sstep];
886 if( values[i] == MISSED_VAL )
887 values[i] = subst;
888 }
889 }
890
getNormCatValues(int vi,InputArray _sidx,int * values) const891 void getNormCatValues( int vi, InputArray _sidx, int* values ) const
892 {
893 float* fvalues = (float*)values;
894 getValues(vi, _sidx, fvalues);
895 int i, n = (int)_sidx.total();
896 Vec2i ofs = catOfs.at<Vec2i>(vi);
897 int m = ofs[1] - ofs[0];
898
899 CV_Assert( m > 0 ); // if m==0, vi is an ordered variable
900 const int* cmap = &catMap.at<int>(ofs[0]);
901 bool fastMap = (m == cmap[m - 1] - cmap[0] + 1);
902
903 if( fastMap )
904 {
905 for( i = 0; i < n; i++ )
906 {
907 int val = cvRound(fvalues[i]);
908 int idx = val - cmap[0];
909 CV_Assert(cmap[idx] == val);
910 values[i] = idx;
911 }
912 }
913 else
914 {
915 for( i = 0; i < n; i++ )
916 {
917 int val = cvRound(fvalues[i]);
918 int a = 0, b = m, c = -1;
919
920 while( a < b )
921 {
922 c = (a + b) >> 1;
923 if( val < cmap[c] )
924 b = c;
925 else if( val > cmap[c] )
926 a = c+1;
927 else
928 break;
929 }
930
931 CV_DbgAssert( c >= 0 && val == cmap[c] );
932 values[i] = c;
933 }
934 }
935 }
936
getSample(InputArray _vidx,int sidx,float * buf) const937 void getSample(InputArray _vidx, int sidx, float* buf) const
938 {
939 CV_Assert(buf != 0 && 0 <= sidx && sidx < getNSamples());
940 Mat vidx = _vidx.getMat();
941 int i, n = vidx.checkVector(1, CV_32S), nvars = getNAllVars();
942 CV_Assert( n >= 0 );
943 const int* vptr = n > 0 ? vidx.ptr<int>() : 0;
944 if( n == 0 )
945 n = nvars;
946
947 size_t step = samples.step/samples.elemSize();
948 size_t sstep = layout == ROW_SAMPLE ? step : 1;
949 size_t vstep = layout == ROW_SAMPLE ? 1 : step;
950
951 const float* src = samples.ptr<float>() + sidx*sstep;
952 for( i = 0; i < n; i++ )
953 {
954 int j = i;
955 if( vptr )
956 {
957 j = vptr[i];
958 CV_Assert( 0 <= j && j < nvars );
959 }
960 buf[i] = src[j*vstep];
961 }
962 }
963
964 FILE* file;
965 int layout;
966 Mat samples, missing, varType, varIdx, responses, missingSubst;
967 Mat sampleIdx, trainSampleIdx, testSampleIdx;
968 Mat sampleWeights, catMap, catOfs;
969 Mat normCatResponses, classLabels, classCounters;
970 MapType nameMap;
971 };
972
loadFromCSV(const String & filename,int headerLines,int responseStartIdx,int responseEndIdx,const String & varTypeSpec,char delimiter,char missch)973 Ptr<TrainData> TrainData::loadFromCSV(const String& filename,
974 int headerLines,
975 int responseStartIdx,
976 int responseEndIdx,
977 const String& varTypeSpec,
978 char delimiter, char missch)
979 {
980 Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
981 if(!td->loadCSV(filename, headerLines, responseStartIdx, responseEndIdx, varTypeSpec, delimiter, missch))
982 td.release();
983 return td;
984 }
985
create(InputArray samples,int layout,InputArray responses,InputArray varIdx,InputArray sampleIdx,InputArray sampleWeights,InputArray varType)986 Ptr<TrainData> TrainData::create(InputArray samples, int layout, InputArray responses,
987 InputArray varIdx, InputArray sampleIdx, InputArray sampleWeights,
988 InputArray varType)
989 {
990 Ptr<TrainDataImpl> td = makePtr<TrainDataImpl>();
991 td->setData(samples, layout, responses, varIdx, sampleIdx, sampleWeights, varType, noArray());
992 return td;
993 }
994
995 }}
996
997 /* End of file. */
998