1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 // * Redistribution's of source code must retain the above copyright notice,
19 // this list of conditions and the following disclaimer.
20 //
21 // * Redistribution's in binary form must reproduce the above copyright notice,
22 // this list of conditions and the following disclaimer in the documentation
23 // and/or other materials provided with the distribution.
24 //
25 // * The name of Intel Corporation may not be used to endorse or promote products
26 // derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40
41 #include "precomp.hpp"
42
43 namespace cv {
44 namespace ml {
45
46
47 class NormalBayesClassifierImpl : public NormalBayesClassifier
48 {
49 public:
NormalBayesClassifierImpl()50 NormalBayesClassifierImpl()
51 {
52 nallvars = 0;
53 }
54
train(const Ptr<TrainData> & trainData,int flags)55 bool train( const Ptr<TrainData>& trainData, int flags )
56 {
57 const float min_variation = FLT_EPSILON;
58 Mat responses = trainData->getNormCatResponses();
59 Mat __cls_labels = trainData->getClassLabels();
60 Mat __var_idx = trainData->getVarIdx();
61 Mat samples = trainData->getTrainSamples();
62 int nclasses = (int)__cls_labels.total();
63
64 int nvars = trainData->getNVars();
65 int s, c1, c2, cls;
66
67 int __nallvars = trainData->getNAllVars();
68 bool update = (flags & UPDATE_MODEL) != 0;
69
70 if( !update )
71 {
72 nallvars = __nallvars;
73 count.resize(nclasses);
74 sum.resize(nclasses);
75 productsum.resize(nclasses);
76 avg.resize(nclasses);
77 inv_eigen_values.resize(nclasses);
78 cov_rotate_mats.resize(nclasses);
79
80 for( cls = 0; cls < nclasses; cls++ )
81 {
82 count[cls] = Mat::zeros( 1, nvars, CV_32SC1 );
83 sum[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
84 productsum[cls] = Mat::zeros( nvars, nvars, CV_64FC1 );
85 avg[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
86 inv_eigen_values[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
87 cov_rotate_mats[cls] = Mat::zeros( nvars, nvars, CV_64FC1 );
88 }
89
90 var_idx = __var_idx;
91 cls_labels = __cls_labels;
92
93 c.create(1, nclasses, CV_64FC1);
94 }
95 else
96 {
97 // check that the new training data has the same dimensionality etc.
98 if( nallvars != __nallvars ||
99 var_idx.size() != __var_idx.size() ||
100 norm(var_idx, __var_idx, NORM_INF) != 0 ||
101 cls_labels.size() != __cls_labels.size() ||
102 norm(cls_labels, __cls_labels, NORM_INF) != 0 )
103 CV_Error( CV_StsBadArg,
104 "The new training data is inconsistent with the original training data; varIdx and the class labels should be the same" );
105 }
106
107 Mat cov( nvars, nvars, CV_64FC1 );
108 int nsamples = samples.rows;
109
110 // process train data (count, sum , productsum)
111 for( s = 0; s < nsamples; s++ )
112 {
113 cls = responses.at<int>(s);
114 int* count_data = count[cls].ptr<int>();
115 double* sum_data = sum[cls].ptr<double>();
116 double* prod_data = productsum[cls].ptr<double>();
117 const float* train_vec = samples.ptr<float>(s);
118
119 for( c1 = 0; c1 < nvars; c1++, prod_data += nvars )
120 {
121 double val1 = train_vec[c1];
122 sum_data[c1] += val1;
123 count_data[c1]++;
124 for( c2 = c1; c2 < nvars; c2++ )
125 prod_data[c2] += train_vec[c2]*val1;
126 }
127 }
128
129 Mat vt;
130
131 // calculate avg, covariance matrix, c
132 for( cls = 0; cls < nclasses; cls++ )
133 {
134 double det = 1;
135 int i, j;
136 Mat& w = inv_eigen_values[cls];
137 int* count_data = count[cls].ptr<int>();
138 double* avg_data = avg[cls].ptr<double>();
139 double* sum1 = sum[cls].ptr<double>();
140
141 completeSymm(productsum[cls], 0);
142
143 for( j = 0; j < nvars; j++ )
144 {
145 int n = count_data[j];
146 avg_data[j] = n ? sum1[j] / n : 0.;
147 }
148
149 count_data = count[cls].ptr<int>();
150 avg_data = avg[cls].ptr<double>();
151 sum1 = sum[cls].ptr<double>();
152
153 for( i = 0; i < nvars; i++ )
154 {
155 double* avg2_data = avg[cls].ptr<double>();
156 double* sum2 = sum[cls].ptr<double>();
157 double* prod_data = productsum[cls].ptr<double>(i);
158 double* cov_data = cov.ptr<double>(i);
159 double s1val = sum1[i];
160 double avg1 = avg_data[i];
161 int _count = count_data[i];
162
163 for( j = 0; j <= i; j++ )
164 {
165 double avg2 = avg2_data[j];
166 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * _count;
167 cov_val = (_count > 1) ? cov_val / (_count - 1) : cov_val;
168 cov_data[j] = cov_val;
169 }
170 }
171
172 completeSymm( cov, 1 );
173
174 SVD::compute(cov, w, cov_rotate_mats[cls], noArray());
175 transpose(cov_rotate_mats[cls], cov_rotate_mats[cls]);
176 cv::max(w, min_variation, w);
177 for( j = 0; j < nvars; j++ )
178 det *= w.at<double>(j);
179
180 divide(1., w, w);
181 c.at<double>(cls) = det > 0 ? log(det) : -700;
182 }
183
184 return true;
185 }
186
187 class NBPredictBody : public ParallelLoopBody
188 {
189 public:
NBPredictBody(const Mat & _c,const vector<Mat> & _cov_rotate_mats,const vector<Mat> & _inv_eigen_values,const vector<Mat> & _avg,const Mat & _samples,const Mat & _vidx,const Mat & _cls_labels,Mat & _results,Mat & _results_prob,bool _rawOutput)190 NBPredictBody( const Mat& _c, const vector<Mat>& _cov_rotate_mats,
191 const vector<Mat>& _inv_eigen_values,
192 const vector<Mat>& _avg,
193 const Mat& _samples, const Mat& _vidx, const Mat& _cls_labels,
194 Mat& _results, Mat& _results_prob, bool _rawOutput )
195 {
196 c = &_c;
197 cov_rotate_mats = &_cov_rotate_mats;
198 inv_eigen_values = &_inv_eigen_values;
199 avg = &_avg;
200 samples = &_samples;
201 vidx = &_vidx;
202 cls_labels = &_cls_labels;
203 results = &_results;
204 results_prob = !_results_prob.empty() ? &_results_prob : 0;
205 rawOutput = _rawOutput;
206 }
207
208 const Mat* c;
209 const vector<Mat>* cov_rotate_mats;
210 const vector<Mat>* inv_eigen_values;
211 const vector<Mat>* avg;
212 const Mat* samples;
213 const Mat* vidx;
214 const Mat* cls_labels;
215
216 Mat* results_prob;
217 Mat* results;
218 float* value;
219 bool rawOutput;
220
operator ()(const Range & range) const221 void operator()( const Range& range ) const
222 {
223 int cls = -1;
224 int rtype = 0, rptype = 0;
225 size_t rstep = 0, rpstep = 0;
226 int nclasses = (int)cls_labels->total();
227 int nvars = avg->at(0).cols;
228 double probability = 0;
229 const int* vptr = vidx && !vidx->empty() ? vidx->ptr<int>() : 0;
230
231 if (results)
232 {
233 rtype = results->type();
234 rstep = results->isContinuous() ? 1 : results->step/results->elemSize();
235 }
236 if (results_prob)
237 {
238 rptype = results_prob->type();
239 rpstep = results_prob->isContinuous() ? 1 : results_prob->step/results_prob->elemSize();
240 }
241 // allocate memory and initializing headers for calculating
242 cv::AutoBuffer<double> _buffer(nvars*2);
243 double* _diffin = _buffer;
244 double* _diffout = _buffer + nvars;
245 Mat diffin( 1, nvars, CV_64FC1, _diffin );
246 Mat diffout( 1, nvars, CV_64FC1, _diffout );
247
248 for(int k = range.start; k < range.end; k++ )
249 {
250 double opt = FLT_MAX;
251
252 for(int i = 0; i < nclasses; i++ )
253 {
254 double cur = c->at<double>(i);
255 const Mat& u = cov_rotate_mats->at(i);
256 const Mat& w = inv_eigen_values->at(i);
257
258 const double* avg_data = avg->at(i).ptr<double>();
259 const float* x = samples->ptr<float>(k);
260
261 // cov = u w u' --> cov^(-1) = u w^(-1) u'
262 for(int j = 0; j < nvars; j++ )
263 _diffin[j] = avg_data[j] - x[vptr ? vptr[j] : j];
264
265 gemm( diffin, u, 1, noArray(), 0, diffout, GEMM_2_T );
266 for(int j = 0; j < nvars; j++ )
267 {
268 double d = _diffout[j];
269 cur += d*d*w.ptr<double>()[j];
270 }
271
272 if( cur < opt )
273 {
274 cls = i;
275 opt = cur;
276 }
277 probability = exp( -0.5 * cur );
278
279 if( results_prob )
280 {
281 if ( rptype == CV_32FC1 )
282 results_prob->ptr<float>()[k*rpstep + i] = (float)probability;
283 else
284 results_prob->ptr<double>()[k*rpstep + i] = probability;
285 }
286 }
287
288 int ival = rawOutput ? cls : cls_labels->at<int>(cls);
289 if( results )
290 {
291 if( rtype == CV_32SC1 )
292 results->ptr<int>()[k*rstep] = ival;
293 else
294 results->ptr<float>()[k*rstep] = (float)ival;
295 }
296 }
297 }
298 };
299
predict(InputArray _samples,OutputArray _results,int flags) const300 float predict( InputArray _samples, OutputArray _results, int flags ) const
301 {
302 return predictProb(_samples, _results, noArray(), flags);
303 }
304
predictProb(InputArray _samples,OutputArray _results,OutputArray _resultsProb,int flags) const305 float predictProb( InputArray _samples, OutputArray _results, OutputArray _resultsProb, int flags ) const
306 {
307 int value=0;
308 Mat samples = _samples.getMat(), results, resultsProb;
309 int nsamples = samples.rows, nclasses = (int)cls_labels.total();
310 bool rawOutput = (flags & RAW_OUTPUT) != 0;
311
312 if( samples.type() != CV_32F || samples.cols != nallvars )
313 CV_Error( CV_StsBadArg,
314 "The input samples must be 32f matrix with the number of columns = nallvars" );
315
316 if( samples.rows > 1 && _results.needed() )
317 CV_Error( CV_StsNullPtr,
318 "When the number of input samples is >1, the output vector of results must be passed" );
319
320 if( _results.needed() )
321 {
322 _results.create(nsamples, 1, CV_32S);
323 results = _results.getMat();
324 }
325 else
326 results = Mat(1, 1, CV_32S, &value);
327
328 if( _resultsProb.needed() )
329 {
330 _resultsProb.create(nsamples, nclasses, CV_32F);
331 resultsProb = _resultsProb.getMat();
332 }
333
334 cv::parallel_for_(cv::Range(0, nsamples),
335 NBPredictBody(c, cov_rotate_mats, inv_eigen_values, avg, samples,
336 var_idx, cls_labels, results, resultsProb, rawOutput));
337
338 return (float)value;
339 }
340
write(FileStorage & fs) const341 void write( FileStorage& fs ) const
342 {
343 int nclasses = (int)cls_labels.total(), i;
344
345 fs << "var_count" << (var_idx.empty() ? nallvars : (int)var_idx.total());
346 fs << "var_all" << nallvars;
347
348 if( !var_idx.empty() )
349 fs << "var_idx" << var_idx;
350 fs << "cls_labels" << cls_labels;
351
352 fs << "count" << "[";
353 for( i = 0; i < nclasses; i++ )
354 fs << count[i];
355
356 fs << "]" << "sum" << "[";
357 for( i = 0; i < nclasses; i++ )
358 fs << sum[i];
359
360 fs << "]" << "productsum" << "[";
361 for( i = 0; i < nclasses; i++ )
362 fs << productsum[i];
363
364 fs << "]" << "avg" << "[";
365 for( i = 0; i < nclasses; i++ )
366 fs << avg[i];
367
368 fs << "]" << "inv_eigen_values" << "[";
369 for( i = 0; i < nclasses; i++ )
370 fs << inv_eigen_values[i];
371
372 fs << "]" << "cov_rotate_mats" << "[";
373 for( i = 0; i < nclasses; i++ )
374 fs << cov_rotate_mats[i];
375
376 fs << "]";
377
378 fs << "c" << c;
379 }
380
read(const FileNode & fn)381 void read( const FileNode& fn )
382 {
383 clear();
384
385 fn["var_all"] >> nallvars;
386
387 if( nallvars <= 0 )
388 CV_Error( CV_StsParseError,
389 "The field \"var_count\" of NBayes classifier is missing or non-positive" );
390
391 fn["var_idx"] >> var_idx;
392 fn["cls_labels"] >> cls_labels;
393
394 int nclasses = (int)cls_labels.total(), i;
395
396 if( cls_labels.empty() || nclasses < 1 )
397 CV_Error( CV_StsParseError, "No or invalid \"cls_labels\" in NBayes classifier" );
398
399 FileNodeIterator
400 count_it = fn["count"].begin(),
401 sum_it = fn["sum"].begin(),
402 productsum_it = fn["productsum"].begin(),
403 avg_it = fn["avg"].begin(),
404 inv_eigen_values_it = fn["inv_eigen_values"].begin(),
405 cov_rotate_mats_it = fn["cov_rotate_mats"].begin();
406
407 count.resize(nclasses);
408 sum.resize(nclasses);
409 productsum.resize(nclasses);
410 avg.resize(nclasses);
411 inv_eigen_values.resize(nclasses);
412 cov_rotate_mats.resize(nclasses);
413
414 for( i = 0; i < nclasses; i++, ++count_it, ++sum_it, ++productsum_it, ++avg_it,
415 ++inv_eigen_values_it, ++cov_rotate_mats_it )
416 {
417 *count_it >> count[i];
418 *sum_it >> sum[i];
419 *productsum_it >> productsum[i];
420 *avg_it >> avg[i];
421 *inv_eigen_values_it >> inv_eigen_values[i];
422 *cov_rotate_mats_it >> cov_rotate_mats[i];
423 }
424
425 fn["c"] >> c;
426 }
427
clear()428 void clear()
429 {
430 count.clear();
431 sum.clear();
432 productsum.clear();
433 avg.clear();
434 inv_eigen_values.clear();
435 cov_rotate_mats.clear();
436
437 var_idx.release();
438 cls_labels.release();
439 c.release();
440 nallvars = 0;
441 }
442
isTrained() const443 bool isTrained() const { return !avg.empty(); }
isClassifier() const444 bool isClassifier() const { return true; }
getVarCount() const445 int getVarCount() const { return nallvars; }
getDefaultName() const446 String getDefaultName() const { return "opencv_ml_nbayes"; }
447
448 int nallvars;
449 Mat var_idx, cls_labels, c;
450 vector<Mat> count, sum, productsum, avg, inv_eigen_values, cov_rotate_mats;
451 };
452
453
create()454 Ptr<NormalBayesClassifier> NormalBayesClassifier::create()
455 {
456 Ptr<NormalBayesClassifierImpl> p = makePtr<NormalBayesClassifierImpl>();
457 return p;
458 }
459
460 }
461 }
462
463 /* End of file. */
464