• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Karatsuba convolution
2  *
3  *  Copyright (C) 1999 Ralph Loader <suckfish@ihug.co.nz>
4  *
5  * This library is free software; you can redistribute it and/or
6  * modify it under the terms of the GNU Library General Public
7  * License as published by the Free Software Foundation; either
8  * version 2 of the License, or (at your option) any later version.
9  *
10  * This library is distributed in the hope that it will be useful,
11  * but WITHOUT ANY WARRANTY; without even the implied warranty of
12  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13  * Library General Public License for more details.
14  *
15  * You should have received a copy of the GNU Library General Public
16  * License along with this library; if not, write to the
17  * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
18  * Boston, MA 02110-1301, USA.
19  *
20  *
21  * Note: 7th December 2004: This file used to be licensed under the GPL,
22  *       but we got permission from Ralp Loader to relicense it to LGPL.
23  *
24  *  $Id$
25  *
26  */
27 
28 /* The algorithm is based on the following.  For the convolution of a pair
29  * of pairs, (a,b) * (c,d) = (0, a.c, a.d+b.c, b.d), we can reduce the four
30  * multiplications to three, by the formulae a.d+b.c = (a+b).(c+d) - a.c -
31  * b.d.  A similar relation enables us to compute a 2n by 2n convolution
32  * using 3 n by n convolutions, and thus a 2^n by 2^n convolution using 3^n
33  * multiplications (as opposed to the 4^n that the quadratic algorithm
34  * takes. */
35 
36 /* For large n, this is slower than the O(n log n) that the FFT method
37  * takes, but we avoid using complex numbers, and we only have to compute
38  * one convolution, as opposed to 3 FFTs.  We have good locality-of-
39  * reference as well, which will help on CPUs with tiny caches.  */
40 
41 /* E.g., for a 512 x 512 convolution, the FFT method takes 55 * 512 = 28160
42  * (real) multiplications, as opposed to 3^9 = 19683 for the Karatsuba
43  * algorithm.  We actually want 257 outputs of a 256 x 512 convolution;
44  * that doesn't appear to give an easy advantage for the FFT algorithm, but
45  * for the Karatsuba algorithm, it's easy to use two 256 x 256
46  * convolutions, taking 2 x 3^8 = 12312 multiplications.  [This difference
47  * is that the FFT method "wraps" the arrays, doing a 2^n x 2^n -> 2^n,
48  * while the Karatsuba algorithm pads with zeros, doing 2^n x 2^n -> 2.2^n
49  * - 1]. */
50 
51 /* There's a big lie above, actually... for a 4x4 convolution, it's quicker
52  * to do it using 16 multiplications than the more complex Karatsuba
53  * algorithm...  So the recursion bottoms out at 4x4s.  This increases the
54  * number of multiplications by a factor of 16/9, but reduces the overheads
55  * dramatically. */
56 
57 /* The convolution algorithm is implemented as a stack machine.  We have a
58  * stack of commands, each in one of the forms "do a 2^n x 2^n
59  * convolution", or "combine these three length 2^n outputs into one
60  * 2^{n+1} output." */
61 
62 #ifdef HAVE_CONFIG_H
63 #include "config.h"
64 #endif
65 
66 #include <stdlib.h>
67 #include "convolve.h"
68 
69 typedef union stack_entry_s
70 {
71   struct
72   {
73     const double *left, *right;
74     double *out;
75   }
76   v;
77   struct
78   {
79     double *main, *null;
80   }
81   b;
82 
83 }
84 stack_entry;
85 
86 struct _struct_convolve_state
87 {
88   int depth, small, big, stack_size;
89   double *left;
90   double *right;
91   double *scratch;
92   stack_entry *stack;
93 };
94 
95 /*
96  * Initialisation routine - sets up tables and space to work in.
97  * Returns a pointer to internal state, to be used when performing calls.
98  * On error, returns NULL.
99  * The pointer should be freed when it is finished with, by convolve_close().
100  */
101 convolve_state *
convolve_init(int depth)102 convolve_init (int depth)
103 {
104   convolve_state *state;
105 
106   state = malloc (sizeof (convolve_state));
107   state->depth = depth;
108   state->small = (1 << depth);
109   state->big = (2 << depth);
110   state->stack_size = depth * 3;
111   state->left = calloc (state->big, sizeof (double));
112   state->right = calloc (state->small * 3, sizeof (double));
113   state->scratch = calloc (state->small * 3, sizeof (double));
114   state->stack = calloc (state->stack_size + 1, sizeof (stack_entry));
115   return state;
116 }
117 
118 /*
119  * Free the state allocated with convolve_init().
120  */
121 void
convolve_close(convolve_state * state)122 convolve_close (convolve_state * state)
123 {
124   free (state->left);
125   free (state->right);
126   free (state->scratch);
127   free (state->stack);
128   free (state);
129 }
130 
131 static void
convolve_4(double * out,const double * left,const double * right)132 convolve_4 (double *out, const double *left, const double *right)
133 /* This does a 4x4 -> 7 convolution.  For what it's worth, the slightly odd
134  * ordering gives about a 1% speed up on my Pentium II. */
135 {
136   double l0, l1, l2, l3, r0, r1, r2, r3;
137   double a;
138 
139   l0 = left[0];
140   r0 = right[0];
141   a = l0 * r0;
142   l1 = left[1];
143   r1 = right[1];
144   out[0] = a;
145   a = (l0 * r1) + (l1 * r0);
146   l2 = left[2];
147   r2 = right[2];
148   out[1] = a;
149   a = (l0 * r2) + (l1 * r1) + (l2 * r0);
150   l3 = left[3];
151   r3 = right[3];
152   out[2] = a;
153 
154   out[3] = (l0 * r3) + (l1 * r2) + (l2 * r1) + (l3 * r0);
155   out[4] = (l1 * r3) + (l2 * r2) + (l3 * r1);
156   out[5] = (l2 * r3) + (l3 * r2);
157   out[6] = l3 * r3;
158 }
159 
160 static void
convolve_run(stack_entry * top,unsigned size,double * scratch)161 convolve_run (stack_entry * top, unsigned size, double *scratch)
162 /* Interpret a stack of commands.  The stack starts with two entries; the
163  * convolution to do, and an illegal entry used to mark the stack top.  The
164  * size is the number of entries in each input, and must be a power of 2,
165  * and at least 8.  It is OK to have out equal to left and/or right.
166  * scratch must have length 3*size.  The number of stack entries needed is
167  * 3n-4 where size=2^n. */
168 {
169   do {
170     const double *left;
171     const double *right;
172     double *out;
173 
174     /* When we get here, the stack top is always a convolve,
175      * with size > 4.  So we will split it.  We repeatedly split
176      * the top entry until we get to size = 4. */
177 
178     left = top->v.left;
179     right = top->v.right;
180     out = top->v.out;
181     top++;
182 
183     do {
184       double *s_left, *s_right;
185       int i;
186 
187       /* Halve the size. */
188       size >>= 1;
189 
190       /* Allocate the scratch areas. */
191       s_left = scratch + size * 3;
192       /* s_right is a length 2*size buffer also used for
193        * intermediate output. */
194       s_right = scratch + size * 4;
195 
196       /* Create the intermediate factors. */
197       for (i = 0; i < size; i++) {
198         double l = left[i] + left[i + size];
199         double r = right[i] + right[i + size];
200 
201         s_left[i + size] = r;
202         s_left[i] = l;
203       }
204 
205       /* Push the combine entry onto the stack. */
206       top -= 3;
207       top[2].b.main = out;
208       top[2].b.null = NULL;
209 
210       /* Push the low entry onto the stack.  This must be
211        * the last of the three sub-convolutions, because
212        * it may overwrite the arguments. */
213       top[1].v.left = left;
214       top[1].v.right = right;
215       top[1].v.out = out;
216 
217       /* Push the mid entry onto the stack. */
218       top[0].v.left = s_left;
219       top[0].v.right = s_right;
220       top[0].v.out = s_right;
221 
222       /* Leave the high entry in variables. */
223       left += size;
224       right += size;
225       out += size * 2;
226 
227     } while (size > 4);
228 
229     /* When we get here, the stack top is a group of 3
230      * convolves, with size = 4, followed by some combines.  */
231     convolve_4 (out, left, right);
232     convolve_4 (top[0].v.out, top[0].v.left, top[0].v.right);
233     convolve_4 (top[1].v.out, top[1].v.left, top[1].v.right);
234     top += 2;
235 
236     /* Now process combines. */
237     do {
238       /* b.main is the output buffer, mid is the middle
239        * part which needs to be adjusted in place, and
240        * then folded back into the output.  We do this in
241        * a slightly strange way, so as to avoid having
242        * two loops. */
243       double *out = top->b.main;
244       double *mid = scratch + size * 4;
245       unsigned int i;
246 
247       top++;
248       out[size * 2 - 1] = 0;
249       for (i = 0; i < size - 1; i++) {
250         double lo;
251         double hi;
252 
253         lo = mid[0] - (out[0] + out[2 * size]) + out[size];
254         hi = mid[size] - (out[size] + out[3 * size]) + out[2 * size];
255         out[size] = lo;
256         out[2 * size] = hi;
257         out++;
258         mid++;
259       }
260       size <<= 1;
261     } while (top->b.null == NULL);
262   } while (top->b.main != NULL);
263 }
264 
265 /*
266  * convolve_match:
267  * @lastchoice: an array of size SMALL.
268  * @input: an array of size BIG (2*SMALL)
269  * @state: a (non-NULL) pointer returned by convolve_init.
270  *
271  * We find the contiguous SMALL-size sub-array of input that best matches
272  * lastchoice. A measure of how good a sub-array is compared with the lastchoice
273  * is given by the sum of the products of each pair of entries.  We maximise
274  * that, by taking an appropriate convolution, and then finding the maximum
275  * entry in the convolutions.
276  *
277  * Return: the position of the best match
278  */
279 int
convolve_match(const int * lastchoice,const short * input,convolve_state * state)280 convolve_match (const int *lastchoice, const short *input,
281     convolve_state * state)
282 {
283   double avg = 0;
284   double best;
285   int p = 0;
286   int i;
287   double *left = state->left;
288   double *right = state->right;
289   double *scratch = state->scratch;
290   stack_entry *top = state->stack + (state->stack_size - 1);
291 
292   for (i = 0; i < state->big; i++)
293     left[i] = input[i];
294 
295   for (i = 0; i < state->small; i++) {
296     double a = lastchoice[(state->small - 1) - i];
297 
298     right[i] = a;
299     avg += a;
300   }
301 
302   /* We adjust the smaller of the two input arrays to have average
303    * value 0.  This makes the eventual result insensitive to both
304    * constant offsets and positive multipliers of the inputs. */
305   avg /= state->small;
306   for (i = 0; i < state->small; i++)
307     right[i] -= avg;
308   /* End-of-stack marker. */
309   top[1].b.null = scratch;
310   top[1].b.main = NULL;
311   /* The low (small x small) part, of which we want the high outputs. */
312   top->v.left = left;
313   top->v.right = right;
314   top->v.out = right + state->small;
315   convolve_run (top, state->small, scratch);
316 
317   /* The high (small x small) part, of which we want the low outputs. */
318   top->v.left = left + state->small;
319   top->v.right = right;
320   top->v.out = right;
321   convolve_run (top, state->small, scratch);
322 
323   /* Now find the best position amoungs this.  Apart from the first
324    * and last, the required convolution outputs are formed by adding
325    * outputs from the two convolutions above. */
326   best = right[state->big - 1];
327   right[state->big + state->small - 1] = 0;
328   p = -1;
329   for (i = 0; i < state->small; i++) {
330     double a = right[i] + right[i + state->big];
331 
332     if (a > best) {
333       best = a;
334       p = i;
335     }
336   }
337   p++;
338 
339 #if 0
340   {
341     /* This is some debugging code... */
342     best = 0;
343     for (i = 0; i < state->small; i++)
344       best += ((double) input[i + p]) * ((double) lastchoice[i] - avg);
345 
346     for (i = 0; i <= state->small; i++) {
347       double tot = 0;
348       unsigned int j;
349 
350       for (j = 0; j < state->small; j++)
351         tot += ((double) input[i + j]) * ((double) lastchoice[j] - avg);
352       if (tot > best)
353         printf ("(%i)", i);
354       if (tot != left[i + (state->small - 1)])
355         printf ("!");
356     }
357 
358     printf ("%i\n", p);
359   }
360 #endif
361 
362   return p;
363 }
364