1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 // * Redistribution's of source code must retain the above copyright notice,
21 // this list of conditions and the following disclaimer.
22 //
23 // * Redistribution's in binary form must reproduce the above copyright notice,
24 // this list of conditions and the following disclaimer in the documentation
25 // and/or other materials provided with the distribution.
26 //
27 // * The name of the copyright holders may not be used to endorse or promote products
28 // derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "precomp.hpp"
44 #include "kdtree.hpp"
45
46 /****************************************************************************************\
47 * K-Nearest Neighbors Classifier *
48 \****************************************************************************************/
49
50 namespace cv {
51 namespace ml {
52
53 const String NAME_BRUTE_FORCE = "opencv_ml_knn";
54 const String NAME_KDTREE = "opencv_ml_knn_kd";
55
56 class Impl
57 {
58 public:
Impl()59 Impl()
60 {
61 defaultK = 10;
62 isclassifier = true;
63 Emax = INT_MAX;
64 }
65
~Impl()66 virtual ~Impl() {}
67 virtual String getModelName() const = 0;
68 virtual int getType() const = 0;
69 virtual float findNearest( InputArray _samples, int k,
70 OutputArray _results,
71 OutputArray _neighborResponses,
72 OutputArray _dists ) const = 0;
73
train(const Ptr<TrainData> & data,int flags)74 bool train( const Ptr<TrainData>& data, int flags )
75 {
76 Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
77 Mat new_responses;
78 data->getTrainResponses().convertTo(new_responses, CV_32F);
79 bool update = (flags & ml::KNearest::UPDATE_MODEL) != 0 && !samples.empty();
80
81 CV_Assert( new_samples.type() == CV_32F );
82
83 if( !update )
84 {
85 clear();
86 }
87 else
88 {
89 CV_Assert( new_samples.cols == samples.cols &&
90 new_responses.cols == responses.cols );
91 }
92
93 samples.push_back(new_samples);
94 responses.push_back(new_responses);
95
96 doTrain(samples);
97
98 return true;
99 }
100
doTrain(InputArray points)101 virtual void doTrain(InputArray points) { (void)points; }
102
clear()103 void clear()
104 {
105 samples.release();
106 responses.release();
107 }
108
read(const FileNode & fn)109 void read( const FileNode& fn )
110 {
111 clear();
112 isclassifier = (int)fn["is_classifier"] != 0;
113 defaultK = (int)fn["default_k"];
114
115 fn["samples"] >> samples;
116 fn["responses"] >> responses;
117 }
118
write(FileStorage & fs) const119 void write( FileStorage& fs ) const
120 {
121 fs << "is_classifier" << (int)isclassifier;
122 fs << "default_k" << defaultK;
123
124 fs << "samples" << samples;
125 fs << "responses" << responses;
126 }
127
128 public:
129 int defaultK;
130 bool isclassifier;
131 int Emax;
132
133 Mat samples;
134 Mat responses;
135 };
136
137 class BruteForceImpl : public Impl
138 {
139 public:
getModelName() const140 String getModelName() const { return NAME_BRUTE_FORCE; }
getType() const141 int getType() const { return ml::KNearest::BRUTE_FORCE; }
142
findNearestCore(const Mat & _samples,int k0,const Range & range,Mat * results,Mat * neighbor_responses,Mat * dists,float * presult) const143 void findNearestCore( const Mat& _samples, int k0, const Range& range,
144 Mat* results, Mat* neighbor_responses,
145 Mat* dists, float* presult ) const
146 {
147 int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
148 int testcount = range.end - range.start;
149 int k = std::min(k0, nsamples);
150
151 AutoBuffer<float> buf(testcount*k*2);
152 float* dbuf = buf;
153 float* rbuf = dbuf + testcount*k;
154
155 const float* rptr = responses.ptr<float>();
156
157 for( testidx = 0; testidx < testcount; testidx++ )
158 {
159 for( i = 0; i < k; i++ )
160 {
161 dbuf[testidx*k + i] = FLT_MAX;
162 rbuf[testidx*k + i] = 0.f;
163 }
164 }
165
166 for( baseidx = 0; baseidx < nsamples; baseidx++ )
167 {
168 for( testidx = 0; testidx < testcount; testidx++ )
169 {
170 const float* v = samples.ptr<float>(baseidx);
171 const float* u = _samples.ptr<float>(testidx + range.start);
172
173 float s = 0;
174 for( i = 0; i <= d - 4; i += 4 )
175 {
176 float t0 = u[i] - v[i], t1 = u[i+1] - v[i+1];
177 float t2 = u[i+2] - v[i+2], t3 = u[i+3] - v[i+3];
178 s += t0*t0 + t1*t1 + t2*t2 + t3*t3;
179 }
180
181 for( ; i < d; i++ )
182 {
183 float t0 = u[i] - v[i];
184 s += t0*t0;
185 }
186
187 Cv32suf si;
188 si.f = (float)s;
189 Cv32suf* dd = (Cv32suf*)(&dbuf[testidx*k]);
190 float* nr = &rbuf[testidx*k];
191
192 for( i = k; i > 0; i-- )
193 if( si.i >= dd[i-1].i )
194 break;
195 if( i >= k )
196 continue;
197
198 for( j = k-2; j >= i; j-- )
199 {
200 dd[j+1].i = dd[j].i;
201 nr[j+1] = nr[j];
202 }
203 dd[i].i = si.i;
204 nr[i] = rptr[baseidx];
205 }
206 }
207
208 float result = 0.f;
209 float inv_scale = 1.f/k;
210
211 for( testidx = 0; testidx < testcount; testidx++ )
212 {
213 if( neighbor_responses )
214 {
215 float* nr = neighbor_responses->ptr<float>(testidx + range.start);
216 for( j = 0; j < k; j++ )
217 nr[j] = rbuf[testidx*k + j];
218 for( ; j < k0; j++ )
219 nr[j] = 0.f;
220 }
221
222 if( dists )
223 {
224 float* dptr = dists->ptr<float>(testidx + range.start);
225 for( j = 0; j < k; j++ )
226 dptr[j] = dbuf[testidx*k + j];
227 for( ; j < k0; j++ )
228 dptr[j] = 0.f;
229 }
230
231 if( results || testidx+range.start == 0 )
232 {
233 if( !isclassifier || k == 1 )
234 {
235 float s = 0.f;
236 for( j = 0; j < k; j++ )
237 s += rbuf[testidx*k + j];
238 result = (float)(s*inv_scale);
239 }
240 else
241 {
242 float* rp = rbuf + testidx*k;
243 for( j = k-1; j > 0; j-- )
244 {
245 bool swap_fl = false;
246 for( i = 0; i < j; i++ )
247 {
248 if( rp[i] > rp[i+1] )
249 {
250 std::swap(rp[i], rp[i+1]);
251 swap_fl = true;
252 }
253 }
254 if( !swap_fl )
255 break;
256 }
257
258 result = rp[0];
259 int prev_start = 0;
260 int best_count = 0;
261 for( j = 1; j <= k; j++ )
262 {
263 if( j == k || rp[j] != rp[j-1] )
264 {
265 int count = j - prev_start;
266 if( best_count < count )
267 {
268 best_count = count;
269 result = rp[j-1];
270 }
271 prev_start = j;
272 }
273 }
274 }
275 if( results )
276 results->at<float>(testidx + range.start) = result;
277 if( presult && testidx+range.start == 0 )
278 *presult = result;
279 }
280 }
281 }
282
283 struct findKNearestInvoker : public ParallelLoopBody
284 {
findKNearestInvokercv::ml::BruteForceImpl::findKNearestInvoker285 findKNearestInvoker(const BruteForceImpl* _p, int _k, const Mat& __samples,
286 Mat* __results, Mat* __neighbor_responses, Mat* __dists, float* _presult)
287 {
288 p = _p;
289 k = _k;
290 _samples = &__samples;
291 _results = __results;
292 _neighbor_responses = __neighbor_responses;
293 _dists = __dists;
294 presult = _presult;
295 }
296
operator ()cv::ml::BruteForceImpl::findKNearestInvoker297 void operator()( const Range& range ) const
298 {
299 int delta = std::min(range.end - range.start, 256);
300 for( int start = range.start; start < range.end; start += delta )
301 {
302 p->findNearestCore( *_samples, k, Range(start, std::min(start + delta, range.end)),
303 _results, _neighbor_responses, _dists, presult );
304 }
305 }
306
307 const BruteForceImpl* p;
308 int k;
309 const Mat* _samples;
310 Mat* _results;
311 Mat* _neighbor_responses;
312 Mat* _dists;
313 float* presult;
314 };
315
findNearest(InputArray _samples,int k,OutputArray _results,OutputArray _neighborResponses,OutputArray _dists) const316 float findNearest( InputArray _samples, int k,
317 OutputArray _results,
318 OutputArray _neighborResponses,
319 OutputArray _dists ) const
320 {
321 float result = 0.f;
322 CV_Assert( 0 < k );
323
324 Mat test_samples = _samples.getMat();
325 CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
326 int testcount = test_samples.rows;
327
328 if( testcount == 0 )
329 {
330 _results.release();
331 _neighborResponses.release();
332 _dists.release();
333 return 0.f;
334 }
335
336 Mat res, nr, d, *pres = 0, *pnr = 0, *pd = 0;
337 if( _results.needed() )
338 {
339 _results.create(testcount, 1, CV_32F);
340 pres = &(res = _results.getMat());
341 }
342 if( _neighborResponses.needed() )
343 {
344 _neighborResponses.create(testcount, k, CV_32F);
345 pnr = &(nr = _neighborResponses.getMat());
346 }
347 if( _dists.needed() )
348 {
349 _dists.create(testcount, k, CV_32F);
350 pd = &(d = _dists.getMat());
351 }
352
353 findKNearestInvoker invoker(this, k, test_samples, pres, pnr, pd, &result);
354 parallel_for_(Range(0, testcount), invoker);
355 //invoker(Range(0, testcount));
356 return result;
357 }
358 };
359
360
361 class KDTreeImpl : public Impl
362 {
363 public:
getModelName() const364 String getModelName() const { return NAME_KDTREE; }
getType() const365 int getType() const { return ml::KNearest::KDTREE; }
366
doTrain(InputArray points)367 void doTrain(InputArray points)
368 {
369 tr.build(points);
370 }
371
findNearest(InputArray _samples,int k,OutputArray _results,OutputArray _neighborResponses,OutputArray _dists) const372 float findNearest( InputArray _samples, int k,
373 OutputArray _results,
374 OutputArray _neighborResponses,
375 OutputArray _dists ) const
376 {
377 float result = 0.f;
378 CV_Assert( 0 < k );
379
380 Mat test_samples = _samples.getMat();
381 CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
382 int testcount = test_samples.rows;
383
384 if( testcount == 0 )
385 {
386 _results.release();
387 _neighborResponses.release();
388 _dists.release();
389 return 0.f;
390 }
391
392 Mat res, nr, d;
393 if( _results.needed() )
394 {
395 _results.create(testcount, 1, CV_32F);
396 res = _results.getMat();
397 }
398 if( _neighborResponses.needed() )
399 {
400 _neighborResponses.create(testcount, k, CV_32F);
401 nr = _neighborResponses.getMat();
402 }
403 if( _dists.needed() )
404 {
405 _dists.create(testcount, k, CV_32F);
406 d = _dists.getMat();
407 }
408
409 for (int i=0; i<test_samples.rows; ++i)
410 {
411 Mat _res, _nr, _d;
412 if (res.rows>i)
413 {
414 _res = res.row(i);
415 }
416 if (nr.rows>i)
417 {
418 _nr = nr.row(i);
419 }
420 if (d.rows>i)
421 {
422 _d = d.row(i);
423 }
424 tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray());
425 }
426
427 return result; // currently always 0
428 }
429
430 KDTree tr;
431 };
432
433 //================================================================
434
435 class KNearestImpl : public KNearest
436 {
437 CV_IMPL_PROPERTY(int, DefaultK, impl->defaultK)
438 CV_IMPL_PROPERTY(bool, IsClassifier, impl->isclassifier)
439 CV_IMPL_PROPERTY(int, Emax, impl->Emax)
440
441 public:
getAlgorithmType() const442 int getAlgorithmType() const
443 {
444 return impl->getType();
445 }
setAlgorithmType(int val)446 void setAlgorithmType(int val)
447 {
448 if (val != BRUTE_FORCE && val != KDTREE)
449 val = BRUTE_FORCE;
450 initImpl(val);
451 }
452
453 public:
KNearestImpl()454 KNearestImpl()
455 {
456 initImpl(BRUTE_FORCE);
457 }
~KNearestImpl()458 ~KNearestImpl()
459 {
460 }
461
isClassifier() const462 bool isClassifier() const { return impl->isclassifier; }
isTrained() const463 bool isTrained() const { return !impl->samples.empty(); }
464
getVarCount() const465 int getVarCount() const { return impl->samples.cols; }
466
write(FileStorage & fs) const467 void write( FileStorage& fs ) const
468 {
469 impl->write(fs);
470 }
471
read(const FileNode & fn)472 void read( const FileNode& fn )
473 {
474 int algorithmType = BRUTE_FORCE;
475 if (fn.name() == NAME_KDTREE)
476 algorithmType = KDTREE;
477 initImpl(algorithmType);
478 impl->read(fn);
479 }
480
findNearest(InputArray samples,int k,OutputArray results,OutputArray neighborResponses=noArray (),OutputArray dist=noArray ()) const481 float findNearest( InputArray samples, int k,
482 OutputArray results,
483 OutputArray neighborResponses=noArray(),
484 OutputArray dist=noArray() ) const
485 {
486 return impl->findNearest(samples, k, results, neighborResponses, dist);
487 }
488
predict(InputArray inputs,OutputArray outputs,int) const489 float predict(InputArray inputs, OutputArray outputs, int) const
490 {
491 return impl->findNearest( inputs, impl->defaultK, outputs, noArray(), noArray() );
492 }
493
train(const Ptr<TrainData> & data,int flags)494 bool train( const Ptr<TrainData>& data, int flags )
495 {
496 return impl->train(data, flags);
497 }
498
getDefaultName() const499 String getDefaultName() const { return impl->getModelName(); }
500
501 protected:
initImpl(int algorithmType)502 void initImpl(int algorithmType)
503 {
504 if (algorithmType != KDTREE)
505 impl = makePtr<BruteForceImpl>();
506 else
507 impl = makePtr<KDTreeImpl>();
508 }
509 Ptr<Impl> impl;
510 };
511
create()512 Ptr<KNearest> KNearest::create()
513 {
514 return makePtr<KNearestImpl>();
515 }
516
517 }
518 }
519
520 /* End of file */
521