1 /*
2 * Copyright (c) 2008-2016 Stefan Krah. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28
29 #ifndef UMODARITH_H
30 #define UMODARITH_H
31
32
33 #include "constants.h"
34 #include "mpdecimal.h"
35 #include "typearith.h"
36
37
38 /* Bignum: Low level routines for unsigned modular arithmetic. These are
39 used in the fast convolution functions for very large coefficients. */
40
41
42 /**************************************************************************/
43 /* ANSI modular arithmetic */
44 /**************************************************************************/
45
46
47 /*
48 * Restrictions: a < m and b < m
49 * ACL2 proof: umodarith.lisp: addmod-correct
50 */
51 static inline mpd_uint_t
addmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)52 addmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
53 {
54 mpd_uint_t s;
55
56 s = a + b;
57 s = (s < a) ? s - m : s;
58 s = (s >= m) ? s - m : s;
59
60 return s;
61 }
62
63 /*
64 * Restrictions: a < m and b < m
65 * ACL2 proof: umodarith.lisp: submod-2-correct
66 */
67 static inline mpd_uint_t
submod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)68 submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
69 {
70 mpd_uint_t d;
71
72 d = a - b;
73 d = (a < b) ? d + m : d;
74
75 return d;
76 }
77
78 /*
79 * Restrictions: a < 2m and b < 2m
80 * ACL2 proof: umodarith.lisp: section ext-submod
81 */
82 static inline mpd_uint_t
ext_submod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)83 ext_submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
84 {
85 mpd_uint_t d;
86
87 a = (a >= m) ? a - m : a;
88 b = (b >= m) ? b - m : b;
89
90 d = a - b;
91 d = (a < b) ? d + m : d;
92
93 return d;
94 }
95
96 /*
97 * Reduce double word modulo m.
98 * Restrictions: m != 0
99 * ACL2 proof: umodarith.lisp: section dw-reduce
100 */
101 static inline mpd_uint_t
dw_reduce(mpd_uint_t hi,mpd_uint_t lo,mpd_uint_t m)102 dw_reduce(mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
103 {
104 mpd_uint_t r1, r2, w;
105
106 _mpd_div_word(&w, &r1, hi, m);
107 _mpd_div_words(&w, &r2, r1, lo, m);
108
109 return r2;
110 }
111
112 /*
113 * Subtract double word from a.
114 * Restrictions: a < m
115 * ACL2 proof: umodarith.lisp: section dw-submod
116 */
117 static inline mpd_uint_t
dw_submod(mpd_uint_t a,mpd_uint_t hi,mpd_uint_t lo,mpd_uint_t m)118 dw_submod(mpd_uint_t a, mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
119 {
120 mpd_uint_t d, r;
121
122 r = dw_reduce(hi, lo, m);
123 d = a - r;
124 d = (a < r) ? d + m : d;
125
126 return d;
127 }
128
129 #ifdef CONFIG_64
130
131 /**************************************************************************/
132 /* 64-bit modular arithmetic */
133 /**************************************************************************/
134
135 /*
136 * A proof of the algorithm is in literature/mulmod-64.txt. An ACL2
137 * proof is in umodarith.lisp: section "Fast modular reduction".
138 *
139 * Algorithm: calculate (a * b) % p:
140 *
141 * a) hi, lo <- a * b # Calculate a * b.
142 *
143 * b) hi, lo <- R(hi, lo) # Reduce modulo p.
144 *
145 * c) Repeat step b) until 0 <= hi * 2**64 + lo < 2*p.
146 *
147 * d) If the result is less than p, return lo. Otherwise return lo - p.
148 */
149
150 static inline mpd_uint_t
x64_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)151 x64_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
152 {
153 mpd_uint_t hi, lo, x, y;
154
155
156 _mpd_mul_words(&hi, &lo, a, b);
157
158 if (m & (1ULL<<32)) { /* P1 */
159
160 /* first reduction */
161 x = y = hi;
162 hi >>= 32;
163
164 x = lo - x;
165 if (x > lo) hi--;
166
167 y <<= 32;
168 lo = y + x;
169 if (lo < y) hi++;
170
171 /* second reduction */
172 x = y = hi;
173 hi >>= 32;
174
175 x = lo - x;
176 if (x > lo) hi--;
177
178 y <<= 32;
179 lo = y + x;
180 if (lo < y) hi++;
181
182 return (hi || lo >= m ? lo - m : lo);
183 }
184 else if (m & (1ULL<<34)) { /* P2 */
185
186 /* first reduction */
187 x = y = hi;
188 hi >>= 30;
189
190 x = lo - x;
191 if (x > lo) hi--;
192
193 y <<= 34;
194 lo = y + x;
195 if (lo < y) hi++;
196
197 /* second reduction */
198 x = y = hi;
199 hi >>= 30;
200
201 x = lo - x;
202 if (x > lo) hi--;
203
204 y <<= 34;
205 lo = y + x;
206 if (lo < y) hi++;
207
208 /* third reduction */
209 x = y = hi;
210 hi >>= 30;
211
212 x = lo - x;
213 if (x > lo) hi--;
214
215 y <<= 34;
216 lo = y + x;
217 if (lo < y) hi++;
218
219 return (hi || lo >= m ? lo - m : lo);
220 }
221 else { /* P3 */
222
223 /* first reduction */
224 x = y = hi;
225 hi >>= 24;
226
227 x = lo - x;
228 if (x > lo) hi--;
229
230 y <<= 40;
231 lo = y + x;
232 if (lo < y) hi++;
233
234 /* second reduction */
235 x = y = hi;
236 hi >>= 24;
237
238 x = lo - x;
239 if (x > lo) hi--;
240
241 y <<= 40;
242 lo = y + x;
243 if (lo < y) hi++;
244
245 /* third reduction */
246 x = y = hi;
247 hi >>= 24;
248
249 x = lo - x;
250 if (x > lo) hi--;
251
252 y <<= 40;
253 lo = y + x;
254 if (lo < y) hi++;
255
256 return (hi || lo >= m ? lo - m : lo);
257 }
258 }
259
260 static inline void
x64_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)261 x64_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
262 {
263 *a = x64_mulmod(*a, w, m);
264 *b = x64_mulmod(*b, w, m);
265 }
266
267 static inline void
x64_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)268 x64_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
269 mpd_uint_t m)
270 {
271 *a0 = x64_mulmod(*a0, b0, m);
272 *a1 = x64_mulmod(*a1, b1, m);
273 }
274
275 static inline mpd_uint_t
x64_powmod(mpd_uint_t base,mpd_uint_t exp,mpd_uint_t umod)276 x64_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
277 {
278 mpd_uint_t r = 1;
279
280 while (exp > 0) {
281 if (exp & 1)
282 r = x64_mulmod(r, base, umod);
283 base = x64_mulmod(base, base, umod);
284 exp >>= 1;
285 }
286
287 return r;
288 }
289
290 /* END CONFIG_64 */
291 #else /* CONFIG_32 */
292
293
294 /**************************************************************************/
295 /* 32-bit modular arithmetic */
296 /**************************************************************************/
297
298 #if defined(ANSI)
299 #if !defined(LEGACY_COMPILER)
300 /* HAVE_UINT64_T */
301 static inline mpd_uint_t
std_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)302 std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
303 {
304 return ((mpd_uuint_t) a * b) % m;
305 }
306
307 static inline void
std_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)308 std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
309 {
310 *a = ((mpd_uuint_t) *a * w) % m;
311 *b = ((mpd_uuint_t) *b * w) % m;
312 }
313
314 static inline void
std_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)315 std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
316 mpd_uint_t m)
317 {
318 *a0 = ((mpd_uuint_t) *a0 * b0) % m;
319 *a1 = ((mpd_uuint_t) *a1 * b1) % m;
320 }
321 /* END HAVE_UINT64_T */
322 #else
323 /* LEGACY_COMPILER */
324 static inline mpd_uint_t
std_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)325 std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
326 {
327 mpd_uint_t hi, lo, q, r;
328 _mpd_mul_words(&hi, &lo, a, b);
329 _mpd_div_words(&q, &r, hi, lo, m);
330 return r;
331 }
332
333 static inline void
std_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)334 std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
335 {
336 *a = std_mulmod(*a, w, m);
337 *b = std_mulmod(*b, w, m);
338 }
339
340 static inline void
std_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)341 std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
342 mpd_uint_t m)
343 {
344 *a0 = std_mulmod(*a0, b0, m);
345 *a1 = std_mulmod(*a1, b1, m);
346 }
347 /* END LEGACY_COMPILER */
348 #endif
349
350 static inline mpd_uint_t
std_powmod(mpd_uint_t base,mpd_uint_t exp,mpd_uint_t umod)351 std_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
352 {
353 mpd_uint_t r = 1;
354
355 while (exp > 0) {
356 if (exp & 1)
357 r = std_mulmod(r, base, umod);
358 base = std_mulmod(base, base, umod);
359 exp >>= 1;
360 }
361
362 return r;
363 }
364 #endif /* ANSI CONFIG_32 */
365
366
367 /**************************************************************************/
368 /* Pentium Pro modular arithmetic */
369 /**************************************************************************/
370
371 /*
372 * A proof of the algorithm is in literature/mulmod-ppro.txt. The FPU
373 * control word must be set to 64-bit precision and truncation mode
374 * prior to using these functions.
375 *
376 * Algorithm: calculate (a * b) % p:
377 *
378 * p := prime < 2**31
379 * pinv := (long double)1.0 / p (precalculated)
380 *
381 * a) n = a * b # Calculate exact product.
382 * b) qest = n * pinv # Calculate estimate for q = n / p.
383 * c) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.
384 * d) r = n - q * p # Calculate remainder.
385 *
386 * Remarks:
387 *
388 * - p = dmod and pinv = dinvmod.
389 * - dinvmod points to an array of three uint32_t, which is interpreted
390 * as an 80 bit long double by fldt.
391 * - Intel compilers prior to version 11 do not seem to handle the
392 * __GNUC__ inline assembly correctly.
393 * - random tests are provided in tests/extended/ppro_mulmod.c
394 */
395
396 #if defined(PPRO)
397 #if defined(ASM)
398
399 /* Return (a * b) % dmod */
400 static inline mpd_uint_t
ppro_mulmod(mpd_uint_t a,mpd_uint_t b,double * dmod,uint32_t * dinvmod)401 ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
402 {
403 mpd_uint_t retval;
404
405 __asm__ (
406 "fildl %2\n\t"
407 "fildl %1\n\t"
408 "fmulp %%st, %%st(1)\n\t"
409 "fldt (%4)\n\t"
410 "fmul %%st(1), %%st\n\t"
411 "flds %5\n\t"
412 "fadd %%st, %%st(1)\n\t"
413 "fsubrp %%st, %%st(1)\n\t"
414 "fldl (%3)\n\t"
415 "fmulp %%st, %%st(1)\n\t"
416 "fsubrp %%st, %%st(1)\n\t"
417 "fistpl %0\n\t"
418 : "=m" (retval)
419 : "m" (a), "m" (b), "r" (dmod), "r" (dinvmod), "m" (MPD_TWO63)
420 : "st", "memory"
421 );
422
423 return retval;
424 }
425
426 /*
427 * Two modular multiplications in parallel:
428 * *a0 = (*a0 * w) % dmod
429 * *a1 = (*a1 * w) % dmod
430 */
431 static inline void
ppro_mulmod2c(mpd_uint_t * a0,mpd_uint_t * a1,mpd_uint_t w,double * dmod,uint32_t * dinvmod)432 ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
433 double *dmod, uint32_t *dinvmod)
434 {
435 __asm__ (
436 "fildl %2\n\t"
437 "fildl (%1)\n\t"
438 "fmul %%st(1), %%st\n\t"
439 "fxch %%st(1)\n\t"
440 "fildl (%0)\n\t"
441 "fmulp %%st, %%st(1) \n\t"
442 "fldt (%4)\n\t"
443 "flds %5\n\t"
444 "fld %%st(2)\n\t"
445 "fmul %%st(2)\n\t"
446 "fadd %%st(1)\n\t"
447 "fsub %%st(1)\n\t"
448 "fmull (%3)\n\t"
449 "fsubrp %%st, %%st(3)\n\t"
450 "fxch %%st(2)\n\t"
451 "fistpl (%0)\n\t"
452 "fmul %%st(2)\n\t"
453 "fadd %%st(1)\n\t"
454 "fsubp %%st, %%st(1)\n\t"
455 "fmull (%3)\n\t"
456 "fsubrp %%st, %%st(1)\n\t"
457 "fistpl (%1)\n\t"
458 : : "r" (a0), "r" (a1), "m" (w),
459 "r" (dmod), "r" (dinvmod),
460 "m" (MPD_TWO63)
461 : "st", "memory"
462 );
463 }
464
465 /*
466 * Two modular multiplications in parallel:
467 * *a0 = (*a0 * b0) % dmod
468 * *a1 = (*a1 * b1) % dmod
469 */
470 static inline void
ppro_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,double * dmod,uint32_t * dinvmod)471 ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
472 double *dmod, uint32_t *dinvmod)
473 {
474 __asm__ (
475 "fildl %3\n\t"
476 "fildl (%2)\n\t"
477 "fmulp %%st, %%st(1)\n\t"
478 "fildl %1\n\t"
479 "fildl (%0)\n\t"
480 "fmulp %%st, %%st(1)\n\t"
481 "fldt (%5)\n\t"
482 "fld %%st(2)\n\t"
483 "fmul %%st(1), %%st\n\t"
484 "fxch %%st(1)\n\t"
485 "fmul %%st(2), %%st\n\t"
486 "flds %6\n\t"
487 "fldl (%4)\n\t"
488 "fxch %%st(3)\n\t"
489 "fadd %%st(1), %%st\n\t"
490 "fxch %%st(2)\n\t"
491 "fadd %%st(1), %%st\n\t"
492 "fxch %%st(2)\n\t"
493 "fsub %%st(1), %%st\n\t"
494 "fxch %%st(2)\n\t"
495 "fsubp %%st, %%st(1)\n\t"
496 "fxch %%st(1)\n\t"
497 "fmul %%st(2), %%st\n\t"
498 "fxch %%st(1)\n\t"
499 "fmulp %%st, %%st(2)\n\t"
500 "fsubrp %%st, %%st(3)\n\t"
501 "fsubrp %%st, %%st(1)\n\t"
502 "fxch %%st(1)\n\t"
503 "fistpl (%2)\n\t"
504 "fistpl (%0)\n\t"
505 : : "r" (a0), "m" (b0), "r" (a1), "m" (b1),
506 "r" (dmod), "r" (dinvmod),
507 "m" (MPD_TWO63)
508 : "st", "memory"
509 );
510 }
511 /* END PPRO GCC ASM */
512 #elif defined(MASM)
513
514 /* Return (a * b) % dmod */
515 static inline mpd_uint_t __cdecl
ppro_mulmod(mpd_uint_t a,mpd_uint_t b,double * dmod,uint32_t * dinvmod)516 ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
517 {
518 mpd_uint_t retval;
519
520 __asm {
521 mov eax, dinvmod
522 mov edx, dmod
523 fild b
524 fild a
525 fmulp st(1), st
526 fld TBYTE PTR [eax]
527 fmul st, st(1)
528 fld MPD_TWO63
529 fadd st(1), st
530 fsubp st(1), st
531 fld QWORD PTR [edx]
532 fmulp st(1), st
533 fsubp st(1), st
534 fistp retval
535 }
536
537 return retval;
538 }
539
540 /*
541 * Two modular multiplications in parallel:
542 * *a0 = (*a0 * w) % dmod
543 * *a1 = (*a1 * w) % dmod
544 */
545 static inline mpd_uint_t __cdecl
ppro_mulmod2c(mpd_uint_t * a0,mpd_uint_t * a1,mpd_uint_t w,double * dmod,uint32_t * dinvmod)546 ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
547 double *dmod, uint32_t *dinvmod)
548 {
549 __asm {
550 mov ecx, dmod
551 mov edx, a1
552 mov ebx, dinvmod
553 mov eax, a0
554 fild w
555 fild DWORD PTR [edx]
556 fmul st, st(1)
557 fxch st(1)
558 fild DWORD PTR [eax]
559 fmulp st(1), st
560 fld TBYTE PTR [ebx]
561 fld MPD_TWO63
562 fld st(2)
563 fmul st, st(2)
564 fadd st, st(1)
565 fsub st, st(1)
566 fmul QWORD PTR [ecx]
567 fsubp st(3), st
568 fxch st(2)
569 fistp DWORD PTR [eax]
570 fmul st, st(2)
571 fadd st, st(1)
572 fsubrp st(1), st
573 fmul QWORD PTR [ecx]
574 fsubp st(1), st
575 fistp DWORD PTR [edx]
576 }
577 }
578
579 /*
580 * Two modular multiplications in parallel:
581 * *a0 = (*a0 * b0) % dmod
582 * *a1 = (*a1 * b1) % dmod
583 */
584 static inline void __cdecl
ppro_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,double * dmod,uint32_t * dinvmod)585 ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
586 double *dmod, uint32_t *dinvmod)
587 {
588 __asm {
589 mov ecx, dmod
590 mov edx, a1
591 mov ebx, dinvmod
592 mov eax, a0
593 fild b1
594 fild DWORD PTR [edx]
595 fmulp st(1), st
596 fild b0
597 fild DWORD PTR [eax]
598 fmulp st(1), st
599 fld TBYTE PTR [ebx]
600 fld st(2)
601 fmul st, st(1)
602 fxch st(1)
603 fmul st, st(2)
604 fld DWORD PTR MPD_TWO63
605 fld QWORD PTR [ecx]
606 fxch st(3)
607 fadd st, st(1)
608 fxch st(2)
609 fadd st, st(1)
610 fxch st(2)
611 fsub st, st(1)
612 fxch st(2)
613 fsubrp st(1), st
614 fxch st(1)
615 fmul st, st(2)
616 fxch st(1)
617 fmulp st(2), st
618 fsubp st(3), st
619 fsubp st(1), st
620 fxch st(1)
621 fistp DWORD PTR [edx]
622 fistp DWORD PTR [eax]
623 }
624 }
625 #endif /* PPRO MASM (_MSC_VER) */
626
627
628 /* Return (base ** exp) % dmod */
629 static inline mpd_uint_t
ppro_powmod(mpd_uint_t base,mpd_uint_t exp,double * dmod,uint32_t * dinvmod)630 ppro_powmod(mpd_uint_t base, mpd_uint_t exp, double *dmod, uint32_t *dinvmod)
631 {
632 mpd_uint_t r = 1;
633
634 while (exp > 0) {
635 if (exp & 1)
636 r = ppro_mulmod(r, base, dmod, dinvmod);
637 base = ppro_mulmod(base, base, dmod, dinvmod);
638 exp >>= 1;
639 }
640
641 return r;
642 }
643 #endif /* PPRO */
644 #endif /* CONFIG_32 */
645
646
647 #endif /* UMODARITH_H */
648
649
650
651