• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //            Intel License Agreement
11 //
12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
13 // Third party copyrights are property of their respective owners.
14 //
15 // Redistribution and use in source and binary forms, with or without modification,
16 // are permitted provided that the following conditions are met:
17 //
18 //   * Redistribution's of source code must retain the above copyright notice,
19 //     this list of conditions and the following disclaimer.
20 //
21 //   * Redistribution's in binary form must reproduce the above copyright notice,
22 //     this list of conditions and the following disclaimer in the documentation
23 //     and/or other materials provided with the distribution.
24 //
25 //   * The name of Intel Corporation may not be used to endorse or promote products
26 //     derived from this software without specific prior written permission.
27 //
28 // This software is provided by the copyright holders and contributors "as is" and
29 // any express or implied warranties, including, but not limited to, the implied
30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
31 // In no event shall the Intel Corporation or contributors be liable for any direct,
32 // indirect, incidental, special, exemplary, or consequential damages
33 // (including, but not limited to, procurement of substitute goods or services;
34 // loss of use, data, or profits; or business interruption) however caused
35 // and on any theory of liability, whether in contract, strict liability,
36 // or tort (including negligence or otherwise) arising in any way out of
37 // the use of this software, even if advised of the possibility of such damage.
38 //
39 //M*/
40 
41 #include "precomp.hpp"
42 
43 namespace cv {
44 namespace ml {
45 
46 
47 class NormalBayesClassifierImpl : public NormalBayesClassifier
48 {
49 public:
NormalBayesClassifierImpl()50     NormalBayesClassifierImpl()
51     {
52         nallvars = 0;
53     }
54 
train(const Ptr<TrainData> & trainData,int flags)55     bool train( const Ptr<TrainData>& trainData, int flags )
56     {
57         const float min_variation = FLT_EPSILON;
58         Mat responses = trainData->getNormCatResponses();
59         Mat __cls_labels = trainData->getClassLabels();
60         Mat __var_idx = trainData->getVarIdx();
61         Mat samples = trainData->getTrainSamples();
62         int nclasses = (int)__cls_labels.total();
63 
64         int nvars = trainData->getNVars();
65         int s, c1, c2, cls;
66 
67         int __nallvars = trainData->getNAllVars();
68         bool update = (flags & UPDATE_MODEL) != 0;
69 
70         if( !update )
71         {
72             nallvars = __nallvars;
73             count.resize(nclasses);
74             sum.resize(nclasses);
75             productsum.resize(nclasses);
76             avg.resize(nclasses);
77             inv_eigen_values.resize(nclasses);
78             cov_rotate_mats.resize(nclasses);
79 
80             for( cls = 0; cls < nclasses; cls++ )
81             {
82                 count[cls]            = Mat::zeros( 1, nvars, CV_32SC1 );
83                 sum[cls]              = Mat::zeros( 1, nvars, CV_64FC1 );
84                 productsum[cls]       = Mat::zeros( nvars, nvars, CV_64FC1 );
85                 avg[cls]              = Mat::zeros( 1, nvars, CV_64FC1 );
86                 inv_eigen_values[cls] = Mat::zeros( 1, nvars, CV_64FC1 );
87                 cov_rotate_mats[cls]  = Mat::zeros( nvars, nvars, CV_64FC1 );
88             }
89 
90             var_idx = __var_idx;
91             cls_labels = __cls_labels;
92 
93             c.create(1, nclasses, CV_64FC1);
94         }
95         else
96         {
97             // check that the new training data has the same dimensionality etc.
98             if( nallvars != __nallvars ||
99                 var_idx.size() != __var_idx.size() ||
100                 norm(var_idx, __var_idx, NORM_INF) != 0 ||
101                 cls_labels.size() != __cls_labels.size() ||
102                 norm(cls_labels, __cls_labels, NORM_INF) != 0 )
103                 CV_Error( CV_StsBadArg,
104                 "The new training data is inconsistent with the original training data; varIdx and the class labels should be the same" );
105         }
106 
107         Mat cov( nvars, nvars, CV_64FC1 );
108         int nsamples = samples.rows;
109 
110         // process train data (count, sum , productsum)
111         for( s = 0; s < nsamples; s++ )
112         {
113             cls = responses.at<int>(s);
114             int* count_data = count[cls].ptr<int>();
115             double* sum_data = sum[cls].ptr<double>();
116             double* prod_data = productsum[cls].ptr<double>();
117             const float* train_vec = samples.ptr<float>(s);
118 
119             for( c1 = 0; c1 < nvars; c1++, prod_data += nvars )
120             {
121                 double val1 = train_vec[c1];
122                 sum_data[c1] += val1;
123                 count_data[c1]++;
124                 for( c2 = c1; c2 < nvars; c2++ )
125                     prod_data[c2] += train_vec[c2]*val1;
126             }
127         }
128 
129         Mat vt;
130 
131         // calculate avg, covariance matrix, c
132         for( cls = 0; cls < nclasses; cls++ )
133         {
134             double det = 1;
135             int i, j;
136             Mat& w = inv_eigen_values[cls];
137             int* count_data = count[cls].ptr<int>();
138             double* avg_data = avg[cls].ptr<double>();
139             double* sum1 = sum[cls].ptr<double>();
140 
141             completeSymm(productsum[cls], 0);
142 
143             for( j = 0; j < nvars; j++ )
144             {
145                 int n = count_data[j];
146                 avg_data[j] = n ? sum1[j] / n : 0.;
147             }
148 
149             count_data = count[cls].ptr<int>();
150             avg_data = avg[cls].ptr<double>();
151             sum1 = sum[cls].ptr<double>();
152 
153             for( i = 0; i < nvars; i++ )
154             {
155                 double* avg2_data = avg[cls].ptr<double>();
156                 double* sum2 = sum[cls].ptr<double>();
157                 double* prod_data = productsum[cls].ptr<double>(i);
158                 double* cov_data = cov.ptr<double>(i);
159                 double s1val = sum1[i];
160                 double avg1 = avg_data[i];
161                 int _count = count_data[i];
162 
163                 for( j = 0; j <= i; j++ )
164                 {
165                     double avg2 = avg2_data[j];
166                     double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * _count;
167                     cov_val = (_count > 1) ? cov_val / (_count - 1) : cov_val;
168                     cov_data[j] = cov_val;
169                 }
170             }
171 
172             completeSymm( cov, 1 );
173 
174             SVD::compute(cov, w, cov_rotate_mats[cls], noArray());
175             transpose(cov_rotate_mats[cls], cov_rotate_mats[cls]);
176             cv::max(w, min_variation, w);
177             for( j = 0; j < nvars; j++ )
178                 det *= w.at<double>(j);
179 
180             divide(1., w, w);
181             c.at<double>(cls) = det > 0 ? log(det) : -700;
182         }
183 
184         return true;
185     }
186 
187     class NBPredictBody : public ParallelLoopBody
188     {
189     public:
NBPredictBody(const Mat & _c,const vector<Mat> & _cov_rotate_mats,const vector<Mat> & _inv_eigen_values,const vector<Mat> & _avg,const Mat & _samples,const Mat & _vidx,const Mat & _cls_labels,Mat & _results,Mat & _results_prob,bool _rawOutput)190         NBPredictBody( const Mat& _c, const vector<Mat>& _cov_rotate_mats,
191                        const vector<Mat>& _inv_eigen_values,
192                        const vector<Mat>& _avg,
193                        const Mat& _samples, const Mat& _vidx, const Mat& _cls_labels,
194                        Mat& _results, Mat& _results_prob, bool _rawOutput )
195         {
196             c = &_c;
197             cov_rotate_mats = &_cov_rotate_mats;
198             inv_eigen_values = &_inv_eigen_values;
199             avg = &_avg;
200             samples = &_samples;
201             vidx = &_vidx;
202             cls_labels = &_cls_labels;
203             results = &_results;
204             results_prob = !_results_prob.empty() ? &_results_prob : 0;
205             rawOutput = _rawOutput;
206         }
207 
208         const Mat* c;
209         const vector<Mat>* cov_rotate_mats;
210         const vector<Mat>* inv_eigen_values;
211         const vector<Mat>* avg;
212         const Mat* samples;
213         const Mat* vidx;
214         const Mat* cls_labels;
215 
216         Mat* results_prob;
217         Mat* results;
218         float* value;
219         bool rawOutput;
220 
operator ()(const Range & range) const221         void operator()( const Range& range ) const
222         {
223             int cls = -1;
224             int rtype = 0, rptype = 0;
225             size_t rstep = 0, rpstep = 0;
226             int nclasses = (int)cls_labels->total();
227             int nvars = avg->at(0).cols;
228             double probability = 0;
229             const int* vptr = vidx && !vidx->empty() ? vidx->ptr<int>() : 0;
230 
231             if (results)
232             {
233                 rtype = results->type();
234                 rstep = results->isContinuous() ? 1 : results->step/results->elemSize();
235             }
236             if (results_prob)
237             {
238                 rptype = results_prob->type();
239                 rpstep = results_prob->isContinuous() ? 1 : results_prob->step/results_prob->elemSize();
240             }
241             // allocate memory and initializing headers for calculating
242             cv::AutoBuffer<double> _buffer(nvars*2);
243             double* _diffin = _buffer;
244             double* _diffout = _buffer + nvars;
245             Mat diffin( 1, nvars, CV_64FC1, _diffin );
246             Mat diffout( 1, nvars, CV_64FC1, _diffout );
247 
248             for(int k = range.start; k < range.end; k++ )
249             {
250                 double opt = FLT_MAX;
251 
252                 for(int i = 0; i < nclasses; i++ )
253                 {
254                     double cur = c->at<double>(i);
255                     const Mat& u = cov_rotate_mats->at(i);
256                     const Mat& w = inv_eigen_values->at(i);
257 
258                     const double* avg_data = avg->at(i).ptr<double>();
259                     const float* x = samples->ptr<float>(k);
260 
261                     // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
262                     for(int j = 0; j < nvars; j++ )
263                         _diffin[j] = avg_data[j] - x[vptr ? vptr[j] : j];
264 
265                     gemm( diffin, u, 1, noArray(), 0, diffout, GEMM_2_T );
266                     for(int j = 0; j < nvars; j++ )
267                     {
268                         double d = _diffout[j];
269                         cur += d*d*w.ptr<double>()[j];
270                     }
271 
272                     if( cur < opt )
273                     {
274                         cls = i;
275                         opt = cur;
276                     }
277                     probability = exp( -0.5 * cur );
278 
279                     if( results_prob )
280                     {
281                         if ( rptype == CV_32FC1 )
282                             results_prob->ptr<float>()[k*rpstep + i] = (float)probability;
283                         else
284                             results_prob->ptr<double>()[k*rpstep + i] = probability;
285                     }
286                 }
287 
288                 int ival = rawOutput ? cls : cls_labels->at<int>(cls);
289                 if( results )
290                 {
291                     if( rtype == CV_32SC1 )
292                         results->ptr<int>()[k*rstep] = ival;
293                     else
294                         results->ptr<float>()[k*rstep] = (float)ival;
295                 }
296             }
297         }
298     };
299 
predict(InputArray _samples,OutputArray _results,int flags) const300     float predict( InputArray _samples, OutputArray _results, int flags ) const
301     {
302         return predictProb(_samples, _results, noArray(), flags);
303     }
304 
predictProb(InputArray _samples,OutputArray _results,OutputArray _resultsProb,int flags) const305     float predictProb( InputArray _samples, OutputArray _results, OutputArray _resultsProb, int flags ) const
306     {
307         int value=0;
308         Mat samples = _samples.getMat(), results, resultsProb;
309         int nsamples = samples.rows, nclasses = (int)cls_labels.total();
310         bool rawOutput = (flags & RAW_OUTPUT) != 0;
311 
312         if( samples.type() != CV_32F || samples.cols != nallvars )
313             CV_Error( CV_StsBadArg,
314                      "The input samples must be 32f matrix with the number of columns = nallvars" );
315 
316         if( samples.rows > 1 && _results.needed() )
317             CV_Error( CV_StsNullPtr,
318                      "When the number of input samples is >1, the output vector of results must be passed" );
319 
320         if( _results.needed() )
321         {
322             _results.create(nsamples, 1, CV_32S);
323             results = _results.getMat();
324         }
325         else
326             results = Mat(1, 1, CV_32S, &value);
327 
328         if( _resultsProb.needed() )
329         {
330             _resultsProb.create(nsamples, nclasses, CV_32F);
331             resultsProb = _resultsProb.getMat();
332         }
333 
334         cv::parallel_for_(cv::Range(0, nsamples),
335                           NBPredictBody(c, cov_rotate_mats, inv_eigen_values, avg, samples,
336                                        var_idx, cls_labels, results, resultsProb, rawOutput));
337 
338         return (float)value;
339     }
340 
write(FileStorage & fs) const341     void write( FileStorage& fs ) const
342     {
343         int nclasses = (int)cls_labels.total(), i;
344 
345         fs << "var_count" << (var_idx.empty() ? nallvars : (int)var_idx.total());
346         fs << "var_all" << nallvars;
347 
348         if( !var_idx.empty() )
349             fs << "var_idx" << var_idx;
350         fs << "cls_labels" << cls_labels;
351 
352         fs << "count" << "[";
353         for( i = 0; i < nclasses; i++ )
354             fs << count[i];
355 
356         fs << "]" << "sum" << "[";
357         for( i = 0; i < nclasses; i++ )
358             fs << sum[i];
359 
360         fs << "]" << "productsum" << "[";
361         for( i = 0; i < nclasses; i++ )
362             fs << productsum[i];
363 
364         fs << "]" << "avg" << "[";
365         for( i = 0; i < nclasses; i++ )
366             fs << avg[i];
367 
368         fs << "]" << "inv_eigen_values" << "[";
369         for( i = 0; i < nclasses; i++ )
370             fs << inv_eigen_values[i];
371 
372         fs << "]" << "cov_rotate_mats" << "[";
373         for( i = 0; i < nclasses; i++ )
374             fs << cov_rotate_mats[i];
375 
376         fs << "]";
377 
378         fs << "c" << c;
379     }
380 
read(const FileNode & fn)381     void read( const FileNode& fn )
382     {
383         clear();
384 
385         fn["var_all"] >> nallvars;
386 
387         if( nallvars <= 0 )
388             CV_Error( CV_StsParseError,
389                      "The field \"var_count\" of NBayes classifier is missing or non-positive" );
390 
391         fn["var_idx"] >> var_idx;
392         fn["cls_labels"] >> cls_labels;
393 
394         int nclasses = (int)cls_labels.total(), i;
395 
396         if( cls_labels.empty() || nclasses < 1 )
397             CV_Error( CV_StsParseError, "No or invalid \"cls_labels\" in NBayes classifier" );
398 
399         FileNodeIterator
400             count_it = fn["count"].begin(),
401             sum_it = fn["sum"].begin(),
402             productsum_it = fn["productsum"].begin(),
403             avg_it = fn["avg"].begin(),
404             inv_eigen_values_it = fn["inv_eigen_values"].begin(),
405             cov_rotate_mats_it = fn["cov_rotate_mats"].begin();
406 
407         count.resize(nclasses);
408         sum.resize(nclasses);
409         productsum.resize(nclasses);
410         avg.resize(nclasses);
411         inv_eigen_values.resize(nclasses);
412         cov_rotate_mats.resize(nclasses);
413 
414         for( i = 0; i < nclasses; i++, ++count_it, ++sum_it, ++productsum_it, ++avg_it,
415                                     ++inv_eigen_values_it, ++cov_rotate_mats_it )
416         {
417             *count_it >> count[i];
418             *sum_it >> sum[i];
419             *productsum_it >> productsum[i];
420             *avg_it >> avg[i];
421             *inv_eigen_values_it >> inv_eigen_values[i];
422             *cov_rotate_mats_it >> cov_rotate_mats[i];
423         }
424 
425         fn["c"] >> c;
426     }
427 
clear()428     void clear()
429     {
430         count.clear();
431         sum.clear();
432         productsum.clear();
433         avg.clear();
434         inv_eigen_values.clear();
435         cov_rotate_mats.clear();
436 
437         var_idx.release();
438         cls_labels.release();
439         c.release();
440         nallvars = 0;
441     }
442 
isTrained() const443     bool isTrained() const { return !avg.empty(); }
isClassifier() const444     bool isClassifier() const { return true; }
getVarCount() const445     int getVarCount() const { return nallvars; }
getDefaultName() const446     String getDefaultName() const { return "opencv_ml_nbayes"; }
447 
448     int nallvars;
449     Mat var_idx, cls_labels, c;
450     vector<Mat> count, sum, productsum, avg, inv_eigen_values, cov_rotate_mats;
451 };
452 
453 
create()454 Ptr<NormalBayesClassifier> NormalBayesClassifier::create()
455 {
456     Ptr<NormalBayesClassifierImpl> p = makePtr<NormalBayesClassifierImpl>();
457     return p;
458 }
459 
460 }
461 }
462 
463 /* End of file. */
464