1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
21 //
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
25 //
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 // (including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort (including negligence or otherwise) arising in any way out of
38 // the use of this software, even if advised of the possibility of such damage.
39 //
40 //M*/
41
42 /*
43 Partially based on Yossi Rubner code:
44 =========================================================================
45 emd.c
46
47 Last update: 3/14/98
48
49 An implementation of the Earth Movers Distance.
50 Based of the solution for the Transportation problem as described in
51 "Introduction to Mathematical Programming" by F. S. Hillier and
52 G. J. Lieberman, McGraw-Hill, 1990.
53
54 Copyright (C) 1998 Yossi Rubner
55 Computer Science Department, Stanford University
56 E-Mail: rubner@cs.stanford.edu URL: http://vision.stanford.edu/~rubner
57 ==========================================================================
58 */
59 #include "_cv.h"
60
61 #define MAX_ITERATIONS 500
62 #define CV_EMD_INF ((float)1e20)
63 #define CV_EMD_EPS ((float)1e-5)
64
65 /* CvNode1D is used for lists, representing 1D sparse array */
66 typedef struct CvNode1D
67 {
68 float val;
69 struct CvNode1D *next;
70 }
71 CvNode1D;
72
73 /* CvNode2D is used for lists, representing 2D sparse matrix */
74 typedef struct CvNode2D
75 {
76 float val;
77 struct CvNode2D *next[2]; /* next row & next column */
78 int i, j;
79 }
80 CvNode2D;
81
82
83 typedef struct CvEMDState
84 {
85 int ssize, dsize;
86
87 float **cost;
88 CvNode2D *_x;
89 CvNode2D *end_x;
90 CvNode2D *enter_x;
91 char **is_x;
92
93 CvNode2D **rows_x;
94 CvNode2D **cols_x;
95
96 CvNode1D *u;
97 CvNode1D *v;
98
99 int* idx1;
100 int* idx2;
101
102 /* find_loop buffers */
103 CvNode2D **loop;
104 char *is_used;
105
106 /* russel buffers */
107 float *s;
108 float *d;
109 float **delta;
110
111 float weight, max_cost;
112 char *buffer;
113 }
114 CvEMDState;
115
116 /* static function declaration */
117 static CvStatus icvInitEMD( const float *signature1, int size1,
118 const float *signature2, int size2,
119 int dims, CvDistanceFunction dist_func, void *user_param,
120 const float* cost, int cost_step,
121 CvEMDState * state, float *lower_bound,
122 char *local_buffer, int local_buffer_size );
123
124 static CvStatus icvFindBasicVariables( float **cost, char **is_x,
125 CvNode1D * u, CvNode1D * v, int ssize, int dsize );
126
127 static float icvIsOptimal( float **cost, char **is_x,
128 CvNode1D * u, CvNode1D * v,
129 int ssize, int dsize, CvNode2D * enter_x );
130
131 static void icvRussel( CvEMDState * state );
132
133
134 static CvStatus icvNewSolution( CvEMDState * state );
135 static int icvFindLoop( CvEMDState * state );
136
137 static void icvAddBasicVariable( CvEMDState * state,
138 int min_i, int min_j,
139 CvNode1D * prev_u_min_i,
140 CvNode1D * prev_v_min_j,
141 CvNode1D * u_head );
142
143 static float icvDistL2( const float *x, const float *y, void *user_param );
144 static float icvDistL1( const float *x, const float *y, void *user_param );
145 static float icvDistC( const float *x, const float *y, void *user_param );
146
147 /* The main function */
148 CV_IMPL float
cvCalcEMD2(const CvArr * signature_arr1,const CvArr * signature_arr2,int dist_type,CvDistanceFunction dist_func,const CvArr * cost_matrix,CvArr * flow_matrix,float * lower_bound,void * user_param)149 cvCalcEMD2( const CvArr* signature_arr1,
150 const CvArr* signature_arr2,
151 int dist_type,
152 CvDistanceFunction dist_func,
153 const CvArr* cost_matrix,
154 CvArr* flow_matrix,
155 float *lower_bound,
156 void *user_param )
157 {
158 char local_buffer[16384];
159 char *local_buffer_ptr = (char *)cvAlignPtr(local_buffer,16);
160 CvEMDState state;
161 float emd = 0;
162
163 CV_FUNCNAME( "cvCalcEMD2" );
164
165 memset( &state, 0, sizeof(state));
166
167 __BEGIN__;
168
169 double total_cost = 0;
170 CvStatus result = CV_NO_ERR;
171 float eps, min_delta;
172 CvNode2D *xp = 0;
173 CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
174 CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
175 CvMat cost_stub, *cost = &cost_stub;
176 CvMat flow_stub, *flow = (CvMat*)flow_matrix;
177 int dims, size1, size2;
178
179 CV_CALL( signature1 = cvGetMat( signature1, &sign_stub1 ));
180 CV_CALL( signature2 = cvGetMat( signature2, &sign_stub2 ));
181
182 if( signature1->cols != signature2->cols )
183 CV_ERROR( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
184
185 dims = signature1->cols - 1;
186 size1 = signature1->rows;
187 size2 = signature2->rows;
188
189 if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
190 CV_ERROR( CV_StsUnmatchedFormats, "The array must have equal types" );
191
192 if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
193 CV_ERROR( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
194
195 if( flow )
196 {
197 CV_CALL( flow = cvGetMat( flow, &flow_stub ));
198
199 if( flow->rows != size1 || flow->cols != size2 )
200 CV_ERROR( CV_StsUnmatchedSizes,
201 "The flow matrix size does not match to the signatures' sizes" );
202
203 if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
204 CV_ERROR( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
205 }
206
207 cost->data.fl = 0;
208 cost->step = 0;
209
210 if( dist_type < 0 )
211 {
212 if( cost_matrix )
213 {
214 if( dist_func )
215 CV_ERROR( CV_StsBadArg,
216 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
217
218 if( lower_bound )
219 CV_ERROR( CV_StsBadArg,
220 "The lower boundary can not be calculated if the cost matrix is used" );
221
222 CV_CALL( cost = cvGetMat( cost_matrix, &cost_stub ));
223 if( cost->rows != size1 || cost->cols != size2 )
224 CV_ERROR( CV_StsUnmatchedSizes,
225 "The cost matrix size does not match to the signatures' sizes" );
226
227 if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
228 CV_ERROR( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
229 }
230 else if( !dist_func )
231 CV_ERROR( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
232 }
233 else
234 {
235 if( dims == 0 )
236 CV_ERROR( CV_StsBadSize,
237 "Number of dimensions can be 0 only if a user-defined metric is used" );
238 user_param = (void *) (size_t)dims;
239 switch (dist_type)
240 {
241 case CV_DIST_L1:
242 dist_func = icvDistL1;
243 break;
244 case CV_DIST_L2:
245 dist_func = icvDistL2;
246 break;
247 case CV_DIST_C:
248 dist_func = icvDistC;
249 break;
250 default:
251 CV_ERROR( CV_StsBadFlag, "Bad or unsupported metric type" );
252 }
253 }
254
255 IPPI_CALL( result = icvInitEMD( signature1->data.fl, size1,
256 signature2->data.fl, size2,
257 dims, dist_func, user_param,
258 cost->data.fl, cost->step,
259 &state, lower_bound, local_buffer_ptr,
260 sizeof( local_buffer ) - 16 ));
261
262 if( result > 0 && lower_bound )
263 {
264 emd = *lower_bound;
265 EXIT;
266 }
267
268 eps = CV_EMD_EPS * state.max_cost;
269
270 /* if ssize = 1 or dsize = 1 then we are done, else ... */
271 if( state.ssize > 1 && state.dsize > 1 )
272 {
273 int itr;
274
275 for( itr = 1; itr < MAX_ITERATIONS; itr++ )
276 {
277 /* find basic variables */
278 result = icvFindBasicVariables( state.cost, state.is_x,
279 state.u, state.v, state.ssize, state.dsize );
280 if( result < 0 )
281 break;
282
283 /* check for optimality */
284 min_delta = icvIsOptimal( state.cost, state.is_x,
285 state.u, state.v,
286 state.ssize, state.dsize, state.enter_x );
287
288 if( min_delta == CV_EMD_INF )
289 {
290 CV_ERROR( CV_StsNoConv, "" );
291 }
292
293 /* if no negative deltamin, we found the optimal solution */
294 if( min_delta >= -eps )
295 break;
296
297 /* improve solution */
298 IPPI_CALL( icvNewSolution( &state ));
299 }
300 }
301
302 /* compute the total flow */
303 for( xp = state._x; xp < state.end_x; xp++ )
304 {
305 float val = xp->val;
306 int i = xp->i;
307 int j = xp->j;
308 int ci = state.idx1[i];
309 int cj = state.idx2[j];
310
311 if( xp != state.enter_x && ci >= 0 && cj >= 0 )
312 {
313 total_cost += (double)val * state.cost[i][j];
314 if( flow )
315 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
316 }
317 }
318
319 emd = (float) (total_cost / state.weight);
320
321 __END__;
322
323 if( state.buffer && state.buffer != local_buffer_ptr )
324 cvFree( &state.buffer );
325
326 return emd;
327 }
328
329
330 /************************************************************************************\
331 * initialize structure, allocate buffers and generate initial golution *
332 \************************************************************************************/
333 static CvStatus
icvInitEMD(const float * signature1,int size1,const float * signature2,int size2,int dims,CvDistanceFunction dist_func,void * user_param,const float * cost,int cost_step,CvEMDState * state,float * lower_bound,char * local_buffer,int local_buffer_size)334 icvInitEMD( const float* signature1, int size1,
335 const float* signature2, int size2,
336 int dims, CvDistanceFunction dist_func, void* user_param,
337 const float* cost, int cost_step,
338 CvEMDState* state, float* lower_bound,
339 char* local_buffer, int local_buffer_size )
340 {
341 float s_sum = 0, d_sum = 0, diff;
342 int i, j;
343 int ssize = 0, dsize = 0;
344 int equal_sums = 1;
345 int buffer_size;
346 float max_cost = 0;
347 char *buffer, *buffer_end;
348
349 memset( state, 0, sizeof( *state ));
350 assert( cost_step % sizeof(float) == 0 );
351 cost_step /= sizeof(float);
352
353 /* calculate buffer size */
354 buffer_size = (size1+1) * (size2+1) * (sizeof( float ) + /* cost */
355 sizeof( char ) + /* is_x */
356 sizeof( float )) + /* delta matrix */
357 (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
358 sizeof( CvNode2D * ) + /* cols_x & rows_x */
359 sizeof( CvNode1D ) + /* u & v */
360 sizeof( float ) + /* s & d */
361 sizeof( int ) + sizeof(CvNode2D*)) + /* idx1 & idx2 */
362 (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
363 sizeof( float * )) + 256; /* cost, is_x and delta */
364
365 if( buffer_size < (int) (dims * 2 * sizeof( float )))
366 {
367 buffer_size = dims * 2 * sizeof( float );
368 }
369
370 /* allocate buffers */
371 if( local_buffer != 0 && local_buffer_size >= buffer_size )
372 {
373 buffer = local_buffer;
374 }
375 else
376 {
377 buffer = (char*)cvAlloc( buffer_size );
378 if( !buffer )
379 return CV_OUTOFMEM_ERR;
380 }
381
382 state->buffer = buffer;
383 buffer_end = buffer + buffer_size;
384
385 state->idx1 = (int*) buffer;
386 buffer += (size1 + 1) * sizeof( int );
387
388 state->idx2 = (int*) buffer;
389 buffer += (size2 + 1) * sizeof( int );
390
391 state->s = (float *) buffer;
392 buffer += (size1 + 1) * sizeof( float );
393
394 state->d = (float *) buffer;
395 buffer += (size2 + 1) * sizeof( float );
396
397 /* sum up the supply and demand */
398 for( i = 0; i < size1; i++ )
399 {
400 float weight = signature1[i * (dims + 1)];
401
402 if( weight > 0 )
403 {
404 s_sum += weight;
405 state->s[ssize] = weight;
406 state->idx1[ssize++] = i;
407
408 }
409 else if( weight < 0 )
410 return CV_BADRANGE_ERR;
411 }
412
413 for( i = 0; i < size2; i++ )
414 {
415 float weight = signature2[i * (dims + 1)];
416
417 if( weight > 0 )
418 {
419 d_sum += weight;
420 state->d[dsize] = weight;
421 state->idx2[dsize++] = i;
422 }
423 else if( weight < 0 )
424 return CV_BADRANGE_ERR;
425 }
426
427 if( ssize == 0 || dsize == 0 )
428 return CV_BADRANGE_ERR;
429
430 /* if supply different than the demand, add a zero-cost dummy cluster */
431 diff = s_sum - d_sum;
432 if( fabs( diff ) >= CV_EMD_EPS * s_sum )
433 {
434 equal_sums = 0;
435 if( diff < 0 )
436 {
437 state->s[ssize] = -diff;
438 state->idx1[ssize++] = -1;
439 }
440 else
441 {
442 state->d[dsize] = diff;
443 state->idx2[dsize++] = -1;
444 }
445 }
446
447 state->ssize = ssize;
448 state->dsize = dsize;
449 state->weight = s_sum > d_sum ? s_sum : d_sum;
450
451 if( lower_bound && equal_sums ) /* check lower bound */
452 {
453 int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
454 float lb = 0;
455
456 float* xs = (float *) buffer;
457 float* xd = xs + dims;
458
459 memset( xs, 0, dims*sizeof(xs[0]));
460 memset( xd, 0, dims*sizeof(xd[0]));
461
462 for( j = 0; j < sz1; j += dims + 1 )
463 {
464 float weight = signature1[j];
465 for( i = 0; i < dims; i++ )
466 xs[i] += signature1[j + i + 1] * weight;
467 }
468
469 for( j = 0; j < sz2; j += dims + 1 )
470 {
471 float weight = signature2[j];
472 for( i = 0; i < dims; i++ )
473 xd[i] += signature2[j + i + 1] * weight;
474 }
475
476 lb = dist_func( xs, xd, user_param ) / state->weight;
477 i = *lower_bound <= lb;
478 *lower_bound = lb;
479 if( i )
480 return ( CvStatus ) 1;
481 }
482
483 /* assign pointers */
484 state->is_used = (char *) buffer;
485 /* init delta matrix */
486 state->delta = (float **) buffer;
487 buffer += ssize * sizeof( float * );
488
489 for( i = 0; i < ssize; i++ )
490 {
491 state->delta[i] = (float *) buffer;
492 buffer += dsize * sizeof( float );
493 }
494
495 state->loop = (CvNode2D **) buffer;
496 buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
497
498 state->_x = state->end_x = (CvNode2D *) buffer;
499 buffer += (ssize + dsize) * sizeof( CvNode2D );
500
501 /* init cost matrix */
502 state->cost = (float **) buffer;
503 buffer += ssize * sizeof( float * );
504
505 /* compute the distance matrix */
506 for( i = 0; i < ssize; i++ )
507 {
508 int ci = state->idx1[i];
509
510 state->cost[i] = (float *) buffer;
511 buffer += dsize * sizeof( float );
512
513 if( ci >= 0 )
514 {
515 for( j = 0; j < dsize; j++ )
516 {
517 int cj = state->idx2[j];
518 if( cj < 0 )
519 state->cost[i][j] = 0;
520 else
521 {
522 float val;
523 if( dist_func )
524 {
525 val = dist_func( signature1 + ci * (dims + 1) + 1,
526 signature2 + cj * (dims + 1) + 1,
527 user_param );
528 }
529 else
530 {
531 assert( cost );
532 val = cost[cost_step*ci + cj];
533 }
534 state->cost[i][j] = val;
535 if( max_cost < val )
536 max_cost = val;
537 }
538 }
539 }
540 else
541 {
542 for( j = 0; j < dsize; j++ )
543 state->cost[i][j] = 0;
544 }
545 }
546
547 state->max_cost = max_cost;
548
549 memset( buffer, 0, buffer_end - buffer );
550
551 state->rows_x = (CvNode2D **) buffer;
552 buffer += ssize * sizeof( CvNode2D * );
553
554 state->cols_x = (CvNode2D **) buffer;
555 buffer += dsize * sizeof( CvNode2D * );
556
557 state->u = (CvNode1D *) buffer;
558 buffer += ssize * sizeof( CvNode1D );
559
560 state->v = (CvNode1D *) buffer;
561 buffer += dsize * sizeof( CvNode1D );
562
563 /* init is_x matrix */
564 state->is_x = (char **) buffer;
565 buffer += ssize * sizeof( char * );
566
567 for( i = 0; i < ssize; i++ )
568 {
569 state->is_x[i] = buffer;
570 buffer += dsize;
571 }
572
573 assert( buffer <= buffer_end );
574
575 icvRussel( state );
576
577 state->enter_x = (state->end_x)++;
578 return CV_NO_ERR;
579 }
580
581
582 /****************************************************************************************\
583 * icvFindBasicVariables *
584 \****************************************************************************************/
585 static CvStatus
icvFindBasicVariables(float ** cost,char ** is_x,CvNode1D * u,CvNode1D * v,int ssize,int dsize)586 icvFindBasicVariables( float **cost, char **is_x,
587 CvNode1D * u, CvNode1D * v, int ssize, int dsize )
588 {
589 int i, j, found;
590 int u_cfound, v_cfound;
591 CvNode1D u0_head, u1_head, *cur_u, *prev_u;
592 CvNode1D v0_head, v1_head, *cur_v, *prev_v;
593
594 /* initialize the rows list (u) and the columns list (v) */
595 u0_head.next = u;
596 for( i = 0; i < ssize; i++ )
597 {
598 u[i].next = u + i + 1;
599 }
600 u[ssize - 1].next = 0;
601 u1_head.next = 0;
602
603 v0_head.next = ssize > 1 ? v + 1 : 0;
604 for( i = 1; i < dsize; i++ )
605 {
606 v[i].next = v + i + 1;
607 }
608 v[dsize - 1].next = 0;
609 v1_head.next = 0;
610
611 /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
612 so set v[0]=0 */
613 v[0].val = 0;
614 v1_head.next = v;
615 v1_head.next->next = 0;
616
617 /* loop until all variables are found */
618 u_cfound = v_cfound = 0;
619 while( u_cfound < ssize || v_cfound < dsize )
620 {
621 found = 0;
622 if( v_cfound < dsize )
623 {
624 /* loop over all marked columns */
625 prev_v = &v1_head;
626
627 for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
628 {
629 float cur_v_val = cur_v->val;
630
631 j = (int)(cur_v - v);
632 /* find the variables in column j */
633 prev_u = &u0_head;
634 for( cur_u = u0_head.next; cur_u != 0; )
635 {
636 i = (int)(cur_u - u);
637 if( is_x[i][j] )
638 {
639 /* compute u[i] */
640 cur_u->val = cost[i][j] - cur_v_val;
641 /* ...and add it to the marked list */
642 prev_u->next = cur_u->next;
643 cur_u->next = u1_head.next;
644 u1_head.next = cur_u;
645 cur_u = prev_u->next;
646 }
647 else
648 {
649 prev_u = cur_u;
650 cur_u = cur_u->next;
651 }
652 }
653 prev_v->next = cur_v->next;
654 v_cfound++;
655 }
656 }
657
658 if( u_cfound < ssize )
659 {
660 /* loop over all marked rows */
661 prev_u = &u1_head;
662 for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
663 {
664 float cur_u_val = cur_u->val;
665 float *_cost;
666 char *_is_x;
667
668 i = (int)(cur_u - u);
669 _cost = cost[i];
670 _is_x = is_x[i];
671 /* find the variables in rows i */
672 prev_v = &v0_head;
673 for( cur_v = v0_head.next; cur_v != 0; )
674 {
675 j = (int)(cur_v - v);
676 if( _is_x[j] )
677 {
678 /* compute v[j] */
679 cur_v->val = _cost[j] - cur_u_val;
680 /* ...and add it to the marked list */
681 prev_v->next = cur_v->next;
682 cur_v->next = v1_head.next;
683 v1_head.next = cur_v;
684 cur_v = prev_v->next;
685 }
686 else
687 {
688 prev_v = cur_v;
689 cur_v = cur_v->next;
690 }
691 }
692 prev_u->next = cur_u->next;
693 u_cfound++;
694 }
695 }
696
697 if( !found )
698 {
699 return CV_NOTDEFINED_ERR;
700 }
701 }
702
703 return CV_NO_ERR;
704 }
705
706
707 /****************************************************************************************\
708 * icvIsOptimal *
709 \****************************************************************************************/
710 static float
icvIsOptimal(float ** cost,char ** is_x,CvNode1D * u,CvNode1D * v,int ssize,int dsize,CvNode2D * enter_x)711 icvIsOptimal( float **cost, char **is_x,
712 CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
713 {
714 float delta, min_delta = CV_EMD_INF;
715 int i, j, min_i = 0, min_j = 0;
716
717 /* find the minimal cij-ui-vj over all i,j */
718 for( i = 0; i < ssize; i++ )
719 {
720 float u_val = u[i].val;
721 float *_cost = cost[i];
722 char *_is_x = is_x[i];
723
724 for( j = 0; j < dsize; j++ )
725 {
726 if( !_is_x[j] )
727 {
728 delta = _cost[j] - u_val - v[j].val;
729 if( min_delta > delta )
730 {
731 min_delta = delta;
732 min_i = i;
733 min_j = j;
734 }
735 }
736 }
737 }
738
739 enter_x->i = min_i;
740 enter_x->j = min_j;
741
742 return min_delta;
743 }
744
745 /****************************************************************************************\
746 * icvNewSolution *
747 \****************************************************************************************/
748 static CvStatus
icvNewSolution(CvEMDState * state)749 icvNewSolution( CvEMDState * state )
750 {
751 int i, j;
752 float min_val = CV_EMD_INF;
753 int steps;
754 CvNode2D head, *cur_x, *next_x, *leave_x = 0;
755 CvNode2D *enter_x = state->enter_x;
756 CvNode2D **loop = state->loop;
757
758 /* enter the new basic variable */
759 i = enter_x->i;
760 j = enter_x->j;
761 state->is_x[i][j] = 1;
762 enter_x->next[0] = state->rows_x[i];
763 enter_x->next[1] = state->cols_x[j];
764 enter_x->val = 0;
765 state->rows_x[i] = enter_x;
766 state->cols_x[j] = enter_x;
767
768 /* find a chain reaction */
769 steps = icvFindLoop( state );
770
771 if( steps == 0 )
772 return CV_NOTDEFINED_ERR;
773
774 /* find the largest value in the loop */
775 for( i = 1; i < steps; i += 2 )
776 {
777 float temp = loop[i]->val;
778
779 if( min_val > temp )
780 {
781 leave_x = loop[i];
782 min_val = temp;
783 }
784 }
785
786 /* update the loop */
787 for( i = 0; i < steps; i += 2 )
788 {
789 float temp0 = loop[i]->val + min_val;
790 float temp1 = loop[i + 1]->val - min_val;
791
792 loop[i]->val = temp0;
793 loop[i + 1]->val = temp1;
794 }
795
796 /* remove the leaving basic variable */
797 i = leave_x->i;
798 j = leave_x->j;
799 state->is_x[i][j] = 0;
800
801 head.next[0] = state->rows_x[i];
802 cur_x = &head;
803 while( (next_x = cur_x->next[0]) != leave_x )
804 {
805 cur_x = next_x;
806 assert( cur_x );
807 }
808 cur_x->next[0] = next_x->next[0];
809 state->rows_x[i] = head.next[0];
810
811 head.next[1] = state->cols_x[j];
812 cur_x = &head;
813 while( (next_x = cur_x->next[1]) != leave_x )
814 {
815 cur_x = next_x;
816 assert( cur_x );
817 }
818 cur_x->next[1] = next_x->next[1];
819 state->cols_x[j] = head.next[1];
820
821 /* set enter_x to be the new empty slot */
822 state->enter_x = leave_x;
823
824 return CV_NO_ERR;
825 }
826
827
828
829 /****************************************************************************************\
830 * icvFindLoop *
831 \****************************************************************************************/
832 static int
icvFindLoop(CvEMDState * state)833 icvFindLoop( CvEMDState * state )
834 {
835 int i, steps = 1;
836 CvNode2D *new_x;
837 CvNode2D **loop = state->loop;
838 CvNode2D *enter_x = state->enter_x, *_x = state->_x;
839 char *is_used = state->is_used;
840
841 memset( is_used, 0, state->ssize + state->dsize );
842
843 new_x = loop[0] = enter_x;
844 is_used[enter_x - _x] = 1;
845 steps = 1;
846
847 do
848 {
849 if( (steps & 1) == 1 )
850 {
851 /* find an unused x in the row */
852 new_x = state->rows_x[new_x->i];
853 while( new_x != 0 && is_used[new_x - _x] )
854 new_x = new_x->next[0];
855 }
856 else
857 {
858 /* find an unused x in the column, or the entering x */
859 new_x = state->cols_x[new_x->j];
860 while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
861 new_x = new_x->next[1];
862 if( new_x == enter_x )
863 break;
864 }
865
866 if( new_x != 0 ) /* found the next x */
867 {
868 /* add x to the loop */
869 loop[steps++] = new_x;
870 is_used[new_x - _x] = 1;
871 }
872 else /* didn't find the next x */
873 {
874 /* backtrack */
875 do
876 {
877 i = steps & 1;
878 new_x = loop[steps - 1];
879 do
880 {
881 new_x = new_x->next[i];
882 }
883 while( new_x != 0 && is_used[new_x - _x] );
884
885 if( new_x == 0 )
886 {
887 is_used[loop[--steps] - _x] = 0;
888 }
889 }
890 while( new_x == 0 && steps > 0 );
891
892 is_used[loop[steps - 1] - _x] = 0;
893 loop[steps - 1] = new_x;
894 is_used[new_x - _x] = 1;
895 }
896 }
897 while( steps > 0 );
898
899 return steps;
900 }
901
902
903
904 /****************************************************************************************\
905 * icvRussel *
906 \****************************************************************************************/
907 static void
icvRussel(CvEMDState * state)908 icvRussel( CvEMDState * state )
909 {
910 int i, j, min_i = -1, min_j = -1;
911 float min_delta, diff;
912 CvNode1D u_head, *cur_u, *prev_u;
913 CvNode1D v_head, *cur_v, *prev_v;
914 CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
915 CvNode1D *u = state->u, *v = state->v;
916 int ssize = state->ssize, dsize = state->dsize;
917 float eps = CV_EMD_EPS * state->max_cost;
918 float **cost = state->cost;
919 float **delta = state->delta;
920
921 /* initialize the rows list (ur), and the columns list (vr) */
922 u_head.next = u;
923 for( i = 0; i < ssize; i++ )
924 {
925 u[i].next = u + i + 1;
926 }
927 u[ssize - 1].next = 0;
928
929 v_head.next = v;
930 for( i = 0; i < dsize; i++ )
931 {
932 v[i].val = -CV_EMD_INF;
933 v[i].next = v + i + 1;
934 }
935 v[dsize - 1].next = 0;
936
937 /* find the maximum row and column values (ur[i] and vr[j]) */
938 for( i = 0; i < ssize; i++ )
939 {
940 float u_val = -CV_EMD_INF;
941 float *cost_row = cost[i];
942
943 for( j = 0; j < dsize; j++ )
944 {
945 float temp = cost_row[j];
946
947 if( u_val < temp )
948 u_val = temp;
949 if( v[j].val < temp )
950 v[j].val = temp;
951 }
952 u[i].val = u_val;
953 }
954
955 /* compute the delta matrix */
956 for( i = 0; i < ssize; i++ )
957 {
958 float u_val = u[i].val;
959 float *delta_row = delta[i];
960 float *cost_row = cost[i];
961
962 for( j = 0; j < dsize; j++ )
963 {
964 delta_row[j] = cost_row[j] - u_val - v[j].val;
965 }
966 }
967
968 /* find the basic variables */
969 do
970 {
971 /* find the smallest delta[i][j] */
972 min_i = -1;
973 min_delta = CV_EMD_INF;
974 prev_u = &u_head;
975 for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
976 {
977 i = (int)(cur_u - u);
978 float *delta_row = delta[i];
979
980 prev_v = &v_head;
981 for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
982 {
983 j = (int)(cur_v - v);
984 if( min_delta > delta_row[j] )
985 {
986 min_delta = delta_row[j];
987 min_i = i;
988 min_j = j;
989 prev_u_min_i = prev_u;
990 prev_v_min_j = prev_v;
991 }
992 prev_v = cur_v;
993 }
994 prev_u = cur_u;
995 }
996
997 if( min_i < 0 )
998 break;
999
1000 /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
1001 remember = prev_u_min_i->next;
1002 icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
1003
1004 /* update the necessary delta[][] */
1005 if( remember == prev_u_min_i->next ) /* line min_i was deleted */
1006 {
1007 for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1008 {
1009 j = (int)(cur_v - v);
1010 if( cur_v->val == cost[min_i][j] ) /* column j needs updating */
1011 {
1012 float max_val = -CV_EMD_INF;
1013
1014 /* find the new maximum value in the column */
1015 for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1016 {
1017 float temp = cost[cur_u - u][j];
1018
1019 if( max_val < temp )
1020 max_val = temp;
1021 }
1022
1023 /* if needed, adjust the relevant delta[*][j] */
1024 diff = max_val - cur_v->val;
1025 cur_v->val = max_val;
1026 if( fabs( diff ) < eps )
1027 {
1028 for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1029 delta[cur_u - u][j] += diff;
1030 }
1031 }
1032 }
1033 }
1034 else /* column min_j was deleted */
1035 {
1036 for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
1037 {
1038 i = (int)(cur_u - u);
1039 if( cur_u->val == cost[i][min_j] ) /* row i needs updating */
1040 {
1041 float max_val = -CV_EMD_INF;
1042
1043 /* find the new maximum value in the row */
1044 for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1045 {
1046 float temp = cost[i][cur_v - v];
1047
1048 if( max_val < temp )
1049 max_val = temp;
1050 }
1051
1052 /* if needed, adjust the relevant delta[i][*] */
1053 diff = max_val - cur_u->val;
1054 cur_u->val = max_val;
1055
1056 if( fabs( diff ) < eps )
1057 {
1058 for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
1059 delta[i][cur_v - v] += diff;
1060 }
1061 }
1062 }
1063 }
1064 }
1065 while( u_head.next != 0 || v_head.next != 0 );
1066 }
1067
1068
1069
1070 /****************************************************************************************\
1071 * icvAddBasicVariable *
1072 \****************************************************************************************/
1073 static void
icvAddBasicVariable(CvEMDState * state,int min_i,int min_j,CvNode1D * prev_u_min_i,CvNode1D * prev_v_min_j,CvNode1D * u_head)1074 icvAddBasicVariable( CvEMDState * state,
1075 int min_i, int min_j,
1076 CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
1077 {
1078 float temp;
1079 CvNode2D *end_x = state->end_x;
1080
1081 if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
1082 { /* supply exhausted */
1083 temp = state->s[min_i];
1084 state->s[min_i] = 0;
1085 state->d[min_j] -= temp;
1086 }
1087 else /* demand exhausted */
1088 {
1089 temp = state->d[min_j];
1090 state->d[min_j] = 0;
1091 state->s[min_i] -= temp;
1092 }
1093
1094 /* x(min_i,min_j) is a basic variable */
1095 state->is_x[min_i][min_j] = 1;
1096
1097 end_x->val = temp;
1098 end_x->i = min_i;
1099 end_x->j = min_j;
1100 end_x->next[0] = state->rows_x[min_i];
1101 end_x->next[1] = state->cols_x[min_j];
1102 state->rows_x[min_i] = end_x;
1103 state->cols_x[min_j] = end_x;
1104 state->end_x = end_x + 1;
1105
1106 /* delete supply row only if the empty, and if not last row */
1107 if( state->s[min_i] == 0 && u_head->next->next != 0 )
1108 prev_u_min_i->next = prev_u_min_i->next->next; /* remove row from list */
1109 else
1110 prev_v_min_j->next = prev_v_min_j->next->next; /* remove column from list */
1111 }
1112
1113
1114 /****************************************************************************************\
1115 * standard metrics *
1116 \****************************************************************************************/
1117 static float
icvDistL1(const float * x,const float * y,void * user_param)1118 icvDistL1( const float *x, const float *y, void *user_param )
1119 {
1120 int i, dims = (int)(size_t)user_param;
1121 double s = 0;
1122
1123 for( i = 0; i < dims; i++ )
1124 {
1125 double t = x[i] - y[i];
1126
1127 s += fabs( t );
1128 }
1129 return (float)s;
1130 }
1131
1132 static float
icvDistL2(const float * x,const float * y,void * user_param)1133 icvDistL2( const float *x, const float *y, void *user_param )
1134 {
1135 int i, dims = (int)(size_t)user_param;
1136 double s = 0;
1137
1138 for( i = 0; i < dims; i++ )
1139 {
1140 double t = x[i] - y[i];
1141
1142 s += t * t;
1143 }
1144 return cvSqrt( (float)s );
1145 }
1146
1147 static float
icvDistC(const float * x,const float * y,void * user_param)1148 icvDistC( const float *x, const float *y, void *user_param )
1149 {
1150 int i, dims = (int)(size_t)user_param;
1151 double s = 0;
1152
1153 for( i = 0; i < dims; i++ )
1154 {
1155 double t = fabs( x[i] - y[i] );
1156
1157 if( s < t )
1158 s = t;
1159 }
1160 return (float)s;
1161 }
1162
1163 /* End of file. */
1164
1165