1 /*
2 * Copyright (c) 2008-2020 Stefan Krah. All rights reserved.
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions
6 * are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright
9 * notice, this list of conditions and the following disclaimer.
10 *
11 * 2. Redistributions in binary form must reproduce the above copyright
12 * notice, this list of conditions and the following disclaimer in the
13 * documentation and/or other materials provided with the distribution.
14 *
15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25 * SUCH DAMAGE.
26 */
27
28
29 #ifndef LIBMPDEC_UMODARITH_H_
30 #define LIBMPDEC_UMODARITH_H_
31
32
33 #include "mpdecimal.h"
34
35 #include "constants.h"
36 #include "typearith.h"
37
38
39 /* Bignum: Low level routines for unsigned modular arithmetic. These are
40 used in the fast convolution functions for very large coefficients. */
41
42
43 /**************************************************************************/
44 /* ANSI modular arithmetic */
45 /**************************************************************************/
46
47
48 /*
49 * Restrictions: a < m and b < m
50 * ACL2 proof: umodarith.lisp: addmod-correct
51 */
52 static inline mpd_uint_t
addmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)53 addmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
54 {
55 mpd_uint_t s;
56
57 s = a + b;
58 s = (s < a) ? s - m : s;
59 s = (s >= m) ? s - m : s;
60
61 return s;
62 }
63
64 /*
65 * Restrictions: a < m and b < m
66 * ACL2 proof: umodarith.lisp: submod-2-correct
67 */
68 static inline mpd_uint_t
submod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)69 submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
70 {
71 mpd_uint_t d;
72
73 d = a - b;
74 d = (a < b) ? d + m : d;
75
76 return d;
77 }
78
79 /*
80 * Restrictions: a < 2m and b < 2m
81 * ACL2 proof: umodarith.lisp: section ext-submod
82 */
83 static inline mpd_uint_t
ext_submod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)84 ext_submod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
85 {
86 mpd_uint_t d;
87
88 a = (a >= m) ? a - m : a;
89 b = (b >= m) ? b - m : b;
90
91 d = a - b;
92 d = (a < b) ? d + m : d;
93
94 return d;
95 }
96
97 /*
98 * Reduce double word modulo m.
99 * Restrictions: m != 0
100 * ACL2 proof: umodarith.lisp: section dw-reduce
101 */
102 static inline mpd_uint_t
dw_reduce(mpd_uint_t hi,mpd_uint_t lo,mpd_uint_t m)103 dw_reduce(mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
104 {
105 mpd_uint_t r1, r2, w;
106
107 _mpd_div_word(&w, &r1, hi, m);
108 _mpd_div_words(&w, &r2, r1, lo, m);
109
110 return r2;
111 }
112
113 /*
114 * Subtract double word from a.
115 * Restrictions: a < m
116 * ACL2 proof: umodarith.lisp: section dw-submod
117 */
118 static inline mpd_uint_t
dw_submod(mpd_uint_t a,mpd_uint_t hi,mpd_uint_t lo,mpd_uint_t m)119 dw_submod(mpd_uint_t a, mpd_uint_t hi, mpd_uint_t lo, mpd_uint_t m)
120 {
121 mpd_uint_t d, r;
122
123 r = dw_reduce(hi, lo, m);
124 d = a - r;
125 d = (a < r) ? d + m : d;
126
127 return d;
128 }
129
130 #ifdef CONFIG_64
131
132 /**************************************************************************/
133 /* 64-bit modular arithmetic */
134 /**************************************************************************/
135
136 /*
137 * A proof of the algorithm is in literature/mulmod-64.txt. An ACL2
138 * proof is in umodarith.lisp: section "Fast modular reduction".
139 *
140 * Algorithm: calculate (a * b) % p:
141 *
142 * a) hi, lo <- a * b # Calculate a * b.
143 *
144 * b) hi, lo <- R(hi, lo) # Reduce modulo p.
145 *
146 * c) Repeat step b) until 0 <= hi * 2**64 + lo < 2*p.
147 *
148 * d) If the result is less than p, return lo. Otherwise return lo - p.
149 */
150
151 static inline mpd_uint_t
x64_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)152 x64_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
153 {
154 mpd_uint_t hi, lo, x, y;
155
156
157 _mpd_mul_words(&hi, &lo, a, b);
158
159 if (m & (1ULL<<32)) { /* P1 */
160
161 /* first reduction */
162 x = y = hi;
163 hi >>= 32;
164
165 x = lo - x;
166 if (x > lo) hi--;
167
168 y <<= 32;
169 lo = y + x;
170 if (lo < y) hi++;
171
172 /* second reduction */
173 x = y = hi;
174 hi >>= 32;
175
176 x = lo - x;
177 if (x > lo) hi--;
178
179 y <<= 32;
180 lo = y + x;
181 if (lo < y) hi++;
182
183 return (hi || lo >= m ? lo - m : lo);
184 }
185 else if (m & (1ULL<<34)) { /* P2 */
186
187 /* first reduction */
188 x = y = hi;
189 hi >>= 30;
190
191 x = lo - x;
192 if (x > lo) hi--;
193
194 y <<= 34;
195 lo = y + x;
196 if (lo < y) hi++;
197
198 /* second reduction */
199 x = y = hi;
200 hi >>= 30;
201
202 x = lo - x;
203 if (x > lo) hi--;
204
205 y <<= 34;
206 lo = y + x;
207 if (lo < y) hi++;
208
209 /* third reduction */
210 x = y = hi;
211 hi >>= 30;
212
213 x = lo - x;
214 if (x > lo) hi--;
215
216 y <<= 34;
217 lo = y + x;
218 if (lo < y) hi++;
219
220 return (hi || lo >= m ? lo - m : lo);
221 }
222 else { /* P3 */
223
224 /* first reduction */
225 x = y = hi;
226 hi >>= 24;
227
228 x = lo - x;
229 if (x > lo) hi--;
230
231 y <<= 40;
232 lo = y + x;
233 if (lo < y) hi++;
234
235 /* second reduction */
236 x = y = hi;
237 hi >>= 24;
238
239 x = lo - x;
240 if (x > lo) hi--;
241
242 y <<= 40;
243 lo = y + x;
244 if (lo < y) hi++;
245
246 /* third reduction */
247 x = y = hi;
248 hi >>= 24;
249
250 x = lo - x;
251 if (x > lo) hi--;
252
253 y <<= 40;
254 lo = y + x;
255 if (lo < y) hi++;
256
257 return (hi || lo >= m ? lo - m : lo);
258 }
259 }
260
261 static inline void
x64_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)262 x64_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
263 {
264 *a = x64_mulmod(*a, w, m);
265 *b = x64_mulmod(*b, w, m);
266 }
267
268 static inline void
x64_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)269 x64_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
270 mpd_uint_t m)
271 {
272 *a0 = x64_mulmod(*a0, b0, m);
273 *a1 = x64_mulmod(*a1, b1, m);
274 }
275
276 static inline mpd_uint_t
x64_powmod(mpd_uint_t base,mpd_uint_t exp,mpd_uint_t umod)277 x64_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
278 {
279 mpd_uint_t r = 1;
280
281 while (exp > 0) {
282 if (exp & 1)
283 r = x64_mulmod(r, base, umod);
284 base = x64_mulmod(base, base, umod);
285 exp >>= 1;
286 }
287
288 return r;
289 }
290
291 /* END CONFIG_64 */
292 #else /* CONFIG_32 */
293
294
295 /**************************************************************************/
296 /* 32-bit modular arithmetic */
297 /**************************************************************************/
298
299 #if defined(ANSI)
300 #if !defined(LEGACY_COMPILER)
301 /* HAVE_UINT64_T */
302 static inline mpd_uint_t
std_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)303 std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
304 {
305 return ((mpd_uuint_t) a * b) % m;
306 }
307
308 static inline void
std_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)309 std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
310 {
311 *a = ((mpd_uuint_t) *a * w) % m;
312 *b = ((mpd_uuint_t) *b * w) % m;
313 }
314
315 static inline void
std_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)316 std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
317 mpd_uint_t m)
318 {
319 *a0 = ((mpd_uuint_t) *a0 * b0) % m;
320 *a1 = ((mpd_uuint_t) *a1 * b1) % m;
321 }
322 /* END HAVE_UINT64_T */
323 #else
324 /* LEGACY_COMPILER */
325 static inline mpd_uint_t
std_mulmod(mpd_uint_t a,mpd_uint_t b,mpd_uint_t m)326 std_mulmod(mpd_uint_t a, mpd_uint_t b, mpd_uint_t m)
327 {
328 mpd_uint_t hi, lo, q, r;
329 _mpd_mul_words(&hi, &lo, a, b);
330 _mpd_div_words(&q, &r, hi, lo, m);
331 return r;
332 }
333
334 static inline void
std_mulmod2c(mpd_uint_t * a,mpd_uint_t * b,mpd_uint_t w,mpd_uint_t m)335 std_mulmod2c(mpd_uint_t *a, mpd_uint_t *b, mpd_uint_t w, mpd_uint_t m)
336 {
337 *a = std_mulmod(*a, w, m);
338 *b = std_mulmod(*b, w, m);
339 }
340
341 static inline void
std_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,mpd_uint_t m)342 std_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
343 mpd_uint_t m)
344 {
345 *a0 = std_mulmod(*a0, b0, m);
346 *a1 = std_mulmod(*a1, b1, m);
347 }
348 /* END LEGACY_COMPILER */
349 #endif
350
351 static inline mpd_uint_t
std_powmod(mpd_uint_t base,mpd_uint_t exp,mpd_uint_t umod)352 std_powmod(mpd_uint_t base, mpd_uint_t exp, mpd_uint_t umod)
353 {
354 mpd_uint_t r = 1;
355
356 while (exp > 0) {
357 if (exp & 1)
358 r = std_mulmod(r, base, umod);
359 base = std_mulmod(base, base, umod);
360 exp >>= 1;
361 }
362
363 return r;
364 }
365 #endif /* ANSI CONFIG_32 */
366
367
368 /**************************************************************************/
369 /* Pentium Pro modular arithmetic */
370 /**************************************************************************/
371
372 /*
373 * A proof of the algorithm is in literature/mulmod-ppro.txt. The FPU
374 * control word must be set to 64-bit precision and truncation mode
375 * prior to using these functions.
376 *
377 * Algorithm: calculate (a * b) % p:
378 *
379 * p := prime < 2**31
380 * pinv := (long double)1.0 / p (precalculated)
381 *
382 * a) n = a * b # Calculate exact product.
383 * b) qest = n * pinv # Calculate estimate for q = n / p.
384 * c) q = (qest+2**63)-2**63 # Truncate qest to the exact quotient.
385 * d) r = n - q * p # Calculate remainder.
386 *
387 * Remarks:
388 *
389 * - p = dmod and pinv = dinvmod.
390 * - dinvmod points to an array of three uint32_t, which is interpreted
391 * as an 80 bit long double by fldt.
392 * - Intel compilers prior to version 11 do not seem to handle the
393 * __GNUC__ inline assembly correctly.
394 * - random tests are provided in tests/extended/ppro_mulmod.c
395 */
396
397 #if defined(PPRO)
398 #if defined(ASM)
399
400 /* Return (a * b) % dmod */
401 static inline mpd_uint_t
ppro_mulmod(mpd_uint_t a,mpd_uint_t b,double * dmod,uint32_t * dinvmod)402 ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
403 {
404 mpd_uint_t retval;
405
406 __asm__ (
407 "fildl %2\n\t"
408 "fildl %1\n\t"
409 "fmulp %%st, %%st(1)\n\t"
410 "fldt (%4)\n\t"
411 "fmul %%st(1), %%st\n\t"
412 "flds %5\n\t"
413 "fadd %%st, %%st(1)\n\t"
414 "fsubrp %%st, %%st(1)\n\t"
415 "fldl (%3)\n\t"
416 "fmulp %%st, %%st(1)\n\t"
417 "fsubrp %%st, %%st(1)\n\t"
418 "fistpl %0\n\t"
419 : "=m" (retval)
420 : "m" (a), "m" (b), "r" (dmod), "r" (dinvmod), "m" (MPD_TWO63)
421 : "st", "memory"
422 );
423
424 return retval;
425 }
426
427 /*
428 * Two modular multiplications in parallel:
429 * *a0 = (*a0 * w) % dmod
430 * *a1 = (*a1 * w) % dmod
431 */
432 static inline void
ppro_mulmod2c(mpd_uint_t * a0,mpd_uint_t * a1,mpd_uint_t w,double * dmod,uint32_t * dinvmod)433 ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
434 double *dmod, uint32_t *dinvmod)
435 {
436 __asm__ (
437 "fildl %2\n\t"
438 "fildl (%1)\n\t"
439 "fmul %%st(1), %%st\n\t"
440 "fxch %%st(1)\n\t"
441 "fildl (%0)\n\t"
442 "fmulp %%st, %%st(1) \n\t"
443 "fldt (%4)\n\t"
444 "flds %5\n\t"
445 "fld %%st(2)\n\t"
446 "fmul %%st(2)\n\t"
447 "fadd %%st(1)\n\t"
448 "fsub %%st(1)\n\t"
449 "fmull (%3)\n\t"
450 "fsubrp %%st, %%st(3)\n\t"
451 "fxch %%st(2)\n\t"
452 "fistpl (%0)\n\t"
453 "fmul %%st(2)\n\t"
454 "fadd %%st(1)\n\t"
455 "fsubp %%st, %%st(1)\n\t"
456 "fmull (%3)\n\t"
457 "fsubrp %%st, %%st(1)\n\t"
458 "fistpl (%1)\n\t"
459 : : "r" (a0), "r" (a1), "m" (w),
460 "r" (dmod), "r" (dinvmod),
461 "m" (MPD_TWO63)
462 : "st", "memory"
463 );
464 }
465
466 /*
467 * Two modular multiplications in parallel:
468 * *a0 = (*a0 * b0) % dmod
469 * *a1 = (*a1 * b1) % dmod
470 */
471 static inline void
ppro_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,double * dmod,uint32_t * dinvmod)472 ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
473 double *dmod, uint32_t *dinvmod)
474 {
475 __asm__ (
476 "fildl %3\n\t"
477 "fildl (%2)\n\t"
478 "fmulp %%st, %%st(1)\n\t"
479 "fildl %1\n\t"
480 "fildl (%0)\n\t"
481 "fmulp %%st, %%st(1)\n\t"
482 "fldt (%5)\n\t"
483 "fld %%st(2)\n\t"
484 "fmul %%st(1), %%st\n\t"
485 "fxch %%st(1)\n\t"
486 "fmul %%st(2), %%st\n\t"
487 "flds %6\n\t"
488 "fldl (%4)\n\t"
489 "fxch %%st(3)\n\t"
490 "fadd %%st(1), %%st\n\t"
491 "fxch %%st(2)\n\t"
492 "fadd %%st(1), %%st\n\t"
493 "fxch %%st(2)\n\t"
494 "fsub %%st(1), %%st\n\t"
495 "fxch %%st(2)\n\t"
496 "fsubp %%st, %%st(1)\n\t"
497 "fxch %%st(1)\n\t"
498 "fmul %%st(2), %%st\n\t"
499 "fxch %%st(1)\n\t"
500 "fmulp %%st, %%st(2)\n\t"
501 "fsubrp %%st, %%st(3)\n\t"
502 "fsubrp %%st, %%st(1)\n\t"
503 "fxch %%st(1)\n\t"
504 "fistpl (%2)\n\t"
505 "fistpl (%0)\n\t"
506 : : "r" (a0), "m" (b0), "r" (a1), "m" (b1),
507 "r" (dmod), "r" (dinvmod),
508 "m" (MPD_TWO63)
509 : "st", "memory"
510 );
511 }
512 /* END PPRO GCC ASM */
513 #elif defined(MASM)
514
515 /* Return (a * b) % dmod */
516 static inline mpd_uint_t __cdecl
ppro_mulmod(mpd_uint_t a,mpd_uint_t b,double * dmod,uint32_t * dinvmod)517 ppro_mulmod(mpd_uint_t a, mpd_uint_t b, double *dmod, uint32_t *dinvmod)
518 {
519 mpd_uint_t retval;
520
521 __asm {
522 mov eax, dinvmod
523 mov edx, dmod
524 fild b
525 fild a
526 fmulp st(1), st
527 fld TBYTE PTR [eax]
528 fmul st, st(1)
529 fld MPD_TWO63
530 fadd st(1), st
531 fsubp st(1), st
532 fld QWORD PTR [edx]
533 fmulp st(1), st
534 fsubp st(1), st
535 fistp retval
536 }
537
538 return retval;
539 }
540
541 /*
542 * Two modular multiplications in parallel:
543 * *a0 = (*a0 * w) % dmod
544 * *a1 = (*a1 * w) % dmod
545 */
546 static inline mpd_uint_t __cdecl
ppro_mulmod2c(mpd_uint_t * a0,mpd_uint_t * a1,mpd_uint_t w,double * dmod,uint32_t * dinvmod)547 ppro_mulmod2c(mpd_uint_t *a0, mpd_uint_t *a1, mpd_uint_t w,
548 double *dmod, uint32_t *dinvmod)
549 {
550 __asm {
551 mov ecx, dmod
552 mov edx, a1
553 mov ebx, dinvmod
554 mov eax, a0
555 fild w
556 fild DWORD PTR [edx]
557 fmul st, st(1)
558 fxch st(1)
559 fild DWORD PTR [eax]
560 fmulp st(1), st
561 fld TBYTE PTR [ebx]
562 fld MPD_TWO63
563 fld st(2)
564 fmul st, st(2)
565 fadd st, st(1)
566 fsub st, st(1)
567 fmul QWORD PTR [ecx]
568 fsubp st(3), st
569 fxch st(2)
570 fistp DWORD PTR [eax]
571 fmul st, st(2)
572 fadd st, st(1)
573 fsubrp st(1), st
574 fmul QWORD PTR [ecx]
575 fsubp st(1), st
576 fistp DWORD PTR [edx]
577 }
578 }
579
580 /*
581 * Two modular multiplications in parallel:
582 * *a0 = (*a0 * b0) % dmod
583 * *a1 = (*a1 * b1) % dmod
584 */
585 static inline void __cdecl
ppro_mulmod2(mpd_uint_t * a0,mpd_uint_t b0,mpd_uint_t * a1,mpd_uint_t b1,double * dmod,uint32_t * dinvmod)586 ppro_mulmod2(mpd_uint_t *a0, mpd_uint_t b0, mpd_uint_t *a1, mpd_uint_t b1,
587 double *dmod, uint32_t *dinvmod)
588 {
589 __asm {
590 mov ecx, dmod
591 mov edx, a1
592 mov ebx, dinvmod
593 mov eax, a0
594 fild b1
595 fild DWORD PTR [edx]
596 fmulp st(1), st
597 fild b0
598 fild DWORD PTR [eax]
599 fmulp st(1), st
600 fld TBYTE PTR [ebx]
601 fld st(2)
602 fmul st, st(1)
603 fxch st(1)
604 fmul st, st(2)
605 fld DWORD PTR MPD_TWO63
606 fld QWORD PTR [ecx]
607 fxch st(3)
608 fadd st, st(1)
609 fxch st(2)
610 fadd st, st(1)
611 fxch st(2)
612 fsub st, st(1)
613 fxch st(2)
614 fsubrp st(1), st
615 fxch st(1)
616 fmul st, st(2)
617 fxch st(1)
618 fmulp st(2), st
619 fsubp st(3), st
620 fsubp st(1), st
621 fxch st(1)
622 fistp DWORD PTR [edx]
623 fistp DWORD PTR [eax]
624 }
625 }
626 #endif /* PPRO MASM (_MSC_VER) */
627
628
629 /* Return (base ** exp) % dmod */
630 static inline mpd_uint_t
ppro_powmod(mpd_uint_t base,mpd_uint_t exp,double * dmod,uint32_t * dinvmod)631 ppro_powmod(mpd_uint_t base, mpd_uint_t exp, double *dmod, uint32_t *dinvmod)
632 {
633 mpd_uint_t r = 1;
634
635 while (exp > 0) {
636 if (exp & 1)
637 r = ppro_mulmod(r, base, dmod, dinvmod);
638 base = ppro_mulmod(base, base, dmod, dinvmod);
639 exp >>= 1;
640 }
641
642 return r;
643 }
644 #endif /* PPRO */
645 #endif /* CONFIG_32 */
646
647
648 #endif /* LIBMPDEC_UMODARITH_H_ */
649