1 /*
2 * Copyright (c) 2008-2024 Stefan Krah. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright
11 * notice, this list of conditions and the following disclaimer in the
12 * documentation and/or other materials provided with the distribution.
13 *
14 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24 * SUCH DAMAGE.
25 */
26
27
28 #ifndef LIBMPDEC_UMODARITH_H_
29 #define LIBMPDEC_UMODARITH_H_
30
31
32 #include "constants.h"
33 #include "mpdecimal.h"
34 #include "typearith.h"
35
36
37 /* Bignum: Low level routines for unsigned modular arithmetic. These are
38 used in the fast convolution functions for very large coefficients. */
39
40
41 /**************************************************************************/
42 /* ANSI modular arithmetic */
43 /**************************************************************************/
44
45 /*
46 * Restrictions: a < m and b < m
47 * ACL2 proof: umodarith.lisp: addmod-correct
48 */
49 static inline mpd_uint_t
addmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)50 addmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
51 {
52 mpd_uint_t s;
53
54 s = a + b;
55 s = (s < a) ? s - m : s;
56 s = (s >= m) ? s - m : s;
57
58 return s;
59 }
60
61 /*
62 * Restrictions: a < m and b < m
63 * ACL2 proof: umodarith.lisp: submod-2-correct
64 */
65 static inline mpd_uint_t
submod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)66 submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
67 {
68 mpd_uint_t d;
69
70 d = a - b;
71 d = (a < b) ? d + m : d;
72
73 return d;
74 }
75
76 /*
77 * Restrictions: a < 2m and b < 2m
78 * ACL2 proof: umodarith.lisp: section ext-submod
79 */
80 static inline mpd_uint_t
ext_submod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)81 ext_submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
82 {
83 mpd_uint_t d;
84
85 a = (a >= m) ? a - m : a;
86 b = (b >= m) ? b - m : b;
87
88 d = a - b;
89 d = (a < b) ? d + m : d;
90
91 return d;
92 }
93
94 /*
95 * Reduce double word modulo m.
96 * Restrictions: m != 0
97 * ACL2 proof: umodarith.lisp: section dw-reduce
98 */
99 static inline mpd_uint_t
dw_reduce(mpd_uint_t hi,mpd_uint_t lo,mpd_uint_t m)100 dw_reduce(mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
101 {
102 mpd_uint_t r1, r2, w;
103
104 _mpd_div_word(&w, &r1, hi, m);
105 _mpd_div_words(&w, &r2, r1, lo, m);
106
107 return r2;
108 }
109
110 /*
111 * Subtract double word from a.
112 * Restrictions: a < m
113 * ACL2 proof: umodarith.lisp: section dw-submod
114 */
115 static inline mpd_uint_t
dw_submod(mpd_uint_t a,mpd_uint_t hi,mpd_uint_t lo,mpd_uint_t m)116 dw_submod(mpd_uint_t a, mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
117 {
118 mpd_uint_t d, r;
119
120 r = dw_reduce(hi, lo, m);
121 d = a - r;
122 d = (a < r) ? d + m : d;
123
124 return d;
125 }
126
127 #ifdef CONFIG_64
128
129 /**************************************************************************/
130 /* 64-bit modular arithmetic */
131 /**************************************************************************/
132
133 /*
134 * A proof of the algorithm is in literature/mulmod-64.txt. An ACL2
135 * proof is in umodarith.lisp: section "Fast modular reduction".
136 *
137 * Algorithm: calculate (a * b) % p:
138 *
139 * a) hi, lo <- a * b # Calculate a * b.
140 *
141 * b) hi, lo <- R(hi, lo) # Reduce modulo p.
142 *
143 * c) Repeat step b) until 0 <= hi * 2**64 + lo < 2*p.
144 *
145 * d) If the result is less than p, return lo. Otherwise return lo - p.
146 */
147
148 static inline mpd_uint_t
x64_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)149 x64_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
150 {
151 mpd_uint_t hi, lo, x, y;
152
153
154 _mpd_mul_words(&hi, &lo, a, b);
155
156 if (m & (1ULL<<32)) { /* P1 */
157
158 /* first reduction */
159 x = y = hi;
160 hi >>= 32;
161
162 x = lo - x;
163 if (x > lo) hi--;
164
165 y <<= 32;
166 lo = y + x;
167 if (lo < y) hi++;
168
169 /* second reduction */
170 x = y = hi;
171 hi >>= 32;
172
173 x = lo - x;
174 if (x > lo) hi--;
175
176 y <<= 32;
177 lo = y + x;
178 if (lo < y) hi++;
179
180 return (hi || lo >= m ? lo - m : lo);
181 }
182 else if (m & (1ULL<<34)) { /* P2 */
183
184 /* first reduction */
185 x = y = hi;
186 hi >>= 30;
187
188 x = lo - x;
189 if (x > lo) hi--;
190
191 y <<= 34;
192 lo = y + x;
193 if (lo < y) hi++;
194
195 /* second reduction */
196 x = y = hi;
197 hi >>= 30;
198
199 x = lo - x;
200 if (x > lo) hi--;
201
202 y <<= 34;
203 lo = y + x;
204 if (lo < y) hi++;
205
206 /* third reduction */
207 x = y = hi;
208 hi >>= 30;
209
210 x = lo - x;
211 if (x > lo) hi--;
212
213 y <<= 34;
214 lo = y + x;
215 if (lo < y) hi++;
216
217 return (hi || lo >= m ? lo - m : lo);
218 }
219 else { /* P3 */
220
221 /* first reduction */
222 x = y = hi;
223 hi >>= 24;
224
225 x = lo - x;
226 if (x > lo) hi--;
227
228 y <<= 40;
229 lo = y + x;
230 if (lo < y) hi++;
231
232 /* second reduction */
233 x = y = hi;
234 hi >>= 24;
235
236 x = lo - x;
237 if (x > lo) hi--;
238
239 y <<= 40;
240 lo = y + x;
241 if (lo < y) hi++;
242
243 /* third reduction */
244 x = y = hi;
245 hi >>= 24;
246
247 x = lo - x;
248 if (x > lo) hi--;
249
250 y <<= 40;
251 lo = y + x;
252 if (lo < y) hi++;
253
254 return (hi || lo >= m ? lo - m : lo);
255 }
256 }
257
258 static inline void
x64_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)259 x64_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
260 {
261 *a = x64_mulmod(*a, w, m);
262 *b = x64_mulmod(*b, w, m);
263 }
264
265 static inline void
x64_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)266 x64_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
267 mpd_uint_t m)
268 {
269 *a0 = x64_mulmod(*a0, b0, m);
270 *a1 = x64_mulmod(*a1, b1, m);
271 }
272
273 static inline mpd_uint_t
x64_powmod(mpd_uint_t base,mpd_uint_t exp,mpd_uint_t umod)274 x64_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
275 {
276 mpd_uint_t r = 1;
277
278 while (exp > 0) {
279 if (exp & 1)
280 r = x64_mulmod(r, base, umod);
281 base = x64_mulmod(base, base, umod);
282 exp >>= 1;
283 }
284
285 return r;
286 }
287
288 /* END CONFIG_64 */
289 #else /* CONFIG_32 */
290
291
292 /**************************************************************************/
293 /* 32-bit modular arithmetic */
294 /**************************************************************************/
295
296 #if defined(ANSI)
297 #if !defined(LEGACY_COMPILER)
298 /* HAVE_UINT64_T */
299 static inline mpd_uint_t
std_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)300 std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
301 {
302 return ((mpd_uuint_t) a * b) % m;
303 }
304
305 static inline void
std_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)306 std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
307 {
308 *a = ((mpd_uuint_t) *a * w) % m;
309 *b = ((mpd_uuint_t) *b * w) % m;
310 }
311
312 static inline void
std_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)313 std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
314 mpd_uint_t m)
315 {
316 *a0 = ((mpd_uuint_t) *a0 * b0) % m;
317 *a1 = ((mpd_uuint_t) *a1 * b1) % m;
318 }
319 /* END HAVE_UINT64_T */
320 #else
321 /* LEGACY_COMPILER */
322 static inline mpd_uint_t
std_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)323 std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
324 {
325 mpd_uint_t hi, lo, q, r;
326 _mpd_mul_words(&hi, &lo, a, b);
327 _mpd_div_words(&q, &r, hi, lo, m);
328 return r;
329 }
330
331 static inline void
std_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)332 std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
333 {
334 *a = std_mulmod(*a, w, m);
335 *b = std_mulmod(*b, w, m);
336 }
337
338 static inline void
std_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)339 std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
340 mpd_uint_t m)
341 {
342 *a0 = std_mulmod(*a0, b0, m);
343 *a1 = std_mulmod(*a1, b1, m);
344 }
345 /* END LEGACY_COMPILER */
346 #endif
347
348 static inline mpd_uint_t
std_powmod(mpd_uint_t base,mpd_uint_t exp,mpd_uint_t umod)349 std_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
350 {
351 mpd_uint_t r = 1;
352
353 while (exp > 0) {
354 if (exp & 1)
355 r = std_mulmod(r, base, umod);
356 base = std_mulmod(base, base, umod);
357 exp >>= 1;
358 }
359
360 return r;
361 }
362 #endif /* ANSI CONFIG_32 */
363
364
365 /**************************************************************************/
366 /* Pentium Pro modular arithmetic */
367 /**************************************************************************/
368
369 /*
370 * A proof of the algorithm is in literature/mulmod-ppro.txt. The FPU
371 * control word must be set to 64-bit precision and truncation mode
372 * prior to using these functions.
373 *
374 * Algorithm: calculate (a * b) % p:
375 *
376 * p := prime < 2**31
377 * pinv := (long double)1.0 / p (precalculated)
378 *
379 * a) n = a * b # Calculate exact product.
380 * b) qest = n * pinv # Calculate estimate for q = n / p.
381 * c) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.
382 * d) r = n - q * p # Calculate remainder.
383 *
384 * Remarks:
385 *
386 * - p = dmod and pinv = dinvmod.
387 * - dinvmod points to an array of three uint32_t, which is interpreted
388 * as an 80 bit long double by fldt.
389 * - Intel compilers prior to version 11 do not seem to handle the
390 * __GNUC__ inline assembly correctly.
391 * - random tests are provided in tests/extended/ppro_mulmod.c
392 */
393
394 #if defined(PPRO)
395 #if defined(ASM)
396
397 /* Return (a * b) % dmod */
398 static inline mpd_uint_t
ppro_mulmod(mpd_uint_t a,mpd_uint_t b,double * dmod,uint32_t * dinvmod)399 ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
400 {
401 mpd_uint_t retval;
402
403 __asm__ (
404 "fildl %2\n\t"
405 "fildl %1\n\t"
406 "fmulp %%st, %%st(1)\n\t"
407 "fldt (%4)\n\t"
408 "fmul %%st(1), %%st\n\t"
409 "flds %5\n\t"
410 "fadd %%st, %%st(1)\n\t"
411 "fsubrp %%st, %%st(1)\n\t"
412 "fldl (%3)\n\t"
413 "fmulp %%st, %%st(1)\n\t"
414 "fsubrp %%st, %%st(1)\n\t"
415 "fistpl %0\n\t"
416 : "=m" (retval)
417 : "m" (a), "m" (b), "r" (dmod), "r" (dinvmod), "m" (MPD_TWO63)
418 : "st", "memory"
419 );
420
421 return retval;
422 }
423
424 /*
425 * Two modular multiplications in parallel:
426 * *a0 = (*a0 * w) % dmod
427 * *a1 = (*a1 * w) % dmod
428 */
429 static inline void
ppro_mulmod2c(mpd_uint_t * a0,mpd_uint_t * a1,mpd_uint_t w,double * dmod,uint32_t * dinvmod)430 ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
431 double *dmod, uint32_t *dinvmod)
432 {
433 __asm__ (
434 "fildl %2\n\t"
435 "fildl (%1)\n\t"
436 "fmul %%st(1), %%st\n\t"
437 "fxch %%st(1)\n\t"
438 "fildl (%0)\n\t"
439 "fmulp %%st, %%st(1) \n\t"
440 "fldt (%4)\n\t"
441 "flds %5\n\t"
442 "fld %%st(2)\n\t"
443 "fmul %%st(2)\n\t"
444 "fadd %%st(1)\n\t"
445 "fsub %%st(1)\n\t"
446 "fmull (%3)\n\t"
447 "fsubrp %%st, %%st(3)\n\t"
448 "fxch %%st(2)\n\t"
449 "fistpl (%0)\n\t"
450 "fmul %%st(2)\n\t"
451 "fadd %%st(1)\n\t"
452 "fsubp %%st, %%st(1)\n\t"
453 "fmull (%3)\n\t"
454 "fsubrp %%st, %%st(1)\n\t"
455 "fistpl (%1)\n\t"
456 : : "r" (a0), "r" (a1), "m" (w),
457 "r" (dmod), "r" (dinvmod),
458 "m" (MPD_TWO63)
459 : "st", "memory"
460 );
461 }
462
463 /*
464 * Two modular multiplications in parallel:
465 * *a0 = (*a0 * b0) % dmod
466 * *a1 = (*a1 * b1) % dmod
467 */
468 static inline void
ppro_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,double * dmod,uint32_t * dinvmod)469 ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
470 double *dmod, uint32_t *dinvmod)
471 {
472 __asm__ (
473 "fildl %3\n\t"
474 "fildl (%2)\n\t"
475 "fmulp %%st, %%st(1)\n\t"
476 "fildl %1\n\t"
477 "fildl (%0)\n\t"
478 "fmulp %%st, %%st(1)\n\t"
479 "fldt (%5)\n\t"
480 "fld %%st(2)\n\t"
481 "fmul %%st(1), %%st\n\t"
482 "fxch %%st(1)\n\t"
483 "fmul %%st(2), %%st\n\t"
484 "flds %6\n\t"
485 "fldl (%4)\n\t"
486 "fxch %%st(3)\n\t"
487 "fadd %%st(1), %%st\n\t"
488 "fxch %%st(2)\n\t"
489 "fadd %%st(1), %%st\n\t"
490 "fxch %%st(2)\n\t"
491 "fsub %%st(1), %%st\n\t"
492 "fxch %%st(2)\n\t"
493 "fsubp %%st, %%st(1)\n\t"
494 "fxch %%st(1)\n\t"
495 "fmul %%st(2), %%st\n\t"
496 "fxch %%st(1)\n\t"
497 "fmulp %%st, %%st(2)\n\t"
498 "fsubrp %%st, %%st(3)\n\t"
499 "fsubrp %%st, %%st(1)\n\t"
500 "fxch %%st(1)\n\t"
501 "fistpl (%2)\n\t"
502 "fistpl (%0)\n\t"
503 : : "r" (a0), "m" (b0), "r" (a1), "m" (b1),
504 "r" (dmod), "r" (dinvmod),
505 "m" (MPD_TWO63)
506 : "st", "memory"
507 );
508 }
509 /* END PPRO GCC ASM */
510 #elif defined(MASM)
511
512 /* Return (a * b) % dmod */
513 static inline mpd_uint_t __cdecl
ppro_mulmod(mpd_uint_t a,mpd_uint_t b,double * dmod,uint32_t * dinvmod)514 ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
515 {
516 mpd_uint_t retval;
517
518 __asm {
519 mov eax, dinvmod
520 mov edx, dmod
521 fild b
522 fild a
523 fmulp st(1), st
524 fld TBYTE PTR [eax]
525 fmul st, st(1)
526 fld MPD_TWO63
527 fadd st(1), st
528 fsubp st(1), st
529 fld QWORD PTR [edx]
530 fmulp st(1), st
531 fsubp st(1), st
532 fistp retval
533 }
534
535 return retval;
536 }
537
538 /*
539 * Two modular multiplications in parallel:
540 * *a0 = (*a0 * w) % dmod
541 * *a1 = (*a1 * w) % dmod
542 */
543 static inline mpd_uint_t __cdecl
ppro_mulmod2c(mpd_uint_t * a0,mpd_uint_t * a1,mpd_uint_t w,double * dmod,uint32_t * dinvmod)544 ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
545 double *dmod, uint32_t *dinvmod)
546 {
547 __asm {
548 mov ecx, dmod
549 mov edx, a1
550 mov ebx, dinvmod
551 mov eax, a0
552 fild w
553 fild DWORD PTR [edx]
554 fmul st, st(1)
555 fxch st(1)
556 fild DWORD PTR [eax]
557 fmulp st(1), st
558 fld TBYTE PTR [ebx]
559 fld MPD_TWO63
560 fld st(2)
561 fmul st, st(2)
562 fadd st, st(1)
563 fsub st, st(1)
564 fmul QWORD PTR [ecx]
565 fsubp st(3), st
566 fxch st(2)
567 fistp DWORD PTR [eax]
568 fmul st, st(2)
569 fadd st, st(1)
570 fsubrp st(1), st
571 fmul QWORD PTR [ecx]
572 fsubp st(1), st
573 fistp DWORD PTR [edx]
574 }
575 }
576
577 /*
578 * Two modular multiplications in parallel:
579 * *a0 = (*a0 * b0) % dmod
580 * *a1 = (*a1 * b1) % dmod
581 */
582 static inline void __cdecl
ppro_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,double * dmod,uint32_t * dinvmod)583 ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
584 double *dmod, uint32_t *dinvmod)
585 {
586 __asm {
587 mov ecx, dmod
588 mov edx, a1
589 mov ebx, dinvmod
590 mov eax, a0
591 fild b1
592 fild DWORD PTR [edx]
593 fmulp st(1), st
594 fild b0
595 fild DWORD PTR [eax]
596 fmulp st(1), st
597 fld TBYTE PTR [ebx]
598 fld st(2)
599 fmul st, st(1)
600 fxch st(1)
601 fmul st, st(2)
602 fld DWORD PTR MPD_TWO63
603 fld QWORD PTR [ecx]
604 fxch st(3)
605 fadd st, st(1)
606 fxch st(2)
607 fadd st, st(1)
608 fxch st(2)
609 fsub st, st(1)
610 fxch st(2)
611 fsubrp st(1), st
612 fxch st(1)
613 fmul st, st(2)
614 fxch st(1)
615 fmulp st(2), st
616 fsubp st(3), st
617 fsubp st(1), st
618 fxch st(1)
619 fistp DWORD PTR [edx]
620 fistp DWORD PTR [eax]
621 }
622 }
623 #endif /* PPRO MASM (_MSC_VER) */
624
625
626 /* Return (base ** exp) % dmod */
627 static inline mpd_uint_t
ppro_powmod(mpd_uint_t base,mpd_uint_t exp,double * dmod,uint32_t * dinvmod)628 ppro_powmod(mpd_uint_t base, mpd_uint_t exp, double *dmod, uint32_t *dinvmod)
629 {
630 mpd_uint_t r = 1;
631
632 while (exp > 0) {
633 if (exp & 1)
634 r = ppro_mulmod(r, base, dmod, dinvmod);
635 base = ppro_mulmod(base, base, dmod, dinvmod);
636 exp >>= 1;
637 }
638
639 return r;
640 }
641 #endif /* PPRO */
642 #endif /* CONFIG_32 */
643
644
645 #endif /* LIBMPDEC_UMODARITH_H_ */
646