• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 #include "kdtree.hpp"
45 
46 /****************************************************************************************\
47 *                              K-Nearest Neighbors Classifier                            *
48 \****************************************************************************************/
49 
50 namespace cv {
51 namespace ml {
52 
53 const String NAME_BRUTE_FORCE = "opencv_ml_knn";
54 const String NAME_KDTREE = "opencv_ml_knn_kd";
55 
56 class Impl
57 {
58 public:
Impl()59     Impl()
60     {
61         defaultK = 10;
62         isclassifier = true;
63         Emax = INT_MAX;
64     }
65 
~Impl()66     virtual ~Impl() {}
67     virtual String getModelName() const = 0;
68     virtual int getType() const = 0;
69     virtual float findNearest( InputArray _samples, int k,
70                                OutputArray _results,
71                                OutputArray _neighborResponses,
72                                OutputArray _dists ) const = 0;
73 
train(const Ptr<TrainData> & data,int flags)74     bool train( const Ptr<TrainData>& data, int flags )
75     {
76         Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
77         Mat new_responses;
78         data->getTrainResponses().convertTo(new_responses, CV_32F);
79         bool update = (flags & ml::KNearest::UPDATE_MODEL) != 0 && !samples.empty();
80 
81         CV_Assert( new_samples.type() == CV_32F );
82 
83         if( !update )
84         {
85             clear();
86         }
87         else
88         {
89             CV_Assert( new_samples.cols == samples.cols &&
90                        new_responses.cols == responses.cols );
91         }
92 
93         samples.push_back(new_samples);
94         responses.push_back(new_responses);
95 
96         doTrain(samples);
97 
98         return true;
99     }
100 
doTrain(InputArray points)101     virtual void doTrain(InputArray points) { (void)points; }
102 
clear()103     void clear()
104     {
105         samples.release();
106         responses.release();
107     }
108 
read(const FileNode & fn)109     void read( const FileNode& fn )
110     {
111         clear();
112         isclassifier = (int)fn["is_classifier"] != 0;
113         defaultK = (int)fn["default_k"];
114 
115         fn["samples"] >> samples;
116         fn["responses"] >> responses;
117     }
118 
write(FileStorage & fs) const119     void write( FileStorage& fs ) const
120     {
121         fs << "is_classifier" << (int)isclassifier;
122         fs << "default_k" << defaultK;
123 
124         fs << "samples" << samples;
125         fs << "responses" << responses;
126     }
127 
128 public:
129     int defaultK;
130     bool isclassifier;
131     int Emax;
132 
133     Mat samples;
134     Mat responses;
135 };
136 
137 class BruteForceImpl : public Impl
138 {
139 public:
getModelName() const140     String getModelName() const { return NAME_BRUTE_FORCE; }
getType() const141     int getType() const { return ml::KNearest::BRUTE_FORCE; }
142 
findNearestCore(const Mat & _samples,int k0,const Range & range,Mat * results,Mat * neighbor_responses,Mat * dists,float * presult) const143     void findNearestCore( const Mat& _samples, int k0, const Range& range,
144                           Mat* results, Mat* neighbor_responses,
145                           Mat* dists, float* presult ) const
146     {
147         int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
148         int testcount = range.end - range.start;
149         int k = std::min(k0, nsamples);
150 
151         AutoBuffer<float> buf(testcount*k*2);
152         float* dbuf = buf;
153         float* rbuf = dbuf + testcount*k;
154 
155         const float* rptr = responses.ptr<float>();
156 
157         for( testidx = 0; testidx < testcount; testidx++ )
158         {
159             for( i = 0; i < k; i++ )
160             {
161                 dbuf[testidx*k + i] = FLT_MAX;
162                 rbuf[testidx*k + i] = 0.f;
163             }
164         }
165 
166         for( baseidx = 0; baseidx < nsamples; baseidx++ )
167         {
168             for( testidx = 0; testidx < testcount; testidx++ )
169             {
170                 const float* v = samples.ptr<float>(baseidx);
171                 const float* u = _samples.ptr<float>(testidx + range.start);
172 
173                 float s = 0;
174                 for( i = 0; i <= d - 4; i += 4 )
175                 {
176                     float t0 = u[i] - v[i], t1 = u[i+1] - v[i+1];
177                     float t2 = u[i+2] - v[i+2], t3 = u[i+3] - v[i+3];
178                     s += t0*t0 + t1*t1 + t2*t2 + t3*t3;
179                 }
180 
181                 for( ; i < d; i++ )
182                 {
183                     float t0 = u[i] - v[i];
184                     s += t0*t0;
185                 }
186 
187                 Cv32suf si;
188                 si.f = (float)s;
189                 Cv32suf* dd = (Cv32suf*)(&dbuf[testidx*k]);
190                 float* nr = &rbuf[testidx*k];
191 
192                 for( i = k; i > 0; i-- )
193                     if( si.i >= dd[i-1].i )
194                         break;
195                 if( i >= k )
196                     continue;
197 
198                 for( j = k-2; j >= i; j-- )
199                 {
200                     dd[j+1].i = dd[j].i;
201                     nr[j+1] = nr[j];
202                 }
203                 dd[i].i = si.i;
204                 nr[i] = rptr[baseidx];
205             }
206         }
207 
208         float result = 0.f;
209         float inv_scale = 1.f/k;
210 
211         for( testidx = 0; testidx < testcount; testidx++ )
212         {
213             if( neighbor_responses )
214             {
215                 float* nr = neighbor_responses->ptr<float>(testidx + range.start);
216                 for( j = 0; j < k; j++ )
217                     nr[j] = rbuf[testidx*k + j];
218                 for( ; j < k0; j++ )
219                     nr[j] = 0.f;
220             }
221 
222             if( dists )
223             {
224                 float* dptr = dists->ptr<float>(testidx + range.start);
225                 for( j = 0; j < k; j++ )
226                     dptr[j] = dbuf[testidx*k + j];
227                 for( ; j < k0; j++ )
228                     dptr[j] = 0.f;
229             }
230 
231             if( results || testidx+range.start == 0 )
232             {
233                 if( !isclassifier || k == 1 )
234                 {
235                     float s = 0.f;
236                     for( j = 0; j < k; j++ )
237                         s += rbuf[testidx*k + j];
238                     result = (float)(s*inv_scale);
239                 }
240                 else
241                 {
242                     float* rp = rbuf + testidx*k;
243                     for( j = k-1; j > 0; j-- )
244                     {
245                         bool swap_fl = false;
246                         for( i = 0; i < j; i++ )
247                         {
248                             if( rp[i] > rp[i+1] )
249                             {
250                                 std::swap(rp[i], rp[i+1]);
251                                 swap_fl = true;
252                             }
253                         }
254                         if( !swap_fl )
255                             break;
256                     }
257 
258                     result = rp[0];
259                     int prev_start = 0;
260                     int best_count = 0;
261                     for( j = 1; j <= k; j++ )
262                     {
263                         if( j == k || rp[j] != rp[j-1] )
264                         {
265                             int count = j - prev_start;
266                             if( best_count < count )
267                             {
268                                 best_count = count;
269                                 result = rp[j-1];
270                             }
271                             prev_start = j;
272                         }
273                     }
274                 }
275                 if( results )
276                     results->at<float>(testidx + range.start) = result;
277                 if( presult && testidx+range.start == 0 )
278                     *presult = result;
279             }
280         }
281     }
282 
283     struct findKNearestInvoker : public ParallelLoopBody
284     {
findKNearestInvokercv::ml::BruteForceImpl::findKNearestInvoker285         findKNearestInvoker(const BruteForceImpl* _p, int _k, const Mat& __samples,
286                             Mat* __results, Mat* __neighbor_responses, Mat* __dists, float* _presult)
287         {
288             p = _p;
289             k = _k;
290             _samples = &__samples;
291             _results = __results;
292             _neighbor_responses = __neighbor_responses;
293             _dists = __dists;
294             presult = _presult;
295         }
296 
operator ()cv::ml::BruteForceImpl::findKNearestInvoker297         void operator()( const Range& range ) const
298         {
299             int delta = std::min(range.end - range.start, 256);
300             for( int start = range.start; start < range.end; start += delta )
301             {
302                 p->findNearestCore( *_samples, k, Range(start, std::min(start + delta, range.end)),
303                                     _results, _neighbor_responses, _dists, presult );
304             }
305         }
306 
307         const BruteForceImpl* p;
308         int k;
309         const Mat* _samples;
310         Mat* _results;
311         Mat* _neighbor_responses;
312         Mat* _dists;
313         float* presult;
314     };
315 
findNearest(InputArray _samples,int k,OutputArray _results,OutputArray _neighborResponses,OutputArray _dists) const316     float findNearest( InputArray _samples, int k,
317                        OutputArray _results,
318                        OutputArray _neighborResponses,
319                        OutputArray _dists ) const
320     {
321         float result = 0.f;
322         CV_Assert( 0 < k );
323 
324         Mat test_samples = _samples.getMat();
325         CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
326         int testcount = test_samples.rows;
327 
328         if( testcount == 0 )
329         {
330             _results.release();
331             _neighborResponses.release();
332             _dists.release();
333             return 0.f;
334         }
335 
336         Mat res, nr, d, *pres = 0, *pnr = 0, *pd = 0;
337         if( _results.needed() )
338         {
339             _results.create(testcount, 1, CV_32F);
340             pres = &(res = _results.getMat());
341         }
342         if( _neighborResponses.needed() )
343         {
344             _neighborResponses.create(testcount, k, CV_32F);
345             pnr = &(nr = _neighborResponses.getMat());
346         }
347         if( _dists.needed() )
348         {
349             _dists.create(testcount, k, CV_32F);
350             pd = &(d = _dists.getMat());
351         }
352 
353         findKNearestInvoker invoker(this, k, test_samples, pres, pnr, pd, &result);
354         parallel_for_(Range(0, testcount), invoker);
355         //invoker(Range(0, testcount));
356         return result;
357     }
358 };
359 
360 
361 class KDTreeImpl : public Impl
362 {
363 public:
getModelName() const364     String getModelName() const { return NAME_KDTREE; }
getType() const365     int getType() const { return ml::KNearest::KDTREE; }
366 
doTrain(InputArray points)367     void doTrain(InputArray points)
368     {
369         tr.build(points);
370     }
371 
findNearest(InputArray _samples,int k,OutputArray _results,OutputArray _neighborResponses,OutputArray _dists) const372     float findNearest( InputArray _samples, int k,
373                        OutputArray _results,
374                        OutputArray _neighborResponses,
375                        OutputArray _dists ) const
376     {
377         float result = 0.f;
378         CV_Assert( 0 < k );
379 
380         Mat test_samples = _samples.getMat();
381         CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
382         int testcount = test_samples.rows;
383 
384         if( testcount == 0 )
385         {
386             _results.release();
387             _neighborResponses.release();
388             _dists.release();
389             return 0.f;
390         }
391 
392         Mat res, nr, d;
393         if( _results.needed() )
394         {
395             _results.create(testcount, 1, CV_32F);
396             res = _results.getMat();
397         }
398         if( _neighborResponses.needed() )
399         {
400             _neighborResponses.create(testcount, k, CV_32F);
401             nr = _neighborResponses.getMat();
402         }
403         if( _dists.needed() )
404         {
405             _dists.create(testcount, k, CV_32F);
406             d = _dists.getMat();
407         }
408 
409         for (int i=0; i<test_samples.rows; ++i)
410         {
411             Mat _res, _nr, _d;
412             if (res.rows>i)
413             {
414                 _res = res.row(i);
415             }
416             if (nr.rows>i)
417             {
418                 _nr = nr.row(i);
419             }
420             if (d.rows>i)
421             {
422                 _d = d.row(i);
423             }
424             tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray());
425         }
426 
427         return result; // currently always 0
428     }
429 
430     KDTree tr;
431 };
432 
433 //================================================================
434 
435 class KNearestImpl : public KNearest
436 {
437     CV_IMPL_PROPERTY(int, DefaultK, impl->defaultK)
438     CV_IMPL_PROPERTY(bool, IsClassifier, impl->isclassifier)
439     CV_IMPL_PROPERTY(int, Emax, impl->Emax)
440 
441 public:
getAlgorithmType() const442     int getAlgorithmType() const
443     {
444         return impl->getType();
445     }
setAlgorithmType(int val)446     void setAlgorithmType(int val)
447     {
448         if (val != BRUTE_FORCE && val != KDTREE)
449             val = BRUTE_FORCE;
450         initImpl(val);
451     }
452 
453 public:
KNearestImpl()454     KNearestImpl()
455     {
456         initImpl(BRUTE_FORCE);
457     }
~KNearestImpl()458     ~KNearestImpl()
459     {
460     }
461 
isClassifier() const462     bool isClassifier() const { return impl->isclassifier; }
isTrained() const463     bool isTrained() const { return !impl->samples.empty(); }
464 
getVarCount() const465     int getVarCount() const { return impl->samples.cols; }
466 
write(FileStorage & fs) const467     void write( FileStorage& fs ) const
468     {
469         impl->write(fs);
470     }
471 
read(const FileNode & fn)472     void read( const FileNode& fn )
473     {
474         int algorithmType = BRUTE_FORCE;
475         if (fn.name() == NAME_KDTREE)
476             algorithmType = KDTREE;
477         initImpl(algorithmType);
478         impl->read(fn);
479     }
480 
findNearest(InputArray samples,int k,OutputArray results,OutputArray neighborResponses=noArray (),OutputArray dist=noArray ()) const481     float findNearest( InputArray samples, int k,
482                        OutputArray results,
483                        OutputArray neighborResponses=noArray(),
484                        OutputArray dist=noArray() ) const
485     {
486         return impl->findNearest(samples, k, results, neighborResponses, dist);
487     }
488 
predict(InputArray inputs,OutputArray outputs,int) const489     float predict(InputArray inputs, OutputArray outputs, int) const
490     {
491         return impl->findNearest( inputs, impl->defaultK, outputs, noArray(), noArray() );
492     }
493 
train(const Ptr<TrainData> & data,int flags)494     bool train( const Ptr<TrainData>& data, int flags )
495     {
496         return impl->train(data, flags);
497     }
498 
getDefaultName() const499     String getDefaultName() const { return impl->getModelName(); }
500 
501 protected:
initImpl(int algorithmType)502     void initImpl(int algorithmType)
503     {
504         if (algorithmType != KDTREE)
505             impl = makePtr<BruteForceImpl>();
506         else
507             impl = makePtr<KDTreeImpl>();
508     }
509     Ptr<Impl> impl;
510 };
511 
create()512 Ptr<KNearest> KNearest::create()
513 {
514     return makePtr<KNearestImpl>();
515 }
516 
517 }
518 }
519 
520 /* End of file */
521