1 /*
2 * Tiny arbitrary precision floating point library
3 *
4 * Copyright (c) 2017-2020 Fabrice Bellard
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to deal
8 * in the Software without restriction, including without limitation the rights
9 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 * copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in
14 * all copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 * THE SOFTWARE.
23 */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <inttypes.h>
27 #include <math.h>
28 #include <string.h>
29 #include <assert.h>
30
31 #ifdef __AVX2__
32 #include <immintrin.h>
33 #endif
34
35 #include "cutils.h"
36 #include "libbf.h"
37
38 /* enable it to check the multiplication result */
39 //#define USE_MUL_CHECK
40 /* enable it to use FFT/NTT multiplication */
41 #define USE_FFT_MUL
42 /* enable decimal floating point support */
43 #define USE_BF_DEC
44
45 //#define inline __attribute__((always_inline))
46
47 #ifdef __AVX2__
48 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
49 #else
50 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
51 #endif
52
53 /* XXX: adjust */
54 #define DIVNORM_LARGE_THRESHOLD 50
55 #define UDIV1NORM_THRESHOLD 3
56
57 #if LIMB_BITS == 64
58 #define FMT_LIMB1 "%" PRIx64
59 #define FMT_LIMB "%016" PRIx64
60 #define PRId_LIMB PRId64
61 #define PRIu_LIMB PRIu64
62
63 #else
64
65 #define FMT_LIMB1 "%x"
66 #define FMT_LIMB "%08x"
67 #define PRId_LIMB "d"
68 #define PRIu_LIMB "u"
69
70 #endif
71
72 typedef intptr_t mp_size_t;
73
74 typedef int bf_op2_func_t(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
75 bf_flags_t flags);
76
77 #ifdef USE_FFT_MUL
78
79 #define FFT_MUL_R_OVERLAP_A (1 << 0)
80 #define FFT_MUL_R_OVERLAP_B (1 << 1)
81 #define FFT_MUL_R_NORESIZE (1 << 2)
82
83 static no_inline int fft_mul(bf_context_t *s,
84 bf_t *res, limb_t *a_tab, limb_t a_len,
85 limb_t *b_tab, limb_t b_len, int mul_flags);
86 static void fft_clear_cache(bf_context_t *s);
87 #endif
88 #ifdef USE_BF_DEC
89 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos);
90 #endif
91
92
93 /* could leading zeros */
clz(limb_t a)94 static inline int clz(limb_t a)
95 {
96 if (a == 0) {
97 return LIMB_BITS;
98 } else {
99 #if LIMB_BITS == 64
100 return clz64(a);
101 #else
102 return clz32(a);
103 #endif
104 }
105 }
106
ctz(limb_t a)107 static inline int ctz(limb_t a)
108 {
109 if (a == 0) {
110 return LIMB_BITS;
111 } else {
112 #if LIMB_BITS == 64
113 return ctz64(a);
114 #else
115 return ctz32(a);
116 #endif
117 }
118 }
119
ceil_log2(limb_t a)120 static inline int ceil_log2(limb_t a)
121 {
122 if (a <= 1)
123 return 0;
124 else
125 return LIMB_BITS - clz(a - 1);
126 }
127
128 /* b must be >= 1 */
ceil_div(slimb_t a,slimb_t b)129 static inline slimb_t ceil_div(slimb_t a, slimb_t b)
130 {
131 if (a >= 0)
132 return (a + b - 1) / b;
133 else
134 return a / b;
135 }
136
137 /* b must be >= 1 */
floor_div(slimb_t a,slimb_t b)138 static inline slimb_t floor_div(slimb_t a, slimb_t b)
139 {
140 if (a >= 0) {
141 return a / b;
142 } else {
143 return (a - b + 1) / b;
144 }
145 }
146
147 /* return r = a modulo b (0 <= r <= b - 1. b must be >= 1 */
smod(slimb_t a,slimb_t b)148 static inline limb_t smod(slimb_t a, slimb_t b)
149 {
150 a = a % (slimb_t)b;
151 if (a < 0)
152 a += b;
153 return a;
154 }
155
156 /* signed addition with saturation */
sat_add(slimb_t a,slimb_t b)157 static inline slimb_t sat_add(slimb_t a, slimb_t b)
158 {
159 slimb_t r;
160 r = a + b;
161 /* overflow ? */
162 if (((a ^ r) & (b ^ r)) < 0)
163 r = (a >> (LIMB_BITS - 1)) ^ (((limb_t)1 << (LIMB_BITS - 1)) - 1);
164 return r;
165 }
166
167 #define malloc(s) malloc_is_forbidden(s)
168 #define free(p) free_is_forbidden(p)
169 #define realloc(p, s) realloc_is_forbidden(p, s)
170
bf_context_init(bf_context_t * s,bf_realloc_func_t * realloc_func,void * realloc_opaque)171 void bf_context_init(bf_context_t *s, bf_realloc_func_t *realloc_func,
172 void *realloc_opaque)
173 {
174 memset(s, 0, sizeof(*s));
175 s->realloc_func = realloc_func;
176 s->realloc_opaque = realloc_opaque;
177 }
178
bf_context_end(bf_context_t * s)179 void bf_context_end(bf_context_t *s)
180 {
181 bf_clear_cache(s);
182 }
183
bf_init(bf_context_t * s,bf_t * r)184 void bf_init(bf_context_t *s, bf_t *r)
185 {
186 r->ctx = s;
187 r->sign = 0;
188 r->expn = BF_EXP_ZERO;
189 r->len = 0;
190 r->tab = NULL;
191 }
192
193 /* return 0 if OK, -1 if alloc error */
bf_resize(bf_t * r,limb_t len)194 int bf_resize(bf_t *r, limb_t len)
195 {
196 limb_t *tab;
197
198 if (len != r->len) {
199 tab = bf_realloc(r->ctx, r->tab, len * sizeof(limb_t));
200 if (!tab && len != 0)
201 return -1;
202 r->tab = tab;
203 r->len = len;
204 }
205 return 0;
206 }
207
208 /* return 0 or BF_ST_MEM_ERROR */
bf_set_ui(bf_t * r,uint64_t a)209 int bf_set_ui(bf_t *r, uint64_t a)
210 {
211 r->sign = 0;
212 if (a == 0) {
213 r->expn = BF_EXP_ZERO;
214 bf_resize(r, 0); /* cannot fail */
215 }
216 #if LIMB_BITS == 32
217 else if (a <= 0xffffffff)
218 #else
219 else
220 #endif
221 {
222 int shift;
223 if (bf_resize(r, 1))
224 goto fail;
225 shift = clz(a);
226 r->tab[0] = a << shift;
227 r->expn = LIMB_BITS - shift;
228 }
229 #if LIMB_BITS == 32
230 else {
231 uint32_t a1, a0;
232 int shift;
233 if (bf_resize(r, 2))
234 goto fail;
235 a0 = a;
236 a1 = a >> 32;
237 shift = clz(a1);
238 r->tab[0] = a0 << shift;
239 r->tab[1] = (a1 << shift) | (a0 >> (LIMB_BITS - shift));
240 r->expn = 2 * LIMB_BITS - shift;
241 }
242 #endif
243 return 0;
244 fail:
245 bf_set_nan(r);
246 return BF_ST_MEM_ERROR;
247 }
248
249 /* return 0 or BF_ST_MEM_ERROR */
bf_set_si(bf_t * r,int64_t a)250 int bf_set_si(bf_t *r, int64_t a)
251 {
252 int ret;
253
254 if (a < 0) {
255 ret = bf_set_ui(r, -a);
256 r->sign = 1;
257 } else {
258 ret = bf_set_ui(r, a);
259 }
260 return ret;
261 }
262
bf_set_nan(bf_t * r)263 void bf_set_nan(bf_t *r)
264 {
265 bf_resize(r, 0); /* cannot fail */
266 r->expn = BF_EXP_NAN;
267 r->sign = 0;
268 }
269
bf_set_zero(bf_t * r,int is_neg)270 void bf_set_zero(bf_t *r, int is_neg)
271 {
272 bf_resize(r, 0); /* cannot fail */
273 r->expn = BF_EXP_ZERO;
274 r->sign = is_neg;
275 }
276
bf_set_inf(bf_t * r,int is_neg)277 void bf_set_inf(bf_t *r, int is_neg)
278 {
279 bf_resize(r, 0); /* cannot fail */
280 r->expn = BF_EXP_INF;
281 r->sign = is_neg;
282 }
283
284 /* return 0 or BF_ST_MEM_ERROR */
bf_set(bf_t * r,const bf_t * a)285 int bf_set(bf_t *r, const bf_t *a)
286 {
287 if (r == a)
288 return 0;
289 if (bf_resize(r, a->len)) {
290 bf_set_nan(r);
291 return BF_ST_MEM_ERROR;
292 }
293 r->sign = a->sign;
294 r->expn = a->expn;
295 memcpy(r->tab, a->tab, a->len * sizeof(limb_t));
296 return 0;
297 }
298
299 /* equivalent to bf_set(r, a); bf_delete(a) */
bf_move(bf_t * r,bf_t * a)300 void bf_move(bf_t *r, bf_t *a)
301 {
302 bf_context_t *s = r->ctx;
303 if (r == a)
304 return;
305 bf_free(s, r->tab);
306 *r = *a;
307 }
308
get_limbz(const bf_t * a,limb_t idx)309 static limb_t get_limbz(const bf_t *a, limb_t idx)
310 {
311 if (idx >= a->len)
312 return 0;
313 else
314 return a->tab[idx];
315 }
316
317 /* get LIMB_BITS at bit position 'pos' in tab */
get_bits(const limb_t * tab,limb_t len,slimb_t pos)318 static inline limb_t get_bits(const limb_t *tab, limb_t len, slimb_t pos)
319 {
320 limb_t i, a0, a1;
321 int p;
322
323 i = pos >> LIMB_LOG2_BITS;
324 p = pos & (LIMB_BITS - 1);
325 if (i < len)
326 a0 = tab[i];
327 else
328 a0 = 0;
329 if (p == 0) {
330 return a0;
331 } else {
332 i++;
333 if (i < len)
334 a1 = tab[i];
335 else
336 a1 = 0;
337 return (a0 >> p) | (a1 << (LIMB_BITS - p));
338 }
339 }
340
get_bit(const limb_t * tab,limb_t len,slimb_t pos)341 static inline limb_t get_bit(const limb_t *tab, limb_t len, slimb_t pos)
342 {
343 slimb_t i;
344 i = pos >> LIMB_LOG2_BITS;
345 if (i < 0 || i >= len)
346 return 0;
347 return (tab[i] >> (pos & (LIMB_BITS - 1))) & 1;
348 }
349
limb_mask(int start,int last)350 static inline limb_t limb_mask(int start, int last)
351 {
352 limb_t v;
353 int n;
354 n = last - start + 1;
355 if (n == LIMB_BITS)
356 v = -1;
357 else
358 v = (((limb_t)1 << n) - 1) << start;
359 return v;
360 }
361
mp_scan_nz(const limb_t * tab,mp_size_t n)362 static limb_t mp_scan_nz(const limb_t *tab, mp_size_t n)
363 {
364 mp_size_t i;
365 for(i = 0; i < n; i++) {
366 if (tab[i] != 0)
367 return 1;
368 }
369 return 0;
370 }
371
372 /* return != 0 if one bit between 0 and bit_pos inclusive is not zero. */
scan_bit_nz(const bf_t * r,slimb_t bit_pos)373 static inline limb_t scan_bit_nz(const bf_t *r, slimb_t bit_pos)
374 {
375 slimb_t pos;
376 limb_t v;
377
378 pos = bit_pos >> LIMB_LOG2_BITS;
379 if (pos < 0)
380 return 0;
381 v = r->tab[pos] & limb_mask(0, bit_pos & (LIMB_BITS - 1));
382 if (v != 0)
383 return 1;
384 pos--;
385 while (pos >= 0) {
386 if (r->tab[pos] != 0)
387 return 1;
388 pos--;
389 }
390 return 0;
391 }
392
393 /* return the addend for rounding. Note that prec can be <= 0 (for
394 BF_FLAG_RADPNT_PREC) */
bf_get_rnd_add(int * pret,const bf_t * r,limb_t l,slimb_t prec,int rnd_mode)395 static int bf_get_rnd_add(int *pret, const bf_t *r, limb_t l,
396 slimb_t prec, int rnd_mode)
397 {
398 int add_one, inexact;
399 limb_t bit1, bit0;
400
401 if (rnd_mode == BF_RNDF) {
402 bit0 = 1; /* faithful rounding does not honor the INEXACT flag */
403 } else {
404 /* starting limb for bit 'prec + 1' */
405 bit0 = scan_bit_nz(r, l * LIMB_BITS - 1 - bf_max(0, prec + 1));
406 }
407
408 /* get the bit at 'prec' */
409 bit1 = get_bit(r->tab, l, l * LIMB_BITS - 1 - prec);
410 inexact = (bit1 | bit0) != 0;
411
412 add_one = 0;
413 switch(rnd_mode) {
414 case BF_RNDZ:
415 break;
416 case BF_RNDN:
417 if (bit1) {
418 if (bit0) {
419 add_one = 1;
420 } else {
421 /* round to even */
422 add_one =
423 get_bit(r->tab, l, l * LIMB_BITS - 1 - (prec - 1));
424 }
425 }
426 break;
427 case BF_RNDD:
428 case BF_RNDU:
429 if (r->sign == (rnd_mode == BF_RNDD))
430 add_one = inexact;
431 break;
432 case BF_RNDA:
433 add_one = inexact;
434 break;
435 case BF_RNDNA:
436 case BF_RNDF:
437 add_one = bit1;
438 break;
439 default:
440 abort();
441 }
442
443 if (inexact)
444 *pret |= BF_ST_INEXACT;
445 return add_one;
446 }
447
bf_set_overflow(bf_t * r,int sign,limb_t prec,bf_flags_t flags)448 static int bf_set_overflow(bf_t *r, int sign, limb_t prec, bf_flags_t flags)
449 {
450 slimb_t i, l, e_max;
451 int rnd_mode;
452
453 rnd_mode = flags & BF_RND_MASK;
454 if (prec == BF_PREC_INF ||
455 rnd_mode == BF_RNDN ||
456 rnd_mode == BF_RNDNA ||
457 rnd_mode == BF_RNDA ||
458 (rnd_mode == BF_RNDD && sign == 1) ||
459 (rnd_mode == BF_RNDU && sign == 0)) {
460 bf_set_inf(r, sign);
461 } else {
462 /* set to maximum finite number */
463 l = (prec + LIMB_BITS - 1) / LIMB_BITS;
464 if (bf_resize(r, l)) {
465 bf_set_nan(r);
466 return BF_ST_MEM_ERROR;
467 }
468 r->tab[0] = limb_mask((-prec) & (LIMB_BITS - 1),
469 LIMB_BITS - 1);
470 for(i = 1; i < l; i++)
471 r->tab[i] = (limb_t)-1;
472 e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
473 r->expn = e_max;
474 r->sign = sign;
475 }
476 return BF_ST_OVERFLOW | BF_ST_INEXACT;
477 }
478
479 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
480 assumed to have length 'l' (1 <= l <= r->len). Note: 'prec1' can be
481 infinite (BF_PREC_INF). 'ret' is 0 or BF_ST_INEXACT if the result
482 is known to be inexact. Can fail with BF_ST_MEM_ERROR in case of
483 overflow not returning infinity. */
__bf_round(bf_t * r,limb_t prec1,bf_flags_t flags,limb_t l,int ret)484 static int __bf_round(bf_t *r, limb_t prec1, bf_flags_t flags, limb_t l,
485 int ret)
486 {
487 limb_t v, a;
488 int shift, add_one, rnd_mode;
489 slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
490
491 /* e_min and e_max are computed to match the IEEE 754 conventions */
492 e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
493 e_min = -e_range + 3;
494 e_max = e_range;
495
496 if (flags & BF_FLAG_RADPNT_PREC) {
497 /* 'prec' is the precision after the radix point */
498 if (prec1 != BF_PREC_INF)
499 prec = r->expn + prec1;
500 else
501 prec = prec1;
502 } else if (unlikely(r->expn < e_min) && (flags & BF_FLAG_SUBNORMAL)) {
503 /* restrict the precision in case of potentially subnormal
504 result */
505 assert(prec1 != BF_PREC_INF);
506 prec = prec1 - (e_min - r->expn);
507 } else {
508 prec = prec1;
509 }
510
511 /* round to prec bits */
512 rnd_mode = flags & BF_RND_MASK;
513 add_one = bf_get_rnd_add(&ret, r, l, prec, rnd_mode);
514
515 if (prec <= 0) {
516 if (add_one) {
517 bf_resize(r, 1); /* cannot fail */
518 r->tab[0] = (limb_t)1 << (LIMB_BITS - 1);
519 r->expn += 1 - prec;
520 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
521 return ret;
522 } else {
523 goto underflow;
524 }
525 } else if (add_one) {
526 limb_t carry;
527
528 /* add one starting at digit 'prec - 1' */
529 bit_pos = l * LIMB_BITS - 1 - (prec - 1);
530 pos = bit_pos >> LIMB_LOG2_BITS;
531 carry = (limb_t)1 << (bit_pos & (LIMB_BITS - 1));
532
533 for(i = pos; i < l; i++) {
534 v = r->tab[i] + carry;
535 carry = (v < carry);
536 r->tab[i] = v;
537 if (carry == 0)
538 break;
539 }
540 if (carry) {
541 /* shift right by one digit */
542 v = 1;
543 for(i = l - 1; i >= pos; i--) {
544 a = r->tab[i];
545 r->tab[i] = (a >> 1) | (v << (LIMB_BITS - 1));
546 v = a;
547 }
548 r->expn++;
549 }
550 }
551
552 /* check underflow */
553 if (unlikely(r->expn < e_min)) {
554 if (flags & BF_FLAG_SUBNORMAL) {
555 /* if inexact, also set the underflow flag */
556 if (ret & BF_ST_INEXACT)
557 ret |= BF_ST_UNDERFLOW;
558 } else {
559 underflow:
560 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
561 bf_set_zero(r, r->sign);
562 return ret;
563 }
564 }
565
566 /* check overflow */
567 if (unlikely(r->expn > e_max))
568 return bf_set_overflow(r, r->sign, prec1, flags);
569
570 /* keep the bits starting at 'prec - 1' */
571 bit_pos = l * LIMB_BITS - 1 - (prec - 1);
572 i = bit_pos >> LIMB_LOG2_BITS;
573 if (i >= 0) {
574 shift = bit_pos & (LIMB_BITS - 1);
575 if (shift != 0)
576 r->tab[i] &= limb_mask(shift, LIMB_BITS - 1);
577 } else {
578 i = 0;
579 }
580 /* remove trailing zeros */
581 while (r->tab[i] == 0)
582 i++;
583 if (i > 0) {
584 l -= i;
585 memmove(r->tab, r->tab + i, l * sizeof(limb_t));
586 }
587 bf_resize(r, l); /* cannot fail */
588 return ret;
589 }
590
591 /* 'r' must be a finite number. */
bf_normalize_and_round(bf_t * r,limb_t prec1,bf_flags_t flags)592 int bf_normalize_and_round(bf_t *r, limb_t prec1, bf_flags_t flags)
593 {
594 limb_t l, v, a;
595 int shift, ret;
596 slimb_t i;
597
598 // bf_print_str("bf_renorm", r);
599 l = r->len;
600 while (l > 0 && r->tab[l - 1] == 0)
601 l--;
602 if (l == 0) {
603 /* zero */
604 r->expn = BF_EXP_ZERO;
605 bf_resize(r, 0); /* cannot fail */
606 ret = 0;
607 } else {
608 r->expn -= (r->len - l) * LIMB_BITS;
609 /* shift to have the MSB set to '1' */
610 v = r->tab[l - 1];
611 shift = clz(v);
612 if (shift != 0) {
613 v = 0;
614 for(i = 0; i < l; i++) {
615 a = r->tab[i];
616 r->tab[i] = (a << shift) | (v >> (LIMB_BITS - shift));
617 v = a;
618 }
619 r->expn -= shift;
620 }
621 ret = __bf_round(r, prec1, flags, l, 0);
622 }
623 // bf_print_str("r_final", r);
624 return ret;
625 }
626
627 /* return true if rounding can be done at precision 'prec' assuming
628 the exact result r is such that |r-a| <= 2^(EXP(a)-k). */
629 /* XXX: check the case where the exponent would be incremented by the
630 rounding */
bf_can_round(const bf_t * a,slimb_t prec,bf_rnd_t rnd_mode,slimb_t k)631 int bf_can_round(const bf_t *a, slimb_t prec, bf_rnd_t rnd_mode, slimb_t k)
632 {
633 BOOL is_rndn;
634 slimb_t bit_pos, n;
635 limb_t bit;
636
637 if (a->expn == BF_EXP_INF || a->expn == BF_EXP_NAN)
638 return FALSE;
639 if (rnd_mode == BF_RNDF) {
640 return (k >= (prec + 1));
641 }
642 if (a->expn == BF_EXP_ZERO)
643 return FALSE;
644 is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA);
645 if (k < (prec + 2))
646 return FALSE;
647 bit_pos = a->len * LIMB_BITS - 1 - prec;
648 n = k - prec;
649 /* bit pattern for RNDN or RNDNA: 0111.. or 1000...
650 for other rounding modes: 000... or 111...
651 */
652 bit = get_bit(a->tab, a->len, bit_pos);
653 bit_pos--;
654 n--;
655 bit ^= is_rndn;
656 /* XXX: slow, but a few iterations on average */
657 while (n != 0) {
658 if (get_bit(a->tab, a->len, bit_pos) != bit)
659 return TRUE;
660 bit_pos--;
661 n--;
662 }
663 return FALSE;
664 }
665
666 /* Cannot fail with BF_ST_MEM_ERROR. */
bf_round(bf_t * r,limb_t prec,bf_flags_t flags)667 int bf_round(bf_t *r, limb_t prec, bf_flags_t flags)
668 {
669 if (r->len == 0)
670 return 0;
671 return __bf_round(r, prec, flags, r->len, 0);
672 }
673
674 /* for debugging */
dump_limbs(const char * str,const limb_t * tab,limb_t n)675 static __maybe_unused void dump_limbs(const char *str, const limb_t *tab, limb_t n)
676 {
677 limb_t i;
678 printf("%s: len=%" PRId_LIMB "\n", str, n);
679 for(i = 0; i < n; i++) {
680 printf("%" PRId_LIMB ": " FMT_LIMB "\n",
681 i, tab[i]);
682 }
683 }
684
mp_print_str(const char * str,const limb_t * tab,limb_t n)685 void mp_print_str(const char *str, const limb_t *tab, limb_t n)
686 {
687 slimb_t i;
688 printf("%s= 0x", str);
689 for(i = n - 1; i >= 0; i--) {
690 if (i != (n - 1))
691 printf("_");
692 printf(FMT_LIMB, tab[i]);
693 }
694 printf("\n");
695 }
696
mp_print_str_h(const char * str,const limb_t * tab,limb_t n,limb_t high)697 static __maybe_unused void mp_print_str_h(const char *str,
698 const limb_t *tab, limb_t n,
699 limb_t high)
700 {
701 slimb_t i;
702 printf("%s= 0x", str);
703 printf(FMT_LIMB, high);
704 for(i = n - 1; i >= 0; i--) {
705 printf("_");
706 printf(FMT_LIMB, tab[i]);
707 }
708 printf("\n");
709 }
710
711 /* for debugging */
bf_print_str(const char * str,const bf_t * a)712 void bf_print_str(const char *str, const bf_t *a)
713 {
714 slimb_t i;
715 printf("%s=", str);
716
717 if (a->expn == BF_EXP_NAN) {
718 printf("NaN");
719 } else {
720 if (a->sign)
721 putchar('-');
722 if (a->expn == BF_EXP_ZERO) {
723 putchar('0');
724 } else if (a->expn == BF_EXP_INF) {
725 printf("Inf");
726 } else {
727 printf("0x0.");
728 for(i = a->len - 1; i >= 0; i--)
729 printf(FMT_LIMB, a->tab[i]);
730 printf("p%" PRId_LIMB, a->expn);
731 }
732 }
733 printf("\n");
734 }
735
736 /* compare the absolute value of 'a' and 'b'. Return < 0 if a < b, 0
737 if a = b and > 0 otherwise. */
bf_cmpu(const bf_t * a,const bf_t * b)738 int bf_cmpu(const bf_t *a, const bf_t *b)
739 {
740 slimb_t i;
741 limb_t len, v1, v2;
742
743 if (a->expn != b->expn) {
744 if (a->expn < b->expn)
745 return -1;
746 else
747 return 1;
748 }
749 len = bf_max(a->len, b->len);
750 for(i = len - 1; i >= 0; i--) {
751 v1 = get_limbz(a, a->len - len + i);
752 v2 = get_limbz(b, b->len - len + i);
753 if (v1 != v2) {
754 if (v1 < v2)
755 return -1;
756 else
757 return 1;
758 }
759 }
760 return 0;
761 }
762
763 /* Full order: -0 < 0, NaN == NaN and NaN is larger than all other numbers */
bf_cmp_full(const bf_t * a,const bf_t * b)764 int bf_cmp_full(const bf_t *a, const bf_t *b)
765 {
766 int res;
767
768 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
769 if (a->expn == b->expn)
770 res = 0;
771 else if (a->expn == BF_EXP_NAN)
772 res = 1;
773 else
774 res = -1;
775 } else if (a->sign != b->sign) {
776 res = 1 - 2 * a->sign;
777 } else {
778 res = bf_cmpu(a, b);
779 if (a->sign)
780 res = -res;
781 }
782 return res;
783 }
784
785 /* Standard floating point comparison: return 2 if one of the operands
786 is NaN (unordered) or -1, 0, 1 depending on the ordering assuming
787 -0 == +0 */
bf_cmp(const bf_t * a,const bf_t * b)788 int bf_cmp(const bf_t *a, const bf_t *b)
789 {
790 int res;
791
792 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
793 res = 2;
794 } else if (a->sign != b->sign) {
795 if (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_ZERO)
796 res = 0;
797 else
798 res = 1 - 2 * a->sign;
799 } else {
800 res = bf_cmpu(a, b);
801 if (a->sign)
802 res = -res;
803 }
804 return res;
805 }
806
807 /* Compute the number of bits 'n' matching the pattern:
808 a= X1000..0
809 b= X0111..1
810
811 When computing a-b, the result will have at least n leading zero
812 bits.
813
814 Precondition: a > b and a.expn - b.expn = 0 or 1
815 */
count_cancelled_bits(const bf_t * a,const bf_t * b)816 static limb_t count_cancelled_bits(const bf_t *a, const bf_t *b)
817 {
818 slimb_t bit_offset, b_offset, n;
819 int p, p1;
820 limb_t v1, v2, mask;
821
822 bit_offset = a->len * LIMB_BITS - 1;
823 b_offset = (b->len - a->len) * LIMB_BITS - (LIMB_BITS - 1) +
824 a->expn - b->expn;
825 n = 0;
826
827 /* first search the equals bits */
828 for(;;) {
829 v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
830 v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
831 // printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
832 if (v1 != v2)
833 break;
834 n += LIMB_BITS;
835 bit_offset -= LIMB_BITS;
836 }
837 /* find the position of the first different bit */
838 p = clz(v1 ^ v2) + 1;
839 n += p;
840 /* then search for '0' in a and '1' in b */
841 p = LIMB_BITS - p;
842 if (p > 0) {
843 /* search in the trailing p bits of v1 and v2 */
844 mask = limb_mask(0, p - 1);
845 p1 = bf_min(clz(v1 & mask), clz((~v2) & mask)) - (LIMB_BITS - p);
846 n += p1;
847 if (p1 != p)
848 goto done;
849 }
850 bit_offset -= LIMB_BITS;
851 for(;;) {
852 v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
853 v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
854 // printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
855 if (v1 != 0 || v2 != -1) {
856 /* different: count the matching bits */
857 p1 = bf_min(clz(v1), clz(~v2));
858 n += p1;
859 break;
860 }
861 n += LIMB_BITS;
862 bit_offset -= LIMB_BITS;
863 }
864 done:
865 return n;
866 }
867
bf_add_internal(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int b_neg)868 static int bf_add_internal(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
869 bf_flags_t flags, int b_neg)
870 {
871 const bf_t *tmp;
872 int is_sub, ret, cmp_res, a_sign, b_sign;
873
874 a_sign = a->sign;
875 b_sign = b->sign ^ b_neg;
876 is_sub = a_sign ^ b_sign;
877 cmp_res = bf_cmpu(a, b);
878 if (cmp_res < 0) {
879 tmp = a;
880 a = b;
881 b = tmp;
882 a_sign = b_sign; /* b_sign is never used later */
883 }
884 /* abs(a) >= abs(b) */
885 if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
886 /* zero result */
887 bf_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
888 ret = 0;
889 } else if (a->len == 0 || b->len == 0) {
890 ret = 0;
891 if (a->expn >= BF_EXP_INF) {
892 if (a->expn == BF_EXP_NAN) {
893 /* at least one operand is NaN */
894 bf_set_nan(r);
895 } else if (b->expn == BF_EXP_INF && is_sub) {
896 /* infinities with different signs */
897 bf_set_nan(r);
898 ret = BF_ST_INVALID_OP;
899 } else {
900 bf_set_inf(r, a_sign);
901 }
902 } else {
903 /* at least one zero and not subtract */
904 bf_set(r, a);
905 r->sign = a_sign;
906 goto renorm;
907 }
908 } else {
909 slimb_t d, a_offset, b_bit_offset, i, cancelled_bits;
910 limb_t carry, v1, v2, u, r_len, carry1, precl, tot_len, z, sub_mask;
911
912 r->sign = a_sign;
913 r->expn = a->expn;
914 d = a->expn - b->expn;
915 /* must add more precision for the leading cancelled bits in
916 subtraction */
917 if (is_sub) {
918 if (d <= 1)
919 cancelled_bits = count_cancelled_bits(a, b);
920 else
921 cancelled_bits = 1;
922 } else {
923 cancelled_bits = 0;
924 }
925
926 /* add two extra bits for rounding */
927 precl = (cancelled_bits + prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
928 tot_len = bf_max(a->len, b->len + (d + LIMB_BITS - 1) / LIMB_BITS);
929 r_len = bf_min(precl, tot_len);
930 if (bf_resize(r, r_len))
931 goto fail;
932 a_offset = a->len - r_len;
933 b_bit_offset = (b->len - r_len) * LIMB_BITS + d;
934
935 /* compute the bits before for the rounding */
936 carry = is_sub;
937 z = 0;
938 sub_mask = -is_sub;
939 i = r_len - tot_len;
940 while (i < 0) {
941 slimb_t ap, bp;
942 BOOL inflag;
943
944 ap = a_offset + i;
945 bp = b_bit_offset + i * LIMB_BITS;
946 inflag = FALSE;
947 if (ap >= 0 && ap < a->len) {
948 v1 = a->tab[ap];
949 inflag = TRUE;
950 } else {
951 v1 = 0;
952 }
953 if (bp + LIMB_BITS > 0 && bp < (slimb_t)(b->len * LIMB_BITS)) {
954 v2 = get_bits(b->tab, b->len, bp);
955 inflag = TRUE;
956 } else {
957 v2 = 0;
958 }
959 if (!inflag) {
960 /* outside 'a' and 'b': go directly to the next value
961 inside a or b so that the running time does not
962 depend on the exponent difference */
963 i = 0;
964 if (ap < 0)
965 i = bf_min(i, -a_offset);
966 /* b_bit_offset + i * LIMB_BITS + LIMB_BITS >= 1
967 equivalent to
968 i >= ceil(-b_bit_offset + 1 - LIMB_BITS) / LIMB_BITS)
969 */
970 if (bp + LIMB_BITS <= 0)
971 i = bf_min(i, (-b_bit_offset) >> LIMB_LOG2_BITS);
972 } else {
973 i++;
974 }
975 v2 ^= sub_mask;
976 u = v1 + v2;
977 carry1 = u < v1;
978 u += carry;
979 carry = (u < carry) | carry1;
980 z |= u;
981 }
982 /* and the result */
983 for(i = 0; i < r_len; i++) {
984 v1 = get_limbz(a, a_offset + i);
985 v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS);
986 v2 ^= sub_mask;
987 u = v1 + v2;
988 carry1 = u < v1;
989 u += carry;
990 carry = (u < carry) | carry1;
991 r->tab[i] = u;
992 }
993 /* set the extra bits for the rounding */
994 r->tab[0] |= (z != 0);
995
996 /* carry is only possible in add case */
997 if (!is_sub && carry) {
998 if (bf_resize(r, r_len + 1))
999 goto fail;
1000 r->tab[r_len] = 1;
1001 r->expn += LIMB_BITS;
1002 }
1003 renorm:
1004 ret = bf_normalize_and_round(r, prec, flags);
1005 }
1006 return ret;
1007 fail:
1008 bf_set_nan(r);
1009 return BF_ST_MEM_ERROR;
1010 }
1011
__bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1012 static int __bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1013 bf_flags_t flags)
1014 {
1015 return bf_add_internal(r, a, b, prec, flags, 0);
1016 }
1017
__bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1018 static int __bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1019 bf_flags_t flags)
1020 {
1021 return bf_add_internal(r, a, b, prec, flags, 1);
1022 }
1023
mp_add(limb_t * res,const limb_t * op1,const limb_t * op2,limb_t n,limb_t carry)1024 limb_t mp_add(limb_t *res, const limb_t *op1, const limb_t *op2,
1025 limb_t n, limb_t carry)
1026 {
1027 slimb_t i;
1028 limb_t k, a, v, k1;
1029
1030 k = carry;
1031 for(i=0;i<n;i++) {
1032 v = op1[i];
1033 a = v + op2[i];
1034 k1 = a < v;
1035 a = a + k;
1036 k = (a < k) | k1;
1037 res[i] = a;
1038 }
1039 return k;
1040 }
1041
mp_add_ui(limb_t * tab,limb_t b,size_t n)1042 limb_t mp_add_ui(limb_t *tab, limb_t b, size_t n)
1043 {
1044 size_t i;
1045 limb_t k, a;
1046
1047 k=b;
1048 for(i=0;i<n;i++) {
1049 if (k == 0)
1050 break;
1051 a = tab[i] + k;
1052 k = (a < k);
1053 tab[i] = a;
1054 }
1055 return k;
1056 }
1057
mp_sub(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)1058 limb_t mp_sub(limb_t *res, const limb_t *op1, const limb_t *op2,
1059 mp_size_t n, limb_t carry)
1060 {
1061 int i;
1062 limb_t k, a, v, k1;
1063
1064 k = carry;
1065 for(i=0;i<n;i++) {
1066 v = op1[i];
1067 a = v - op2[i];
1068 k1 = a > v;
1069 v = a - k;
1070 k = (v > a) | k1;
1071 res[i] = v;
1072 }
1073 return k;
1074 }
1075
1076 /* compute 0 - op2 */
mp_neg(limb_t * res,const limb_t * op2,mp_size_t n,limb_t carry)1077 static limb_t mp_neg(limb_t *res, const limb_t *op2, mp_size_t n, limb_t carry)
1078 {
1079 int i;
1080 limb_t k, a, v, k1;
1081
1082 k = carry;
1083 for(i=0;i<n;i++) {
1084 v = 0;
1085 a = v - op2[i];
1086 k1 = a > v;
1087 v = a - k;
1088 k = (v > a) | k1;
1089 res[i] = v;
1090 }
1091 return k;
1092 }
1093
mp_sub_ui(limb_t * tab,limb_t b,mp_size_t n)1094 limb_t mp_sub_ui(limb_t *tab, limb_t b, mp_size_t n)
1095 {
1096 mp_size_t i;
1097 limb_t k, a, v;
1098
1099 k=b;
1100 for(i=0;i<n;i++) {
1101 v = tab[i];
1102 a = v - k;
1103 k = a > v;
1104 tab[i] = a;
1105 if (k == 0)
1106 break;
1107 }
1108 return k;
1109 }
1110
1111 /* r = (a + high*B^n) >> shift. Return the remainder r (0 <= r < 2^shift).
1112 1 <= shift <= LIMB_BITS - 1 */
mp_shr(limb_t * tab_r,const limb_t * tab,mp_size_t n,int shift,limb_t high)1113 static limb_t mp_shr(limb_t *tab_r, const limb_t *tab, mp_size_t n,
1114 int shift, limb_t high)
1115 {
1116 mp_size_t i;
1117 limb_t l, a;
1118
1119 assert(shift >= 1 && shift < LIMB_BITS);
1120 l = high;
1121 for(i = n - 1; i >= 0; i--) {
1122 a = tab[i];
1123 tab_r[i] = (a >> shift) | (l << (LIMB_BITS - shift));
1124 l = a;
1125 }
1126 return l & (((limb_t)1 << shift) - 1);
1127 }
1128
1129 /* tabr[] = taba[] * b + l. Return the high carry */
mp_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t l)1130 static limb_t mp_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1131 limb_t b, limb_t l)
1132 {
1133 limb_t i;
1134 dlimb_t t;
1135
1136 for(i = 0; i < n; i++) {
1137 t = (dlimb_t)taba[i] * (dlimb_t)b + l;
1138 tabr[i] = t;
1139 l = t >> LIMB_BITS;
1140 }
1141 return l;
1142 }
1143
1144 /* tabr[] += taba[] * b, return the high word. */
mp_add_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1145 static limb_t mp_add_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1146 limb_t b)
1147 {
1148 limb_t i, l;
1149 dlimb_t t;
1150
1151 l = 0;
1152 for(i = 0; i < n; i++) {
1153 t = (dlimb_t)taba[i] * (dlimb_t)b + l + tabr[i];
1154 tabr[i] = t;
1155 l = t >> LIMB_BITS;
1156 }
1157 return l;
1158 }
1159
1160 /* size of the result : op1_size + op2_size. */
mp_mul_basecase(limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1161 static void mp_mul_basecase(limb_t *result,
1162 const limb_t *op1, limb_t op1_size,
1163 const limb_t *op2, limb_t op2_size)
1164 {
1165 limb_t i, r;
1166
1167 result[op1_size] = mp_mul1(result, op1, op1_size, op2[0], 0);
1168 for(i=1;i<op2_size;i++) {
1169 r = mp_add_mul1(result + i, op1, op1_size, op2[i]);
1170 result[i + op1_size] = r;
1171 }
1172 }
1173
1174 /* return 0 if OK, -1 if memory error */
1175 /* XXX: change API so that result can be allocated */
mp_mul(bf_context_t * s,limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1176 int mp_mul(bf_context_t *s, limb_t *result,
1177 const limb_t *op1, limb_t op1_size,
1178 const limb_t *op2, limb_t op2_size)
1179 {
1180 #ifdef USE_FFT_MUL
1181 if (unlikely(bf_min(op1_size, op2_size) >= FFT_MUL_THRESHOLD)) {
1182 bf_t r_s, *r = &r_s;
1183 r->tab = result;
1184 /* XXX: optimize memory usage in API */
1185 if (fft_mul(s, r, (limb_t *)op1, op1_size,
1186 (limb_t *)op2, op2_size, FFT_MUL_R_NORESIZE))
1187 return -1;
1188 } else
1189 #endif
1190 {
1191 mp_mul_basecase(result, op1, op1_size, op2, op2_size);
1192 }
1193 return 0;
1194 }
1195
1196 /* tabr[] -= taba[] * b. Return the value to substract to the high
1197 word. */
mp_sub_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1198 static limb_t mp_sub_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1199 limb_t b)
1200 {
1201 limb_t i, l;
1202 dlimb_t t;
1203
1204 l = 0;
1205 for(i = 0; i < n; i++) {
1206 t = tabr[i] - (dlimb_t)taba[i] * (dlimb_t)b - l;
1207 tabr[i] = t;
1208 l = -(t >> LIMB_BITS);
1209 }
1210 return l;
1211 }
1212
1213 /* WARNING: d must be >= 2^(LIMB_BITS-1) */
udiv1norm_init(limb_t d)1214 static inline limb_t udiv1norm_init(limb_t d)
1215 {
1216 limb_t a0, a1;
1217 a1 = -d - 1;
1218 a0 = -1;
1219 return (((dlimb_t)a1 << LIMB_BITS) | a0) / d;
1220 }
1221
1222 /* return the quotient and the remainder in '*pr'of 'a1*2^LIMB_BITS+a0
1223 / d' with 0 <= a1 < d. */
udiv1norm(limb_t * pr,limb_t a1,limb_t a0,limb_t d,limb_t d_inv)1224 static inline limb_t udiv1norm(limb_t *pr, limb_t a1, limb_t a0,
1225 limb_t d, limb_t d_inv)
1226 {
1227 limb_t n1m, n_adj, q, r, ah;
1228 dlimb_t a;
1229 n1m = ((slimb_t)a0 >> (LIMB_BITS - 1));
1230 n_adj = a0 + (n1m & d);
1231 a = (dlimb_t)d_inv * (a1 - n1m) + n_adj;
1232 q = (a >> LIMB_BITS) + a1;
1233 /* compute a - q * r and update q so that the remainder is\
1234 between 0 and d - 1 */
1235 a = ((dlimb_t)a1 << LIMB_BITS) | a0;
1236 a = a - (dlimb_t)q * d - d;
1237 ah = a >> LIMB_BITS;
1238 q += 1 + ah;
1239 r = (limb_t)a + (ah & d);
1240 *pr = r;
1241 return q;
1242 }
1243
1244 /* b must be >= 1 << (LIMB_BITS - 1) */
mp_div1norm(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t r)1245 static limb_t mp_div1norm(limb_t *tabr, const limb_t *taba, limb_t n,
1246 limb_t b, limb_t r)
1247 {
1248 slimb_t i;
1249
1250 if (n >= UDIV1NORM_THRESHOLD) {
1251 limb_t b_inv;
1252 b_inv = udiv1norm_init(b);
1253 for(i = n - 1; i >= 0; i--) {
1254 tabr[i] = udiv1norm(&r, r, taba[i], b, b_inv);
1255 }
1256 } else {
1257 dlimb_t a1;
1258 for(i = n - 1; i >= 0; i--) {
1259 a1 = ((dlimb_t)r << LIMB_BITS) | taba[i];
1260 tabr[i] = a1 / b;
1261 r = a1 % b;
1262 }
1263 }
1264 return r;
1265 }
1266
1267 static int mp_divnorm_large(bf_context_t *s,
1268 limb_t *tabq, limb_t *taba, limb_t na,
1269 const limb_t *tabb, limb_t nb);
1270
1271 /* base case division: divides taba[0..na-1] by tabb[0..nb-1]. tabb[nb
1272 - 1] must be >= 1 << (LIMB_BITS - 1). na - nb must be >= 0. 'taba'
1273 is modified and contains the remainder (nb limbs). tabq[0..na-nb]
1274 contains the quotient with tabq[na - nb] <= 1. */
mp_divnorm(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1275 static int mp_divnorm(bf_context_t *s, limb_t *tabq, limb_t *taba, limb_t na,
1276 const limb_t *tabb, limb_t nb)
1277 {
1278 limb_t r, a, c, q, v, b1, b1_inv, n, dummy_r;
1279 slimb_t i, j;
1280
1281 b1 = tabb[nb - 1];
1282 if (nb == 1) {
1283 taba[0] = mp_div1norm(tabq, taba, na, b1, 0);
1284 return 0;
1285 }
1286 n = na - nb;
1287 if (bf_min(n, nb) >= DIVNORM_LARGE_THRESHOLD) {
1288 return mp_divnorm_large(s, tabq, taba, na, tabb, nb);
1289 }
1290
1291 if (n >= UDIV1NORM_THRESHOLD)
1292 b1_inv = udiv1norm_init(b1);
1293 else
1294 b1_inv = 0;
1295
1296 /* first iteration: the quotient is only 0 or 1 */
1297 q = 1;
1298 for(j = nb - 1; j >= 0; j--) {
1299 if (taba[n + j] != tabb[j]) {
1300 if (taba[n + j] < tabb[j])
1301 q = 0;
1302 break;
1303 }
1304 }
1305 tabq[n] = q;
1306 if (q) {
1307 mp_sub(taba + n, taba + n, tabb, nb, 0);
1308 }
1309
1310 for(i = n - 1; i >= 0; i--) {
1311 if (unlikely(taba[i + nb] >= b1)) {
1312 q = -1;
1313 } else if (b1_inv) {
1314 q = udiv1norm(&dummy_r, taba[i + nb], taba[i + nb - 1], b1, b1_inv);
1315 } else {
1316 dlimb_t al;
1317 al = ((dlimb_t)taba[i + nb] << LIMB_BITS) | taba[i + nb - 1];
1318 q = al / b1;
1319 r = al % b1;
1320 }
1321 r = mp_sub_mul1(taba + i, tabb, nb, q);
1322
1323 v = taba[i + nb];
1324 a = v - r;
1325 c = (a > v);
1326 taba[i + nb] = a;
1327
1328 if (c != 0) {
1329 /* negative result */
1330 for(;;) {
1331 q--;
1332 c = mp_add(taba + i, taba + i, tabb, nb, 0);
1333 /* propagate carry and test if positive result */
1334 if (c != 0) {
1335 if (++taba[i + nb] == 0) {
1336 break;
1337 }
1338 }
1339 }
1340 }
1341 tabq[i] = q;
1342 }
1343 return 0;
1344 }
1345
1346 /* compute r=B^(2*n)/a such as a*r < B^(2*n) < a*r + 2 with n >= 1. 'a'
1347 has n limbs with a[n-1] >= B/2 and 'r' has n+1 limbs with r[n] = 1.
1348
1349 See Modern Computer Arithmetic by Richard P. Brent and Paul
1350 Zimmermann, algorithm 3.5 */
mp_recip(bf_context_t * s,limb_t * tabr,const limb_t * taba,limb_t n)1351 int mp_recip(bf_context_t *s, limb_t *tabr, const limb_t *taba, limb_t n)
1352 {
1353 mp_size_t l, h, k, i;
1354 limb_t *tabxh, *tabt, c, *tabu;
1355
1356 if (n <= 2) {
1357 /* return ceil(B^(2*n)/a) - 1 */
1358 /* XXX: could avoid allocation */
1359 tabu = bf_malloc(s, sizeof(limb_t) * (2 * n + 1));
1360 tabt = bf_malloc(s, sizeof(limb_t) * (n + 2));
1361 if (!tabt || !tabu)
1362 goto fail;
1363 for(i = 0; i < 2 * n; i++)
1364 tabu[i] = 0;
1365 tabu[2 * n] = 1;
1366 if (mp_divnorm(s, tabt, tabu, 2 * n + 1, taba, n))
1367 goto fail;
1368 for(i = 0; i < n + 1; i++)
1369 tabr[i] = tabt[i];
1370 if (mp_scan_nz(tabu, n) == 0) {
1371 /* only happens for a=B^n/2 */
1372 mp_sub_ui(tabr, 1, n + 1);
1373 }
1374 } else {
1375 l = (n - 1) / 2;
1376 h = n - l;
1377 /* n=2p -> l=p-1, h = p + 1, k = p + 3
1378 n=2p+1-> l=p, h = p + 1; k = p + 2
1379 */
1380 tabt = bf_malloc(s, sizeof(limb_t) * (n + h + 1));
1381 tabu = bf_malloc(s, sizeof(limb_t) * (n + 2 * h - l + 2));
1382 if (!tabt || !tabu)
1383 goto fail;
1384 tabxh = tabr + l;
1385 if (mp_recip(s, tabxh, taba + l, h))
1386 goto fail;
1387 if (mp_mul(s, tabt, taba, n, tabxh, h + 1)) /* n + h + 1 limbs */
1388 goto fail;
1389 while (tabt[n + h] != 0) {
1390 mp_sub_ui(tabxh, 1, h + 1);
1391 c = mp_sub(tabt, tabt, taba, n, 0);
1392 mp_sub_ui(tabt + n, c, h + 1);
1393 }
1394 /* T = B^(n+h) - T */
1395 mp_neg(tabt, tabt, n + h + 1, 0);
1396 tabt[n + h]++;
1397 if (mp_mul(s, tabu, tabt + l, n + h + 1 - l, tabxh, h + 1))
1398 goto fail;
1399 /* n + 2*h - l + 2 limbs */
1400 k = 2 * h - l;
1401 for(i = 0; i < l; i++)
1402 tabr[i] = tabu[i + k];
1403 mp_add(tabr + l, tabr + l, tabu + 2 * h, h, 0);
1404 }
1405 bf_free(s, tabt);
1406 bf_free(s, tabu);
1407 return 0;
1408 fail:
1409 bf_free(s, tabt);
1410 bf_free(s, tabu);
1411 return -1;
1412 }
1413
1414 /* return -1, 0 or 1 */
mp_cmp(const limb_t * taba,const limb_t * tabb,mp_size_t n)1415 static int mp_cmp(const limb_t *taba, const limb_t *tabb, mp_size_t n)
1416 {
1417 mp_size_t i;
1418 for(i = n - 1; i >= 0; i--) {
1419 if (taba[i] != tabb[i]) {
1420 if (taba[i] < tabb[i])
1421 return -1;
1422 else
1423 return 1;
1424 }
1425 }
1426 return 0;
1427 }
1428
1429 //#define DEBUG_DIVNORM_LARGE
1430 //#define DEBUG_DIVNORM_LARGE2
1431
1432 /* subquadratic divnorm */
mp_divnorm_large(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1433 static int mp_divnorm_large(bf_context_t *s,
1434 limb_t *tabq, limb_t *taba, limb_t na,
1435 const limb_t *tabb, limb_t nb)
1436 {
1437 limb_t *tabb_inv, nq, *tabt, i, n;
1438 nq = na - nb;
1439 #ifdef DEBUG_DIVNORM_LARGE
1440 printf("na=%d nb=%d nq=%d\n", (int)na, (int)nb, (int)nq);
1441 mp_print_str("a", taba, na);
1442 mp_print_str("b", tabb, nb);
1443 #endif
1444 assert(nq >= 1);
1445 n = nq;
1446 if (nq < nb)
1447 n++;
1448 tabb_inv = bf_malloc(s, sizeof(limb_t) * (n + 1));
1449 tabt = bf_malloc(s, sizeof(limb_t) * 2 * (n + 1));
1450 if (!tabb_inv || !tabt)
1451 goto fail;
1452
1453 if (n >= nb) {
1454 for(i = 0; i < n - nb; i++)
1455 tabt[i] = 0;
1456 for(i = 0; i < nb; i++)
1457 tabt[i + n - nb] = tabb[i];
1458 } else {
1459 /* truncate B: need to increment it so that the approximate
1460 inverse is smaller that the exact inverse */
1461 for(i = 0; i < n; i++)
1462 tabt[i] = tabb[i + nb - n];
1463 if (mp_add_ui(tabt, 1, n)) {
1464 /* tabt = B^n : tabb_inv = B^n */
1465 memset(tabb_inv, 0, n * sizeof(limb_t));
1466 tabb_inv[n] = 1;
1467 goto recip_done;
1468 }
1469 }
1470 if (mp_recip(s, tabb_inv, tabt, n))
1471 goto fail;
1472 recip_done:
1473 /* Q=A*B^-1 */
1474 if (mp_mul(s, tabt, tabb_inv, n + 1, taba + na - (n + 1), n + 1))
1475 goto fail;
1476
1477 for(i = 0; i < nq + 1; i++)
1478 tabq[i] = tabt[i + 2 * (n + 1) - (nq + 1)];
1479 #ifdef DEBUG_DIVNORM_LARGE
1480 mp_print_str("q", tabq, nq + 1);
1481 #endif
1482
1483 bf_free(s, tabt);
1484 bf_free(s, tabb_inv);
1485 tabb_inv = NULL;
1486
1487 /* R=A-B*Q */
1488 tabt = bf_malloc(s, sizeof(limb_t) * (na + 1));
1489 if (!tabt)
1490 goto fail;
1491 if (mp_mul(s, tabt, tabq, nq + 1, tabb, nb))
1492 goto fail;
1493 /* we add one more limb for the result */
1494 mp_sub(taba, taba, tabt, nb + 1, 0);
1495 bf_free(s, tabt);
1496 /* the approximated quotient is smaller than than the exact one,
1497 hence we may have to increment it */
1498 #ifdef DEBUG_DIVNORM_LARGE2
1499 int cnt = 0;
1500 static int cnt_max;
1501 #endif
1502 for(;;) {
1503 if (taba[nb] == 0 && mp_cmp(taba, tabb, nb) < 0)
1504 break;
1505 taba[nb] -= mp_sub(taba, taba, tabb, nb, 0);
1506 mp_add_ui(tabq, 1, nq + 1);
1507 #ifdef DEBUG_DIVNORM_LARGE2
1508 cnt++;
1509 #endif
1510 }
1511 #ifdef DEBUG_DIVNORM_LARGE2
1512 if (cnt > cnt_max) {
1513 cnt_max = cnt;
1514 printf("\ncnt=%d nq=%d nb=%d\n", cnt_max, (int)nq, (int)nb);
1515 }
1516 #endif
1517 return 0;
1518 fail:
1519 bf_free(s, tabb_inv);
1520 bf_free(s, tabt);
1521 return -1;
1522 }
1523
bf_mul(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1524 int bf_mul(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1525 bf_flags_t flags)
1526 {
1527 int ret, r_sign;
1528
1529 if (a->len < b->len) {
1530 const bf_t *tmp = a;
1531 a = b;
1532 b = tmp;
1533 }
1534 r_sign = a->sign ^ b->sign;
1535 /* here b->len <= a->len */
1536 if (b->len == 0) {
1537 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1538 bf_set_nan(r);
1539 ret = 0;
1540 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
1541 if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
1542 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
1543 bf_set_nan(r);
1544 ret = BF_ST_INVALID_OP;
1545 } else {
1546 bf_set_inf(r, r_sign);
1547 ret = 0;
1548 }
1549 } else {
1550 bf_set_zero(r, r_sign);
1551 ret = 0;
1552 }
1553 } else {
1554 bf_t tmp, *r1 = NULL;
1555 limb_t a_len, b_len, precl;
1556 limb_t *a_tab, *b_tab;
1557
1558 a_len = a->len;
1559 b_len = b->len;
1560
1561 if ((flags & BF_RND_MASK) == BF_RNDF) {
1562 /* faithful rounding does not require using the full inputs */
1563 precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1564 a_len = bf_min(a_len, precl);
1565 b_len = bf_min(b_len, precl);
1566 }
1567 a_tab = a->tab + a->len - a_len;
1568 b_tab = b->tab + b->len - b_len;
1569
1570 #ifdef USE_FFT_MUL
1571 if (b_len >= FFT_MUL_THRESHOLD) {
1572 int mul_flags = 0;
1573 if (r == a)
1574 mul_flags |= FFT_MUL_R_OVERLAP_A;
1575 if (r == b)
1576 mul_flags |= FFT_MUL_R_OVERLAP_B;
1577 if (fft_mul(r->ctx, r, a_tab, a_len, b_tab, b_len, mul_flags))
1578 goto fail;
1579 } else
1580 #endif
1581 {
1582 if (r == a || r == b) {
1583 bf_init(r->ctx, &tmp);
1584 r1 = r;
1585 r = &tmp;
1586 }
1587 if (bf_resize(r, a_len + b_len)) {
1588 fail:
1589 bf_set_nan(r);
1590 ret = BF_ST_MEM_ERROR;
1591 goto done;
1592 }
1593 mp_mul_basecase(r->tab, a_tab, a_len, b_tab, b_len);
1594 }
1595 r->sign = r_sign;
1596 r->expn = a->expn + b->expn;
1597 ret = bf_normalize_and_round(r, prec, flags);
1598 done:
1599 if (r == &tmp)
1600 bf_move(r1, &tmp);
1601 }
1602 return ret;
1603 }
1604
1605 /* multiply 'r' by 2^e */
bf_mul_2exp(bf_t * r,slimb_t e,limb_t prec,bf_flags_t flags)1606 int bf_mul_2exp(bf_t *r, slimb_t e, limb_t prec, bf_flags_t flags)
1607 {
1608 slimb_t e_max;
1609 if (r->len == 0)
1610 return 0;
1611 e_max = ((limb_t)1 << BF_EXT_EXP_BITS_MAX) - 1;
1612 e = bf_max(e, -e_max);
1613 e = bf_min(e, e_max);
1614 r->expn += e;
1615 return __bf_round(r, prec, flags, r->len, 0);
1616 }
1617
1618 /* Return e such as a=m*2^e with m odd integer. return 0 if a is zero,
1619 Infinite or Nan. */
bf_get_exp_min(const bf_t * a)1620 slimb_t bf_get_exp_min(const bf_t *a)
1621 {
1622 slimb_t i;
1623 limb_t v;
1624 int k;
1625
1626 for(i = 0; i < a->len; i++) {
1627 v = a->tab[i];
1628 if (v != 0) {
1629 k = ctz(v);
1630 return a->expn - (a->len - i) * LIMB_BITS + k;
1631 }
1632 }
1633 return 0;
1634 }
1635
1636 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
1637 integer defined as floor(a/b) and r = a - q * b. */
bf_tdivremu(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b)1638 static void bf_tdivremu(bf_t *q, bf_t *r,
1639 const bf_t *a, const bf_t *b)
1640 {
1641 if (bf_cmpu(a, b) < 0) {
1642 bf_set_ui(q, 0);
1643 bf_set(r, a);
1644 } else {
1645 bf_div(q, a, b, bf_max(a->expn - b->expn + 1, 2), BF_RNDZ);
1646 bf_rint(q, BF_RNDZ);
1647 bf_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
1648 bf_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
1649 }
1650 }
1651
__bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1652 static int __bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1653 bf_flags_t flags)
1654 {
1655 bf_context_t *s = r->ctx;
1656 int ret, r_sign;
1657 limb_t n, nb, precl;
1658
1659 r_sign = a->sign ^ b->sign;
1660 if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
1661 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1662 bf_set_nan(r);
1663 return 0;
1664 } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
1665 bf_set_nan(r);
1666 return BF_ST_INVALID_OP;
1667 } else if (a->expn == BF_EXP_INF) {
1668 bf_set_inf(r, r_sign);
1669 return 0;
1670 } else {
1671 bf_set_zero(r, r_sign);
1672 return 0;
1673 }
1674 } else if (a->expn == BF_EXP_ZERO) {
1675 if (b->expn == BF_EXP_ZERO) {
1676 bf_set_nan(r);
1677 return BF_ST_INVALID_OP;
1678 } else {
1679 bf_set_zero(r, r_sign);
1680 return 0;
1681 }
1682 } else if (b->expn == BF_EXP_ZERO) {
1683 bf_set_inf(r, r_sign);
1684 return BF_ST_DIVIDE_ZERO;
1685 }
1686
1687 /* number of limbs of the quotient (2 extra bits for rounding) */
1688 precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1689 nb = b->len;
1690 n = bf_max(a->len, precl);
1691
1692 {
1693 limb_t *taba, na;
1694 slimb_t d;
1695
1696 na = n + nb;
1697 taba = bf_malloc(s, (na + 1) * sizeof(limb_t));
1698 if (!taba)
1699 goto fail;
1700 d = na - a->len;
1701 memset(taba, 0, d * sizeof(limb_t));
1702 memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
1703 if (bf_resize(r, n + 1))
1704 goto fail1;
1705 if (mp_divnorm(s, r->tab, taba, na, b->tab, nb)) {
1706 fail1:
1707 bf_free(s, taba);
1708 goto fail;
1709 }
1710 /* see if non zero remainder */
1711 if (mp_scan_nz(taba, nb))
1712 r->tab[0] |= 1;
1713 bf_free(r->ctx, taba);
1714 r->expn = a->expn - b->expn + LIMB_BITS;
1715 r->sign = r_sign;
1716 ret = bf_normalize_and_round(r, prec, flags);
1717 }
1718 return ret;
1719 fail:
1720 bf_set_nan(r);
1721 return BF_ST_MEM_ERROR;
1722 }
1723
1724 /* division and remainder.
1725
1726 rnd_mode is the rounding mode for the quotient. The additional
1727 rounding mode BF_RND_EUCLIDIAN is supported.
1728
1729 'q' is an integer. 'r' is rounded with prec and flags (prec can be
1730 BF_PREC_INF).
1731 */
bf_divrem(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1732 int bf_divrem(bf_t *q, bf_t *r, const bf_t *a, const bf_t *b,
1733 limb_t prec, bf_flags_t flags, int rnd_mode)
1734 {
1735 bf_t a1_s, *a1 = &a1_s;
1736 bf_t b1_s, *b1 = &b1_s;
1737 int q_sign, ret;
1738 BOOL is_ceil, is_rndn;
1739
1740 assert(q != a && q != b);
1741 assert(r != a && r != b);
1742 assert(q != r);
1743
1744 if (a->len == 0 || b->len == 0) {
1745 bf_set_zero(q, 0);
1746 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1747 bf_set_nan(r);
1748 return 0;
1749 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
1750 bf_set_nan(r);
1751 return BF_ST_INVALID_OP;
1752 } else {
1753 bf_set(r, a);
1754 return bf_round(r, prec, flags);
1755 }
1756 }
1757
1758 q_sign = a->sign ^ b->sign;
1759 is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA);
1760 switch(rnd_mode) {
1761 default:
1762 case BF_RNDZ:
1763 case BF_RNDN:
1764 case BF_RNDNA:
1765 is_ceil = FALSE;
1766 break;
1767 case BF_RNDD:
1768 is_ceil = q_sign;
1769 break;
1770 case BF_RNDU:
1771 is_ceil = q_sign ^ 1;
1772 break;
1773 case BF_RNDA:
1774 is_ceil = TRUE;
1775 break;
1776 case BF_DIVREM_EUCLIDIAN:
1777 is_ceil = a->sign;
1778 break;
1779 }
1780
1781 a1->expn = a->expn;
1782 a1->tab = a->tab;
1783 a1->len = a->len;
1784 a1->sign = 0;
1785
1786 b1->expn = b->expn;
1787 b1->tab = b->tab;
1788 b1->len = b->len;
1789 b1->sign = 0;
1790
1791 /* XXX: could improve to avoid having a large 'q' */
1792 bf_tdivremu(q, r, a1, b1);
1793 if (bf_is_nan(q) || bf_is_nan(r))
1794 goto fail;
1795
1796 if (r->len != 0) {
1797 if (is_rndn) {
1798 int res;
1799 b1->expn--;
1800 res = bf_cmpu(r, b1);
1801 b1->expn++;
1802 if (res > 0 ||
1803 (res == 0 &&
1804 (rnd_mode == BF_RNDNA ||
1805 get_bit(q->tab, q->len, q->len * LIMB_BITS - q->expn)))) {
1806 goto do_sub_r;
1807 }
1808 } else if (is_ceil) {
1809 do_sub_r:
1810 ret = bf_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
1811 ret |= bf_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
1812 if (ret & BF_ST_MEM_ERROR)
1813 goto fail;
1814 }
1815 }
1816
1817 r->sign ^= a->sign;
1818 q->sign = q_sign;
1819 return bf_round(r, prec, flags);
1820 fail:
1821 bf_set_nan(q);
1822 bf_set_nan(r);
1823 return BF_ST_MEM_ERROR;
1824 }
1825
bf_rem(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1826 int bf_rem(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1827 bf_flags_t flags, int rnd_mode)
1828 {
1829 bf_t q_s, *q = &q_s;
1830 int ret;
1831
1832 bf_init(r->ctx, q);
1833 ret = bf_divrem(q, r, a, b, prec, flags, rnd_mode);
1834 bf_delete(q);
1835 return ret;
1836 }
1837
bf_get_limb(slimb_t * pres,const bf_t * a,int flags)1838 static inline int bf_get_limb(slimb_t *pres, const bf_t *a, int flags)
1839 {
1840 #if LIMB_BITS == 32
1841 return bf_get_int32(pres, a, flags);
1842 #else
1843 return bf_get_int64(pres, a, flags);
1844 #endif
1845 }
1846
bf_remquo(slimb_t * pq,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1847 int bf_remquo(slimb_t *pq, bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1848 bf_flags_t flags, int rnd_mode)
1849 {
1850 bf_t q_s, *q = &q_s;
1851 int ret;
1852
1853 bf_init(r->ctx, q);
1854 ret = bf_divrem(q, r, a, b, prec, flags, rnd_mode);
1855 bf_get_limb(pq, q, BF_GET_INT_MOD);
1856 bf_delete(q);
1857 return ret;
1858 }
1859
mul_mod(limb_t a,limb_t b,limb_t m)1860 static __maybe_unused inline limb_t mul_mod(limb_t a, limb_t b, limb_t m)
1861 {
1862 dlimb_t t;
1863 t = (dlimb_t)a * (dlimb_t)b;
1864 return t % m;
1865 }
1866
1867 #if defined(USE_MUL_CHECK)
mp_mod1(const limb_t * tab,limb_t n,limb_t m,limb_t r)1868 static limb_t mp_mod1(const limb_t *tab, limb_t n, limb_t m, limb_t r)
1869 {
1870 slimb_t i;
1871 dlimb_t t;
1872
1873 for(i = n - 1; i >= 0; i--) {
1874 t = ((dlimb_t)r << LIMB_BITS) | tab[i];
1875 r = t % m;
1876 }
1877 return r;
1878 }
1879 #endif
1880
1881 static const uint16_t sqrt_table[192] = {
1882 128,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,144,145,146,147,148,149,150,150,151,152,153,154,155,155,156,157,158,159,160,160,161,162,163,163,164,165,166,167,167,168,169,170,170,171,172,173,173,174,175,176,176,177,178,178,179,180,181,181,182,183,183,184,185,185,186,187,187,188,189,189,190,191,192,192,193,193,194,195,195,196,197,197,198,199,199,200,201,201,202,203,203,204,204,205,206,206,207,208,208,209,209,210,211,211,212,212,213,214,214,215,215,216,217,217,218,218,219,219,220,221,221,222,222,223,224,224,225,225,226,226,227,227,228,229,229,230,230,231,231,232,232,233,234,234,235,235,236,236,237,237,238,238,239,240,240,241,241,242,242,243,243,244,244,245,245,246,246,247,247,248,248,249,249,250,250,251,251,252,252,253,253,254,254,255,
1883 };
1884
1885 /* a >= 2^(LIMB_BITS - 2). Return (s, r) with s=floor(sqrt(a)) and
1886 r=a-s^2. 0 <= r <= 2 * s */
mp_sqrtrem1(limb_t * pr,limb_t a)1887 static limb_t mp_sqrtrem1(limb_t *pr, limb_t a)
1888 {
1889 limb_t s1, r1, s, r, q, u, num;
1890
1891 /* use a table for the 16 -> 8 bit sqrt */
1892 s1 = sqrt_table[(a >> (LIMB_BITS - 8)) - 64];
1893 r1 = (a >> (LIMB_BITS - 16)) - s1 * s1;
1894 if (r1 > 2 * s1) {
1895 r1 -= 2 * s1 + 1;
1896 s1++;
1897 }
1898
1899 /* one iteration to get a 32 -> 16 bit sqrt */
1900 num = (r1 << 8) | ((a >> (LIMB_BITS - 32 + 8)) & 0xff);
1901 q = num / (2 * s1); /* q <= 2^8 */
1902 u = num % (2 * s1);
1903 s = (s1 << 8) + q;
1904 r = (u << 8) | ((a >> (LIMB_BITS - 32)) & 0xff);
1905 r -= q * q;
1906 if ((slimb_t)r < 0) {
1907 s--;
1908 r += 2 * s + 1;
1909 }
1910
1911 #if LIMB_BITS == 64
1912 s1 = s;
1913 r1 = r;
1914 /* one more iteration for 64 -> 32 bit sqrt */
1915 num = (r1 << 16) | ((a >> (LIMB_BITS - 64 + 16)) & 0xffff);
1916 q = num / (2 * s1); /* q <= 2^16 */
1917 u = num % (2 * s1);
1918 s = (s1 << 16) + q;
1919 r = (u << 16) | ((a >> (LIMB_BITS - 64)) & 0xffff);
1920 r -= q * q;
1921 if ((slimb_t)r < 0) {
1922 s--;
1923 r += 2 * s + 1;
1924 }
1925 #endif
1926 *pr = r;
1927 return s;
1928 }
1929
1930 /* return floor(sqrt(a)) */
bf_isqrt(limb_t a)1931 limb_t bf_isqrt(limb_t a)
1932 {
1933 limb_t s, r;
1934 int k;
1935
1936 if (a == 0)
1937 return 0;
1938 k = clz(a) & ~1;
1939 s = mp_sqrtrem1(&r, a << k);
1940 s >>= (k >> 1);
1941 return s;
1942 }
1943
mp_sqrtrem2(limb_t * tabs,limb_t * taba)1944 static limb_t mp_sqrtrem2(limb_t *tabs, limb_t *taba)
1945 {
1946 limb_t s1, r1, s, q, u, a0, a1;
1947 dlimb_t r, num;
1948 int l;
1949
1950 a0 = taba[0];
1951 a1 = taba[1];
1952 s1 = mp_sqrtrem1(&r1, a1);
1953 l = LIMB_BITS / 2;
1954 num = ((dlimb_t)r1 << l) | (a0 >> l);
1955 q = num / (2 * s1);
1956 u = num % (2 * s1);
1957 s = (s1 << l) + q;
1958 r = ((dlimb_t)u << l) | (a0 & (((limb_t)1 << l) - 1));
1959 if (unlikely((q >> l) != 0))
1960 r -= (dlimb_t)1 << LIMB_BITS; /* special case when q=2^l */
1961 else
1962 r -= q * q;
1963 if ((slimb_t)(r >> LIMB_BITS) < 0) {
1964 s--;
1965 r += 2 * (dlimb_t)s + 1;
1966 }
1967 tabs[0] = s;
1968 taba[0] = r;
1969 return r >> LIMB_BITS;
1970 }
1971
1972 //#define DEBUG_SQRTREM
1973
1974 /* tmp_buf must contain (n / 2 + 1 limbs). *prh contains the highest
1975 limb of the remainder. */
mp_sqrtrem_rec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf,limb_t * prh)1976 static int mp_sqrtrem_rec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n,
1977 limb_t *tmp_buf, limb_t *prh)
1978 {
1979 limb_t l, h, rh, ql, qh, c, i;
1980
1981 if (n == 1) {
1982 *prh = mp_sqrtrem2(tabs, taba);
1983 return 0;
1984 }
1985 #ifdef DEBUG_SQRTREM
1986 mp_print_str("a", taba, 2 * n);
1987 #endif
1988 l = n / 2;
1989 h = n - l;
1990 if (mp_sqrtrem_rec(s, tabs + l, taba + 2 * l, h, tmp_buf, &qh))
1991 return -1;
1992 #ifdef DEBUG_SQRTREM
1993 mp_print_str("s1", tabs + l, h);
1994 mp_print_str_h("r1", taba + 2 * l, h, qh);
1995 mp_print_str_h("r2", taba + l, n, qh);
1996 #endif
1997
1998 /* the remainder is in taba + 2 * l. Its high bit is in qh */
1999 if (qh) {
2000 mp_sub(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
2001 }
2002 /* instead of dividing by 2*s, divide by s (which is normalized)
2003 and update q and r */
2004 if (mp_divnorm(s, tmp_buf, taba + l, n, tabs + l, h))
2005 return -1;
2006 qh += tmp_buf[l];
2007 for(i = 0; i < l; i++)
2008 tabs[i] = tmp_buf[i];
2009 ql = mp_shr(tabs, tabs, l, 1, qh & 1);
2010 qh = qh >> 1; /* 0 or 1 */
2011 if (ql)
2012 rh = mp_add(taba + l, taba + l, tabs + l, h, 0);
2013 else
2014 rh = 0;
2015 #ifdef DEBUG_SQRTREM
2016 mp_print_str_h("q", tabs, l, qh);
2017 mp_print_str_h("u", taba + l, h, rh);
2018 #endif
2019
2020 mp_add_ui(tabs + l, qh, h);
2021 #ifdef DEBUG_SQRTREM
2022 mp_print_str_h("s2", tabs, n, sh);
2023 #endif
2024
2025 /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
2026 /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
2027 if (qh) {
2028 c = qh;
2029 } else {
2030 if (mp_mul(s, taba + n, tabs, l, tabs, l))
2031 return -1;
2032 c = mp_sub(taba, taba, taba + n, 2 * l, 0);
2033 }
2034 rh -= mp_sub_ui(taba + 2 * l, c, n - 2 * l);
2035 if ((slimb_t)rh < 0) {
2036 mp_sub_ui(tabs, 1, n);
2037 rh += mp_add_mul1(taba, tabs, n, 2);
2038 rh += mp_add_ui(taba, 1, n);
2039 }
2040 *prh = rh;
2041 return 0;
2042 }
2043
2044 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= 2 ^ (LIMB_BITS
2045 - 2). Return (s, r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2
2046 * s. tabs has n limbs. r is returned in the lower n limbs of
2047 taba. Its r[n] is the returned value of the function. */
2048 /* Algorithm from the article "Karatsuba Square Root" by Paul Zimmermann and
2049 inspirated from its GMP implementation */
mp_sqrtrem(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)2050 int mp_sqrtrem(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
2051 {
2052 limb_t tmp_buf1[8];
2053 limb_t *tmp_buf;
2054 mp_size_t n2;
2055 int ret;
2056 n2 = n / 2 + 1;
2057 if (n2 <= countof(tmp_buf1)) {
2058 tmp_buf = tmp_buf1;
2059 } else {
2060 tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
2061 if (!tmp_buf)
2062 return -1;
2063 }
2064 ret = mp_sqrtrem_rec(s, tabs, taba, n, tmp_buf, taba + n);
2065 if (tmp_buf != tmp_buf1)
2066 bf_free(s, tmp_buf);
2067 return ret;
2068 }
2069
2070 /* Integer square root with remainder. 'a' must be an integer. r =
2071 floor(sqrt(a)) and rem = a - r^2. BF_ST_INEXACT is set if the result
2072 is inexact. 'rem' can be NULL if the remainder is not needed. */
bf_sqrtrem(bf_t * r,bf_t * rem1,const bf_t * a)2073 int bf_sqrtrem(bf_t *r, bf_t *rem1, const bf_t *a)
2074 {
2075 int ret;
2076
2077 if (a->len == 0) {
2078 if (a->expn == BF_EXP_NAN) {
2079 bf_set_nan(r);
2080 } else if (a->expn == BF_EXP_INF && a->sign) {
2081 goto invalid_op;
2082 } else {
2083 bf_set(r, a);
2084 }
2085 if (rem1)
2086 bf_set_ui(rem1, 0);
2087 ret = 0;
2088 } else if (a->sign) {
2089 invalid_op:
2090 bf_set_nan(r);
2091 if (rem1)
2092 bf_set_ui(rem1, 0);
2093 ret = BF_ST_INVALID_OP;
2094 } else {
2095 bf_t rem_s, *rem;
2096
2097 bf_sqrt(r, a, (a->expn + 1) / 2, BF_RNDZ);
2098 bf_rint(r, BF_RNDZ);
2099 /* see if the result is exact by computing the remainder */
2100 if (rem1) {
2101 rem = rem1;
2102 } else {
2103 rem = &rem_s;
2104 bf_init(r->ctx, rem);
2105 }
2106 /* XXX: could avoid recomputing the remainder */
2107 bf_mul(rem, r, r, BF_PREC_INF, BF_RNDZ);
2108 bf_neg(rem);
2109 bf_add(rem, rem, a, BF_PREC_INF, BF_RNDZ);
2110 if (bf_is_nan(rem)) {
2111 ret = BF_ST_MEM_ERROR;
2112 goto done;
2113 }
2114 if (rem->len != 0) {
2115 ret = BF_ST_INEXACT;
2116 } else {
2117 ret = 0;
2118 }
2119 done:
2120 if (!rem1)
2121 bf_delete(rem);
2122 }
2123 return ret;
2124 }
2125
bf_sqrt(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)2126 int bf_sqrt(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
2127 {
2128 bf_context_t *s = a->ctx;
2129 int ret;
2130
2131 assert(r != a);
2132
2133 if (a->len == 0) {
2134 if (a->expn == BF_EXP_NAN) {
2135 bf_set_nan(r);
2136 } else if (a->expn == BF_EXP_INF && a->sign) {
2137 goto invalid_op;
2138 } else {
2139 bf_set(r, a);
2140 }
2141 ret = 0;
2142 } else if (a->sign) {
2143 invalid_op:
2144 bf_set_nan(r);
2145 ret = BF_ST_INVALID_OP;
2146 } else {
2147 limb_t *a1;
2148 slimb_t n, n1;
2149 limb_t res;
2150
2151 /* convert the mantissa to an integer with at least 2 *
2152 prec + 4 bits */
2153 n = (2 * (prec + 2) + 2 * LIMB_BITS - 1) / (2 * LIMB_BITS);
2154 if (bf_resize(r, n))
2155 goto fail;
2156 a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
2157 if (!a1)
2158 goto fail;
2159 n1 = bf_min(2 * n, a->len);
2160 memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
2161 memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
2162 if (a->expn & 1) {
2163 res = mp_shr(a1, a1, 2 * n, 1, 0);
2164 } else {
2165 res = 0;
2166 }
2167 if (mp_sqrtrem(s, r->tab, a1, n)) {
2168 bf_free(s, a1);
2169 goto fail;
2170 }
2171 if (!res) {
2172 res = mp_scan_nz(a1, n + 1);
2173 }
2174 bf_free(s, a1);
2175 if (!res) {
2176 res = mp_scan_nz(a->tab, a->len - n1);
2177 }
2178 if (res != 0)
2179 r->tab[0] |= 1;
2180 r->sign = 0;
2181 r->expn = (a->expn + 1) >> 1;
2182 ret = bf_round(r, prec, flags);
2183 }
2184 return ret;
2185 fail:
2186 bf_set_nan(r);
2187 return BF_ST_MEM_ERROR;
2188 }
2189
bf_op2(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,bf_op2_func_t * func)2190 static no_inline int bf_op2(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2191 bf_flags_t flags, bf_op2_func_t *func)
2192 {
2193 bf_t tmp;
2194 int ret;
2195
2196 if (r == a || r == b) {
2197 bf_init(r->ctx, &tmp);
2198 ret = func(&tmp, a, b, prec, flags);
2199 bf_move(r, &tmp);
2200 } else {
2201 ret = func(r, a, b, prec, flags);
2202 }
2203 return ret;
2204 }
2205
bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2206 int bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2207 bf_flags_t flags)
2208 {
2209 return bf_op2(r, a, b, prec, flags, __bf_add);
2210 }
2211
bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2212 int bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2213 bf_flags_t flags)
2214 {
2215 return bf_op2(r, a, b, prec, flags, __bf_sub);
2216 }
2217
bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2218 int bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2219 bf_flags_t flags)
2220 {
2221 return bf_op2(r, a, b, prec, flags, __bf_div);
2222 }
2223
bf_mul_ui(bf_t * r,const bf_t * a,uint64_t b1,limb_t prec,bf_flags_t flags)2224 int bf_mul_ui(bf_t *r, const bf_t *a, uint64_t b1, limb_t prec,
2225 bf_flags_t flags)
2226 {
2227 bf_t b;
2228 int ret;
2229 bf_init(r->ctx, &b);
2230 ret = bf_set_ui(&b, b1);
2231 ret |= bf_mul(r, a, &b, prec, flags);
2232 bf_delete(&b);
2233 return ret;
2234 }
2235
bf_mul_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2236 int bf_mul_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2237 bf_flags_t flags)
2238 {
2239 bf_t b;
2240 int ret;
2241 bf_init(r->ctx, &b);
2242 ret = bf_set_si(&b, b1);
2243 ret |= bf_mul(r, a, &b, prec, flags);
2244 bf_delete(&b);
2245 return ret;
2246 }
2247
bf_add_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2248 int bf_add_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2249 bf_flags_t flags)
2250 {
2251 bf_t b;
2252 int ret;
2253
2254 bf_init(r->ctx, &b);
2255 ret = bf_set_si(&b, b1);
2256 ret |= bf_add(r, a, &b, prec, flags);
2257 bf_delete(&b);
2258 return ret;
2259 }
2260
bf_pow_ui(bf_t * r,const bf_t * a,limb_t b,limb_t prec,bf_flags_t flags)2261 static int bf_pow_ui(bf_t *r, const bf_t *a, limb_t b, limb_t prec,
2262 bf_flags_t flags)
2263 {
2264 int ret, n_bits, i;
2265
2266 assert(r != a);
2267 if (b == 0)
2268 return bf_set_ui(r, 1);
2269 ret = bf_set(r, a);
2270 n_bits = LIMB_BITS - clz(b);
2271 for(i = n_bits - 2; i >= 0; i--) {
2272 ret |= bf_mul(r, r, r, prec, flags);
2273 if ((b >> i) & 1)
2274 ret |= bf_mul(r, r, a, prec, flags);
2275 }
2276 return ret;
2277 }
2278
bf_pow_ui_ui(bf_t * r,limb_t a1,limb_t b,limb_t prec,bf_flags_t flags)2279 static int bf_pow_ui_ui(bf_t *r, limb_t a1, limb_t b,
2280 limb_t prec, bf_flags_t flags)
2281 {
2282 bf_t a;
2283 int ret;
2284
2285 if (a1 == 10 && b <= LIMB_DIGITS) {
2286 /* use precomputed powers. We do not round at this point
2287 because we expect the caller to do it */
2288 ret = bf_set_ui(r, mp_pow_dec[b]);
2289 } else {
2290 bf_init(r->ctx, &a);
2291 ret = bf_set_ui(&a, a1);
2292 ret |= bf_pow_ui(r, &a, b, prec, flags);
2293 bf_delete(&a);
2294 }
2295 return ret;
2296 }
2297
2298 /* convert to integer (infinite precision) */
bf_rint(bf_t * r,int rnd_mode)2299 int bf_rint(bf_t *r, int rnd_mode)
2300 {
2301 return bf_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
2302 }
2303
2304 /* logical operations */
2305 #define BF_LOGIC_OR 0
2306 #define BF_LOGIC_XOR 1
2307 #define BF_LOGIC_AND 2
2308
bf_logic_op1(limb_t a,limb_t b,int op)2309 static inline limb_t bf_logic_op1(limb_t a, limb_t b, int op)
2310 {
2311 switch(op) {
2312 case BF_LOGIC_OR:
2313 return a | b;
2314 case BF_LOGIC_XOR:
2315 return a ^ b;
2316 default:
2317 case BF_LOGIC_AND:
2318 return a & b;
2319 }
2320 }
2321
bf_logic_op(bf_t * r,const bf_t * a1,const bf_t * b1,int op)2322 static int bf_logic_op(bf_t *r, const bf_t *a1, const bf_t *b1, int op)
2323 {
2324 bf_t b1_s, a1_s, *a, *b;
2325 limb_t a_sign, b_sign, r_sign;
2326 slimb_t l, i, a_bit_offset, b_bit_offset;
2327 limb_t v1, v2, v1_mask, v2_mask, r_mask;
2328 int ret;
2329
2330 assert(r != a1 && r != b1);
2331
2332 if (a1->expn <= 0)
2333 a_sign = 0; /* minus zero is considered as positive */
2334 else
2335 a_sign = a1->sign;
2336
2337 if (b1->expn <= 0)
2338 b_sign = 0; /* minus zero is considered as positive */
2339 else
2340 b_sign = b1->sign;
2341
2342 if (a_sign) {
2343 a = &a1_s;
2344 bf_init(r->ctx, a);
2345 if (bf_add_si(a, a1, 1, BF_PREC_INF, BF_RNDZ)) {
2346 b = NULL;
2347 goto fail;
2348 }
2349 } else {
2350 a = (bf_t *)a1;
2351 }
2352
2353 if (b_sign) {
2354 b = &b1_s;
2355 bf_init(r->ctx, b);
2356 if (bf_add_si(b, b1, 1, BF_PREC_INF, BF_RNDZ))
2357 goto fail;
2358 } else {
2359 b = (bf_t *)b1;
2360 }
2361
2362 r_sign = bf_logic_op1(a_sign, b_sign, op);
2363 if (op == BF_LOGIC_AND && r_sign == 0) {
2364 /* no need to compute extra zeros for and */
2365 if (a_sign == 0 && b_sign == 0)
2366 l = bf_min(a->expn, b->expn);
2367 else if (a_sign == 0)
2368 l = a->expn;
2369 else
2370 l = b->expn;
2371 } else {
2372 l = bf_max(a->expn, b->expn);
2373 }
2374 /* Note: a or b can be zero */
2375 l = (bf_max(l, 1) + LIMB_BITS - 1) / LIMB_BITS;
2376 if (bf_resize(r, l))
2377 goto fail;
2378 a_bit_offset = a->len * LIMB_BITS - a->expn;
2379 b_bit_offset = b->len * LIMB_BITS - b->expn;
2380 v1_mask = -a_sign;
2381 v2_mask = -b_sign;
2382 r_mask = -r_sign;
2383 for(i = 0; i < l; i++) {
2384 v1 = get_bits(a->tab, a->len, a_bit_offset + i * LIMB_BITS) ^ v1_mask;
2385 v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS) ^ v2_mask;
2386 r->tab[i] = bf_logic_op1(v1, v2, op) ^ r_mask;
2387 }
2388 r->expn = l * LIMB_BITS;
2389 r->sign = r_sign;
2390 bf_normalize_and_round(r, BF_PREC_INF, BF_RNDZ); /* cannot fail */
2391 if (r_sign) {
2392 if (bf_add_si(r, r, -1, BF_PREC_INF, BF_RNDZ))
2393 goto fail;
2394 }
2395 ret = 0;
2396 done:
2397 if (a == &a1_s)
2398 bf_delete(a);
2399 if (b == &b1_s)
2400 bf_delete(b);
2401 return ret;
2402 fail:
2403 bf_set_nan(r);
2404 ret = BF_ST_MEM_ERROR;
2405 goto done;
2406 }
2407
2408 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_or(bf_t * r,const bf_t * a,const bf_t * b)2409 int bf_logic_or(bf_t *r, const bf_t *a, const bf_t *b)
2410 {
2411 return bf_logic_op(r, a, b, BF_LOGIC_OR);
2412 }
2413
2414 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_xor(bf_t * r,const bf_t * a,const bf_t * b)2415 int bf_logic_xor(bf_t *r, const bf_t *a, const bf_t *b)
2416 {
2417 return bf_logic_op(r, a, b, BF_LOGIC_XOR);
2418 }
2419
2420 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_and(bf_t * r,const bf_t * a,const bf_t * b)2421 int bf_logic_and(bf_t *r, const bf_t *a, const bf_t *b)
2422 {
2423 return bf_logic_op(r, a, b, BF_LOGIC_AND);
2424 }
2425
2426 /* conversion between fixed size types */
2427
2428 typedef union {
2429 double d;
2430 uint64_t u;
2431 } Float64Union;
2432
bf_get_float64(const bf_t * a,double * pres,bf_rnd_t rnd_mode)2433 int bf_get_float64(const bf_t *a, double *pres, bf_rnd_t rnd_mode)
2434 {
2435 Float64Union u;
2436 int e, ret;
2437 uint64_t m;
2438
2439 ret = 0;
2440 if (a->expn == BF_EXP_NAN) {
2441 u.u = 0x7ff8000000000000; /* quiet nan */
2442 } else {
2443 bf_t b_s, *b = &b_s;
2444
2445 bf_init(a->ctx, b);
2446 bf_set(b, a);
2447 if (bf_is_finite(b)) {
2448 ret = bf_round(b, 53, rnd_mode | BF_FLAG_SUBNORMAL | bf_set_exp_bits(11));
2449 }
2450 if (b->expn == BF_EXP_INF) {
2451 e = (1 << 11) - 1;
2452 m = 0;
2453 } else if (b->expn == BF_EXP_ZERO) {
2454 e = 0;
2455 m = 0;
2456 } else {
2457 e = b->expn + 1023 - 1;
2458 #if LIMB_BITS == 32
2459 if (b->len == 2) {
2460 m = ((uint64_t)b->tab[1] << 32) | b->tab[0];
2461 } else {
2462 m = ((uint64_t)b->tab[0] << 32);
2463 }
2464 #else
2465 m = b->tab[0];
2466 #endif
2467 if (e <= 0) {
2468 /* subnormal */
2469 m = m >> (12 - e);
2470 e = 0;
2471 } else {
2472 m = (m << 1) >> 12;
2473 }
2474 }
2475 u.u = m | ((uint64_t)e << 52) | ((uint64_t)b->sign << 63);
2476 bf_delete(b);
2477 }
2478 *pres = u.d;
2479 return ret;
2480 }
2481
bf_set_float64(bf_t * a,double d)2482 int bf_set_float64(bf_t *a, double d)
2483 {
2484 Float64Union u;
2485 uint64_t m;
2486 int shift, e, sgn;
2487
2488 u.d = d;
2489 sgn = u.u >> 63;
2490 e = (u.u >> 52) & ((1 << 11) - 1);
2491 m = u.u & (((uint64_t)1 << 52) - 1);
2492 if (e == ((1 << 11) - 1)) {
2493 if (m != 0) {
2494 bf_set_nan(a);
2495 } else {
2496 bf_set_inf(a, sgn);
2497 }
2498 } else if (e == 0) {
2499 if (m == 0) {
2500 bf_set_zero(a, sgn);
2501 } else {
2502 /* subnormal number */
2503 m <<= 12;
2504 shift = clz64(m);
2505 m <<= shift;
2506 e = -shift;
2507 goto norm;
2508 }
2509 } else {
2510 m = (m << 11) | ((uint64_t)1 << 63);
2511 norm:
2512 a->expn = e - 1023 + 1;
2513 #if LIMB_BITS == 32
2514 if (bf_resize(a, 2))
2515 goto fail;
2516 a->tab[0] = m;
2517 a->tab[1] = m >> 32;
2518 #else
2519 if (bf_resize(a, 1))
2520 goto fail;
2521 a->tab[0] = m;
2522 #endif
2523 a->sign = sgn;
2524 }
2525 return 0;
2526 fail:
2527 bf_set_nan(a);
2528 return BF_ST_MEM_ERROR;
2529 }
2530
2531 /* The rounding mode is always BF_RNDZ. Return BF_ST_INVALID_OP if there
2532 is an overflow and 0 otherwise. */
bf_get_int32(int * pres,const bf_t * a,int flags)2533 int bf_get_int32(int *pres, const bf_t *a, int flags)
2534 {
2535 uint32_t v;
2536 int ret;
2537 if (a->expn >= BF_EXP_INF) {
2538 ret = BF_ST_INVALID_OP;
2539 if (flags & BF_GET_INT_MOD) {
2540 v = 0;
2541 } else if (a->expn == BF_EXP_INF) {
2542 v = (uint32_t)INT32_MAX + a->sign;
2543 } else {
2544 v = INT32_MAX;
2545 }
2546 } else if (a->expn <= 0) {
2547 v = 0;
2548 ret = 0;
2549 } else if (a->expn <= 31) {
2550 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2551 if (a->sign)
2552 v = -v;
2553 ret = 0;
2554 } else if (!(flags & BF_GET_INT_MOD)) {
2555 ret = BF_ST_INVALID_OP;
2556 if (a->sign) {
2557 v = (uint32_t)INT32_MAX + 1;
2558 if (a->expn == 32 &&
2559 (a->tab[a->len - 1] >> (LIMB_BITS - 32)) == v) {
2560 ret = 0;
2561 }
2562 } else {
2563 v = INT32_MAX;
2564 }
2565 } else {
2566 v = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
2567 if (a->sign)
2568 v = -v;
2569 ret = 0;
2570 }
2571 *pres = v;
2572 return ret;
2573 }
2574
2575 /* The rounding mode is always BF_RNDZ. Return BF_ST_INVALID_OP if there
2576 is an overflow and 0 otherwise. */
bf_get_int64(int64_t * pres,const bf_t * a,int flags)2577 int bf_get_int64(int64_t *pres, const bf_t *a, int flags)
2578 {
2579 uint64_t v;
2580 int ret;
2581 if (a->expn >= BF_EXP_INF) {
2582 ret = BF_ST_INVALID_OP;
2583 if (flags & BF_GET_INT_MOD) {
2584 v = 0;
2585 } else if (a->expn == BF_EXP_INF) {
2586 v = (uint64_t)INT64_MAX + a->sign;
2587 } else {
2588 v = INT64_MAX;
2589 }
2590 } else if (a->expn <= 0) {
2591 v = 0;
2592 ret = 0;
2593 } else if (a->expn <= 63) {
2594 #if LIMB_BITS == 32
2595 if (a->expn <= 32)
2596 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2597 else
2598 v = (((uint64_t)a->tab[a->len - 1] << 32) |
2599 get_limbz(a, a->len - 2)) >> (64 - a->expn);
2600 #else
2601 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2602 #endif
2603 if (a->sign)
2604 v = -v;
2605 ret = 0;
2606 } else if (!(flags & BF_GET_INT_MOD)) {
2607 ret = BF_ST_INVALID_OP;
2608 if (a->sign) {
2609 uint64_t v1;
2610 v = (uint64_t)INT64_MAX + 1;
2611 if (a->expn == 64) {
2612 v1 = a->tab[a->len - 1];
2613 #if LIMB_BITS == 32
2614 v1 = (v1 << 32) | get_limbz(a, a->len - 2);
2615 #endif
2616 if (v1 == v)
2617 ret = 0;
2618 }
2619 } else {
2620 v = INT64_MAX;
2621 }
2622 } else {
2623 slimb_t bit_pos = a->len * LIMB_BITS - a->expn;
2624 v = get_bits(a->tab, a->len, bit_pos);
2625 #if LIMB_BITS == 32
2626 v |= (uint64_t)get_bits(a->tab, a->len, bit_pos + 32) << 32;
2627 #endif
2628 if (a->sign)
2629 v = -v;
2630 ret = 0;
2631 }
2632 *pres = v;
2633 return ret;
2634 }
2635
2636 /* The rounding mode is always BF_RNDZ. Return BF_ST_INVALID_OP if there
2637 is an overflow and 0 otherwise. */
bf_get_uint64(uint64_t * pres,const bf_t * a)2638 int bf_get_uint64(uint64_t *pres, const bf_t *a)
2639 {
2640 uint64_t v;
2641 int ret;
2642 if (a->expn == BF_EXP_NAN) {
2643 goto overflow;
2644 } else if (a->expn <= 0) {
2645 v = 0;
2646 ret = 0;
2647 } else if (a->sign) {
2648 v = 0;
2649 ret = BF_ST_INVALID_OP;
2650 } else if (a->expn <= 64) {
2651 #if LIMB_BITS == 32
2652 if (a->expn <= 32)
2653 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2654 else
2655 v = (((uint64_t)a->tab[a->len - 1] << 32) |
2656 get_limbz(a, a->len - 2)) >> (64 - a->expn);
2657 #else
2658 v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2659 #endif
2660 ret = 0;
2661 } else {
2662 overflow:
2663 v = UINT64_MAX;
2664 ret = BF_ST_INVALID_OP;
2665 }
2666 *pres = v;
2667 return ret;
2668 }
2669
2670 /* base conversion from radix */
2671
2672 static const uint8_t digits_per_limb_table[BF_RADIX_MAX - 1] = {
2673 #if LIMB_BITS == 32
2674 32,20,16,13,12,11,10,10, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
2675 #else
2676 64,40,32,27,24,22,21,20,19,18,17,17,16,16,16,15,15,15,14,14,14,14,13,13,13,13,13,13,13,12,12,12,12,12,12,
2677 #endif
2678 };
2679
get_limb_radix(int radix)2680 static limb_t get_limb_radix(int radix)
2681 {
2682 int i, k;
2683 limb_t radixl;
2684
2685 k = digits_per_limb_table[radix - 2];
2686 radixl = radix;
2687 for(i = 1; i < k; i++)
2688 radixl *= radix;
2689 return radixl;
2690 }
2691
2692 /* return != 0 if error */
bf_integer_from_radix_rec(bf_t * r,const limb_t * tab,limb_t n,int level,limb_t n0,limb_t radix,bf_t * pow_tab)2693 static int bf_integer_from_radix_rec(bf_t *r, const limb_t *tab,
2694 limb_t n, int level, limb_t n0,
2695 limb_t radix, bf_t *pow_tab)
2696 {
2697 int ret;
2698 if (n == 1) {
2699 ret = bf_set_ui(r, tab[0]);
2700 } else {
2701 bf_t T_s, *T = &T_s, *B;
2702 limb_t n1, n2;
2703
2704 n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
2705 n1 = n - n2;
2706 // printf("level=%d n0=%ld n1=%ld n2=%ld\n", level, n0, n1, n2);
2707 B = &pow_tab[level];
2708 if (B->len == 0) {
2709 ret = bf_pow_ui_ui(B, radix, n2, BF_PREC_INF, BF_RNDZ);
2710 if (ret)
2711 return ret;
2712 }
2713 ret = bf_integer_from_radix_rec(r, tab + n2, n1, level + 1, n0,
2714 radix, pow_tab);
2715 if (ret)
2716 return ret;
2717 ret = bf_mul(r, r, B, BF_PREC_INF, BF_RNDZ);
2718 if (ret)
2719 return ret;
2720 bf_init(r->ctx, T);
2721 ret = bf_integer_from_radix_rec(T, tab, n2, level + 1, n0,
2722 radix, pow_tab);
2723 if (!ret)
2724 ret = bf_add(r, r, T, BF_PREC_INF, BF_RNDZ);
2725 bf_delete(T);
2726 }
2727 return ret;
2728 // bf_print_str(" r=", r);
2729 }
2730
2731 /* return 0 if OK != 0 if memory error */
bf_integer_from_radix(bf_t * r,const limb_t * tab,limb_t n,limb_t radix)2732 static int bf_integer_from_radix(bf_t *r, const limb_t *tab,
2733 limb_t n, limb_t radix)
2734 {
2735 bf_context_t *s = r->ctx;
2736 int pow_tab_len, i, ret;
2737 limb_t radixl;
2738 bf_t *pow_tab;
2739
2740 radixl = get_limb_radix(radix);
2741 pow_tab_len = ceil_log2(n) + 2; /* XXX: check */
2742 pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
2743 if (!pow_tab)
2744 return -1;
2745 for(i = 0; i < pow_tab_len; i++)
2746 bf_init(r->ctx, &pow_tab[i]);
2747 ret = bf_integer_from_radix_rec(r, tab, n, 0, n, radixl, pow_tab);
2748 for(i = 0; i < pow_tab_len; i++) {
2749 bf_delete(&pow_tab[i]);
2750 }
2751 bf_free(s, pow_tab);
2752 return ret;
2753 }
2754
2755 /* compute and round T * radix^expn. */
bf_mul_pow_radix(bf_t * r,const bf_t * T,limb_t radix,slimb_t expn,limb_t prec,bf_flags_t flags)2756 int bf_mul_pow_radix(bf_t *r, const bf_t *T, limb_t radix,
2757 slimb_t expn, limb_t prec, bf_flags_t flags)
2758 {
2759 int ret, expn_sign, overflow;
2760 slimb_t e, extra_bits, prec1, ziv_extra_bits;
2761 bf_t B_s, *B = &B_s;
2762
2763 if (T->len == 0) {
2764 return bf_set(r, T);
2765 } else if (expn == 0) {
2766 ret = bf_set(r, T);
2767 ret |= bf_round(r, prec, flags);
2768 return ret;
2769 }
2770
2771 e = expn;
2772 expn_sign = 0;
2773 if (e < 0) {
2774 e = -e;
2775 expn_sign = 1;
2776 }
2777 bf_init(r->ctx, B);
2778 if (prec == BF_PREC_INF) {
2779 /* infinite precision: only used if the result is known to be exact */
2780 ret = bf_pow_ui_ui(B, radix, e, BF_PREC_INF, BF_RNDN);
2781 if (expn_sign) {
2782 ret |= bf_div(r, T, B, T->len * LIMB_BITS, BF_RNDN);
2783 } else {
2784 ret |= bf_mul(r, T, B, BF_PREC_INF, BF_RNDN);
2785 }
2786 } else {
2787 ziv_extra_bits = 16;
2788 for(;;) {
2789 prec1 = prec + ziv_extra_bits;
2790 /* XXX: correct overflow/underflow handling */
2791 /* XXX: rigorous error analysis needed */
2792 extra_bits = ceil_log2(e) * 2 + 1;
2793 ret = bf_pow_ui_ui(B, radix, e, prec1 + extra_bits, BF_RNDN | BF_FLAG_EXT_EXP);
2794 overflow = !bf_is_finite(B);
2795 /* XXX: if bf_pow_ui_ui returns an exact result, can stop
2796 after the next operation */
2797 if (expn_sign)
2798 ret |= bf_div(r, T, B, prec1 + extra_bits, BF_RNDN | BF_FLAG_EXT_EXP);
2799 else
2800 ret |= bf_mul(r, T, B, prec1 + extra_bits, BF_RNDN | BF_FLAG_EXT_EXP);
2801 if (ret & BF_ST_MEM_ERROR)
2802 break;
2803 if ((ret & BF_ST_INEXACT) &&
2804 !bf_can_round(r, prec, flags & BF_RND_MASK, prec1) &&
2805 !overflow) {
2806 /* and more precision and retry */
2807 ziv_extra_bits = ziv_extra_bits + (ziv_extra_bits / 2);
2808 } else {
2809 /* XXX: need to use __bf_round() to pass the inexact
2810 flag for the subnormal case */
2811 ret = bf_round(r, prec, flags) | (ret & BF_ST_INEXACT);
2812 break;
2813 }
2814 }
2815 }
2816 bf_delete(B);
2817 return ret;
2818 }
2819
to_digit(int c)2820 static inline int to_digit(int c)
2821 {
2822 if (c >= '0' && c <= '9')
2823 return c - '0';
2824 else if (c >= 'A' && c <= 'Z')
2825 return c - 'A' + 10;
2826 else if (c >= 'a' && c <= 'z')
2827 return c - 'a' + 10;
2828 else
2829 return 36;
2830 }
2831
2832 /* add a limb at 'pos' and decrement pos. new space is created if
2833 needed. Return 0 if OK, -1 if memory error */
bf_add_limb(bf_t * a,slimb_t * ppos,limb_t v)2834 static int bf_add_limb(bf_t *a, slimb_t *ppos, limb_t v)
2835 {
2836 slimb_t pos;
2837 pos = *ppos;
2838 if (unlikely(pos < 0)) {
2839 limb_t new_size, d, *new_tab;
2840 new_size = bf_max(a->len + 1, a->len * 3 / 2);
2841 new_tab = bf_realloc(a->ctx, a->tab, sizeof(limb_t) * new_size);
2842 if (!new_tab)
2843 return -1;
2844 a->tab = new_tab;
2845 d = new_size - a->len;
2846 memmove(a->tab + d, a->tab, a->len * sizeof(limb_t));
2847 a->len = new_size;
2848 pos += d;
2849 }
2850 a->tab[pos--] = v;
2851 *ppos = pos;
2852 return 0;
2853 }
2854
bf_tolower(int c)2855 static int bf_tolower(int c)
2856 {
2857 if (c >= 'A' && c <= 'Z')
2858 c = c - 'A' + 'a';
2859 return c;
2860 }
2861
strcasestart(const char * str,const char * val,const char ** ptr)2862 static int strcasestart(const char *str, const char *val, const char **ptr)
2863 {
2864 const char *p, *q;
2865 p = str;
2866 q = val;
2867 while (*q != '\0') {
2868 if (bf_tolower(*p) != *q)
2869 return 0;
2870 p++;
2871 q++;
2872 }
2873 if (ptr)
2874 *ptr = p;
2875 return 1;
2876 }
2877
bf_atof_internal(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)2878 static int bf_atof_internal(bf_t *r, slimb_t *pexponent,
2879 const char *str, const char **pnext, int radix,
2880 limb_t prec, bf_flags_t flags, BOOL is_dec)
2881 {
2882 const char *p, *p_start;
2883 int is_neg, radix_bits, exp_is_neg, ret, digits_per_limb, shift;
2884 limb_t cur_limb;
2885 slimb_t pos, expn, int_len, digit_count;
2886 BOOL has_decpt, is_bin_exp;
2887 bf_t a_s, *a;
2888
2889 *pexponent = 0;
2890 p = str;
2891 if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2892 strcasestart(p, "nan", &p)) {
2893 bf_set_nan(r);
2894 ret = 0;
2895 goto done;
2896 }
2897 is_neg = 0;
2898
2899 if (p[0] == '+') {
2900 p++;
2901 p_start = p;
2902 } else if (p[0] == '-') {
2903 is_neg = 1;
2904 p++;
2905 p_start = p;
2906 } else {
2907 p_start = p;
2908 }
2909 if (p[0] == '0') {
2910 if ((p[1] == 'x' || p[1] == 'X') &&
2911 (radix == 0 || radix == 16) &&
2912 !(flags & BF_ATOF_NO_HEX)) {
2913 radix = 16;
2914 p += 2;
2915 } else if ((p[1] == 'o' || p[1] == 'O') &&
2916 radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2917 p += 2;
2918 radix = 8;
2919 } else if ((p[1] == 'b' || p[1] == 'B') &&
2920 radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2921 p += 2;
2922 radix = 2;
2923 } else {
2924 goto no_prefix;
2925 }
2926 /* there must be a digit after the prefix */
2927 if (to_digit((uint8_t)*p) >= radix) {
2928 bf_set_nan(r);
2929 ret = 0;
2930 goto done;
2931 }
2932 no_prefix: ;
2933 } else {
2934 if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2935 strcasestart(p, "inf", &p)) {
2936 bf_set_inf(r, is_neg);
2937 ret = 0;
2938 goto done;
2939 }
2940 }
2941
2942 if (radix == 0)
2943 radix = 10;
2944 if (is_dec) {
2945 assert(radix == 10);
2946 radix_bits = 0;
2947 a = r;
2948 } else if ((radix & (radix - 1)) != 0) {
2949 radix_bits = 0; /* base is not a power of two */
2950 a = &a_s;
2951 bf_init(r->ctx, a);
2952 } else {
2953 radix_bits = ceil_log2(radix);
2954 a = r;
2955 }
2956
2957 /* skip leading zeros */
2958 /* XXX: could also skip zeros after the decimal point */
2959 while (*p == '0')
2960 p++;
2961
2962 if (radix_bits) {
2963 shift = digits_per_limb = LIMB_BITS;
2964 } else {
2965 radix_bits = 0;
2966 shift = digits_per_limb = digits_per_limb_table[radix - 2];
2967 }
2968 cur_limb = 0;
2969 bf_resize(a, 1);
2970 pos = 0;
2971 has_decpt = FALSE;
2972 int_len = digit_count = 0;
2973 for(;;) {
2974 limb_t c;
2975 if (*p == '.' && (p > p_start || to_digit(p[1]) < radix)) {
2976 if (has_decpt)
2977 break;
2978 has_decpt = TRUE;
2979 int_len = digit_count;
2980 p++;
2981 }
2982 c = to_digit(*p);
2983 if (c >= radix)
2984 break;
2985 digit_count++;
2986 p++;
2987 if (radix_bits) {
2988 shift -= radix_bits;
2989 if (shift <= 0) {
2990 cur_limb |= c >> (-shift);
2991 if (bf_add_limb(a, &pos, cur_limb))
2992 goto mem_error;
2993 if (shift < 0)
2994 cur_limb = c << (LIMB_BITS + shift);
2995 else
2996 cur_limb = 0;
2997 shift += LIMB_BITS;
2998 } else {
2999 cur_limb |= c << shift;
3000 }
3001 } else {
3002 cur_limb = cur_limb * radix + c;
3003 shift--;
3004 if (shift == 0) {
3005 if (bf_add_limb(a, &pos, cur_limb))
3006 goto mem_error;
3007 shift = digits_per_limb;
3008 cur_limb = 0;
3009 }
3010 }
3011 }
3012 if (!has_decpt)
3013 int_len = digit_count;
3014
3015 /* add the last limb and pad with zeros */
3016 if (shift != digits_per_limb) {
3017 if (radix_bits == 0) {
3018 while (shift != 0) {
3019 cur_limb *= radix;
3020 shift--;
3021 }
3022 }
3023 if (bf_add_limb(a, &pos, cur_limb)) {
3024 mem_error:
3025 ret = BF_ST_MEM_ERROR;
3026 if (!radix_bits)
3027 bf_delete(a);
3028 bf_set_nan(r);
3029 goto done;
3030 }
3031 }
3032
3033 /* reset the next limbs to zero (we prefer to reallocate in the
3034 renormalization) */
3035 memset(a->tab, 0, (pos + 1) * sizeof(limb_t));
3036
3037 if (p == p_start) {
3038 ret = 0;
3039 if (!radix_bits)
3040 bf_delete(a);
3041 bf_set_nan(r);
3042 goto done;
3043 }
3044
3045 /* parse the exponent, if any */
3046 expn = 0;
3047 is_bin_exp = FALSE;
3048 if (((radix == 10 && (*p == 'e' || *p == 'E')) ||
3049 (radix != 10 && (*p == '@' ||
3050 (radix_bits && (*p == 'p' || *p == 'P'))))) &&
3051 p > p_start) {
3052 is_bin_exp = (*p == 'p' || *p == 'P');
3053 p++;
3054 exp_is_neg = 0;
3055 if (*p == '+') {
3056 p++;
3057 } else if (*p == '-') {
3058 exp_is_neg = 1;
3059 p++;
3060 }
3061 for(;;) {
3062 int c;
3063 c = to_digit(*p);
3064 if (c >= 10)
3065 break;
3066 if (unlikely(expn > ((BF_RAW_EXP_MAX - 2 - 9) / 10))) {
3067 /* exponent overflow */
3068 if (exp_is_neg) {
3069 bf_set_zero(r, is_neg);
3070 ret = BF_ST_UNDERFLOW | BF_ST_INEXACT;
3071 } else {
3072 bf_set_inf(r, is_neg);
3073 ret = BF_ST_OVERFLOW | BF_ST_INEXACT;
3074 }
3075 goto done;
3076 }
3077 p++;
3078 expn = expn * 10 + c;
3079 }
3080 if (exp_is_neg)
3081 expn = -expn;
3082 }
3083 if (is_dec) {
3084 a->expn = expn + int_len;
3085 a->sign = is_neg;
3086 ret = bfdec_normalize_and_round((bfdec_t *)a, prec, flags);
3087 } else if (radix_bits) {
3088 /* XXX: may overflow */
3089 if (!is_bin_exp)
3090 expn *= radix_bits;
3091 a->expn = expn + (int_len * radix_bits);
3092 a->sign = is_neg;
3093 ret = bf_normalize_and_round(a, prec, flags);
3094 } else {
3095 limb_t l;
3096 pos++;
3097 l = a->len - pos; /* number of limbs */
3098 if (l == 0) {
3099 bf_set_zero(r, is_neg);
3100 ret = 0;
3101 } else {
3102 bf_t T_s, *T = &T_s;
3103
3104 expn -= l * digits_per_limb - int_len;
3105 bf_init(r->ctx, T);
3106 if (bf_integer_from_radix(T, a->tab + pos, l, radix)) {
3107 bf_set_nan(r);
3108 ret = BF_ST_MEM_ERROR;
3109 } else {
3110 T->sign = is_neg;
3111 if (flags & BF_ATOF_EXPONENT) {
3112 /* return the exponent */
3113 *pexponent = expn;
3114 ret = bf_set(r, T);
3115 } else {
3116 ret = bf_mul_pow_radix(r, T, radix, expn, prec, flags);
3117 }
3118 }
3119 bf_delete(T);
3120 }
3121 bf_delete(a);
3122 }
3123 done:
3124 if (pnext)
3125 *pnext = p;
3126 return ret;
3127 }
3128
3129 /*
3130 Return (status, n, exp). 'status' is the floating point status. 'n'
3131 is the parsed number.
3132
3133 If (flags & BF_ATOF_EXPONENT) and if the radix is not a power of
3134 two, the parsed number is equal to r *
3135 (*pexponent)^radix. Otherwise *pexponent = 0.
3136 */
bf_atof2(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3137 int bf_atof2(bf_t *r, slimb_t *pexponent,
3138 const char *str, const char **pnext, int radix,
3139 limb_t prec, bf_flags_t flags)
3140 {
3141 return bf_atof_internal(r, pexponent, str, pnext, radix, prec, flags,
3142 FALSE);
3143 }
3144
bf_atof(bf_t * r,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3145 int bf_atof(bf_t *r, const char *str, const char **pnext, int radix,
3146 limb_t prec, bf_flags_t flags)
3147 {
3148 slimb_t dummy_exp;
3149 return bf_atof_internal(r, &dummy_exp, str, pnext, radix, prec, flags, FALSE);
3150 }
3151
3152 /* base conversion to radix */
3153
3154 #if LIMB_BITS == 64
3155 #define RADIXL_10 UINT64_C(10000000000000000000)
3156 #else
3157 #define RADIXL_10 UINT64_C(1000000000)
3158 #endif
3159
3160 static const uint32_t inv_log2_radix[BF_RADIX_MAX - 1][LIMB_BITS / 32 + 1] = {
3161 #if LIMB_BITS == 32
3162 { 0x80000000, 0x00000000,},
3163 { 0x50c24e60, 0xd4d4f4a7,},
3164 { 0x40000000, 0x00000000,},
3165 { 0x372068d2, 0x0a1ee5ca,},
3166 { 0x3184648d, 0xb8153e7a,},
3167 { 0x2d983275, 0x9d5369c4,},
3168 { 0x2aaaaaaa, 0xaaaaaaab,},
3169 { 0x28612730, 0x6a6a7a54,},
3170 { 0x268826a1, 0x3ef3fde6,},
3171 { 0x25001383, 0xbac8a744,},
3172 { 0x23b46706, 0x82c0c709,},
3173 { 0x229729f1, 0xb2c83ded,},
3174 { 0x219e7ffd, 0xa5ad572b,},
3175 { 0x20c33b88, 0xda7c29ab,},
3176 { 0x20000000, 0x00000000,},
3177 { 0x1f50b57e, 0xac5884b3,},
3178 { 0x1eb22cc6, 0x8aa6e26f,},
3179 { 0x1e21e118, 0x0c5daab2,},
3180 { 0x1d9dcd21, 0x439834e4,},
3181 { 0x1d244c78, 0x367a0d65,},
3182 { 0x1cb40589, 0xac173e0c,},
3183 { 0x1c4bd95b, 0xa8d72b0d,},
3184 { 0x1bead768, 0x98f8ce4c,},
3185 { 0x1b903469, 0x050f72e5,},
3186 { 0x1b3b433f, 0x2eb06f15,},
3187 { 0x1aeb6f75, 0x9c46fc38,},
3188 { 0x1aa038eb, 0x0e3bfd17,},
3189 { 0x1a593062, 0xb38d8c56,},
3190 { 0x1a15f4c3, 0x2b95a2e6,},
3191 { 0x19d630dc, 0xcc7ddef9,},
3192 { 0x19999999, 0x9999999a,},
3193 { 0x195fec80, 0x8a609431,},
3194 { 0x1928ee7b, 0x0b4f22f9,},
3195 { 0x18f46acf, 0x8c06e318,},
3196 { 0x18c23246, 0xdc0a9f3d,},
3197 #else
3198 { 0x80000000, 0x00000000, 0x00000000,},
3199 { 0x50c24e60, 0xd4d4f4a7, 0x021f57bc,},
3200 { 0x40000000, 0x00000000, 0x00000000,},
3201 { 0x372068d2, 0x0a1ee5ca, 0x19ea911b,},
3202 { 0x3184648d, 0xb8153e7a, 0x7fc2d2e1,},
3203 { 0x2d983275, 0x9d5369c4, 0x4dec1661,},
3204 { 0x2aaaaaaa, 0xaaaaaaaa, 0xaaaaaaab,},
3205 { 0x28612730, 0x6a6a7a53, 0x810fabde,},
3206 { 0x268826a1, 0x3ef3fde6, 0x23e2566b,},
3207 { 0x25001383, 0xbac8a744, 0x385a3349,},
3208 { 0x23b46706, 0x82c0c709, 0x3f891718,},
3209 { 0x229729f1, 0xb2c83ded, 0x15fba800,},
3210 { 0x219e7ffd, 0xa5ad572a, 0xe169744b,},
3211 { 0x20c33b88, 0xda7c29aa, 0x9bddee52,},
3212 { 0x20000000, 0x00000000, 0x00000000,},
3213 { 0x1f50b57e, 0xac5884b3, 0x70e28eee,},
3214 { 0x1eb22cc6, 0x8aa6e26f, 0x06d1a2a2,},
3215 { 0x1e21e118, 0x0c5daab1, 0x81b4f4bf,},
3216 { 0x1d9dcd21, 0x439834e3, 0x81667575,},
3217 { 0x1d244c78, 0x367a0d64, 0xc8204d6d,},
3218 { 0x1cb40589, 0xac173e0c, 0x3b7b16ba,},
3219 { 0x1c4bd95b, 0xa8d72b0d, 0x5879f25a,},
3220 { 0x1bead768, 0x98f8ce4c, 0x66cc2858,},
3221 { 0x1b903469, 0x050f72e5, 0x0cf5488e,},
3222 { 0x1b3b433f, 0x2eb06f14, 0x8c89719c,},
3223 { 0x1aeb6f75, 0x9c46fc37, 0xab5fc7e9,},
3224 { 0x1aa038eb, 0x0e3bfd17, 0x1bd62080,},
3225 { 0x1a593062, 0xb38d8c56, 0x7998ab45,},
3226 { 0x1a15f4c3, 0x2b95a2e6, 0x46aed6a0,},
3227 { 0x19d630dc, 0xcc7ddef9, 0x5aadd61b,},
3228 { 0x19999999, 0x99999999, 0x9999999a,},
3229 { 0x195fec80, 0x8a609430, 0xe1106014,},
3230 { 0x1928ee7b, 0x0b4f22f9, 0x5f69791d,},
3231 { 0x18f46acf, 0x8c06e318, 0x4d2aeb2c,},
3232 { 0x18c23246, 0xdc0a9f3d, 0x3fe16970,},
3233 #endif
3234 };
3235
3236 static const limb_t log2_radix[BF_RADIX_MAX - 1] = {
3237 #if LIMB_BITS == 32
3238 0x20000000,
3239 0x32b80347,
3240 0x40000000,
3241 0x4a4d3c26,
3242 0x52b80347,
3243 0x59d5d9fd,
3244 0x60000000,
3245 0x6570068e,
3246 0x6a4d3c26,
3247 0x6eb3a9f0,
3248 0x72b80347,
3249 0x766a008e,
3250 0x79d5d9fd,
3251 0x7d053f6d,
3252 0x80000000,
3253 0x82cc7edf,
3254 0x8570068e,
3255 0x87ef05ae,
3256 0x8a4d3c26,
3257 0x8c8ddd45,
3258 0x8eb3a9f0,
3259 0x90c10501,
3260 0x92b80347,
3261 0x949a784c,
3262 0x966a008e,
3263 0x982809d6,
3264 0x99d5d9fd,
3265 0x9b74948f,
3266 0x9d053f6d,
3267 0x9e88c6b3,
3268 0xa0000000,
3269 0xa16bad37,
3270 0xa2cc7edf,
3271 0xa4231623,
3272 0xa570068e,
3273 #else
3274 0x2000000000000000,
3275 0x32b803473f7ad0f4,
3276 0x4000000000000000,
3277 0x4a4d3c25e68dc57f,
3278 0x52b803473f7ad0f4,
3279 0x59d5d9fd5010b366,
3280 0x6000000000000000,
3281 0x6570068e7ef5a1e8,
3282 0x6a4d3c25e68dc57f,
3283 0x6eb3a9f01975077f,
3284 0x72b803473f7ad0f4,
3285 0x766a008e4788cbcd,
3286 0x79d5d9fd5010b366,
3287 0x7d053f6d26089673,
3288 0x8000000000000000,
3289 0x82cc7edf592262d0,
3290 0x8570068e7ef5a1e8,
3291 0x87ef05ae409a0289,
3292 0x8a4d3c25e68dc57f,
3293 0x8c8ddd448f8b845a,
3294 0x8eb3a9f01975077f,
3295 0x90c10500d63aa659,
3296 0x92b803473f7ad0f4,
3297 0x949a784bcd1b8afe,
3298 0x966a008e4788cbcd,
3299 0x982809d5be7072dc,
3300 0x99d5d9fd5010b366,
3301 0x9b74948f5532da4b,
3302 0x9d053f6d26089673,
3303 0x9e88c6b3626a72aa,
3304 0xa000000000000000,
3305 0xa16bad3758efd873,
3306 0xa2cc7edf592262d0,
3307 0xa4231623369e78e6,
3308 0xa570068e7ef5a1e8,
3309 #endif
3310 };
3311
3312 /* compute floor(a*b) or ceil(a*b) with b = log2(radix) or
3313 b=1/log2(radix). For is_inv = 0, strict accuracy is not guaranteed
3314 when radix is not a power of two. */
bf_mul_log2_radix(slimb_t a1,unsigned int radix,int is_inv,int is_ceil1)3315 slimb_t bf_mul_log2_radix(slimb_t a1, unsigned int radix, int is_inv,
3316 int is_ceil1)
3317 {
3318 int is_neg;
3319 limb_t a;
3320 BOOL is_ceil;
3321
3322 is_ceil = is_ceil1;
3323 a = a1;
3324 if (a1 < 0) {
3325 a = -a;
3326 is_neg = 1;
3327 } else {
3328 is_neg = 0;
3329 }
3330 is_ceil ^= is_neg;
3331 if ((radix & (radix - 1)) == 0) {
3332 int radix_bits;
3333 /* radix is a power of two */
3334 radix_bits = ceil_log2(radix);
3335 if (is_inv) {
3336 if (is_ceil)
3337 a += radix_bits - 1;
3338 a = a / radix_bits;
3339 } else {
3340 a = a * radix_bits;
3341 }
3342 } else {
3343 const uint32_t *tab;
3344 limb_t b0, b1;
3345 dlimb_t t;
3346
3347 if (is_inv) {
3348 tab = inv_log2_radix[radix - 2];
3349 #if LIMB_BITS == 32
3350 b1 = tab[0];
3351 b0 = tab[1];
3352 #else
3353 b1 = ((limb_t)tab[0] << 32) | tab[1];
3354 b0 = (limb_t)tab[2] << 32;
3355 #endif
3356 t = (dlimb_t)b0 * (dlimb_t)a;
3357 t = (dlimb_t)b1 * (dlimb_t)a + (t >> LIMB_BITS);
3358 a = t >> (LIMB_BITS - 1);
3359 } else {
3360 b0 = log2_radix[radix - 2];
3361 t = (dlimb_t)b0 * (dlimb_t)a;
3362 a = t >> (LIMB_BITS - 3);
3363 }
3364 /* a = floor(result) and 'result' cannot be an integer */
3365 a += is_ceil;
3366 }
3367 if (is_neg)
3368 a = -a;
3369 return a;
3370 }
3371
3372 /* 'n' is the number of output limbs */
bf_integer_to_radix_rec(bf_t * pow_tab,limb_t * out,const bf_t * a,limb_t n,int level,limb_t n0,limb_t radixl,unsigned int radixl_bits)3373 static int bf_integer_to_radix_rec(bf_t *pow_tab,
3374 limb_t *out, const bf_t *a, limb_t n,
3375 int level, limb_t n0, limb_t radixl,
3376 unsigned int radixl_bits)
3377 {
3378 limb_t n1, n2, q_prec;
3379 int ret;
3380
3381 assert(n >= 1);
3382 if (n == 1) {
3383 out[0] = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
3384 } else if (n == 2) {
3385 dlimb_t t;
3386 slimb_t pos;
3387 pos = a->len * LIMB_BITS - a->expn;
3388 t = ((dlimb_t)get_bits(a->tab, a->len, pos + LIMB_BITS) << LIMB_BITS) |
3389 get_bits(a->tab, a->len, pos);
3390 if (likely(radixl == RADIXL_10)) {
3391 /* use division by a constant when possible */
3392 out[0] = t % RADIXL_10;
3393 out[1] = t / RADIXL_10;
3394 } else {
3395 out[0] = t % radixl;
3396 out[1] = t / radixl;
3397 }
3398 } else {
3399 bf_t Q, R, *B, *B_inv;
3400 int q_add;
3401 bf_init(a->ctx, &Q);
3402 bf_init(a->ctx, &R);
3403 n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
3404 n1 = n - n2;
3405 B = &pow_tab[2 * level];
3406 B_inv = &pow_tab[2 * level + 1];
3407 ret = 0;
3408 if (B->len == 0) {
3409 /* compute BASE^n2 */
3410 ret |= bf_pow_ui_ui(B, radixl, n2, BF_PREC_INF, BF_RNDZ);
3411 /* we use enough bits for the maximum possible 'n1' value,
3412 i.e. n2 + 1 */
3413 ret |= bf_set_ui(&R, 1);
3414 ret |= bf_div(B_inv, &R, B, (n2 + 1) * radixl_bits + 2, BF_RNDN);
3415 }
3416 // printf("%d: n1=% " PRId64 " n2=%" PRId64 "\n", level, n1, n2);
3417 q_prec = n1 * radixl_bits;
3418 ret |= bf_mul(&Q, a, B_inv, q_prec, BF_RNDN);
3419 ret |= bf_rint(&Q, BF_RNDZ);
3420
3421 ret |= bf_mul(&R, &Q, B, BF_PREC_INF, BF_RNDZ);
3422 ret |= bf_sub(&R, a, &R, BF_PREC_INF, BF_RNDZ);
3423
3424 if (ret & BF_ST_MEM_ERROR)
3425 goto fail;
3426 /* adjust if necessary */
3427 q_add = 0;
3428 while (R.sign && R.len != 0) {
3429 if (bf_add(&R, &R, B, BF_PREC_INF, BF_RNDZ))
3430 goto fail;
3431 q_add--;
3432 }
3433 while (bf_cmpu(&R, B) >= 0) {
3434 if (bf_sub(&R, &R, B, BF_PREC_INF, BF_RNDZ))
3435 goto fail;
3436 q_add++;
3437 }
3438 if (q_add != 0) {
3439 if (bf_add_si(&Q, &Q, q_add, BF_PREC_INF, BF_RNDZ))
3440 goto fail;
3441 }
3442 if (bf_integer_to_radix_rec(pow_tab, out + n2, &Q, n1, level + 1, n0,
3443 radixl, radixl_bits))
3444 goto fail;
3445 if (bf_integer_to_radix_rec(pow_tab, out, &R, n2, level + 1, n0,
3446 radixl, radixl_bits)) {
3447 fail:
3448 bf_delete(&Q);
3449 bf_delete(&R);
3450 return -1;
3451 }
3452 bf_delete(&Q);
3453 bf_delete(&R);
3454 }
3455 return 0;
3456 }
3457
3458 /* return 0 if OK != 0 if memory error */
bf_integer_to_radix(bf_t * r,const bf_t * a,limb_t radixl)3459 static int bf_integer_to_radix(bf_t *r, const bf_t *a, limb_t radixl)
3460 {
3461 bf_context_t *s = r->ctx;
3462 limb_t r_len;
3463 bf_t *pow_tab;
3464 int i, pow_tab_len, ret;
3465
3466 r_len = r->len;
3467 pow_tab_len = (ceil_log2(r_len) + 2) * 2; /* XXX: check */
3468 pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
3469 if (!pow_tab)
3470 return -1;
3471 for(i = 0; i < pow_tab_len; i++)
3472 bf_init(r->ctx, &pow_tab[i]);
3473
3474 ret = bf_integer_to_radix_rec(pow_tab, r->tab, a, r_len, 0, r_len, radixl,
3475 ceil_log2(radixl));
3476
3477 for(i = 0; i < pow_tab_len; i++) {
3478 bf_delete(&pow_tab[i]);
3479 }
3480 bf_free(s, pow_tab);
3481 return ret;
3482 }
3483
3484 /* a must be >= 0. 'P' is the wanted number of digits in radix
3485 'radix'. 'r' is the mantissa represented as an integer. *pE
3486 contains the exponent. Return != 0 if memory error. */
bf_convert_to_radix(bf_t * r,slimb_t * pE,const bf_t * a,int radix,limb_t P,bf_rnd_t rnd_mode,BOOL is_fixed_exponent)3487 static int bf_convert_to_radix(bf_t *r, slimb_t *pE,
3488 const bf_t *a, int radix,
3489 limb_t P, bf_rnd_t rnd_mode,
3490 BOOL is_fixed_exponent)
3491 {
3492 slimb_t E, e, prec, extra_bits, ziv_extra_bits, prec0;
3493 bf_t B_s, *B = &B_s;
3494 int e_sign, ret, res;
3495
3496 if (a->len == 0) {
3497 /* zero case */
3498 *pE = 0;
3499 return bf_set(r, a);
3500 }
3501
3502 if (is_fixed_exponent) {
3503 E = *pE;
3504 } else {
3505 /* compute the new exponent */
3506 E = 1 + bf_mul_log2_radix(a->expn - 1, radix, TRUE, FALSE);
3507 }
3508 // bf_print_str("a", a);
3509 // printf("E=%ld P=%ld radix=%d\n", E, P, radix);
3510
3511 for(;;) {
3512 e = P - E;
3513 e_sign = 0;
3514 if (e < 0) {
3515 e = -e;
3516 e_sign = 1;
3517 }
3518 /* Note: precision for log2(radix) is not critical here */
3519 prec0 = bf_mul_log2_radix(P, radix, FALSE, TRUE);
3520 ziv_extra_bits = 16;
3521 for(;;) {
3522 prec = prec0 + ziv_extra_bits;
3523 /* XXX: rigorous error analysis needed */
3524 extra_bits = ceil_log2(e) * 2 + 1;
3525 ret = bf_pow_ui_ui(r, radix, e, prec + extra_bits,
3526 BF_RNDN | BF_FLAG_EXT_EXP);
3527 if (!e_sign)
3528 ret |= bf_mul(r, r, a, prec + extra_bits,
3529 BF_RNDN | BF_FLAG_EXT_EXP);
3530 else
3531 ret |= bf_div(r, a, r, prec + extra_bits,
3532 BF_RNDN | BF_FLAG_EXT_EXP);
3533 if (ret & BF_ST_MEM_ERROR)
3534 return BF_ST_MEM_ERROR;
3535 /* if the result is not exact, check that it can be safely
3536 rounded to an integer */
3537 if ((ret & BF_ST_INEXACT) &&
3538 !bf_can_round(r, r->expn, rnd_mode, prec)) {
3539 /* and more precision and retry */
3540 ziv_extra_bits = ziv_extra_bits + (ziv_extra_bits / 2);
3541 continue;
3542 } else {
3543 ret = bf_rint(r, rnd_mode);
3544 if (ret & BF_ST_MEM_ERROR)
3545 return BF_ST_MEM_ERROR;
3546 break;
3547 }
3548 }
3549 if (is_fixed_exponent)
3550 break;
3551 /* check that the result is < B^P */
3552 /* XXX: do a fast approximate test first ? */
3553 bf_init(r->ctx, B);
3554 ret = bf_pow_ui_ui(B, radix, P, BF_PREC_INF, BF_RNDZ);
3555 if (ret) {
3556 bf_delete(B);
3557 return ret;
3558 }
3559 res = bf_cmpu(r, B);
3560 bf_delete(B);
3561 if (res < 0)
3562 break;
3563 /* try a larger exponent */
3564 E++;
3565 }
3566 *pE = E;
3567 return 0;
3568 }
3569
limb_to_a(char * buf,limb_t n,unsigned int radix,int len)3570 static void limb_to_a(char *buf, limb_t n, unsigned int radix, int len)
3571 {
3572 int digit, i;
3573
3574 if (radix == 10) {
3575 /* specific case with constant divisor */
3576 for(i = len - 1; i >= 0; i--) {
3577 digit = (limb_t)n % 10;
3578 n = (limb_t)n / 10;
3579 buf[i] = digit + '0';
3580 }
3581 } else {
3582 for(i = len - 1; i >= 0; i--) {
3583 digit = (limb_t)n % radix;
3584 n = (limb_t)n / radix;
3585 if (digit < 10)
3586 digit += '0';
3587 else
3588 digit += 'a' - 10;
3589 buf[i] = digit;
3590 }
3591 }
3592 }
3593
3594 /* for power of 2 radixes */
limb_to_a2(char * buf,limb_t n,unsigned int radix_bits,int len)3595 static void limb_to_a2(char *buf, limb_t n, unsigned int radix_bits, int len)
3596 {
3597 int digit, i;
3598 unsigned int mask;
3599
3600 mask = (1 << radix_bits) - 1;
3601 for(i = len - 1; i >= 0; i--) {
3602 digit = n & mask;
3603 n >>= radix_bits;
3604 if (digit < 10)
3605 digit += '0';
3606 else
3607 digit += 'a' - 10;
3608 buf[i] = digit;
3609 }
3610 }
3611
3612 /* 'a' must be an integer if the is_dec = FALSE or if the radix is not
3613 a power of two. A dot is added before the 'dot_pos' digit. dot_pos
3614 = n_digits does not display the dot. 0 <= dot_pos <=
3615 n_digits. n_digits >= 1. */
output_digits(DynBuf * s,const bf_t * a1,int radix,limb_t n_digits,limb_t dot_pos,BOOL is_dec)3616 static void output_digits(DynBuf *s, const bf_t *a1, int radix, limb_t n_digits,
3617 limb_t dot_pos, BOOL is_dec)
3618 {
3619 limb_t i, v, l;
3620 slimb_t pos, pos_incr;
3621 int digits_per_limb, buf_pos, radix_bits, first_buf_pos;
3622 char buf[65];
3623 bf_t a_s, *a;
3624
3625 if (is_dec) {
3626 digits_per_limb = LIMB_DIGITS;
3627 a = (bf_t *)a1;
3628 radix_bits = 0;
3629 pos = a->len;
3630 pos_incr = 1;
3631 first_buf_pos = 0;
3632 } else if ((radix & (radix - 1)) == 0) {
3633 a = (bf_t *)a1;
3634 radix_bits = ceil_log2(radix);
3635 digits_per_limb = LIMB_BITS / radix_bits;
3636 pos_incr = digits_per_limb * radix_bits;
3637 /* digits are aligned relative to the radix point */
3638 pos = a->len * LIMB_BITS + smod(-a->expn, radix_bits);
3639 first_buf_pos = 0;
3640 } else {
3641 limb_t n, radixl;
3642
3643 digits_per_limb = digits_per_limb_table[radix - 2];
3644 radixl = get_limb_radix(radix);
3645 a = &a_s;
3646 bf_init(a1->ctx, a);
3647 n = (n_digits + digits_per_limb - 1) / digits_per_limb;
3648 if (bf_resize(a, n)) {
3649 dbuf_set_error(s);
3650 goto done;
3651 }
3652 if (bf_integer_to_radix(a, a1, radixl)) {
3653 dbuf_set_error(s);
3654 goto done;
3655 }
3656 radix_bits = 0;
3657 pos = n;
3658 pos_incr = 1;
3659 first_buf_pos = pos * digits_per_limb - n_digits;
3660 }
3661 buf_pos = digits_per_limb;
3662 i = 0;
3663 while (i < n_digits) {
3664 if (buf_pos == digits_per_limb) {
3665 pos -= pos_incr;
3666 if (radix_bits == 0) {
3667 v = get_limbz(a, pos);
3668 limb_to_a(buf, v, radix, digits_per_limb);
3669 } else {
3670 v = get_bits(a->tab, a->len, pos);
3671 limb_to_a2(buf, v, radix_bits, digits_per_limb);
3672 }
3673 buf_pos = first_buf_pos;
3674 first_buf_pos = 0;
3675 }
3676 if (i < dot_pos) {
3677 l = dot_pos;
3678 } else {
3679 if (i == dot_pos)
3680 dbuf_putc(s, '.');
3681 l = n_digits;
3682 }
3683 l = bf_min(digits_per_limb - buf_pos, l - i);
3684 dbuf_put(s, (uint8_t *)(buf + buf_pos), l);
3685 buf_pos += l;
3686 i += l;
3687 }
3688 done:
3689 if (a != a1)
3690 bf_delete(a);
3691 }
3692
bf_dbuf_realloc(void * opaque,void * ptr,size_t size)3693 static void *bf_dbuf_realloc(void *opaque, void *ptr, size_t size)
3694 {
3695 bf_context_t *s = opaque;
3696 return bf_realloc(s, ptr, size);
3697 }
3698
3699 /* return the length in bytes. A trailing '\0' is added */
bf_ftoa_internal(size_t * plen,const bf_t * a2,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)3700 static char *bf_ftoa_internal(size_t *plen, const bf_t *a2, int radix,
3701 limb_t prec, bf_flags_t flags, BOOL is_dec)
3702 {
3703 bf_context_t *ctx = a2->ctx;
3704 DynBuf s_s, *s = &s_s;
3705 int radix_bits;
3706
3707 // bf_print_str("ftoa", a2);
3708 // printf("radix=%d\n", radix);
3709 dbuf_init2(s, ctx, bf_dbuf_realloc);
3710 if (a2->expn == BF_EXP_NAN) {
3711 dbuf_putstr(s, "NaN");
3712 } else {
3713 if (a2->sign)
3714 dbuf_putc(s, '-');
3715 if (a2->expn == BF_EXP_INF) {
3716 if (flags & BF_FTOA_JS_QUIRKS)
3717 dbuf_putstr(s, "Infinity");
3718 else
3719 dbuf_putstr(s, "Inf");
3720 } else {
3721 int fmt, ret;
3722 slimb_t n_digits, n, i, n_max, n1;
3723 bf_t a1_s, *a1 = &a1_s;
3724
3725 if ((radix & (radix - 1)) != 0)
3726 radix_bits = 0;
3727 else
3728 radix_bits = ceil_log2(radix);
3729
3730 fmt = flags & BF_FTOA_FORMAT_MASK;
3731 bf_init(ctx, a1);
3732 if (fmt == BF_FTOA_FORMAT_FRAC) {
3733 if (is_dec || radix_bits != 0) {
3734 if (bf_set(a1, a2))
3735 goto fail1;
3736 #ifdef USE_BF_DEC
3737 if (is_dec) {
3738 if (bfdec_round((bfdec_t *)a1, prec, (flags & BF_RND_MASK) | BF_FLAG_RADPNT_PREC) & BF_ST_MEM_ERROR)
3739 goto fail1;
3740 n = a1->expn;
3741 } else
3742 #endif
3743 {
3744 if (bf_round(a1, prec * radix_bits, (flags & BF_RND_MASK) | BF_FLAG_RADPNT_PREC) & BF_ST_MEM_ERROR)
3745 goto fail1;
3746 n = ceil_div(a1->expn, radix_bits);
3747 }
3748 if (flags & BF_FTOA_ADD_PREFIX) {
3749 if (radix == 16)
3750 dbuf_putstr(s, "0x");
3751 else if (radix == 8)
3752 dbuf_putstr(s, "0o");
3753 else if (radix == 2)
3754 dbuf_putstr(s, "0b");
3755 }
3756 if (a1->expn == BF_EXP_ZERO) {
3757 dbuf_putstr(s, "0");
3758 if (prec > 0) {
3759 dbuf_putstr(s, ".");
3760 for(i = 0; i < prec; i++) {
3761 dbuf_putc(s, '0');
3762 }
3763 }
3764 } else {
3765 n_digits = prec + n;
3766 if (n <= 0) {
3767 /* 0.x */
3768 dbuf_putstr(s, "0.");
3769 for(i = 0; i < -n; i++) {
3770 dbuf_putc(s, '0');
3771 }
3772 if (n_digits > 0) {
3773 output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3774 }
3775 } else {
3776 output_digits(s, a1, radix, n_digits, n, is_dec);
3777 }
3778 }
3779 } else {
3780 size_t pos, start;
3781 bf_t a_s, *a = &a_s;
3782
3783 /* make a positive number */
3784 a->tab = a2->tab;
3785 a->len = a2->len;
3786 a->expn = a2->expn;
3787 a->sign = 0;
3788
3789 /* one more digit for the rounding */
3790 n = 1 + bf_mul_log2_radix(bf_max(a->expn, 0), radix, TRUE, TRUE);
3791 n_digits = n + prec;
3792 n1 = n;
3793 if (bf_convert_to_radix(a1, &n1, a, radix, n_digits,
3794 flags & BF_RND_MASK, TRUE))
3795 goto fail1;
3796 start = s->size;
3797 output_digits(s, a1, radix, n_digits, n, is_dec);
3798 /* remove leading zeros because we allocated one more digit */
3799 pos = start;
3800 while ((pos + 1) < s->size && s->buf[pos] == '0' &&
3801 s->buf[pos + 1] != '.')
3802 pos++;
3803 if (pos > start) {
3804 memmove(s->buf + start, s->buf + pos, s->size - pos);
3805 s->size -= (pos - start);
3806 }
3807 }
3808 } else {
3809 #ifdef USE_BF_DEC
3810 if (is_dec) {
3811 if (bf_set(a1, a2))
3812 goto fail1;
3813 if (fmt == BF_FTOA_FORMAT_FIXED) {
3814 n_digits = prec;
3815 n_max = n_digits;
3816 if (bfdec_round((bfdec_t *)a1, prec, (flags & BF_RND_MASK)) & BF_ST_MEM_ERROR)
3817 goto fail1;
3818 } else {
3819 /* prec is ignored */
3820 prec = n_digits = a1->len * LIMB_DIGITS;
3821 /* remove the trailing zero digits */
3822 while (n_digits > 1 &&
3823 get_digit(a1->tab, a1->len, prec - n_digits) == 0) {
3824 n_digits--;
3825 }
3826 n_max = n_digits + 4;
3827 }
3828 n = a1->expn;
3829 } else
3830 #endif
3831 if (radix_bits != 0) {
3832 if (bf_set(a1, a2))
3833 goto fail1;
3834 if (fmt == BF_FTOA_FORMAT_FIXED) {
3835 slimb_t prec_bits;
3836 n_digits = prec;
3837 n_max = n_digits;
3838 /* align to the radix point */
3839 prec_bits = prec * radix_bits -
3840 smod(-a1->expn, radix_bits);
3841 if (bf_round(a1, prec_bits,
3842 (flags & BF_RND_MASK)) & BF_ST_MEM_ERROR)
3843 goto fail1;
3844 } else {
3845 limb_t digit_mask;
3846 slimb_t pos;
3847 /* position of the digit before the most
3848 significant digit in bits */
3849 pos = a1->len * LIMB_BITS +
3850 smod(-a1->expn, radix_bits);
3851 n_digits = ceil_div(pos, radix_bits);
3852 /* remove the trailing zero digits */
3853 digit_mask = ((limb_t)1 << radix_bits) - 1;
3854 while (n_digits > 1 &&
3855 (get_bits(a1->tab, a1->len, pos - n_digits * radix_bits) & digit_mask) == 0) {
3856 n_digits--;
3857 }
3858 n_max = n_digits + 4;
3859 }
3860 n = ceil_div(a1->expn, radix_bits);
3861 } else {
3862 bf_t a_s, *a = &a_s;
3863
3864 /* make a positive number */
3865 a->tab = a2->tab;
3866 a->len = a2->len;
3867 a->expn = a2->expn;
3868 a->sign = 0;
3869
3870 if (fmt == BF_FTOA_FORMAT_FIXED) {
3871 n_digits = prec;
3872 n_max = n_digits;
3873 } else {
3874 slimb_t n_digits_max, n_digits_min;
3875
3876 assert(prec != BF_PREC_INF);
3877 n_digits = 1 + bf_mul_log2_radix(prec, radix, TRUE, TRUE);
3878 /* max number of digits for non exponential
3879 notation. The rational is to have the same rule
3880 as JS i.e. n_max = 21 for 64 bit float in base 10. */
3881 n_max = n_digits + 4;
3882 if (fmt == BF_FTOA_FORMAT_FREE_MIN) {
3883 bf_t b_s, *b = &b_s;
3884
3885 /* find the minimum number of digits by
3886 dichotomy. */
3887 /* XXX: inefficient */
3888 n_digits_max = n_digits;
3889 n_digits_min = 1;
3890 bf_init(ctx, b);
3891 while (n_digits_min < n_digits_max) {
3892 n_digits = (n_digits_min + n_digits_max) / 2;
3893 if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3894 flags & BF_RND_MASK, FALSE)) {
3895 bf_delete(b);
3896 goto fail1;
3897 }
3898 /* convert back to a number and compare */
3899 ret = bf_mul_pow_radix(b, a1, radix, n - n_digits,
3900 prec,
3901 (flags & ~BF_RND_MASK) |
3902 BF_RNDN);
3903 if (ret & BF_ST_MEM_ERROR) {
3904 bf_delete(b);
3905 goto fail1;
3906 }
3907 if (bf_cmpu(b, a) == 0) {
3908 n_digits_max = n_digits;
3909 } else {
3910 n_digits_min = n_digits + 1;
3911 }
3912 }
3913 bf_delete(b);
3914 n_digits = n_digits_max;
3915 }
3916 }
3917 if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3918 flags & BF_RND_MASK, FALSE)) {
3919 fail1:
3920 bf_delete(a1);
3921 goto fail;
3922 }
3923 }
3924 if (a1->expn == BF_EXP_ZERO &&
3925 fmt != BF_FTOA_FORMAT_FIXED &&
3926 !(flags & BF_FTOA_FORCE_EXP)) {
3927 /* just output zero */
3928 dbuf_putstr(s, "0");
3929 } else {
3930 if (flags & BF_FTOA_ADD_PREFIX) {
3931 if (radix == 16)
3932 dbuf_putstr(s, "0x");
3933 else if (radix == 8)
3934 dbuf_putstr(s, "0o");
3935 else if (radix == 2)
3936 dbuf_putstr(s, "0b");
3937 }
3938 if (a1->expn == BF_EXP_ZERO)
3939 n = 1;
3940 if ((flags & BF_FTOA_FORCE_EXP) ||
3941 n <= -6 || n > n_max) {
3942 const char *fmt;
3943 /* exponential notation */
3944 output_digits(s, a1, radix, n_digits, 1, is_dec);
3945 if (radix_bits != 0 && radix <= 16) {
3946 if (flags & BF_FTOA_JS_QUIRKS)
3947 fmt = "p%+" PRId_LIMB;
3948 else
3949 fmt = "p%" PRId_LIMB;
3950 dbuf_printf(s, fmt, (n - 1) * radix_bits);
3951 } else {
3952 if (flags & BF_FTOA_JS_QUIRKS)
3953 fmt = "%c%+" PRId_LIMB;
3954 else
3955 fmt = "%c%" PRId_LIMB;
3956 dbuf_printf(s, fmt,
3957 radix <= 10 ? 'e' : '@', n - 1);
3958 }
3959 } else if (n <= 0) {
3960 /* 0.x */
3961 dbuf_putstr(s, "0.");
3962 for(i = 0; i < -n; i++) {
3963 dbuf_putc(s, '0');
3964 }
3965 output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3966 } else {
3967 if (n_digits <= n) {
3968 /* no dot */
3969 output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3970 for(i = 0; i < (n - n_digits); i++)
3971 dbuf_putc(s, '0');
3972 } else {
3973 output_digits(s, a1, radix, n_digits, n, is_dec);
3974 }
3975 }
3976 }
3977 }
3978 bf_delete(a1);
3979 }
3980 }
3981 dbuf_putc(s, '\0');
3982 if (dbuf_error(s))
3983 goto fail;
3984 if (plen)
3985 *plen = s->size - 1;
3986 return (char *)s->buf;
3987 fail:
3988 bf_free(ctx, s->buf);
3989 if (plen)
3990 *plen = 0;
3991 return NULL;
3992 }
3993
bf_ftoa(size_t * plen,const bf_t * a,int radix,limb_t prec,bf_flags_t flags)3994 char *bf_ftoa(size_t *plen, const bf_t *a, int radix, limb_t prec,
3995 bf_flags_t flags)
3996 {
3997 return bf_ftoa_internal(plen, a, radix, prec, flags, FALSE);
3998 }
3999
4000 /***************************************************************/
4001 /* transcendental functions */
4002
4003 /* Note: the algorithm is from MPFR */
bf_const_log2_rec(bf_t * T,bf_t * P,bf_t * Q,limb_t n1,limb_t n2,BOOL need_P)4004 static void bf_const_log2_rec(bf_t *T, bf_t *P, bf_t *Q, limb_t n1,
4005 limb_t n2, BOOL need_P)
4006 {
4007 bf_context_t *s = T->ctx;
4008 if ((n2 - n1) == 1) {
4009 if (n1 == 0) {
4010 bf_set_ui(P, 3);
4011 } else {
4012 bf_set_ui(P, n1);
4013 P->sign = 1;
4014 }
4015 bf_set_ui(Q, 2 * n1 + 1);
4016 Q->expn += 2;
4017 bf_set(T, P);
4018 } else {
4019 limb_t m;
4020 bf_t T1_s, *T1 = &T1_s;
4021 bf_t P1_s, *P1 = &P1_s;
4022 bf_t Q1_s, *Q1 = &Q1_s;
4023
4024 m = n1 + ((n2 - n1) >> 1);
4025 bf_const_log2_rec(T, P, Q, n1, m, TRUE);
4026 bf_init(s, T1);
4027 bf_init(s, P1);
4028 bf_init(s, Q1);
4029 bf_const_log2_rec(T1, P1, Q1, m, n2, need_P);
4030 bf_mul(T, T, Q1, BF_PREC_INF, BF_RNDZ);
4031 bf_mul(T1, T1, P, BF_PREC_INF, BF_RNDZ);
4032 bf_add(T, T, T1, BF_PREC_INF, BF_RNDZ);
4033 if (need_P)
4034 bf_mul(P, P, P1, BF_PREC_INF, BF_RNDZ);
4035 bf_mul(Q, Q, Q1, BF_PREC_INF, BF_RNDZ);
4036 bf_delete(T1);
4037 bf_delete(P1);
4038 bf_delete(Q1);
4039 }
4040 }
4041
4042 /* compute log(2) with faithful rounding at precision 'prec' */
bf_const_log2_internal(bf_t * T,limb_t prec)4043 static void bf_const_log2_internal(bf_t *T, limb_t prec)
4044 {
4045 limb_t w, N;
4046 bf_t P_s, *P = &P_s;
4047 bf_t Q_s, *Q = &Q_s;
4048
4049 w = prec + 15;
4050 N = w / 3 + 1;
4051 bf_init(T->ctx, P);
4052 bf_init(T->ctx, Q);
4053 bf_const_log2_rec(T, P, Q, 0, N, FALSE);
4054 bf_div(T, T, Q, prec, BF_RNDN);
4055 bf_delete(P);
4056 bf_delete(Q);
4057 }
4058
4059 /* PI constant */
4060
4061 #define CHUD_A 13591409
4062 #define CHUD_B 545140134
4063 #define CHUD_C 640320
4064 #define CHUD_BITS_PER_TERM 47
4065
chud_bs(bf_t * P,bf_t * Q,bf_t * G,int64_t a,int64_t b,int need_g,limb_t prec)4066 static void chud_bs(bf_t *P, bf_t *Q, bf_t *G, int64_t a, int64_t b, int need_g,
4067 limb_t prec)
4068 {
4069 bf_context_t *s = P->ctx;
4070 int64_t c;
4071
4072 if (a == (b - 1)) {
4073 bf_t T0, T1;
4074
4075 bf_init(s, &T0);
4076 bf_init(s, &T1);
4077 bf_set_ui(G, 2 * b - 1);
4078 bf_mul_ui(G, G, 6 * b - 1, prec, BF_RNDN);
4079 bf_mul_ui(G, G, 6 * b - 5, prec, BF_RNDN);
4080 bf_set_ui(&T0, CHUD_B);
4081 bf_mul_ui(&T0, &T0, b, prec, BF_RNDN);
4082 bf_set_ui(&T1, CHUD_A);
4083 bf_add(&T0, &T0, &T1, prec, BF_RNDN);
4084 bf_mul(P, G, &T0, prec, BF_RNDN);
4085 P->sign = b & 1;
4086
4087 bf_set_ui(Q, b);
4088 bf_mul_ui(Q, Q, b, prec, BF_RNDN);
4089 bf_mul_ui(Q, Q, b, prec, BF_RNDN);
4090 bf_mul_ui(Q, Q, (uint64_t)CHUD_C * CHUD_C * CHUD_C / 24, prec, BF_RNDN);
4091 bf_delete(&T0);
4092 bf_delete(&T1);
4093 } else {
4094 bf_t P2, Q2, G2;
4095
4096 bf_init(s, &P2);
4097 bf_init(s, &Q2);
4098 bf_init(s, &G2);
4099
4100 c = (a + b) / 2;
4101 chud_bs(P, Q, G, a, c, 1, prec);
4102 chud_bs(&P2, &Q2, &G2, c, b, need_g, prec);
4103
4104 /* Q = Q1 * Q2 */
4105 /* G = G1 * G2 */
4106 /* P = P1 * Q2 + P2 * G1 */
4107 bf_mul(&P2, &P2, G, prec, BF_RNDN);
4108 if (!need_g)
4109 bf_set_ui(G, 0);
4110 bf_mul(P, P, &Q2, prec, BF_RNDN);
4111 bf_add(P, P, &P2, prec, BF_RNDN);
4112 bf_delete(&P2);
4113
4114 bf_mul(Q, Q, &Q2, prec, BF_RNDN);
4115 bf_delete(&Q2);
4116 if (need_g)
4117 bf_mul(G, G, &G2, prec, BF_RNDN);
4118 bf_delete(&G2);
4119 }
4120 }
4121
4122 /* compute Pi with faithful rounding at precision 'prec' using the
4123 Chudnovsky formula */
bf_const_pi_internal(bf_t * Q,limb_t prec)4124 static void bf_const_pi_internal(bf_t *Q, limb_t prec)
4125 {
4126 bf_context_t *s = Q->ctx;
4127 int64_t n, prec1;
4128 bf_t P, G;
4129
4130 /* number of serie terms */
4131 n = prec / CHUD_BITS_PER_TERM + 1;
4132 /* XXX: precision analysis */
4133 prec1 = prec + 32;
4134
4135 bf_init(s, &P);
4136 bf_init(s, &G);
4137
4138 chud_bs(&P, Q, &G, 0, n, 0, BF_PREC_INF);
4139
4140 bf_mul_ui(&G, Q, CHUD_A, prec1, BF_RNDN);
4141 bf_add(&P, &G, &P, prec1, BF_RNDN);
4142 bf_div(Q, Q, &P, prec1, BF_RNDF);
4143
4144 bf_set_ui(&P, CHUD_C);
4145 bf_sqrt(&G, &P, prec1, BF_RNDF);
4146 bf_mul_ui(&G, &G, (uint64_t)CHUD_C / 12, prec1, BF_RNDF);
4147 bf_mul(Q, Q, &G, prec, BF_RNDN);
4148 bf_delete(&P);
4149 bf_delete(&G);
4150 }
4151
bf_const_get(bf_t * T,limb_t prec,bf_flags_t flags,BFConstCache * c,void (* func)(bf_t * res,limb_t prec),int sign)4152 static int bf_const_get(bf_t *T, limb_t prec, bf_flags_t flags,
4153 BFConstCache *c,
4154 void (*func)(bf_t *res, limb_t prec), int sign)
4155 {
4156 limb_t ziv_extra_bits, prec1;
4157
4158 ziv_extra_bits = 32;
4159 for(;;) {
4160 prec1 = prec + ziv_extra_bits;
4161 if (c->prec < prec1) {
4162 if (c->val.len == 0)
4163 bf_init(T->ctx, &c->val);
4164 func(&c->val, prec1);
4165 c->prec = prec1;
4166 } else {
4167 prec1 = c->prec;
4168 }
4169 bf_set(T, &c->val);
4170 T->sign = sign;
4171 if (!bf_can_round(T, prec, flags & BF_RND_MASK, prec1)) {
4172 /* and more precision and retry */
4173 ziv_extra_bits = ziv_extra_bits + (ziv_extra_bits / 2);
4174 } else {
4175 break;
4176 }
4177 }
4178 return bf_round(T, prec, flags);
4179 }
4180
bf_const_free(BFConstCache * c)4181 static void bf_const_free(BFConstCache *c)
4182 {
4183 bf_delete(&c->val);
4184 memset(c, 0, sizeof(*c));
4185 }
4186
bf_const_log2(bf_t * T,limb_t prec,bf_flags_t flags)4187 int bf_const_log2(bf_t *T, limb_t prec, bf_flags_t flags)
4188 {
4189 bf_context_t *s = T->ctx;
4190 return bf_const_get(T, prec, flags, &s->log2_cache, bf_const_log2_internal, 0);
4191 }
4192
4193 /* return rounded pi * (1 - 2 * sign) */
bf_const_pi_signed(bf_t * T,int sign,limb_t prec,bf_flags_t flags)4194 static int bf_const_pi_signed(bf_t *T, int sign, limb_t prec, bf_flags_t flags)
4195 {
4196 bf_context_t *s = T->ctx;
4197 return bf_const_get(T, prec, flags, &s->pi_cache, bf_const_pi_internal,
4198 sign);
4199 }
4200
bf_const_pi(bf_t * T,limb_t prec,bf_flags_t flags)4201 int bf_const_pi(bf_t *T, limb_t prec, bf_flags_t flags)
4202 {
4203 return bf_const_pi_signed(T, 0, prec, flags);
4204 }
4205
bf_clear_cache(bf_context_t * s)4206 void bf_clear_cache(bf_context_t *s)
4207 {
4208 #ifdef USE_FFT_MUL
4209 fft_clear_cache(s);
4210 #endif
4211 bf_const_free(&s->log2_cache);
4212 bf_const_free(&s->pi_cache);
4213 }
4214
4215 /* ZivFunc should compute the result 'r' with faithful rounding at
4216 precision 'prec'. For efficiency purposes, the final bf_round()
4217 does not need to be done in the function. */
4218 typedef int ZivFunc(bf_t *r, const bf_t *a, limb_t prec, void *opaque);
4219
bf_ziv_rounding(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags,ZivFunc * f,void * opaque)4220 static int bf_ziv_rounding(bf_t *r, const bf_t *a,
4221 limb_t prec, bf_flags_t flags,
4222 ZivFunc *f, void *opaque)
4223 {
4224 int rnd_mode, ret;
4225 slimb_t prec1, ziv_extra_bits;
4226
4227 rnd_mode = flags & BF_RND_MASK;
4228 if (rnd_mode == BF_RNDF) {
4229 /* no need to iterate */
4230 f(r, a, prec, opaque);
4231 ret = 0;
4232 } else {
4233 ziv_extra_bits = 32;
4234 for(;;) {
4235 prec1 = prec + ziv_extra_bits;
4236 ret = f(r, a, prec1, opaque);
4237 if (ret & (BF_ST_OVERFLOW | BF_ST_UNDERFLOW | BF_ST_MEM_ERROR)) {
4238 /* overflow or underflow should never happen because
4239 it indicates the rounding cannot be done correctly,
4240 but we do not catch all the cases */
4241 return ret;
4242 }
4243 /* if the result is exact, we can stop */
4244 if (!(ret & BF_ST_INEXACT)) {
4245 ret = 0;
4246 break;
4247 }
4248 if (bf_can_round(r, prec, rnd_mode, prec1)) {
4249 ret = BF_ST_INEXACT;
4250 break;
4251 }
4252 ziv_extra_bits = ziv_extra_bits * 2;
4253 // printf("ziv_extra_bits=%" PRId64 "\n", (int64_t)ziv_extra_bits);
4254 }
4255 }
4256 if (r->len == 0)
4257 return ret;
4258 else
4259 return __bf_round(r, prec, flags, r->len, ret);
4260 }
4261
4262 /* add (1 - 2*e_sign) * 2^e */
bf_add_epsilon(bf_t * r,const bf_t * a,slimb_t e,int e_sign,limb_t prec,int flags)4263 static int bf_add_epsilon(bf_t *r, const bf_t *a, slimb_t e, int e_sign,
4264 limb_t prec, int flags)
4265 {
4266 bf_t T_s, *T = &T_s;
4267 int ret;
4268 /* small argument case: result = 1 + epsilon * sign(x) */
4269 bf_init(a->ctx, T);
4270 bf_set_ui(T, 1);
4271 T->sign = e_sign;
4272 T->expn += e;
4273 ret = bf_add(r, r, T, prec, flags);
4274 bf_delete(T);
4275 return ret;
4276 }
4277
4278 /* Compute the exponential using faithful rounding at precision 'prec'.
4279 Note: the algorithm is from MPFR */
bf_exp_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4280 static int bf_exp_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4281 {
4282 bf_context_t *s = r->ctx;
4283 bf_t T_s, *T = &T_s;
4284 slimb_t n, K, l, i, prec1;
4285
4286 assert(r != a);
4287
4288 /* argument reduction:
4289 T = a - n*log(2) with 0 <= T < log(2) and n integer.
4290 */
4291 bf_init(s, T);
4292 if (a->expn <= -1) {
4293 /* 0 <= abs(a) <= 0.5 */
4294 if (a->sign)
4295 n = -1;
4296 else
4297 n = 0;
4298 } else {
4299 bf_const_log2(T, LIMB_BITS, BF_RNDZ);
4300 bf_div(T, a, T, LIMB_BITS, BF_RNDD);
4301 bf_get_limb(&n, T, 0);
4302 }
4303
4304 K = bf_isqrt((prec + 1) / 2);
4305 l = (prec - 1) / K + 1;
4306 /* XXX: precision analysis ? */
4307 prec1 = prec + (K + 2 * l + 18) + K + 8;
4308 if (a->expn > 0)
4309 prec1 += a->expn;
4310 // printf("n=%ld K=%ld prec1=%ld\n", n, K, prec1);
4311
4312 bf_const_log2(T, prec1, BF_RNDF);
4313 bf_mul_si(T, T, n, prec1, BF_RNDN);
4314 bf_sub(T, a, T, prec1, BF_RNDN);
4315
4316 /* reduce the range of T */
4317 bf_mul_2exp(T, -K, BF_PREC_INF, BF_RNDZ);
4318
4319 /* Taylor expansion around zero :
4320 1 + x + x^2/2 + ... + x^n/n!
4321 = (1 + x * (1 + x/2 * (1 + ... (x/n))))
4322 */
4323 {
4324 bf_t U_s, *U = &U_s;
4325
4326 bf_init(s, U);
4327 bf_set_ui(r, 1);
4328 for(i = l ; i >= 1; i--) {
4329 bf_set_ui(U, i);
4330 bf_div(U, T, U, prec1, BF_RNDN);
4331 bf_mul(r, r, U, prec1, BF_RNDN);
4332 bf_add_si(r, r, 1, prec1, BF_RNDN);
4333 }
4334 bf_delete(U);
4335 }
4336 bf_delete(T);
4337
4338 /* undo the range reduction */
4339 for(i = 0; i < K; i++) {
4340 bf_mul(r, r, r, prec1, BF_RNDN | BF_FLAG_EXT_EXP);
4341 }
4342
4343 /* undo the argument reduction */
4344 bf_mul_2exp(r, n, BF_PREC_INF, BF_RNDZ | BF_FLAG_EXT_EXP);
4345
4346 return BF_ST_INEXACT;
4347 }
4348
4349 /* crude overflow and underflow tests for exp(a). a_low <= a <= a_high */
check_exp_underflow_overflow(bf_context_t * s,bf_t * r,const bf_t * a_low,const bf_t * a_high,limb_t prec,bf_flags_t flags)4350 static int check_exp_underflow_overflow(bf_context_t *s, bf_t *r,
4351 const bf_t *a_low, const bf_t *a_high,
4352 limb_t prec, bf_flags_t flags)
4353 {
4354 bf_t T_s, *T = &T_s;
4355 bf_t log2_s, *log2 = &log2_s;
4356 slimb_t e_min, e_max;
4357
4358 if (a_high->expn <= 0)
4359 return 0;
4360
4361 e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
4362 e_min = -e_max + 3;
4363 if (flags & BF_FLAG_SUBNORMAL)
4364 e_min -= (prec - 1);
4365
4366 bf_init(s, T);
4367 bf_init(s, log2);
4368 bf_const_log2(log2, LIMB_BITS, BF_RNDU);
4369 bf_mul_ui(T, log2, e_max, LIMB_BITS, BF_RNDU);
4370 /* a_low > e_max * log(2) implies exp(a) > e_max */
4371 if (bf_cmp_lt(T, a_low) > 0) {
4372 /* overflow */
4373 bf_delete(T);
4374 bf_delete(log2);
4375 return bf_set_overflow(r, 0, prec, flags);
4376 }
4377 /* a_high < (e_min - 2) * log(2) implies exp(a) < (e_min - 2) */
4378 bf_const_log2(log2, LIMB_BITS, BF_RNDD);
4379 bf_mul_si(T, log2, e_min - 2, LIMB_BITS, BF_RNDD);
4380 if (bf_cmp_lt(a_high, T)) {
4381 int rnd_mode = flags & BF_RND_MASK;
4382
4383 /* underflow */
4384 bf_delete(T);
4385 bf_delete(log2);
4386 if (rnd_mode == BF_RNDU) {
4387 /* set the smallest value */
4388 bf_set_ui(r, 1);
4389 r->expn = e_min;
4390 } else {
4391 bf_set_zero(r, 0);
4392 }
4393 return BF_ST_UNDERFLOW | BF_ST_INEXACT;
4394 }
4395 bf_delete(log2);
4396 bf_delete(T);
4397 return 0;
4398 }
4399
bf_exp(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4400 int bf_exp(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4401 {
4402 bf_context_t *s = r->ctx;
4403 int ret;
4404 assert(r != a);
4405 if (a->len == 0) {
4406 if (a->expn == BF_EXP_NAN) {
4407 bf_set_nan(r);
4408 } else if (a->expn == BF_EXP_INF) {
4409 if (a->sign)
4410 bf_set_zero(r, 0);
4411 else
4412 bf_set_inf(r, 0);
4413 } else {
4414 bf_set_ui(r, 1);
4415 }
4416 return 0;
4417 }
4418
4419 ret = check_exp_underflow_overflow(s, r, a, a, prec, flags);
4420 if (ret)
4421 return ret;
4422 if (a->expn < 0 && (-a->expn) >= (prec + 2)) {
4423 /* small argument case: result = 1 + epsilon * sign(x) */
4424 bf_set_ui(r, 1);
4425 return bf_add_epsilon(r, r, -(prec + 2), a->sign, prec, flags);
4426 }
4427
4428 return bf_ziv_rounding(r, a, prec, flags, bf_exp_internal, NULL);
4429 }
4430
bf_log_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4431 static int bf_log_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4432 {
4433 bf_context_t *s = r->ctx;
4434 bf_t T_s, *T = &T_s;
4435 bf_t U_s, *U = &U_s;
4436 bf_t V_s, *V = &V_s;
4437 slimb_t n, prec1, l, i, K;
4438
4439 assert(r != a);
4440
4441 bf_init(s, T);
4442 /* argument reduction 1 */
4443 /* T=a*2^n with 2/3 <= T <= 4/3 */
4444 {
4445 bf_t U_s, *U = &U_s;
4446 bf_set(T, a);
4447 n = T->expn;
4448 T->expn = 0;
4449 /* U= ~ 2/3 */
4450 bf_init(s, U);
4451 bf_set_ui(U, 0xaaaaaaaa);
4452 U->expn = 0;
4453 if (bf_cmp_lt(T, U)) {
4454 T->expn++;
4455 n--;
4456 }
4457 bf_delete(U);
4458 }
4459 // printf("n=%ld\n", n);
4460 // bf_print_str("T", T);
4461
4462 /* XXX: precision analysis */
4463 /* number of iterations for argument reduction 2 */
4464 K = bf_isqrt((prec + 1) / 2);
4465 /* order of Taylor expansion */
4466 l = prec / (2 * K) + 1;
4467 /* precision of the intermediate computations */
4468 prec1 = prec + K + 2 * l + 32;
4469
4470 bf_init(s, U);
4471 bf_init(s, V);
4472
4473 /* Note: cancellation occurs here, so we use more precision (XXX:
4474 reduce the precision by computing the exact cancellation) */
4475 bf_add_si(T, T, -1, BF_PREC_INF, BF_RNDN);
4476
4477 /* argument reduction 2 */
4478 for(i = 0; i < K; i++) {
4479 /* T = T / (1 + sqrt(1 + T)) */
4480 bf_add_si(U, T, 1, prec1, BF_RNDN);
4481 bf_sqrt(V, U, prec1, BF_RNDF);
4482 bf_add_si(U, V, 1, prec1, BF_RNDN);
4483 bf_div(T, T, U, prec1, BF_RNDN);
4484 }
4485
4486 {
4487 bf_t Y_s, *Y = &Y_s;
4488 bf_t Y2_s, *Y2 = &Y2_s;
4489 bf_init(s, Y);
4490 bf_init(s, Y2);
4491
4492 /* compute ln(1+x) = ln((1+y)/(1-y)) with y=x/(2+x)
4493 = y + y^3/3 + ... + y^(2*l + 1) / (2*l+1)
4494 with Y=Y^2
4495 = y*(1+Y/3+Y^2/5+...) = y*(1+Y*(1/3+Y*(1/5 + ...)))
4496 */
4497 bf_add_si(Y, T, 2, prec1, BF_RNDN);
4498 bf_div(Y, T, Y, prec1, BF_RNDN);
4499
4500 bf_mul(Y2, Y, Y, prec1, BF_RNDN);
4501 bf_set_ui(r, 0);
4502 for(i = l; i >= 1; i--) {
4503 bf_set_ui(U, 1);
4504 bf_set_ui(V, 2 * i + 1);
4505 bf_div(U, U, V, prec1, BF_RNDN);
4506 bf_add(r, r, U, prec1, BF_RNDN);
4507 bf_mul(r, r, Y2, prec1, BF_RNDN);
4508 }
4509 bf_add_si(r, r, 1, prec1, BF_RNDN);
4510 bf_mul(r, r, Y, prec1, BF_RNDN);
4511 bf_delete(Y);
4512 bf_delete(Y2);
4513 }
4514 bf_delete(V);
4515 bf_delete(U);
4516
4517 /* multiplication by 2 for the Taylor expansion and undo the
4518 argument reduction 2*/
4519 bf_mul_2exp(r, K + 1, BF_PREC_INF, BF_RNDZ);
4520
4521 /* undo the argument reduction 1 */
4522 bf_const_log2(T, prec1, BF_RNDF);
4523 bf_mul_si(T, T, n, prec1, BF_RNDN);
4524 bf_add(r, r, T, prec1, BF_RNDN);
4525
4526 bf_delete(T);
4527 return BF_ST_INEXACT;
4528 }
4529
bf_log(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4530 int bf_log(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4531 {
4532 bf_context_t *s = r->ctx;
4533 bf_t T_s, *T = &T_s;
4534
4535 assert(r != a);
4536 if (a->len == 0) {
4537 if (a->expn == BF_EXP_NAN) {
4538 bf_set_nan(r);
4539 return 0;
4540 } else if (a->expn == BF_EXP_INF) {
4541 if (a->sign) {
4542 bf_set_nan(r);
4543 return BF_ST_INVALID_OP;
4544 } else {
4545 bf_set_inf(r, 0);
4546 return 0;
4547 }
4548 } else {
4549 bf_set_inf(r, 1);
4550 return 0;
4551 }
4552 }
4553 if (a->sign) {
4554 bf_set_nan(r);
4555 return BF_ST_INVALID_OP;
4556 }
4557 bf_init(s, T);
4558 bf_set_ui(T, 1);
4559 if (bf_cmp_eq(a, T)) {
4560 bf_set_zero(r, 0);
4561 bf_delete(T);
4562 return 0;
4563 }
4564 bf_delete(T);
4565
4566 return bf_ziv_rounding(r, a, prec, flags, bf_log_internal, NULL);
4567 }
4568
4569 /* x and y finite and x > 0 */
bf_pow_generic(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4570 static int bf_pow_generic(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4571 {
4572 bf_context_t *s = r->ctx;
4573 const bf_t *y = opaque;
4574 bf_t T_s, *T = &T_s;
4575 limb_t prec1;
4576
4577 bf_init(s, T);
4578 /* XXX: proof for the added precision */
4579 prec1 = prec + 32;
4580 bf_log(T, x, prec1, BF_RNDF | BF_FLAG_EXT_EXP);
4581 bf_mul(T, T, y, prec1, BF_RNDF | BF_FLAG_EXT_EXP);
4582 if (bf_is_nan(T))
4583 bf_set_nan(r);
4584 else
4585 bf_exp_internal(r, T, prec1, NULL); /* no overflow/underlow test needed */
4586 bf_delete(T);
4587 return BF_ST_INEXACT;
4588 }
4589
4590 /* x and y finite, x > 0, y integer and y fits on one limb */
bf_pow_int(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4591 static int bf_pow_int(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4592 {
4593 bf_context_t *s = r->ctx;
4594 const bf_t *y = opaque;
4595 bf_t T_s, *T = &T_s;
4596 limb_t prec1;
4597 int ret;
4598 slimb_t y1;
4599
4600 bf_get_limb(&y1, y, 0);
4601 if (y1 < 0)
4602 y1 = -y1;
4603 /* XXX: proof for the added precision */
4604 prec1 = prec + ceil_log2(y1) * 2 + 8;
4605 ret = bf_pow_ui(r, x, y1 < 0 ? -y1 : y1, prec1, BF_RNDN | BF_FLAG_EXT_EXP);
4606 if (y->sign) {
4607 bf_init(s, T);
4608 bf_set_ui(T, 1);
4609 ret |= bf_div(r, T, r, prec1, BF_RNDN | BF_FLAG_EXT_EXP);
4610 bf_delete(T);
4611 }
4612 return ret;
4613 }
4614
4615 /* x must be a finite non zero float. Return TRUE if there is a
4616 floating point number r such as x=r^(2^n) and return this floating
4617 point number 'r'. Otherwise return FALSE and r is undefined. */
check_exact_power2n(bf_t * r,const bf_t * x,slimb_t n)4618 static BOOL check_exact_power2n(bf_t *r, const bf_t *x, slimb_t n)
4619 {
4620 bf_context_t *s = r->ctx;
4621 bf_t T_s, *T = &T_s;
4622 slimb_t e, i, er;
4623 limb_t v;
4624
4625 /* x = m*2^e with m odd integer */
4626 e = bf_get_exp_min(x);
4627 /* fast check on the exponent */
4628 if (n > (LIMB_BITS - 1)) {
4629 if (e != 0)
4630 return FALSE;
4631 er = 0;
4632 } else {
4633 if ((e & (((limb_t)1 << n) - 1)) != 0)
4634 return FALSE;
4635 er = e >> n;
4636 }
4637 /* every perfect odd square = 1 modulo 8 */
4638 v = get_bits(x->tab, x->len, x->len * LIMB_BITS - x->expn + e);
4639 if ((v & 7) != 1)
4640 return FALSE;
4641
4642 bf_init(s, T);
4643 bf_set(T, x);
4644 T->expn -= e;
4645 for(i = 0; i < n; i++) {
4646 if (i != 0)
4647 bf_set(T, r);
4648 if (bf_sqrtrem(r, NULL, T) != 0)
4649 return FALSE;
4650 }
4651 r->expn += er;
4652 return TRUE;
4653 }
4654
4655 /* prec = BF_PREC_INF is accepted for x and y integers and y >= 0 */
bf_pow(bf_t * r,const bf_t * x,const bf_t * y,limb_t prec,bf_flags_t flags)4656 int bf_pow(bf_t *r, const bf_t *x, const bf_t *y, limb_t prec, bf_flags_t flags)
4657 {
4658 bf_context_t *s = r->ctx;
4659 bf_t T_s, *T = &T_s;
4660 bf_t ytmp_s;
4661 BOOL y_is_int, y_is_odd;
4662 int r_sign, ret, rnd_mode;
4663 slimb_t y_emin;
4664
4665 if (x->len == 0 || y->len == 0) {
4666 if (y->expn == BF_EXP_ZERO) {
4667 /* pow(x, 0) = 1 */
4668 bf_set_ui(r, 1);
4669 } else if (x->expn == BF_EXP_NAN) {
4670 bf_set_nan(r);
4671 } else {
4672 int cmp_x_abs_1;
4673 bf_set_ui(r, 1);
4674 cmp_x_abs_1 = bf_cmpu(x, r);
4675 if (cmp_x_abs_1 == 0 && (flags & BF_POW_JS_QUIRKS) &&
4676 (y->expn >= BF_EXP_INF)) {
4677 bf_set_nan(r);
4678 } else if (cmp_x_abs_1 == 0 &&
4679 (!x->sign || y->expn != BF_EXP_NAN)) {
4680 /* pow(1, y) = 1 even if y = NaN */
4681 /* pow(-1, +/-inf) = 1 */
4682 } else if (y->expn == BF_EXP_NAN) {
4683 bf_set_nan(r);
4684 } else if (y->expn == BF_EXP_INF) {
4685 if (y->sign == (cmp_x_abs_1 > 0)) {
4686 bf_set_zero(r, 0);
4687 } else {
4688 bf_set_inf(r, 0);
4689 }
4690 } else {
4691 y_emin = bf_get_exp_min(y);
4692 y_is_odd = (y_emin == 0);
4693 if (y->sign == (x->expn == BF_EXP_ZERO)) {
4694 bf_set_inf(r, y_is_odd & x->sign);
4695 if (y->sign) {
4696 /* pow(0, y) with y < 0 */
4697 return BF_ST_DIVIDE_ZERO;
4698 }
4699 } else {
4700 bf_set_zero(r, y_is_odd & x->sign);
4701 }
4702 }
4703 }
4704 return 0;
4705 }
4706 bf_init(s, T);
4707 bf_set(T, x);
4708 y_emin = bf_get_exp_min(y);
4709 y_is_int = (y_emin >= 0);
4710 rnd_mode = flags & BF_RND_MASK;
4711 if (x->sign) {
4712 if (!y_is_int) {
4713 bf_set_nan(r);
4714 bf_delete(T);
4715 return BF_ST_INVALID_OP;
4716 }
4717 y_is_odd = (y_emin == 0);
4718 r_sign = y_is_odd;
4719 /* change the directed rounding mode if the sign of the result
4720 is changed */
4721 if (r_sign && (rnd_mode == BF_RNDD || rnd_mode == BF_RNDU))
4722 flags ^= 1;
4723 bf_neg(T);
4724 } else {
4725 r_sign = 0;
4726 }
4727
4728 bf_set_ui(r, 1);
4729 if (bf_cmp_eq(T, r)) {
4730 /* abs(x) = 1: nothing more to do */
4731 ret = 0;
4732 } else {
4733 /* check the overflow/underflow cases */
4734 {
4735 bf_t al_s, *al = &al_s;
4736 bf_t ah_s, *ah = &ah_s;
4737 limb_t precl = LIMB_BITS;
4738
4739 bf_init(s, al);
4740 bf_init(s, ah);
4741 /* compute bounds of log(abs(x)) * y with a low precision */
4742 /* XXX: compute bf_log() once */
4743 /* XXX: add a fast test before this slow test */
4744 bf_log(al, T, precl, BF_RNDD);
4745 bf_log(ah, T, precl, BF_RNDU);
4746 bf_mul(al, al, y, precl, BF_RNDD ^ y->sign);
4747 bf_mul(ah, ah, y, precl, BF_RNDU ^ y->sign);
4748 ret = check_exp_underflow_overflow(s, r, al, ah, prec, flags);
4749 bf_delete(al);
4750 bf_delete(ah);
4751 if (ret)
4752 goto done;
4753 }
4754
4755 if (y_is_int) {
4756 slimb_t T_bits, e;
4757 int_pow:
4758 T_bits = T->expn - bf_get_exp_min(T);
4759 if (T_bits == 1) {
4760 /* pow(2^b, y) = 2^(b*y) */
4761 bf_mul_si(T, y, T->expn - 1, LIMB_BITS, BF_RNDZ);
4762 bf_get_limb(&e, T, 0);
4763 bf_set_ui(r, 1);
4764 ret = bf_mul_2exp(r, e, prec, flags);
4765 } else if (prec == BF_PREC_INF) {
4766 slimb_t y1;
4767 /* specific case for infinite precision (integer case) */
4768 bf_get_limb(&y1, y, 0);
4769 assert(!y->sign);
4770 /* x must be an integer, so abs(x) >= 2 */
4771 if (y1 >= ((slimb_t)1 << BF_EXP_BITS_MAX)) {
4772 bf_delete(T);
4773 return bf_set_overflow(r, 0, BF_PREC_INF, flags);
4774 }
4775 ret = bf_pow_ui(r, T, y1, BF_PREC_INF, BF_RNDZ);
4776 } else {
4777 if (y->expn <= 31) {
4778 /* small enough power: use exponentiation in all cases */
4779 } else if (y->sign) {
4780 /* cannot be exact */
4781 goto general_case;
4782 } else {
4783 if (rnd_mode == BF_RNDF)
4784 goto general_case; /* no need to track exact results */
4785 /* see if the result has a chance to be exact:
4786 if x=a*2^b (a odd), x^y=a^y*2^(b*y)
4787 x^y needs a precision of at least floor_log2(a)*y bits
4788 */
4789 bf_mul_si(r, y, T_bits - 1, LIMB_BITS, BF_RNDZ);
4790 bf_get_limb(&e, r, 0);
4791 if (prec < e)
4792 goto general_case;
4793 }
4794 ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_int, (void *)y);
4795 }
4796 } else {
4797 if (rnd_mode != BF_RNDF) {
4798 bf_t *y1;
4799 if (y_emin < 0 && check_exact_power2n(r, T, -y_emin)) {
4800 /* the problem is reduced to a power to an integer */
4801 #if 0
4802 printf("\nn=%" PRId64 "\n", -(int64_t)y_emin);
4803 bf_print_str("T", T);
4804 bf_print_str("r", r);
4805 #endif
4806 bf_set(T, r);
4807 y1 = &ytmp_s;
4808 y1->tab = y->tab;
4809 y1->len = y->len;
4810 y1->sign = y->sign;
4811 y1->expn = y->expn - y_emin;
4812 y = y1;
4813 goto int_pow;
4814 }
4815 }
4816 general_case:
4817 ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_generic, (void *)y);
4818 }
4819 }
4820 done:
4821 bf_delete(T);
4822 r->sign = r_sign;
4823 return ret;
4824 }
4825
4826 /* compute sqrt(-2*x-x^2) to get |sin(x)| from cos(x) - 1. */
bf_sqrt_sin(bf_t * r,const bf_t * x,limb_t prec1)4827 static void bf_sqrt_sin(bf_t *r, const bf_t *x, limb_t prec1)
4828 {
4829 bf_context_t *s = r->ctx;
4830 bf_t T_s, *T = &T_s;
4831 bf_init(s, T);
4832 bf_set(T, x);
4833 bf_mul(r, T, T, prec1, BF_RNDN);
4834 bf_mul_2exp(T, 1, BF_PREC_INF, BF_RNDZ);
4835 bf_add(T, T, r, prec1, BF_RNDN);
4836 bf_neg(T);
4837 bf_sqrt(r, T, prec1, BF_RNDF);
4838 bf_delete(T);
4839 }
4840
bf_sincos(bf_t * s,bf_t * c,const bf_t * a,limb_t prec)4841 static int bf_sincos(bf_t *s, bf_t *c, const bf_t *a, limb_t prec)
4842 {
4843 bf_context_t *s1 = a->ctx;
4844 bf_t T_s, *T = &T_s;
4845 bf_t U_s, *U = &U_s;
4846 bf_t r_s, *r = &r_s;
4847 slimb_t K, prec1, i, l, mod, prec2;
4848 int is_neg;
4849
4850 assert(c != a && s != a);
4851
4852 bf_init(s1, T);
4853 bf_init(s1, U);
4854 bf_init(s1, r);
4855
4856 /* XXX: precision analysis */
4857 K = bf_isqrt(prec / 2);
4858 l = prec / (2 * K) + 1;
4859 prec1 = prec + 2 * K + l + 8;
4860
4861 /* after the modulo reduction, -pi/4 <= T <= pi/4 */
4862 if (a->expn <= -1) {
4863 /* abs(a) <= 0.25: no modulo reduction needed */
4864 bf_set(T, a);
4865 mod = 0;
4866 } else {
4867 slimb_t cancel;
4868 cancel = 0;
4869 for(;;) {
4870 prec2 = prec1 + a->expn + cancel;
4871 bf_const_pi(U, prec2, BF_RNDF);
4872 bf_mul_2exp(U, -1, BF_PREC_INF, BF_RNDZ);
4873 bf_remquo(&mod, T, a, U, prec2, BF_RNDN, BF_RNDN);
4874 // printf("T.expn=%ld prec2=%ld\n", T->expn, prec2);
4875 if (mod == 0 || (T->expn != BF_EXP_ZERO &&
4876 (T->expn + prec2) >= (prec1 - 1)))
4877 break;
4878 /* increase the number of bits until the precision is good enough */
4879 cancel = bf_max(-T->expn, (cancel + 1) * 3 / 2);
4880 }
4881 mod &= 3;
4882 }
4883
4884 is_neg = T->sign;
4885
4886 /* compute cosm1(x) = cos(x) - 1 */
4887 bf_mul(T, T, T, prec1, BF_RNDN);
4888 bf_mul_2exp(T, -2 * K, BF_PREC_INF, BF_RNDZ);
4889
4890 /* Taylor expansion:
4891 -x^2/2 + x^4/4! - x^6/6! + ...
4892 */
4893 bf_set_ui(r, 1);
4894 for(i = l ; i >= 1; i--) {
4895 bf_set_ui(U, 2 * i - 1);
4896 bf_mul_ui(U, U, 2 * i, BF_PREC_INF, BF_RNDZ);
4897 bf_div(U, T, U, prec1, BF_RNDN);
4898 bf_mul(r, r, U, prec1, BF_RNDN);
4899 bf_neg(r);
4900 if (i != 1)
4901 bf_add_si(r, r, 1, prec1, BF_RNDN);
4902 }
4903 bf_delete(U);
4904
4905 /* undo argument reduction:
4906 cosm1(2*x)= 2*(2*cosm1(x)+cosm1(x)^2)
4907 */
4908 for(i = 0; i < K; i++) {
4909 bf_mul(T, r, r, prec1, BF_RNDN);
4910 bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4911 bf_add(r, r, T, prec1, BF_RNDN);
4912 bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4913 }
4914 bf_delete(T);
4915
4916 if (c) {
4917 if ((mod & 1) == 0) {
4918 bf_add_si(c, r, 1, prec1, BF_RNDN);
4919 } else {
4920 bf_sqrt_sin(c, r, prec1);
4921 c->sign = is_neg ^ 1;
4922 }
4923 c->sign ^= mod >> 1;
4924 }
4925 if (s) {
4926 if ((mod & 1) == 0) {
4927 bf_sqrt_sin(s, r, prec1);
4928 s->sign = is_neg;
4929 } else {
4930 bf_add_si(s, r, 1, prec1, BF_RNDN);
4931 }
4932 s->sign ^= mod >> 1;
4933 }
4934 bf_delete(r);
4935 return BF_ST_INEXACT;
4936 }
4937
bf_cos_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4938 static int bf_cos_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4939 {
4940 return bf_sincos(NULL, r, a, prec);
4941 }
4942
bf_cos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4943 int bf_cos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4944 {
4945 if (a->len == 0) {
4946 if (a->expn == BF_EXP_NAN) {
4947 bf_set_nan(r);
4948 return 0;
4949 } else if (a->expn == BF_EXP_INF) {
4950 bf_set_nan(r);
4951 return BF_ST_INVALID_OP;
4952 } else {
4953 bf_set_ui(r, 1);
4954 return 0;
4955 }
4956 }
4957
4958 /* small argument case: result = 1+r(x) with r(x) = -x^2/2 +
4959 O(X^4). We assume r(x) < 2^(2*EXP(x) - 1). */
4960 if (a->expn < 0) {
4961 slimb_t e;
4962 e = 2 * a->expn - 1;
4963 if (e < -(prec + 2)) {
4964 bf_set_ui(r, 1);
4965 return bf_add_epsilon(r, r, e, 1, prec, flags);
4966 }
4967 }
4968
4969 return bf_ziv_rounding(r, a, prec, flags, bf_cos_internal, NULL);
4970 }
4971
bf_sin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4972 static int bf_sin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4973 {
4974 return bf_sincos(r, NULL, a, prec);
4975 }
4976
bf_sin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4977 int bf_sin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4978 {
4979 if (a->len == 0) {
4980 if (a->expn == BF_EXP_NAN) {
4981 bf_set_nan(r);
4982 return 0;
4983 } else if (a->expn == BF_EXP_INF) {
4984 bf_set_nan(r);
4985 return BF_ST_INVALID_OP;
4986 } else {
4987 bf_set_zero(r, a->sign);
4988 return 0;
4989 }
4990 }
4991
4992 /* small argument case: result = x+r(x) with r(x) = -x^3/6 +
4993 O(X^5). We assume r(x) < 2^(3*EXP(x) - 2). */
4994 if (a->expn < 0) {
4995 slimb_t e;
4996 e = sat_add(2 * a->expn, a->expn - 2);
4997 if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
4998 bf_set(r, a);
4999 return bf_add_epsilon(r, r, e, 1 - a->sign, prec, flags);
5000 }
5001 }
5002
5003 return bf_ziv_rounding(r, a, prec, flags, bf_sin_internal, NULL);
5004 }
5005
bf_tan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)5006 static int bf_tan_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
5007 {
5008 bf_context_t *s = r->ctx;
5009 bf_t T_s, *T = &T_s;
5010 limb_t prec1;
5011
5012 /* XXX: precision analysis */
5013 prec1 = prec + 8;
5014 bf_init(s, T);
5015 bf_sincos(r, T, a, prec1);
5016 bf_div(r, r, T, prec1, BF_RNDF);
5017 bf_delete(T);
5018 return BF_ST_INEXACT;
5019 }
5020
bf_tan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5021 int bf_tan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5022 {
5023 assert(r != a);
5024 if (a->len == 0) {
5025 if (a->expn == BF_EXP_NAN) {
5026 bf_set_nan(r);
5027 return 0;
5028 } else if (a->expn == BF_EXP_INF) {
5029 bf_set_nan(r);
5030 return BF_ST_INVALID_OP;
5031 } else {
5032 bf_set_zero(r, a->sign);
5033 return 0;
5034 }
5035 }
5036
5037 /* small argument case: result = x+r(x) with r(x) = x^3/3 +
5038 O(X^5). We assume r(x) < 2^(3*EXP(x) - 1). */
5039 if (a->expn < 0) {
5040 slimb_t e;
5041 e = sat_add(2 * a->expn, a->expn - 1);
5042 if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
5043 bf_set(r, a);
5044 return bf_add_epsilon(r, r, e, a->sign, prec, flags);
5045 }
5046 }
5047
5048 return bf_ziv_rounding(r, a, prec, flags, bf_tan_internal, NULL);
5049 }
5050
5051 /* if add_pi2 is true, add pi/2 to the result (used for acos(x) to
5052 avoid cancellation) */
bf_atan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)5053 static int bf_atan_internal(bf_t *r, const bf_t *a, limb_t prec,
5054 void *opaque)
5055 {
5056 bf_context_t *s = r->ctx;
5057 BOOL add_pi2 = (BOOL)(intptr_t)opaque;
5058 bf_t T_s, *T = &T_s;
5059 bf_t U_s, *U = &U_s;
5060 bf_t V_s, *V = &V_s;
5061 bf_t X2_s, *X2 = &X2_s;
5062 int cmp_1;
5063 slimb_t prec1, i, K, l;
5064
5065 /* XXX: precision analysis */
5066 K = bf_isqrt((prec + 1) / 2);
5067 l = prec / (2 * K) + 1;
5068 prec1 = prec + K + 2 * l + 32;
5069 // printf("prec=%d K=%d l=%d prec1=%d\n", (int)prec, (int)K, (int)l, (int)prec1);
5070
5071 bf_init(s, T);
5072 cmp_1 = (a->expn >= 1); /* a >= 1 */
5073 if (cmp_1) {
5074 bf_set_ui(T, 1);
5075 bf_div(T, T, a, prec1, BF_RNDN);
5076 } else {
5077 bf_set(T, a);
5078 }
5079
5080 /* abs(T) <= 1 */
5081
5082 /* argument reduction */
5083
5084 bf_init(s, U);
5085 bf_init(s, V);
5086 bf_init(s, X2);
5087 for(i = 0; i < K; i++) {
5088 /* T = T / (1 + sqrt(1 + T^2)) */
5089 bf_mul(U, T, T, prec1, BF_RNDN);
5090 bf_add_si(U, U, 1, prec1, BF_RNDN);
5091 bf_sqrt(V, U, prec1, BF_RNDN);
5092 bf_add_si(V, V, 1, prec1, BF_RNDN);
5093 bf_div(T, T, V, prec1, BF_RNDN);
5094 }
5095
5096 /* Taylor series:
5097 x - x^3/3 + ... + (-1)^ l * y^(2*l + 1) / (2*l+1)
5098 */
5099 bf_mul(X2, T, T, prec1, BF_RNDN);
5100 bf_set_ui(r, 0);
5101 for(i = l; i >= 1; i--) {
5102 bf_set_si(U, 1);
5103 bf_set_ui(V, 2 * i + 1);
5104 bf_div(U, U, V, prec1, BF_RNDN);
5105 bf_neg(r);
5106 bf_add(r, r, U, prec1, BF_RNDN);
5107 bf_mul(r, r, X2, prec1, BF_RNDN);
5108 }
5109 bf_neg(r);
5110 bf_add_si(r, r, 1, prec1, BF_RNDN);
5111 bf_mul(r, r, T, prec1, BF_RNDN);
5112
5113 /* undo the argument reduction */
5114 bf_mul_2exp(r, K, BF_PREC_INF, BF_RNDZ);
5115
5116 bf_delete(U);
5117 bf_delete(V);
5118 bf_delete(X2);
5119
5120 i = add_pi2;
5121 if (cmp_1 > 0) {
5122 /* undo the inversion : r = sign(a)*PI/2 - r */
5123 bf_neg(r);
5124 i += 1 - 2 * a->sign;
5125 }
5126 /* add i*(pi/2) with -1 <= i <= 2 */
5127 if (i != 0) {
5128 bf_const_pi(T, prec1, BF_RNDF);
5129 if (i != 2)
5130 bf_mul_2exp(T, -1, BF_PREC_INF, BF_RNDZ);
5131 T->sign = (i < 0);
5132 bf_add(r, T, r, prec1, BF_RNDN);
5133 }
5134
5135 bf_delete(T);
5136 return BF_ST_INEXACT;
5137 }
5138
bf_atan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5139 int bf_atan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5140 {
5141 bf_context_t *s = r->ctx;
5142 bf_t T_s, *T = &T_s;
5143 int res;
5144
5145 if (a->len == 0) {
5146 if (a->expn == BF_EXP_NAN) {
5147 bf_set_nan(r);
5148 return 0;
5149 } else if (a->expn == BF_EXP_INF) {
5150 /* -PI/2 or PI/2 */
5151 bf_const_pi_signed(r, a->sign, prec, flags);
5152 bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
5153 return BF_ST_INEXACT;
5154 } else {
5155 bf_set_zero(r, a->sign);
5156 return 0;
5157 }
5158 }
5159
5160 bf_init(s, T);
5161 bf_set_ui(T, 1);
5162 res = bf_cmpu(a, T);
5163 bf_delete(T);
5164 if (res == 0) {
5165 /* short cut: abs(a) == 1 -> +/-pi/4 */
5166 bf_const_pi_signed(r, a->sign, prec, flags);
5167 bf_mul_2exp(r, -2, BF_PREC_INF, BF_RNDZ);
5168 return BF_ST_INEXACT;
5169 }
5170
5171 /* small argument case: result = x+r(x) with r(x) = -x^3/3 +
5172 O(X^5). We assume r(x) < 2^(3*EXP(x) - 1). */
5173 if (a->expn < 0) {
5174 slimb_t e;
5175 e = sat_add(2 * a->expn, a->expn - 1);
5176 if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
5177 bf_set(r, a);
5178 return bf_add_epsilon(r, r, e, 1 - a->sign, prec, flags);
5179 }
5180 }
5181
5182 return bf_ziv_rounding(r, a, prec, flags, bf_atan_internal, (void *)FALSE);
5183 }
5184
bf_atan2_internal(bf_t * r,const bf_t * y,limb_t prec,void * opaque)5185 static int bf_atan2_internal(bf_t *r, const bf_t *y, limb_t prec, void *opaque)
5186 {
5187 bf_context_t *s = r->ctx;
5188 const bf_t *x = opaque;
5189 bf_t T_s, *T = &T_s;
5190 limb_t prec1;
5191 int ret;
5192
5193 if (y->expn == BF_EXP_NAN || x->expn == BF_EXP_NAN) {
5194 bf_set_nan(r);
5195 return 0;
5196 }
5197
5198 /* compute atan(y/x) assumming inf/inf = 1 and 0/0 = 0 */
5199 bf_init(s, T);
5200 prec1 = prec + 32;
5201 if (y->expn == BF_EXP_INF && x->expn == BF_EXP_INF) {
5202 bf_set_ui(T, 1);
5203 T->sign = y->sign ^ x->sign;
5204 } else if (y->expn == BF_EXP_ZERO && x->expn == BF_EXP_ZERO) {
5205 bf_set_zero(T, y->sign ^ x->sign);
5206 } else {
5207 bf_div(T, y, x, prec1, BF_RNDF);
5208 }
5209 ret = bf_atan(r, T, prec1, BF_RNDF);
5210
5211 if (x->sign) {
5212 /* if x < 0 (it includes -0), return sign(y)*pi + atan(y/x) */
5213 bf_const_pi(T, prec1, BF_RNDF);
5214 T->sign = y->sign;
5215 bf_add(r, r, T, prec1, BF_RNDN);
5216 ret |= BF_ST_INEXACT;
5217 }
5218
5219 bf_delete(T);
5220 return ret;
5221 }
5222
bf_atan2(bf_t * r,const bf_t * y,const bf_t * x,limb_t prec,bf_flags_t flags)5223 int bf_atan2(bf_t *r, const bf_t *y, const bf_t *x,
5224 limb_t prec, bf_flags_t flags)
5225 {
5226 return bf_ziv_rounding(r, y, prec, flags, bf_atan2_internal, (void *)x);
5227 }
5228
bf_asin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)5229 static int bf_asin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
5230 {
5231 bf_context_t *s = r->ctx;
5232 BOOL is_acos = (BOOL)(intptr_t)opaque;
5233 bf_t T_s, *T = &T_s;
5234 limb_t prec1, prec2;
5235
5236 /* asin(x) = atan(x/sqrt(1-x^2))
5237 acos(x) = pi/2 - asin(x) */
5238 prec1 = prec + 8;
5239 /* increase the precision in x^2 to compensate the cancellation in
5240 (1-x^2) if x is close to 1 */
5241 /* XXX: use less precision when possible */
5242 if (a->expn >= 0)
5243 prec2 = BF_PREC_INF;
5244 else
5245 prec2 = prec1;
5246 bf_init(s, T);
5247 bf_mul(T, a, a, prec2, BF_RNDN);
5248 bf_neg(T);
5249 bf_add_si(T, T, 1, prec2, BF_RNDN);
5250
5251 bf_sqrt(r, T, prec1, BF_RNDN);
5252 bf_div(T, a, r, prec1, BF_RNDN);
5253 if (is_acos)
5254 bf_neg(T);
5255 bf_atan_internal(r, T, prec1, (void *)(intptr_t)is_acos);
5256 bf_delete(T);
5257 return BF_ST_INEXACT;
5258 }
5259
bf_asin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5260 int bf_asin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5261 {
5262 bf_context_t *s = r->ctx;
5263 bf_t T_s, *T = &T_s;
5264 int res;
5265
5266 if (a->len == 0) {
5267 if (a->expn == BF_EXP_NAN) {
5268 bf_set_nan(r);
5269 return 0;
5270 } else if (a->expn == BF_EXP_INF) {
5271 bf_set_nan(r);
5272 return BF_ST_INVALID_OP;
5273 } else {
5274 bf_set_zero(r, a->sign);
5275 return 0;
5276 }
5277 }
5278 bf_init(s, T);
5279 bf_set_ui(T, 1);
5280 res = bf_cmpu(a, T);
5281 bf_delete(T);
5282 if (res > 0) {
5283 bf_set_nan(r);
5284 return BF_ST_INVALID_OP;
5285 }
5286
5287 /* small argument case: result = x+r(x) with r(x) = x^3/6 +
5288 O(X^5). We assume r(x) < 2^(3*EXP(x) - 2). */
5289 if (a->expn < 0) {
5290 slimb_t e;
5291 e = sat_add(2 * a->expn, a->expn - 2);
5292 if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
5293 bf_set(r, a);
5294 return bf_add_epsilon(r, r, e, a->sign, prec, flags);
5295 }
5296 }
5297
5298 return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)FALSE);
5299 }
5300
bf_acos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5301 int bf_acos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5302 {
5303 bf_context_t *s = r->ctx;
5304 bf_t T_s, *T = &T_s;
5305 int res;
5306
5307 if (a->len == 0) {
5308 if (a->expn == BF_EXP_NAN) {
5309 bf_set_nan(r);
5310 return 0;
5311 } else if (a->expn == BF_EXP_INF) {
5312 bf_set_nan(r);
5313 return BF_ST_INVALID_OP;
5314 } else {
5315 bf_const_pi(r, prec, flags);
5316 bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
5317 return BF_ST_INEXACT;
5318 }
5319 }
5320 bf_init(s, T);
5321 bf_set_ui(T, 1);
5322 res = bf_cmpu(a, T);
5323 bf_delete(T);
5324 if (res > 0) {
5325 bf_set_nan(r);
5326 return BF_ST_INVALID_OP;
5327 } else if (res == 0 && a->sign == 0) {
5328 bf_set_zero(r, 0);
5329 return 0;
5330 }
5331
5332 return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)TRUE);
5333 }
5334
5335 /***************************************************************/
5336 /* decimal floating point numbers */
5337
5338 #ifdef USE_BF_DEC
5339
5340 #define adddq(r1, r0, a1, a0) \
5341 do { \
5342 limb_t __t = r0; \
5343 r0 += (a0); \
5344 r1 += (a1) + (r0 < __t); \
5345 } while (0)
5346
5347 #define subdq(r1, r0, a1, a0) \
5348 do { \
5349 limb_t __t = r0; \
5350 r0 -= (a0); \
5351 r1 -= (a1) + (r0 > __t); \
5352 } while (0)
5353
5354 #if LIMB_BITS == 64
5355
5356 /* Note: we assume __int128 is available */
5357 #define muldq(r1, r0, a, b) \
5358 do { \
5359 unsigned __int128 __t; \
5360 __t = (unsigned __int128)(a) * (unsigned __int128)(b); \
5361 r0 = __t; \
5362 r1 = __t >> 64; \
5363 } while (0)
5364
5365 #define divdq(q, r, a1, a0, b) \
5366 do { \
5367 unsigned __int128 __t; \
5368 limb_t __b = (b); \
5369 __t = ((unsigned __int128)(a1) << 64) | (a0); \
5370 q = __t / __b; \
5371 r = __t % __b; \
5372 } while (0)
5373
5374 #else
5375
5376 #define muldq(r1, r0, a, b) \
5377 do { \
5378 uint64_t __t; \
5379 __t = (uint64_t)(a) * (uint64_t)(b); \
5380 r0 = __t; \
5381 r1 = __t >> 32; \
5382 } while (0)
5383
5384 #define divdq(q, r, a1, a0, b) \
5385 do { \
5386 uint64_t __t; \
5387 limb_t __b = (b); \
5388 __t = ((uint64_t)(a1) << 32) | (a0); \
5389 q = __t / __b; \
5390 r = __t % __b; \
5391 } while (0)
5392
5393 #endif /* LIMB_BITS != 64 */
5394
shrd(limb_t low,limb_t high,long shift)5395 static inline __maybe_unused limb_t shrd(limb_t low, limb_t high, long shift)
5396 {
5397 if (shift != 0)
5398 low = (low >> shift) | (high << (LIMB_BITS - shift));
5399 return low;
5400 }
5401
shld(limb_t a1,limb_t a0,long shift)5402 static inline __maybe_unused limb_t shld(limb_t a1, limb_t a0, long shift)
5403 {
5404 if (shift != 0)
5405 return (a1 << shift) | (a0 >> (LIMB_BITS - shift));
5406 else
5407 return a1;
5408 }
5409
5410 #if LIMB_DIGITS == 19
5411
5412 /* WARNING: hardcoded for b = 1e19. It is assumed that:
5413 0 <= a1 < 2^63 */
5414 #define divdq_base(q, r, a1, a0)\
5415 do {\
5416 uint64_t __a0, __a1, __t0, __t1, __b = BF_DEC_BASE; \
5417 __a0 = a0;\
5418 __a1 = a1;\
5419 __t0 = __a1;\
5420 __t0 = shld(__t0, __a0, 1);\
5421 muldq(q, __t1, __t0, UINT64_C(17014118346046923173)); \
5422 muldq(__t1, __t0, q, __b);\
5423 subdq(__a1, __a0, __t1, __t0);\
5424 subdq(__a1, __a0, 1, __b * 2); \
5425 __t0 = (slimb_t)__a1 >> 1; \
5426 q += 2 + __t0;\
5427 adddq(__a1, __a0, 0, __b & __t0);\
5428 q += __a1; \
5429 __a0 += __b & __a1; \
5430 r = __a0;\
5431 } while(0)
5432
5433 #elif LIMB_DIGITS == 9
5434
5435 /* WARNING: hardcoded for b = 1e9. It is assumed that:
5436 0 <= a1 < 2^29 */
5437 #define divdq_base(q, r, a1, a0)\
5438 do {\
5439 uint32_t __t0, __t1, __b = BF_DEC_BASE; \
5440 __t0 = a1;\
5441 __t1 = a0;\
5442 __t0 = (__t0 << 3) | (__t1 >> (32 - 3)); \
5443 muldq(q, __t1, __t0, 2305843009U);\
5444 r = a0 - q * __b;\
5445 __t1 = (r >= __b);\
5446 q += __t1;\
5447 if (__t1)\
5448 r -= __b;\
5449 } while(0)
5450
5451 #endif
5452
5453 /* fast integer division by a fixed constant */
5454
5455 typedef struct FastDivData {
5456 limb_t m1; /* multiplier */
5457 int8_t shift1;
5458 int8_t shift2;
5459 } FastDivData;
5460
5461 /* From "Division by Invariant Integers using Multiplication" by
5462 Torborn Granlund and Peter L. Montgomery */
5463 /* d must be != 0 */
fast_udiv_init(FastDivData * s,limb_t d)5464 static inline __maybe_unused void fast_udiv_init(FastDivData *s, limb_t d)
5465 {
5466 int l;
5467 limb_t q, r, m1;
5468 if (d == 1)
5469 l = 0;
5470 else
5471 l = 64 - clz64(d - 1);
5472 divdq(q, r, ((limb_t)1 << l) - d, 0, d);
5473 (void)r;
5474 m1 = q + 1;
5475 // printf("d=%lu l=%d m1=0x%016lx\n", d, l, m1);
5476 s->m1 = m1;
5477 s->shift1 = l;
5478 if (s->shift1 > 1)
5479 s->shift1 = 1;
5480 s->shift2 = l - 1;
5481 if (s->shift2 < 0)
5482 s->shift2 = 0;
5483 }
5484
fast_udiv(limb_t a,const FastDivData * s)5485 static inline limb_t fast_udiv(limb_t a, const FastDivData *s)
5486 {
5487 limb_t t0, t1;
5488 muldq(t1, t0, s->m1, a);
5489 t0 = (a - t1) >> s->shift1;
5490 return (t1 + t0) >> s->shift2;
5491 }
5492
5493 /* contains 10^i */
5494 const limb_t mp_pow_dec[LIMB_DIGITS + 1] = {
5495 1U,
5496 10U,
5497 100U,
5498 1000U,
5499 10000U,
5500 100000U,
5501 1000000U,
5502 10000000U,
5503 100000000U,
5504 1000000000U,
5505 #if LIMB_BITS == 64
5506 10000000000U,
5507 100000000000U,
5508 1000000000000U,
5509 10000000000000U,
5510 100000000000000U,
5511 1000000000000000U,
5512 10000000000000000U,
5513 100000000000000000U,
5514 1000000000000000000U,
5515 10000000000000000000U,
5516 #endif
5517 };
5518
5519 /* precomputed from fast_udiv_init(10^i) */
5520 static const FastDivData mp_pow_div[LIMB_DIGITS + 1] = {
5521 #if LIMB_BITS == 32
5522 { 0x00000001, 0, 0 },
5523 { 0x9999999a, 1, 3 },
5524 { 0x47ae147b, 1, 6 },
5525 { 0x0624dd30, 1, 9 },
5526 { 0xa36e2eb2, 1, 13 },
5527 { 0x4f8b588f, 1, 16 },
5528 { 0x0c6f7a0c, 1, 19 },
5529 { 0xad7f29ac, 1, 23 },
5530 { 0x5798ee24, 1, 26 },
5531 { 0x12e0be83, 1, 29 },
5532 #else
5533 { 0x0000000000000001, 0, 0 },
5534 { 0x999999999999999a, 1, 3 },
5535 { 0x47ae147ae147ae15, 1, 6 },
5536 { 0x0624dd2f1a9fbe77, 1, 9 },
5537 { 0xa36e2eb1c432ca58, 1, 13 },
5538 { 0x4f8b588e368f0847, 1, 16 },
5539 { 0x0c6f7a0b5ed8d36c, 1, 19 },
5540 { 0xad7f29abcaf48579, 1, 23 },
5541 { 0x5798ee2308c39dfa, 1, 26 },
5542 { 0x12e0be826d694b2f, 1, 29 },
5543 { 0xb7cdfd9d7bdbab7e, 1, 33 },
5544 { 0x5fd7fe17964955fe, 1, 36 },
5545 { 0x19799812dea11198, 1, 39 },
5546 { 0xc25c268497681c27, 1, 43 },
5547 { 0x6849b86a12b9b01f, 1, 46 },
5548 { 0x203af9ee756159b3, 1, 49 },
5549 { 0xcd2b297d889bc2b7, 1, 53 },
5550 { 0x70ef54646d496893, 1, 56 },
5551 { 0x2725dd1d243aba0f, 1, 59 },
5552 { 0xd83c94fb6d2ac34d, 1, 63 },
5553 #endif
5554 };
5555
5556 /* divide by 10^shift with 0 <= shift <= LIMB_DIGITS */
fast_shr_dec(limb_t a,int shift)5557 static inline limb_t fast_shr_dec(limb_t a, int shift)
5558 {
5559 return fast_udiv(a, &mp_pow_div[shift]);
5560 }
5561
5562 /* division and remainder by 10^shift */
5563 #define fast_shr_rem_dec(q, r, a, shift) q = fast_shr_dec(a, shift), r = a - q * mp_pow_dec[shift]
5564
mp_add_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5565 limb_t mp_add_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5566 mp_size_t n, limb_t carry)
5567 {
5568 limb_t base = BF_DEC_BASE;
5569 mp_size_t i;
5570 limb_t k, a, v;
5571
5572 k=carry;
5573 for(i=0;i<n;i++) {
5574 /* XXX: reuse the trick in add_mod */
5575 v = op1[i];
5576 a = v + op2[i] + k - base;
5577 k = a <= v;
5578 if (!k)
5579 a += base;
5580 res[i]=a;
5581 }
5582 return k;
5583 }
5584
mp_add_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5585 limb_t mp_add_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5586 {
5587 limb_t base = BF_DEC_BASE;
5588 mp_size_t i;
5589 limb_t k, a, v;
5590
5591 k=b;
5592 for(i=0;i<n;i++) {
5593 v = tab[i];
5594 a = v + k - base;
5595 k = a <= v;
5596 if (!k)
5597 a += base;
5598 tab[i] = a;
5599 if (k == 0)
5600 break;
5601 }
5602 return k;
5603 }
5604
mp_sub_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5605 limb_t mp_sub_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5606 mp_size_t n, limb_t carry)
5607 {
5608 limb_t base = BF_DEC_BASE;
5609 mp_size_t i;
5610 limb_t k, v, a;
5611
5612 k=carry;
5613 for(i=0;i<n;i++) {
5614 v = op1[i];
5615 a = v - op2[i] - k;
5616 k = a > v;
5617 if (k)
5618 a += base;
5619 res[i] = a;
5620 }
5621 return k;
5622 }
5623
mp_sub_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5624 limb_t mp_sub_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5625 {
5626 limb_t base = BF_DEC_BASE;
5627 mp_size_t i;
5628 limb_t k, v, a;
5629
5630 k=b;
5631 for(i=0;i<n;i++) {
5632 v = tab[i];
5633 a = v - k;
5634 k = a > v;
5635 if (k)
5636 a += base;
5637 tab[i]=a;
5638 if (k == 0)
5639 break;
5640 }
5641 return k;
5642 }
5643
5644 /* taba[] = taba[] * b + l. 0 <= b, l <= base - 1. Return the high carry */
mp_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b,limb_t l)5645 limb_t mp_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5646 limb_t b, limb_t l)
5647 {
5648 mp_size_t i;
5649 limb_t t0, t1, r;
5650
5651 for(i = 0; i < n; i++) {
5652 muldq(t1, t0, taba[i], b);
5653 adddq(t1, t0, 0, l);
5654 divdq_base(l, r, t1, t0);
5655 tabr[i] = r;
5656 }
5657 return l;
5658 }
5659
5660 /* tabr[] += taba[] * b. 0 <= b <= base - 1. Return the value to add
5661 to the high word */
mp_add_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5662 limb_t mp_add_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5663 limb_t b)
5664 {
5665 mp_size_t i;
5666 limb_t l, t0, t1, r;
5667
5668 l = 0;
5669 for(i = 0; i < n; i++) {
5670 muldq(t1, t0, taba[i], b);
5671 adddq(t1, t0, 0, l);
5672 adddq(t1, t0, 0, tabr[i]);
5673 divdq_base(l, r, t1, t0);
5674 tabr[i] = r;
5675 }
5676 return l;
5677 }
5678
5679 /* tabr[] -= taba[] * b. 0 <= b <= base - 1. Return the value to
5680 substract to the high word. */
mp_sub_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5681 limb_t mp_sub_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5682 limb_t b)
5683 {
5684 limb_t base = BF_DEC_BASE;
5685 mp_size_t i;
5686 limb_t l, t0, t1, r, a, v, c;
5687
5688 /* XXX: optimize */
5689 l = 0;
5690 for(i = 0; i < n; i++) {
5691 muldq(t1, t0, taba[i], b);
5692 adddq(t1, t0, 0, l);
5693 divdq_base(l, r, t1, t0);
5694 v = tabr[i];
5695 a = v - r;
5696 c = a > v;
5697 if (c)
5698 a += base;
5699 /* never bigger than base because r = 0 when l = base - 1 */
5700 l += c;
5701 tabr[i] = a;
5702 }
5703 return l;
5704 }
5705
5706 /* size of the result : op1_size + op2_size. */
mp_mul_basecase_dec(limb_t * result,const limb_t * op1,mp_size_t op1_size,const limb_t * op2,mp_size_t op2_size)5707 void mp_mul_basecase_dec(limb_t *result,
5708 const limb_t *op1, mp_size_t op1_size,
5709 const limb_t *op2, mp_size_t op2_size)
5710 {
5711 mp_size_t i;
5712 limb_t r;
5713
5714 result[op1_size] = mp_mul1_dec(result, op1, op1_size, op2[0], 0);
5715
5716 for(i=1;i<op2_size;i++) {
5717 r = mp_add_mul1_dec(result + i, op1, op1_size, op2[i]);
5718 result[i + op1_size] = r;
5719 }
5720 }
5721
5722 /* taba[] = (taba[] + r*base^na) / b. 0 <= b < base. 0 <= r <
5723 b. Return the remainder. */
mp_div1_dec(limb_t * tabr,const limb_t * taba,mp_size_t na,limb_t b,limb_t r)5724 limb_t mp_div1_dec(limb_t *tabr, const limb_t *taba, mp_size_t na,
5725 limb_t b, limb_t r)
5726 {
5727 limb_t base = BF_DEC_BASE;
5728 mp_size_t i;
5729 limb_t t0, t1, q;
5730 int shift;
5731
5732 #if (BF_DEC_BASE % 2) == 0
5733 if (b == 2) {
5734 limb_t base_div2;
5735 /* Note: only works if base is even */
5736 base_div2 = base >> 1;
5737 if (r)
5738 r = base_div2;
5739 for(i = na - 1; i >= 0; i--) {
5740 t0 = taba[i];
5741 tabr[i] = (t0 >> 1) + r;
5742 r = 0;
5743 if (t0 & 1)
5744 r = base_div2;
5745 }
5746 if (r)
5747 r = 1;
5748 } else
5749 #endif
5750 if (na >= UDIV1NORM_THRESHOLD) {
5751 shift = clz(b);
5752 if (shift == 0) {
5753 /* normalized case: b >= 2^(LIMB_BITS-1) */
5754 limb_t b_inv;
5755 b_inv = udiv1norm_init(b);
5756 for(i = na - 1; i >= 0; i--) {
5757 muldq(t1, t0, r, base);
5758 adddq(t1, t0, 0, taba[i]);
5759 q = udiv1norm(&r, t1, t0, b, b_inv);
5760 tabr[i] = q;
5761 }
5762 } else {
5763 limb_t b_inv;
5764 b <<= shift;
5765 b_inv = udiv1norm_init(b);
5766 for(i = na - 1; i >= 0; i--) {
5767 muldq(t1, t0, r, base);
5768 adddq(t1, t0, 0, taba[i]);
5769 t1 = (t1 << shift) | (t0 >> (LIMB_BITS - shift));
5770 t0 <<= shift;
5771 q = udiv1norm(&r, t1, t0, b, b_inv);
5772 r >>= shift;
5773 tabr[i] = q;
5774 }
5775 }
5776 } else {
5777 for(i = na - 1; i >= 0; i--) {
5778 muldq(t1, t0, r, base);
5779 adddq(t1, t0, 0, taba[i]);
5780 divdq(q, r, t1, t0, b);
5781 tabr[i] = q;
5782 }
5783 }
5784 return r;
5785 }
5786
mp_print_str_dec(const char * str,const limb_t * tab,slimb_t n)5787 static __maybe_unused void mp_print_str_dec(const char *str,
5788 const limb_t *tab, slimb_t n)
5789 {
5790 slimb_t i;
5791 printf("%s=", str);
5792 for(i = n - 1; i >= 0; i--) {
5793 if (i != n - 1)
5794 printf("_");
5795 printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5796 }
5797 printf("\n");
5798 }
5799
mp_print_str_h_dec(const char * str,const limb_t * tab,slimb_t n,limb_t high)5800 static __maybe_unused void mp_print_str_h_dec(const char *str,
5801 const limb_t *tab, slimb_t n,
5802 limb_t high)
5803 {
5804 slimb_t i;
5805 printf("%s=", str);
5806 printf("%0*" PRIu_LIMB, LIMB_DIGITS, high);
5807 for(i = n - 1; i >= 0; i--) {
5808 printf("_");
5809 printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5810 }
5811 printf("\n");
5812 }
5813
5814 //#define DEBUG_DIV_SLOW
5815
5816 #define DIV_STATIC_ALLOC_LEN 16
5817
5818 /* return q = a / b and r = a % b.
5819
5820 taba[na] must be allocated if tabb1[nb - 1] < B / 2. tabb1[nb - 1]
5821 must be != zero. na must be >= nb. 's' can be NULL if tabb1[nb - 1]
5822 >= B / 2.
5823
5824 The remainder is is returned in taba and contains nb libms. tabq
5825 contains na - nb + 1 limbs. No overlap is permitted.
5826
5827 Running time of the standard method: (na - nb + 1) * nb
5828 Return 0 if OK, -1 if memory alloc error
5829 */
5830 /* XXX: optimize */
mp_div_dec(bf_context_t * s,limb_t * tabq,limb_t * taba,mp_size_t na,const limb_t * tabb1,mp_size_t nb)5831 static int mp_div_dec(bf_context_t *s, limb_t *tabq,
5832 limb_t *taba, mp_size_t na,
5833 const limb_t *tabb1, mp_size_t nb)
5834 {
5835 limb_t base = BF_DEC_BASE;
5836 limb_t r, mult, t0, t1, a, c, q, v, *tabb;
5837 mp_size_t i, j;
5838 limb_t static_tabb[DIV_STATIC_ALLOC_LEN];
5839
5840 #ifdef DEBUG_DIV_SLOW
5841 mp_print_str_dec("a", taba, na);
5842 mp_print_str_dec("b", tabb1, nb);
5843 #endif
5844
5845 /* normalize tabb */
5846 r = tabb1[nb - 1];
5847 assert(r != 0);
5848 i = na - nb;
5849 if (r >= BF_DEC_BASE / 2) {
5850 mult = 1;
5851 tabb = (limb_t *)tabb1;
5852 q = 1;
5853 for(j = nb - 1; j >= 0; j--) {
5854 if (taba[i + j] != tabb[j]) {
5855 if (taba[i + j] < tabb[j])
5856 q = 0;
5857 break;
5858 }
5859 }
5860 tabq[i] = q;
5861 if (q) {
5862 mp_sub_dec(taba + i, taba + i, tabb, nb, 0);
5863 }
5864 i--;
5865 } else {
5866 mult = base / (r + 1);
5867 if (likely(nb <= DIV_STATIC_ALLOC_LEN)) {
5868 tabb = static_tabb;
5869 } else {
5870 tabb = bf_malloc(s, sizeof(limb_t) * nb);
5871 if (!tabb)
5872 return -1;
5873 }
5874 mp_mul1_dec(tabb, tabb1, nb, mult, 0);
5875 taba[na] = mp_mul1_dec(taba, taba, na, mult, 0);
5876 }
5877
5878 #ifdef DEBUG_DIV_SLOW
5879 printf("mult=" FMT_LIMB "\n", mult);
5880 mp_print_str_dec("a_norm", taba, na + 1);
5881 mp_print_str_dec("b_norm", tabb, nb);
5882 #endif
5883
5884 for(; i >= 0; i--) {
5885 if (unlikely(taba[i + nb] >= tabb[nb - 1])) {
5886 /* XXX: check if it is really possible */
5887 q = base - 1;
5888 } else {
5889 muldq(t1, t0, taba[i + nb], base);
5890 adddq(t1, t0, 0, taba[i + nb - 1]);
5891 divdq(q, r, t1, t0, tabb[nb - 1]);
5892 }
5893 // printf("i=%d q1=%ld\n", i, q);
5894
5895 r = mp_sub_mul1_dec(taba + i, tabb, nb, q);
5896 // mp_dump("r1", taba + i, nb, bd);
5897 // printf("r2=%ld\n", r);
5898
5899 v = taba[i + nb];
5900 a = v - r;
5901 c = a > v;
5902 if (c)
5903 a += base;
5904 taba[i + nb] = a;
5905
5906 if (c != 0) {
5907 /* negative result */
5908 for(;;) {
5909 q--;
5910 c = mp_add_dec(taba + i, taba + i, tabb, nb, 0);
5911 /* propagate carry and test if positive result */
5912 if (c != 0) {
5913 if (++taba[i + nb] == base) {
5914 break;
5915 }
5916 }
5917 }
5918 }
5919 tabq[i] = q;
5920 }
5921
5922 #ifdef DEBUG_DIV_SLOW
5923 mp_print_str_dec("q", tabq, na - nb + 1);
5924 mp_print_str_dec("r", taba, nb);
5925 #endif
5926
5927 /* remove the normalization */
5928 if (mult != 1) {
5929 mp_div1_dec(taba, taba, nb, mult, 0);
5930 if (unlikely(tabb != static_tabb))
5931 bf_free(s, tabb);
5932 }
5933 return 0;
5934 }
5935
5936 /* divide by 10^shift */
mp_shr_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t high)5937 static limb_t mp_shr_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5938 limb_t shift, limb_t high)
5939 {
5940 mp_size_t i;
5941 limb_t l, a, q, r;
5942
5943 assert(shift >= 1 && shift < LIMB_DIGITS);
5944 l = high;
5945 for(i = n - 1; i >= 0; i--) {
5946 a = tab[i];
5947 fast_shr_rem_dec(q, r, a, shift);
5948 tab_r[i] = q + l * mp_pow_dec[LIMB_DIGITS - shift];
5949 l = r;
5950 }
5951 return l;
5952 }
5953
5954 /* multiply by 10^shift */
mp_shl_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t low)5955 static limb_t mp_shl_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5956 limb_t shift, limb_t low)
5957 {
5958 mp_size_t i;
5959 limb_t l, a, q, r;
5960
5961 assert(shift >= 1 && shift < LIMB_DIGITS);
5962 l = low;
5963 for(i = 0; i < n; i++) {
5964 a = tab[i];
5965 fast_shr_rem_dec(q, r, a, LIMB_DIGITS - shift);
5966 tab_r[i] = r * mp_pow_dec[shift] + l;
5967 l = q;
5968 }
5969 return l;
5970 }
5971
mp_sqrtrem2_dec(limb_t * tabs,limb_t * taba)5972 static limb_t mp_sqrtrem2_dec(limb_t *tabs, limb_t *taba)
5973 {
5974 int k;
5975 dlimb_t a, b, r;
5976 limb_t taba1[2], s, r0, r1;
5977
5978 /* convert to binary and normalize */
5979 a = (dlimb_t)taba[1] * BF_DEC_BASE + taba[0];
5980 k = clz(a >> LIMB_BITS) & ~1;
5981 b = a << k;
5982 taba1[0] = b;
5983 taba1[1] = b >> LIMB_BITS;
5984 mp_sqrtrem2(&s, taba1);
5985 s >>= (k >> 1);
5986 /* convert the remainder back to decimal */
5987 r = a - (dlimb_t)s * (dlimb_t)s;
5988 divdq_base(r1, r0, r >> LIMB_BITS, r);
5989 taba[0] = r0;
5990 tabs[0] = s;
5991 return r1;
5992 }
5993
5994 //#define DEBUG_SQRTREM_DEC
5995
5996 /* tmp_buf must contain (n / 2 + 1 limbs) */
mp_sqrtrem_rec_dec(limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf)5997 static limb_t mp_sqrtrem_rec_dec(limb_t *tabs, limb_t *taba, limb_t n,
5998 limb_t *tmp_buf)
5999 {
6000 limb_t l, h, rh, ql, qh, c, i;
6001
6002 if (n == 1)
6003 return mp_sqrtrem2_dec(tabs, taba);
6004 #ifdef DEBUG_SQRTREM_DEC
6005 mp_print_str_dec("a", taba, 2 * n);
6006 #endif
6007 l = n / 2;
6008 h = n - l;
6009 qh = mp_sqrtrem_rec_dec(tabs + l, taba + 2 * l, h, tmp_buf);
6010 #ifdef DEBUG_SQRTREM_DEC
6011 mp_print_str_dec("s1", tabs + l, h);
6012 mp_print_str_h_dec("r1", taba + 2 * l, h, qh);
6013 mp_print_str_h_dec("r2", taba + l, n, qh);
6014 #endif
6015
6016 /* the remainder is in taba + 2 * l. Its high bit is in qh */
6017 if (qh) {
6018 mp_sub_dec(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
6019 }
6020 /* instead of dividing by 2*s, divide by s (which is normalized)
6021 and update q and r */
6022 mp_div_dec(NULL, tmp_buf, taba + l, n, tabs + l, h);
6023 qh += tmp_buf[l];
6024 for(i = 0; i < l; i++)
6025 tabs[i] = tmp_buf[i];
6026 ql = mp_div1_dec(tabs, tabs, l, 2, qh & 1);
6027 qh = qh >> 1; /* 0 or 1 */
6028 if (ql)
6029 rh = mp_add_dec(taba + l, taba + l, tabs + l, h, 0);
6030 else
6031 rh = 0;
6032 #ifdef DEBUG_SQRTREM_DEC
6033 mp_print_str_h_dec("q", tabs, l, qh);
6034 mp_print_str_h_dec("u", taba + l, h, rh);
6035 #endif
6036
6037 mp_add_ui_dec(tabs + l, qh, h);
6038 #ifdef DEBUG_SQRTREM_DEC
6039 mp_print_str_dec("s2", tabs, n);
6040 #endif
6041
6042 /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
6043 /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
6044 if (qh) {
6045 c = qh;
6046 } else {
6047 mp_mul_basecase_dec(taba + n, tabs, l, tabs, l);
6048 c = mp_sub_dec(taba, taba, taba + n, 2 * l, 0);
6049 }
6050 rh -= mp_sub_ui_dec(taba + 2 * l, c, n - 2 * l);
6051 if ((slimb_t)rh < 0) {
6052 mp_sub_ui_dec(tabs, 1, n);
6053 rh += mp_add_mul1_dec(taba, tabs, n, 2);
6054 rh += mp_add_ui_dec(taba, 1, n);
6055 }
6056 return rh;
6057 }
6058
6059 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= B/4. Return (s,
6060 r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2 * s. tabs has n
6061 limbs. r is returned in the lower n limbs of taba. Its r[n] is the
6062 returned value of the function. */
mp_sqrtrem_dec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)6063 int mp_sqrtrem_dec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
6064 {
6065 limb_t tmp_buf1[8];
6066 limb_t *tmp_buf;
6067 mp_size_t n2;
6068 n2 = n / 2 + 1;
6069 if (n2 <= countof(tmp_buf1)) {
6070 tmp_buf = tmp_buf1;
6071 } else {
6072 tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
6073 if (!tmp_buf)
6074 return -1;
6075 }
6076 taba[n] = mp_sqrtrem_rec_dec(tabs, taba, n, tmp_buf);
6077 if (tmp_buf != tmp_buf1)
6078 bf_free(s, tmp_buf);
6079 return 0;
6080 }
6081
6082 /* return the number of leading zero digits, from 0 to LIMB_DIGITS */
clz_dec(limb_t a)6083 static int clz_dec(limb_t a)
6084 {
6085 if (a == 0)
6086 return LIMB_DIGITS;
6087 switch(LIMB_BITS - 1 - clz(a)) {
6088 case 0: /* 1-1 */
6089 return LIMB_DIGITS - 1;
6090 case 1: /* 2-3 */
6091 return LIMB_DIGITS - 1;
6092 case 2: /* 4-7 */
6093 return LIMB_DIGITS - 1;
6094 case 3: /* 8-15 */
6095 if (a < 10)
6096 return LIMB_DIGITS - 1;
6097 else
6098 return LIMB_DIGITS - 2;
6099 case 4: /* 16-31 */
6100 return LIMB_DIGITS - 2;
6101 case 5: /* 32-63 */
6102 return LIMB_DIGITS - 2;
6103 case 6: /* 64-127 */
6104 if (a < 100)
6105 return LIMB_DIGITS - 2;
6106 else
6107 return LIMB_DIGITS - 3;
6108 case 7: /* 128-255 */
6109 return LIMB_DIGITS - 3;
6110 case 8: /* 256-511 */
6111 return LIMB_DIGITS - 3;
6112 case 9: /* 512-1023 */
6113 if (a < 1000)
6114 return LIMB_DIGITS - 3;
6115 else
6116 return LIMB_DIGITS - 4;
6117 case 10: /* 1024-2047 */
6118 return LIMB_DIGITS - 4;
6119 case 11: /* 2048-4095 */
6120 return LIMB_DIGITS - 4;
6121 case 12: /* 4096-8191 */
6122 return LIMB_DIGITS - 4;
6123 case 13: /* 8192-16383 */
6124 if (a < 10000)
6125 return LIMB_DIGITS - 4;
6126 else
6127 return LIMB_DIGITS - 5;
6128 case 14: /* 16384-32767 */
6129 return LIMB_DIGITS - 5;
6130 case 15: /* 32768-65535 */
6131 return LIMB_DIGITS - 5;
6132 case 16: /* 65536-131071 */
6133 if (a < 100000)
6134 return LIMB_DIGITS - 5;
6135 else
6136 return LIMB_DIGITS - 6;
6137 case 17: /* 131072-262143 */
6138 return LIMB_DIGITS - 6;
6139 case 18: /* 262144-524287 */
6140 return LIMB_DIGITS - 6;
6141 case 19: /* 524288-1048575 */
6142 if (a < 1000000)
6143 return LIMB_DIGITS - 6;
6144 else
6145 return LIMB_DIGITS - 7;
6146 case 20: /* 1048576-2097151 */
6147 return LIMB_DIGITS - 7;
6148 case 21: /* 2097152-4194303 */
6149 return LIMB_DIGITS - 7;
6150 case 22: /* 4194304-8388607 */
6151 return LIMB_DIGITS - 7;
6152 case 23: /* 8388608-16777215 */
6153 if (a < 10000000)
6154 return LIMB_DIGITS - 7;
6155 else
6156 return LIMB_DIGITS - 8;
6157 case 24: /* 16777216-33554431 */
6158 return LIMB_DIGITS - 8;
6159 case 25: /* 33554432-67108863 */
6160 return LIMB_DIGITS - 8;
6161 case 26: /* 67108864-134217727 */
6162 if (a < 100000000)
6163 return LIMB_DIGITS - 8;
6164 else
6165 return LIMB_DIGITS - 9;
6166 #if LIMB_BITS == 64
6167 case 27: /* 134217728-268435455 */
6168 return LIMB_DIGITS - 9;
6169 case 28: /* 268435456-536870911 */
6170 return LIMB_DIGITS - 9;
6171 case 29: /* 536870912-1073741823 */
6172 if (a < 1000000000)
6173 return LIMB_DIGITS - 9;
6174 else
6175 return LIMB_DIGITS - 10;
6176 case 30: /* 1073741824-2147483647 */
6177 return LIMB_DIGITS - 10;
6178 case 31: /* 2147483648-4294967295 */
6179 return LIMB_DIGITS - 10;
6180 case 32: /* 4294967296-8589934591 */
6181 return LIMB_DIGITS - 10;
6182 case 33: /* 8589934592-17179869183 */
6183 if (a < 10000000000)
6184 return LIMB_DIGITS - 10;
6185 else
6186 return LIMB_DIGITS - 11;
6187 case 34: /* 17179869184-34359738367 */
6188 return LIMB_DIGITS - 11;
6189 case 35: /* 34359738368-68719476735 */
6190 return LIMB_DIGITS - 11;
6191 case 36: /* 68719476736-137438953471 */
6192 if (a < 100000000000)
6193 return LIMB_DIGITS - 11;
6194 else
6195 return LIMB_DIGITS - 12;
6196 case 37: /* 137438953472-274877906943 */
6197 return LIMB_DIGITS - 12;
6198 case 38: /* 274877906944-549755813887 */
6199 return LIMB_DIGITS - 12;
6200 case 39: /* 549755813888-1099511627775 */
6201 if (a < 1000000000000)
6202 return LIMB_DIGITS - 12;
6203 else
6204 return LIMB_DIGITS - 13;
6205 case 40: /* 1099511627776-2199023255551 */
6206 return LIMB_DIGITS - 13;
6207 case 41: /* 2199023255552-4398046511103 */
6208 return LIMB_DIGITS - 13;
6209 case 42: /* 4398046511104-8796093022207 */
6210 return LIMB_DIGITS - 13;
6211 case 43: /* 8796093022208-17592186044415 */
6212 if (a < 10000000000000)
6213 return LIMB_DIGITS - 13;
6214 else
6215 return LIMB_DIGITS - 14;
6216 case 44: /* 17592186044416-35184372088831 */
6217 return LIMB_DIGITS - 14;
6218 case 45: /* 35184372088832-70368744177663 */
6219 return LIMB_DIGITS - 14;
6220 case 46: /* 70368744177664-140737488355327 */
6221 if (a < 100000000000000)
6222 return LIMB_DIGITS - 14;
6223 else
6224 return LIMB_DIGITS - 15;
6225 case 47: /* 140737488355328-281474976710655 */
6226 return LIMB_DIGITS - 15;
6227 case 48: /* 281474976710656-562949953421311 */
6228 return LIMB_DIGITS - 15;
6229 case 49: /* 562949953421312-1125899906842623 */
6230 if (a < 1000000000000000)
6231 return LIMB_DIGITS - 15;
6232 else
6233 return LIMB_DIGITS - 16;
6234 case 50: /* 1125899906842624-2251799813685247 */
6235 return LIMB_DIGITS - 16;
6236 case 51: /* 2251799813685248-4503599627370495 */
6237 return LIMB_DIGITS - 16;
6238 case 52: /* 4503599627370496-9007199254740991 */
6239 return LIMB_DIGITS - 16;
6240 case 53: /* 9007199254740992-18014398509481983 */
6241 if (a < 10000000000000000)
6242 return LIMB_DIGITS - 16;
6243 else
6244 return LIMB_DIGITS - 17;
6245 case 54: /* 18014398509481984-36028797018963967 */
6246 return LIMB_DIGITS - 17;
6247 case 55: /* 36028797018963968-72057594037927935 */
6248 return LIMB_DIGITS - 17;
6249 case 56: /* 72057594037927936-144115188075855871 */
6250 if (a < 100000000000000000)
6251 return LIMB_DIGITS - 17;
6252 else
6253 return LIMB_DIGITS - 18;
6254 case 57: /* 144115188075855872-288230376151711743 */
6255 return LIMB_DIGITS - 18;
6256 case 58: /* 288230376151711744-576460752303423487 */
6257 return LIMB_DIGITS - 18;
6258 case 59: /* 576460752303423488-1152921504606846975 */
6259 if (a < 1000000000000000000)
6260 return LIMB_DIGITS - 18;
6261 else
6262 return LIMB_DIGITS - 19;
6263 #endif
6264 default:
6265 return 0;
6266 }
6267 }
6268
6269 /* for debugging */
bfdec_print_str(const char * str,const bfdec_t * a)6270 void bfdec_print_str(const char *str, const bfdec_t *a)
6271 {
6272 slimb_t i;
6273 printf("%s=", str);
6274
6275 if (a->expn == BF_EXP_NAN) {
6276 printf("NaN");
6277 } else {
6278 if (a->sign)
6279 putchar('-');
6280 if (a->expn == BF_EXP_ZERO) {
6281 putchar('0');
6282 } else if (a->expn == BF_EXP_INF) {
6283 printf("Inf");
6284 } else {
6285 printf("0.");
6286 for(i = a->len - 1; i >= 0; i--)
6287 printf("%0*" PRIu_LIMB, LIMB_DIGITS, a->tab[i]);
6288 printf("e%" PRId_LIMB, a->expn);
6289 }
6290 }
6291 printf("\n");
6292 }
6293
6294 /* return != 0 if one digit between 0 and bit_pos inclusive is not zero. */
scan_digit_nz(const bfdec_t * r,slimb_t bit_pos)6295 static inline limb_t scan_digit_nz(const bfdec_t *r, slimb_t bit_pos)
6296 {
6297 slimb_t pos;
6298 limb_t v, q;
6299 int shift;
6300
6301 if (bit_pos < 0)
6302 return 0;
6303 pos = (limb_t)bit_pos / LIMB_DIGITS;
6304 shift = (limb_t)bit_pos % LIMB_DIGITS;
6305 fast_shr_rem_dec(q, v, r->tab[pos], shift + 1);
6306 (void)q;
6307 if (v != 0)
6308 return 1;
6309 pos--;
6310 while (pos >= 0) {
6311 if (r->tab[pos] != 0)
6312 return 1;
6313 pos--;
6314 }
6315 return 0;
6316 }
6317
get_digit(const limb_t * tab,limb_t len,slimb_t pos)6318 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos)
6319 {
6320 slimb_t i;
6321 int shift;
6322 i = floor_div(pos, LIMB_DIGITS);
6323 if (i < 0 || i >= len)
6324 return 0;
6325 shift = pos - i * LIMB_DIGITS;
6326 return fast_shr_dec(tab[i], shift) % 10;
6327 }
6328
6329 #if 0
6330 static limb_t get_digits(const limb_t *tab, limb_t len, slimb_t pos)
6331 {
6332 limb_t a0, a1;
6333 int shift;
6334 slimb_t i;
6335
6336 i = floor_div(pos, LIMB_DIGITS);
6337 shift = pos - i * LIMB_DIGITS;
6338 if (i >= 0 && i < len)
6339 a0 = tab[i];
6340 else
6341 a0 = 0;
6342 if (shift == 0) {
6343 return a0;
6344 } else {
6345 i++;
6346 if (i >= 0 && i < len)
6347 a1 = tab[i];
6348 else
6349 a1 = 0;
6350 return fast_shr_dec(a0, shift) +
6351 fast_urem(a1, &mp_pow_div[LIMB_DIGITS - shift]) *
6352 mp_pow_dec[shift];
6353 }
6354 }
6355 #endif
6356
6357 /* return the addend for rounding. Note that prec can be <= 0 for bf_rint() */
bfdec_get_rnd_add(int * pret,const bfdec_t * r,limb_t l,slimb_t prec,int rnd_mode)6358 static int bfdec_get_rnd_add(int *pret, const bfdec_t *r, limb_t l,
6359 slimb_t prec, int rnd_mode)
6360 {
6361 int add_one, inexact;
6362 limb_t digit1, digit0;
6363
6364 // bfdec_print_str("get_rnd_add", r);
6365 if (rnd_mode == BF_RNDF) {
6366 digit0 = 1; /* faithful rounding does not honor the INEXACT flag */
6367 } else {
6368 /* starting limb for bit 'prec + 1' */
6369 digit0 = scan_digit_nz(r, l * LIMB_DIGITS - 1 - bf_max(0, prec + 1));
6370 }
6371
6372 /* get the digit at 'prec' */
6373 digit1 = get_digit(r->tab, l, l * LIMB_DIGITS - 1 - prec);
6374 inexact = (digit1 | digit0) != 0;
6375
6376 add_one = 0;
6377 switch(rnd_mode) {
6378 case BF_RNDZ:
6379 break;
6380 case BF_RNDN:
6381 if (digit1 == 5) {
6382 if (digit0) {
6383 add_one = 1;
6384 } else {
6385 /* round to even */
6386 add_one =
6387 get_digit(r->tab, l, l * LIMB_DIGITS - 1 - (prec - 1)) & 1;
6388 }
6389 } else if (digit1 > 5) {
6390 add_one = 1;
6391 }
6392 break;
6393 case BF_RNDD:
6394 case BF_RNDU:
6395 if (r->sign == (rnd_mode == BF_RNDD))
6396 add_one = inexact;
6397 break;
6398 case BF_RNDNA:
6399 case BF_RNDF:
6400 add_one = (digit1 >= 5);
6401 break;
6402 case BF_RNDA:
6403 add_one = inexact;
6404 break;
6405 default:
6406 abort();
6407 }
6408
6409 if (inexact)
6410 *pret |= BF_ST_INEXACT;
6411 return add_one;
6412 }
6413
6414 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
6415 assumed to have length 'l' (1 <= l <= r->len). prec1 can be
6416 BF_PREC_INF. BF_FLAG_SUBNORMAL is not supported. Cannot fail with
6417 BF_ST_MEM_ERROR.
6418 */
__bfdec_round(bfdec_t * r,limb_t prec1,bf_flags_t flags,limb_t l)6419 static int __bfdec_round(bfdec_t *r, limb_t prec1, bf_flags_t flags, limb_t l)
6420 {
6421 int shift, add_one, rnd_mode, ret;
6422 slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
6423
6424 /* XXX: align to IEEE 754 2008 for decimal numbers ? */
6425 e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
6426 e_min = -e_range + 3;
6427 e_max = e_range;
6428
6429 if (flags & BF_FLAG_RADPNT_PREC) {
6430 /* 'prec' is the precision after the decimal point */
6431 if (prec1 != BF_PREC_INF)
6432 prec = r->expn + prec1;
6433 else
6434 prec = prec1;
6435 } else if (unlikely(r->expn < e_min) && (flags & BF_FLAG_SUBNORMAL)) {
6436 /* restrict the precision in case of potentially subnormal
6437 result */
6438 assert(prec1 != BF_PREC_INF);
6439 prec = prec1 - (e_min - r->expn);
6440 } else {
6441 prec = prec1;
6442 }
6443
6444 /* round to prec bits */
6445 rnd_mode = flags & BF_RND_MASK;
6446 ret = 0;
6447 add_one = bfdec_get_rnd_add(&ret, r, l, prec, rnd_mode);
6448
6449 if (prec <= 0) {
6450 if (add_one) {
6451 bfdec_resize(r, 1); /* cannot fail because r is non zero */
6452 r->tab[0] = BF_DEC_BASE / 10;
6453 r->expn += 1 - prec;
6454 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6455 return ret;
6456 } else {
6457 goto underflow;
6458 }
6459 } else if (add_one) {
6460 limb_t carry;
6461
6462 /* add one starting at digit 'prec - 1' */
6463 bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6464 pos = bit_pos / LIMB_DIGITS;
6465 carry = mp_pow_dec[bit_pos % LIMB_DIGITS];
6466 carry = mp_add_ui_dec(r->tab + pos, carry, l - pos);
6467 if (carry) {
6468 /* shift right by one digit */
6469 mp_shr_dec(r->tab + pos, r->tab + pos, l - pos, 1, 1);
6470 r->expn++;
6471 }
6472 }
6473
6474 /* check underflow */
6475 if (unlikely(r->expn < e_min)) {
6476 if (flags & BF_FLAG_SUBNORMAL) {
6477 /* if inexact, also set the underflow flag */
6478 if (ret & BF_ST_INEXACT)
6479 ret |= BF_ST_UNDERFLOW;
6480 } else {
6481 underflow:
6482 bfdec_set_zero(r, r->sign);
6483 ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6484 return ret;
6485 }
6486 }
6487
6488 /* check overflow */
6489 if (unlikely(r->expn > e_max)) {
6490 bfdec_set_inf(r, r->sign);
6491 ret |= BF_ST_OVERFLOW | BF_ST_INEXACT;
6492 return ret;
6493 }
6494
6495 /* keep the bits starting at 'prec - 1' */
6496 bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6497 i = floor_div(bit_pos, LIMB_DIGITS);
6498 if (i >= 0) {
6499 shift = smod(bit_pos, LIMB_DIGITS);
6500 if (shift != 0) {
6501 r->tab[i] = fast_shr_dec(r->tab[i], shift) *
6502 mp_pow_dec[shift];
6503 }
6504 } else {
6505 i = 0;
6506 }
6507 /* remove trailing zeros */
6508 while (r->tab[i] == 0)
6509 i++;
6510 if (i > 0) {
6511 l -= i;
6512 memmove(r->tab, r->tab + i, l * sizeof(limb_t));
6513 }
6514 bfdec_resize(r, l); /* cannot fail */
6515 return ret;
6516 }
6517
6518 /* Cannot fail with BF_ST_MEM_ERROR. */
bfdec_round(bfdec_t * r,limb_t prec,bf_flags_t flags)6519 int bfdec_round(bfdec_t *r, limb_t prec, bf_flags_t flags)
6520 {
6521 if (r->len == 0)
6522 return 0;
6523 return __bfdec_round(r, prec, flags, r->len);
6524 }
6525
6526 /* 'r' must be a finite number. Cannot fail with BF_ST_MEM_ERROR. */
bfdec_normalize_and_round(bfdec_t * r,limb_t prec1,bf_flags_t flags)6527 int bfdec_normalize_and_round(bfdec_t *r, limb_t prec1, bf_flags_t flags)
6528 {
6529 limb_t l, v;
6530 int shift, ret;
6531
6532 // bfdec_print_str("bf_renorm", r);
6533 l = r->len;
6534 while (l > 0 && r->tab[l - 1] == 0)
6535 l--;
6536 if (l == 0) {
6537 /* zero */
6538 r->expn = BF_EXP_ZERO;
6539 bfdec_resize(r, 0); /* cannot fail */
6540 ret = 0;
6541 } else {
6542 r->expn -= (r->len - l) * LIMB_DIGITS;
6543 /* shift to have the MSB set to '1' */
6544 v = r->tab[l - 1];
6545 shift = clz_dec(v);
6546 if (shift != 0) {
6547 mp_shl_dec(r->tab, r->tab, l, shift, 0);
6548 r->expn -= shift;
6549 }
6550 ret = __bfdec_round(r, prec1, flags, l);
6551 }
6552 // bf_print_str("r_final", r);
6553 return ret;
6554 }
6555
bfdec_set_ui(bfdec_t * r,uint64_t v)6556 int bfdec_set_ui(bfdec_t *r, uint64_t v)
6557 {
6558 #if LIMB_BITS == 32
6559 if (v >= BF_DEC_BASE * BF_DEC_BASE) {
6560 if (bfdec_resize(r, 3))
6561 goto fail;
6562 r->tab[0] = v % BF_DEC_BASE;
6563 v /= BF_DEC_BASE;
6564 r->tab[1] = v % BF_DEC_BASE;
6565 r->tab[2] = v / BF_DEC_BASE;
6566 r->expn = 3 * LIMB_DIGITS;
6567 } else
6568 #endif
6569 if (v >= BF_DEC_BASE) {
6570 if (bfdec_resize(r, 2))
6571 goto fail;
6572 r->tab[0] = v % BF_DEC_BASE;
6573 r->tab[1] = v / BF_DEC_BASE;
6574 r->expn = 2 * LIMB_DIGITS;
6575 } else {
6576 if (bfdec_resize(r, 1))
6577 goto fail;
6578 r->tab[0] = v;
6579 r->expn = LIMB_DIGITS;
6580 }
6581 r->sign = 0;
6582 return bfdec_normalize_and_round(r, BF_PREC_INF, 0);
6583 fail:
6584 bfdec_set_nan(r);
6585 return BF_ST_MEM_ERROR;
6586 }
6587
bfdec_set_si(bfdec_t * r,int64_t v)6588 int bfdec_set_si(bfdec_t *r, int64_t v)
6589 {
6590 int ret;
6591 if (v < 0) {
6592 ret = bfdec_set_ui(r, -v);
6593 r->sign = 1;
6594 } else {
6595 ret = bfdec_set_ui(r, v);
6596 }
6597 return ret;
6598 }
6599
bfdec_add_internal(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int b_neg)6600 static int bfdec_add_internal(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec, bf_flags_t flags, int b_neg)
6601 {
6602 bf_context_t *s = r->ctx;
6603 int is_sub, cmp_res, a_sign, b_sign, ret;
6604
6605 a_sign = a->sign;
6606 b_sign = b->sign ^ b_neg;
6607 is_sub = a_sign ^ b_sign;
6608 cmp_res = bfdec_cmpu(a, b);
6609 if (cmp_res < 0) {
6610 const bfdec_t *tmp;
6611 tmp = a;
6612 a = b;
6613 b = tmp;
6614 a_sign = b_sign; /* b_sign is never used later */
6615 }
6616 /* abs(a) >= abs(b) */
6617 if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
6618 /* zero result */
6619 bfdec_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
6620 ret = 0;
6621 } else if (a->len == 0 || b->len == 0) {
6622 ret = 0;
6623 if (a->expn >= BF_EXP_INF) {
6624 if (a->expn == BF_EXP_NAN) {
6625 /* at least one operand is NaN */
6626 bfdec_set_nan(r);
6627 ret = 0;
6628 } else if (b->expn == BF_EXP_INF && is_sub) {
6629 /* infinities with different signs */
6630 bfdec_set_nan(r);
6631 ret = BF_ST_INVALID_OP;
6632 } else {
6633 bfdec_set_inf(r, a_sign);
6634 }
6635 } else {
6636 /* at least one zero and not subtract */
6637 if (bfdec_set(r, a))
6638 return BF_ST_MEM_ERROR;
6639 r->sign = a_sign;
6640 goto renorm;
6641 }
6642 } else {
6643 slimb_t d, a_offset, b_offset, i, r_len;
6644 limb_t carry;
6645 limb_t *b1_tab;
6646 int b_shift;
6647 mp_size_t b1_len;
6648
6649 d = a->expn - b->expn;
6650
6651 /* XXX: not efficient in time and memory if the precision is
6652 not infinite */
6653 r_len = bf_max(a->len, b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6654 if (bfdec_resize(r, r_len))
6655 goto fail;
6656 r->sign = a_sign;
6657 r->expn = a->expn;
6658
6659 a_offset = r_len - a->len;
6660 for(i = 0; i < a_offset; i++)
6661 r->tab[i] = 0;
6662 for(i = 0; i < a->len; i++)
6663 r->tab[a_offset + i] = a->tab[i];
6664
6665 b_shift = d % LIMB_DIGITS;
6666 if (b_shift == 0) {
6667 b1_len = b->len;
6668 b1_tab = (limb_t *)b->tab;
6669 } else {
6670 b1_len = b->len + 1;
6671 b1_tab = bf_malloc(s, sizeof(limb_t) * b1_len);
6672 if (!b1_tab)
6673 goto fail;
6674 b1_tab[0] = mp_shr_dec(b1_tab + 1, b->tab, b->len, b_shift, 0) *
6675 mp_pow_dec[LIMB_DIGITS - b_shift];
6676 }
6677 b_offset = r_len - (b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6678
6679 if (is_sub) {
6680 carry = mp_sub_dec(r->tab + b_offset, r->tab + b_offset,
6681 b1_tab, b1_len, 0);
6682 if (carry != 0) {
6683 carry = mp_sub_ui_dec(r->tab + b_offset + b1_len, carry,
6684 r_len - (b_offset + b1_len));
6685 assert(carry == 0);
6686 }
6687 } else {
6688 carry = mp_add_dec(r->tab + b_offset, r->tab + b_offset,
6689 b1_tab, b1_len, 0);
6690 if (carry != 0) {
6691 carry = mp_add_ui_dec(r->tab + b_offset + b1_len, carry,
6692 r_len - (b_offset + b1_len));
6693 }
6694 if (carry != 0) {
6695 if (bfdec_resize(r, r_len + 1)) {
6696 if (b_shift != 0)
6697 bf_free(s, b1_tab);
6698 goto fail;
6699 }
6700 r->tab[r_len] = 1;
6701 r->expn += LIMB_DIGITS;
6702 }
6703 }
6704 if (b_shift != 0)
6705 bf_free(s, b1_tab);
6706 renorm:
6707 ret = bfdec_normalize_and_round(r, prec, flags);
6708 }
6709 return ret;
6710 fail:
6711 bfdec_set_nan(r);
6712 return BF_ST_MEM_ERROR;
6713 }
6714
__bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6715 static int __bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6716 bf_flags_t flags)
6717 {
6718 return bfdec_add_internal(r, a, b, prec, flags, 0);
6719 }
6720
__bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6721 static int __bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6722 bf_flags_t flags)
6723 {
6724 return bfdec_add_internal(r, a, b, prec, flags, 1);
6725 }
6726
bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6727 int bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6728 bf_flags_t flags)
6729 {
6730 return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6731 (bf_op2_func_t *)__bfdec_add);
6732 }
6733
bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6734 int bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6735 bf_flags_t flags)
6736 {
6737 return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6738 (bf_op2_func_t *)__bfdec_sub);
6739 }
6740
bfdec_mul(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6741 int bfdec_mul(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6742 bf_flags_t flags)
6743 {
6744 int ret, r_sign;
6745
6746 if (a->len < b->len) {
6747 const bfdec_t *tmp = a;
6748 a = b;
6749 b = tmp;
6750 }
6751 r_sign = a->sign ^ b->sign;
6752 /* here b->len <= a->len */
6753 if (b->len == 0) {
6754 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6755 bfdec_set_nan(r);
6756 ret = 0;
6757 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
6758 if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
6759 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
6760 bfdec_set_nan(r);
6761 ret = BF_ST_INVALID_OP;
6762 } else {
6763 bfdec_set_inf(r, r_sign);
6764 ret = 0;
6765 }
6766 } else {
6767 bfdec_set_zero(r, r_sign);
6768 ret = 0;
6769 }
6770 } else {
6771 bfdec_t tmp, *r1 = NULL;
6772 limb_t a_len, b_len;
6773 limb_t *a_tab, *b_tab;
6774
6775 a_len = a->len;
6776 b_len = b->len;
6777 a_tab = a->tab;
6778 b_tab = b->tab;
6779
6780 if (r == a || r == b) {
6781 bfdec_init(r->ctx, &tmp);
6782 r1 = r;
6783 r = &tmp;
6784 }
6785 if (bfdec_resize(r, a_len + b_len)) {
6786 bfdec_set_nan(r);
6787 ret = BF_ST_MEM_ERROR;
6788 goto done;
6789 }
6790 mp_mul_basecase_dec(r->tab, a_tab, a_len, b_tab, b_len);
6791 r->sign = r_sign;
6792 r->expn = a->expn + b->expn;
6793 ret = bfdec_normalize_and_round(r, prec, flags);
6794 done:
6795 if (r == &tmp)
6796 bfdec_move(r1, &tmp);
6797 }
6798 return ret;
6799 }
6800
bfdec_mul_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6801 int bfdec_mul_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6802 bf_flags_t flags)
6803 {
6804 bfdec_t b;
6805 int ret;
6806 bfdec_init(r->ctx, &b);
6807 ret = bfdec_set_si(&b, b1);
6808 ret |= bfdec_mul(r, a, &b, prec, flags);
6809 bfdec_delete(&b);
6810 return ret;
6811 }
6812
bfdec_add_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6813 int bfdec_add_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6814 bf_flags_t flags)
6815 {
6816 bfdec_t b;
6817 int ret;
6818
6819 bfdec_init(r->ctx, &b);
6820 ret = bfdec_set_si(&b, b1);
6821 ret |= bfdec_add(r, a, &b, prec, flags);
6822 bfdec_delete(&b);
6823 return ret;
6824 }
6825
__bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6826 static int __bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6827 limb_t prec, bf_flags_t flags)
6828 {
6829 int ret, r_sign;
6830 limb_t n, nb, precl;
6831
6832 r_sign = a->sign ^ b->sign;
6833 if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
6834 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6835 bfdec_set_nan(r);
6836 return 0;
6837 } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
6838 bfdec_set_nan(r);
6839 return BF_ST_INVALID_OP;
6840 } else if (a->expn == BF_EXP_INF) {
6841 bfdec_set_inf(r, r_sign);
6842 return 0;
6843 } else {
6844 bfdec_set_zero(r, r_sign);
6845 return 0;
6846 }
6847 } else if (a->expn == BF_EXP_ZERO) {
6848 if (b->expn == BF_EXP_ZERO) {
6849 bfdec_set_nan(r);
6850 return BF_ST_INVALID_OP;
6851 } else {
6852 bfdec_set_zero(r, r_sign);
6853 return 0;
6854 }
6855 } else if (b->expn == BF_EXP_ZERO) {
6856 bfdec_set_inf(r, r_sign);
6857 return BF_ST_DIVIDE_ZERO;
6858 }
6859
6860 nb = b->len;
6861 if (prec == BF_PREC_INF) {
6862 /* infinite precision: return BF_ST_INVALID_OP if not an exact
6863 result */
6864 /* XXX: check */
6865 precl = nb + 1;
6866 } else if (flags & BF_FLAG_RADPNT_PREC) {
6867 /* number of digits after the decimal point */
6868 /* XXX: check (2 extra digits for rounding + 2 digits) */
6869 precl = (bf_max(a->expn - b->expn, 0) + 2 +
6870 prec + 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6871 } else {
6872 /* number of limbs of the quotient (2 extra digits for rounding) */
6873 precl = (prec + 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6874 }
6875 n = bf_max(a->len, precl);
6876
6877 {
6878 limb_t *taba, na, i;
6879 slimb_t d;
6880
6881 na = n + nb;
6882 taba = bf_malloc(r->ctx, (na + 1) * sizeof(limb_t));
6883 if (!taba)
6884 goto fail;
6885 d = na - a->len;
6886 memset(taba, 0, d * sizeof(limb_t));
6887 memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
6888 if (bfdec_resize(r, n + 1))
6889 goto fail1;
6890 if (mp_div_dec(r->ctx, r->tab, taba, na, b->tab, nb)) {
6891 fail1:
6892 bf_free(r->ctx, taba);
6893 goto fail;
6894 }
6895 /* see if non zero remainder */
6896 for(i = 0; i < nb; i++) {
6897 if (taba[i] != 0)
6898 break;
6899 }
6900 bf_free(r->ctx, taba);
6901 if (i != nb) {
6902 if (prec == BF_PREC_INF) {
6903 bfdec_set_nan(r);
6904 return BF_ST_INVALID_OP;
6905 } else {
6906 r->tab[0] |= 1;
6907 }
6908 }
6909 r->expn = a->expn - b->expn + LIMB_DIGITS;
6910 r->sign = r_sign;
6911 ret = bfdec_normalize_and_round(r, prec, flags);
6912 }
6913 return ret;
6914 fail:
6915 bfdec_set_nan(r);
6916 return BF_ST_MEM_ERROR;
6917 }
6918
bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6919 int bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6920 bf_flags_t flags)
6921 {
6922 return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6923 (bf_op2_func_t *)__bfdec_div);
6924 }
6925
6926 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
6927 integer defined as floor(a/b) and r = a - q * b. */
bfdec_tdivremu(bf_context_t * s,bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b)6928 static void bfdec_tdivremu(bf_context_t *s, bfdec_t *q, bfdec_t *r,
6929 const bfdec_t *a, const bfdec_t *b)
6930 {
6931 if (bfdec_cmpu(a, b) < 0) {
6932 bfdec_set_ui(q, 0);
6933 bfdec_set(r, a);
6934 } else {
6935 bfdec_div(q, a, b, 0, BF_RNDZ | BF_FLAG_RADPNT_PREC);
6936 bfdec_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
6937 bfdec_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
6938 }
6939 }
6940
6941 /* division and remainder.
6942
6943 rnd_mode is the rounding mode for the quotient. The additional
6944 rounding mode BF_RND_EUCLIDIAN is supported.
6945
6946 'q' is an integer. 'r' is rounded with prec and flags (prec can be
6947 BF_PREC_INF).
6948 */
bfdec_divrem(bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)6949 int bfdec_divrem(bfdec_t *q, bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6950 limb_t prec, bf_flags_t flags, int rnd_mode)
6951 {
6952 bf_context_t *s = q->ctx;
6953 bfdec_t a1_s, *a1 = &a1_s;
6954 bfdec_t b1_s, *b1 = &b1_s;
6955 bfdec_t r1_s, *r1 = &r1_s;
6956 int q_sign, res;
6957 BOOL is_ceil, is_rndn;
6958
6959 assert(q != a && q != b);
6960 assert(r != a && r != b);
6961 assert(q != r);
6962
6963 if (a->len == 0 || b->len == 0) {
6964 bfdec_set_zero(q, 0);
6965 if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6966 bfdec_set_nan(r);
6967 return 0;
6968 } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
6969 bfdec_set_nan(r);
6970 return BF_ST_INVALID_OP;
6971 } else {
6972 bfdec_set(r, a);
6973 return bfdec_round(r, prec, flags);
6974 }
6975 }
6976
6977 q_sign = a->sign ^ b->sign;
6978 is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA);
6979 switch(rnd_mode) {
6980 default:
6981 case BF_RNDZ:
6982 case BF_RNDN:
6983 case BF_RNDNA:
6984 is_ceil = FALSE;
6985 break;
6986 case BF_RNDD:
6987 is_ceil = q_sign;
6988 break;
6989 case BF_RNDU:
6990 is_ceil = q_sign ^ 1;
6991 break;
6992 case BF_RNDA:
6993 is_ceil = TRUE;
6994 break;
6995 case BF_DIVREM_EUCLIDIAN:
6996 is_ceil = a->sign;
6997 break;
6998 }
6999
7000 a1->expn = a->expn;
7001 a1->tab = a->tab;
7002 a1->len = a->len;
7003 a1->sign = 0;
7004
7005 b1->expn = b->expn;
7006 b1->tab = b->tab;
7007 b1->len = b->len;
7008 b1->sign = 0;
7009
7010 // bfdec_print_str("a1", a1);
7011 // bfdec_print_str("b1", b1);
7012 /* XXX: could improve to avoid having a large 'q' */
7013 bfdec_tdivremu(s, q, r, a1, b1);
7014 if (bfdec_is_nan(q) || bfdec_is_nan(r))
7015 goto fail;
7016 // bfdec_print_str("q", q);
7017 // bfdec_print_str("r", r);
7018
7019 if (r->len != 0) {
7020 if (is_rndn) {
7021 bfdec_init(s, r1);
7022 if (bfdec_set(r1, r))
7023 goto fail;
7024 if (bfdec_mul_si(r1, r1, 2, BF_PREC_INF, BF_RNDZ)) {
7025 bfdec_delete(r1);
7026 goto fail;
7027 }
7028 res = bfdec_cmpu(r1, b);
7029 bfdec_delete(r1);
7030 if (res > 0 ||
7031 (res == 0 &&
7032 (rnd_mode == BF_RNDNA ||
7033 (get_digit(q->tab, q->len, q->len * LIMB_DIGITS - q->expn) & 1) != 0))) {
7034 goto do_sub_r;
7035 }
7036 } else if (is_ceil) {
7037 do_sub_r:
7038 res = bfdec_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
7039 res |= bfdec_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
7040 if (res & BF_ST_MEM_ERROR)
7041 goto fail;
7042 }
7043 }
7044
7045 r->sign ^= a->sign;
7046 q->sign = q_sign;
7047 return bfdec_round(r, prec, flags);
7048 fail:
7049 bfdec_set_nan(q);
7050 bfdec_set_nan(r);
7051 return BF_ST_MEM_ERROR;
7052 }
7053
bfdec_rem(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)7054 int bfdec_rem(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
7055 bf_flags_t flags, int rnd_mode)
7056 {
7057 bfdec_t q_s, *q = &q_s;
7058 int ret;
7059
7060 bfdec_init(r->ctx, q);
7061 ret = bfdec_divrem(q, r, a, b, prec, flags, rnd_mode);
7062 bfdec_delete(q);
7063 return ret;
7064 }
7065
7066 /* convert to integer (infinite precision) */
bfdec_rint(bfdec_t * r,int rnd_mode)7067 int bfdec_rint(bfdec_t *r, int rnd_mode)
7068 {
7069 return bfdec_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
7070 }
7071
bfdec_sqrt(bfdec_t * r,const bfdec_t * a,limb_t prec,bf_flags_t flags)7072 int bfdec_sqrt(bfdec_t *r, const bfdec_t *a, limb_t prec, bf_flags_t flags)
7073 {
7074 bf_context_t *s = a->ctx;
7075 int ret, k;
7076 limb_t *a1, v;
7077 slimb_t n, n1, prec1;
7078 limb_t res;
7079
7080 assert(r != a);
7081
7082 if (a->len == 0) {
7083 if (a->expn == BF_EXP_NAN) {
7084 bfdec_set_nan(r);
7085 } else if (a->expn == BF_EXP_INF && a->sign) {
7086 goto invalid_op;
7087 } else {
7088 bfdec_set(r, a);
7089 }
7090 ret = 0;
7091 } else if (a->sign || prec == BF_PREC_INF) {
7092 invalid_op:
7093 bfdec_set_nan(r);
7094 ret = BF_ST_INVALID_OP;
7095 } else {
7096 if (flags & BF_FLAG_RADPNT_PREC) {
7097 prec1 = bf_max(floor_div(a->expn + 1, 2) + prec, 1);
7098 } else {
7099 prec1 = prec;
7100 }
7101 /* convert the mantissa to an integer with at least 2 *
7102 prec + 4 digits */
7103 n = (2 * (prec1 + 2) + 2 * LIMB_DIGITS - 1) / (2 * LIMB_DIGITS);
7104 if (bfdec_resize(r, n))
7105 goto fail;
7106 a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
7107 if (!a1)
7108 goto fail;
7109 n1 = bf_min(2 * n, a->len);
7110 memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
7111 memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
7112 if (a->expn & 1) {
7113 res = mp_shr_dec(a1, a1, 2 * n, 1, 0);
7114 } else {
7115 res = 0;
7116 }
7117 /* normalize so that a1 >= B^(2*n)/4. Not need for n = 1
7118 because mp_sqrtrem2_dec already does it */
7119 k = 0;
7120 if (n > 1) {
7121 v = a1[2 * n - 1];
7122 while (v < BF_DEC_BASE / 4) {
7123 k++;
7124 v *= 4;
7125 }
7126 if (k != 0)
7127 mp_mul1_dec(a1, a1, 2 * n, 1 << (2 * k), 0);
7128 }
7129 if (mp_sqrtrem_dec(s, r->tab, a1, n)) {
7130 bf_free(s, a1);
7131 goto fail;
7132 }
7133 if (k != 0)
7134 mp_div1_dec(r->tab, r->tab, n, 1 << k, 0);
7135 if (!res) {
7136 res = mp_scan_nz(a1, n + 1);
7137 }
7138 bf_free(s, a1);
7139 if (!res) {
7140 res = mp_scan_nz(a->tab, a->len - n1);
7141 }
7142 if (res != 0)
7143 r->tab[0] |= 1;
7144 r->sign = 0;
7145 r->expn = (a->expn + 1) >> 1;
7146 ret = bfdec_round(r, prec, flags);
7147 }
7148 return ret;
7149 fail:
7150 bfdec_set_nan(r);
7151 return BF_ST_MEM_ERROR;
7152 }
7153
7154 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
7155 is an overflow and 0 otherwise. No memory error is possible. */
bfdec_get_int32(int * pres,const bfdec_t * a)7156 int bfdec_get_int32(int *pres, const bfdec_t *a)
7157 {
7158 uint32_t v;
7159 int ret;
7160 if (a->expn >= BF_EXP_INF) {
7161 ret = 0;
7162 if (a->expn == BF_EXP_INF) {
7163 v = (uint32_t)INT32_MAX + a->sign;
7164 /* XXX: return overflow ? */
7165 } else {
7166 v = INT32_MAX;
7167 }
7168 } else if (a->expn <= 0) {
7169 v = 0;
7170 ret = 0;
7171 } else if (a->expn <= 9) {
7172 v = fast_shr_dec(a->tab[a->len - 1], LIMB_DIGITS - a->expn);
7173 if (a->sign)
7174 v = -v;
7175 ret = 0;
7176 } else if (a->expn == 10) {
7177 uint64_t v1;
7178 uint32_t v_max;
7179 #if LIMB_BITS == 64
7180 v1 = fast_shr_dec(a->tab[a->len - 1], LIMB_DIGITS - a->expn);
7181 #else
7182 v1 = (uint64_t)a->tab[a->len - 1] * 10 +
7183 get_digit(a->tab, a->len, (a->len - 1) * LIMB_DIGITS - 1);
7184 #endif
7185 v_max = (uint32_t)INT32_MAX + a->sign;
7186 if (v1 > v_max) {
7187 v = v_max;
7188 ret = BF_ST_OVERFLOW;
7189 } else {
7190 v = v1;
7191 if (a->sign)
7192 v = -v;
7193 ret = 0;
7194 }
7195 } else {
7196 v = (uint32_t)INT32_MAX + a->sign;
7197 ret = BF_ST_OVERFLOW;
7198 }
7199 *pres = v;
7200 return ret;
7201 }
7202
7203 /* power to an integer with infinite precision */
bfdec_pow_ui(bfdec_t * r,const bfdec_t * a,limb_t b)7204 int bfdec_pow_ui(bfdec_t *r, const bfdec_t *a, limb_t b)
7205 {
7206 int ret, n_bits, i;
7207
7208 assert(r != a);
7209 if (b == 0)
7210 return bfdec_set_ui(r, 1);
7211 ret = bfdec_set(r, a);
7212 n_bits = LIMB_BITS - clz(b);
7213 for(i = n_bits - 2; i >= 0; i--) {
7214 ret |= bfdec_mul(r, r, r, BF_PREC_INF, BF_RNDZ);
7215 if ((b >> i) & 1)
7216 ret |= bfdec_mul(r, r, a, BF_PREC_INF, BF_RNDZ);
7217 }
7218 return ret;
7219 }
7220
bfdec_ftoa(size_t * plen,const bfdec_t * a,limb_t prec,bf_flags_t flags)7221 char *bfdec_ftoa(size_t *plen, const bfdec_t *a, limb_t prec, bf_flags_t flags)
7222 {
7223 return bf_ftoa_internal(plen, (const bf_t *)a, 10, prec, flags, TRUE);
7224 }
7225
bfdec_atof(bfdec_t * r,const char * str,const char ** pnext,limb_t prec,bf_flags_t flags)7226 int bfdec_atof(bfdec_t *r, const char *str, const char **pnext,
7227 limb_t prec, bf_flags_t flags)
7228 {
7229 slimb_t dummy_exp;
7230 return bf_atof_internal((bf_t *)r, &dummy_exp, str, pnext, 10, prec,
7231 flags, TRUE);
7232 }
7233
7234 #endif /* USE_BF_DEC */
7235
7236 #ifdef USE_FFT_MUL
7237 /***************************************************************/
7238 /* Integer multiplication with FFT */
7239
7240 /* or LIMB_BITS at bit position 'pos' in tab */
put_bits(limb_t * tab,limb_t len,slimb_t pos,limb_t val)7241 static inline void put_bits(limb_t *tab, limb_t len, slimb_t pos, limb_t val)
7242 {
7243 limb_t i;
7244 int p;
7245
7246 i = pos >> LIMB_LOG2_BITS;
7247 p = pos & (LIMB_BITS - 1);
7248 if (i < len)
7249 tab[i] |= val << p;
7250 if (p != 0) {
7251 i++;
7252 if (i < len) {
7253 tab[i] |= val >> (LIMB_BITS - p);
7254 }
7255 }
7256 }
7257
7258 #if defined(__AVX2__)
7259
7260 typedef double NTTLimb;
7261
7262 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
7263 #define NTT_MOD_LOG2_MIN 50
7264 #define NTT_MOD_LOG2_MAX 51
7265 #define NB_MODS 5
7266 #define NTT_PROOT_2EXP 39
7267 static const int ntt_int_bits[NB_MODS] = { 254, 203, 152, 101, 50, };
7268
7269 static const limb_t ntt_mods[NB_MODS] = { 0x00073a8000000001, 0x0007858000000001, 0x0007a38000000001, 0x0007a68000000001, 0x0007fd8000000001,
7270 };
7271
7272 static const limb_t ntt_proot[2][NB_MODS] = {
7273 { 0x00056198d44332c8, 0x0002eb5d640aad39, 0x00047e31eaa35fd0, 0x0005271ac118a150, 0x00075e0ce8442bd5, },
7274 { 0x000461169761bcc5, 0x0002dac3cb2da688, 0x0004abc97751e3bf, 0x000656778fc8c485, 0x0000dc6469c269fa, },
7275 };
7276
7277 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7278 0x00020e4da740da8e, 0x0004c3dc09c09c1d, 0x000063bd097b4271, 0x000799d8f18f18fd,
7279 0x0005384222222264, 0x000572b07c1f07fe, 0x00035cd08888889a,
7280 0x00066015555557e3, 0x000725960b60b623,
7281 0x0002fc1fa1d6ce12,
7282 };
7283
7284 #else
7285
7286 typedef limb_t NTTLimb;
7287
7288 #if LIMB_BITS == 64
7289
7290 #define NTT_MOD_LOG2_MIN 61
7291 #define NTT_MOD_LOG2_MAX 62
7292 #define NB_MODS 5
7293 #define NTT_PROOT_2EXP 51
7294 static const int ntt_int_bits[NB_MODS] = { 307, 246, 185, 123, 61, };
7295
7296 static const limb_t ntt_mods[NB_MODS] = { 0x28d8000000000001, 0x2a88000000000001, 0x2ed8000000000001, 0x3508000000000001, 0x3aa8000000000001,
7297 };
7298
7299 static const limb_t ntt_proot[2][NB_MODS] = {
7300 { 0x1b8ea61034a2bea7, 0x21a9762de58206fb, 0x02ca782f0756a8ea, 0x278384537a3e50a1, 0x106e13fee74ce0ab, },
7301 { 0x233513af133e13b8, 0x1d13140d1c6f75f1, 0x12cde57f97e3eeda, 0x0d6149e23cbe654f, 0x36cd204f522a1379, },
7302 };
7303
7304 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7305 0x08a9ed097b425eea, 0x18a44aaaaaaaaab3, 0x2493f57f57f57f5d, 0x126b8d0649a7f8d4,
7306 0x09d80ed7303b5ccc, 0x25b8bcf3cf3cf3d5, 0x2ce6ce63398ce638,
7307 0x0e31fad40a57eb59, 0x02a3529fd4a7f52f,
7308 0x3a5493e93e93e94a,
7309 };
7310
7311 #elif LIMB_BITS == 32
7312
7313 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
7314 #define NTT_MOD_LOG2_MIN 29
7315 #define NTT_MOD_LOG2_MAX 30
7316 #define NB_MODS 5
7317 #define NTT_PROOT_2EXP 20
7318 static const int ntt_int_bits[NB_MODS] = { 148, 119, 89, 59, 29, };
7319
7320 static const limb_t ntt_mods[NB_MODS] = { 0x0000000032b00001, 0x0000000033700001, 0x0000000036d00001, 0x0000000037300001, 0x000000003e500001,
7321 };
7322
7323 static const limb_t ntt_proot[2][NB_MODS] = {
7324 { 0x0000000032525f31, 0x0000000005eb3b37, 0x00000000246eda9f, 0x0000000035f25901, 0x00000000022f5768, },
7325 { 0x00000000051eba1a, 0x00000000107be10e, 0x000000001cd574e0, 0x00000000053806e6, 0x000000002cd6bf98, },
7326 };
7327
7328 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7329 0x000000000449559a, 0x000000001eba6ca9, 0x000000002ec18e46, 0x000000000860160b,
7330 0x000000000d321307, 0x000000000bf51120, 0x000000000f662938,
7331 0x000000000932ab3e, 0x000000002f40eef8,
7332 0x000000002e760905,
7333 };
7334
7335 #endif /* LIMB_BITS */
7336
7337 #endif /* !AVX2 */
7338
7339 #if defined(__AVX2__)
7340 #define NTT_TRIG_K_MAX 18
7341 #else
7342 #define NTT_TRIG_K_MAX 19
7343 #endif
7344
7345 typedef struct BFNTTState {
7346 bf_context_t *ctx;
7347
7348 /* used for mul_mod_fast() */
7349 limb_t ntt_mods_div[NB_MODS];
7350
7351 limb_t ntt_proot_pow[NB_MODS][2][NTT_PROOT_2EXP + 1];
7352 limb_t ntt_proot_pow_inv[NB_MODS][2][NTT_PROOT_2EXP + 1];
7353 NTTLimb *ntt_trig[NB_MODS][2][NTT_TRIG_K_MAX + 1];
7354 /* 1/2^n mod m */
7355 limb_t ntt_len_inv[NB_MODS][NTT_PROOT_2EXP + 1][2];
7356 #if defined(__AVX2__)
7357 __m256d ntt_mods_cr_vec[NB_MODS * (NB_MODS - 1) / 2];
7358 __m256d ntt_mods_vec[NB_MODS];
7359 __m256d ntt_mods_inv_vec[NB_MODS];
7360 #else
7361 limb_t ntt_mods_cr_inv[NB_MODS * (NB_MODS - 1) / 2];
7362 #endif
7363 } BFNTTState;
7364
7365 static NTTLimb *get_trig(BFNTTState *s, int k, int inverse, int m_idx);
7366
7367 /* add modulo with up to (LIMB_BITS-1) bit modulo */
add_mod(limb_t a,limb_t b,limb_t m)7368 static inline limb_t add_mod(limb_t a, limb_t b, limb_t m)
7369 {
7370 limb_t r;
7371 r = a + b;
7372 if (r >= m)
7373 r -= m;
7374 return r;
7375 }
7376
7377 /* sub modulo with up to LIMB_BITS bit modulo */
sub_mod(limb_t a,limb_t b,limb_t m)7378 static inline limb_t sub_mod(limb_t a, limb_t b, limb_t m)
7379 {
7380 limb_t r;
7381 r = a - b;
7382 if (r > a)
7383 r += m;
7384 return r;
7385 }
7386
7387 /* return (r0+r1*B) mod m
7388 precondition: 0 <= r0+r1*B < 2^(64+NTT_MOD_LOG2_MIN)
7389 */
mod_fast(dlimb_t r,limb_t m,limb_t m_inv)7390 static inline limb_t mod_fast(dlimb_t r,
7391 limb_t m, limb_t m_inv)
7392 {
7393 limb_t a1, q, t0, r1, r0;
7394
7395 a1 = r >> NTT_MOD_LOG2_MIN;
7396
7397 q = ((dlimb_t)a1 * m_inv) >> LIMB_BITS;
7398 r = r - (dlimb_t)q * m - m * 2;
7399 r1 = r >> LIMB_BITS;
7400 t0 = (slimb_t)r1 >> 1;
7401 r += m & t0;
7402 r0 = r;
7403 r1 = r >> LIMB_BITS;
7404 r0 += m & r1;
7405 return r0;
7406 }
7407
7408 /* faster version using precomputed modulo inverse.
7409 precondition: 0 <= a * b < 2^(64+NTT_MOD_LOG2_MIN) */
mul_mod_fast(limb_t a,limb_t b,limb_t m,limb_t m_inv)7410 static inline limb_t mul_mod_fast(limb_t a, limb_t b,
7411 limb_t m, limb_t m_inv)
7412 {
7413 dlimb_t r;
7414 r = (dlimb_t)a * (dlimb_t)b;
7415 return mod_fast(r, m, m_inv);
7416 }
7417
init_mul_mod_fast(limb_t m)7418 static inline limb_t init_mul_mod_fast(limb_t m)
7419 {
7420 dlimb_t t;
7421 assert(m < (limb_t)1 << NTT_MOD_LOG2_MAX);
7422 assert(m >= (limb_t)1 << NTT_MOD_LOG2_MIN);
7423 t = (dlimb_t)1 << (LIMB_BITS + NTT_MOD_LOG2_MIN);
7424 return t / m;
7425 }
7426
7427 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7428 0 <= b < m. */
mul_mod_fast2(limb_t a,limb_t b,limb_t m,limb_t b_inv)7429 static inline limb_t mul_mod_fast2(limb_t a, limb_t b,
7430 limb_t m, limb_t b_inv)
7431 {
7432 limb_t r, q;
7433
7434 q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7435 r = a * b - q * m;
7436 if (r >= m)
7437 r -= m;
7438 return r;
7439 }
7440
7441 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7442 0 <= b < m. Let r = a * b mod m. The return value is 'r' or 'r +
7443 m'. */
mul_mod_fast3(limb_t a,limb_t b,limb_t m,limb_t b_inv)7444 static inline limb_t mul_mod_fast3(limb_t a, limb_t b,
7445 limb_t m, limb_t b_inv)
7446 {
7447 limb_t r, q;
7448
7449 q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7450 r = a * b - q * m;
7451 return r;
7452 }
7453
init_mul_mod_fast2(limb_t b,limb_t m)7454 static inline limb_t init_mul_mod_fast2(limb_t b, limb_t m)
7455 {
7456 return ((dlimb_t)b << LIMB_BITS) / m;
7457 }
7458
7459 #ifdef __AVX2__
7460
ntt_limb_to_int(NTTLimb a,limb_t m)7461 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7462 {
7463 slimb_t v;
7464 v = a;
7465 if (v < 0)
7466 v += m;
7467 if (v >= m)
7468 v -= m;
7469 return v;
7470 }
7471
int_to_ntt_limb(limb_t a,limb_t m)7472 static inline NTTLimb int_to_ntt_limb(limb_t a, limb_t m)
7473 {
7474 return (slimb_t)a;
7475 }
7476
int_to_ntt_limb2(limb_t a,limb_t m)7477 static inline NTTLimb int_to_ntt_limb2(limb_t a, limb_t m)
7478 {
7479 if (a >= (m / 2))
7480 a -= m;
7481 return (slimb_t)a;
7482 }
7483
7484 /* return r + m if r < 0 otherwise r. */
ntt_mod1(__m256d r,__m256d m)7485 static inline __m256d ntt_mod1(__m256d r, __m256d m)
7486 {
7487 return _mm256_blendv_pd(r, r + m, r);
7488 }
7489
7490 /* input: abs(r) < 2 * m. Output: abs(r) < m */
ntt_mod(__m256d r,__m256d mf,__m256d m2f)7491 static inline __m256d ntt_mod(__m256d r, __m256d mf, __m256d m2f)
7492 {
7493 return _mm256_blendv_pd(r, r + m2f, r) - mf;
7494 }
7495
7496 /* input: abs(a*b) < 2 * m^2, output: abs(r) < m */
ntt_mul_mod(__m256d a,__m256d b,__m256d mf,__m256d m_inv)7497 static inline __m256d ntt_mul_mod(__m256d a, __m256d b, __m256d mf,
7498 __m256d m_inv)
7499 {
7500 __m256d r, q, ab1, ab0, qm0, qm1;
7501 ab1 = a * b;
7502 q = _mm256_round_pd(ab1 * m_inv, 0); /* round to nearest */
7503 qm1 = q * mf;
7504 qm0 = _mm256_fmsub_pd(q, mf, qm1); /* low part */
7505 ab0 = _mm256_fmsub_pd(a, b, ab1); /* low part */
7506 r = (ab1 - qm1) + (ab0 - qm0);
7507 return r;
7508 }
7509
bf_aligned_malloc(bf_context_t * s,size_t size,size_t align)7510 static void *bf_aligned_malloc(bf_context_t *s, size_t size, size_t align)
7511 {
7512 void *ptr;
7513 void **ptr1;
7514 ptr = bf_malloc(s, size + sizeof(void *) + align - 1);
7515 if (!ptr)
7516 return NULL;
7517 ptr1 = (void **)(((uintptr_t)ptr + sizeof(void *) + align - 1) &
7518 ~(align - 1));
7519 ptr1[-1] = ptr;
7520 return ptr1;
7521 }
7522
bf_aligned_free(bf_context_t * s,void * ptr)7523 static void bf_aligned_free(bf_context_t *s, void *ptr)
7524 {
7525 if (!ptr)
7526 return;
7527 bf_free(s, ((void **)ptr)[-1]);
7528 }
7529
ntt_malloc(BFNTTState * s,size_t size)7530 static void *ntt_malloc(BFNTTState *s, size_t size)
7531 {
7532 return bf_aligned_malloc(s->ctx, size, 64);
7533 }
7534
ntt_free(BFNTTState * s,void * ptr)7535 static void ntt_free(BFNTTState *s, void *ptr)
7536 {
7537 bf_aligned_free(s->ctx, ptr);
7538 }
7539
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7540 static no_inline int ntt_fft(BFNTTState *s,
7541 NTTLimb *out_buf, NTTLimb *in_buf,
7542 NTTLimb *tmp_buf, int fft_len_log2,
7543 int inverse, int m_idx)
7544 {
7545 limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j;
7546 NTTLimb *tab_in, *tab_out, *tmp, *trig;
7547 __m256d m_inv, mf, m2f, c, a0, a1, b0, b1;
7548 limb_t m;
7549 int l;
7550
7551 m = ntt_mods[m_idx];
7552
7553 m_inv = _mm256_set1_pd(1.0 / (double)m);
7554 mf = _mm256_set1_pd(m);
7555 m2f = _mm256_set1_pd(m * 2);
7556
7557 n = (limb_t)1 << fft_len_log2;
7558 assert(n >= 8);
7559 stride_in = n / 2;
7560
7561 tab_in = in_buf;
7562 tab_out = tmp_buf;
7563 trig = get_trig(s, fft_len_log2, inverse, m_idx);
7564 if (!trig)
7565 return -1;
7566 p = 0;
7567 for(k = 0; k < stride_in; k += 4) {
7568 a0 = _mm256_load_pd(&tab_in[k]);
7569 a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7570 c = _mm256_load_pd(trig);
7571 trig += 4;
7572 b0 = ntt_mod(a0 + a1, mf, m2f);
7573 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7574 a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7575 a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7576 a0 = _mm256_permute4x64_pd(a0, 0xd8);
7577 a1 = _mm256_permute4x64_pd(a1, 0xd8);
7578 _mm256_store_pd(&tab_out[p], a0);
7579 _mm256_store_pd(&tab_out[p + 4], a1);
7580 p += 2 * 4;
7581 }
7582 tmp = tab_in;
7583 tab_in = tab_out;
7584 tab_out = tmp;
7585
7586 trig = get_trig(s, fft_len_log2 - 1, inverse, m_idx);
7587 if (!trig)
7588 return -1;
7589 p = 0;
7590 for(k = 0; k < stride_in; k += 4) {
7591 a0 = _mm256_load_pd(&tab_in[k]);
7592 a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7593 c = _mm256_setr_pd(trig[0], trig[0], trig[1], trig[1]);
7594 trig += 2;
7595 b0 = ntt_mod(a0 + a1, mf, m2f);
7596 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7597 a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7598 a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7599 _mm256_store_pd(&tab_out[p], a0);
7600 _mm256_store_pd(&tab_out[p + 4], a1);
7601 p += 2 * 4;
7602 }
7603 tmp = tab_in;
7604 tab_in = tab_out;
7605 tab_out = tmp;
7606
7607 nb_blocks = n / 4;
7608 fft_per_block = 4;
7609
7610 l = fft_len_log2 - 2;
7611 while (nb_blocks != 2) {
7612 nb_blocks >>= 1;
7613 p = 0;
7614 k = 0;
7615 trig = get_trig(s, l, inverse, m_idx);
7616 if (!trig)
7617 return -1;
7618 for(i = 0; i < nb_blocks; i++) {
7619 c = _mm256_set1_pd(trig[0]);
7620 trig++;
7621 for(j = 0; j < fft_per_block; j += 4) {
7622 a0 = _mm256_load_pd(&tab_in[k + j]);
7623 a1 = _mm256_load_pd(&tab_in[k + j + stride_in]);
7624 b0 = ntt_mod(a0 + a1, mf, m2f);
7625 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7626 _mm256_store_pd(&tab_out[p + j], b0);
7627 _mm256_store_pd(&tab_out[p + j + fft_per_block], b1);
7628 }
7629 k += fft_per_block;
7630 p += 2 * fft_per_block;
7631 }
7632 fft_per_block <<= 1;
7633 l--;
7634 tmp = tab_in;
7635 tab_in = tab_out;
7636 tab_out = tmp;
7637 }
7638
7639 tab_out = out_buf;
7640 for(k = 0; k < stride_in; k += 4) {
7641 a0 = _mm256_load_pd(&tab_in[k]);
7642 a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7643 b0 = ntt_mod(a0 + a1, mf, m2f);
7644 b1 = ntt_mod(a0 - a1, mf, m2f);
7645 _mm256_store_pd(&tab_out[k], b0);
7646 _mm256_store_pd(&tab_out[k + stride_in], b1);
7647 }
7648 return 0;
7649 }
7650
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,limb_t fft_len_log2,int k_tot,int m_idx)7651 static void ntt_vec_mul(BFNTTState *s,
7652 NTTLimb *tab1, NTTLimb *tab2, limb_t fft_len_log2,
7653 int k_tot, int m_idx)
7654 {
7655 limb_t i, c_inv, n, m;
7656 __m256d m_inv, mf, a, b, c;
7657
7658 m = ntt_mods[m_idx];
7659 c_inv = s->ntt_len_inv[m_idx][k_tot][0];
7660 m_inv = _mm256_set1_pd(1.0 / (double)m);
7661 mf = _mm256_set1_pd(m);
7662 c = _mm256_set1_pd(int_to_ntt_limb(c_inv, m));
7663 n = (limb_t)1 << fft_len_log2;
7664 for(i = 0; i < n; i += 4) {
7665 a = _mm256_load_pd(&tab1[i]);
7666 b = _mm256_load_pd(&tab2[i]);
7667 a = ntt_mul_mod(a, b, mf, m_inv);
7668 a = ntt_mul_mod(a, c, mf, m_inv);
7669 _mm256_store_pd(&tab1[i], a);
7670 }
7671 }
7672
mul_trig(NTTLimb * buf,limb_t n,limb_t c1,limb_t m,limb_t m_inv1)7673 static no_inline void mul_trig(NTTLimb *buf,
7674 limb_t n, limb_t c1, limb_t m, limb_t m_inv1)
7675 {
7676 limb_t i, c2, c3, c4;
7677 __m256d c, c_mul, a0, mf, m_inv;
7678 assert(n >= 2);
7679
7680 mf = _mm256_set1_pd(m);
7681 m_inv = _mm256_set1_pd(1.0 / (double)m);
7682
7683 c2 = mul_mod_fast(c1, c1, m, m_inv1);
7684 c3 = mul_mod_fast(c2, c1, m, m_inv1);
7685 c4 = mul_mod_fast(c2, c2, m, m_inv1);
7686 c = _mm256_setr_pd(1, int_to_ntt_limb(c1, m),
7687 int_to_ntt_limb(c2, m), int_to_ntt_limb(c3, m));
7688 c_mul = _mm256_set1_pd(int_to_ntt_limb(c4, m));
7689 for(i = 0; i < n; i += 4) {
7690 a0 = _mm256_load_pd(&buf[i]);
7691 a0 = ntt_mul_mod(a0, c, mf, m_inv);
7692 _mm256_store_pd(&buf[i], a0);
7693 c = ntt_mul_mod(c, c_mul, mf, m_inv);
7694 }
7695 }
7696
7697 #else
7698
ntt_malloc(BFNTTState * s,size_t size)7699 static void *ntt_malloc(BFNTTState *s, size_t size)
7700 {
7701 return bf_malloc(s->ctx, size);
7702 }
7703
ntt_free(BFNTTState * s,void * ptr)7704 static void ntt_free(BFNTTState *s, void *ptr)
7705 {
7706 bf_free(s->ctx, ptr);
7707 }
7708
ntt_limb_to_int(NTTLimb a,limb_t m)7709 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7710 {
7711 if (a >= m)
7712 a -= m;
7713 return a;
7714 }
7715
int_to_ntt_limb(slimb_t a,limb_t m)7716 static inline NTTLimb int_to_ntt_limb(slimb_t a, limb_t m)
7717 {
7718 return a;
7719 }
7720
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7721 static no_inline int ntt_fft(BFNTTState *s, NTTLimb *out_buf, NTTLimb *in_buf,
7722 NTTLimb *tmp_buf, int fft_len_log2,
7723 int inverse, int m_idx)
7724 {
7725 limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j, m, m2;
7726 NTTLimb *tab_in, *tab_out, *tmp, a0, a1, b0, b1, c, *trig, c_inv;
7727 int l;
7728
7729 m = ntt_mods[m_idx];
7730 m2 = 2 * m;
7731 n = (limb_t)1 << fft_len_log2;
7732 nb_blocks = n;
7733 fft_per_block = 1;
7734 stride_in = n / 2;
7735 tab_in = in_buf;
7736 tab_out = tmp_buf;
7737 l = fft_len_log2;
7738 while (nb_blocks != 2) {
7739 nb_blocks >>= 1;
7740 p = 0;
7741 k = 0;
7742 trig = get_trig(s, l, inverse, m_idx);
7743 if (!trig)
7744 return -1;
7745 for(i = 0; i < nb_blocks; i++) {
7746 c = trig[0];
7747 c_inv = trig[1];
7748 trig += 2;
7749 for(j = 0; j < fft_per_block; j++) {
7750 a0 = tab_in[k + j];
7751 a1 = tab_in[k + j + stride_in];
7752 b0 = add_mod(a0, a1, m2);
7753 b1 = a0 - a1 + m2;
7754 b1 = mul_mod_fast3(b1, c, m, c_inv);
7755 tab_out[p + j] = b0;
7756 tab_out[p + j + fft_per_block] = b1;
7757 }
7758 k += fft_per_block;
7759 p += 2 * fft_per_block;
7760 }
7761 fft_per_block <<= 1;
7762 l--;
7763 tmp = tab_in;
7764 tab_in = tab_out;
7765 tab_out = tmp;
7766 }
7767 /* no twiddle in last step */
7768 tab_out = out_buf;
7769 for(k = 0; k < stride_in; k++) {
7770 a0 = tab_in[k];
7771 a1 = tab_in[k + stride_in];
7772 b0 = add_mod(a0, a1, m2);
7773 b1 = sub_mod(a0, a1, m2);
7774 tab_out[k] = b0;
7775 tab_out[k + stride_in] = b1;
7776 }
7777 return 0;
7778 }
7779
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,int fft_len_log2,int k_tot,int m_idx)7780 static void ntt_vec_mul(BFNTTState *s,
7781 NTTLimb *tab1, NTTLimb *tab2, int fft_len_log2,
7782 int k_tot, int m_idx)
7783 {
7784 limb_t i, norm, norm_inv, a, n, m, m_inv;
7785
7786 m = ntt_mods[m_idx];
7787 m_inv = s->ntt_mods_div[m_idx];
7788 norm = s->ntt_len_inv[m_idx][k_tot][0];
7789 norm_inv = s->ntt_len_inv[m_idx][k_tot][1];
7790 n = (limb_t)1 << fft_len_log2;
7791 for(i = 0; i < n; i++) {
7792 a = tab1[i];
7793 /* need to reduce the range so that the product is <
7794 2^(LIMB_BITS+NTT_MOD_LOG2_MIN) */
7795 if (a >= m)
7796 a -= m;
7797 a = mul_mod_fast(a, tab2[i], m, m_inv);
7798 a = mul_mod_fast3(a, norm, m, norm_inv);
7799 tab1[i] = a;
7800 }
7801 }
7802
mul_trig(NTTLimb * buf,limb_t n,limb_t c_mul,limb_t m,limb_t m_inv)7803 static no_inline void mul_trig(NTTLimb *buf,
7804 limb_t n, limb_t c_mul, limb_t m, limb_t m_inv)
7805 {
7806 limb_t i, c0, c_mul_inv;
7807
7808 c0 = 1;
7809 c_mul_inv = init_mul_mod_fast2(c_mul, m);
7810 for(i = 0; i < n; i++) {
7811 buf[i] = mul_mod_fast(buf[i], c0, m, m_inv);
7812 c0 = mul_mod_fast2(c0, c_mul, m, c_mul_inv);
7813 }
7814 }
7815
7816 #endif /* !AVX2 */
7817
get_trig(BFNTTState * s,int k,int inverse,int m_idx)7818 static no_inline NTTLimb *get_trig(BFNTTState *s,
7819 int k, int inverse, int m_idx)
7820 {
7821 NTTLimb *tab;
7822 limb_t i, n2, c, c_mul, m, c_mul_inv;
7823
7824 if (k > NTT_TRIG_K_MAX)
7825 return NULL;
7826
7827 tab = s->ntt_trig[m_idx][inverse][k];
7828 if (tab)
7829 return tab;
7830 n2 = (limb_t)1 << (k - 1);
7831 m = ntt_mods[m_idx];
7832 #ifdef __AVX2__
7833 tab = ntt_malloc(s, sizeof(NTTLimb) * n2);
7834 #else
7835 tab = ntt_malloc(s, sizeof(NTTLimb) * n2 * 2);
7836 #endif
7837 if (!tab)
7838 return NULL;
7839 c = 1;
7840 c_mul = s->ntt_proot_pow[m_idx][inverse][k];
7841 c_mul_inv = s->ntt_proot_pow_inv[m_idx][inverse][k];
7842 for(i = 0; i < n2; i++) {
7843 #ifdef __AVX2__
7844 tab[i] = int_to_ntt_limb2(c, m);
7845 #else
7846 tab[2 * i] = int_to_ntt_limb(c, m);
7847 tab[2 * i + 1] = init_mul_mod_fast2(c, m);
7848 #endif
7849 c = mul_mod_fast2(c, c_mul, m, c_mul_inv);
7850 }
7851 s->ntt_trig[m_idx][inverse][k] = tab;
7852 return tab;
7853 }
7854
fft_clear_cache(bf_context_t * s1)7855 void fft_clear_cache(bf_context_t *s1)
7856 {
7857 int m_idx, inverse, k;
7858 BFNTTState *s = s1->ntt_state;
7859 if (s) {
7860 for(m_idx = 0; m_idx < NB_MODS; m_idx++) {
7861 for(inverse = 0; inverse < 2; inverse++) {
7862 for(k = 0; k < NTT_TRIG_K_MAX + 1; k++) {
7863 if (s->ntt_trig[m_idx][inverse][k]) {
7864 ntt_free(s, s->ntt_trig[m_idx][inverse][k]);
7865 s->ntt_trig[m_idx][inverse][k] = NULL;
7866 }
7867 }
7868 }
7869 }
7870 #if defined(__AVX2__)
7871 bf_aligned_free(s1, s);
7872 #else
7873 bf_free(s1, s);
7874 #endif
7875 s1->ntt_state = NULL;
7876 }
7877 }
7878
7879 #define STRIP_LEN 16
7880
7881 /* dst = buf1, src = buf2 */
ntt_fft_partial(BFNTTState * s,NTTLimb * buf1,int k1,int k2,limb_t n1,limb_t n2,int inverse,limb_t m_idx)7882 static int ntt_fft_partial(BFNTTState *s, NTTLimb *buf1,
7883 int k1, int k2, limb_t n1, limb_t n2, int inverse,
7884 limb_t m_idx)
7885 {
7886 limb_t i, j, c_mul, c0, m, m_inv, strip_len, l;
7887 NTTLimb *buf2, *buf3;
7888
7889 buf2 = NULL;
7890 buf3 = ntt_malloc(s, sizeof(NTTLimb) * n1);
7891 if (!buf3)
7892 goto fail;
7893 if (k2 == 0) {
7894 if (ntt_fft(s, buf1, buf1, buf3, k1, inverse, m_idx))
7895 goto fail;
7896 } else {
7897 strip_len = STRIP_LEN;
7898 buf2 = ntt_malloc(s, sizeof(NTTLimb) * n1 * strip_len);
7899 if (!buf2)
7900 goto fail;
7901 m = ntt_mods[m_idx];
7902 m_inv = s->ntt_mods_div[m_idx];
7903 c0 = s->ntt_proot_pow[m_idx][inverse][k1 + k2];
7904 c_mul = 1;
7905 assert((n2 % strip_len) == 0);
7906 for(j = 0; j < n2; j += strip_len) {
7907 for(i = 0; i < n1; i++) {
7908 for(l = 0; l < strip_len; l++) {
7909 buf2[i + l * n1] = buf1[i * n2 + (j + l)];
7910 }
7911 }
7912 for(l = 0; l < strip_len; l++) {
7913 if (inverse)
7914 mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7915 if (ntt_fft(s, buf2 + l * n1, buf2 + l * n1, buf3, k1, inverse, m_idx))
7916 goto fail;
7917 if (!inverse)
7918 mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7919 c_mul = mul_mod_fast(c_mul, c0, m, m_inv);
7920 }
7921
7922 for(i = 0; i < n1; i++) {
7923 for(l = 0; l < strip_len; l++) {
7924 buf1[i * n2 + (j + l)] = buf2[i + l *n1];
7925 }
7926 }
7927 }
7928 ntt_free(s, buf2);
7929 }
7930 ntt_free(s, buf3);
7931 return 0;
7932 fail:
7933 ntt_free(s, buf2);
7934 ntt_free(s, buf3);
7935 return -1;
7936 }
7937
7938
7939 /* dst = buf1, src = buf2, tmp = buf3 */
ntt_conv(BFNTTState * s,NTTLimb * buf1,NTTLimb * buf2,int k,int k_tot,limb_t m_idx)7940 static int ntt_conv(BFNTTState *s, NTTLimb *buf1, NTTLimb *buf2,
7941 int k, int k_tot, limb_t m_idx)
7942 {
7943 limb_t n1, n2, i;
7944 int k1, k2;
7945
7946 if (k <= NTT_TRIG_K_MAX) {
7947 k1 = k;
7948 } else {
7949 /* recursive split of the FFT */
7950 k1 = bf_min(k / 2, NTT_TRIG_K_MAX);
7951 }
7952 k2 = k - k1;
7953 n1 = (limb_t)1 << k1;
7954 n2 = (limb_t)1 << k2;
7955
7956 if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 0, m_idx))
7957 return -1;
7958 if (ntt_fft_partial(s, buf2, k1, k2, n1, n2, 0, m_idx))
7959 return -1;
7960 if (k2 == 0) {
7961 ntt_vec_mul(s, buf1, buf2, k, k_tot, m_idx);
7962 } else {
7963 for(i = 0; i < n1; i++) {
7964 ntt_conv(s, buf1 + i * n2, buf2 + i * n2, k2, k_tot, m_idx);
7965 }
7966 }
7967 if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 1, m_idx))
7968 return -1;
7969 return 0;
7970 }
7971
7972
limb_to_ntt(BFNTTState * s,NTTLimb * tabr,limb_t fft_len,const limb_t * taba,limb_t a_len,int dpl,int first_m_idx,int nb_mods)7973 static no_inline void limb_to_ntt(BFNTTState *s,
7974 NTTLimb *tabr, limb_t fft_len,
7975 const limb_t *taba, limb_t a_len, int dpl,
7976 int first_m_idx, int nb_mods)
7977 {
7978 slimb_t i, n;
7979 dlimb_t a, b;
7980 int j, shift;
7981 limb_t base_mask1, a0, a1, a2, r, m, m_inv;
7982
7983 #if 0
7984 for(i = 0; i < a_len; i++) {
7985 printf("%" PRId64 ": " FMT_LIMB "\n",
7986 (int64_t)i, taba[i]);
7987 }
7988 #endif
7989 memset(tabr, 0, sizeof(NTTLimb) * fft_len * nb_mods);
7990 shift = dpl & (LIMB_BITS - 1);
7991 if (shift == 0)
7992 base_mask1 = -1;
7993 else
7994 base_mask1 = ((limb_t)1 << shift) - 1;
7995 n = bf_min(fft_len, (a_len * LIMB_BITS + dpl - 1) / dpl);
7996 for(i = 0; i < n; i++) {
7997 a0 = get_bits(taba, a_len, i * dpl);
7998 if (dpl <= LIMB_BITS) {
7999 a0 &= base_mask1;
8000 a = a0;
8001 } else {
8002 a1 = get_bits(taba, a_len, i * dpl + LIMB_BITS);
8003 if (dpl <= (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
8004 a = a0 | ((dlimb_t)(a1 & base_mask1) << LIMB_BITS);
8005 } else {
8006 if (dpl > 2 * LIMB_BITS) {
8007 a2 = get_bits(taba, a_len, i * dpl + LIMB_BITS * 2) &
8008 base_mask1;
8009 } else {
8010 a1 &= base_mask1;
8011 a2 = 0;
8012 }
8013 // printf("a=0x%016lx%016lx%016lx\n", a2, a1, a0);
8014 a = (a0 >> (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) |
8015 ((dlimb_t)a1 << (NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN)) |
8016 ((dlimb_t)a2 << (LIMB_BITS + NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN));
8017 a0 &= ((limb_t)1 << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) - 1;
8018 }
8019 }
8020 for(j = 0; j < nb_mods; j++) {
8021 m = ntt_mods[first_m_idx + j];
8022 m_inv = s->ntt_mods_div[first_m_idx + j];
8023 r = mod_fast(a, m, m_inv);
8024 if (dpl > (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
8025 b = ((dlimb_t)r << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) | a0;
8026 r = mod_fast(b, m, m_inv);
8027 }
8028 tabr[i + j * fft_len] = int_to_ntt_limb(r, m);
8029 }
8030 }
8031 }
8032
8033 #if defined(__AVX2__)
8034
8035 #define VEC_LEN 4
8036
8037 typedef union {
8038 __m256d v;
8039 double d[4];
8040 } VecUnion;
8041
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)8042 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
8043 const NTTLimb *buf, int fft_len_log2, int dpl,
8044 int nb_mods)
8045 {
8046 const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
8047 const __m256d *mods_cr_vec, *mf, *m_inv;
8048 VecUnion y[NB_MODS];
8049 limb_t u[NB_MODS], carry[NB_MODS], fft_len, base_mask1, r;
8050 slimb_t i, len, pos;
8051 int j, k, l, shift, n_limb1, p;
8052 dlimb_t t;
8053
8054 j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
8055 mods_cr_vec = s->ntt_mods_cr_vec + j;
8056 mf = s->ntt_mods_vec + NB_MODS - nb_mods;
8057 m_inv = s->ntt_mods_inv_vec + NB_MODS - nb_mods;
8058
8059 shift = dpl & (LIMB_BITS - 1);
8060 if (shift == 0)
8061 base_mask1 = -1;
8062 else
8063 base_mask1 = ((limb_t)1 << shift) - 1;
8064 n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
8065 for(j = 0; j < NB_MODS; j++)
8066 carry[j] = 0;
8067 for(j = 0; j < NB_MODS; j++)
8068 u[j] = 0; /* avoid warnings */
8069 memset(tabr, 0, sizeof(limb_t) * r_len);
8070 fft_len = (limb_t)1 << fft_len_log2;
8071 len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
8072 len = (len + VEC_LEN - 1) & ~(VEC_LEN - 1);
8073 i = 0;
8074 while (i < len) {
8075 for(j = 0; j < nb_mods; j++)
8076 y[j].v = *(__m256d *)&buf[i + fft_len * j];
8077
8078 /* Chinese remainder to get mixed radix representation */
8079 l = 0;
8080 for(j = 0; j < nb_mods - 1; j++) {
8081 y[j].v = ntt_mod1(y[j].v, mf[j]);
8082 for(k = j + 1; k < nb_mods; k++) {
8083 y[k].v = ntt_mul_mod(y[k].v - y[j].v,
8084 mods_cr_vec[l], mf[k], m_inv[k]);
8085 l++;
8086 }
8087 }
8088 y[j].v = ntt_mod1(y[j].v, mf[j]);
8089
8090 for(p = 0; p < VEC_LEN; p++) {
8091 /* back to normal representation */
8092 u[0] = (int64_t)y[nb_mods - 1].d[p];
8093 l = 1;
8094 for(j = nb_mods - 2; j >= 1; j--) {
8095 r = (int64_t)y[j].d[p];
8096 for(k = 0; k < l; k++) {
8097 t = (dlimb_t)u[k] * mods[j] + r;
8098 r = t >> LIMB_BITS;
8099 u[k] = t;
8100 }
8101 u[l] = r;
8102 l++;
8103 }
8104 /* XXX: for nb_mods = 5, l should be 4 */
8105
8106 /* last step adds the carry */
8107 r = (int64_t)y[0].d[p];
8108 for(k = 0; k < l; k++) {
8109 t = (dlimb_t)u[k] * mods[j] + r + carry[k];
8110 r = t >> LIMB_BITS;
8111 u[k] = t;
8112 }
8113 u[l] = r + carry[l];
8114
8115 #if 0
8116 printf("%" PRId64 ": ", i);
8117 for(j = nb_mods - 1; j >= 0; j--) {
8118 printf(" %019" PRIu64, u[j]);
8119 }
8120 printf("\n");
8121 #endif
8122
8123 /* write the digits */
8124 pos = i * dpl;
8125 for(j = 0; j < n_limb1; j++) {
8126 put_bits(tabr, r_len, pos, u[j]);
8127 pos += LIMB_BITS;
8128 }
8129 put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
8130 /* shift by dpl digits and set the carry */
8131 if (shift == 0) {
8132 for(j = n_limb1 + 1; j < nb_mods; j++)
8133 carry[j - (n_limb1 + 1)] = u[j];
8134 } else {
8135 for(j = n_limb1; j < nb_mods - 1; j++) {
8136 carry[j - n_limb1] = (u[j] >> shift) |
8137 (u[j + 1] << (LIMB_BITS - shift));
8138 }
8139 carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
8140 }
8141 i++;
8142 }
8143 }
8144 }
8145 #else
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)8146 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
8147 const NTTLimb *buf, int fft_len_log2, int dpl,
8148 int nb_mods)
8149 {
8150 const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
8151 const limb_t *mods_cr, *mods_cr_inv;
8152 limb_t y[NB_MODS], u[NB_MODS], carry[NB_MODS], fft_len, base_mask1, r;
8153 slimb_t i, len, pos;
8154 int j, k, l, shift, n_limb1;
8155 dlimb_t t;
8156
8157 j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
8158 mods_cr = ntt_mods_cr + j;
8159 mods_cr_inv = s->ntt_mods_cr_inv + j;
8160
8161 shift = dpl & (LIMB_BITS - 1);
8162 if (shift == 0)
8163 base_mask1 = -1;
8164 else
8165 base_mask1 = ((limb_t)1 << shift) - 1;
8166 n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
8167 for(j = 0; j < NB_MODS; j++)
8168 carry[j] = 0;
8169 for(j = 0; j < NB_MODS; j++)
8170 u[j] = 0; /* avoid warnings */
8171 memset(tabr, 0, sizeof(limb_t) * r_len);
8172 fft_len = (limb_t)1 << fft_len_log2;
8173 len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
8174 for(i = 0; i < len; i++) {
8175 for(j = 0; j < nb_mods; j++) {
8176 y[j] = ntt_limb_to_int(buf[i + fft_len * j], mods[j]);
8177 }
8178
8179 /* Chinese remainder to get mixed radix representation */
8180 l = 0;
8181 for(j = 0; j < nb_mods - 1; j++) {
8182 for(k = j + 1; k < nb_mods; k++) {
8183 limb_t m;
8184 m = mods[k];
8185 /* Note: there is no overflow in the sub_mod() because
8186 the modulos are sorted by increasing order */
8187 y[k] = mul_mod_fast2(y[k] - y[j] + m,
8188 mods_cr[l], m, mods_cr_inv[l]);
8189 l++;
8190 }
8191 }
8192
8193 /* back to normal representation */
8194 u[0] = y[nb_mods - 1];
8195 l = 1;
8196 for(j = nb_mods - 2; j >= 1; j--) {
8197 r = y[j];
8198 for(k = 0; k < l; k++) {
8199 t = (dlimb_t)u[k] * mods[j] + r;
8200 r = t >> LIMB_BITS;
8201 u[k] = t;
8202 }
8203 u[l] = r;
8204 l++;
8205 }
8206
8207 /* last step adds the carry */
8208 r = y[0];
8209 for(k = 0; k < l; k++) {
8210 t = (dlimb_t)u[k] * mods[j] + r + carry[k];
8211 r = t >> LIMB_BITS;
8212 u[k] = t;
8213 }
8214 u[l] = r + carry[l];
8215
8216 #if 0
8217 printf("%" PRId64 ": ", (int64_t)i);
8218 for(j = nb_mods - 1; j >= 0; j--) {
8219 printf(" " FMT_LIMB, u[j]);
8220 }
8221 printf("\n");
8222 #endif
8223
8224 /* write the digits */
8225 pos = i * dpl;
8226 for(j = 0; j < n_limb1; j++) {
8227 put_bits(tabr, r_len, pos, u[j]);
8228 pos += LIMB_BITS;
8229 }
8230 put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
8231 /* shift by dpl digits and set the carry */
8232 if (shift == 0) {
8233 for(j = n_limb1 + 1; j < nb_mods; j++)
8234 carry[j - (n_limb1 + 1)] = u[j];
8235 } else {
8236 for(j = n_limb1; j < nb_mods - 1; j++) {
8237 carry[j - n_limb1] = (u[j] >> shift) |
8238 (u[j + 1] << (LIMB_BITS - shift));
8239 }
8240 carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
8241 }
8242 }
8243 }
8244 #endif
8245
ntt_static_init(bf_context_t * s1)8246 static int ntt_static_init(bf_context_t *s1)
8247 {
8248 BFNTTState *s;
8249 int inverse, i, j, k, l;
8250 limb_t c, c_inv, c_inv2, m, m_inv;
8251
8252 if (s1->ntt_state)
8253 return 0;
8254 #if defined(__AVX2__)
8255 s = bf_aligned_malloc(s1, sizeof(*s), 64);
8256 #else
8257 s = bf_malloc(s1, sizeof(*s));
8258 #endif
8259 if (!s)
8260 return -1;
8261 memset(s, 0, sizeof(*s));
8262 s1->ntt_state = s;
8263 s->ctx = s1;
8264
8265 for(j = 0; j < NB_MODS; j++) {
8266 m = ntt_mods[j];
8267 m_inv = init_mul_mod_fast(m);
8268 s->ntt_mods_div[j] = m_inv;
8269 #if defined(__AVX2__)
8270 s->ntt_mods_vec[j] = _mm256_set1_pd(m);
8271 s->ntt_mods_inv_vec[j] = _mm256_set1_pd(1.0 / (double)m);
8272 #endif
8273 c_inv2 = (m + 1) / 2; /* 1/2 */
8274 c_inv = 1;
8275 for(i = 0; i <= NTT_PROOT_2EXP; i++) {
8276 s->ntt_len_inv[j][i][0] = c_inv;
8277 s->ntt_len_inv[j][i][1] = init_mul_mod_fast2(c_inv, m);
8278 c_inv = mul_mod_fast(c_inv, c_inv2, m, m_inv);
8279 }
8280
8281 for(inverse = 0; inverse < 2; inverse++) {
8282 c = ntt_proot[inverse][j];
8283 for(i = 0; i < NTT_PROOT_2EXP; i++) {
8284 s->ntt_proot_pow[j][inverse][NTT_PROOT_2EXP - i] = c;
8285 s->ntt_proot_pow_inv[j][inverse][NTT_PROOT_2EXP - i] =
8286 init_mul_mod_fast2(c, m);
8287 c = mul_mod_fast(c, c, m, m_inv);
8288 }
8289 }
8290 }
8291
8292 l = 0;
8293 for(j = 0; j < NB_MODS - 1; j++) {
8294 for(k = j + 1; k < NB_MODS; k++) {
8295 #if defined(__AVX2__)
8296 s->ntt_mods_cr_vec[l] = _mm256_set1_pd(int_to_ntt_limb2(ntt_mods_cr[l],
8297 ntt_mods[k]));
8298 #else
8299 s->ntt_mods_cr_inv[l] = init_mul_mod_fast2(ntt_mods_cr[l],
8300 ntt_mods[k]);
8301 #endif
8302 l++;
8303 }
8304 }
8305 return 0;
8306 }
8307
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8308 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8309 {
8310 int dpl, fft_len_log2, n_bits, nb_mods, dpl_found, fft_len_log2_found;
8311 int int_bits, nb_mods_found;
8312 limb_t cost, min_cost;
8313
8314 min_cost = -1;
8315 dpl_found = 0;
8316 nb_mods_found = 4;
8317 fft_len_log2_found = 0;
8318 for(nb_mods = 3; nb_mods <= NB_MODS; nb_mods++) {
8319 int_bits = ntt_int_bits[NB_MODS - nb_mods];
8320 dpl = bf_min((int_bits - 4) / 2,
8321 2 * LIMB_BITS + 2 * NTT_MOD_LOG2_MIN - NTT_MOD_LOG2_MAX);
8322 for(;;) {
8323 fft_len_log2 = ceil_log2((len * LIMB_BITS + dpl - 1) / dpl);
8324 if (fft_len_log2 > NTT_PROOT_2EXP)
8325 goto next;
8326 n_bits = fft_len_log2 + 2 * dpl;
8327 if (n_bits <= int_bits) {
8328 cost = ((limb_t)(fft_len_log2 + 1) << fft_len_log2) * nb_mods;
8329 // printf("n=%d dpl=%d: cost=%" PRId64 "\n", nb_mods, dpl, (int64_t)cost);
8330 if (cost < min_cost) {
8331 min_cost = cost;
8332 dpl_found = dpl;
8333 nb_mods_found = nb_mods;
8334 fft_len_log2_found = fft_len_log2;
8335 }
8336 break;
8337 }
8338 dpl--;
8339 if (dpl == 0)
8340 break;
8341 }
8342 next: ;
8343 }
8344 if (!dpl_found)
8345 abort();
8346 /* limit dpl if possible to reduce fixed cost of limb/NTT conversion */
8347 if (dpl_found > (LIMB_BITS + NTT_MOD_LOG2_MIN) &&
8348 ((limb_t)(LIMB_BITS + NTT_MOD_LOG2_MIN) << fft_len_log2_found) >=
8349 len * LIMB_BITS) {
8350 dpl_found = LIMB_BITS + NTT_MOD_LOG2_MIN;
8351 }
8352 *pnb_mods = nb_mods_found;
8353 *pdpl = dpl_found;
8354 return fft_len_log2_found;
8355 }
8356
8357 /* return 0 if OK, -1 if memory error */
fft_mul(bf_context_t * s1,bf_t * res,limb_t * a_tab,limb_t a_len,limb_t * b_tab,limb_t b_len,int mul_flags)8358 static no_inline int fft_mul(bf_context_t *s1,
8359 bf_t *res, limb_t *a_tab, limb_t a_len,
8360 limb_t *b_tab, limb_t b_len, int mul_flags)
8361 {
8362 BFNTTState *s;
8363 int dpl, fft_len_log2, j, nb_mods, reduced_mem;
8364 slimb_t len, fft_len;
8365 NTTLimb *buf1, *buf2, *ptr;
8366 #if defined(USE_MUL_CHECK)
8367 limb_t ha, hb, hr, h_ref;
8368 #endif
8369
8370 if (ntt_static_init(s1))
8371 return -1;
8372 s = s1->ntt_state;
8373
8374 /* find the optimal number of digits per limb (dpl) */
8375 len = a_len + b_len;
8376 fft_len_log2 = bf_get_fft_size(&dpl, &nb_mods, len);
8377 fft_len = (uint64_t)1 << fft_len_log2;
8378 // printf("len=%" PRId64 " fft_len_log2=%d dpl=%d\n", len, fft_len_log2, dpl);
8379 #if defined(USE_MUL_CHECK)
8380 ha = mp_mod1(a_tab, a_len, BF_CHKSUM_MOD, 0);
8381 hb = mp_mod1(b_tab, b_len, BF_CHKSUM_MOD, 0);
8382 #endif
8383 if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) == 0) {
8384 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8385 bf_resize(res, 0);
8386 } else if (mul_flags & FFT_MUL_R_OVERLAP_B) {
8387 limb_t *tmp_tab, tmp_len;
8388 /* it is better to free 'b' first */
8389 tmp_tab = a_tab;
8390 a_tab = b_tab;
8391 b_tab = tmp_tab;
8392 tmp_len = a_len;
8393 a_len = b_len;
8394 b_len = tmp_len;
8395 }
8396 buf1 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8397 if (!buf1)
8398 return -1;
8399 limb_to_ntt(s, buf1, fft_len, a_tab, a_len, dpl,
8400 NB_MODS - nb_mods, nb_mods);
8401 if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) ==
8402 FFT_MUL_R_OVERLAP_A) {
8403 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8404 bf_resize(res, 0);
8405 }
8406 reduced_mem = (fft_len_log2 >= 14);
8407 if (!reduced_mem) {
8408 buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8409 if (!buf2)
8410 goto fail;
8411 limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8412 NB_MODS - nb_mods, nb_mods);
8413 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8414 bf_resize(res, 0); /* in case res == b */
8415 } else {
8416 buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len);
8417 if (!buf2)
8418 goto fail;
8419 }
8420 for(j = 0; j < nb_mods; j++) {
8421 if (reduced_mem) {
8422 limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8423 NB_MODS - nb_mods + j, 1);
8424 ptr = buf2;
8425 } else {
8426 ptr = buf2 + fft_len * j;
8427 }
8428 if (ntt_conv(s, buf1 + fft_len * j, ptr,
8429 fft_len_log2, fft_len_log2, j + NB_MODS - nb_mods))
8430 goto fail;
8431 }
8432 if (!(mul_flags & FFT_MUL_R_NORESIZE))
8433 bf_resize(res, 0); /* in case res == b and reduced mem */
8434 ntt_free(s, buf2);
8435 buf2 = NULL;
8436 if (!(mul_flags & FFT_MUL_R_NORESIZE)) {
8437 if (bf_resize(res, len))
8438 goto fail;
8439 }
8440 ntt_to_limb(s, res->tab, len, buf1, fft_len_log2, dpl, nb_mods);
8441 ntt_free(s, buf1);
8442 #if defined(USE_MUL_CHECK)
8443 hr = mp_mod1(res->tab, len, BF_CHKSUM_MOD, 0);
8444 h_ref = mul_mod(ha, hb, BF_CHKSUM_MOD);
8445 if (hr != h_ref) {
8446 printf("ntt_mul_error: len=%" PRId_LIMB " fft_len_log2=%d dpl=%d nb_mods=%d\n",
8447 len, fft_len_log2, dpl, nb_mods);
8448 // printf("ha=0x" FMT_LIMB" hb=0x" FMT_LIMB " hr=0x" FMT_LIMB " expected=0x" FMT_LIMB "\n", ha, hb, hr, h_ref);
8449 exit(1);
8450 }
8451 #endif
8452 return 0;
8453 fail:
8454 ntt_free(s, buf1);
8455 ntt_free(s, buf2);
8456 return -1;
8457 }
8458
8459 #else /* USE_FFT_MUL */
8460
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8461 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8462 {
8463 return 0;
8464 }
8465
8466 #endif /* !USE_FFT_MUL */
8467