• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Tiny arbitrary precision floating point library
3  *
4  * Copyright (c) 2017-2020 Fabrice Bellard
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to deal
8  * in the Software without restriction, including without limitation the rights
9  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10  * copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in
14  * all copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22  * THE SOFTWARE.
23  */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <inttypes.h>
27 #include <math.h>
28 #include <string.h>
29 #include <assert.h>
30 
31 #ifdef __AVX2__
32 #include <immintrin.h>
33 #endif
34 
35 #include "cutils.h"
36 #include "libbf.h"
37 
38 /* enable it to check the multiplication result */
39 //#define USE_MUL_CHECK
40 /* enable it to use FFT/NTT multiplication */
41 #define USE_FFT_MUL
42 /* enable decimal floating point support */
43 #define USE_BF_DEC
44 
45 //#define inline __attribute__((always_inline))
46 
47 #ifdef __AVX2__
48 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
49 #else
50 #define FFT_MUL_THRESHOLD 100 /* in limbs of the smallest factor */
51 #endif
52 
53 /* XXX: adjust */
54 #define DIVNORM_LARGE_THRESHOLD 50
55 #define UDIV1NORM_THRESHOLD 3
56 
57 #if LIMB_BITS == 64
58 #define FMT_LIMB1 "%" PRIx64
59 #define FMT_LIMB "%016" PRIx64
60 #define PRId_LIMB PRId64
61 #define PRIu_LIMB PRIu64
62 
63 #else
64 
65 #define FMT_LIMB1 "%x"
66 #define FMT_LIMB "%08x"
67 #define PRId_LIMB "d"
68 #define PRIu_LIMB "u"
69 
70 #endif
71 
72 typedef intptr_t mp_size_t;
73 
74 typedef int bf_op2_func_t(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
75                           bf_flags_t flags);
76 
77 #ifdef USE_FFT_MUL
78 
79 #define FFT_MUL_R_OVERLAP_A (1 << 0)
80 #define FFT_MUL_R_OVERLAP_B (1 << 1)
81 #define FFT_MUL_R_NORESIZE  (1 << 2)
82 
83 static no_inline int fft_mul(bf_context_t *s,
84                              bf_t *res, limb_t *a_tab, limb_t a_len,
85                              limb_t *b_tab, limb_t b_len, int mul_flags);
86 static void fft_clear_cache(bf_context_t *s);
87 #endif
88 #ifdef USE_BF_DEC
89 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos);
90 #endif
91 
92 
93 /* could leading zeros */
clz(limb_t a)94 static inline int clz(limb_t a)
95 {
96     if (a == 0) {
97         return LIMB_BITS;
98     } else {
99 #if LIMB_BITS == 64
100         return clz64(a);
101 #else
102         return clz32(a);
103 #endif
104     }
105 }
106 
ctz(limb_t a)107 static inline int ctz(limb_t a)
108 {
109     if (a == 0) {
110         return LIMB_BITS;
111     } else {
112 #if LIMB_BITS == 64
113         return ctz64(a);
114 #else
115         return ctz32(a);
116 #endif
117     }
118 }
119 
ceil_log2(limb_t a)120 static inline int ceil_log2(limb_t a)
121 {
122     if (a <= 1)
123         return 0;
124     else
125         return LIMB_BITS - clz(a - 1);
126 }
127 
128 /* b must be >= 1 */
ceil_div(slimb_t a,slimb_t b)129 static inline slimb_t ceil_div(slimb_t a, slimb_t b)
130 {
131     if (a >= 0)
132         return (a + b - 1) / b;
133     else
134         return a / b;
135 }
136 
137 /* b must be >= 1 */
floor_div(slimb_t a,slimb_t b)138 static inline slimb_t floor_div(slimb_t a, slimb_t b)
139 {
140     if (a >= 0) {
141         return a / b;
142     } else {
143         return (a - b + 1) / b;
144     }
145 }
146 
147 /* return r = a modulo b (0 <= r <= b - 1. b must be >= 1 */
smod(slimb_t a,slimb_t b)148 static inline limb_t smod(slimb_t a, slimb_t b)
149 {
150     a = a % (slimb_t)b;
151     if (a < 0)
152         a += b;
153     return a;
154 }
155 
156 /* signed addition with saturation */
sat_add(slimb_t a,slimb_t b)157 static inline slimb_t sat_add(slimb_t a, slimb_t b)
158 {
159     slimb_t r;
160     r = a + b;
161     /* overflow ? */
162     if (((a ^ r) & (b ^ r)) < 0)
163         r = (a >> (LIMB_BITS - 1)) ^ (((limb_t)1 << (LIMB_BITS - 1)) - 1);
164     return r;
165 }
166 
167 #define malloc(s) malloc_is_forbidden(s)
168 #define free(p) free_is_forbidden(p)
169 #define realloc(p, s) realloc_is_forbidden(p, s)
170 
bf_context_init(bf_context_t * s,bf_realloc_func_t * realloc_func,void * realloc_opaque)171 void bf_context_init(bf_context_t *s, bf_realloc_func_t *realloc_func,
172                      void *realloc_opaque)
173 {
174     memset(s, 0, sizeof(*s));
175     s->realloc_func = realloc_func;
176     s->realloc_opaque = realloc_opaque;
177 }
178 
bf_context_end(bf_context_t * s)179 void bf_context_end(bf_context_t *s)
180 {
181     bf_clear_cache(s);
182 }
183 
bf_init(bf_context_t * s,bf_t * r)184 void bf_init(bf_context_t *s, bf_t *r)
185 {
186     r->ctx = s;
187     r->sign = 0;
188     r->expn = BF_EXP_ZERO;
189     r->len = 0;
190     r->tab = NULL;
191 }
192 
193 /* return 0 if OK, -1 if alloc error */
bf_resize(bf_t * r,limb_t len)194 int bf_resize(bf_t *r, limb_t len)
195 {
196     limb_t *tab;
197 
198     if (len != r->len) {
199         tab = bf_realloc(r->ctx, r->tab, len * sizeof(limb_t));
200         if (!tab && len != 0)
201             return -1;
202         r->tab = tab;
203         r->len = len;
204     }
205     return 0;
206 }
207 
208 /* return 0 or BF_ST_MEM_ERROR */
bf_set_ui(bf_t * r,uint64_t a)209 int bf_set_ui(bf_t *r, uint64_t a)
210 {
211     r->sign = 0;
212     if (a == 0) {
213         r->expn = BF_EXP_ZERO;
214         bf_resize(r, 0); /* cannot fail */
215     }
216 #if LIMB_BITS == 32
217     else if (a <= 0xffffffff)
218 #else
219     else
220 #endif
221     {
222         int shift;
223         if (bf_resize(r, 1))
224             goto fail;
225         shift = clz(a);
226         r->tab[0] = a << shift;
227         r->expn = LIMB_BITS - shift;
228     }
229 #if LIMB_BITS == 32
230     else {
231         uint32_t a1, a0;
232         int shift;
233         if (bf_resize(r, 2))
234             goto fail;
235         a0 = a;
236         a1 = a >> 32;
237         shift = clz(a1);
238         r->tab[0] = a0 << shift;
239         r->tab[1] = (a1 << shift) | (a0 >> (LIMB_BITS - shift));
240         r->expn = 2 * LIMB_BITS - shift;
241     }
242 #endif
243     return 0;
244  fail:
245     bf_set_nan(r);
246     return BF_ST_MEM_ERROR;
247 }
248 
249 /* return 0 or BF_ST_MEM_ERROR */
bf_set_si(bf_t * r,int64_t a)250 int bf_set_si(bf_t *r, int64_t a)
251 {
252     int ret;
253 
254     if (a < 0) {
255         ret = bf_set_ui(r, -a);
256         r->sign = 1;
257     } else {
258         ret = bf_set_ui(r, a);
259     }
260     return ret;
261 }
262 
bf_set_nan(bf_t * r)263 void bf_set_nan(bf_t *r)
264 {
265     bf_resize(r, 0); /* cannot fail */
266     r->expn = BF_EXP_NAN;
267     r->sign = 0;
268 }
269 
bf_set_zero(bf_t * r,int is_neg)270 void bf_set_zero(bf_t *r, int is_neg)
271 {
272     bf_resize(r, 0); /* cannot fail */
273     r->expn = BF_EXP_ZERO;
274     r->sign = is_neg;
275 }
276 
bf_set_inf(bf_t * r,int is_neg)277 void bf_set_inf(bf_t *r, int is_neg)
278 {
279     bf_resize(r, 0); /* cannot fail */
280     r->expn = BF_EXP_INF;
281     r->sign = is_neg;
282 }
283 
284 /* return 0 or BF_ST_MEM_ERROR */
bf_set(bf_t * r,const bf_t * a)285 int bf_set(bf_t *r, const bf_t *a)
286 {
287     if (r == a)
288         return 0;
289     if (bf_resize(r, a->len)) {
290         bf_set_nan(r);
291         return BF_ST_MEM_ERROR;
292     }
293     r->sign = a->sign;
294     r->expn = a->expn;
295     memcpy(r->tab, a->tab, a->len * sizeof(limb_t));
296     return 0;
297 }
298 
299 /* equivalent to bf_set(r, a); bf_delete(a) */
bf_move(bf_t * r,bf_t * a)300 void bf_move(bf_t *r, bf_t *a)
301 {
302     bf_context_t *s = r->ctx;
303     if (r == a)
304         return;
305     bf_free(s, r->tab);
306     *r = *a;
307 }
308 
get_limbz(const bf_t * a,limb_t idx)309 static limb_t get_limbz(const bf_t *a, limb_t idx)
310 {
311     if (idx >= a->len)
312         return 0;
313     else
314         return a->tab[idx];
315 }
316 
317 /* get LIMB_BITS at bit position 'pos' in tab */
get_bits(const limb_t * tab,limb_t len,slimb_t pos)318 static inline limb_t get_bits(const limb_t *tab, limb_t len, slimb_t pos)
319 {
320     limb_t i, a0, a1;
321     int p;
322 
323     i = pos >> LIMB_LOG2_BITS;
324     p = pos & (LIMB_BITS - 1);
325     if (i < len)
326         a0 = tab[i];
327     else
328         a0 = 0;
329     if (p == 0) {
330         return a0;
331     } else {
332         i++;
333         if (i < len)
334             a1 = tab[i];
335         else
336             a1 = 0;
337         return (a0 >> p) | (a1 << (LIMB_BITS - p));
338     }
339 }
340 
get_bit(const limb_t * tab,limb_t len,slimb_t pos)341 static inline limb_t get_bit(const limb_t *tab, limb_t len, slimb_t pos)
342 {
343     slimb_t i;
344     i = pos >> LIMB_LOG2_BITS;
345     if (i < 0 || i >= len)
346         return 0;
347     return (tab[i] >> (pos & (LIMB_BITS - 1))) & 1;
348 }
349 
limb_mask(int start,int last)350 static inline limb_t limb_mask(int start, int last)
351 {
352     limb_t v;
353     int n;
354     n = last - start + 1;
355     if (n == LIMB_BITS)
356         v = -1;
357     else
358         v = (((limb_t)1 << n) - 1) << start;
359     return v;
360 }
361 
mp_scan_nz(const limb_t * tab,mp_size_t n)362 static limb_t mp_scan_nz(const limb_t *tab, mp_size_t n)
363 {
364     mp_size_t i;
365     for(i = 0; i < n; i++) {
366         if (tab[i] != 0)
367             return 1;
368     }
369     return 0;
370 }
371 
372 /* return != 0 if one bit between 0 and bit_pos inclusive is not zero. */
scan_bit_nz(const bf_t * r,slimb_t bit_pos)373 static inline limb_t scan_bit_nz(const bf_t *r, slimb_t bit_pos)
374 {
375     slimb_t pos;
376     limb_t v;
377 
378     pos = bit_pos >> LIMB_LOG2_BITS;
379     if (pos < 0)
380         return 0;
381     v = r->tab[pos] & limb_mask(0, bit_pos & (LIMB_BITS - 1));
382     if (v != 0)
383         return 1;
384     pos--;
385     while (pos >= 0) {
386         if (r->tab[pos] != 0)
387             return 1;
388         pos--;
389     }
390     return 0;
391 }
392 
393 /* return the addend for rounding. Note that prec can be <= 0 (for
394    BF_FLAG_RADPNT_PREC) */
bf_get_rnd_add(int * pret,const bf_t * r,limb_t l,slimb_t prec,int rnd_mode)395 static int bf_get_rnd_add(int *pret, const bf_t *r, limb_t l,
396                           slimb_t prec, int rnd_mode)
397 {
398     int add_one, inexact;
399     limb_t bit1, bit0;
400 
401     if (rnd_mode == BF_RNDF) {
402         bit0 = 1; /* faithful rounding does not honor the INEXACT flag */
403     } else {
404         /* starting limb for bit 'prec + 1' */
405         bit0 = scan_bit_nz(r, l * LIMB_BITS - 1 - bf_max(0, prec + 1));
406     }
407 
408     /* get the bit at 'prec' */
409     bit1 = get_bit(r->tab, l, l * LIMB_BITS - 1 - prec);
410     inexact = (bit1 | bit0) != 0;
411 
412     add_one = 0;
413     switch(rnd_mode) {
414     case BF_RNDZ:
415         break;
416     case BF_RNDN:
417         if (bit1) {
418             if (bit0) {
419                 add_one = 1;
420             } else {
421                 /* round to even */
422                 add_one =
423                     get_bit(r->tab, l, l * LIMB_BITS - 1 - (prec - 1));
424             }
425         }
426         break;
427     case BF_RNDD:
428     case BF_RNDU:
429         if (r->sign == (rnd_mode == BF_RNDD))
430             add_one = inexact;
431         break;
432     case BF_RNDA:
433         add_one = inexact;
434         break;
435     case BF_RNDNA:
436     case BF_RNDF:
437         add_one = bit1;
438         break;
439     default:
440         abort();
441     }
442 
443     if (inexact)
444         *pret |= BF_ST_INEXACT;
445     return add_one;
446 }
447 
bf_set_overflow(bf_t * r,int sign,limb_t prec,bf_flags_t flags)448 static int bf_set_overflow(bf_t *r, int sign, limb_t prec, bf_flags_t flags)
449 {
450     slimb_t i, l, e_max;
451     int rnd_mode;
452 
453     rnd_mode = flags & BF_RND_MASK;
454     if (prec == BF_PREC_INF ||
455         rnd_mode == BF_RNDN ||
456         rnd_mode == BF_RNDNA ||
457         rnd_mode == BF_RNDA ||
458         (rnd_mode == BF_RNDD && sign == 1) ||
459         (rnd_mode == BF_RNDU && sign == 0)) {
460         bf_set_inf(r, sign);
461     } else {
462         /* set to maximum finite number */
463         l = (prec + LIMB_BITS - 1) / LIMB_BITS;
464         if (bf_resize(r, l)) {
465             bf_set_nan(r);
466             return BF_ST_MEM_ERROR;
467         }
468         r->tab[0] = limb_mask((-prec) & (LIMB_BITS - 1),
469                               LIMB_BITS - 1);
470         for(i = 1; i < l; i++)
471             r->tab[i] = (limb_t)-1;
472         e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
473         r->expn = e_max;
474         r->sign = sign;
475     }
476     return BF_ST_OVERFLOW | BF_ST_INEXACT;
477 }
478 
479 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
480    assumed to have length 'l' (1 <= l <= r->len). Note: 'prec1' can be
481    infinite (BF_PREC_INF). 'ret' is 0 or BF_ST_INEXACT if the result
482    is known to be inexact. Can fail with BF_ST_MEM_ERROR in case of
483    overflow not returning infinity. */
__bf_round(bf_t * r,limb_t prec1,bf_flags_t flags,limb_t l,int ret)484 static int __bf_round(bf_t *r, limb_t prec1, bf_flags_t flags, limb_t l,
485                       int ret)
486 {
487     limb_t v, a;
488     int shift, add_one, rnd_mode;
489     slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
490 
491     /* e_min and e_max are computed to match the IEEE 754 conventions */
492     e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
493     e_min = -e_range + 3;
494     e_max = e_range;
495 
496     if (flags & BF_FLAG_RADPNT_PREC) {
497         /* 'prec' is the precision after the radix point */
498         if (prec1 != BF_PREC_INF)
499             prec = r->expn + prec1;
500         else
501             prec = prec1;
502     } else if (unlikely(r->expn < e_min) && (flags & BF_FLAG_SUBNORMAL)) {
503         /* restrict the precision in case of potentially subnormal
504            result */
505         assert(prec1 != BF_PREC_INF);
506         prec = prec1 - (e_min - r->expn);
507     } else {
508         prec = prec1;
509     }
510 
511     /* round to prec bits */
512     rnd_mode = flags & BF_RND_MASK;
513     add_one = bf_get_rnd_add(&ret, r, l, prec, rnd_mode);
514 
515     if (prec <= 0) {
516         if (add_one) {
517             bf_resize(r, 1); /* cannot fail */
518             r->tab[0] = (limb_t)1 << (LIMB_BITS - 1);
519             r->expn += 1 - prec;
520             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
521             return ret;
522         } else {
523             goto underflow;
524         }
525     } else if (add_one) {
526         limb_t carry;
527 
528         /* add one starting at digit 'prec - 1' */
529         bit_pos = l * LIMB_BITS - 1 - (prec - 1);
530         pos = bit_pos >> LIMB_LOG2_BITS;
531         carry = (limb_t)1 << (bit_pos & (LIMB_BITS - 1));
532 
533         for(i = pos; i < l; i++) {
534             v = r->tab[i] + carry;
535             carry = (v < carry);
536             r->tab[i] = v;
537             if (carry == 0)
538                 break;
539         }
540         if (carry) {
541             /* shift right by one digit */
542             v = 1;
543             for(i = l - 1; i >= pos; i--) {
544                 a = r->tab[i];
545                 r->tab[i] = (a >> 1) | (v << (LIMB_BITS - 1));
546                 v = a;
547             }
548             r->expn++;
549         }
550     }
551 
552     /* check underflow */
553     if (unlikely(r->expn < e_min)) {
554         if (flags & BF_FLAG_SUBNORMAL) {
555             /* if inexact, also set the underflow flag */
556             if (ret & BF_ST_INEXACT)
557                 ret |= BF_ST_UNDERFLOW;
558         } else {
559         underflow:
560             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
561             bf_set_zero(r, r->sign);
562             return ret;
563         }
564     }
565 
566     /* check overflow */
567     if (unlikely(r->expn > e_max))
568         return bf_set_overflow(r, r->sign, prec1, flags);
569 
570     /* keep the bits starting at 'prec - 1' */
571     bit_pos = l * LIMB_BITS - 1 - (prec - 1);
572     i = bit_pos >> LIMB_LOG2_BITS;
573     if (i >= 0) {
574         shift = bit_pos & (LIMB_BITS - 1);
575         if (shift != 0)
576             r->tab[i] &= limb_mask(shift, LIMB_BITS - 1);
577     } else {
578         i = 0;
579     }
580     /* remove trailing zeros */
581     while (r->tab[i] == 0)
582         i++;
583     if (i > 0) {
584         l -= i;
585         memmove(r->tab, r->tab + i, l * sizeof(limb_t));
586     }
587     bf_resize(r, l); /* cannot fail */
588     return ret;
589 }
590 
591 /* 'r' must be a finite number. */
bf_normalize_and_round(bf_t * r,limb_t prec1,bf_flags_t flags)592 int bf_normalize_and_round(bf_t *r, limb_t prec1, bf_flags_t flags)
593 {
594     limb_t l, v, a;
595     int shift, ret;
596     slimb_t i;
597 
598     //    bf_print_str("bf_renorm", r);
599     l = r->len;
600     while (l > 0 && r->tab[l - 1] == 0)
601         l--;
602     if (l == 0) {
603         /* zero */
604         r->expn = BF_EXP_ZERO;
605         bf_resize(r, 0); /* cannot fail */
606         ret = 0;
607     } else {
608         r->expn -= (r->len - l) * LIMB_BITS;
609         /* shift to have the MSB set to '1' */
610         v = r->tab[l - 1];
611         shift = clz(v);
612         if (shift != 0) {
613             v = 0;
614             for(i = 0; i < l; i++) {
615                 a = r->tab[i];
616                 r->tab[i] = (a << shift) | (v >> (LIMB_BITS - shift));
617                 v = a;
618             }
619             r->expn -= shift;
620         }
621         ret = __bf_round(r, prec1, flags, l, 0);
622     }
623     //    bf_print_str("r_final", r);
624     return ret;
625 }
626 
627 /* return true if rounding can be done at precision 'prec' assuming
628    the exact result r is such that |r-a| <= 2^(EXP(a)-k). */
629 /* XXX: check the case where the exponent would be incremented by the
630    rounding */
bf_can_round(const bf_t * a,slimb_t prec,bf_rnd_t rnd_mode,slimb_t k)631 int bf_can_round(const bf_t *a, slimb_t prec, bf_rnd_t rnd_mode, slimb_t k)
632 {
633     BOOL is_rndn;
634     slimb_t bit_pos, n;
635     limb_t bit;
636 
637     if (a->expn == BF_EXP_INF || a->expn == BF_EXP_NAN)
638         return FALSE;
639     if (rnd_mode == BF_RNDF) {
640         return (k >= (prec + 1));
641     }
642     if (a->expn == BF_EXP_ZERO)
643         return FALSE;
644     is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA);
645     if (k < (prec + 2))
646         return FALSE;
647     bit_pos = a->len * LIMB_BITS - 1 - prec;
648     n = k - prec;
649     /* bit pattern for RNDN or RNDNA: 0111.. or 1000...
650        for other rounding modes: 000... or 111...
651     */
652     bit = get_bit(a->tab, a->len, bit_pos);
653     bit_pos--;
654     n--;
655     bit ^= is_rndn;
656     /* XXX: slow, but a few iterations on average */
657     while (n != 0) {
658         if (get_bit(a->tab, a->len, bit_pos) != bit)
659             return TRUE;
660         bit_pos--;
661         n--;
662     }
663     return FALSE;
664 }
665 
666 /* Cannot fail with BF_ST_MEM_ERROR. */
bf_round(bf_t * r,limb_t prec,bf_flags_t flags)667 int bf_round(bf_t *r, limb_t prec, bf_flags_t flags)
668 {
669     if (r->len == 0)
670         return 0;
671     return __bf_round(r, prec, flags, r->len, 0);
672 }
673 
674 /* for debugging */
dump_limbs(const char * str,const limb_t * tab,limb_t n)675 static __maybe_unused void dump_limbs(const char *str, const limb_t *tab, limb_t n)
676 {
677     limb_t i;
678     printf("%s: len=%" PRId_LIMB "\n", str, n);
679     for(i = 0; i < n; i++) {
680         printf("%" PRId_LIMB ": " FMT_LIMB "\n",
681                i, tab[i]);
682     }
683 }
684 
mp_print_str(const char * str,const limb_t * tab,limb_t n)685 void mp_print_str(const char *str, const limb_t *tab, limb_t n)
686 {
687     slimb_t i;
688     printf("%s= 0x", str);
689     for(i = n - 1; i >= 0; i--) {
690         if (i != (n - 1))
691             printf("_");
692         printf(FMT_LIMB, tab[i]);
693     }
694     printf("\n");
695 }
696 
mp_print_str_h(const char * str,const limb_t * tab,limb_t n,limb_t high)697 static __maybe_unused void mp_print_str_h(const char *str,
698                                           const limb_t *tab, limb_t n,
699                                           limb_t high)
700 {
701     slimb_t i;
702     printf("%s= 0x", str);
703     printf(FMT_LIMB, high);
704     for(i = n - 1; i >= 0; i--) {
705         printf("_");
706         printf(FMT_LIMB, tab[i]);
707     }
708     printf("\n");
709 }
710 
711 /* for debugging */
bf_print_str(const char * str,const bf_t * a)712 void bf_print_str(const char *str, const bf_t *a)
713 {
714     slimb_t i;
715     printf("%s=", str);
716 
717     if (a->expn == BF_EXP_NAN) {
718         printf("NaN");
719     } else {
720         if (a->sign)
721             putchar('-');
722         if (a->expn == BF_EXP_ZERO) {
723             putchar('0');
724         } else if (a->expn == BF_EXP_INF) {
725             printf("Inf");
726         } else {
727             printf("0x0.");
728             for(i = a->len - 1; i >= 0; i--)
729                 printf(FMT_LIMB, a->tab[i]);
730             printf("p%" PRId_LIMB, a->expn);
731         }
732     }
733     printf("\n");
734 }
735 
736 /* compare the absolute value of 'a' and 'b'. Return < 0 if a < b, 0
737    if a = b and > 0 otherwise. */
bf_cmpu(const bf_t * a,const bf_t * b)738 int bf_cmpu(const bf_t *a, const bf_t *b)
739 {
740     slimb_t i;
741     limb_t len, v1, v2;
742 
743     if (a->expn != b->expn) {
744         if (a->expn < b->expn)
745             return -1;
746         else
747             return 1;
748     }
749     len = bf_max(a->len, b->len);
750     for(i = len - 1; i >= 0; i--) {
751         v1 = get_limbz(a, a->len - len + i);
752         v2 = get_limbz(b, b->len - len + i);
753         if (v1 != v2) {
754             if (v1 < v2)
755                 return -1;
756             else
757                 return 1;
758         }
759     }
760     return 0;
761 }
762 
763 /* Full order: -0 < 0, NaN == NaN and NaN is larger than all other numbers */
bf_cmp_full(const bf_t * a,const bf_t * b)764 int bf_cmp_full(const bf_t *a, const bf_t *b)
765 {
766     int res;
767 
768     if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
769         if (a->expn == b->expn)
770             res = 0;
771         else if (a->expn == BF_EXP_NAN)
772             res = 1;
773         else
774             res = -1;
775     } else if (a->sign != b->sign) {
776         res = 1 - 2 * a->sign;
777     } else {
778         res = bf_cmpu(a, b);
779         if (a->sign)
780             res = -res;
781     }
782     return res;
783 }
784 
785 /* Standard floating point comparison: return 2 if one of the operands
786    is NaN (unordered) or -1, 0, 1 depending on the ordering assuming
787    -0 == +0 */
bf_cmp(const bf_t * a,const bf_t * b)788 int bf_cmp(const bf_t *a, const bf_t *b)
789 {
790     int res;
791 
792     if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
793         res = 2;
794     } else if (a->sign != b->sign) {
795         if (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_ZERO)
796             res = 0;
797         else
798             res = 1 - 2 * a->sign;
799     } else {
800         res = bf_cmpu(a, b);
801         if (a->sign)
802             res = -res;
803     }
804     return res;
805 }
806 
807 /* Compute the number of bits 'n' matching the pattern:
808    a= X1000..0
809    b= X0111..1
810 
811    When computing a-b, the result will have at least n leading zero
812    bits.
813 
814    Precondition: a > b and a.expn - b.expn = 0 or 1
815 */
count_cancelled_bits(const bf_t * a,const bf_t * b)816 static limb_t count_cancelled_bits(const bf_t *a, const bf_t *b)
817 {
818     slimb_t bit_offset, b_offset, n;
819     int p, p1;
820     limb_t v1, v2, mask;
821 
822     bit_offset = a->len * LIMB_BITS - 1;
823     b_offset = (b->len - a->len) * LIMB_BITS - (LIMB_BITS - 1) +
824         a->expn - b->expn;
825     n = 0;
826 
827     /* first search the equals bits */
828     for(;;) {
829         v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
830         v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
831         //        printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
832         if (v1 != v2)
833             break;
834         n += LIMB_BITS;
835         bit_offset -= LIMB_BITS;
836     }
837     /* find the position of the first different bit */
838     p = clz(v1 ^ v2) + 1;
839     n += p;
840     /* then search for '0' in a and '1' in b */
841     p = LIMB_BITS - p;
842     if (p > 0) {
843         /* search in the trailing p bits of v1 and v2 */
844         mask = limb_mask(0, p - 1);
845         p1 = bf_min(clz(v1 & mask), clz((~v2) & mask)) - (LIMB_BITS - p);
846         n += p1;
847         if (p1 != p)
848             goto done;
849     }
850     bit_offset -= LIMB_BITS;
851     for(;;) {
852         v1 = get_limbz(a, bit_offset >> LIMB_LOG2_BITS);
853         v2 = get_bits(b->tab, b->len, bit_offset + b_offset);
854         //        printf("v1=" FMT_LIMB " v2=" FMT_LIMB "\n", v1, v2);
855         if (v1 != 0 || v2 != -1) {
856             /* different: count the matching bits */
857             p1 = bf_min(clz(v1), clz(~v2));
858             n += p1;
859             break;
860         }
861         n += LIMB_BITS;
862         bit_offset -= LIMB_BITS;
863     }
864  done:
865     return n;
866 }
867 
bf_add_internal(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int b_neg)868 static int bf_add_internal(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
869                            bf_flags_t flags, int b_neg)
870 {
871     const bf_t *tmp;
872     int is_sub, ret, cmp_res, a_sign, b_sign;
873 
874     a_sign = a->sign;
875     b_sign = b->sign ^ b_neg;
876     is_sub = a_sign ^ b_sign;
877     cmp_res = bf_cmpu(a, b);
878     if (cmp_res < 0) {
879         tmp = a;
880         a = b;
881         b = tmp;
882         a_sign = b_sign; /* b_sign is never used later */
883     }
884     /* abs(a) >= abs(b) */
885     if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
886         /* zero result */
887         bf_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
888         ret = 0;
889     } else if (a->len == 0 || b->len == 0) {
890         ret = 0;
891         if (a->expn >= BF_EXP_INF) {
892             if (a->expn == BF_EXP_NAN) {
893                 /* at least one operand is NaN */
894                 bf_set_nan(r);
895             } else if (b->expn == BF_EXP_INF && is_sub) {
896                 /* infinities with different signs */
897                 bf_set_nan(r);
898                 ret = BF_ST_INVALID_OP;
899             } else {
900                 bf_set_inf(r, a_sign);
901             }
902         } else {
903             /* at least one zero and not subtract */
904             bf_set(r, a);
905             r->sign = a_sign;
906             goto renorm;
907         }
908     } else {
909         slimb_t d, a_offset, b_bit_offset, i, cancelled_bits;
910         limb_t carry, v1, v2, u, r_len, carry1, precl, tot_len, z, sub_mask;
911 
912         r->sign = a_sign;
913         r->expn = a->expn;
914         d = a->expn - b->expn;
915         /* must add more precision for the leading cancelled bits in
916            subtraction */
917         if (is_sub) {
918             if (d <= 1)
919                 cancelled_bits = count_cancelled_bits(a, b);
920             else
921                 cancelled_bits = 1;
922         } else {
923             cancelled_bits = 0;
924         }
925 
926         /* add two extra bits for rounding */
927         precl = (cancelled_bits + prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
928         tot_len = bf_max(a->len, b->len + (d + LIMB_BITS - 1) / LIMB_BITS);
929         r_len = bf_min(precl, tot_len);
930         if (bf_resize(r, r_len))
931             goto fail;
932         a_offset = a->len - r_len;
933         b_bit_offset = (b->len - r_len) * LIMB_BITS + d;
934 
935         /* compute the bits before for the rounding */
936         carry = is_sub;
937         z = 0;
938         sub_mask = -is_sub;
939         i = r_len - tot_len;
940         while (i < 0) {
941             slimb_t ap, bp;
942             BOOL inflag;
943 
944             ap = a_offset + i;
945             bp = b_bit_offset + i * LIMB_BITS;
946             inflag = FALSE;
947             if (ap >= 0 && ap < a->len) {
948                 v1 = a->tab[ap];
949                 inflag = TRUE;
950             } else {
951                 v1 = 0;
952             }
953             if (bp + LIMB_BITS > 0 && bp < (slimb_t)(b->len * LIMB_BITS)) {
954                 v2 = get_bits(b->tab, b->len, bp);
955                 inflag = TRUE;
956             } else {
957                 v2 = 0;
958             }
959             if (!inflag) {
960                 /* outside 'a' and 'b': go directly to the next value
961                    inside a or b so that the running time does not
962                    depend on the exponent difference */
963                 i = 0;
964                 if (ap < 0)
965                     i = bf_min(i, -a_offset);
966                 /* b_bit_offset + i * LIMB_BITS + LIMB_BITS >= 1
967                    equivalent to
968                    i >= ceil(-b_bit_offset + 1 - LIMB_BITS) / LIMB_BITS)
969                 */
970                 if (bp + LIMB_BITS <= 0)
971                     i = bf_min(i, (-b_bit_offset) >> LIMB_LOG2_BITS);
972             } else {
973                 i++;
974             }
975             v2 ^= sub_mask;
976             u = v1 + v2;
977             carry1 = u < v1;
978             u += carry;
979             carry = (u < carry) | carry1;
980             z |= u;
981         }
982         /* and the result */
983         for(i = 0; i < r_len; i++) {
984             v1 = get_limbz(a, a_offset + i);
985             v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS);
986             v2 ^= sub_mask;
987             u = v1 + v2;
988             carry1 = u < v1;
989             u += carry;
990             carry = (u < carry) | carry1;
991             r->tab[i] = u;
992         }
993         /* set the extra bits for the rounding */
994         r->tab[0] |= (z != 0);
995 
996         /* carry is only possible in add case */
997         if (!is_sub && carry) {
998             if (bf_resize(r, r_len + 1))
999                 goto fail;
1000             r->tab[r_len] = 1;
1001             r->expn += LIMB_BITS;
1002         }
1003     renorm:
1004         ret = bf_normalize_and_round(r, prec, flags);
1005     }
1006     return ret;
1007  fail:
1008     bf_set_nan(r);
1009     return BF_ST_MEM_ERROR;
1010 }
1011 
__bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1012 static int __bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1013                      bf_flags_t flags)
1014 {
1015     return bf_add_internal(r, a, b, prec, flags, 0);
1016 }
1017 
__bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1018 static int __bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1019                      bf_flags_t flags)
1020 {
1021     return bf_add_internal(r, a, b, prec, flags, 1);
1022 }
1023 
mp_add(limb_t * res,const limb_t * op1,const limb_t * op2,limb_t n,limb_t carry)1024 limb_t mp_add(limb_t *res, const limb_t *op1, const limb_t *op2,
1025               limb_t n, limb_t carry)
1026 {
1027     slimb_t i;
1028     limb_t k, a, v, k1;
1029 
1030     k = carry;
1031     for(i=0;i<n;i++) {
1032         v = op1[i];
1033         a = v + op2[i];
1034         k1 = a < v;
1035         a = a + k;
1036         k = (a < k) | k1;
1037         res[i] = a;
1038     }
1039     return k;
1040 }
1041 
mp_add_ui(limb_t * tab,limb_t b,size_t n)1042 limb_t mp_add_ui(limb_t *tab, limb_t b, size_t n)
1043 {
1044     size_t i;
1045     limb_t k, a;
1046 
1047     k=b;
1048     for(i=0;i<n;i++) {
1049         if (k == 0)
1050             break;
1051         a = tab[i] + k;
1052         k = (a < k);
1053         tab[i] = a;
1054     }
1055     return k;
1056 }
1057 
mp_sub(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)1058 limb_t mp_sub(limb_t *res, const limb_t *op1, const limb_t *op2,
1059               mp_size_t n, limb_t carry)
1060 {
1061     int i;
1062     limb_t k, a, v, k1;
1063 
1064     k = carry;
1065     for(i=0;i<n;i++) {
1066         v = op1[i];
1067         a = v - op2[i];
1068         k1 = a > v;
1069         v = a - k;
1070         k = (v > a) | k1;
1071         res[i] = v;
1072     }
1073     return k;
1074 }
1075 
1076 /* compute 0 - op2 */
mp_neg(limb_t * res,const limb_t * op2,mp_size_t n,limb_t carry)1077 static limb_t mp_neg(limb_t *res, const limb_t *op2, mp_size_t n, limb_t carry)
1078 {
1079     int i;
1080     limb_t k, a, v, k1;
1081 
1082     k = carry;
1083     for(i=0;i<n;i++) {
1084         v = 0;
1085         a = v - op2[i];
1086         k1 = a > v;
1087         v = a - k;
1088         k = (v > a) | k1;
1089         res[i] = v;
1090     }
1091     return k;
1092 }
1093 
mp_sub_ui(limb_t * tab,limb_t b,mp_size_t n)1094 limb_t mp_sub_ui(limb_t *tab, limb_t b, mp_size_t n)
1095 {
1096     mp_size_t i;
1097     limb_t k, a, v;
1098 
1099     k=b;
1100     for(i=0;i<n;i++) {
1101         v = tab[i];
1102         a = v - k;
1103         k = a > v;
1104         tab[i] = a;
1105         if (k == 0)
1106             break;
1107     }
1108     return k;
1109 }
1110 
1111 /* r = (a + high*B^n) >> shift. Return the remainder r (0 <= r < 2^shift).
1112    1 <= shift <= LIMB_BITS - 1 */
mp_shr(limb_t * tab_r,const limb_t * tab,mp_size_t n,int shift,limb_t high)1113 static limb_t mp_shr(limb_t *tab_r, const limb_t *tab, mp_size_t n,
1114                      int shift, limb_t high)
1115 {
1116     mp_size_t i;
1117     limb_t l, a;
1118 
1119     assert(shift >= 1 && shift < LIMB_BITS);
1120     l = high;
1121     for(i = n - 1; i >= 0; i--) {
1122         a = tab[i];
1123         tab_r[i] = (a >> shift) | (l << (LIMB_BITS - shift));
1124         l = a;
1125     }
1126     return l & (((limb_t)1 << shift) - 1);
1127 }
1128 
1129 /* tabr[] = taba[] * b + l. Return the high carry */
mp_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t l)1130 static limb_t mp_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1131                       limb_t b, limb_t l)
1132 {
1133     limb_t i;
1134     dlimb_t t;
1135 
1136     for(i = 0; i < n; i++) {
1137         t = (dlimb_t)taba[i] * (dlimb_t)b + l;
1138         tabr[i] = t;
1139         l = t >> LIMB_BITS;
1140     }
1141     return l;
1142 }
1143 
1144 /* tabr[] += taba[] * b, return the high word. */
mp_add_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1145 static limb_t mp_add_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1146                           limb_t b)
1147 {
1148     limb_t i, l;
1149     dlimb_t t;
1150 
1151     l = 0;
1152     for(i = 0; i < n; i++) {
1153         t = (dlimb_t)taba[i] * (dlimb_t)b + l + tabr[i];
1154         tabr[i] = t;
1155         l = t >> LIMB_BITS;
1156     }
1157     return l;
1158 }
1159 
1160 /* size of the result : op1_size + op2_size. */
mp_mul_basecase(limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1161 static void mp_mul_basecase(limb_t *result,
1162                             const limb_t *op1, limb_t op1_size,
1163                             const limb_t *op2, limb_t op2_size)
1164 {
1165     limb_t i, r;
1166 
1167     result[op1_size] = mp_mul1(result, op1, op1_size, op2[0], 0);
1168     for(i=1;i<op2_size;i++) {
1169         r = mp_add_mul1(result + i, op1, op1_size, op2[i]);
1170         result[i + op1_size] = r;
1171     }
1172 }
1173 
1174 /* return 0 if OK, -1 if memory error */
1175 /* XXX: change API so that result can be allocated */
mp_mul(bf_context_t * s,limb_t * result,const limb_t * op1,limb_t op1_size,const limb_t * op2,limb_t op2_size)1176 int mp_mul(bf_context_t *s, limb_t *result,
1177            const limb_t *op1, limb_t op1_size,
1178            const limb_t *op2, limb_t op2_size)
1179 {
1180 #ifdef USE_FFT_MUL
1181     if (unlikely(bf_min(op1_size, op2_size) >= FFT_MUL_THRESHOLD)) {
1182         bf_t r_s, *r = &r_s;
1183         r->tab = result;
1184         /* XXX: optimize memory usage in API */
1185         if (fft_mul(s, r, (limb_t *)op1, op1_size,
1186                     (limb_t *)op2, op2_size, FFT_MUL_R_NORESIZE))
1187             return -1;
1188     } else
1189 #endif
1190     {
1191         mp_mul_basecase(result, op1, op1_size, op2, op2_size);
1192     }
1193     return 0;
1194 }
1195 
1196 /* tabr[] -= taba[] * b. Return the value to substract to the high
1197    word. */
mp_sub_mul1(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b)1198 static limb_t mp_sub_mul1(limb_t *tabr, const limb_t *taba, limb_t n,
1199                           limb_t b)
1200 {
1201     limb_t i, l;
1202     dlimb_t t;
1203 
1204     l = 0;
1205     for(i = 0; i < n; i++) {
1206         t = tabr[i] - (dlimb_t)taba[i] * (dlimb_t)b - l;
1207         tabr[i] = t;
1208         l = -(t >> LIMB_BITS);
1209     }
1210     return l;
1211 }
1212 
1213 /* WARNING: d must be >= 2^(LIMB_BITS-1) */
udiv1norm_init(limb_t d)1214 static inline limb_t udiv1norm_init(limb_t d)
1215 {
1216     limb_t a0, a1;
1217     a1 = -d - 1;
1218     a0 = -1;
1219     return (((dlimb_t)a1 << LIMB_BITS) | a0) / d;
1220 }
1221 
1222 /* return the quotient and the remainder in '*pr'of 'a1*2^LIMB_BITS+a0
1223    / d' with 0 <= a1 < d. */
udiv1norm(limb_t * pr,limb_t a1,limb_t a0,limb_t d,limb_t d_inv)1224 static inline limb_t udiv1norm(limb_t *pr, limb_t a1, limb_t a0,
1225                                 limb_t d, limb_t d_inv)
1226 {
1227     limb_t n1m, n_adj, q, r, ah;
1228     dlimb_t a;
1229     n1m = ((slimb_t)a0 >> (LIMB_BITS - 1));
1230     n_adj = a0 + (n1m & d);
1231     a = (dlimb_t)d_inv * (a1 - n1m) + n_adj;
1232     q = (a >> LIMB_BITS) + a1;
1233     /* compute a - q * r and update q so that the remainder is\
1234        between 0 and d - 1 */
1235     a = ((dlimb_t)a1 << LIMB_BITS) | a0;
1236     a = a - (dlimb_t)q * d - d;
1237     ah = a >> LIMB_BITS;
1238     q += 1 + ah;
1239     r = (limb_t)a + (ah & d);
1240     *pr = r;
1241     return q;
1242 }
1243 
1244 /* b must be >= 1 << (LIMB_BITS - 1) */
mp_div1norm(limb_t * tabr,const limb_t * taba,limb_t n,limb_t b,limb_t r)1245 static limb_t mp_div1norm(limb_t *tabr, const limb_t *taba, limb_t n,
1246                           limb_t b, limb_t r)
1247 {
1248     slimb_t i;
1249 
1250     if (n >= UDIV1NORM_THRESHOLD) {
1251         limb_t b_inv;
1252         b_inv = udiv1norm_init(b);
1253         for(i = n - 1; i >= 0; i--) {
1254             tabr[i] = udiv1norm(&r, r, taba[i], b, b_inv);
1255         }
1256     } else {
1257         dlimb_t a1;
1258         for(i = n - 1; i >= 0; i--) {
1259             a1 = ((dlimb_t)r << LIMB_BITS) | taba[i];
1260             tabr[i] = a1 / b;
1261             r = a1 % b;
1262         }
1263     }
1264     return r;
1265 }
1266 
1267 static int mp_divnorm_large(bf_context_t *s,
1268                             limb_t *tabq, limb_t *taba, limb_t na,
1269                             const limb_t *tabb, limb_t nb);
1270 
1271 /* base case division: divides taba[0..na-1] by tabb[0..nb-1]. tabb[nb
1272    - 1] must be >= 1 << (LIMB_BITS - 1). na - nb must be >= 0. 'taba'
1273    is modified and contains the remainder (nb limbs). tabq[0..na-nb]
1274    contains the quotient with tabq[na - nb] <= 1. */
mp_divnorm(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1275 static int mp_divnorm(bf_context_t *s, limb_t *tabq, limb_t *taba, limb_t na,
1276                       const limb_t *tabb, limb_t nb)
1277 {
1278     limb_t r, a, c, q, v, b1, b1_inv, n, dummy_r;
1279     slimb_t i, j;
1280 
1281     b1 = tabb[nb - 1];
1282     if (nb == 1) {
1283         taba[0] = mp_div1norm(tabq, taba, na, b1, 0);
1284         return 0;
1285     }
1286     n = na - nb;
1287     if (bf_min(n, nb) >= DIVNORM_LARGE_THRESHOLD) {
1288         return mp_divnorm_large(s, tabq, taba, na, tabb, nb);
1289     }
1290 
1291     if (n >= UDIV1NORM_THRESHOLD)
1292         b1_inv = udiv1norm_init(b1);
1293     else
1294         b1_inv = 0;
1295 
1296     /* first iteration: the quotient is only 0 or 1 */
1297     q = 1;
1298     for(j = nb - 1; j >= 0; j--) {
1299         if (taba[n + j] != tabb[j]) {
1300             if (taba[n + j] < tabb[j])
1301                 q = 0;
1302             break;
1303         }
1304     }
1305     tabq[n] = q;
1306     if (q) {
1307         mp_sub(taba + n, taba + n, tabb, nb, 0);
1308     }
1309 
1310     for(i = n - 1; i >= 0; i--) {
1311         if (unlikely(taba[i + nb] >= b1)) {
1312             q = -1;
1313         } else if (b1_inv) {
1314             q = udiv1norm(&dummy_r, taba[i + nb], taba[i + nb - 1], b1, b1_inv);
1315         } else {
1316             dlimb_t al;
1317             al = ((dlimb_t)taba[i + nb] << LIMB_BITS) | taba[i + nb - 1];
1318             q = al / b1;
1319             r = al % b1;
1320         }
1321         r = mp_sub_mul1(taba + i, tabb, nb, q);
1322 
1323         v = taba[i + nb];
1324         a = v - r;
1325         c = (a > v);
1326         taba[i + nb] = a;
1327 
1328         if (c != 0) {
1329             /* negative result */
1330             for(;;) {
1331                 q--;
1332                 c = mp_add(taba + i, taba + i, tabb, nb, 0);
1333                 /* propagate carry and test if positive result */
1334                 if (c != 0) {
1335                     if (++taba[i + nb] == 0) {
1336                         break;
1337                     }
1338                 }
1339             }
1340         }
1341         tabq[i] = q;
1342     }
1343     return 0;
1344 }
1345 
1346 /* compute r=B^(2*n)/a such as a*r < B^(2*n) < a*r + 2 with n >= 1. 'a'
1347    has n limbs with a[n-1] >= B/2 and 'r' has n+1 limbs with r[n] = 1.
1348 
1349    See Modern Computer Arithmetic by Richard P. Brent and Paul
1350    Zimmermann, algorithm 3.5 */
mp_recip(bf_context_t * s,limb_t * tabr,const limb_t * taba,limb_t n)1351 int mp_recip(bf_context_t *s, limb_t *tabr, const limb_t *taba, limb_t n)
1352 {
1353     mp_size_t l, h, k, i;
1354     limb_t *tabxh, *tabt, c, *tabu;
1355 
1356     if (n <= 2) {
1357         /* return ceil(B^(2*n)/a) - 1 */
1358         /* XXX: could avoid allocation */
1359         tabu = bf_malloc(s, sizeof(limb_t) * (2 * n + 1));
1360         tabt = bf_malloc(s, sizeof(limb_t) * (n + 2));
1361         if (!tabt || !tabu)
1362             goto fail;
1363         for(i = 0; i < 2 * n; i++)
1364             tabu[i] = 0;
1365         tabu[2 * n] = 1;
1366         if (mp_divnorm(s, tabt, tabu, 2 * n + 1, taba, n))
1367             goto fail;
1368         for(i = 0; i < n + 1; i++)
1369             tabr[i] = tabt[i];
1370         if (mp_scan_nz(tabu, n) == 0) {
1371             /* only happens for a=B^n/2 */
1372             mp_sub_ui(tabr, 1, n + 1);
1373         }
1374     } else {
1375         l = (n - 1) / 2;
1376         h = n - l;
1377         /* n=2p  -> l=p-1, h = p + 1, k = p + 3
1378            n=2p+1-> l=p,  h = p + 1; k = p + 2
1379         */
1380         tabt = bf_malloc(s, sizeof(limb_t) * (n + h + 1));
1381         tabu = bf_malloc(s, sizeof(limb_t) * (n + 2 * h - l + 2));
1382         if (!tabt || !tabu)
1383             goto fail;
1384         tabxh = tabr + l;
1385         if (mp_recip(s, tabxh, taba + l, h))
1386             goto fail;
1387         if (mp_mul(s, tabt, taba, n, tabxh, h + 1)) /* n + h + 1 limbs */
1388             goto fail;
1389         while (tabt[n + h] != 0) {
1390             mp_sub_ui(tabxh, 1, h + 1);
1391             c = mp_sub(tabt, tabt, taba, n, 0);
1392             mp_sub_ui(tabt + n, c, h + 1);
1393         }
1394         /* T = B^(n+h) - T */
1395         mp_neg(tabt, tabt, n + h + 1, 0);
1396         tabt[n + h]++;
1397         if (mp_mul(s, tabu, tabt + l, n + h + 1 - l, tabxh, h + 1))
1398             goto fail;
1399         /* n + 2*h - l + 2 limbs */
1400         k = 2 * h - l;
1401         for(i = 0; i < l; i++)
1402             tabr[i] = tabu[i + k];
1403         mp_add(tabr + l, tabr + l, tabu + 2 * h, h, 0);
1404     }
1405     bf_free(s, tabt);
1406     bf_free(s, tabu);
1407     return 0;
1408  fail:
1409     bf_free(s, tabt);
1410     bf_free(s, tabu);
1411     return -1;
1412 }
1413 
1414 /* return -1, 0 or 1 */
mp_cmp(const limb_t * taba,const limb_t * tabb,mp_size_t n)1415 static int mp_cmp(const limb_t *taba, const limb_t *tabb, mp_size_t n)
1416 {
1417     mp_size_t i;
1418     for(i = n - 1; i >= 0; i--) {
1419         if (taba[i] != tabb[i]) {
1420             if (taba[i] < tabb[i])
1421                 return -1;
1422             else
1423                 return 1;
1424         }
1425     }
1426     return 0;
1427 }
1428 
1429 //#define DEBUG_DIVNORM_LARGE
1430 //#define DEBUG_DIVNORM_LARGE2
1431 
1432 /* subquadratic divnorm */
mp_divnorm_large(bf_context_t * s,limb_t * tabq,limb_t * taba,limb_t na,const limb_t * tabb,limb_t nb)1433 static int mp_divnorm_large(bf_context_t *s,
1434                             limb_t *tabq, limb_t *taba, limb_t na,
1435                             const limb_t *tabb, limb_t nb)
1436 {
1437     limb_t *tabb_inv, nq, *tabt, i, n;
1438     nq = na - nb;
1439 #ifdef DEBUG_DIVNORM_LARGE
1440     printf("na=%d nb=%d nq=%d\n", (int)na, (int)nb, (int)nq);
1441     mp_print_str("a", taba, na);
1442     mp_print_str("b", tabb, nb);
1443 #endif
1444     assert(nq >= 1);
1445     n = nq;
1446     if (nq < nb)
1447         n++;
1448     tabb_inv = bf_malloc(s, sizeof(limb_t) * (n + 1));
1449     tabt = bf_malloc(s, sizeof(limb_t) * 2 * (n + 1));
1450     if (!tabb_inv || !tabt)
1451         goto fail;
1452 
1453     if (n >= nb) {
1454         for(i = 0; i < n - nb; i++)
1455             tabt[i] = 0;
1456         for(i = 0; i < nb; i++)
1457             tabt[i + n - nb] = tabb[i];
1458     } else {
1459         /* truncate B: need to increment it so that the approximate
1460            inverse is smaller that the exact inverse */
1461         for(i = 0; i < n; i++)
1462             tabt[i] = tabb[i + nb - n];
1463         if (mp_add_ui(tabt, 1, n)) {
1464             /* tabt = B^n : tabb_inv = B^n */
1465             memset(tabb_inv, 0, n * sizeof(limb_t));
1466             tabb_inv[n] = 1;
1467             goto recip_done;
1468         }
1469     }
1470     if (mp_recip(s, tabb_inv, tabt, n))
1471         goto fail;
1472  recip_done:
1473     /* Q=A*B^-1 */
1474     if (mp_mul(s, tabt, tabb_inv, n + 1, taba + na - (n + 1), n + 1))
1475         goto fail;
1476 
1477     for(i = 0; i < nq + 1; i++)
1478         tabq[i] = tabt[i + 2 * (n + 1) - (nq + 1)];
1479 #ifdef DEBUG_DIVNORM_LARGE
1480     mp_print_str("q", tabq, nq + 1);
1481 #endif
1482 
1483     bf_free(s, tabt);
1484     bf_free(s, tabb_inv);
1485     tabb_inv = NULL;
1486 
1487     /* R=A-B*Q */
1488     tabt = bf_malloc(s, sizeof(limb_t) * (na + 1));
1489     if (!tabt)
1490         goto fail;
1491     if (mp_mul(s, tabt, tabq, nq + 1, tabb, nb))
1492         goto fail;
1493     /* we add one more limb for the result */
1494     mp_sub(taba, taba, tabt, nb + 1, 0);
1495     bf_free(s, tabt);
1496     /* the approximated quotient is smaller than than the exact one,
1497        hence we may have to increment it */
1498 #ifdef DEBUG_DIVNORM_LARGE2
1499     int cnt = 0;
1500     static int cnt_max;
1501 #endif
1502     for(;;) {
1503         if (taba[nb] == 0 && mp_cmp(taba, tabb, nb) < 0)
1504             break;
1505         taba[nb] -= mp_sub(taba, taba, tabb, nb, 0);
1506         mp_add_ui(tabq, 1, nq + 1);
1507 #ifdef DEBUG_DIVNORM_LARGE2
1508         cnt++;
1509 #endif
1510     }
1511 #ifdef DEBUG_DIVNORM_LARGE2
1512     if (cnt > cnt_max) {
1513         cnt_max = cnt;
1514         printf("\ncnt=%d nq=%d nb=%d\n", cnt_max, (int)nq, (int)nb);
1515     }
1516 #endif
1517     return 0;
1518  fail:
1519     bf_free(s, tabb_inv);
1520     bf_free(s, tabt);
1521     return -1;
1522 }
1523 
bf_mul(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1524 int bf_mul(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1525            bf_flags_t flags)
1526 {
1527     int ret, r_sign;
1528 
1529     if (a->len < b->len) {
1530         const bf_t *tmp = a;
1531         a = b;
1532         b = tmp;
1533     }
1534     r_sign = a->sign ^ b->sign;
1535     /* here b->len <= a->len */
1536     if (b->len == 0) {
1537         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1538             bf_set_nan(r);
1539             ret = 0;
1540         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
1541             if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
1542                 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
1543                 bf_set_nan(r);
1544                 ret = BF_ST_INVALID_OP;
1545             } else {
1546                 bf_set_inf(r, r_sign);
1547                 ret = 0;
1548             }
1549         } else {
1550             bf_set_zero(r, r_sign);
1551             ret = 0;
1552         }
1553     } else {
1554         bf_t tmp, *r1 = NULL;
1555         limb_t a_len, b_len, precl;
1556         limb_t *a_tab, *b_tab;
1557 
1558         a_len = a->len;
1559         b_len = b->len;
1560 
1561         if ((flags & BF_RND_MASK) == BF_RNDF) {
1562             /* faithful rounding does not require using the full inputs */
1563             precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1564             a_len = bf_min(a_len, precl);
1565             b_len = bf_min(b_len, precl);
1566         }
1567         a_tab = a->tab + a->len - a_len;
1568         b_tab = b->tab + b->len - b_len;
1569 
1570 #ifdef USE_FFT_MUL
1571         if (b_len >= FFT_MUL_THRESHOLD) {
1572             int mul_flags = 0;
1573             if (r == a)
1574                 mul_flags |= FFT_MUL_R_OVERLAP_A;
1575             if (r == b)
1576                 mul_flags |= FFT_MUL_R_OVERLAP_B;
1577             if (fft_mul(r->ctx, r, a_tab, a_len, b_tab, b_len, mul_flags))
1578                 goto fail;
1579         } else
1580 #endif
1581         {
1582             if (r == a || r == b) {
1583                 bf_init(r->ctx, &tmp);
1584                 r1 = r;
1585                 r = &tmp;
1586             }
1587             if (bf_resize(r, a_len + b_len)) {
1588             fail:
1589                 bf_set_nan(r);
1590                 ret = BF_ST_MEM_ERROR;
1591                 goto done;
1592             }
1593             mp_mul_basecase(r->tab, a_tab, a_len, b_tab, b_len);
1594         }
1595         r->sign = r_sign;
1596         r->expn = a->expn + b->expn;
1597         ret = bf_normalize_and_round(r, prec, flags);
1598     done:
1599         if (r == &tmp)
1600             bf_move(r1, &tmp);
1601     }
1602     return ret;
1603 }
1604 
1605 /* multiply 'r' by 2^e */
bf_mul_2exp(bf_t * r,slimb_t e,limb_t prec,bf_flags_t flags)1606 int bf_mul_2exp(bf_t *r, slimb_t e, limb_t prec, bf_flags_t flags)
1607 {
1608     slimb_t e_max;
1609     if (r->len == 0)
1610         return 0;
1611     e_max = ((limb_t)1 << BF_EXT_EXP_BITS_MAX) - 1;
1612     e = bf_max(e, -e_max);
1613     e = bf_min(e, e_max);
1614     r->expn += e;
1615     return __bf_round(r, prec, flags, r->len, 0);
1616 }
1617 
1618 /* Return e such as a=m*2^e with m odd integer. return 0 if a is zero,
1619    Infinite or Nan. */
bf_get_exp_min(const bf_t * a)1620 slimb_t bf_get_exp_min(const bf_t *a)
1621 {
1622     slimb_t i;
1623     limb_t v;
1624     int k;
1625 
1626     for(i = 0; i < a->len; i++) {
1627         v = a->tab[i];
1628         if (v != 0) {
1629             k = ctz(v);
1630             return a->expn - (a->len - i) * LIMB_BITS + k;
1631         }
1632     }
1633     return 0;
1634 }
1635 
1636 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
1637    integer defined as floor(a/b) and r = a - q * b. */
bf_tdivremu(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b)1638 static void bf_tdivremu(bf_t *q, bf_t *r,
1639                         const bf_t *a, const bf_t *b)
1640 {
1641     if (bf_cmpu(a, b) < 0) {
1642         bf_set_ui(q, 0);
1643         bf_set(r, a);
1644     } else {
1645         bf_div(q, a, b, bf_max(a->expn - b->expn + 1, 2), BF_RNDZ);
1646         bf_rint(q, BF_RNDZ);
1647         bf_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
1648         bf_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
1649     }
1650 }
1651 
__bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)1652 static int __bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1653                     bf_flags_t flags)
1654 {
1655     bf_context_t *s = r->ctx;
1656     int ret, r_sign;
1657     limb_t n, nb, precl;
1658 
1659     r_sign = a->sign ^ b->sign;
1660     if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
1661         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1662             bf_set_nan(r);
1663             return 0;
1664         } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
1665             bf_set_nan(r);
1666             return BF_ST_INVALID_OP;
1667         } else if (a->expn == BF_EXP_INF) {
1668             bf_set_inf(r, r_sign);
1669             return 0;
1670         } else {
1671             bf_set_zero(r, r_sign);
1672             return 0;
1673         }
1674     } else if (a->expn == BF_EXP_ZERO) {
1675         if (b->expn == BF_EXP_ZERO) {
1676             bf_set_nan(r);
1677             return BF_ST_INVALID_OP;
1678         } else {
1679             bf_set_zero(r, r_sign);
1680             return 0;
1681         }
1682     } else if (b->expn == BF_EXP_ZERO) {
1683         bf_set_inf(r, r_sign);
1684         return BF_ST_DIVIDE_ZERO;
1685     }
1686 
1687     /* number of limbs of the quotient (2 extra bits for rounding) */
1688     precl = (prec + 2 + LIMB_BITS - 1) / LIMB_BITS;
1689     nb = b->len;
1690     n = bf_max(a->len, precl);
1691 
1692     {
1693         limb_t *taba, na;
1694         slimb_t d;
1695 
1696         na = n + nb;
1697         taba = bf_malloc(s, (na + 1) * sizeof(limb_t));
1698         if (!taba)
1699             goto fail;
1700         d = na - a->len;
1701         memset(taba, 0, d * sizeof(limb_t));
1702         memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
1703         if (bf_resize(r, n + 1))
1704             goto fail1;
1705         if (mp_divnorm(s, r->tab, taba, na, b->tab, nb)) {
1706         fail1:
1707             bf_free(s, taba);
1708             goto fail;
1709         }
1710         /* see if non zero remainder */
1711         if (mp_scan_nz(taba, nb))
1712             r->tab[0] |= 1;
1713         bf_free(r->ctx, taba);
1714         r->expn = a->expn - b->expn + LIMB_BITS;
1715         r->sign = r_sign;
1716         ret = bf_normalize_and_round(r, prec, flags);
1717     }
1718     return ret;
1719  fail:
1720     bf_set_nan(r);
1721     return BF_ST_MEM_ERROR;
1722 }
1723 
1724 /* division and remainder.
1725 
1726    rnd_mode is the rounding mode for the quotient. The additional
1727    rounding mode BF_RND_EUCLIDIAN is supported.
1728 
1729    'q' is an integer. 'r' is rounded with prec and flags (prec can be
1730    BF_PREC_INF).
1731 */
bf_divrem(bf_t * q,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1732 int bf_divrem(bf_t *q, bf_t *r, const bf_t *a, const bf_t *b,
1733               limb_t prec, bf_flags_t flags, int rnd_mode)
1734 {
1735     bf_t a1_s, *a1 = &a1_s;
1736     bf_t b1_s, *b1 = &b1_s;
1737     int q_sign, ret;
1738     BOOL is_ceil, is_rndn;
1739 
1740     assert(q != a && q != b);
1741     assert(r != a && r != b);
1742     assert(q != r);
1743 
1744     if (a->len == 0 || b->len == 0) {
1745         bf_set_zero(q, 0);
1746         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
1747             bf_set_nan(r);
1748             return 0;
1749         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
1750             bf_set_nan(r);
1751             return BF_ST_INVALID_OP;
1752         } else {
1753             bf_set(r, a);
1754             return bf_round(r, prec, flags);
1755         }
1756     }
1757 
1758     q_sign = a->sign ^ b->sign;
1759     is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA);
1760     switch(rnd_mode) {
1761     default:
1762     case BF_RNDZ:
1763     case BF_RNDN:
1764     case BF_RNDNA:
1765         is_ceil = FALSE;
1766         break;
1767     case BF_RNDD:
1768         is_ceil = q_sign;
1769         break;
1770     case BF_RNDU:
1771         is_ceil = q_sign ^ 1;
1772         break;
1773     case BF_RNDA:
1774         is_ceil = TRUE;
1775         break;
1776     case BF_DIVREM_EUCLIDIAN:
1777         is_ceil = a->sign;
1778         break;
1779     }
1780 
1781     a1->expn = a->expn;
1782     a1->tab = a->tab;
1783     a1->len = a->len;
1784     a1->sign = 0;
1785 
1786     b1->expn = b->expn;
1787     b1->tab = b->tab;
1788     b1->len = b->len;
1789     b1->sign = 0;
1790 
1791     /* XXX: could improve to avoid having a large 'q' */
1792     bf_tdivremu(q, r, a1, b1);
1793     if (bf_is_nan(q) || bf_is_nan(r))
1794         goto fail;
1795 
1796     if (r->len != 0) {
1797         if (is_rndn) {
1798             int res;
1799             b1->expn--;
1800             res = bf_cmpu(r, b1);
1801             b1->expn++;
1802             if (res > 0 ||
1803                 (res == 0 &&
1804                  (rnd_mode == BF_RNDNA ||
1805                   get_bit(q->tab, q->len, q->len * LIMB_BITS - q->expn)))) {
1806                 goto do_sub_r;
1807             }
1808         } else if (is_ceil) {
1809         do_sub_r:
1810             ret = bf_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
1811             ret |= bf_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
1812             if (ret & BF_ST_MEM_ERROR)
1813                 goto fail;
1814         }
1815     }
1816 
1817     r->sign ^= a->sign;
1818     q->sign = q_sign;
1819     return bf_round(r, prec, flags);
1820  fail:
1821     bf_set_nan(q);
1822     bf_set_nan(r);
1823     return BF_ST_MEM_ERROR;
1824 }
1825 
bf_rem(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1826 int bf_rem(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1827            bf_flags_t flags, int rnd_mode)
1828 {
1829     bf_t q_s, *q = &q_s;
1830     int ret;
1831 
1832     bf_init(r->ctx, q);
1833     ret = bf_divrem(q, r, a, b, prec, flags, rnd_mode);
1834     bf_delete(q);
1835     return ret;
1836 }
1837 
bf_get_limb(slimb_t * pres,const bf_t * a,int flags)1838 static inline int bf_get_limb(slimb_t *pres, const bf_t *a, int flags)
1839 {
1840 #if LIMB_BITS == 32
1841     return bf_get_int32(pres, a, flags);
1842 #else
1843     return bf_get_int64(pres, a, flags);
1844 #endif
1845 }
1846 
bf_remquo(slimb_t * pq,bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)1847 int bf_remquo(slimb_t *pq, bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
1848               bf_flags_t flags, int rnd_mode)
1849 {
1850     bf_t q_s, *q = &q_s;
1851     int ret;
1852 
1853     bf_init(r->ctx, q);
1854     ret = bf_divrem(q, r, a, b, prec, flags, rnd_mode);
1855     bf_get_limb(pq, q, BF_GET_INT_MOD);
1856     bf_delete(q);
1857     return ret;
1858 }
1859 
mul_mod(limb_t a,limb_t b,limb_t m)1860 static __maybe_unused inline limb_t mul_mod(limb_t a, limb_t b, limb_t m)
1861 {
1862     dlimb_t t;
1863     t = (dlimb_t)a * (dlimb_t)b;
1864     return t % m;
1865 }
1866 
1867 #if defined(USE_MUL_CHECK)
mp_mod1(const limb_t * tab,limb_t n,limb_t m,limb_t r)1868 static limb_t mp_mod1(const limb_t *tab, limb_t n, limb_t m, limb_t r)
1869 {
1870     slimb_t i;
1871     dlimb_t t;
1872 
1873     for(i = n - 1; i >= 0; i--) {
1874         t = ((dlimb_t)r << LIMB_BITS) | tab[i];
1875         r = t % m;
1876     }
1877     return r;
1878 }
1879 #endif
1880 
1881 static const uint16_t sqrt_table[192] = {
1882 128,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,144,145,146,147,148,149,150,150,151,152,153,154,155,155,156,157,158,159,160,160,161,162,163,163,164,165,166,167,167,168,169,170,170,171,172,173,173,174,175,176,176,177,178,178,179,180,181,181,182,183,183,184,185,185,186,187,187,188,189,189,190,191,192,192,193,193,194,195,195,196,197,197,198,199,199,200,201,201,202,203,203,204,204,205,206,206,207,208,208,209,209,210,211,211,212,212,213,214,214,215,215,216,217,217,218,218,219,219,220,221,221,222,222,223,224,224,225,225,226,226,227,227,228,229,229,230,230,231,231,232,232,233,234,234,235,235,236,236,237,237,238,238,239,240,240,241,241,242,242,243,243,244,244,245,245,246,246,247,247,248,248,249,249,250,250,251,251,252,252,253,253,254,254,255,
1883 };
1884 
1885 /* a >= 2^(LIMB_BITS - 2).  Return (s, r) with s=floor(sqrt(a)) and
1886    r=a-s^2. 0 <= r <= 2 * s */
mp_sqrtrem1(limb_t * pr,limb_t a)1887 static limb_t mp_sqrtrem1(limb_t *pr, limb_t a)
1888 {
1889     limb_t s1, r1, s, r, q, u, num;
1890 
1891     /* use a table for the 16 -> 8 bit sqrt */
1892     s1 = sqrt_table[(a >> (LIMB_BITS - 8)) - 64];
1893     r1 = (a >> (LIMB_BITS - 16)) - s1 * s1;
1894     if (r1 > 2 * s1) {
1895         r1 -= 2 * s1 + 1;
1896         s1++;
1897     }
1898 
1899     /* one iteration to get a 32 -> 16 bit sqrt */
1900     num = (r1 << 8) | ((a >> (LIMB_BITS - 32 + 8)) & 0xff);
1901     q = num / (2 * s1); /* q <= 2^8 */
1902     u = num % (2 * s1);
1903     s = (s1 << 8) + q;
1904     r = (u << 8) | ((a >> (LIMB_BITS - 32)) & 0xff);
1905     r -= q * q;
1906     if ((slimb_t)r < 0) {
1907         s--;
1908         r += 2 * s + 1;
1909     }
1910 
1911 #if LIMB_BITS == 64
1912     s1 = s;
1913     r1 = r;
1914     /* one more iteration for 64 -> 32 bit sqrt */
1915     num = (r1 << 16) | ((a >> (LIMB_BITS - 64 + 16)) & 0xffff);
1916     q = num / (2 * s1); /* q <= 2^16 */
1917     u = num % (2 * s1);
1918     s = (s1 << 16) + q;
1919     r = (u << 16) | ((a >> (LIMB_BITS - 64)) & 0xffff);
1920     r -= q * q;
1921     if ((slimb_t)r < 0) {
1922         s--;
1923         r += 2 * s + 1;
1924     }
1925 #endif
1926     *pr = r;
1927     return s;
1928 }
1929 
1930 /* return floor(sqrt(a)) */
bf_isqrt(limb_t a)1931 limb_t bf_isqrt(limb_t a)
1932 {
1933     limb_t s, r;
1934     int k;
1935 
1936     if (a == 0)
1937         return 0;
1938     k = clz(a) & ~1;
1939     s = mp_sqrtrem1(&r, a << k);
1940     s >>= (k >> 1);
1941     return s;
1942 }
1943 
mp_sqrtrem2(limb_t * tabs,limb_t * taba)1944 static limb_t mp_sqrtrem2(limb_t *tabs, limb_t *taba)
1945 {
1946     limb_t s1, r1, s, q, u, a0, a1;
1947     dlimb_t r, num;
1948     int l;
1949 
1950     a0 = taba[0];
1951     a1 = taba[1];
1952     s1 = mp_sqrtrem1(&r1, a1);
1953     l = LIMB_BITS / 2;
1954     num = ((dlimb_t)r1 << l) | (a0 >> l);
1955     q = num / (2 * s1);
1956     u = num % (2 * s1);
1957     s = (s1 << l) + q;
1958     r = ((dlimb_t)u << l) | (a0 & (((limb_t)1 << l) - 1));
1959     if (unlikely((q >> l) != 0))
1960         r -= (dlimb_t)1 << LIMB_BITS; /* special case when q=2^l */
1961     else
1962         r -= q * q;
1963     if ((slimb_t)(r >> LIMB_BITS) < 0) {
1964         s--;
1965         r += 2 * (dlimb_t)s + 1;
1966     }
1967     tabs[0] = s;
1968     taba[0] = r;
1969     return r >> LIMB_BITS;
1970 }
1971 
1972 //#define DEBUG_SQRTREM
1973 
1974 /* tmp_buf must contain (n / 2 + 1 limbs). *prh contains the highest
1975    limb of the remainder. */
mp_sqrtrem_rec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf,limb_t * prh)1976 static int mp_sqrtrem_rec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n,
1977                           limb_t *tmp_buf, limb_t *prh)
1978 {
1979     limb_t l, h, rh, ql, qh, c, i;
1980 
1981     if (n == 1) {
1982         *prh = mp_sqrtrem2(tabs, taba);
1983         return 0;
1984     }
1985 #ifdef DEBUG_SQRTREM
1986     mp_print_str("a", taba, 2 * n);
1987 #endif
1988     l = n / 2;
1989     h = n - l;
1990     if (mp_sqrtrem_rec(s, tabs + l, taba + 2 * l, h, tmp_buf, &qh))
1991         return -1;
1992 #ifdef DEBUG_SQRTREM
1993     mp_print_str("s1", tabs + l, h);
1994     mp_print_str_h("r1", taba + 2 * l, h, qh);
1995     mp_print_str_h("r2", taba + l, n, qh);
1996 #endif
1997 
1998     /* the remainder is in taba + 2 * l. Its high bit is in qh */
1999     if (qh) {
2000         mp_sub(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
2001     }
2002     /* instead of dividing by 2*s, divide by s (which is normalized)
2003        and update q and r */
2004     if (mp_divnorm(s, tmp_buf, taba + l, n, tabs + l, h))
2005         return -1;
2006     qh += tmp_buf[l];
2007     for(i = 0; i < l; i++)
2008         tabs[i] = tmp_buf[i];
2009     ql = mp_shr(tabs, tabs, l, 1, qh & 1);
2010     qh = qh >> 1; /* 0 or 1 */
2011     if (ql)
2012         rh = mp_add(taba + l, taba + l, tabs + l, h, 0);
2013     else
2014         rh = 0;
2015 #ifdef DEBUG_SQRTREM
2016     mp_print_str_h("q", tabs, l, qh);
2017     mp_print_str_h("u", taba + l, h, rh);
2018 #endif
2019 
2020     mp_add_ui(tabs + l, qh, h);
2021 #ifdef DEBUG_SQRTREM
2022     mp_print_str_h("s2", tabs, n, sh);
2023 #endif
2024 
2025     /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
2026     /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
2027     if (qh) {
2028         c = qh;
2029     } else {
2030         if (mp_mul(s, taba + n, tabs, l, tabs, l))
2031             return -1;
2032         c = mp_sub(taba, taba, taba + n, 2 * l, 0);
2033     }
2034     rh -= mp_sub_ui(taba + 2 * l, c, n - 2 * l);
2035     if ((slimb_t)rh < 0) {
2036         mp_sub_ui(tabs, 1, n);
2037         rh += mp_add_mul1(taba, tabs, n, 2);
2038         rh += mp_add_ui(taba, 1, n);
2039     }
2040     *prh = rh;
2041     return 0;
2042 }
2043 
2044 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= 2 ^ (LIMB_BITS
2045    - 2). Return (s, r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2
2046    * s. tabs has n limbs. r is returned in the lower n limbs of
2047    taba. Its r[n] is the returned value of the function. */
2048 /* Algorithm from the article "Karatsuba Square Root" by Paul Zimmermann and
2049    inspirated from its GMP implementation */
mp_sqrtrem(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)2050 int mp_sqrtrem(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
2051 {
2052     limb_t tmp_buf1[8];
2053     limb_t *tmp_buf;
2054     mp_size_t n2;
2055     int ret;
2056     n2 = n / 2 + 1;
2057     if (n2 <= countof(tmp_buf1)) {
2058         tmp_buf = tmp_buf1;
2059     } else {
2060         tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
2061         if (!tmp_buf)
2062             return -1;
2063     }
2064     ret = mp_sqrtrem_rec(s, tabs, taba, n, tmp_buf, taba + n);
2065     if (tmp_buf != tmp_buf1)
2066         bf_free(s, tmp_buf);
2067     return ret;
2068 }
2069 
2070 /* Integer square root with remainder. 'a' must be an integer. r =
2071    floor(sqrt(a)) and rem = a - r^2.  BF_ST_INEXACT is set if the result
2072    is inexact. 'rem' can be NULL if the remainder is not needed. */
bf_sqrtrem(bf_t * r,bf_t * rem1,const bf_t * a)2073 int bf_sqrtrem(bf_t *r, bf_t *rem1, const bf_t *a)
2074 {
2075     int ret;
2076 
2077     if (a->len == 0) {
2078         if (a->expn == BF_EXP_NAN) {
2079             bf_set_nan(r);
2080         } else if (a->expn == BF_EXP_INF && a->sign) {
2081             goto invalid_op;
2082         } else {
2083             bf_set(r, a);
2084         }
2085         if (rem1)
2086             bf_set_ui(rem1, 0);
2087         ret = 0;
2088     } else if (a->sign) {
2089  invalid_op:
2090         bf_set_nan(r);
2091         if (rem1)
2092             bf_set_ui(rem1, 0);
2093         ret = BF_ST_INVALID_OP;
2094     } else {
2095         bf_t rem_s, *rem;
2096 
2097         bf_sqrt(r, a, (a->expn + 1) / 2, BF_RNDZ);
2098         bf_rint(r, BF_RNDZ);
2099         /* see if the result is exact by computing the remainder */
2100         if (rem1) {
2101             rem = rem1;
2102         } else {
2103             rem = &rem_s;
2104             bf_init(r->ctx, rem);
2105         }
2106         /* XXX: could avoid recomputing the remainder */
2107         bf_mul(rem, r, r, BF_PREC_INF, BF_RNDZ);
2108         bf_neg(rem);
2109         bf_add(rem, rem, a, BF_PREC_INF, BF_RNDZ);
2110         if (bf_is_nan(rem)) {
2111             ret = BF_ST_MEM_ERROR;
2112             goto done;
2113         }
2114         if (rem->len != 0) {
2115             ret = BF_ST_INEXACT;
2116         } else {
2117             ret = 0;
2118         }
2119     done:
2120         if (!rem1)
2121             bf_delete(rem);
2122     }
2123     return ret;
2124 }
2125 
bf_sqrt(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)2126 int bf_sqrt(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
2127 {
2128     bf_context_t *s = a->ctx;
2129     int ret;
2130 
2131     assert(r != a);
2132 
2133     if (a->len == 0) {
2134         if (a->expn == BF_EXP_NAN) {
2135             bf_set_nan(r);
2136         } else if (a->expn == BF_EXP_INF && a->sign) {
2137             goto invalid_op;
2138         } else {
2139             bf_set(r, a);
2140         }
2141         ret = 0;
2142     } else if (a->sign) {
2143  invalid_op:
2144         bf_set_nan(r);
2145         ret = BF_ST_INVALID_OP;
2146     } else {
2147         limb_t *a1;
2148         slimb_t n, n1;
2149         limb_t res;
2150 
2151         /* convert the mantissa to an integer with at least 2 *
2152            prec + 4 bits */
2153         n = (2 * (prec + 2) + 2 * LIMB_BITS - 1) / (2 * LIMB_BITS);
2154         if (bf_resize(r, n))
2155             goto fail;
2156         a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
2157         if (!a1)
2158             goto fail;
2159         n1 = bf_min(2 * n, a->len);
2160         memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
2161         memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
2162         if (a->expn & 1) {
2163             res = mp_shr(a1, a1, 2 * n, 1, 0);
2164         } else {
2165             res = 0;
2166         }
2167         if (mp_sqrtrem(s, r->tab, a1, n)) {
2168             bf_free(s, a1);
2169             goto fail;
2170         }
2171         if (!res) {
2172             res = mp_scan_nz(a1, n + 1);
2173         }
2174         bf_free(s, a1);
2175         if (!res) {
2176             res = mp_scan_nz(a->tab, a->len - n1);
2177         }
2178         if (res != 0)
2179             r->tab[0] |= 1;
2180         r->sign = 0;
2181         r->expn = (a->expn + 1) >> 1;
2182         ret = bf_round(r, prec, flags);
2183     }
2184     return ret;
2185  fail:
2186     bf_set_nan(r);
2187     return BF_ST_MEM_ERROR;
2188 }
2189 
bf_op2(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags,bf_op2_func_t * func)2190 static no_inline int bf_op2(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2191                             bf_flags_t flags, bf_op2_func_t *func)
2192 {
2193     bf_t tmp;
2194     int ret;
2195 
2196     if (r == a || r == b) {
2197         bf_init(r->ctx, &tmp);
2198         ret = func(&tmp, a, b, prec, flags);
2199         bf_move(r, &tmp);
2200     } else {
2201         ret = func(r, a, b, prec, flags);
2202     }
2203     return ret;
2204 }
2205 
bf_add(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2206 int bf_add(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2207             bf_flags_t flags)
2208 {
2209     return bf_op2(r, a, b, prec, flags, __bf_add);
2210 }
2211 
bf_sub(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2212 int bf_sub(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2213             bf_flags_t flags)
2214 {
2215     return bf_op2(r, a, b, prec, flags, __bf_sub);
2216 }
2217 
bf_div(bf_t * r,const bf_t * a,const bf_t * b,limb_t prec,bf_flags_t flags)2218 int bf_div(bf_t *r, const bf_t *a, const bf_t *b, limb_t prec,
2219            bf_flags_t flags)
2220 {
2221     return bf_op2(r, a, b, prec, flags, __bf_div);
2222 }
2223 
bf_mul_ui(bf_t * r,const bf_t * a,uint64_t b1,limb_t prec,bf_flags_t flags)2224 int bf_mul_ui(bf_t *r, const bf_t *a, uint64_t b1, limb_t prec,
2225                bf_flags_t flags)
2226 {
2227     bf_t b;
2228     int ret;
2229     bf_init(r->ctx, &b);
2230     ret = bf_set_ui(&b, b1);
2231     ret |= bf_mul(r, a, &b, prec, flags);
2232     bf_delete(&b);
2233     return ret;
2234 }
2235 
bf_mul_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2236 int bf_mul_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2237                bf_flags_t flags)
2238 {
2239     bf_t b;
2240     int ret;
2241     bf_init(r->ctx, &b);
2242     ret = bf_set_si(&b, b1);
2243     ret |= bf_mul(r, a, &b, prec, flags);
2244     bf_delete(&b);
2245     return ret;
2246 }
2247 
bf_add_si(bf_t * r,const bf_t * a,int64_t b1,limb_t prec,bf_flags_t flags)2248 int bf_add_si(bf_t *r, const bf_t *a, int64_t b1, limb_t prec,
2249               bf_flags_t flags)
2250 {
2251     bf_t b;
2252     int ret;
2253 
2254     bf_init(r->ctx, &b);
2255     ret = bf_set_si(&b, b1);
2256     ret |= bf_add(r, a, &b, prec, flags);
2257     bf_delete(&b);
2258     return ret;
2259 }
2260 
bf_pow_ui(bf_t * r,const bf_t * a,limb_t b,limb_t prec,bf_flags_t flags)2261 static int bf_pow_ui(bf_t *r, const bf_t *a, limb_t b, limb_t prec,
2262                      bf_flags_t flags)
2263 {
2264     int ret, n_bits, i;
2265 
2266     assert(r != a);
2267     if (b == 0)
2268         return bf_set_ui(r, 1);
2269     ret = bf_set(r, a);
2270     n_bits = LIMB_BITS - clz(b);
2271     for(i = n_bits - 2; i >= 0; i--) {
2272         ret |= bf_mul(r, r, r, prec, flags);
2273         if ((b >> i) & 1)
2274             ret |= bf_mul(r, r, a, prec, flags);
2275     }
2276     return ret;
2277 }
2278 
bf_pow_ui_ui(bf_t * r,limb_t a1,limb_t b,limb_t prec,bf_flags_t flags)2279 static int bf_pow_ui_ui(bf_t *r, limb_t a1, limb_t b,
2280                         limb_t prec, bf_flags_t flags)
2281 {
2282     bf_t a;
2283     int ret;
2284 
2285     if (a1 == 10 && b <= LIMB_DIGITS) {
2286         /* use precomputed powers. We do not round at this point
2287            because we expect the caller to do it */
2288         ret = bf_set_ui(r, mp_pow_dec[b]);
2289     } else {
2290         bf_init(r->ctx, &a);
2291         ret = bf_set_ui(&a, a1);
2292         ret |= bf_pow_ui(r, &a, b, prec, flags);
2293         bf_delete(&a);
2294     }
2295     return ret;
2296 }
2297 
2298 /* convert to integer (infinite precision) */
bf_rint(bf_t * r,int rnd_mode)2299 int bf_rint(bf_t *r, int rnd_mode)
2300 {
2301     return bf_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
2302 }
2303 
2304 /* logical operations */
2305 #define BF_LOGIC_OR  0
2306 #define BF_LOGIC_XOR 1
2307 #define BF_LOGIC_AND 2
2308 
bf_logic_op1(limb_t a,limb_t b,int op)2309 static inline limb_t bf_logic_op1(limb_t a, limb_t b, int op)
2310 {
2311     switch(op) {
2312     case BF_LOGIC_OR:
2313         return a | b;
2314     case BF_LOGIC_XOR:
2315         return a ^ b;
2316     default:
2317     case BF_LOGIC_AND:
2318         return a & b;
2319     }
2320 }
2321 
bf_logic_op(bf_t * r,const bf_t * a1,const bf_t * b1,int op)2322 static int bf_logic_op(bf_t *r, const bf_t *a1, const bf_t *b1, int op)
2323 {
2324     bf_t b1_s, a1_s, *a, *b;
2325     limb_t a_sign, b_sign, r_sign;
2326     slimb_t l, i, a_bit_offset, b_bit_offset;
2327     limb_t v1, v2, v1_mask, v2_mask, r_mask;
2328     int ret;
2329 
2330     assert(r != a1 && r != b1);
2331 
2332     if (a1->expn <= 0)
2333         a_sign = 0; /* minus zero is considered as positive */
2334     else
2335         a_sign = a1->sign;
2336 
2337     if (b1->expn <= 0)
2338         b_sign = 0; /* minus zero is considered as positive */
2339     else
2340         b_sign = b1->sign;
2341 
2342     if (a_sign) {
2343         a = &a1_s;
2344         bf_init(r->ctx, a);
2345         if (bf_add_si(a, a1, 1, BF_PREC_INF, BF_RNDZ)) {
2346             b = NULL;
2347             goto fail;
2348         }
2349     } else {
2350         a = (bf_t *)a1;
2351     }
2352 
2353     if (b_sign) {
2354         b = &b1_s;
2355         bf_init(r->ctx, b);
2356         if (bf_add_si(b, b1, 1, BF_PREC_INF, BF_RNDZ))
2357             goto fail;
2358     } else {
2359         b = (bf_t *)b1;
2360     }
2361 
2362     r_sign = bf_logic_op1(a_sign, b_sign, op);
2363     if (op == BF_LOGIC_AND && r_sign == 0) {
2364         /* no need to compute extra zeros for and */
2365         if (a_sign == 0 && b_sign == 0)
2366             l = bf_min(a->expn, b->expn);
2367         else if (a_sign == 0)
2368             l = a->expn;
2369         else
2370             l = b->expn;
2371     } else {
2372         l = bf_max(a->expn, b->expn);
2373     }
2374     /* Note: a or b can be zero */
2375     l = (bf_max(l, 1) + LIMB_BITS - 1) / LIMB_BITS;
2376     if (bf_resize(r, l))
2377         goto fail;
2378     a_bit_offset = a->len * LIMB_BITS - a->expn;
2379     b_bit_offset = b->len * LIMB_BITS - b->expn;
2380     v1_mask = -a_sign;
2381     v2_mask = -b_sign;
2382     r_mask = -r_sign;
2383     for(i = 0; i < l; i++) {
2384         v1 = get_bits(a->tab, a->len, a_bit_offset + i * LIMB_BITS) ^ v1_mask;
2385         v2 = get_bits(b->tab, b->len, b_bit_offset + i * LIMB_BITS) ^ v2_mask;
2386         r->tab[i] = bf_logic_op1(v1, v2, op) ^ r_mask;
2387     }
2388     r->expn = l * LIMB_BITS;
2389     r->sign = r_sign;
2390     bf_normalize_and_round(r, BF_PREC_INF, BF_RNDZ); /* cannot fail */
2391     if (r_sign) {
2392         if (bf_add_si(r, r, -1, BF_PREC_INF, BF_RNDZ))
2393             goto fail;
2394     }
2395     ret = 0;
2396  done:
2397     if (a == &a1_s)
2398         bf_delete(a);
2399     if (b == &b1_s)
2400         bf_delete(b);
2401     return ret;
2402  fail:
2403     bf_set_nan(r);
2404     ret = BF_ST_MEM_ERROR;
2405     goto done;
2406 }
2407 
2408 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_or(bf_t * r,const bf_t * a,const bf_t * b)2409 int bf_logic_or(bf_t *r, const bf_t *a, const bf_t *b)
2410 {
2411     return bf_logic_op(r, a, b, BF_LOGIC_OR);
2412 }
2413 
2414 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_xor(bf_t * r,const bf_t * a,const bf_t * b)2415 int bf_logic_xor(bf_t *r, const bf_t *a, const bf_t *b)
2416 {
2417     return bf_logic_op(r, a, b, BF_LOGIC_XOR);
2418 }
2419 
2420 /* 'a' and 'b' must be integers. Return 0 or BF_ST_MEM_ERROR. */
bf_logic_and(bf_t * r,const bf_t * a,const bf_t * b)2421 int bf_logic_and(bf_t *r, const bf_t *a, const bf_t *b)
2422 {
2423     return bf_logic_op(r, a, b, BF_LOGIC_AND);
2424 }
2425 
2426 /* conversion between fixed size types */
2427 
2428 typedef union {
2429     double d;
2430     uint64_t u;
2431 } Float64Union;
2432 
bf_get_float64(const bf_t * a,double * pres,bf_rnd_t rnd_mode)2433 int bf_get_float64(const bf_t *a, double *pres, bf_rnd_t rnd_mode)
2434 {
2435     Float64Union u;
2436     int e, ret;
2437     uint64_t m;
2438 
2439     ret = 0;
2440     if (a->expn == BF_EXP_NAN) {
2441         u.u = 0x7ff8000000000000; /* quiet nan */
2442     } else {
2443         bf_t b_s, *b = &b_s;
2444 
2445         bf_init(a->ctx, b);
2446         bf_set(b, a);
2447         if (bf_is_finite(b)) {
2448             ret = bf_round(b, 53, rnd_mode | BF_FLAG_SUBNORMAL | bf_set_exp_bits(11));
2449         }
2450         if (b->expn == BF_EXP_INF) {
2451             e = (1 << 11) - 1;
2452             m = 0;
2453         } else if (b->expn == BF_EXP_ZERO) {
2454             e = 0;
2455             m = 0;
2456         } else {
2457             e = b->expn + 1023 - 1;
2458 #if LIMB_BITS == 32
2459             if (b->len == 2) {
2460                 m = ((uint64_t)b->tab[1] << 32) | b->tab[0];
2461             } else {
2462                 m = ((uint64_t)b->tab[0] << 32);
2463             }
2464 #else
2465             m = b->tab[0];
2466 #endif
2467             if (e <= 0) {
2468                 /* subnormal */
2469                 m = m >> (12 - e);
2470                 e = 0;
2471             } else {
2472                 m = (m << 1) >> 12;
2473             }
2474         }
2475         u.u = m | ((uint64_t)e << 52) | ((uint64_t)b->sign << 63);
2476         bf_delete(b);
2477     }
2478     *pres = u.d;
2479     return ret;
2480 }
2481 
bf_set_float64(bf_t * a,double d)2482 int bf_set_float64(bf_t *a, double d)
2483 {
2484     Float64Union u;
2485     uint64_t m;
2486     int shift, e, sgn;
2487 
2488     u.d = d;
2489     sgn = u.u >> 63;
2490     e = (u.u >> 52) & ((1 << 11) - 1);
2491     m = u.u & (((uint64_t)1 << 52) - 1);
2492     if (e == ((1 << 11) - 1)) {
2493         if (m != 0) {
2494             bf_set_nan(a);
2495         } else {
2496             bf_set_inf(a, sgn);
2497         }
2498     } else if (e == 0) {
2499         if (m == 0) {
2500             bf_set_zero(a, sgn);
2501         } else {
2502             /* subnormal number */
2503             m <<= 12;
2504             shift = clz64(m);
2505             m <<= shift;
2506             e = -shift;
2507             goto norm;
2508         }
2509     } else {
2510         m = (m << 11) | ((uint64_t)1 << 63);
2511     norm:
2512         a->expn = e - 1023 + 1;
2513 #if LIMB_BITS == 32
2514         if (bf_resize(a, 2))
2515             goto fail;
2516         a->tab[0] = m;
2517         a->tab[1] = m >> 32;
2518 #else
2519         if (bf_resize(a, 1))
2520             goto fail;
2521         a->tab[0] = m;
2522 #endif
2523         a->sign = sgn;
2524     }
2525     return 0;
2526 fail:
2527     bf_set_nan(a);
2528     return BF_ST_MEM_ERROR;
2529 }
2530 
2531 /* The rounding mode is always BF_RNDZ. Return BF_ST_INVALID_OP if there
2532    is an overflow and 0 otherwise. */
bf_get_int32(int * pres,const bf_t * a,int flags)2533 int bf_get_int32(int *pres, const bf_t *a, int flags)
2534 {
2535     uint32_t v;
2536     int ret;
2537     if (a->expn >= BF_EXP_INF) {
2538         ret = BF_ST_INVALID_OP;
2539         if (flags & BF_GET_INT_MOD) {
2540             v = 0;
2541         } else if (a->expn == BF_EXP_INF) {
2542             v = (uint32_t)INT32_MAX + a->sign;
2543         } else {
2544             v = INT32_MAX;
2545         }
2546     } else if (a->expn <= 0) {
2547         v = 0;
2548         ret = 0;
2549     } else if (a->expn <= 31) {
2550         v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2551         if (a->sign)
2552             v = -v;
2553         ret = 0;
2554     } else if (!(flags & BF_GET_INT_MOD)) {
2555         ret = BF_ST_INVALID_OP;
2556         if (a->sign) {
2557             v = (uint32_t)INT32_MAX + 1;
2558             if (a->expn == 32 &&
2559                 (a->tab[a->len - 1] >> (LIMB_BITS - 32)) == v) {
2560                 ret = 0;
2561             }
2562         } else {
2563             v = INT32_MAX;
2564         }
2565     } else {
2566         v = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
2567         if (a->sign)
2568             v = -v;
2569         ret = 0;
2570     }
2571     *pres = v;
2572     return ret;
2573 }
2574 
2575 /* The rounding mode is always BF_RNDZ. Return BF_ST_INVALID_OP if there
2576    is an overflow and 0 otherwise. */
bf_get_int64(int64_t * pres,const bf_t * a,int flags)2577 int bf_get_int64(int64_t *pres, const bf_t *a, int flags)
2578 {
2579     uint64_t v;
2580     int ret;
2581     if (a->expn >= BF_EXP_INF) {
2582         ret = BF_ST_INVALID_OP;
2583         if (flags & BF_GET_INT_MOD) {
2584             v = 0;
2585         } else if (a->expn == BF_EXP_INF) {
2586             v = (uint64_t)INT64_MAX + a->sign;
2587         } else {
2588             v = INT64_MAX;
2589         }
2590     } else if (a->expn <= 0) {
2591         v = 0;
2592         ret = 0;
2593     } else if (a->expn <= 63) {
2594 #if LIMB_BITS == 32
2595         if (a->expn <= 32)
2596             v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2597         else
2598             v = (((uint64_t)a->tab[a->len - 1] << 32) |
2599                  get_limbz(a, a->len - 2)) >> (64 - a->expn);
2600 #else
2601         v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2602 #endif
2603         if (a->sign)
2604             v = -v;
2605         ret = 0;
2606     } else if (!(flags & BF_GET_INT_MOD)) {
2607         ret = BF_ST_INVALID_OP;
2608         if (a->sign) {
2609             uint64_t v1;
2610             v = (uint64_t)INT64_MAX + 1;
2611             if (a->expn == 64) {
2612                 v1 = a->tab[a->len - 1];
2613 #if LIMB_BITS == 32
2614                 v1 = (v1 << 32) | get_limbz(a, a->len - 2);
2615 #endif
2616                 if (v1 == v)
2617                     ret = 0;
2618             }
2619         } else {
2620             v = INT64_MAX;
2621         }
2622     } else {
2623         slimb_t bit_pos = a->len * LIMB_BITS - a->expn;
2624         v = get_bits(a->tab, a->len, bit_pos);
2625 #if LIMB_BITS == 32
2626         v |= (uint64_t)get_bits(a->tab, a->len, bit_pos + 32) << 32;
2627 #endif
2628         if (a->sign)
2629             v = -v;
2630         ret = 0;
2631     }
2632     *pres = v;
2633     return ret;
2634 }
2635 
2636 /* The rounding mode is always BF_RNDZ. Return BF_ST_INVALID_OP if there
2637    is an overflow and 0 otherwise. */
bf_get_uint64(uint64_t * pres,const bf_t * a)2638 int bf_get_uint64(uint64_t *pres, const bf_t *a)
2639 {
2640     uint64_t v;
2641     int ret;
2642     if (a->expn == BF_EXP_NAN) {
2643         goto overflow;
2644     } else if (a->expn <= 0) {
2645         v = 0;
2646         ret = 0;
2647     } else if (a->sign) {
2648         v = 0;
2649         ret = BF_ST_INVALID_OP;
2650     } else if (a->expn <= 64) {
2651 #if LIMB_BITS == 32
2652         if (a->expn <= 32)
2653             v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2654         else
2655             v = (((uint64_t)a->tab[a->len - 1] << 32) |
2656                  get_limbz(a, a->len - 2)) >> (64 - a->expn);
2657 #else
2658         v = a->tab[a->len - 1] >> (LIMB_BITS - a->expn);
2659 #endif
2660         ret = 0;
2661     } else {
2662     overflow:
2663         v = UINT64_MAX;
2664         ret = BF_ST_INVALID_OP;
2665     }
2666     *pres = v;
2667     return ret;
2668 }
2669 
2670 /* base conversion from radix */
2671 
2672 static const uint8_t digits_per_limb_table[BF_RADIX_MAX - 1] = {
2673 #if LIMB_BITS == 32
2674 32,20,16,13,12,11,10,10, 9, 9, 8, 8, 8, 8, 8, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
2675 #else
2676 64,40,32,27,24,22,21,20,19,18,17,17,16,16,16,15,15,15,14,14,14,14,13,13,13,13,13,13,13,12,12,12,12,12,12,
2677 #endif
2678 };
2679 
get_limb_radix(int radix)2680 static limb_t get_limb_radix(int radix)
2681 {
2682     int i, k;
2683     limb_t radixl;
2684 
2685     k = digits_per_limb_table[radix - 2];
2686     radixl = radix;
2687     for(i = 1; i < k; i++)
2688         radixl *= radix;
2689     return radixl;
2690 }
2691 
2692 /* return != 0 if error */
bf_integer_from_radix_rec(bf_t * r,const limb_t * tab,limb_t n,int level,limb_t n0,limb_t radix,bf_t * pow_tab)2693 static int bf_integer_from_radix_rec(bf_t *r, const limb_t *tab,
2694                                      limb_t n, int level, limb_t n0,
2695                                      limb_t radix, bf_t *pow_tab)
2696 {
2697     int ret;
2698     if (n == 1) {
2699         ret = bf_set_ui(r, tab[0]);
2700     } else {
2701         bf_t T_s, *T = &T_s, *B;
2702         limb_t n1, n2;
2703 
2704         n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
2705         n1 = n - n2;
2706         //        printf("level=%d n0=%ld n1=%ld n2=%ld\n", level, n0, n1, n2);
2707         B = &pow_tab[level];
2708         if (B->len == 0) {
2709             ret = bf_pow_ui_ui(B, radix, n2, BF_PREC_INF, BF_RNDZ);
2710             if (ret)
2711                 return ret;
2712         }
2713         ret = bf_integer_from_radix_rec(r, tab + n2, n1, level + 1, n0,
2714                                         radix, pow_tab);
2715         if (ret)
2716             return ret;
2717         ret = bf_mul(r, r, B, BF_PREC_INF, BF_RNDZ);
2718         if (ret)
2719             return ret;
2720         bf_init(r->ctx, T);
2721         ret = bf_integer_from_radix_rec(T, tab, n2, level + 1, n0,
2722                                         radix, pow_tab);
2723         if (!ret)
2724             ret = bf_add(r, r, T, BF_PREC_INF, BF_RNDZ);
2725         bf_delete(T);
2726     }
2727     return ret;
2728     //    bf_print_str("  r=", r);
2729 }
2730 
2731 /* return 0 if OK != 0 if memory error */
bf_integer_from_radix(bf_t * r,const limb_t * tab,limb_t n,limb_t radix)2732 static int bf_integer_from_radix(bf_t *r, const limb_t *tab,
2733                                  limb_t n, limb_t radix)
2734 {
2735     bf_context_t *s = r->ctx;
2736     int pow_tab_len, i, ret;
2737     limb_t radixl;
2738     bf_t *pow_tab;
2739 
2740     radixl = get_limb_radix(radix);
2741     pow_tab_len = ceil_log2(n) + 2; /* XXX: check */
2742     pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
2743     if (!pow_tab)
2744         return -1;
2745     for(i = 0; i < pow_tab_len; i++)
2746         bf_init(r->ctx, &pow_tab[i]);
2747     ret = bf_integer_from_radix_rec(r, tab, n, 0, n, radixl, pow_tab);
2748     for(i = 0; i < pow_tab_len; i++) {
2749         bf_delete(&pow_tab[i]);
2750     }
2751     bf_free(s, pow_tab);
2752     return ret;
2753 }
2754 
2755 /* compute and round T * radix^expn. */
bf_mul_pow_radix(bf_t * r,const bf_t * T,limb_t radix,slimb_t expn,limb_t prec,bf_flags_t flags)2756 int bf_mul_pow_radix(bf_t *r, const bf_t *T, limb_t radix,
2757                      slimb_t expn, limb_t prec, bf_flags_t flags)
2758 {
2759     int ret, expn_sign, overflow;
2760     slimb_t e, extra_bits, prec1, ziv_extra_bits;
2761     bf_t B_s, *B = &B_s;
2762 
2763     if (T->len == 0) {
2764         return bf_set(r, T);
2765     } else if (expn == 0) {
2766         ret = bf_set(r, T);
2767         ret |= bf_round(r, prec, flags);
2768         return ret;
2769     }
2770 
2771     e = expn;
2772     expn_sign = 0;
2773     if (e < 0) {
2774         e = -e;
2775         expn_sign = 1;
2776     }
2777     bf_init(r->ctx, B);
2778     if (prec == BF_PREC_INF) {
2779         /* infinite precision: only used if the result is known to be exact */
2780         ret = bf_pow_ui_ui(B, radix, e, BF_PREC_INF, BF_RNDN);
2781         if (expn_sign) {
2782             ret |= bf_div(r, T, B, T->len * LIMB_BITS, BF_RNDN);
2783         } else {
2784             ret |= bf_mul(r, T, B, BF_PREC_INF, BF_RNDN);
2785         }
2786     } else {
2787         ziv_extra_bits = 16;
2788         for(;;) {
2789             prec1 = prec + ziv_extra_bits;
2790             /* XXX: correct overflow/underflow handling */
2791             /* XXX: rigorous error analysis needed */
2792             extra_bits = ceil_log2(e) * 2 + 1;
2793             ret = bf_pow_ui_ui(B, radix, e, prec1 + extra_bits, BF_RNDN | BF_FLAG_EXT_EXP);
2794             overflow = !bf_is_finite(B);
2795             /* XXX: if bf_pow_ui_ui returns an exact result, can stop
2796                after the next operation */
2797             if (expn_sign)
2798                 ret |= bf_div(r, T, B, prec1 + extra_bits, BF_RNDN | BF_FLAG_EXT_EXP);
2799             else
2800                 ret |= bf_mul(r, T, B, prec1 + extra_bits, BF_RNDN | BF_FLAG_EXT_EXP);
2801             if (ret & BF_ST_MEM_ERROR)
2802                 break;
2803             if ((ret & BF_ST_INEXACT) &&
2804                 !bf_can_round(r, prec, flags & BF_RND_MASK, prec1) &&
2805                 !overflow) {
2806                 /* and more precision and retry */
2807                 ziv_extra_bits = ziv_extra_bits  + (ziv_extra_bits / 2);
2808             } else {
2809                 /* XXX: need to use __bf_round() to pass the inexact
2810                    flag for the subnormal case */
2811                 ret = bf_round(r, prec, flags) | (ret & BF_ST_INEXACT);
2812                 break;
2813             }
2814         }
2815     }
2816     bf_delete(B);
2817     return ret;
2818 }
2819 
to_digit(int c)2820 static inline int to_digit(int c)
2821 {
2822     if (c >= '0' && c <= '9')
2823         return c - '0';
2824     else if (c >= 'A' && c <= 'Z')
2825         return c - 'A' + 10;
2826     else if (c >= 'a' && c <= 'z')
2827         return c - 'a' + 10;
2828     else
2829         return 36;
2830 }
2831 
2832 /* add a limb at 'pos' and decrement pos. new space is created if
2833    needed. Return 0 if OK, -1 if memory error */
bf_add_limb(bf_t * a,slimb_t * ppos,limb_t v)2834 static int bf_add_limb(bf_t *a, slimb_t *ppos, limb_t v)
2835 {
2836     slimb_t pos;
2837     pos = *ppos;
2838     if (unlikely(pos < 0)) {
2839         limb_t new_size, d, *new_tab;
2840         new_size = bf_max(a->len + 1, a->len * 3 / 2);
2841         new_tab = bf_realloc(a->ctx, a->tab, sizeof(limb_t) * new_size);
2842         if (!new_tab)
2843             return -1;
2844         a->tab = new_tab;
2845         d = new_size - a->len;
2846         memmove(a->tab + d, a->tab, a->len * sizeof(limb_t));
2847         a->len = new_size;
2848         pos += d;
2849     }
2850     a->tab[pos--] = v;
2851     *ppos = pos;
2852     return 0;
2853 }
2854 
bf_tolower(int c)2855 static int bf_tolower(int c)
2856 {
2857     if (c >= 'A' && c <= 'Z')
2858         c = c - 'A' + 'a';
2859     return c;
2860 }
2861 
strcasestart(const char * str,const char * val,const char ** ptr)2862 static int strcasestart(const char *str, const char *val, const char **ptr)
2863 {
2864     const char *p, *q;
2865     p = str;
2866     q = val;
2867     while (*q != '\0') {
2868         if (bf_tolower(*p) != *q)
2869             return 0;
2870         p++;
2871         q++;
2872     }
2873     if (ptr)
2874         *ptr = p;
2875     return 1;
2876 }
2877 
bf_atof_internal(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)2878 static int bf_atof_internal(bf_t *r, slimb_t *pexponent,
2879                             const char *str, const char **pnext, int radix,
2880                             limb_t prec, bf_flags_t flags, BOOL is_dec)
2881 {
2882     const char *p, *p_start;
2883     int is_neg, radix_bits, exp_is_neg, ret, digits_per_limb, shift;
2884     limb_t cur_limb;
2885     slimb_t pos, expn, int_len, digit_count;
2886     BOOL has_decpt, is_bin_exp;
2887     bf_t a_s, *a;
2888 
2889     *pexponent = 0;
2890     p = str;
2891     if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2892         strcasestart(p, "nan", &p)) {
2893         bf_set_nan(r);
2894         ret = 0;
2895         goto done;
2896     }
2897     is_neg = 0;
2898 
2899     if (p[0] == '+') {
2900         p++;
2901         p_start = p;
2902     } else if (p[0] == '-') {
2903         is_neg = 1;
2904         p++;
2905         p_start = p;
2906     } else {
2907         p_start = p;
2908     }
2909     if (p[0] == '0') {
2910         if ((p[1] == 'x' || p[1] == 'X') &&
2911             (radix == 0 || radix == 16) &&
2912             !(flags & BF_ATOF_NO_HEX)) {
2913             radix = 16;
2914             p += 2;
2915         } else if ((p[1] == 'o' || p[1] == 'O') &&
2916                    radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2917             p += 2;
2918             radix = 8;
2919         } else if ((p[1] == 'b' || p[1] == 'B') &&
2920                    radix == 0 && (flags & BF_ATOF_BIN_OCT)) {
2921             p += 2;
2922             radix = 2;
2923         } else {
2924             goto no_prefix;
2925         }
2926         /* there must be a digit after the prefix */
2927         if (to_digit((uint8_t)*p) >= radix) {
2928             bf_set_nan(r);
2929             ret = 0;
2930             goto done;
2931         }
2932     no_prefix: ;
2933     } else {
2934         if (!(flags & BF_ATOF_NO_NAN_INF) && radix <= 16 &&
2935             strcasestart(p, "inf", &p)) {
2936             bf_set_inf(r, is_neg);
2937             ret = 0;
2938             goto done;
2939         }
2940     }
2941 
2942     if (radix == 0)
2943         radix = 10;
2944     if (is_dec) {
2945         assert(radix == 10);
2946         radix_bits = 0;
2947         a = r;
2948     } else if ((radix & (radix - 1)) != 0) {
2949         radix_bits = 0; /* base is not a power of two */
2950         a = &a_s;
2951         bf_init(r->ctx, a);
2952     } else {
2953         radix_bits = ceil_log2(radix);
2954         a = r;
2955     }
2956 
2957     /* skip leading zeros */
2958     /* XXX: could also skip zeros after the decimal point */
2959     while (*p == '0')
2960         p++;
2961 
2962     if (radix_bits) {
2963         shift = digits_per_limb = LIMB_BITS;
2964     } else {
2965         radix_bits = 0;
2966         shift = digits_per_limb = digits_per_limb_table[radix - 2];
2967     }
2968     cur_limb = 0;
2969     bf_resize(a, 1);
2970     pos = 0;
2971     has_decpt = FALSE;
2972     int_len = digit_count = 0;
2973     for(;;) {
2974         limb_t c;
2975         if (*p == '.' && (p > p_start || to_digit(p[1]) < radix)) {
2976             if (has_decpt)
2977                 break;
2978             has_decpt = TRUE;
2979             int_len = digit_count;
2980             p++;
2981         }
2982         c = to_digit(*p);
2983         if (c >= radix)
2984             break;
2985         digit_count++;
2986         p++;
2987         if (radix_bits) {
2988             shift -= radix_bits;
2989             if (shift <= 0) {
2990                 cur_limb |= c >> (-shift);
2991                 if (bf_add_limb(a, &pos, cur_limb))
2992                     goto mem_error;
2993                 if (shift < 0)
2994                     cur_limb = c << (LIMB_BITS + shift);
2995                 else
2996                     cur_limb = 0;
2997                 shift += LIMB_BITS;
2998             } else {
2999                 cur_limb |= c << shift;
3000             }
3001         } else {
3002             cur_limb = cur_limb * radix + c;
3003             shift--;
3004             if (shift == 0) {
3005                 if (bf_add_limb(a, &pos, cur_limb))
3006                     goto mem_error;
3007                 shift = digits_per_limb;
3008                 cur_limb = 0;
3009             }
3010         }
3011     }
3012     if (!has_decpt)
3013         int_len = digit_count;
3014 
3015     /* add the last limb and pad with zeros */
3016     if (shift != digits_per_limb) {
3017         if (radix_bits == 0) {
3018             while (shift != 0) {
3019                 cur_limb *= radix;
3020                 shift--;
3021             }
3022         }
3023         if (bf_add_limb(a, &pos, cur_limb)) {
3024         mem_error:
3025             ret = BF_ST_MEM_ERROR;
3026             if (!radix_bits)
3027                 bf_delete(a);
3028             bf_set_nan(r);
3029             goto done;
3030         }
3031     }
3032 
3033     /* reset the next limbs to zero (we prefer to reallocate in the
3034        renormalization) */
3035     memset(a->tab, 0, (pos + 1) * sizeof(limb_t));
3036 
3037     if (p == p_start) {
3038         ret = 0;
3039         if (!radix_bits)
3040             bf_delete(a);
3041         bf_set_nan(r);
3042         goto done;
3043     }
3044 
3045     /* parse the exponent, if any */
3046     expn = 0;
3047     is_bin_exp = FALSE;
3048     if (((radix == 10 && (*p == 'e' || *p == 'E')) ||
3049          (radix != 10 && (*p == '@' ||
3050                           (radix_bits && (*p == 'p' || *p == 'P'))))) &&
3051         p > p_start) {
3052         is_bin_exp = (*p == 'p' || *p == 'P');
3053         p++;
3054         exp_is_neg = 0;
3055         if (*p == '+') {
3056             p++;
3057         } else if (*p == '-') {
3058             exp_is_neg = 1;
3059             p++;
3060         }
3061         for(;;) {
3062             int c;
3063             c = to_digit(*p);
3064             if (c >= 10)
3065                 break;
3066             if (unlikely(expn > ((BF_RAW_EXP_MAX - 2 - 9) / 10))) {
3067                 /* exponent overflow */
3068                 if (exp_is_neg) {
3069                     bf_set_zero(r, is_neg);
3070                     ret = BF_ST_UNDERFLOW | BF_ST_INEXACT;
3071                 } else {
3072                     bf_set_inf(r, is_neg);
3073                     ret = BF_ST_OVERFLOW | BF_ST_INEXACT;
3074                 }
3075                 goto done;
3076             }
3077             p++;
3078             expn = expn * 10 + c;
3079         }
3080         if (exp_is_neg)
3081             expn = -expn;
3082     }
3083     if (is_dec) {
3084         a->expn = expn + int_len;
3085         a->sign = is_neg;
3086         ret = bfdec_normalize_and_round((bfdec_t *)a, prec, flags);
3087     } else if (radix_bits) {
3088         /* XXX: may overflow */
3089         if (!is_bin_exp)
3090             expn *= radix_bits;
3091         a->expn = expn + (int_len * radix_bits);
3092         a->sign = is_neg;
3093         ret = bf_normalize_and_round(a, prec, flags);
3094     } else {
3095         limb_t l;
3096         pos++;
3097         l = a->len - pos; /* number of limbs */
3098         if (l == 0) {
3099             bf_set_zero(r, is_neg);
3100             ret = 0;
3101         } else {
3102             bf_t T_s, *T = &T_s;
3103 
3104             expn -= l * digits_per_limb - int_len;
3105             bf_init(r->ctx, T);
3106             if (bf_integer_from_radix(T, a->tab + pos, l, radix)) {
3107                 bf_set_nan(r);
3108                 ret = BF_ST_MEM_ERROR;
3109             } else {
3110                 T->sign = is_neg;
3111                 if (flags & BF_ATOF_EXPONENT) {
3112                     /* return the exponent */
3113                     *pexponent = expn;
3114                     ret = bf_set(r, T);
3115                 } else {
3116                     ret = bf_mul_pow_radix(r, T, radix, expn, prec, flags);
3117                 }
3118             }
3119             bf_delete(T);
3120         }
3121         bf_delete(a);
3122     }
3123  done:
3124     if (pnext)
3125         *pnext = p;
3126     return ret;
3127 }
3128 
3129 /*
3130    Return (status, n, exp). 'status' is the floating point status. 'n'
3131    is the parsed number.
3132 
3133    If (flags & BF_ATOF_EXPONENT) and if the radix is not a power of
3134    two, the parsed number is equal to r *
3135    (*pexponent)^radix. Otherwise *pexponent = 0.
3136 */
bf_atof2(bf_t * r,slimb_t * pexponent,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3137 int bf_atof2(bf_t *r, slimb_t *pexponent,
3138              const char *str, const char **pnext, int radix,
3139              limb_t prec, bf_flags_t flags)
3140 {
3141     return bf_atof_internal(r, pexponent, str, pnext, radix, prec, flags,
3142                             FALSE);
3143 }
3144 
bf_atof(bf_t * r,const char * str,const char ** pnext,int radix,limb_t prec,bf_flags_t flags)3145 int bf_atof(bf_t *r, const char *str, const char **pnext, int radix,
3146             limb_t prec, bf_flags_t flags)
3147 {
3148     slimb_t dummy_exp;
3149     return bf_atof_internal(r, &dummy_exp, str, pnext, radix, prec, flags, FALSE);
3150 }
3151 
3152 /* base conversion to radix */
3153 
3154 #if LIMB_BITS == 64
3155 #define RADIXL_10 UINT64_C(10000000000000000000)
3156 #else
3157 #define RADIXL_10 UINT64_C(1000000000)
3158 #endif
3159 
3160 static const uint32_t inv_log2_radix[BF_RADIX_MAX - 1][LIMB_BITS / 32 + 1] = {
3161 #if LIMB_BITS == 32
3162 { 0x80000000, 0x00000000,},
3163 { 0x50c24e60, 0xd4d4f4a7,},
3164 { 0x40000000, 0x00000000,},
3165 { 0x372068d2, 0x0a1ee5ca,},
3166 { 0x3184648d, 0xb8153e7a,},
3167 { 0x2d983275, 0x9d5369c4,},
3168 { 0x2aaaaaaa, 0xaaaaaaab,},
3169 { 0x28612730, 0x6a6a7a54,},
3170 { 0x268826a1, 0x3ef3fde6,},
3171 { 0x25001383, 0xbac8a744,},
3172 { 0x23b46706, 0x82c0c709,},
3173 { 0x229729f1, 0xb2c83ded,},
3174 { 0x219e7ffd, 0xa5ad572b,},
3175 { 0x20c33b88, 0xda7c29ab,},
3176 { 0x20000000, 0x00000000,},
3177 { 0x1f50b57e, 0xac5884b3,},
3178 { 0x1eb22cc6, 0x8aa6e26f,},
3179 { 0x1e21e118, 0x0c5daab2,},
3180 { 0x1d9dcd21, 0x439834e4,},
3181 { 0x1d244c78, 0x367a0d65,},
3182 { 0x1cb40589, 0xac173e0c,},
3183 { 0x1c4bd95b, 0xa8d72b0d,},
3184 { 0x1bead768, 0x98f8ce4c,},
3185 { 0x1b903469, 0x050f72e5,},
3186 { 0x1b3b433f, 0x2eb06f15,},
3187 { 0x1aeb6f75, 0x9c46fc38,},
3188 { 0x1aa038eb, 0x0e3bfd17,},
3189 { 0x1a593062, 0xb38d8c56,},
3190 { 0x1a15f4c3, 0x2b95a2e6,},
3191 { 0x19d630dc, 0xcc7ddef9,},
3192 { 0x19999999, 0x9999999a,},
3193 { 0x195fec80, 0x8a609431,},
3194 { 0x1928ee7b, 0x0b4f22f9,},
3195 { 0x18f46acf, 0x8c06e318,},
3196 { 0x18c23246, 0xdc0a9f3d,},
3197 #else
3198 { 0x80000000, 0x00000000, 0x00000000,},
3199 { 0x50c24e60, 0xd4d4f4a7, 0x021f57bc,},
3200 { 0x40000000, 0x00000000, 0x00000000,},
3201 { 0x372068d2, 0x0a1ee5ca, 0x19ea911b,},
3202 { 0x3184648d, 0xb8153e7a, 0x7fc2d2e1,},
3203 { 0x2d983275, 0x9d5369c4, 0x4dec1661,},
3204 { 0x2aaaaaaa, 0xaaaaaaaa, 0xaaaaaaab,},
3205 { 0x28612730, 0x6a6a7a53, 0x810fabde,},
3206 { 0x268826a1, 0x3ef3fde6, 0x23e2566b,},
3207 { 0x25001383, 0xbac8a744, 0x385a3349,},
3208 { 0x23b46706, 0x82c0c709, 0x3f891718,},
3209 { 0x229729f1, 0xb2c83ded, 0x15fba800,},
3210 { 0x219e7ffd, 0xa5ad572a, 0xe169744b,},
3211 { 0x20c33b88, 0xda7c29aa, 0x9bddee52,},
3212 { 0x20000000, 0x00000000, 0x00000000,},
3213 { 0x1f50b57e, 0xac5884b3, 0x70e28eee,},
3214 { 0x1eb22cc6, 0x8aa6e26f, 0x06d1a2a2,},
3215 { 0x1e21e118, 0x0c5daab1, 0x81b4f4bf,},
3216 { 0x1d9dcd21, 0x439834e3, 0x81667575,},
3217 { 0x1d244c78, 0x367a0d64, 0xc8204d6d,},
3218 { 0x1cb40589, 0xac173e0c, 0x3b7b16ba,},
3219 { 0x1c4bd95b, 0xa8d72b0d, 0x5879f25a,},
3220 { 0x1bead768, 0x98f8ce4c, 0x66cc2858,},
3221 { 0x1b903469, 0x050f72e5, 0x0cf5488e,},
3222 { 0x1b3b433f, 0x2eb06f14, 0x8c89719c,},
3223 { 0x1aeb6f75, 0x9c46fc37, 0xab5fc7e9,},
3224 { 0x1aa038eb, 0x0e3bfd17, 0x1bd62080,},
3225 { 0x1a593062, 0xb38d8c56, 0x7998ab45,},
3226 { 0x1a15f4c3, 0x2b95a2e6, 0x46aed6a0,},
3227 { 0x19d630dc, 0xcc7ddef9, 0x5aadd61b,},
3228 { 0x19999999, 0x99999999, 0x9999999a,},
3229 { 0x195fec80, 0x8a609430, 0xe1106014,},
3230 { 0x1928ee7b, 0x0b4f22f9, 0x5f69791d,},
3231 { 0x18f46acf, 0x8c06e318, 0x4d2aeb2c,},
3232 { 0x18c23246, 0xdc0a9f3d, 0x3fe16970,},
3233 #endif
3234 };
3235 
3236 static const limb_t log2_radix[BF_RADIX_MAX - 1] = {
3237 #if LIMB_BITS == 32
3238 0x20000000,
3239 0x32b80347,
3240 0x40000000,
3241 0x4a4d3c26,
3242 0x52b80347,
3243 0x59d5d9fd,
3244 0x60000000,
3245 0x6570068e,
3246 0x6a4d3c26,
3247 0x6eb3a9f0,
3248 0x72b80347,
3249 0x766a008e,
3250 0x79d5d9fd,
3251 0x7d053f6d,
3252 0x80000000,
3253 0x82cc7edf,
3254 0x8570068e,
3255 0x87ef05ae,
3256 0x8a4d3c26,
3257 0x8c8ddd45,
3258 0x8eb3a9f0,
3259 0x90c10501,
3260 0x92b80347,
3261 0x949a784c,
3262 0x966a008e,
3263 0x982809d6,
3264 0x99d5d9fd,
3265 0x9b74948f,
3266 0x9d053f6d,
3267 0x9e88c6b3,
3268 0xa0000000,
3269 0xa16bad37,
3270 0xa2cc7edf,
3271 0xa4231623,
3272 0xa570068e,
3273 #else
3274 0x2000000000000000,
3275 0x32b803473f7ad0f4,
3276 0x4000000000000000,
3277 0x4a4d3c25e68dc57f,
3278 0x52b803473f7ad0f4,
3279 0x59d5d9fd5010b366,
3280 0x6000000000000000,
3281 0x6570068e7ef5a1e8,
3282 0x6a4d3c25e68dc57f,
3283 0x6eb3a9f01975077f,
3284 0x72b803473f7ad0f4,
3285 0x766a008e4788cbcd,
3286 0x79d5d9fd5010b366,
3287 0x7d053f6d26089673,
3288 0x8000000000000000,
3289 0x82cc7edf592262d0,
3290 0x8570068e7ef5a1e8,
3291 0x87ef05ae409a0289,
3292 0x8a4d3c25e68dc57f,
3293 0x8c8ddd448f8b845a,
3294 0x8eb3a9f01975077f,
3295 0x90c10500d63aa659,
3296 0x92b803473f7ad0f4,
3297 0x949a784bcd1b8afe,
3298 0x966a008e4788cbcd,
3299 0x982809d5be7072dc,
3300 0x99d5d9fd5010b366,
3301 0x9b74948f5532da4b,
3302 0x9d053f6d26089673,
3303 0x9e88c6b3626a72aa,
3304 0xa000000000000000,
3305 0xa16bad3758efd873,
3306 0xa2cc7edf592262d0,
3307 0xa4231623369e78e6,
3308 0xa570068e7ef5a1e8,
3309 #endif
3310 };
3311 
3312 /* compute floor(a*b) or ceil(a*b) with b = log2(radix) or
3313    b=1/log2(radix). For is_inv = 0, strict accuracy is not guaranteed
3314    when radix is not a power of two. */
bf_mul_log2_radix(slimb_t a1,unsigned int radix,int is_inv,int is_ceil1)3315 slimb_t bf_mul_log2_radix(slimb_t a1, unsigned int radix, int is_inv,
3316                           int is_ceil1)
3317 {
3318     int is_neg;
3319     limb_t a;
3320     BOOL is_ceil;
3321 
3322     is_ceil = is_ceil1;
3323     a = a1;
3324     if (a1 < 0) {
3325         a = -a;
3326         is_neg = 1;
3327     } else {
3328         is_neg = 0;
3329     }
3330     is_ceil ^= is_neg;
3331     if ((radix & (radix - 1)) == 0) {
3332         int radix_bits;
3333         /* radix is a power of two */
3334         radix_bits = ceil_log2(radix);
3335         if (is_inv) {
3336             if (is_ceil)
3337                 a += radix_bits - 1;
3338             a = a / radix_bits;
3339         } else {
3340             a = a * radix_bits;
3341         }
3342     } else {
3343         const uint32_t *tab;
3344         limb_t b0, b1;
3345         dlimb_t t;
3346 
3347         if (is_inv) {
3348             tab = inv_log2_radix[radix - 2];
3349 #if LIMB_BITS == 32
3350             b1 = tab[0];
3351             b0 = tab[1];
3352 #else
3353             b1 = ((limb_t)tab[0] << 32) | tab[1];
3354             b0 = (limb_t)tab[2] << 32;
3355 #endif
3356             t = (dlimb_t)b0 * (dlimb_t)a;
3357             t = (dlimb_t)b1 * (dlimb_t)a + (t >> LIMB_BITS);
3358             a = t >> (LIMB_BITS - 1);
3359         } else {
3360             b0 = log2_radix[radix - 2];
3361             t = (dlimb_t)b0 * (dlimb_t)a;
3362             a = t >> (LIMB_BITS - 3);
3363         }
3364         /* a = floor(result) and 'result' cannot be an integer */
3365         a += is_ceil;
3366     }
3367     if (is_neg)
3368         a = -a;
3369     return a;
3370 }
3371 
3372 /* 'n' is the number of output limbs */
bf_integer_to_radix_rec(bf_t * pow_tab,limb_t * out,const bf_t * a,limb_t n,int level,limb_t n0,limb_t radixl,unsigned int radixl_bits)3373 static int bf_integer_to_radix_rec(bf_t *pow_tab,
3374                                    limb_t *out, const bf_t *a, limb_t n,
3375                                    int level, limb_t n0, limb_t radixl,
3376                                    unsigned int radixl_bits)
3377 {
3378     limb_t n1, n2, q_prec;
3379     int ret;
3380 
3381     assert(n >= 1);
3382     if (n == 1) {
3383         out[0] = get_bits(a->tab, a->len, a->len * LIMB_BITS - a->expn);
3384     } else if (n == 2) {
3385         dlimb_t t;
3386         slimb_t pos;
3387         pos = a->len * LIMB_BITS - a->expn;
3388         t = ((dlimb_t)get_bits(a->tab, a->len, pos + LIMB_BITS) << LIMB_BITS) |
3389             get_bits(a->tab, a->len, pos);
3390         if (likely(radixl == RADIXL_10)) {
3391             /* use division by a constant when possible */
3392             out[0] = t % RADIXL_10;
3393             out[1] = t / RADIXL_10;
3394         } else {
3395             out[0] = t % radixl;
3396             out[1] = t / radixl;
3397         }
3398     } else {
3399         bf_t Q, R, *B, *B_inv;
3400         int q_add;
3401         bf_init(a->ctx, &Q);
3402         bf_init(a->ctx, &R);
3403         n2 = (((n0 * 2) >> (level + 1)) + 1) / 2;
3404         n1 = n - n2;
3405         B = &pow_tab[2 * level];
3406         B_inv = &pow_tab[2 * level + 1];
3407         ret = 0;
3408         if (B->len == 0) {
3409             /* compute BASE^n2 */
3410             ret |= bf_pow_ui_ui(B, radixl, n2, BF_PREC_INF, BF_RNDZ);
3411             /* we use enough bits for the maximum possible 'n1' value,
3412                i.e. n2 + 1 */
3413             ret |= bf_set_ui(&R, 1);
3414             ret |= bf_div(B_inv, &R, B, (n2 + 1) * radixl_bits + 2, BF_RNDN);
3415         }
3416         //        printf("%d: n1=% " PRId64 " n2=%" PRId64 "\n", level, n1, n2);
3417         q_prec = n1 * radixl_bits;
3418         ret |= bf_mul(&Q, a, B_inv, q_prec, BF_RNDN);
3419         ret |= bf_rint(&Q, BF_RNDZ);
3420 
3421         ret |= bf_mul(&R, &Q, B, BF_PREC_INF, BF_RNDZ);
3422         ret |= bf_sub(&R, a, &R, BF_PREC_INF, BF_RNDZ);
3423 
3424         if (ret & BF_ST_MEM_ERROR)
3425             goto fail;
3426         /* adjust if necessary */
3427         q_add = 0;
3428         while (R.sign && R.len != 0) {
3429             if (bf_add(&R, &R, B, BF_PREC_INF, BF_RNDZ))
3430                 goto fail;
3431             q_add--;
3432         }
3433         while (bf_cmpu(&R, B) >= 0) {
3434             if (bf_sub(&R, &R, B, BF_PREC_INF, BF_RNDZ))
3435                 goto fail;
3436             q_add++;
3437         }
3438         if (q_add != 0) {
3439             if (bf_add_si(&Q, &Q, q_add, BF_PREC_INF, BF_RNDZ))
3440                 goto fail;
3441         }
3442         if (bf_integer_to_radix_rec(pow_tab, out + n2, &Q, n1, level + 1, n0,
3443                                     radixl, radixl_bits))
3444             goto fail;
3445         if (bf_integer_to_radix_rec(pow_tab, out, &R, n2, level + 1, n0,
3446                                     radixl, radixl_bits)) {
3447         fail:
3448             bf_delete(&Q);
3449             bf_delete(&R);
3450             return -1;
3451         }
3452         bf_delete(&Q);
3453         bf_delete(&R);
3454     }
3455     return 0;
3456 }
3457 
3458 /* return 0 if OK != 0 if memory error */
bf_integer_to_radix(bf_t * r,const bf_t * a,limb_t radixl)3459 static int bf_integer_to_radix(bf_t *r, const bf_t *a, limb_t radixl)
3460 {
3461     bf_context_t *s = r->ctx;
3462     limb_t r_len;
3463     bf_t *pow_tab;
3464     int i, pow_tab_len, ret;
3465 
3466     r_len = r->len;
3467     pow_tab_len = (ceil_log2(r_len) + 2) * 2; /* XXX: check */
3468     pow_tab = bf_malloc(s, sizeof(pow_tab[0]) * pow_tab_len);
3469     if (!pow_tab)
3470         return -1;
3471     for(i = 0; i < pow_tab_len; i++)
3472         bf_init(r->ctx, &pow_tab[i]);
3473 
3474     ret = bf_integer_to_radix_rec(pow_tab, r->tab, a, r_len, 0, r_len, radixl,
3475                                   ceil_log2(radixl));
3476 
3477     for(i = 0; i < pow_tab_len; i++) {
3478         bf_delete(&pow_tab[i]);
3479     }
3480     bf_free(s, pow_tab);
3481     return ret;
3482 }
3483 
3484 /* a must be >= 0. 'P' is the wanted number of digits in radix
3485    'radix'. 'r' is the mantissa represented as an integer. *pE
3486    contains the exponent. Return != 0 if memory error. */
bf_convert_to_radix(bf_t * r,slimb_t * pE,const bf_t * a,int radix,limb_t P,bf_rnd_t rnd_mode,BOOL is_fixed_exponent)3487 static int bf_convert_to_radix(bf_t *r, slimb_t *pE,
3488                                const bf_t *a, int radix,
3489                                limb_t P, bf_rnd_t rnd_mode,
3490                                BOOL is_fixed_exponent)
3491 {
3492     slimb_t E, e, prec, extra_bits, ziv_extra_bits, prec0;
3493     bf_t B_s, *B = &B_s;
3494     int e_sign, ret, res;
3495 
3496     if (a->len == 0) {
3497         /* zero case */
3498         *pE = 0;
3499         return bf_set(r, a);
3500     }
3501 
3502     if (is_fixed_exponent) {
3503         E = *pE;
3504     } else {
3505         /* compute the new exponent */
3506         E = 1 + bf_mul_log2_radix(a->expn - 1, radix, TRUE, FALSE);
3507     }
3508     //    bf_print_str("a", a);
3509     //    printf("E=%ld P=%ld radix=%d\n", E, P, radix);
3510 
3511     for(;;) {
3512         e = P - E;
3513         e_sign = 0;
3514         if (e < 0) {
3515             e = -e;
3516             e_sign = 1;
3517         }
3518         /* Note: precision for log2(radix) is not critical here */
3519         prec0 = bf_mul_log2_radix(P, radix, FALSE, TRUE);
3520         ziv_extra_bits = 16;
3521         for(;;) {
3522             prec = prec0 + ziv_extra_bits;
3523             /* XXX: rigorous error analysis needed */
3524             extra_bits = ceil_log2(e) * 2 + 1;
3525             ret = bf_pow_ui_ui(r, radix, e, prec + extra_bits,
3526                                BF_RNDN | BF_FLAG_EXT_EXP);
3527             if (!e_sign)
3528                 ret |= bf_mul(r, r, a, prec + extra_bits,
3529                               BF_RNDN | BF_FLAG_EXT_EXP);
3530             else
3531                 ret |= bf_div(r, a, r, prec + extra_bits,
3532                               BF_RNDN | BF_FLAG_EXT_EXP);
3533             if (ret & BF_ST_MEM_ERROR)
3534                 return BF_ST_MEM_ERROR;
3535             /* if the result is not exact, check that it can be safely
3536                rounded to an integer */
3537             if ((ret & BF_ST_INEXACT) &&
3538                 !bf_can_round(r, r->expn, rnd_mode, prec)) {
3539                 /* and more precision and retry */
3540                 ziv_extra_bits = ziv_extra_bits  + (ziv_extra_bits / 2);
3541                 continue;
3542             } else {
3543                 ret = bf_rint(r, rnd_mode);
3544                 if (ret & BF_ST_MEM_ERROR)
3545                     return BF_ST_MEM_ERROR;
3546                 break;
3547             }
3548         }
3549         if (is_fixed_exponent)
3550             break;
3551         /* check that the result is < B^P */
3552         /* XXX: do a fast approximate test first ? */
3553         bf_init(r->ctx, B);
3554         ret = bf_pow_ui_ui(B, radix, P, BF_PREC_INF, BF_RNDZ);
3555         if (ret) {
3556             bf_delete(B);
3557             return ret;
3558         }
3559         res = bf_cmpu(r, B);
3560         bf_delete(B);
3561         if (res < 0)
3562             break;
3563         /* try a larger exponent */
3564         E++;
3565     }
3566     *pE = E;
3567     return 0;
3568 }
3569 
limb_to_a(char * buf,limb_t n,unsigned int radix,int len)3570 static void limb_to_a(char *buf, limb_t n, unsigned int radix, int len)
3571 {
3572     int digit, i;
3573 
3574     if (radix == 10) {
3575         /* specific case with constant divisor */
3576         for(i = len - 1; i >= 0; i--) {
3577             digit = (limb_t)n % 10;
3578             n = (limb_t)n / 10;
3579             buf[i] = digit + '0';
3580         }
3581     } else {
3582         for(i = len - 1; i >= 0; i--) {
3583             digit = (limb_t)n % radix;
3584             n = (limb_t)n / radix;
3585             if (digit < 10)
3586                 digit += '0';
3587             else
3588                 digit += 'a' - 10;
3589             buf[i] = digit;
3590         }
3591     }
3592 }
3593 
3594 /* for power of 2 radixes */
limb_to_a2(char * buf,limb_t n,unsigned int radix_bits,int len)3595 static void limb_to_a2(char *buf, limb_t n, unsigned int radix_bits, int len)
3596 {
3597     int digit, i;
3598     unsigned int mask;
3599 
3600     mask = (1 << radix_bits) - 1;
3601     for(i = len - 1; i >= 0; i--) {
3602         digit = n & mask;
3603         n >>= radix_bits;
3604         if (digit < 10)
3605             digit += '0';
3606         else
3607             digit += 'a' - 10;
3608         buf[i] = digit;
3609     }
3610 }
3611 
3612 /* 'a' must be an integer if the is_dec = FALSE or if the radix is not
3613    a power of two. A dot is added before the 'dot_pos' digit. dot_pos
3614    = n_digits does not display the dot. 0 <= dot_pos <=
3615    n_digits. n_digits >= 1. */
output_digits(DynBuf * s,const bf_t * a1,int radix,limb_t n_digits,limb_t dot_pos,BOOL is_dec)3616 static void output_digits(DynBuf *s, const bf_t *a1, int radix, limb_t n_digits,
3617                           limb_t dot_pos, BOOL is_dec)
3618 {
3619     limb_t i, v, l;
3620     slimb_t pos, pos_incr;
3621     int digits_per_limb, buf_pos, radix_bits, first_buf_pos;
3622     char buf[65];
3623     bf_t a_s, *a;
3624 
3625     if (is_dec) {
3626         digits_per_limb = LIMB_DIGITS;
3627         a = (bf_t *)a1;
3628         radix_bits = 0;
3629         pos = a->len;
3630         pos_incr = 1;
3631         first_buf_pos = 0;
3632     } else if ((radix & (radix - 1)) == 0) {
3633         a = (bf_t *)a1;
3634         radix_bits = ceil_log2(radix);
3635         digits_per_limb = LIMB_BITS / radix_bits;
3636         pos_incr = digits_per_limb * radix_bits;
3637         /* digits are aligned relative to the radix point */
3638         pos = a->len * LIMB_BITS + smod(-a->expn, radix_bits);
3639         first_buf_pos = 0;
3640     } else {
3641         limb_t n, radixl;
3642 
3643         digits_per_limb = digits_per_limb_table[radix - 2];
3644         radixl = get_limb_radix(radix);
3645         a = &a_s;
3646         bf_init(a1->ctx, a);
3647         n = (n_digits + digits_per_limb - 1) / digits_per_limb;
3648         if (bf_resize(a, n)) {
3649             dbuf_set_error(s);
3650             goto done;
3651         }
3652         if (bf_integer_to_radix(a, a1, radixl)) {
3653             dbuf_set_error(s);
3654             goto done;
3655         }
3656         radix_bits = 0;
3657         pos = n;
3658         pos_incr = 1;
3659         first_buf_pos = pos * digits_per_limb - n_digits;
3660     }
3661     buf_pos = digits_per_limb;
3662     i = 0;
3663     while (i < n_digits) {
3664         if (buf_pos == digits_per_limb) {
3665             pos -= pos_incr;
3666             if (radix_bits == 0) {
3667                 v = get_limbz(a, pos);
3668                 limb_to_a(buf, v, radix, digits_per_limb);
3669             } else {
3670                 v = get_bits(a->tab, a->len, pos);
3671                 limb_to_a2(buf, v, radix_bits, digits_per_limb);
3672             }
3673             buf_pos = first_buf_pos;
3674             first_buf_pos = 0;
3675         }
3676         if (i < dot_pos) {
3677             l = dot_pos;
3678         } else {
3679             if (i == dot_pos)
3680                 dbuf_putc(s, '.');
3681             l = n_digits;
3682         }
3683         l = bf_min(digits_per_limb - buf_pos, l - i);
3684         dbuf_put(s, (uint8_t *)(buf + buf_pos), l);
3685         buf_pos += l;
3686         i += l;
3687     }
3688  done:
3689     if (a != a1)
3690         bf_delete(a);
3691 }
3692 
bf_dbuf_realloc(void * opaque,void * ptr,size_t size)3693 static void *bf_dbuf_realloc(void *opaque, void *ptr, size_t size)
3694 {
3695     bf_context_t *s = opaque;
3696     return bf_realloc(s, ptr, size);
3697 }
3698 
3699 /* return the length in bytes. A trailing '\0' is added */
bf_ftoa_internal(size_t * plen,const bf_t * a2,int radix,limb_t prec,bf_flags_t flags,BOOL is_dec)3700 static char *bf_ftoa_internal(size_t *plen, const bf_t *a2, int radix,
3701                               limb_t prec, bf_flags_t flags, BOOL is_dec)
3702 {
3703     bf_context_t *ctx = a2->ctx;
3704     DynBuf s_s, *s = &s_s;
3705     int radix_bits;
3706 
3707     //    bf_print_str("ftoa", a2);
3708     //    printf("radix=%d\n", radix);
3709     dbuf_init2(s, ctx, bf_dbuf_realloc);
3710     if (a2->expn == BF_EXP_NAN) {
3711         dbuf_putstr(s, "NaN");
3712     } else {
3713         if (a2->sign)
3714             dbuf_putc(s, '-');
3715         if (a2->expn == BF_EXP_INF) {
3716             if (flags & BF_FTOA_JS_QUIRKS)
3717                 dbuf_putstr(s, "Infinity");
3718             else
3719                 dbuf_putstr(s, "Inf");
3720         } else {
3721             int fmt, ret;
3722             slimb_t n_digits, n, i, n_max, n1;
3723             bf_t a1_s, *a1 = &a1_s;
3724 
3725             if ((radix & (radix - 1)) != 0)
3726                 radix_bits = 0;
3727             else
3728                 radix_bits = ceil_log2(radix);
3729 
3730             fmt = flags & BF_FTOA_FORMAT_MASK;
3731             bf_init(ctx, a1);
3732             if (fmt == BF_FTOA_FORMAT_FRAC) {
3733                 if (is_dec || radix_bits != 0) {
3734                     if (bf_set(a1, a2))
3735                         goto fail1;
3736 #ifdef USE_BF_DEC
3737                     if (is_dec) {
3738                         if (bfdec_round((bfdec_t *)a1, prec, (flags & BF_RND_MASK) | BF_FLAG_RADPNT_PREC) & BF_ST_MEM_ERROR)
3739                             goto fail1;
3740                         n = a1->expn;
3741                     } else
3742 #endif
3743                     {
3744                         if (bf_round(a1, prec * radix_bits, (flags & BF_RND_MASK) | BF_FLAG_RADPNT_PREC) & BF_ST_MEM_ERROR)
3745                             goto fail1;
3746                         n = ceil_div(a1->expn, radix_bits);
3747                     }
3748                     if (flags & BF_FTOA_ADD_PREFIX) {
3749                         if (radix == 16)
3750                             dbuf_putstr(s, "0x");
3751                         else if (radix == 8)
3752                             dbuf_putstr(s, "0o");
3753                         else if (radix == 2)
3754                             dbuf_putstr(s, "0b");
3755                     }
3756                     if (a1->expn == BF_EXP_ZERO) {
3757                         dbuf_putstr(s, "0");
3758                         if (prec > 0) {
3759                             dbuf_putstr(s, ".");
3760                             for(i = 0; i < prec; i++) {
3761                                 dbuf_putc(s, '0');
3762                             }
3763                         }
3764                     } else {
3765                         n_digits = prec + n;
3766                         if (n <= 0) {
3767                             /* 0.x */
3768                             dbuf_putstr(s, "0.");
3769                             for(i = 0; i < -n; i++) {
3770                                 dbuf_putc(s, '0');
3771                             }
3772                             if (n_digits > 0) {
3773                                 output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3774                             }
3775                         } else {
3776                             output_digits(s, a1, radix, n_digits, n, is_dec);
3777                         }
3778                     }
3779                 } else {
3780                     size_t pos, start;
3781                     bf_t a_s, *a = &a_s;
3782 
3783                     /* make a positive number */
3784                     a->tab = a2->tab;
3785                     a->len = a2->len;
3786                     a->expn = a2->expn;
3787                     a->sign = 0;
3788 
3789                     /* one more digit for the rounding */
3790                     n = 1 + bf_mul_log2_radix(bf_max(a->expn, 0), radix, TRUE, TRUE);
3791                     n_digits = n + prec;
3792                     n1 = n;
3793                     if (bf_convert_to_radix(a1, &n1, a, radix, n_digits,
3794                                             flags & BF_RND_MASK, TRUE))
3795                         goto fail1;
3796                     start = s->size;
3797                     output_digits(s, a1, radix, n_digits, n, is_dec);
3798                     /* remove leading zeros because we allocated one more digit */
3799                     pos = start;
3800                     while ((pos + 1) < s->size && s->buf[pos] == '0' &&
3801                            s->buf[pos + 1] != '.')
3802                         pos++;
3803                     if (pos > start) {
3804                         memmove(s->buf + start, s->buf + pos, s->size - pos);
3805                         s->size -= (pos - start);
3806                     }
3807                 }
3808             } else {
3809 #ifdef USE_BF_DEC
3810                 if (is_dec) {
3811                     if (bf_set(a1, a2))
3812                         goto fail1;
3813                     if (fmt == BF_FTOA_FORMAT_FIXED) {
3814                         n_digits = prec;
3815                         n_max = n_digits;
3816                         if (bfdec_round((bfdec_t *)a1, prec, (flags & BF_RND_MASK)) & BF_ST_MEM_ERROR)
3817                             goto fail1;
3818                     } else {
3819                         /* prec is ignored */
3820                         prec = n_digits = a1->len * LIMB_DIGITS;
3821                         /* remove the trailing zero digits */
3822                         while (n_digits > 1 &&
3823                                get_digit(a1->tab, a1->len, prec - n_digits) == 0) {
3824                             n_digits--;
3825                         }
3826                         n_max = n_digits + 4;
3827                     }
3828                     n = a1->expn;
3829                 } else
3830 #endif
3831                 if (radix_bits != 0) {
3832                     if (bf_set(a1, a2))
3833                         goto fail1;
3834                     if (fmt == BF_FTOA_FORMAT_FIXED) {
3835                         slimb_t prec_bits;
3836                         n_digits = prec;
3837                         n_max = n_digits;
3838                         /* align to the radix point */
3839                         prec_bits = prec * radix_bits -
3840                             smod(-a1->expn, radix_bits);
3841                         if (bf_round(a1, prec_bits,
3842                                      (flags & BF_RND_MASK)) & BF_ST_MEM_ERROR)
3843                             goto fail1;
3844                     } else {
3845                         limb_t digit_mask;
3846                         slimb_t pos;
3847                         /* position of the digit before the most
3848                            significant digit in bits */
3849                         pos = a1->len * LIMB_BITS +
3850                             smod(-a1->expn, radix_bits);
3851                         n_digits = ceil_div(pos, radix_bits);
3852                         /* remove the trailing zero digits */
3853                         digit_mask = ((limb_t)1 << radix_bits) - 1;
3854                         while (n_digits > 1 &&
3855                                (get_bits(a1->tab, a1->len, pos - n_digits * radix_bits) & digit_mask) == 0) {
3856                             n_digits--;
3857                         }
3858                         n_max = n_digits + 4;
3859                     }
3860                     n = ceil_div(a1->expn, radix_bits);
3861                 } else {
3862                     bf_t a_s, *a = &a_s;
3863 
3864                     /* make a positive number */
3865                     a->tab = a2->tab;
3866                     a->len = a2->len;
3867                     a->expn = a2->expn;
3868                     a->sign = 0;
3869 
3870                     if (fmt == BF_FTOA_FORMAT_FIXED) {
3871                         n_digits = prec;
3872                         n_max = n_digits;
3873                     } else {
3874                         slimb_t n_digits_max, n_digits_min;
3875 
3876                         assert(prec != BF_PREC_INF);
3877                         n_digits = 1 + bf_mul_log2_radix(prec, radix, TRUE, TRUE);
3878                         /* max number of digits for non exponential
3879                            notation. The rational is to have the same rule
3880                            as JS i.e. n_max = 21 for 64 bit float in base 10. */
3881                         n_max = n_digits + 4;
3882                         if (fmt == BF_FTOA_FORMAT_FREE_MIN) {
3883                             bf_t b_s, *b = &b_s;
3884 
3885                             /* find the minimum number of digits by
3886                                dichotomy. */
3887                             /* XXX: inefficient */
3888                             n_digits_max = n_digits;
3889                             n_digits_min = 1;
3890                             bf_init(ctx, b);
3891                             while (n_digits_min < n_digits_max) {
3892                                 n_digits = (n_digits_min + n_digits_max) / 2;
3893                                 if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3894                                                         flags & BF_RND_MASK, FALSE)) {
3895                                     bf_delete(b);
3896                                     goto fail1;
3897                                 }
3898                                 /* convert back to a number and compare */
3899                                 ret = bf_mul_pow_radix(b, a1, radix, n - n_digits,
3900                                                        prec,
3901                                                        (flags & ~BF_RND_MASK) |
3902                                                        BF_RNDN);
3903                                 if (ret & BF_ST_MEM_ERROR) {
3904                                     bf_delete(b);
3905                                     goto fail1;
3906                                 }
3907                                 if (bf_cmpu(b, a) == 0) {
3908                                     n_digits_max = n_digits;
3909                                 } else {
3910                                     n_digits_min = n_digits + 1;
3911                                 }
3912                             }
3913                             bf_delete(b);
3914                             n_digits = n_digits_max;
3915                         }
3916                     }
3917                     if (bf_convert_to_radix(a1, &n, a, radix, n_digits,
3918                                             flags & BF_RND_MASK, FALSE)) {
3919                     fail1:
3920                         bf_delete(a1);
3921                         goto fail;
3922                     }
3923                 }
3924                 if (a1->expn == BF_EXP_ZERO &&
3925                     fmt != BF_FTOA_FORMAT_FIXED &&
3926                     !(flags & BF_FTOA_FORCE_EXP)) {
3927                     /* just output zero */
3928                     dbuf_putstr(s, "0");
3929                 } else {
3930                     if (flags & BF_FTOA_ADD_PREFIX) {
3931                         if (radix == 16)
3932                             dbuf_putstr(s, "0x");
3933                         else if (radix == 8)
3934                             dbuf_putstr(s, "0o");
3935                         else if (radix == 2)
3936                             dbuf_putstr(s, "0b");
3937                     }
3938                     if (a1->expn == BF_EXP_ZERO)
3939                         n = 1;
3940                     if ((flags & BF_FTOA_FORCE_EXP) ||
3941                         n <= -6 || n > n_max) {
3942                         const char *fmt;
3943                         /* exponential notation */
3944                         output_digits(s, a1, radix, n_digits, 1, is_dec);
3945                         if (radix_bits != 0 && radix <= 16) {
3946                             if (flags & BF_FTOA_JS_QUIRKS)
3947                                 fmt = "p%+" PRId_LIMB;
3948                             else
3949                                 fmt = "p%" PRId_LIMB;
3950                             dbuf_printf(s, fmt, (n - 1) * radix_bits);
3951                         } else {
3952                             if (flags & BF_FTOA_JS_QUIRKS)
3953                                 fmt = "%c%+" PRId_LIMB;
3954                             else
3955                                 fmt = "%c%" PRId_LIMB;
3956                             dbuf_printf(s, fmt,
3957                                         radix <= 10 ? 'e' : '@', n - 1);
3958                         }
3959                     } else if (n <= 0) {
3960                         /* 0.x */
3961                         dbuf_putstr(s, "0.");
3962                         for(i = 0; i < -n; i++) {
3963                             dbuf_putc(s, '0');
3964                         }
3965                         output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3966                     } else {
3967                         if (n_digits <= n) {
3968                             /* no dot */
3969                             output_digits(s, a1, radix, n_digits, n_digits, is_dec);
3970                             for(i = 0; i < (n - n_digits); i++)
3971                                 dbuf_putc(s, '0');
3972                         } else {
3973                             output_digits(s, a1, radix, n_digits, n, is_dec);
3974                         }
3975                     }
3976                 }
3977             }
3978             bf_delete(a1);
3979         }
3980     }
3981     dbuf_putc(s, '\0');
3982     if (dbuf_error(s))
3983         goto fail;
3984     if (plen)
3985         *plen = s->size - 1;
3986     return (char *)s->buf;
3987  fail:
3988     bf_free(ctx, s->buf);
3989     if (plen)
3990         *plen = 0;
3991     return NULL;
3992 }
3993 
bf_ftoa(size_t * plen,const bf_t * a,int radix,limb_t prec,bf_flags_t flags)3994 char *bf_ftoa(size_t *plen, const bf_t *a, int radix, limb_t prec,
3995               bf_flags_t flags)
3996 {
3997     return bf_ftoa_internal(plen, a, radix, prec, flags, FALSE);
3998 }
3999 
4000 /***************************************************************/
4001 /* transcendental functions */
4002 
4003 /* Note: the algorithm is from MPFR */
bf_const_log2_rec(bf_t * T,bf_t * P,bf_t * Q,limb_t n1,limb_t n2,BOOL need_P)4004 static void bf_const_log2_rec(bf_t *T, bf_t *P, bf_t *Q, limb_t n1,
4005                               limb_t n2, BOOL need_P)
4006 {
4007     bf_context_t *s = T->ctx;
4008     if ((n2 - n1) == 1) {
4009         if (n1 == 0) {
4010             bf_set_ui(P, 3);
4011         } else {
4012             bf_set_ui(P, n1);
4013             P->sign = 1;
4014         }
4015         bf_set_ui(Q, 2 * n1 + 1);
4016         Q->expn += 2;
4017         bf_set(T, P);
4018     } else {
4019         limb_t m;
4020         bf_t T1_s, *T1 = &T1_s;
4021         bf_t P1_s, *P1 = &P1_s;
4022         bf_t Q1_s, *Q1 = &Q1_s;
4023 
4024         m = n1 + ((n2 - n1) >> 1);
4025         bf_const_log2_rec(T, P, Q, n1, m, TRUE);
4026         bf_init(s, T1);
4027         bf_init(s, P1);
4028         bf_init(s, Q1);
4029         bf_const_log2_rec(T1, P1, Q1, m, n2, need_P);
4030         bf_mul(T, T, Q1, BF_PREC_INF, BF_RNDZ);
4031         bf_mul(T1, T1, P, BF_PREC_INF, BF_RNDZ);
4032         bf_add(T, T, T1, BF_PREC_INF, BF_RNDZ);
4033         if (need_P)
4034             bf_mul(P, P, P1, BF_PREC_INF, BF_RNDZ);
4035         bf_mul(Q, Q, Q1, BF_PREC_INF, BF_RNDZ);
4036         bf_delete(T1);
4037         bf_delete(P1);
4038         bf_delete(Q1);
4039     }
4040 }
4041 
4042 /* compute log(2) with faithful rounding at precision 'prec' */
bf_const_log2_internal(bf_t * T,limb_t prec)4043 static void bf_const_log2_internal(bf_t *T, limb_t prec)
4044 {
4045     limb_t w, N;
4046     bf_t P_s, *P = &P_s;
4047     bf_t Q_s, *Q = &Q_s;
4048 
4049     w = prec + 15;
4050     N = w / 3 + 1;
4051     bf_init(T->ctx, P);
4052     bf_init(T->ctx, Q);
4053     bf_const_log2_rec(T, P, Q, 0, N, FALSE);
4054     bf_div(T, T, Q, prec, BF_RNDN);
4055     bf_delete(P);
4056     bf_delete(Q);
4057 }
4058 
4059 /* PI constant */
4060 
4061 #define CHUD_A 13591409
4062 #define CHUD_B 545140134
4063 #define CHUD_C 640320
4064 #define CHUD_BITS_PER_TERM 47
4065 
chud_bs(bf_t * P,bf_t * Q,bf_t * G,int64_t a,int64_t b,int need_g,limb_t prec)4066 static void chud_bs(bf_t *P, bf_t *Q, bf_t *G, int64_t a, int64_t b, int need_g,
4067                     limb_t prec)
4068 {
4069     bf_context_t *s = P->ctx;
4070     int64_t c;
4071 
4072     if (a == (b - 1)) {
4073         bf_t T0, T1;
4074 
4075         bf_init(s, &T0);
4076         bf_init(s, &T1);
4077         bf_set_ui(G, 2 * b - 1);
4078         bf_mul_ui(G, G, 6 * b - 1, prec, BF_RNDN);
4079         bf_mul_ui(G, G, 6 * b - 5, prec, BF_RNDN);
4080         bf_set_ui(&T0, CHUD_B);
4081         bf_mul_ui(&T0, &T0, b, prec, BF_RNDN);
4082         bf_set_ui(&T1, CHUD_A);
4083         bf_add(&T0, &T0, &T1, prec, BF_RNDN);
4084         bf_mul(P, G, &T0, prec, BF_RNDN);
4085         P->sign = b & 1;
4086 
4087         bf_set_ui(Q, b);
4088         bf_mul_ui(Q, Q, b, prec, BF_RNDN);
4089         bf_mul_ui(Q, Q, b, prec, BF_RNDN);
4090         bf_mul_ui(Q, Q, (uint64_t)CHUD_C * CHUD_C * CHUD_C / 24, prec, BF_RNDN);
4091         bf_delete(&T0);
4092         bf_delete(&T1);
4093     } else {
4094         bf_t P2, Q2, G2;
4095 
4096         bf_init(s, &P2);
4097         bf_init(s, &Q2);
4098         bf_init(s, &G2);
4099 
4100         c = (a + b) / 2;
4101         chud_bs(P, Q, G, a, c, 1, prec);
4102         chud_bs(&P2, &Q2, &G2, c, b, need_g, prec);
4103 
4104         /* Q = Q1 * Q2 */
4105         /* G = G1 * G2 */
4106         /* P = P1 * Q2 + P2 * G1 */
4107         bf_mul(&P2, &P2, G, prec, BF_RNDN);
4108         if (!need_g)
4109             bf_set_ui(G, 0);
4110         bf_mul(P, P, &Q2, prec, BF_RNDN);
4111         bf_add(P, P, &P2, prec, BF_RNDN);
4112         bf_delete(&P2);
4113 
4114         bf_mul(Q, Q, &Q2, prec, BF_RNDN);
4115         bf_delete(&Q2);
4116         if (need_g)
4117             bf_mul(G, G, &G2, prec, BF_RNDN);
4118         bf_delete(&G2);
4119     }
4120 }
4121 
4122 /* compute Pi with faithful rounding at precision 'prec' using the
4123    Chudnovsky formula */
bf_const_pi_internal(bf_t * Q,limb_t prec)4124 static void bf_const_pi_internal(bf_t *Q, limb_t prec)
4125 {
4126     bf_context_t *s = Q->ctx;
4127     int64_t n, prec1;
4128     bf_t P, G;
4129 
4130     /* number of serie terms */
4131     n = prec / CHUD_BITS_PER_TERM + 1;
4132     /* XXX: precision analysis */
4133     prec1 = prec + 32;
4134 
4135     bf_init(s, &P);
4136     bf_init(s, &G);
4137 
4138     chud_bs(&P, Q, &G, 0, n, 0, BF_PREC_INF);
4139 
4140     bf_mul_ui(&G, Q, CHUD_A, prec1, BF_RNDN);
4141     bf_add(&P, &G, &P, prec1, BF_RNDN);
4142     bf_div(Q, Q, &P, prec1, BF_RNDF);
4143 
4144     bf_set_ui(&P, CHUD_C);
4145     bf_sqrt(&G, &P, prec1, BF_RNDF);
4146     bf_mul_ui(&G, &G, (uint64_t)CHUD_C / 12, prec1, BF_RNDF);
4147     bf_mul(Q, Q, &G, prec, BF_RNDN);
4148     bf_delete(&P);
4149     bf_delete(&G);
4150 }
4151 
bf_const_get(bf_t * T,limb_t prec,bf_flags_t flags,BFConstCache * c,void (* func)(bf_t * res,limb_t prec),int sign)4152 static int bf_const_get(bf_t *T, limb_t prec, bf_flags_t flags,
4153                         BFConstCache *c,
4154                         void (*func)(bf_t *res, limb_t prec), int sign)
4155 {
4156     limb_t ziv_extra_bits, prec1;
4157 
4158     ziv_extra_bits = 32;
4159     for(;;) {
4160         prec1 = prec + ziv_extra_bits;
4161         if (c->prec < prec1) {
4162             if (c->val.len == 0)
4163                 bf_init(T->ctx, &c->val);
4164             func(&c->val, prec1);
4165             c->prec = prec1;
4166         } else {
4167             prec1 = c->prec;
4168         }
4169         bf_set(T, &c->val);
4170         T->sign = sign;
4171         if (!bf_can_round(T, prec, flags & BF_RND_MASK, prec1)) {
4172             /* and more precision and retry */
4173             ziv_extra_bits = ziv_extra_bits  + (ziv_extra_bits / 2);
4174         } else {
4175             break;
4176         }
4177     }
4178     return bf_round(T, prec, flags);
4179 }
4180 
bf_const_free(BFConstCache * c)4181 static void bf_const_free(BFConstCache *c)
4182 {
4183     bf_delete(&c->val);
4184     memset(c, 0, sizeof(*c));
4185 }
4186 
bf_const_log2(bf_t * T,limb_t prec,bf_flags_t flags)4187 int bf_const_log2(bf_t *T, limb_t prec, bf_flags_t flags)
4188 {
4189     bf_context_t *s = T->ctx;
4190     return bf_const_get(T, prec, flags, &s->log2_cache, bf_const_log2_internal, 0);
4191 }
4192 
4193 /* return rounded pi * (1 - 2 * sign) */
bf_const_pi_signed(bf_t * T,int sign,limb_t prec,bf_flags_t flags)4194 static int bf_const_pi_signed(bf_t *T, int sign, limb_t prec, bf_flags_t flags)
4195 {
4196     bf_context_t *s = T->ctx;
4197     return bf_const_get(T, prec, flags, &s->pi_cache, bf_const_pi_internal,
4198                         sign);
4199 }
4200 
bf_const_pi(bf_t * T,limb_t prec,bf_flags_t flags)4201 int bf_const_pi(bf_t *T, limb_t prec, bf_flags_t flags)
4202 {
4203     return bf_const_pi_signed(T, 0, prec, flags);
4204 }
4205 
bf_clear_cache(bf_context_t * s)4206 void bf_clear_cache(bf_context_t *s)
4207 {
4208 #ifdef USE_FFT_MUL
4209     fft_clear_cache(s);
4210 #endif
4211     bf_const_free(&s->log2_cache);
4212     bf_const_free(&s->pi_cache);
4213 }
4214 
4215 /* ZivFunc should compute the result 'r' with faithful rounding at
4216    precision 'prec'. For efficiency purposes, the final bf_round()
4217    does not need to be done in the function. */
4218 typedef int ZivFunc(bf_t *r, const bf_t *a, limb_t prec, void *opaque);
4219 
bf_ziv_rounding(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags,ZivFunc * f,void * opaque)4220 static int bf_ziv_rounding(bf_t *r, const bf_t *a,
4221                            limb_t prec, bf_flags_t flags,
4222                            ZivFunc *f, void *opaque)
4223 {
4224     int rnd_mode, ret;
4225     slimb_t prec1, ziv_extra_bits;
4226 
4227     rnd_mode = flags & BF_RND_MASK;
4228     if (rnd_mode == BF_RNDF) {
4229         /* no need to iterate */
4230         f(r, a, prec, opaque);
4231         ret = 0;
4232     } else {
4233         ziv_extra_bits = 32;
4234         for(;;) {
4235             prec1 = prec + ziv_extra_bits;
4236             ret = f(r, a, prec1, opaque);
4237             if (ret & (BF_ST_OVERFLOW | BF_ST_UNDERFLOW | BF_ST_MEM_ERROR)) {
4238                 /* overflow or underflow should never happen because
4239                    it indicates the rounding cannot be done correctly,
4240                    but we do not catch all the cases */
4241                 return ret;
4242             }
4243             /* if the result is exact, we can stop */
4244             if (!(ret & BF_ST_INEXACT)) {
4245                 ret = 0;
4246                 break;
4247             }
4248             if (bf_can_round(r, prec, rnd_mode, prec1)) {
4249                 ret = BF_ST_INEXACT;
4250                 break;
4251             }
4252             ziv_extra_bits = ziv_extra_bits * 2;
4253             //            printf("ziv_extra_bits=%" PRId64 "\n", (int64_t)ziv_extra_bits);
4254         }
4255     }
4256     if (r->len == 0)
4257         return ret;
4258     else
4259         return __bf_round(r, prec, flags, r->len, ret);
4260 }
4261 
4262 /* add (1 - 2*e_sign) * 2^e */
bf_add_epsilon(bf_t * r,const bf_t * a,slimb_t e,int e_sign,limb_t prec,int flags)4263 static int bf_add_epsilon(bf_t *r, const bf_t *a, slimb_t e, int e_sign,
4264                           limb_t prec, int flags)
4265 {
4266     bf_t T_s, *T = &T_s;
4267     int ret;
4268     /* small argument case: result = 1 + epsilon * sign(x) */
4269     bf_init(a->ctx, T);
4270     bf_set_ui(T, 1);
4271     T->sign = e_sign;
4272     T->expn += e;
4273     ret = bf_add(r, r, T, prec, flags);
4274     bf_delete(T);
4275     return ret;
4276 }
4277 
4278 /* Compute the exponential using faithful rounding at precision 'prec'.
4279    Note: the algorithm is from MPFR */
bf_exp_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4280 static int bf_exp_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4281 {
4282     bf_context_t *s = r->ctx;
4283     bf_t T_s, *T = &T_s;
4284     slimb_t n, K, l, i, prec1;
4285 
4286     assert(r != a);
4287 
4288     /* argument reduction:
4289        T = a - n*log(2) with 0 <= T < log(2) and n integer.
4290     */
4291     bf_init(s, T);
4292     if (a->expn <= -1) {
4293         /* 0 <= abs(a) <= 0.5 */
4294         if (a->sign)
4295             n = -1;
4296         else
4297             n = 0;
4298     } else {
4299         bf_const_log2(T, LIMB_BITS, BF_RNDZ);
4300         bf_div(T, a, T, LIMB_BITS, BF_RNDD);
4301         bf_get_limb(&n, T, 0);
4302     }
4303 
4304     K = bf_isqrt((prec + 1) / 2);
4305     l = (prec - 1) / K + 1;
4306     /* XXX: precision analysis ? */
4307     prec1 = prec + (K + 2 * l + 18) + K + 8;
4308     if (a->expn > 0)
4309         prec1 += a->expn;
4310     //    printf("n=%ld K=%ld prec1=%ld\n", n, K, prec1);
4311 
4312     bf_const_log2(T, prec1, BF_RNDF);
4313     bf_mul_si(T, T, n, prec1, BF_RNDN);
4314     bf_sub(T, a, T, prec1, BF_RNDN);
4315 
4316     /* reduce the range of T */
4317     bf_mul_2exp(T, -K, BF_PREC_INF, BF_RNDZ);
4318 
4319     /* Taylor expansion around zero :
4320      1 + x + x^2/2 + ... + x^n/n!
4321      = (1 + x * (1 + x/2 * (1 + ... (x/n))))
4322     */
4323     {
4324         bf_t U_s, *U = &U_s;
4325 
4326         bf_init(s, U);
4327         bf_set_ui(r, 1);
4328         for(i = l ; i >= 1; i--) {
4329             bf_set_ui(U, i);
4330             bf_div(U, T, U, prec1, BF_RNDN);
4331             bf_mul(r, r, U, prec1, BF_RNDN);
4332             bf_add_si(r, r, 1, prec1, BF_RNDN);
4333         }
4334         bf_delete(U);
4335     }
4336     bf_delete(T);
4337 
4338     /* undo the range reduction */
4339     for(i = 0; i < K; i++) {
4340         bf_mul(r, r, r, prec1, BF_RNDN | BF_FLAG_EXT_EXP);
4341     }
4342 
4343     /* undo the argument reduction */
4344     bf_mul_2exp(r, n, BF_PREC_INF, BF_RNDZ | BF_FLAG_EXT_EXP);
4345 
4346     return BF_ST_INEXACT;
4347 }
4348 
4349 /* crude overflow and underflow tests for exp(a). a_low <= a <= a_high */
check_exp_underflow_overflow(bf_context_t * s,bf_t * r,const bf_t * a_low,const bf_t * a_high,limb_t prec,bf_flags_t flags)4350 static int check_exp_underflow_overflow(bf_context_t *s, bf_t *r,
4351                                         const bf_t *a_low, const bf_t *a_high,
4352                                         limb_t prec, bf_flags_t flags)
4353 {
4354     bf_t T_s, *T = &T_s;
4355     bf_t log2_s, *log2 = &log2_s;
4356     slimb_t e_min, e_max;
4357 
4358     if (a_high->expn <= 0)
4359         return 0;
4360 
4361     e_max = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
4362     e_min = -e_max + 3;
4363     if (flags & BF_FLAG_SUBNORMAL)
4364         e_min -= (prec - 1);
4365 
4366     bf_init(s, T);
4367     bf_init(s, log2);
4368     bf_const_log2(log2, LIMB_BITS, BF_RNDU);
4369     bf_mul_ui(T, log2, e_max, LIMB_BITS, BF_RNDU);
4370     /* a_low > e_max * log(2) implies exp(a) > e_max */
4371     if (bf_cmp_lt(T, a_low) > 0) {
4372         /* overflow */
4373         bf_delete(T);
4374         bf_delete(log2);
4375         return bf_set_overflow(r, 0, prec, flags);
4376     }
4377     /* a_high < (e_min - 2) * log(2) implies exp(a) < (e_min - 2) */
4378     bf_const_log2(log2, LIMB_BITS, BF_RNDD);
4379     bf_mul_si(T, log2, e_min - 2, LIMB_BITS, BF_RNDD);
4380     if (bf_cmp_lt(a_high, T)) {
4381         int rnd_mode = flags & BF_RND_MASK;
4382 
4383         /* underflow */
4384         bf_delete(T);
4385         bf_delete(log2);
4386         if (rnd_mode == BF_RNDU) {
4387             /* set the smallest value */
4388             bf_set_ui(r, 1);
4389             r->expn = e_min;
4390         } else {
4391             bf_set_zero(r, 0);
4392         }
4393         return BF_ST_UNDERFLOW | BF_ST_INEXACT;
4394     }
4395     bf_delete(log2);
4396     bf_delete(T);
4397     return 0;
4398 }
4399 
bf_exp(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4400 int bf_exp(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4401 {
4402     bf_context_t *s = r->ctx;
4403     int ret;
4404     assert(r != a);
4405     if (a->len == 0) {
4406         if (a->expn == BF_EXP_NAN) {
4407             bf_set_nan(r);
4408         } else if (a->expn == BF_EXP_INF) {
4409             if (a->sign)
4410                 bf_set_zero(r, 0);
4411             else
4412                 bf_set_inf(r, 0);
4413         } else {
4414             bf_set_ui(r, 1);
4415         }
4416         return 0;
4417     }
4418 
4419     ret = check_exp_underflow_overflow(s, r, a, a, prec, flags);
4420     if (ret)
4421         return ret;
4422     if (a->expn < 0 && (-a->expn) >= (prec + 2)) {
4423         /* small argument case: result = 1 + epsilon * sign(x) */
4424         bf_set_ui(r, 1);
4425         return bf_add_epsilon(r, r, -(prec + 2), a->sign, prec, flags);
4426     }
4427 
4428     return bf_ziv_rounding(r, a, prec, flags, bf_exp_internal, NULL);
4429 }
4430 
bf_log_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4431 static int bf_log_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4432 {
4433     bf_context_t *s = r->ctx;
4434     bf_t T_s, *T = &T_s;
4435     bf_t U_s, *U = &U_s;
4436     bf_t V_s, *V = &V_s;
4437     slimb_t n, prec1, l, i, K;
4438 
4439     assert(r != a);
4440 
4441     bf_init(s, T);
4442     /* argument reduction 1 */
4443     /* T=a*2^n with 2/3 <= T <= 4/3 */
4444     {
4445         bf_t U_s, *U = &U_s;
4446         bf_set(T, a);
4447         n = T->expn;
4448         T->expn = 0;
4449         /* U= ~ 2/3 */
4450         bf_init(s, U);
4451         bf_set_ui(U, 0xaaaaaaaa);
4452         U->expn = 0;
4453         if (bf_cmp_lt(T, U)) {
4454             T->expn++;
4455             n--;
4456         }
4457         bf_delete(U);
4458     }
4459     //    printf("n=%ld\n", n);
4460     //    bf_print_str("T", T);
4461 
4462     /* XXX: precision analysis */
4463     /* number of iterations for argument reduction 2 */
4464     K = bf_isqrt((prec + 1) / 2);
4465     /* order of Taylor expansion */
4466     l = prec / (2 * K) + 1;
4467     /* precision of the intermediate computations */
4468     prec1 = prec + K + 2 * l + 32;
4469 
4470     bf_init(s, U);
4471     bf_init(s, V);
4472 
4473     /* Note: cancellation occurs here, so we use more precision (XXX:
4474        reduce the precision by computing the exact cancellation) */
4475     bf_add_si(T, T, -1, BF_PREC_INF, BF_RNDN);
4476 
4477     /* argument reduction 2 */
4478     for(i = 0; i < K; i++) {
4479         /* T = T / (1 + sqrt(1 + T)) */
4480         bf_add_si(U, T, 1, prec1, BF_RNDN);
4481         bf_sqrt(V, U, prec1, BF_RNDF);
4482         bf_add_si(U, V, 1, prec1, BF_RNDN);
4483         bf_div(T, T, U, prec1, BF_RNDN);
4484     }
4485 
4486     {
4487         bf_t Y_s, *Y = &Y_s;
4488         bf_t Y2_s, *Y2 = &Y2_s;
4489         bf_init(s, Y);
4490         bf_init(s, Y2);
4491 
4492         /* compute ln(1+x) = ln((1+y)/(1-y)) with y=x/(2+x)
4493            = y + y^3/3 + ... + y^(2*l + 1) / (2*l+1)
4494            with Y=Y^2
4495            = y*(1+Y/3+Y^2/5+...) = y*(1+Y*(1/3+Y*(1/5 + ...)))
4496         */
4497         bf_add_si(Y, T, 2, prec1, BF_RNDN);
4498         bf_div(Y, T, Y, prec1, BF_RNDN);
4499 
4500         bf_mul(Y2, Y, Y, prec1, BF_RNDN);
4501         bf_set_ui(r, 0);
4502         for(i = l; i >= 1; i--) {
4503             bf_set_ui(U, 1);
4504             bf_set_ui(V, 2 * i + 1);
4505             bf_div(U, U, V, prec1, BF_RNDN);
4506             bf_add(r, r, U, prec1, BF_RNDN);
4507             bf_mul(r, r, Y2, prec1, BF_RNDN);
4508         }
4509         bf_add_si(r, r, 1, prec1, BF_RNDN);
4510         bf_mul(r, r, Y, prec1, BF_RNDN);
4511         bf_delete(Y);
4512         bf_delete(Y2);
4513     }
4514     bf_delete(V);
4515     bf_delete(U);
4516 
4517     /* multiplication by 2 for the Taylor expansion and undo the
4518        argument reduction 2*/
4519     bf_mul_2exp(r, K + 1, BF_PREC_INF, BF_RNDZ);
4520 
4521     /* undo the argument reduction 1 */
4522     bf_const_log2(T, prec1, BF_RNDF);
4523     bf_mul_si(T, T, n, prec1, BF_RNDN);
4524     bf_add(r, r, T, prec1, BF_RNDN);
4525 
4526     bf_delete(T);
4527     return BF_ST_INEXACT;
4528 }
4529 
bf_log(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4530 int bf_log(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4531 {
4532     bf_context_t *s = r->ctx;
4533     bf_t T_s, *T = &T_s;
4534 
4535     assert(r != a);
4536     if (a->len == 0) {
4537         if (a->expn == BF_EXP_NAN) {
4538             bf_set_nan(r);
4539             return 0;
4540         } else if (a->expn == BF_EXP_INF) {
4541             if (a->sign) {
4542                 bf_set_nan(r);
4543                 return BF_ST_INVALID_OP;
4544             } else {
4545                 bf_set_inf(r, 0);
4546                 return 0;
4547             }
4548         } else {
4549             bf_set_inf(r, 1);
4550             return 0;
4551         }
4552     }
4553     if (a->sign) {
4554         bf_set_nan(r);
4555         return BF_ST_INVALID_OP;
4556     }
4557     bf_init(s, T);
4558     bf_set_ui(T, 1);
4559     if (bf_cmp_eq(a, T)) {
4560         bf_set_zero(r, 0);
4561         bf_delete(T);
4562         return 0;
4563     }
4564     bf_delete(T);
4565 
4566     return bf_ziv_rounding(r, a, prec, flags, bf_log_internal, NULL);
4567 }
4568 
4569 /* x and y finite and x > 0 */
bf_pow_generic(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4570 static int bf_pow_generic(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4571 {
4572     bf_context_t *s = r->ctx;
4573     const bf_t *y = opaque;
4574     bf_t T_s, *T = &T_s;
4575     limb_t prec1;
4576 
4577     bf_init(s, T);
4578     /* XXX: proof for the added precision */
4579     prec1 = prec + 32;
4580     bf_log(T, x, prec1, BF_RNDF | BF_FLAG_EXT_EXP);
4581     bf_mul(T, T, y, prec1, BF_RNDF | BF_FLAG_EXT_EXP);
4582     if (bf_is_nan(T))
4583         bf_set_nan(r);
4584     else
4585         bf_exp_internal(r, T, prec1, NULL); /* no overflow/underlow test needed */
4586     bf_delete(T);
4587     return BF_ST_INEXACT;
4588 }
4589 
4590 /* x and y finite, x > 0, y integer and y fits on one limb */
bf_pow_int(bf_t * r,const bf_t * x,limb_t prec,void * opaque)4591 static int bf_pow_int(bf_t *r, const bf_t *x, limb_t prec, void *opaque)
4592 {
4593     bf_context_t *s = r->ctx;
4594     const bf_t *y = opaque;
4595     bf_t T_s, *T = &T_s;
4596     limb_t prec1;
4597     int ret;
4598     slimb_t y1;
4599 
4600     bf_get_limb(&y1, y, 0);
4601     if (y1 < 0)
4602         y1 = -y1;
4603     /* XXX: proof for the added precision */
4604     prec1 = prec + ceil_log2(y1) * 2 + 8;
4605     ret = bf_pow_ui(r, x, y1 < 0 ? -y1 : y1, prec1, BF_RNDN | BF_FLAG_EXT_EXP);
4606     if (y->sign) {
4607         bf_init(s, T);
4608         bf_set_ui(T, 1);
4609         ret |= bf_div(r, T, r, prec1, BF_RNDN | BF_FLAG_EXT_EXP);
4610         bf_delete(T);
4611     }
4612     return ret;
4613 }
4614 
4615 /* x must be a finite non zero float. Return TRUE if there is a
4616    floating point number r such as x=r^(2^n) and return this floating
4617    point number 'r'. Otherwise return FALSE and r is undefined. */
check_exact_power2n(bf_t * r,const bf_t * x,slimb_t n)4618 static BOOL check_exact_power2n(bf_t *r, const bf_t *x, slimb_t n)
4619 {
4620     bf_context_t *s = r->ctx;
4621     bf_t T_s, *T = &T_s;
4622     slimb_t e, i, er;
4623     limb_t v;
4624 
4625     /* x = m*2^e with m odd integer */
4626     e = bf_get_exp_min(x);
4627     /* fast check on the exponent */
4628     if (n > (LIMB_BITS - 1)) {
4629         if (e != 0)
4630             return FALSE;
4631         er = 0;
4632     } else {
4633         if ((e & (((limb_t)1 << n) - 1)) != 0)
4634             return FALSE;
4635         er = e >> n;
4636     }
4637     /* every perfect odd square = 1 modulo 8 */
4638     v = get_bits(x->tab, x->len, x->len * LIMB_BITS - x->expn + e);
4639     if ((v & 7) != 1)
4640         return FALSE;
4641 
4642     bf_init(s, T);
4643     bf_set(T, x);
4644     T->expn -= e;
4645     for(i = 0; i < n; i++) {
4646         if (i != 0)
4647             bf_set(T, r);
4648         if (bf_sqrtrem(r, NULL, T) != 0)
4649             return FALSE;
4650     }
4651     r->expn += er;
4652     return TRUE;
4653 }
4654 
4655 /* prec = BF_PREC_INF is accepted for x and y integers and y >= 0 */
bf_pow(bf_t * r,const bf_t * x,const bf_t * y,limb_t prec,bf_flags_t flags)4656 int bf_pow(bf_t *r, const bf_t *x, const bf_t *y, limb_t prec, bf_flags_t flags)
4657 {
4658     bf_context_t *s = r->ctx;
4659     bf_t T_s, *T = &T_s;
4660     bf_t ytmp_s;
4661     BOOL y_is_int, y_is_odd;
4662     int r_sign, ret, rnd_mode;
4663     slimb_t y_emin;
4664 
4665     if (x->len == 0 || y->len == 0) {
4666         if (y->expn == BF_EXP_ZERO) {
4667             /* pow(x, 0) = 1 */
4668             bf_set_ui(r, 1);
4669         } else if (x->expn == BF_EXP_NAN) {
4670             bf_set_nan(r);
4671         } else {
4672             int cmp_x_abs_1;
4673             bf_set_ui(r, 1);
4674             cmp_x_abs_1 = bf_cmpu(x, r);
4675             if (cmp_x_abs_1 == 0 && (flags & BF_POW_JS_QUIRKS) &&
4676                 (y->expn >= BF_EXP_INF)) {
4677                 bf_set_nan(r);
4678             } else if (cmp_x_abs_1 == 0 &&
4679                        (!x->sign || y->expn != BF_EXP_NAN)) {
4680                 /* pow(1, y) = 1 even if y = NaN */
4681                 /* pow(-1, +/-inf) = 1 */
4682             } else if (y->expn == BF_EXP_NAN) {
4683                 bf_set_nan(r);
4684             } else if (y->expn == BF_EXP_INF) {
4685                 if (y->sign == (cmp_x_abs_1 > 0)) {
4686                     bf_set_zero(r, 0);
4687                 } else {
4688                     bf_set_inf(r, 0);
4689                 }
4690             } else {
4691                 y_emin = bf_get_exp_min(y);
4692                 y_is_odd = (y_emin == 0);
4693                 if (y->sign == (x->expn == BF_EXP_ZERO)) {
4694                     bf_set_inf(r, y_is_odd & x->sign);
4695                     if (y->sign) {
4696                         /* pow(0, y) with y < 0 */
4697                         return BF_ST_DIVIDE_ZERO;
4698                     }
4699                 } else {
4700                     bf_set_zero(r, y_is_odd & x->sign);
4701                 }
4702             }
4703         }
4704         return 0;
4705     }
4706     bf_init(s, T);
4707     bf_set(T, x);
4708     y_emin = bf_get_exp_min(y);
4709     y_is_int = (y_emin >= 0);
4710     rnd_mode = flags & BF_RND_MASK;
4711     if (x->sign) {
4712         if (!y_is_int) {
4713             bf_set_nan(r);
4714             bf_delete(T);
4715             return BF_ST_INVALID_OP;
4716         }
4717         y_is_odd = (y_emin == 0);
4718         r_sign = y_is_odd;
4719         /* change the directed rounding mode if the sign of the result
4720            is changed */
4721         if (r_sign && (rnd_mode == BF_RNDD || rnd_mode == BF_RNDU))
4722             flags ^= 1;
4723         bf_neg(T);
4724     } else {
4725         r_sign = 0;
4726     }
4727 
4728     bf_set_ui(r, 1);
4729     if (bf_cmp_eq(T, r)) {
4730         /* abs(x) = 1: nothing more to do */
4731         ret = 0;
4732     } else {
4733         /* check the overflow/underflow cases */
4734         {
4735             bf_t al_s, *al = &al_s;
4736             bf_t ah_s, *ah = &ah_s;
4737             limb_t precl = LIMB_BITS;
4738 
4739             bf_init(s, al);
4740             bf_init(s, ah);
4741             /* compute bounds of log(abs(x)) * y with a low precision */
4742             /* XXX: compute bf_log() once */
4743             /* XXX: add a fast test before this slow test */
4744             bf_log(al, T, precl, BF_RNDD);
4745             bf_log(ah, T, precl, BF_RNDU);
4746             bf_mul(al, al, y, precl, BF_RNDD ^ y->sign);
4747             bf_mul(ah, ah, y, precl, BF_RNDU ^ y->sign);
4748             ret = check_exp_underflow_overflow(s, r, al, ah, prec, flags);
4749             bf_delete(al);
4750             bf_delete(ah);
4751             if (ret)
4752                 goto done;
4753         }
4754 
4755         if (y_is_int) {
4756             slimb_t T_bits, e;
4757         int_pow:
4758             T_bits = T->expn - bf_get_exp_min(T);
4759             if (T_bits == 1) {
4760                 /* pow(2^b, y) = 2^(b*y) */
4761                 bf_mul_si(T, y, T->expn - 1, LIMB_BITS, BF_RNDZ);
4762                 bf_get_limb(&e, T, 0);
4763                 bf_set_ui(r, 1);
4764                 ret = bf_mul_2exp(r, e, prec, flags);
4765             } else if (prec == BF_PREC_INF) {
4766                 slimb_t y1;
4767                 /* specific case for infinite precision (integer case) */
4768                 bf_get_limb(&y1, y, 0);
4769                 assert(!y->sign);
4770                 /* x must be an integer, so abs(x) >= 2 */
4771                 if (y1 >= ((slimb_t)1 << BF_EXP_BITS_MAX)) {
4772                     bf_delete(T);
4773                     return bf_set_overflow(r, 0, BF_PREC_INF, flags);
4774                 }
4775                 ret = bf_pow_ui(r, T, y1, BF_PREC_INF, BF_RNDZ);
4776             } else {
4777                 if (y->expn <= 31) {
4778                     /* small enough power: use exponentiation in all cases */
4779                 } else if (y->sign) {
4780                     /* cannot be exact */
4781                     goto general_case;
4782                 } else {
4783                     if (rnd_mode == BF_RNDF)
4784                         goto general_case; /* no need to track exact results */
4785                     /* see if the result has a chance to be exact:
4786                        if x=a*2^b (a odd), x^y=a^y*2^(b*y)
4787                        x^y needs a precision of at least floor_log2(a)*y bits
4788                     */
4789                     bf_mul_si(r, y, T_bits - 1, LIMB_BITS, BF_RNDZ);
4790                     bf_get_limb(&e, r, 0);
4791                     if (prec < e)
4792                         goto general_case;
4793                 }
4794                 ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_int, (void *)y);
4795             }
4796         } else {
4797             if (rnd_mode != BF_RNDF) {
4798                 bf_t *y1;
4799                 if (y_emin < 0 && check_exact_power2n(r, T, -y_emin)) {
4800                     /* the problem is reduced to a power to an integer */
4801 #if 0
4802                     printf("\nn=%" PRId64 "\n", -(int64_t)y_emin);
4803                     bf_print_str("T", T);
4804                     bf_print_str("r", r);
4805 #endif
4806                     bf_set(T, r);
4807                     y1 = &ytmp_s;
4808                     y1->tab = y->tab;
4809                     y1->len = y->len;
4810                     y1->sign = y->sign;
4811                     y1->expn = y->expn - y_emin;
4812                     y = y1;
4813                     goto int_pow;
4814                 }
4815             }
4816         general_case:
4817             ret = bf_ziv_rounding(r, T, prec, flags, bf_pow_generic, (void *)y);
4818         }
4819     }
4820  done:
4821     bf_delete(T);
4822     r->sign = r_sign;
4823     return ret;
4824 }
4825 
4826 /* compute sqrt(-2*x-x^2) to get |sin(x)| from cos(x) - 1. */
bf_sqrt_sin(bf_t * r,const bf_t * x,limb_t prec1)4827 static void bf_sqrt_sin(bf_t *r, const bf_t *x, limb_t prec1)
4828 {
4829     bf_context_t *s = r->ctx;
4830     bf_t T_s, *T = &T_s;
4831     bf_init(s, T);
4832     bf_set(T, x);
4833     bf_mul(r, T, T, prec1, BF_RNDN);
4834     bf_mul_2exp(T, 1, BF_PREC_INF, BF_RNDZ);
4835     bf_add(T, T, r, prec1, BF_RNDN);
4836     bf_neg(T);
4837     bf_sqrt(r, T, prec1, BF_RNDF);
4838     bf_delete(T);
4839 }
4840 
bf_sincos(bf_t * s,bf_t * c,const bf_t * a,limb_t prec)4841 static int bf_sincos(bf_t *s, bf_t *c, const bf_t *a, limb_t prec)
4842 {
4843     bf_context_t *s1 = a->ctx;
4844     bf_t T_s, *T = &T_s;
4845     bf_t U_s, *U = &U_s;
4846     bf_t r_s, *r = &r_s;
4847     slimb_t K, prec1, i, l, mod, prec2;
4848     int is_neg;
4849 
4850     assert(c != a && s != a);
4851 
4852     bf_init(s1, T);
4853     bf_init(s1, U);
4854     bf_init(s1, r);
4855 
4856     /* XXX: precision analysis */
4857     K = bf_isqrt(prec / 2);
4858     l = prec / (2 * K) + 1;
4859     prec1 = prec + 2 * K + l + 8;
4860 
4861     /* after the modulo reduction, -pi/4 <= T <= pi/4 */
4862     if (a->expn <= -1) {
4863         /* abs(a) <= 0.25: no modulo reduction needed */
4864         bf_set(T, a);
4865         mod = 0;
4866     } else {
4867         slimb_t cancel;
4868         cancel = 0;
4869         for(;;) {
4870             prec2 = prec1 + a->expn + cancel;
4871             bf_const_pi(U, prec2, BF_RNDF);
4872             bf_mul_2exp(U, -1, BF_PREC_INF, BF_RNDZ);
4873             bf_remquo(&mod, T, a, U, prec2, BF_RNDN, BF_RNDN);
4874             //            printf("T.expn=%ld prec2=%ld\n", T->expn, prec2);
4875             if (mod == 0 || (T->expn != BF_EXP_ZERO &&
4876                              (T->expn + prec2) >= (prec1 - 1)))
4877                 break;
4878             /* increase the number of bits until the precision is good enough */
4879             cancel = bf_max(-T->expn, (cancel + 1) * 3 / 2);
4880         }
4881         mod &= 3;
4882     }
4883 
4884     is_neg = T->sign;
4885 
4886     /* compute cosm1(x) = cos(x) - 1 */
4887     bf_mul(T, T, T, prec1, BF_RNDN);
4888     bf_mul_2exp(T, -2 * K, BF_PREC_INF, BF_RNDZ);
4889 
4890     /* Taylor expansion:
4891        -x^2/2 + x^4/4! - x^6/6! + ...
4892     */
4893     bf_set_ui(r, 1);
4894     for(i = l ; i >= 1; i--) {
4895         bf_set_ui(U, 2 * i - 1);
4896         bf_mul_ui(U, U, 2 * i, BF_PREC_INF, BF_RNDZ);
4897         bf_div(U, T, U, prec1, BF_RNDN);
4898         bf_mul(r, r, U, prec1, BF_RNDN);
4899         bf_neg(r);
4900         if (i != 1)
4901             bf_add_si(r, r, 1, prec1, BF_RNDN);
4902     }
4903     bf_delete(U);
4904 
4905     /* undo argument reduction:
4906        cosm1(2*x)= 2*(2*cosm1(x)+cosm1(x)^2)
4907     */
4908     for(i = 0; i < K; i++) {
4909         bf_mul(T, r, r, prec1, BF_RNDN);
4910         bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4911         bf_add(r, r, T, prec1, BF_RNDN);
4912         bf_mul_2exp(r, 1, BF_PREC_INF, BF_RNDZ);
4913     }
4914     bf_delete(T);
4915 
4916     if (c) {
4917         if ((mod & 1) == 0) {
4918             bf_add_si(c, r, 1, prec1, BF_RNDN);
4919         } else {
4920             bf_sqrt_sin(c, r, prec1);
4921             c->sign = is_neg ^ 1;
4922         }
4923         c->sign ^= mod >> 1;
4924     }
4925     if (s) {
4926         if ((mod & 1) == 0) {
4927             bf_sqrt_sin(s, r, prec1);
4928             s->sign = is_neg;
4929         } else {
4930             bf_add_si(s, r, 1, prec1, BF_RNDN);
4931         }
4932         s->sign ^= mod >> 1;
4933     }
4934     bf_delete(r);
4935     return BF_ST_INEXACT;
4936 }
4937 
bf_cos_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4938 static int bf_cos_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4939 {
4940     return bf_sincos(NULL, r, a, prec);
4941 }
4942 
bf_cos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4943 int bf_cos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4944 {
4945     if (a->len == 0) {
4946         if (a->expn == BF_EXP_NAN) {
4947             bf_set_nan(r);
4948             return 0;
4949         } else if (a->expn == BF_EXP_INF) {
4950             bf_set_nan(r);
4951             return BF_ST_INVALID_OP;
4952         } else {
4953             bf_set_ui(r, 1);
4954             return 0;
4955         }
4956     }
4957 
4958     /* small argument case: result = 1+r(x) with r(x) = -x^2/2 +
4959        O(X^4). We assume r(x) < 2^(2*EXP(x) - 1). */
4960     if (a->expn < 0) {
4961         slimb_t e;
4962         e = 2 * a->expn - 1;
4963         if (e < -(prec + 2)) {
4964             bf_set_ui(r, 1);
4965             return bf_add_epsilon(r, r, e, 1, prec, flags);
4966         }
4967     }
4968 
4969     return bf_ziv_rounding(r, a, prec, flags, bf_cos_internal, NULL);
4970 }
4971 
bf_sin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)4972 static int bf_sin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
4973 {
4974     return bf_sincos(r, NULL, a, prec);
4975 }
4976 
bf_sin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)4977 int bf_sin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
4978 {
4979     if (a->len == 0) {
4980         if (a->expn == BF_EXP_NAN) {
4981             bf_set_nan(r);
4982             return 0;
4983         } else if (a->expn == BF_EXP_INF) {
4984             bf_set_nan(r);
4985             return BF_ST_INVALID_OP;
4986         } else {
4987             bf_set_zero(r, a->sign);
4988             return 0;
4989         }
4990     }
4991 
4992     /* small argument case: result = x+r(x) with r(x) = -x^3/6 +
4993        O(X^5). We assume r(x) < 2^(3*EXP(x) - 2). */
4994     if (a->expn < 0) {
4995         slimb_t e;
4996         e = sat_add(2 * a->expn, a->expn - 2);
4997         if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
4998             bf_set(r, a);
4999             return bf_add_epsilon(r, r, e, 1 - a->sign, prec, flags);
5000         }
5001     }
5002 
5003     return bf_ziv_rounding(r, a, prec, flags, bf_sin_internal, NULL);
5004 }
5005 
bf_tan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)5006 static int bf_tan_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
5007 {
5008     bf_context_t *s = r->ctx;
5009     bf_t T_s, *T = &T_s;
5010     limb_t prec1;
5011 
5012     /* XXX: precision analysis */
5013     prec1 = prec + 8;
5014     bf_init(s, T);
5015     bf_sincos(r, T, a, prec1);
5016     bf_div(r, r, T, prec1, BF_RNDF);
5017     bf_delete(T);
5018     return BF_ST_INEXACT;
5019 }
5020 
bf_tan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5021 int bf_tan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5022 {
5023     assert(r != a);
5024     if (a->len == 0) {
5025         if (a->expn == BF_EXP_NAN) {
5026             bf_set_nan(r);
5027             return 0;
5028         } else if (a->expn == BF_EXP_INF) {
5029             bf_set_nan(r);
5030             return BF_ST_INVALID_OP;
5031         } else {
5032             bf_set_zero(r, a->sign);
5033             return 0;
5034         }
5035     }
5036 
5037     /* small argument case: result = x+r(x) with r(x) = x^3/3 +
5038        O(X^5). We assume r(x) < 2^(3*EXP(x) - 1). */
5039     if (a->expn < 0) {
5040         slimb_t e;
5041         e = sat_add(2 * a->expn, a->expn - 1);
5042         if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
5043             bf_set(r, a);
5044             return bf_add_epsilon(r, r, e, a->sign, prec, flags);
5045         }
5046     }
5047 
5048     return bf_ziv_rounding(r, a, prec, flags, bf_tan_internal, NULL);
5049 }
5050 
5051 /* if add_pi2 is true, add pi/2 to the result (used for acos(x) to
5052    avoid cancellation) */
bf_atan_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)5053 static int bf_atan_internal(bf_t *r, const bf_t *a, limb_t prec,
5054                             void *opaque)
5055 {
5056     bf_context_t *s = r->ctx;
5057     BOOL add_pi2 = (BOOL)(intptr_t)opaque;
5058     bf_t T_s, *T = &T_s;
5059     bf_t U_s, *U = &U_s;
5060     bf_t V_s, *V = &V_s;
5061     bf_t X2_s, *X2 = &X2_s;
5062     int cmp_1;
5063     slimb_t prec1, i, K, l;
5064 
5065     /* XXX: precision analysis */
5066     K = bf_isqrt((prec + 1) / 2);
5067     l = prec / (2 * K) + 1;
5068     prec1 = prec + K + 2 * l + 32;
5069     //    printf("prec=%d K=%d l=%d prec1=%d\n", (int)prec, (int)K, (int)l, (int)prec1);
5070 
5071     bf_init(s, T);
5072     cmp_1 = (a->expn >= 1); /* a >= 1 */
5073     if (cmp_1) {
5074         bf_set_ui(T, 1);
5075         bf_div(T, T, a, prec1, BF_RNDN);
5076     } else {
5077         bf_set(T, a);
5078     }
5079 
5080     /* abs(T) <= 1 */
5081 
5082     /* argument reduction */
5083 
5084     bf_init(s, U);
5085     bf_init(s, V);
5086     bf_init(s, X2);
5087     for(i = 0; i < K; i++) {
5088         /* T = T / (1 + sqrt(1 + T^2)) */
5089         bf_mul(U, T, T, prec1, BF_RNDN);
5090         bf_add_si(U, U, 1, prec1, BF_RNDN);
5091         bf_sqrt(V, U, prec1, BF_RNDN);
5092         bf_add_si(V, V, 1, prec1, BF_RNDN);
5093         bf_div(T, T, V, prec1, BF_RNDN);
5094     }
5095 
5096     /* Taylor series:
5097        x - x^3/3 + ... + (-1)^ l * y^(2*l + 1) / (2*l+1)
5098     */
5099     bf_mul(X2, T, T, prec1, BF_RNDN);
5100     bf_set_ui(r, 0);
5101     for(i = l; i >= 1; i--) {
5102         bf_set_si(U, 1);
5103         bf_set_ui(V, 2 * i + 1);
5104         bf_div(U, U, V, prec1, BF_RNDN);
5105         bf_neg(r);
5106         bf_add(r, r, U, prec1, BF_RNDN);
5107         bf_mul(r, r, X2, prec1, BF_RNDN);
5108     }
5109     bf_neg(r);
5110     bf_add_si(r, r, 1, prec1, BF_RNDN);
5111     bf_mul(r, r, T, prec1, BF_RNDN);
5112 
5113     /* undo the argument reduction */
5114     bf_mul_2exp(r, K, BF_PREC_INF, BF_RNDZ);
5115 
5116     bf_delete(U);
5117     bf_delete(V);
5118     bf_delete(X2);
5119 
5120     i = add_pi2;
5121     if (cmp_1 > 0) {
5122         /* undo the inversion : r = sign(a)*PI/2 - r */
5123         bf_neg(r);
5124         i += 1 - 2 * a->sign;
5125     }
5126     /* add i*(pi/2) with -1 <= i <= 2 */
5127     if (i != 0) {
5128         bf_const_pi(T, prec1, BF_RNDF);
5129         if (i != 2)
5130             bf_mul_2exp(T, -1, BF_PREC_INF, BF_RNDZ);
5131         T->sign = (i < 0);
5132         bf_add(r, T, r, prec1, BF_RNDN);
5133     }
5134 
5135     bf_delete(T);
5136     return BF_ST_INEXACT;
5137 }
5138 
bf_atan(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5139 int bf_atan(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5140 {
5141     bf_context_t *s = r->ctx;
5142     bf_t T_s, *T = &T_s;
5143     int res;
5144 
5145     if (a->len == 0) {
5146         if (a->expn == BF_EXP_NAN) {
5147             bf_set_nan(r);
5148             return 0;
5149         } else if (a->expn == BF_EXP_INF)  {
5150             /* -PI/2 or PI/2 */
5151             bf_const_pi_signed(r, a->sign, prec, flags);
5152             bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
5153             return BF_ST_INEXACT;
5154         } else {
5155             bf_set_zero(r, a->sign);
5156             return 0;
5157         }
5158     }
5159 
5160     bf_init(s, T);
5161     bf_set_ui(T, 1);
5162     res = bf_cmpu(a, T);
5163     bf_delete(T);
5164     if (res == 0) {
5165         /* short cut: abs(a) == 1 -> +/-pi/4 */
5166         bf_const_pi_signed(r, a->sign, prec, flags);
5167         bf_mul_2exp(r, -2, BF_PREC_INF, BF_RNDZ);
5168         return BF_ST_INEXACT;
5169     }
5170 
5171     /* small argument case: result = x+r(x) with r(x) = -x^3/3 +
5172        O(X^5). We assume r(x) < 2^(3*EXP(x) - 1). */
5173     if (a->expn < 0) {
5174         slimb_t e;
5175         e = sat_add(2 * a->expn, a->expn - 1);
5176         if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
5177             bf_set(r, a);
5178             return bf_add_epsilon(r, r, e, 1 - a->sign, prec, flags);
5179         }
5180     }
5181 
5182     return bf_ziv_rounding(r, a, prec, flags, bf_atan_internal, (void *)FALSE);
5183 }
5184 
bf_atan2_internal(bf_t * r,const bf_t * y,limb_t prec,void * opaque)5185 static int bf_atan2_internal(bf_t *r, const bf_t *y, limb_t prec, void *opaque)
5186 {
5187     bf_context_t *s = r->ctx;
5188     const bf_t *x = opaque;
5189     bf_t T_s, *T = &T_s;
5190     limb_t prec1;
5191     int ret;
5192 
5193     if (y->expn == BF_EXP_NAN || x->expn == BF_EXP_NAN) {
5194         bf_set_nan(r);
5195         return 0;
5196     }
5197 
5198     /* compute atan(y/x) assumming inf/inf = 1 and 0/0 = 0 */
5199     bf_init(s, T);
5200     prec1 = prec + 32;
5201     if (y->expn == BF_EXP_INF && x->expn == BF_EXP_INF) {
5202         bf_set_ui(T, 1);
5203         T->sign = y->sign ^ x->sign;
5204     } else if (y->expn == BF_EXP_ZERO && x->expn == BF_EXP_ZERO) {
5205         bf_set_zero(T, y->sign ^ x->sign);
5206     } else {
5207         bf_div(T, y, x, prec1, BF_RNDF);
5208     }
5209     ret = bf_atan(r, T, prec1, BF_RNDF);
5210 
5211     if (x->sign) {
5212         /* if x < 0 (it includes -0), return sign(y)*pi + atan(y/x) */
5213         bf_const_pi(T, prec1, BF_RNDF);
5214         T->sign = y->sign;
5215         bf_add(r, r, T, prec1, BF_RNDN);
5216         ret |= BF_ST_INEXACT;
5217     }
5218 
5219     bf_delete(T);
5220     return ret;
5221 }
5222 
bf_atan2(bf_t * r,const bf_t * y,const bf_t * x,limb_t prec,bf_flags_t flags)5223 int bf_atan2(bf_t *r, const bf_t *y, const bf_t *x,
5224              limb_t prec, bf_flags_t flags)
5225 {
5226     return bf_ziv_rounding(r, y, prec, flags, bf_atan2_internal, (void *)x);
5227 }
5228 
bf_asin_internal(bf_t * r,const bf_t * a,limb_t prec,void * opaque)5229 static int bf_asin_internal(bf_t *r, const bf_t *a, limb_t prec, void *opaque)
5230 {
5231     bf_context_t *s = r->ctx;
5232     BOOL is_acos = (BOOL)(intptr_t)opaque;
5233     bf_t T_s, *T = &T_s;
5234     limb_t prec1, prec2;
5235 
5236     /* asin(x) = atan(x/sqrt(1-x^2))
5237        acos(x) = pi/2 - asin(x) */
5238     prec1 = prec + 8;
5239     /* increase the precision in x^2 to compensate the cancellation in
5240        (1-x^2) if x is close to 1 */
5241     /* XXX: use less precision when possible */
5242     if (a->expn >= 0)
5243         prec2 = BF_PREC_INF;
5244     else
5245         prec2 = prec1;
5246     bf_init(s, T);
5247     bf_mul(T, a, a, prec2, BF_RNDN);
5248     bf_neg(T);
5249     bf_add_si(T, T, 1, prec2, BF_RNDN);
5250 
5251     bf_sqrt(r, T, prec1, BF_RNDN);
5252     bf_div(T, a, r, prec1, BF_RNDN);
5253     if (is_acos)
5254         bf_neg(T);
5255     bf_atan_internal(r, T, prec1, (void *)(intptr_t)is_acos);
5256     bf_delete(T);
5257     return BF_ST_INEXACT;
5258 }
5259 
bf_asin(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5260 int bf_asin(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5261 {
5262     bf_context_t *s = r->ctx;
5263     bf_t T_s, *T = &T_s;
5264     int res;
5265 
5266     if (a->len == 0) {
5267         if (a->expn == BF_EXP_NAN) {
5268             bf_set_nan(r);
5269             return 0;
5270         } else if (a->expn == BF_EXP_INF) {
5271             bf_set_nan(r);
5272             return BF_ST_INVALID_OP;
5273         } else {
5274             bf_set_zero(r, a->sign);
5275             return 0;
5276         }
5277     }
5278     bf_init(s, T);
5279     bf_set_ui(T, 1);
5280     res = bf_cmpu(a, T);
5281     bf_delete(T);
5282     if (res > 0) {
5283         bf_set_nan(r);
5284         return BF_ST_INVALID_OP;
5285     }
5286 
5287     /* small argument case: result = x+r(x) with r(x) = x^3/6 +
5288        O(X^5). We assume r(x) < 2^(3*EXP(x) - 2). */
5289     if (a->expn < 0) {
5290         slimb_t e;
5291         e = sat_add(2 * a->expn, a->expn - 2);
5292         if (e < a->expn - bf_max(prec + 2, a->len * LIMB_BITS + 2)) {
5293             bf_set(r, a);
5294             return bf_add_epsilon(r, r, e, a->sign, prec, flags);
5295         }
5296     }
5297 
5298     return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)FALSE);
5299 }
5300 
bf_acos(bf_t * r,const bf_t * a,limb_t prec,bf_flags_t flags)5301 int bf_acos(bf_t *r, const bf_t *a, limb_t prec, bf_flags_t flags)
5302 {
5303     bf_context_t *s = r->ctx;
5304     bf_t T_s, *T = &T_s;
5305     int res;
5306 
5307     if (a->len == 0) {
5308         if (a->expn == BF_EXP_NAN) {
5309             bf_set_nan(r);
5310             return 0;
5311         } else if (a->expn == BF_EXP_INF) {
5312             bf_set_nan(r);
5313             return BF_ST_INVALID_OP;
5314         } else {
5315             bf_const_pi(r, prec, flags);
5316             bf_mul_2exp(r, -1, BF_PREC_INF, BF_RNDZ);
5317             return BF_ST_INEXACT;
5318         }
5319     }
5320     bf_init(s, T);
5321     bf_set_ui(T, 1);
5322     res = bf_cmpu(a, T);
5323     bf_delete(T);
5324     if (res > 0) {
5325         bf_set_nan(r);
5326         return BF_ST_INVALID_OP;
5327     } else if (res == 0 && a->sign == 0) {
5328         bf_set_zero(r, 0);
5329         return 0;
5330     }
5331 
5332     return bf_ziv_rounding(r, a, prec, flags, bf_asin_internal, (void *)TRUE);
5333 }
5334 
5335 /***************************************************************/
5336 /* decimal floating point numbers */
5337 
5338 #ifdef USE_BF_DEC
5339 
5340 #define adddq(r1, r0, a1, a0)                   \
5341     do {                                        \
5342         limb_t __t = r0;                        \
5343         r0 += (a0);                             \
5344         r1 += (a1) + (r0 < __t);                \
5345     } while (0)
5346 
5347 #define subdq(r1, r0, a1, a0)                   \
5348     do {                                        \
5349         limb_t __t = r0;                        \
5350         r0 -= (a0);                             \
5351         r1 -= (a1) + (r0 > __t);                \
5352     } while (0)
5353 
5354 #if LIMB_BITS == 64
5355 
5356 /* Note: we assume __int128 is available */
5357 #define muldq(r1, r0, a, b)                     \
5358     do {                                        \
5359         unsigned __int128 __t;                          \
5360         __t = (unsigned __int128)(a) * (unsigned __int128)(b);  \
5361         r0 = __t;                               \
5362         r1 = __t >> 64;                         \
5363     } while (0)
5364 
5365 #define divdq(q, r, a1, a0, b)                  \
5366     do {                                        \
5367         unsigned __int128 __t;                  \
5368         limb_t __b = (b);                       \
5369         __t = ((unsigned __int128)(a1) << 64) | (a0);   \
5370         q = __t / __b;                                  \
5371         r = __t % __b;                                  \
5372     } while (0)
5373 
5374 #else
5375 
5376 #define muldq(r1, r0, a, b)                     \
5377     do {                                        \
5378         uint64_t __t;                          \
5379         __t = (uint64_t)(a) * (uint64_t)(b);  \
5380         r0 = __t;                               \
5381         r1 = __t >> 32;                         \
5382     } while (0)
5383 
5384 #define divdq(q, r, a1, a0, b)                  \
5385     do {                                        \
5386         uint64_t __t;                  \
5387         limb_t __b = (b);                       \
5388         __t = ((uint64_t)(a1) << 32) | (a0);   \
5389         q = __t / __b;                                  \
5390         r = __t % __b;                                  \
5391     } while (0)
5392 
5393 #endif /* LIMB_BITS != 64 */
5394 
shrd(limb_t low,limb_t high,long shift)5395 static inline __maybe_unused limb_t shrd(limb_t low, limb_t high, long shift)
5396 {
5397     if (shift != 0)
5398         low = (low >> shift) | (high << (LIMB_BITS - shift));
5399     return low;
5400 }
5401 
shld(limb_t a1,limb_t a0,long shift)5402 static inline __maybe_unused limb_t shld(limb_t a1, limb_t a0, long shift)
5403 {
5404     if (shift != 0)
5405         return (a1 << shift) | (a0 >> (LIMB_BITS - shift));
5406     else
5407         return a1;
5408 }
5409 
5410 #if LIMB_DIGITS == 19
5411 
5412 /* WARNING: hardcoded for b = 1e19. It is assumed that:
5413    0 <= a1 < 2^63 */
5414 #define divdq_base(q, r, a1, a0)\
5415 do {\
5416     uint64_t __a0, __a1, __t0, __t1, __b = BF_DEC_BASE; \
5417     __a0 = a0;\
5418     __a1 = a1;\
5419     __t0 = __a1;\
5420     __t0 = shld(__t0, __a0, 1);\
5421     muldq(q, __t1, __t0, UINT64_C(17014118346046923173)); \
5422     muldq(__t1, __t0, q, __b);\
5423     subdq(__a1, __a0, __t1, __t0);\
5424     subdq(__a1, __a0, 1, __b * 2);    \
5425     __t0 = (slimb_t)__a1 >> 1; \
5426     q += 2 + __t0;\
5427     adddq(__a1, __a0, 0, __b & __t0);\
5428     q += __a1;                  \
5429     __a0 += __b & __a1;           \
5430     r = __a0;\
5431 } while(0)
5432 
5433 #elif LIMB_DIGITS == 9
5434 
5435 /* WARNING: hardcoded for b = 1e9. It is assumed that:
5436    0 <= a1 < 2^29 */
5437 #define divdq_base(q, r, a1, a0)\
5438 do {\
5439     uint32_t __t0, __t1, __b = BF_DEC_BASE; \
5440     __t0 = a1;\
5441     __t1 = a0;\
5442     __t0 = (__t0 << 3) | (__t1 >> (32 - 3));    \
5443     muldq(q, __t1, __t0, 2305843009U);\
5444     r = a0 - q * __b;\
5445     __t1 = (r >= __b);\
5446     q += __t1;\
5447     if (__t1)\
5448         r -= __b;\
5449 } while(0)
5450 
5451 #endif
5452 
5453 /* fast integer division by a fixed constant */
5454 
5455 typedef struct FastDivData {
5456     limb_t m1; /* multiplier */
5457     int8_t shift1;
5458     int8_t shift2;
5459 } FastDivData;
5460 
5461 /* From "Division by Invariant Integers using Multiplication" by
5462    Torborn Granlund and Peter L. Montgomery */
5463 /* d must be != 0 */
fast_udiv_init(FastDivData * s,limb_t d)5464 static inline __maybe_unused void fast_udiv_init(FastDivData *s, limb_t d)
5465 {
5466     int l;
5467     limb_t q, r, m1;
5468     if (d == 1)
5469         l = 0;
5470     else
5471         l = 64 - clz64(d - 1);
5472     divdq(q, r, ((limb_t)1 << l) - d, 0, d);
5473     (void)r;
5474     m1 = q + 1;
5475     //    printf("d=%lu l=%d m1=0x%016lx\n", d, l, m1);
5476     s->m1 = m1;
5477     s->shift1 = l;
5478     if (s->shift1 > 1)
5479         s->shift1 = 1;
5480     s->shift2 = l - 1;
5481     if (s->shift2 < 0)
5482         s->shift2 = 0;
5483 }
5484 
fast_udiv(limb_t a,const FastDivData * s)5485 static inline limb_t fast_udiv(limb_t a, const FastDivData *s)
5486 {
5487     limb_t t0, t1;
5488     muldq(t1, t0, s->m1, a);
5489     t0 = (a - t1) >> s->shift1;
5490     return (t1 + t0) >> s->shift2;
5491 }
5492 
5493 /* contains 10^i */
5494 const limb_t mp_pow_dec[LIMB_DIGITS + 1] = {
5495     1U,
5496     10U,
5497     100U,
5498     1000U,
5499     10000U,
5500     100000U,
5501     1000000U,
5502     10000000U,
5503     100000000U,
5504     1000000000U,
5505 #if LIMB_BITS == 64
5506     10000000000U,
5507     100000000000U,
5508     1000000000000U,
5509     10000000000000U,
5510     100000000000000U,
5511     1000000000000000U,
5512     10000000000000000U,
5513     100000000000000000U,
5514     1000000000000000000U,
5515     10000000000000000000U,
5516 #endif
5517 };
5518 
5519 /* precomputed from fast_udiv_init(10^i) */
5520 static const FastDivData mp_pow_div[LIMB_DIGITS + 1] = {
5521 #if LIMB_BITS == 32
5522     { 0x00000001, 0, 0 },
5523     { 0x9999999a, 1, 3 },
5524     { 0x47ae147b, 1, 6 },
5525     { 0x0624dd30, 1, 9 },
5526     { 0xa36e2eb2, 1, 13 },
5527     { 0x4f8b588f, 1, 16 },
5528     { 0x0c6f7a0c, 1, 19 },
5529     { 0xad7f29ac, 1, 23 },
5530     { 0x5798ee24, 1, 26 },
5531     { 0x12e0be83, 1, 29 },
5532 #else
5533     { 0x0000000000000001, 0, 0 },
5534     { 0x999999999999999a, 1, 3 },
5535     { 0x47ae147ae147ae15, 1, 6 },
5536     { 0x0624dd2f1a9fbe77, 1, 9 },
5537     { 0xa36e2eb1c432ca58, 1, 13 },
5538     { 0x4f8b588e368f0847, 1, 16 },
5539     { 0x0c6f7a0b5ed8d36c, 1, 19 },
5540     { 0xad7f29abcaf48579, 1, 23 },
5541     { 0x5798ee2308c39dfa, 1, 26 },
5542     { 0x12e0be826d694b2f, 1, 29 },
5543     { 0xb7cdfd9d7bdbab7e, 1, 33 },
5544     { 0x5fd7fe17964955fe, 1, 36 },
5545     { 0x19799812dea11198, 1, 39 },
5546     { 0xc25c268497681c27, 1, 43 },
5547     { 0x6849b86a12b9b01f, 1, 46 },
5548     { 0x203af9ee756159b3, 1, 49 },
5549     { 0xcd2b297d889bc2b7, 1, 53 },
5550     { 0x70ef54646d496893, 1, 56 },
5551     { 0x2725dd1d243aba0f, 1, 59 },
5552     { 0xd83c94fb6d2ac34d, 1, 63 },
5553 #endif
5554 };
5555 
5556 /* divide by 10^shift with 0 <= shift <= LIMB_DIGITS */
fast_shr_dec(limb_t a,int shift)5557 static inline limb_t fast_shr_dec(limb_t a, int shift)
5558 {
5559     return fast_udiv(a, &mp_pow_div[shift]);
5560 }
5561 
5562 /* division and remainder by 10^shift */
5563 #define fast_shr_rem_dec(q, r, a, shift) q = fast_shr_dec(a, shift), r = a - q * mp_pow_dec[shift]
5564 
mp_add_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5565 limb_t mp_add_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5566                   mp_size_t n, limb_t carry)
5567 {
5568     limb_t base = BF_DEC_BASE;
5569     mp_size_t i;
5570     limb_t k, a, v;
5571 
5572     k=carry;
5573     for(i=0;i<n;i++) {
5574         /* XXX: reuse the trick in add_mod */
5575         v = op1[i];
5576         a = v + op2[i] + k - base;
5577         k = a <= v;
5578         if (!k)
5579             a += base;
5580         res[i]=a;
5581     }
5582     return k;
5583 }
5584 
mp_add_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5585 limb_t mp_add_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5586 {
5587     limb_t base = BF_DEC_BASE;
5588     mp_size_t i;
5589     limb_t k, a, v;
5590 
5591     k=b;
5592     for(i=0;i<n;i++) {
5593         v = tab[i];
5594         a = v + k - base;
5595         k = a <= v;
5596         if (!k)
5597             a += base;
5598         tab[i] = a;
5599         if (k == 0)
5600             break;
5601     }
5602     return k;
5603 }
5604 
mp_sub_dec(limb_t * res,const limb_t * op1,const limb_t * op2,mp_size_t n,limb_t carry)5605 limb_t mp_sub_dec(limb_t *res, const limb_t *op1, const limb_t *op2,
5606                   mp_size_t n, limb_t carry)
5607 {
5608     limb_t base = BF_DEC_BASE;
5609     mp_size_t i;
5610     limb_t k, v, a;
5611 
5612     k=carry;
5613     for(i=0;i<n;i++) {
5614         v = op1[i];
5615         a = v - op2[i] - k;
5616         k = a > v;
5617         if (k)
5618             a += base;
5619         res[i] = a;
5620     }
5621     return k;
5622 }
5623 
mp_sub_ui_dec(limb_t * tab,limb_t b,mp_size_t n)5624 limb_t mp_sub_ui_dec(limb_t *tab, limb_t b, mp_size_t n)
5625 {
5626     limb_t base = BF_DEC_BASE;
5627     mp_size_t i;
5628     limb_t k, v, a;
5629 
5630     k=b;
5631     for(i=0;i<n;i++) {
5632         v = tab[i];
5633         a = v - k;
5634         k = a > v;
5635         if (k)
5636             a += base;
5637         tab[i]=a;
5638         if (k == 0)
5639             break;
5640     }
5641     return k;
5642 }
5643 
5644 /* taba[] = taba[] * b + l. 0 <= b, l <= base - 1. Return the high carry */
mp_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b,limb_t l)5645 limb_t mp_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5646                    limb_t b, limb_t l)
5647 {
5648     mp_size_t i;
5649     limb_t t0, t1, r;
5650 
5651     for(i = 0; i < n; i++) {
5652         muldq(t1, t0, taba[i], b);
5653         adddq(t1, t0, 0, l);
5654         divdq_base(l, r, t1, t0);
5655         tabr[i] = r;
5656     }
5657     return l;
5658 }
5659 
5660 /* tabr[] += taba[] * b. 0 <= b <= base - 1. Return the value to add
5661    to the high word */
mp_add_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5662 limb_t mp_add_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5663                        limb_t b)
5664 {
5665     mp_size_t i;
5666     limb_t l, t0, t1, r;
5667 
5668     l = 0;
5669     for(i = 0; i < n; i++) {
5670         muldq(t1, t0, taba[i], b);
5671         adddq(t1, t0, 0, l);
5672         adddq(t1, t0, 0, tabr[i]);
5673         divdq_base(l, r, t1, t0);
5674         tabr[i] = r;
5675     }
5676     return l;
5677 }
5678 
5679 /* tabr[] -= taba[] * b. 0 <= b <= base - 1. Return the value to
5680    substract to the high word. */
mp_sub_mul1_dec(limb_t * tabr,const limb_t * taba,mp_size_t n,limb_t b)5681 limb_t mp_sub_mul1_dec(limb_t *tabr, const limb_t *taba, mp_size_t n,
5682                        limb_t b)
5683 {
5684     limb_t base = BF_DEC_BASE;
5685     mp_size_t i;
5686     limb_t l, t0, t1, r, a, v, c;
5687 
5688     /* XXX: optimize */
5689     l = 0;
5690     for(i = 0; i < n; i++) {
5691         muldq(t1, t0, taba[i], b);
5692         adddq(t1, t0, 0, l);
5693         divdq_base(l, r, t1, t0);
5694         v = tabr[i];
5695         a = v - r;
5696         c = a > v;
5697         if (c)
5698             a += base;
5699         /* never bigger than base because r = 0 when l = base - 1 */
5700         l += c;
5701         tabr[i] = a;
5702     }
5703     return l;
5704 }
5705 
5706 /* size of the result : op1_size + op2_size. */
mp_mul_basecase_dec(limb_t * result,const limb_t * op1,mp_size_t op1_size,const limb_t * op2,mp_size_t op2_size)5707 void mp_mul_basecase_dec(limb_t *result,
5708                          const limb_t *op1, mp_size_t op1_size,
5709                          const limb_t *op2, mp_size_t op2_size)
5710 {
5711     mp_size_t i;
5712     limb_t r;
5713 
5714     result[op1_size] = mp_mul1_dec(result, op1, op1_size, op2[0], 0);
5715 
5716     for(i=1;i<op2_size;i++) {
5717         r = mp_add_mul1_dec(result + i, op1, op1_size, op2[i]);
5718         result[i + op1_size] = r;
5719     }
5720 }
5721 
5722 /* taba[] = (taba[] + r*base^na) / b. 0 <= b < base. 0 <= r <
5723    b. Return the remainder. */
mp_div1_dec(limb_t * tabr,const limb_t * taba,mp_size_t na,limb_t b,limb_t r)5724 limb_t mp_div1_dec(limb_t *tabr, const limb_t *taba, mp_size_t na,
5725                    limb_t b, limb_t r)
5726 {
5727     limb_t base = BF_DEC_BASE;
5728     mp_size_t i;
5729     limb_t t0, t1, q;
5730     int shift;
5731 
5732 #if (BF_DEC_BASE % 2) == 0
5733     if (b == 2) {
5734         limb_t base_div2;
5735         /* Note: only works if base is even */
5736         base_div2 = base >> 1;
5737         if (r)
5738             r = base_div2;
5739         for(i = na - 1; i >= 0; i--) {
5740             t0 = taba[i];
5741             tabr[i] = (t0 >> 1) + r;
5742             r = 0;
5743             if (t0 & 1)
5744                 r = base_div2;
5745         }
5746         if (r)
5747             r = 1;
5748     } else
5749 #endif
5750     if (na >= UDIV1NORM_THRESHOLD) {
5751         shift = clz(b);
5752         if (shift == 0) {
5753             /* normalized case: b >= 2^(LIMB_BITS-1) */
5754             limb_t b_inv;
5755             b_inv = udiv1norm_init(b);
5756             for(i = na - 1; i >= 0; i--) {
5757                 muldq(t1, t0, r, base);
5758                 adddq(t1, t0, 0, taba[i]);
5759                 q = udiv1norm(&r, t1, t0, b, b_inv);
5760                 tabr[i] = q;
5761             }
5762         } else {
5763             limb_t b_inv;
5764             b <<= shift;
5765             b_inv = udiv1norm_init(b);
5766             for(i = na - 1; i >= 0; i--) {
5767                 muldq(t1, t0, r, base);
5768                 adddq(t1, t0, 0, taba[i]);
5769                 t1 = (t1 << shift) | (t0 >> (LIMB_BITS - shift));
5770                 t0 <<= shift;
5771                 q = udiv1norm(&r, t1, t0, b, b_inv);
5772                 r >>= shift;
5773                 tabr[i] = q;
5774             }
5775         }
5776     } else {
5777         for(i = na - 1; i >= 0; i--) {
5778             muldq(t1, t0, r, base);
5779             adddq(t1, t0, 0, taba[i]);
5780             divdq(q, r, t1, t0, b);
5781             tabr[i] = q;
5782         }
5783     }
5784     return r;
5785 }
5786 
mp_print_str_dec(const char * str,const limb_t * tab,slimb_t n)5787 static __maybe_unused void mp_print_str_dec(const char *str,
5788                                        const limb_t *tab, slimb_t n)
5789 {
5790     slimb_t i;
5791     printf("%s=", str);
5792     for(i = n - 1; i >= 0; i--) {
5793         if (i != n - 1)
5794             printf("_");
5795         printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5796     }
5797     printf("\n");
5798 }
5799 
mp_print_str_h_dec(const char * str,const limb_t * tab,slimb_t n,limb_t high)5800 static __maybe_unused void mp_print_str_h_dec(const char *str,
5801                                               const limb_t *tab, slimb_t n,
5802                                               limb_t high)
5803 {
5804     slimb_t i;
5805     printf("%s=", str);
5806     printf("%0*" PRIu_LIMB, LIMB_DIGITS, high);
5807     for(i = n - 1; i >= 0; i--) {
5808         printf("_");
5809         printf("%0*" PRIu_LIMB, LIMB_DIGITS, tab[i]);
5810     }
5811     printf("\n");
5812 }
5813 
5814 //#define DEBUG_DIV_SLOW
5815 
5816 #define DIV_STATIC_ALLOC_LEN 16
5817 
5818 /* return q = a / b and r = a % b.
5819 
5820    taba[na] must be allocated if tabb1[nb - 1] < B / 2.  tabb1[nb - 1]
5821    must be != zero. na must be >= nb. 's' can be NULL if tabb1[nb - 1]
5822    >= B / 2.
5823 
5824    The remainder is is returned in taba and contains nb libms. tabq
5825    contains na - nb + 1 limbs. No overlap is permitted.
5826 
5827    Running time of the standard method: (na - nb + 1) * nb
5828    Return 0 if OK, -1 if memory alloc error
5829 */
5830 /* XXX: optimize */
mp_div_dec(bf_context_t * s,limb_t * tabq,limb_t * taba,mp_size_t na,const limb_t * tabb1,mp_size_t nb)5831 static int mp_div_dec(bf_context_t *s, limb_t *tabq,
5832                       limb_t *taba, mp_size_t na,
5833                       const limb_t *tabb1, mp_size_t nb)
5834 {
5835     limb_t base = BF_DEC_BASE;
5836     limb_t r, mult, t0, t1, a, c, q, v, *tabb;
5837     mp_size_t i, j;
5838     limb_t static_tabb[DIV_STATIC_ALLOC_LEN];
5839 
5840 #ifdef DEBUG_DIV_SLOW
5841     mp_print_str_dec("a", taba, na);
5842     mp_print_str_dec("b", tabb1, nb);
5843 #endif
5844 
5845     /* normalize tabb */
5846     r = tabb1[nb - 1];
5847     assert(r != 0);
5848     i = na - nb;
5849     if (r >= BF_DEC_BASE / 2) {
5850         mult = 1;
5851         tabb = (limb_t *)tabb1;
5852         q = 1;
5853         for(j = nb - 1; j >= 0; j--) {
5854             if (taba[i + j] != tabb[j]) {
5855                 if (taba[i + j] < tabb[j])
5856                     q = 0;
5857                 break;
5858             }
5859         }
5860         tabq[i] = q;
5861         if (q) {
5862             mp_sub_dec(taba + i, taba + i, tabb, nb, 0);
5863         }
5864         i--;
5865     } else {
5866         mult = base / (r + 1);
5867         if (likely(nb <= DIV_STATIC_ALLOC_LEN)) {
5868             tabb = static_tabb;
5869         } else {
5870             tabb = bf_malloc(s, sizeof(limb_t) * nb);
5871             if (!tabb)
5872                 return -1;
5873         }
5874         mp_mul1_dec(tabb, tabb1, nb, mult, 0);
5875         taba[na] = mp_mul1_dec(taba, taba, na, mult, 0);
5876     }
5877 
5878 #ifdef DEBUG_DIV_SLOW
5879     printf("mult=" FMT_LIMB "\n", mult);
5880     mp_print_str_dec("a_norm", taba, na + 1);
5881     mp_print_str_dec("b_norm", tabb, nb);
5882 #endif
5883 
5884     for(; i >= 0; i--) {
5885         if (unlikely(taba[i + nb] >= tabb[nb - 1])) {
5886             /* XXX: check if it is really possible */
5887             q = base - 1;
5888         } else {
5889             muldq(t1, t0, taba[i + nb], base);
5890             adddq(t1, t0, 0, taba[i + nb - 1]);
5891             divdq(q, r, t1, t0, tabb[nb - 1]);
5892         }
5893         //        printf("i=%d q1=%ld\n", i, q);
5894 
5895         r = mp_sub_mul1_dec(taba + i, tabb, nb, q);
5896         //        mp_dump("r1", taba + i, nb, bd);
5897         //        printf("r2=%ld\n", r);
5898 
5899         v = taba[i + nb];
5900         a = v - r;
5901         c = a > v;
5902         if (c)
5903             a += base;
5904         taba[i + nb] = a;
5905 
5906         if (c != 0) {
5907             /* negative result */
5908             for(;;) {
5909                 q--;
5910                 c = mp_add_dec(taba + i, taba + i, tabb, nb, 0);
5911                 /* propagate carry and test if positive result */
5912                 if (c != 0) {
5913                     if (++taba[i + nb] == base) {
5914                         break;
5915                     }
5916                 }
5917             }
5918         }
5919         tabq[i] = q;
5920     }
5921 
5922 #ifdef DEBUG_DIV_SLOW
5923     mp_print_str_dec("q", tabq, na - nb + 1);
5924     mp_print_str_dec("r", taba, nb);
5925 #endif
5926 
5927     /* remove the normalization */
5928     if (mult != 1) {
5929         mp_div1_dec(taba, taba, nb, mult, 0);
5930         if (unlikely(tabb != static_tabb))
5931             bf_free(s, tabb);
5932     }
5933     return 0;
5934 }
5935 
5936 /* divide by 10^shift */
mp_shr_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t high)5937 static limb_t mp_shr_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5938                          limb_t shift, limb_t high)
5939 {
5940     mp_size_t i;
5941     limb_t l, a, q, r;
5942 
5943     assert(shift >= 1 && shift < LIMB_DIGITS);
5944     l = high;
5945     for(i = n - 1; i >= 0; i--) {
5946         a = tab[i];
5947         fast_shr_rem_dec(q, r, a, shift);
5948         tab_r[i] = q + l * mp_pow_dec[LIMB_DIGITS - shift];
5949         l = r;
5950     }
5951     return l;
5952 }
5953 
5954 /* multiply by 10^shift */
mp_shl_dec(limb_t * tab_r,const limb_t * tab,mp_size_t n,limb_t shift,limb_t low)5955 static limb_t mp_shl_dec(limb_t *tab_r, const limb_t *tab, mp_size_t n,
5956                          limb_t shift, limb_t low)
5957 {
5958     mp_size_t i;
5959     limb_t l, a, q, r;
5960 
5961     assert(shift >= 1 && shift < LIMB_DIGITS);
5962     l = low;
5963     for(i = 0; i < n; i++) {
5964         a = tab[i];
5965         fast_shr_rem_dec(q, r, a, LIMB_DIGITS - shift);
5966         tab_r[i] = r * mp_pow_dec[shift] + l;
5967         l = q;
5968     }
5969     return l;
5970 }
5971 
mp_sqrtrem2_dec(limb_t * tabs,limb_t * taba)5972 static limb_t mp_sqrtrem2_dec(limb_t *tabs, limb_t *taba)
5973 {
5974     int k;
5975     dlimb_t a, b, r;
5976     limb_t taba1[2], s, r0, r1;
5977 
5978     /* convert to binary and normalize */
5979     a = (dlimb_t)taba[1] * BF_DEC_BASE + taba[0];
5980     k = clz(a >> LIMB_BITS) & ~1;
5981     b = a << k;
5982     taba1[0] = b;
5983     taba1[1] = b >> LIMB_BITS;
5984     mp_sqrtrem2(&s, taba1);
5985     s >>= (k >> 1);
5986     /* convert the remainder back to decimal */
5987     r = a - (dlimb_t)s * (dlimb_t)s;
5988     divdq_base(r1, r0, r >> LIMB_BITS, r);
5989     taba[0] = r0;
5990     tabs[0] = s;
5991     return r1;
5992 }
5993 
5994 //#define DEBUG_SQRTREM_DEC
5995 
5996 /* tmp_buf must contain (n / 2 + 1 limbs) */
mp_sqrtrem_rec_dec(limb_t * tabs,limb_t * taba,limb_t n,limb_t * tmp_buf)5997 static limb_t mp_sqrtrem_rec_dec(limb_t *tabs, limb_t *taba, limb_t n,
5998                                  limb_t *tmp_buf)
5999 {
6000     limb_t l, h, rh, ql, qh, c, i;
6001 
6002     if (n == 1)
6003         return mp_sqrtrem2_dec(tabs, taba);
6004 #ifdef DEBUG_SQRTREM_DEC
6005     mp_print_str_dec("a", taba, 2 * n);
6006 #endif
6007     l = n / 2;
6008     h = n - l;
6009     qh = mp_sqrtrem_rec_dec(tabs + l, taba + 2 * l, h, tmp_buf);
6010 #ifdef DEBUG_SQRTREM_DEC
6011     mp_print_str_dec("s1", tabs + l, h);
6012     mp_print_str_h_dec("r1", taba + 2 * l, h, qh);
6013     mp_print_str_h_dec("r2", taba + l, n, qh);
6014 #endif
6015 
6016     /* the remainder is in taba + 2 * l. Its high bit is in qh */
6017     if (qh) {
6018         mp_sub_dec(taba + 2 * l, taba + 2 * l, tabs + l, h, 0);
6019     }
6020     /* instead of dividing by 2*s, divide by s (which is normalized)
6021        and update q and r */
6022     mp_div_dec(NULL, tmp_buf, taba + l, n, tabs + l, h);
6023     qh += tmp_buf[l];
6024     for(i = 0; i < l; i++)
6025         tabs[i] = tmp_buf[i];
6026     ql = mp_div1_dec(tabs, tabs, l, 2, qh & 1);
6027     qh = qh >> 1; /* 0 or 1 */
6028     if (ql)
6029         rh = mp_add_dec(taba + l, taba + l, tabs + l, h, 0);
6030     else
6031         rh = 0;
6032 #ifdef DEBUG_SQRTREM_DEC
6033     mp_print_str_h_dec("q", tabs, l, qh);
6034     mp_print_str_h_dec("u", taba + l, h, rh);
6035 #endif
6036 
6037     mp_add_ui_dec(tabs + l, qh, h);
6038 #ifdef DEBUG_SQRTREM_DEC
6039     mp_print_str_dec("s2", tabs, n);
6040 #endif
6041 
6042     /* q = qh, tabs[l - 1 ... 0], r = taba[n - 1 ... l] */
6043     /* subtract q^2. if qh = 1 then q = B^l, so we can take shortcuts */
6044     if (qh) {
6045         c = qh;
6046     } else {
6047         mp_mul_basecase_dec(taba + n, tabs, l, tabs, l);
6048         c = mp_sub_dec(taba, taba, taba + n, 2 * l, 0);
6049     }
6050     rh -= mp_sub_ui_dec(taba + 2 * l, c, n - 2 * l);
6051     if ((slimb_t)rh < 0) {
6052         mp_sub_ui_dec(tabs, 1, n);
6053         rh += mp_add_mul1_dec(taba, tabs, n, 2);
6054         rh += mp_add_ui_dec(taba, 1, n);
6055     }
6056     return rh;
6057 }
6058 
6059 /* 'taba' has 2*n limbs with n >= 1 and taba[2*n-1] >= B/4. Return (s,
6060    r) with s=floor(sqrt(a)) and r=a-s^2. 0 <= r <= 2 * s. tabs has n
6061    limbs. r is returned in the lower n limbs of taba. Its r[n] is the
6062    returned value of the function. */
mp_sqrtrem_dec(bf_context_t * s,limb_t * tabs,limb_t * taba,limb_t n)6063 int mp_sqrtrem_dec(bf_context_t *s, limb_t *tabs, limb_t *taba, limb_t n)
6064 {
6065     limb_t tmp_buf1[8];
6066     limb_t *tmp_buf;
6067     mp_size_t n2;
6068     n2 = n / 2 + 1;
6069     if (n2 <= countof(tmp_buf1)) {
6070         tmp_buf = tmp_buf1;
6071     } else {
6072         tmp_buf = bf_malloc(s, sizeof(limb_t) * n2);
6073         if (!tmp_buf)
6074             return -1;
6075     }
6076     taba[n] = mp_sqrtrem_rec_dec(tabs, taba, n, tmp_buf);
6077     if (tmp_buf != tmp_buf1)
6078         bf_free(s, tmp_buf);
6079     return 0;
6080 }
6081 
6082 /* return the number of leading zero digits, from 0 to LIMB_DIGITS */
clz_dec(limb_t a)6083 static int clz_dec(limb_t a)
6084 {
6085     if (a == 0)
6086         return LIMB_DIGITS;
6087     switch(LIMB_BITS - 1 - clz(a)) {
6088     case 0: /* 1-1 */
6089         return LIMB_DIGITS - 1;
6090     case 1: /* 2-3 */
6091         return LIMB_DIGITS - 1;
6092     case 2: /* 4-7 */
6093         return LIMB_DIGITS - 1;
6094     case 3: /* 8-15 */
6095         if (a < 10)
6096             return LIMB_DIGITS - 1;
6097         else
6098             return LIMB_DIGITS - 2;
6099     case 4: /* 16-31 */
6100         return LIMB_DIGITS - 2;
6101     case 5: /* 32-63 */
6102         return LIMB_DIGITS - 2;
6103     case 6: /* 64-127 */
6104         if (a < 100)
6105             return LIMB_DIGITS - 2;
6106         else
6107             return LIMB_DIGITS - 3;
6108     case 7: /* 128-255 */
6109         return LIMB_DIGITS - 3;
6110     case 8: /* 256-511 */
6111         return LIMB_DIGITS - 3;
6112     case 9: /* 512-1023 */
6113         if (a < 1000)
6114             return LIMB_DIGITS - 3;
6115         else
6116             return LIMB_DIGITS - 4;
6117     case 10: /* 1024-2047 */
6118         return LIMB_DIGITS - 4;
6119     case 11: /* 2048-4095 */
6120         return LIMB_DIGITS - 4;
6121     case 12: /* 4096-8191 */
6122         return LIMB_DIGITS - 4;
6123     case 13: /* 8192-16383 */
6124         if (a < 10000)
6125             return LIMB_DIGITS - 4;
6126         else
6127             return LIMB_DIGITS - 5;
6128     case 14: /* 16384-32767 */
6129         return LIMB_DIGITS - 5;
6130     case 15: /* 32768-65535 */
6131         return LIMB_DIGITS - 5;
6132     case 16: /* 65536-131071 */
6133         if (a < 100000)
6134             return LIMB_DIGITS - 5;
6135         else
6136             return LIMB_DIGITS - 6;
6137     case 17: /* 131072-262143 */
6138         return LIMB_DIGITS - 6;
6139     case 18: /* 262144-524287 */
6140         return LIMB_DIGITS - 6;
6141     case 19: /* 524288-1048575 */
6142         if (a < 1000000)
6143             return LIMB_DIGITS - 6;
6144         else
6145             return LIMB_DIGITS - 7;
6146     case 20: /* 1048576-2097151 */
6147         return LIMB_DIGITS - 7;
6148     case 21: /* 2097152-4194303 */
6149         return LIMB_DIGITS - 7;
6150     case 22: /* 4194304-8388607 */
6151         return LIMB_DIGITS - 7;
6152     case 23: /* 8388608-16777215 */
6153         if (a < 10000000)
6154             return LIMB_DIGITS - 7;
6155         else
6156             return LIMB_DIGITS - 8;
6157     case 24: /* 16777216-33554431 */
6158         return LIMB_DIGITS - 8;
6159     case 25: /* 33554432-67108863 */
6160         return LIMB_DIGITS - 8;
6161     case 26: /* 67108864-134217727 */
6162         if (a < 100000000)
6163             return LIMB_DIGITS - 8;
6164         else
6165             return LIMB_DIGITS - 9;
6166 #if LIMB_BITS == 64
6167     case 27: /* 134217728-268435455 */
6168         return LIMB_DIGITS - 9;
6169     case 28: /* 268435456-536870911 */
6170         return LIMB_DIGITS - 9;
6171     case 29: /* 536870912-1073741823 */
6172         if (a < 1000000000)
6173             return LIMB_DIGITS - 9;
6174         else
6175             return LIMB_DIGITS - 10;
6176     case 30: /* 1073741824-2147483647 */
6177         return LIMB_DIGITS - 10;
6178     case 31: /* 2147483648-4294967295 */
6179         return LIMB_DIGITS - 10;
6180     case 32: /* 4294967296-8589934591 */
6181         return LIMB_DIGITS - 10;
6182     case 33: /* 8589934592-17179869183 */
6183         if (a < 10000000000)
6184             return LIMB_DIGITS - 10;
6185         else
6186             return LIMB_DIGITS - 11;
6187     case 34: /* 17179869184-34359738367 */
6188         return LIMB_DIGITS - 11;
6189     case 35: /* 34359738368-68719476735 */
6190         return LIMB_DIGITS - 11;
6191     case 36: /* 68719476736-137438953471 */
6192         if (a < 100000000000)
6193             return LIMB_DIGITS - 11;
6194         else
6195             return LIMB_DIGITS - 12;
6196     case 37: /* 137438953472-274877906943 */
6197         return LIMB_DIGITS - 12;
6198     case 38: /* 274877906944-549755813887 */
6199         return LIMB_DIGITS - 12;
6200     case 39: /* 549755813888-1099511627775 */
6201         if (a < 1000000000000)
6202             return LIMB_DIGITS - 12;
6203         else
6204             return LIMB_DIGITS - 13;
6205     case 40: /* 1099511627776-2199023255551 */
6206         return LIMB_DIGITS - 13;
6207     case 41: /* 2199023255552-4398046511103 */
6208         return LIMB_DIGITS - 13;
6209     case 42: /* 4398046511104-8796093022207 */
6210         return LIMB_DIGITS - 13;
6211     case 43: /* 8796093022208-17592186044415 */
6212         if (a < 10000000000000)
6213             return LIMB_DIGITS - 13;
6214         else
6215             return LIMB_DIGITS - 14;
6216     case 44: /* 17592186044416-35184372088831 */
6217         return LIMB_DIGITS - 14;
6218     case 45: /* 35184372088832-70368744177663 */
6219         return LIMB_DIGITS - 14;
6220     case 46: /* 70368744177664-140737488355327 */
6221         if (a < 100000000000000)
6222             return LIMB_DIGITS - 14;
6223         else
6224             return LIMB_DIGITS - 15;
6225     case 47: /* 140737488355328-281474976710655 */
6226         return LIMB_DIGITS - 15;
6227     case 48: /* 281474976710656-562949953421311 */
6228         return LIMB_DIGITS - 15;
6229     case 49: /* 562949953421312-1125899906842623 */
6230         if (a < 1000000000000000)
6231             return LIMB_DIGITS - 15;
6232         else
6233             return LIMB_DIGITS - 16;
6234     case 50: /* 1125899906842624-2251799813685247 */
6235         return LIMB_DIGITS - 16;
6236     case 51: /* 2251799813685248-4503599627370495 */
6237         return LIMB_DIGITS - 16;
6238     case 52: /* 4503599627370496-9007199254740991 */
6239         return LIMB_DIGITS - 16;
6240     case 53: /* 9007199254740992-18014398509481983 */
6241         if (a < 10000000000000000)
6242             return LIMB_DIGITS - 16;
6243         else
6244             return LIMB_DIGITS - 17;
6245     case 54: /* 18014398509481984-36028797018963967 */
6246         return LIMB_DIGITS - 17;
6247     case 55: /* 36028797018963968-72057594037927935 */
6248         return LIMB_DIGITS - 17;
6249     case 56: /* 72057594037927936-144115188075855871 */
6250         if (a < 100000000000000000)
6251             return LIMB_DIGITS - 17;
6252         else
6253             return LIMB_DIGITS - 18;
6254     case 57: /* 144115188075855872-288230376151711743 */
6255         return LIMB_DIGITS - 18;
6256     case 58: /* 288230376151711744-576460752303423487 */
6257         return LIMB_DIGITS - 18;
6258     case 59: /* 576460752303423488-1152921504606846975 */
6259         if (a < 1000000000000000000)
6260             return LIMB_DIGITS - 18;
6261         else
6262             return LIMB_DIGITS - 19;
6263 #endif
6264     default:
6265         return 0;
6266     }
6267 }
6268 
6269 /* for debugging */
bfdec_print_str(const char * str,const bfdec_t * a)6270 void bfdec_print_str(const char *str, const bfdec_t *a)
6271 {
6272     slimb_t i;
6273     printf("%s=", str);
6274 
6275     if (a->expn == BF_EXP_NAN) {
6276         printf("NaN");
6277     } else {
6278         if (a->sign)
6279             putchar('-');
6280         if (a->expn == BF_EXP_ZERO) {
6281             putchar('0');
6282         } else if (a->expn == BF_EXP_INF) {
6283             printf("Inf");
6284         } else {
6285             printf("0.");
6286             for(i = a->len - 1; i >= 0; i--)
6287                 printf("%0*" PRIu_LIMB, LIMB_DIGITS, a->tab[i]);
6288             printf("e%" PRId_LIMB, a->expn);
6289         }
6290     }
6291     printf("\n");
6292 }
6293 
6294 /* return != 0 if one digit between 0 and bit_pos inclusive is not zero. */
scan_digit_nz(const bfdec_t * r,slimb_t bit_pos)6295 static inline limb_t scan_digit_nz(const bfdec_t *r, slimb_t bit_pos)
6296 {
6297     slimb_t pos;
6298     limb_t v, q;
6299     int shift;
6300 
6301     if (bit_pos < 0)
6302         return 0;
6303     pos = (limb_t)bit_pos / LIMB_DIGITS;
6304     shift = (limb_t)bit_pos % LIMB_DIGITS;
6305     fast_shr_rem_dec(q, v, r->tab[pos], shift + 1);
6306     (void)q;
6307     if (v != 0)
6308         return 1;
6309     pos--;
6310     while (pos >= 0) {
6311         if (r->tab[pos] != 0)
6312             return 1;
6313         pos--;
6314     }
6315     return 0;
6316 }
6317 
get_digit(const limb_t * tab,limb_t len,slimb_t pos)6318 static limb_t get_digit(const limb_t *tab, limb_t len, slimb_t pos)
6319 {
6320     slimb_t i;
6321     int shift;
6322     i = floor_div(pos, LIMB_DIGITS);
6323     if (i < 0 || i >= len)
6324         return 0;
6325     shift = pos - i * LIMB_DIGITS;
6326     return fast_shr_dec(tab[i], shift) % 10;
6327 }
6328 
6329 #if 0
6330 static limb_t get_digits(const limb_t *tab, limb_t len, slimb_t pos)
6331 {
6332     limb_t a0, a1;
6333     int shift;
6334     slimb_t i;
6335 
6336     i = floor_div(pos, LIMB_DIGITS);
6337     shift = pos - i * LIMB_DIGITS;
6338     if (i >= 0 && i < len)
6339         a0 = tab[i];
6340     else
6341         a0 = 0;
6342     if (shift == 0) {
6343         return a0;
6344     } else {
6345         i++;
6346         if (i >= 0 && i < len)
6347             a1 = tab[i];
6348         else
6349             a1 = 0;
6350         return fast_shr_dec(a0, shift) +
6351             fast_urem(a1, &mp_pow_div[LIMB_DIGITS - shift]) *
6352             mp_pow_dec[shift];
6353     }
6354 }
6355 #endif
6356 
6357 /* return the addend for rounding. Note that prec can be <= 0 for bf_rint() */
bfdec_get_rnd_add(int * pret,const bfdec_t * r,limb_t l,slimb_t prec,int rnd_mode)6358 static int bfdec_get_rnd_add(int *pret, const bfdec_t *r, limb_t l,
6359                              slimb_t prec, int rnd_mode)
6360 {
6361     int add_one, inexact;
6362     limb_t digit1, digit0;
6363 
6364     //    bfdec_print_str("get_rnd_add", r);
6365     if (rnd_mode == BF_RNDF) {
6366         digit0 = 1; /* faithful rounding does not honor the INEXACT flag */
6367     } else {
6368         /* starting limb for bit 'prec + 1' */
6369         digit0 = scan_digit_nz(r, l * LIMB_DIGITS - 1 - bf_max(0, prec + 1));
6370     }
6371 
6372     /* get the digit at 'prec' */
6373     digit1 = get_digit(r->tab, l, l * LIMB_DIGITS - 1 - prec);
6374     inexact = (digit1 | digit0) != 0;
6375 
6376     add_one = 0;
6377     switch(rnd_mode) {
6378     case BF_RNDZ:
6379         break;
6380     case BF_RNDN:
6381         if (digit1 == 5) {
6382             if (digit0) {
6383                 add_one = 1;
6384             } else {
6385                 /* round to even */
6386                 add_one =
6387                     get_digit(r->tab, l, l * LIMB_DIGITS - 1 - (prec - 1)) & 1;
6388             }
6389         } else if (digit1 > 5) {
6390             add_one = 1;
6391         }
6392         break;
6393     case BF_RNDD:
6394     case BF_RNDU:
6395         if (r->sign == (rnd_mode == BF_RNDD))
6396             add_one = inexact;
6397         break;
6398     case BF_RNDNA:
6399     case BF_RNDF:
6400         add_one = (digit1 >= 5);
6401         break;
6402     case BF_RNDA:
6403         add_one = inexact;
6404         break;
6405     default:
6406         abort();
6407     }
6408 
6409     if (inexact)
6410         *pret |= BF_ST_INEXACT;
6411     return add_one;
6412 }
6413 
6414 /* round to prec1 bits assuming 'r' is non zero and finite. 'r' is
6415    assumed to have length 'l' (1 <= l <= r->len). prec1 can be
6416    BF_PREC_INF. BF_FLAG_SUBNORMAL is not supported. Cannot fail with
6417    BF_ST_MEM_ERROR.
6418  */
__bfdec_round(bfdec_t * r,limb_t prec1,bf_flags_t flags,limb_t l)6419 static int __bfdec_round(bfdec_t *r, limb_t prec1, bf_flags_t flags, limb_t l)
6420 {
6421     int shift, add_one, rnd_mode, ret;
6422     slimb_t i, bit_pos, pos, e_min, e_max, e_range, prec;
6423 
6424     /* XXX: align to IEEE 754 2008 for decimal numbers ? */
6425     e_range = (limb_t)1 << (bf_get_exp_bits(flags) - 1);
6426     e_min = -e_range + 3;
6427     e_max = e_range;
6428 
6429     if (flags & BF_FLAG_RADPNT_PREC) {
6430         /* 'prec' is the precision after the decimal point */
6431         if (prec1 != BF_PREC_INF)
6432             prec = r->expn + prec1;
6433         else
6434             prec = prec1;
6435     } else if (unlikely(r->expn < e_min) && (flags & BF_FLAG_SUBNORMAL)) {
6436         /* restrict the precision in case of potentially subnormal
6437            result */
6438         assert(prec1 != BF_PREC_INF);
6439         prec = prec1 - (e_min - r->expn);
6440     } else {
6441         prec = prec1;
6442     }
6443 
6444     /* round to prec bits */
6445     rnd_mode = flags & BF_RND_MASK;
6446     ret = 0;
6447     add_one = bfdec_get_rnd_add(&ret, r, l, prec, rnd_mode);
6448 
6449     if (prec <= 0) {
6450         if (add_one) {
6451             bfdec_resize(r, 1); /* cannot fail because r is non zero */
6452             r->tab[0] = BF_DEC_BASE / 10;
6453             r->expn += 1 - prec;
6454             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6455             return ret;
6456         } else {
6457             goto underflow;
6458         }
6459     } else if (add_one) {
6460         limb_t carry;
6461 
6462         /* add one starting at digit 'prec - 1' */
6463         bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6464         pos = bit_pos / LIMB_DIGITS;
6465         carry = mp_pow_dec[bit_pos % LIMB_DIGITS];
6466         carry = mp_add_ui_dec(r->tab + pos, carry, l - pos);
6467         if (carry) {
6468             /* shift right by one digit */
6469             mp_shr_dec(r->tab + pos, r->tab + pos, l - pos, 1, 1);
6470             r->expn++;
6471         }
6472     }
6473 
6474     /* check underflow */
6475     if (unlikely(r->expn < e_min)) {
6476         if (flags & BF_FLAG_SUBNORMAL) {
6477             /* if inexact, also set the underflow flag */
6478             if (ret & BF_ST_INEXACT)
6479                 ret |= BF_ST_UNDERFLOW;
6480         } else {
6481         underflow:
6482             bfdec_set_zero(r, r->sign);
6483             ret |= BF_ST_UNDERFLOW | BF_ST_INEXACT;
6484             return ret;
6485         }
6486     }
6487 
6488     /* check overflow */
6489     if (unlikely(r->expn > e_max)) {
6490         bfdec_set_inf(r, r->sign);
6491         ret |= BF_ST_OVERFLOW | BF_ST_INEXACT;
6492         return ret;
6493     }
6494 
6495     /* keep the bits starting at 'prec - 1' */
6496     bit_pos = l * LIMB_DIGITS - 1 - (prec - 1);
6497     i = floor_div(bit_pos, LIMB_DIGITS);
6498     if (i >= 0) {
6499         shift = smod(bit_pos, LIMB_DIGITS);
6500         if (shift != 0) {
6501             r->tab[i] = fast_shr_dec(r->tab[i], shift) *
6502                 mp_pow_dec[shift];
6503         }
6504     } else {
6505         i = 0;
6506     }
6507     /* remove trailing zeros */
6508     while (r->tab[i] == 0)
6509         i++;
6510     if (i > 0) {
6511         l -= i;
6512         memmove(r->tab, r->tab + i, l * sizeof(limb_t));
6513     }
6514     bfdec_resize(r, l); /* cannot fail */
6515     return ret;
6516 }
6517 
6518 /* Cannot fail with BF_ST_MEM_ERROR. */
bfdec_round(bfdec_t * r,limb_t prec,bf_flags_t flags)6519 int bfdec_round(bfdec_t *r, limb_t prec, bf_flags_t flags)
6520 {
6521     if (r->len == 0)
6522         return 0;
6523     return __bfdec_round(r, prec, flags, r->len);
6524 }
6525 
6526 /* 'r' must be a finite number. Cannot fail with BF_ST_MEM_ERROR.  */
bfdec_normalize_and_round(bfdec_t * r,limb_t prec1,bf_flags_t flags)6527 int bfdec_normalize_and_round(bfdec_t *r, limb_t prec1, bf_flags_t flags)
6528 {
6529     limb_t l, v;
6530     int shift, ret;
6531 
6532     //    bfdec_print_str("bf_renorm", r);
6533     l = r->len;
6534     while (l > 0 && r->tab[l - 1] == 0)
6535         l--;
6536     if (l == 0) {
6537         /* zero */
6538         r->expn = BF_EXP_ZERO;
6539         bfdec_resize(r, 0); /* cannot fail */
6540         ret = 0;
6541     } else {
6542         r->expn -= (r->len - l) * LIMB_DIGITS;
6543         /* shift to have the MSB set to '1' */
6544         v = r->tab[l - 1];
6545         shift = clz_dec(v);
6546         if (shift != 0) {
6547             mp_shl_dec(r->tab, r->tab, l, shift, 0);
6548             r->expn -= shift;
6549         }
6550         ret = __bfdec_round(r, prec1, flags, l);
6551     }
6552     //    bf_print_str("r_final", r);
6553     return ret;
6554 }
6555 
bfdec_set_ui(bfdec_t * r,uint64_t v)6556 int bfdec_set_ui(bfdec_t *r, uint64_t v)
6557 {
6558 #if LIMB_BITS == 32
6559     if (v >= BF_DEC_BASE * BF_DEC_BASE) {
6560         if (bfdec_resize(r, 3))
6561             goto fail;
6562         r->tab[0] = v % BF_DEC_BASE;
6563         v /= BF_DEC_BASE;
6564         r->tab[1] = v % BF_DEC_BASE;
6565         r->tab[2] = v / BF_DEC_BASE;
6566         r->expn = 3 * LIMB_DIGITS;
6567     } else
6568 #endif
6569     if (v >= BF_DEC_BASE) {
6570         if (bfdec_resize(r, 2))
6571             goto fail;
6572         r->tab[0] = v % BF_DEC_BASE;
6573         r->tab[1] = v / BF_DEC_BASE;
6574         r->expn = 2 * LIMB_DIGITS;
6575     } else {
6576         if (bfdec_resize(r, 1))
6577             goto fail;
6578         r->tab[0] = v;
6579         r->expn = LIMB_DIGITS;
6580     }
6581     r->sign = 0;
6582     return bfdec_normalize_and_round(r, BF_PREC_INF, 0);
6583  fail:
6584     bfdec_set_nan(r);
6585     return BF_ST_MEM_ERROR;
6586 }
6587 
bfdec_set_si(bfdec_t * r,int64_t v)6588 int bfdec_set_si(bfdec_t *r, int64_t v)
6589 {
6590     int ret;
6591     if (v < 0) {
6592         ret = bfdec_set_ui(r, -v);
6593         r->sign = 1;
6594     } else {
6595         ret = bfdec_set_ui(r, v);
6596     }
6597     return ret;
6598 }
6599 
bfdec_add_internal(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int b_neg)6600 static int bfdec_add_internal(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec, bf_flags_t flags, int b_neg)
6601 {
6602     bf_context_t *s = r->ctx;
6603     int is_sub, cmp_res, a_sign, b_sign, ret;
6604 
6605     a_sign = a->sign;
6606     b_sign = b->sign ^ b_neg;
6607     is_sub = a_sign ^ b_sign;
6608     cmp_res = bfdec_cmpu(a, b);
6609     if (cmp_res < 0) {
6610         const bfdec_t *tmp;
6611         tmp = a;
6612         a = b;
6613         b = tmp;
6614         a_sign = b_sign; /* b_sign is never used later */
6615     }
6616     /* abs(a) >= abs(b) */
6617     if (cmp_res == 0 && is_sub && a->expn < BF_EXP_INF) {
6618         /* zero result */
6619         bfdec_set_zero(r, (flags & BF_RND_MASK) == BF_RNDD);
6620         ret = 0;
6621     } else if (a->len == 0 || b->len == 0) {
6622         ret = 0;
6623         if (a->expn >= BF_EXP_INF) {
6624             if (a->expn == BF_EXP_NAN) {
6625                 /* at least one operand is NaN */
6626                 bfdec_set_nan(r);
6627                 ret = 0;
6628             } else if (b->expn == BF_EXP_INF && is_sub) {
6629                 /* infinities with different signs */
6630                 bfdec_set_nan(r);
6631                 ret = BF_ST_INVALID_OP;
6632             } else {
6633                 bfdec_set_inf(r, a_sign);
6634             }
6635         } else {
6636             /* at least one zero and not subtract */
6637             if (bfdec_set(r, a))
6638                 return BF_ST_MEM_ERROR;
6639             r->sign = a_sign;
6640             goto renorm;
6641         }
6642     } else {
6643         slimb_t d, a_offset, b_offset, i, r_len;
6644         limb_t carry;
6645         limb_t *b1_tab;
6646         int b_shift;
6647         mp_size_t b1_len;
6648 
6649         d = a->expn - b->expn;
6650 
6651         /* XXX: not efficient in time and memory if the precision is
6652            not infinite */
6653         r_len = bf_max(a->len, b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6654         if (bfdec_resize(r, r_len))
6655             goto fail;
6656         r->sign = a_sign;
6657         r->expn = a->expn;
6658 
6659         a_offset = r_len - a->len;
6660         for(i = 0; i < a_offset; i++)
6661             r->tab[i] = 0;
6662         for(i = 0; i < a->len; i++)
6663             r->tab[a_offset + i] = a->tab[i];
6664 
6665         b_shift = d % LIMB_DIGITS;
6666         if (b_shift == 0) {
6667             b1_len = b->len;
6668             b1_tab = (limb_t *)b->tab;
6669         } else {
6670             b1_len = b->len + 1;
6671             b1_tab = bf_malloc(s, sizeof(limb_t) * b1_len);
6672             if (!b1_tab)
6673                 goto fail;
6674             b1_tab[0] = mp_shr_dec(b1_tab + 1, b->tab, b->len, b_shift, 0) *
6675                 mp_pow_dec[LIMB_DIGITS - b_shift];
6676         }
6677         b_offset = r_len - (b->len + (d + LIMB_DIGITS - 1) / LIMB_DIGITS);
6678 
6679         if (is_sub) {
6680             carry = mp_sub_dec(r->tab + b_offset, r->tab + b_offset,
6681                                b1_tab, b1_len, 0);
6682             if (carry != 0) {
6683                 carry = mp_sub_ui_dec(r->tab + b_offset + b1_len, carry,
6684                                       r_len - (b_offset + b1_len));
6685                 assert(carry == 0);
6686             }
6687         } else {
6688             carry = mp_add_dec(r->tab + b_offset, r->tab + b_offset,
6689                                b1_tab, b1_len, 0);
6690             if (carry != 0) {
6691                 carry = mp_add_ui_dec(r->tab + b_offset + b1_len, carry,
6692                                       r_len - (b_offset + b1_len));
6693             }
6694             if (carry != 0) {
6695                 if (bfdec_resize(r, r_len + 1)) {
6696                     if (b_shift != 0)
6697                         bf_free(s, b1_tab);
6698                     goto fail;
6699                 }
6700                 r->tab[r_len] = 1;
6701                 r->expn += LIMB_DIGITS;
6702             }
6703         }
6704         if (b_shift != 0)
6705             bf_free(s, b1_tab);
6706     renorm:
6707         ret = bfdec_normalize_and_round(r, prec, flags);
6708     }
6709     return ret;
6710  fail:
6711     bfdec_set_nan(r);
6712     return BF_ST_MEM_ERROR;
6713 }
6714 
__bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6715 static int __bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6716                      bf_flags_t flags)
6717 {
6718     return bfdec_add_internal(r, a, b, prec, flags, 0);
6719 }
6720 
__bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6721 static int __bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6722                      bf_flags_t flags)
6723 {
6724     return bfdec_add_internal(r, a, b, prec, flags, 1);
6725 }
6726 
bfdec_add(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6727 int bfdec_add(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6728               bf_flags_t flags)
6729 {
6730     return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6731                   (bf_op2_func_t *)__bfdec_add);
6732 }
6733 
bfdec_sub(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6734 int bfdec_sub(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6735               bf_flags_t flags)
6736 {
6737     return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6738                   (bf_op2_func_t *)__bfdec_sub);
6739 }
6740 
bfdec_mul(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6741 int bfdec_mul(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6742               bf_flags_t flags)
6743 {
6744     int ret, r_sign;
6745 
6746     if (a->len < b->len) {
6747         const bfdec_t *tmp = a;
6748         a = b;
6749         b = tmp;
6750     }
6751     r_sign = a->sign ^ b->sign;
6752     /* here b->len <= a->len */
6753     if (b->len == 0) {
6754         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6755             bfdec_set_nan(r);
6756             ret = 0;
6757         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_INF) {
6758             if ((a->expn == BF_EXP_INF && b->expn == BF_EXP_ZERO) ||
6759                 (a->expn == BF_EXP_ZERO && b->expn == BF_EXP_INF)) {
6760                 bfdec_set_nan(r);
6761                 ret = BF_ST_INVALID_OP;
6762             } else {
6763                 bfdec_set_inf(r, r_sign);
6764                 ret = 0;
6765             }
6766         } else {
6767             bfdec_set_zero(r, r_sign);
6768             ret = 0;
6769         }
6770     } else {
6771         bfdec_t tmp, *r1 = NULL;
6772         limb_t a_len, b_len;
6773         limb_t *a_tab, *b_tab;
6774 
6775         a_len = a->len;
6776         b_len = b->len;
6777         a_tab = a->tab;
6778         b_tab = b->tab;
6779 
6780         if (r == a || r == b) {
6781             bfdec_init(r->ctx, &tmp);
6782             r1 = r;
6783             r = &tmp;
6784         }
6785         if (bfdec_resize(r, a_len + b_len)) {
6786             bfdec_set_nan(r);
6787             ret = BF_ST_MEM_ERROR;
6788             goto done;
6789         }
6790         mp_mul_basecase_dec(r->tab, a_tab, a_len, b_tab, b_len);
6791         r->sign = r_sign;
6792         r->expn = a->expn + b->expn;
6793         ret = bfdec_normalize_and_round(r, prec, flags);
6794     done:
6795         if (r == &tmp)
6796             bfdec_move(r1, &tmp);
6797     }
6798     return ret;
6799 }
6800 
bfdec_mul_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6801 int bfdec_mul_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6802                  bf_flags_t flags)
6803 {
6804     bfdec_t b;
6805     int ret;
6806     bfdec_init(r->ctx, &b);
6807     ret = bfdec_set_si(&b, b1);
6808     ret |= bfdec_mul(r, a, &b, prec, flags);
6809     bfdec_delete(&b);
6810     return ret;
6811 }
6812 
bfdec_add_si(bfdec_t * r,const bfdec_t * a,int64_t b1,limb_t prec,bf_flags_t flags)6813 int bfdec_add_si(bfdec_t *r, const bfdec_t *a, int64_t b1, limb_t prec,
6814                  bf_flags_t flags)
6815 {
6816     bfdec_t b;
6817     int ret;
6818 
6819     bfdec_init(r->ctx, &b);
6820     ret = bfdec_set_si(&b, b1);
6821     ret |= bfdec_add(r, a, &b, prec, flags);
6822     bfdec_delete(&b);
6823     return ret;
6824 }
6825 
__bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6826 static int __bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6827                        limb_t prec, bf_flags_t flags)
6828 {
6829     int ret, r_sign;
6830     limb_t n, nb, precl;
6831 
6832     r_sign = a->sign ^ b->sign;
6833     if (a->expn >= BF_EXP_INF || b->expn >= BF_EXP_INF) {
6834         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6835             bfdec_set_nan(r);
6836             return 0;
6837         } else if (a->expn == BF_EXP_INF && b->expn == BF_EXP_INF) {
6838             bfdec_set_nan(r);
6839             return BF_ST_INVALID_OP;
6840         } else if (a->expn == BF_EXP_INF) {
6841             bfdec_set_inf(r, r_sign);
6842             return 0;
6843         } else {
6844             bfdec_set_zero(r, r_sign);
6845             return 0;
6846         }
6847     } else if (a->expn == BF_EXP_ZERO) {
6848         if (b->expn == BF_EXP_ZERO) {
6849             bfdec_set_nan(r);
6850             return BF_ST_INVALID_OP;
6851         } else {
6852             bfdec_set_zero(r, r_sign);
6853             return 0;
6854         }
6855     } else if (b->expn == BF_EXP_ZERO) {
6856         bfdec_set_inf(r, r_sign);
6857         return BF_ST_DIVIDE_ZERO;
6858     }
6859 
6860     nb = b->len;
6861     if (prec == BF_PREC_INF) {
6862         /* infinite precision: return BF_ST_INVALID_OP if not an exact
6863            result */
6864         /* XXX: check */
6865         precl = nb + 1;
6866     } else if (flags & BF_FLAG_RADPNT_PREC) {
6867         /* number of digits after the decimal point */
6868         /* XXX: check (2 extra digits for rounding + 2 digits) */
6869         precl = (bf_max(a->expn - b->expn, 0) + 2 +
6870                  prec + 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6871     } else {
6872         /* number of limbs of the quotient (2 extra digits for rounding) */
6873         precl = (prec + 2 + LIMB_DIGITS - 1) / LIMB_DIGITS;
6874     }
6875     n = bf_max(a->len, precl);
6876 
6877     {
6878         limb_t *taba, na, i;
6879         slimb_t d;
6880 
6881         na = n + nb;
6882         taba = bf_malloc(r->ctx, (na + 1) * sizeof(limb_t));
6883         if (!taba)
6884             goto fail;
6885         d = na - a->len;
6886         memset(taba, 0, d * sizeof(limb_t));
6887         memcpy(taba + d, a->tab, a->len * sizeof(limb_t));
6888         if (bfdec_resize(r, n + 1))
6889             goto fail1;
6890         if (mp_div_dec(r->ctx, r->tab, taba, na, b->tab, nb)) {
6891         fail1:
6892             bf_free(r->ctx, taba);
6893             goto fail;
6894         }
6895         /* see if non zero remainder */
6896         for(i = 0; i < nb; i++) {
6897             if (taba[i] != 0)
6898                 break;
6899         }
6900         bf_free(r->ctx, taba);
6901         if (i != nb) {
6902             if (prec == BF_PREC_INF) {
6903                 bfdec_set_nan(r);
6904                 return BF_ST_INVALID_OP;
6905             } else {
6906                 r->tab[0] |= 1;
6907             }
6908         }
6909         r->expn = a->expn - b->expn + LIMB_DIGITS;
6910         r->sign = r_sign;
6911         ret = bfdec_normalize_and_round(r, prec, flags);
6912     }
6913     return ret;
6914  fail:
6915     bfdec_set_nan(r);
6916     return BF_ST_MEM_ERROR;
6917 }
6918 
bfdec_div(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags)6919 int bfdec_div(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
6920               bf_flags_t flags)
6921 {
6922     return bf_op2((bf_t *)r, (bf_t *)a, (bf_t *)b, prec, flags,
6923                   (bf_op2_func_t *)__bfdec_div);
6924 }
6925 
6926 /* a and b must be finite numbers with a >= 0 and b > 0. 'q' is the
6927    integer defined as floor(a/b) and r = a - q * b. */
bfdec_tdivremu(bf_context_t * s,bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b)6928 static void bfdec_tdivremu(bf_context_t *s, bfdec_t *q, bfdec_t *r,
6929                            const bfdec_t *a, const bfdec_t *b)
6930 {
6931     if (bfdec_cmpu(a, b) < 0) {
6932         bfdec_set_ui(q, 0);
6933         bfdec_set(r, a);
6934     } else {
6935         bfdec_div(q, a, b, 0, BF_RNDZ | BF_FLAG_RADPNT_PREC);
6936         bfdec_mul(r, q, b, BF_PREC_INF, BF_RNDZ);
6937         bfdec_sub(r, a, r, BF_PREC_INF, BF_RNDZ);
6938     }
6939 }
6940 
6941 /* division and remainder.
6942 
6943    rnd_mode is the rounding mode for the quotient. The additional
6944    rounding mode BF_RND_EUCLIDIAN is supported.
6945 
6946    'q' is an integer. 'r' is rounded with prec and flags (prec can be
6947    BF_PREC_INF).
6948 */
bfdec_divrem(bfdec_t * q,bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)6949 int bfdec_divrem(bfdec_t *q, bfdec_t *r, const bfdec_t *a, const bfdec_t *b,
6950                  limb_t prec, bf_flags_t flags, int rnd_mode)
6951 {
6952     bf_context_t *s = q->ctx;
6953     bfdec_t a1_s, *a1 = &a1_s;
6954     bfdec_t b1_s, *b1 = &b1_s;
6955     bfdec_t r1_s, *r1 = &r1_s;
6956     int q_sign, res;
6957     BOOL is_ceil, is_rndn;
6958 
6959     assert(q != a && q != b);
6960     assert(r != a && r != b);
6961     assert(q != r);
6962 
6963     if (a->len == 0 || b->len == 0) {
6964         bfdec_set_zero(q, 0);
6965         if (a->expn == BF_EXP_NAN || b->expn == BF_EXP_NAN) {
6966             bfdec_set_nan(r);
6967             return 0;
6968         } else if (a->expn == BF_EXP_INF || b->expn == BF_EXP_ZERO) {
6969             bfdec_set_nan(r);
6970             return BF_ST_INVALID_OP;
6971         } else {
6972             bfdec_set(r, a);
6973             return bfdec_round(r, prec, flags);
6974         }
6975     }
6976 
6977     q_sign = a->sign ^ b->sign;
6978     is_rndn = (rnd_mode == BF_RNDN || rnd_mode == BF_RNDNA);
6979     switch(rnd_mode) {
6980     default:
6981     case BF_RNDZ:
6982     case BF_RNDN:
6983     case BF_RNDNA:
6984         is_ceil = FALSE;
6985         break;
6986     case BF_RNDD:
6987         is_ceil = q_sign;
6988         break;
6989     case BF_RNDU:
6990         is_ceil = q_sign ^ 1;
6991         break;
6992     case BF_RNDA:
6993         is_ceil = TRUE;
6994         break;
6995     case BF_DIVREM_EUCLIDIAN:
6996         is_ceil = a->sign;
6997         break;
6998     }
6999 
7000     a1->expn = a->expn;
7001     a1->tab = a->tab;
7002     a1->len = a->len;
7003     a1->sign = 0;
7004 
7005     b1->expn = b->expn;
7006     b1->tab = b->tab;
7007     b1->len = b->len;
7008     b1->sign = 0;
7009 
7010     //    bfdec_print_str("a1", a1);
7011     //    bfdec_print_str("b1", b1);
7012     /* XXX: could improve to avoid having a large 'q' */
7013     bfdec_tdivremu(s, q, r, a1, b1);
7014     if (bfdec_is_nan(q) || bfdec_is_nan(r))
7015         goto fail;
7016     //    bfdec_print_str("q", q);
7017     //    bfdec_print_str("r", r);
7018 
7019     if (r->len != 0) {
7020         if (is_rndn) {
7021             bfdec_init(s, r1);
7022             if (bfdec_set(r1, r))
7023                 goto fail;
7024             if (bfdec_mul_si(r1, r1, 2, BF_PREC_INF, BF_RNDZ)) {
7025                 bfdec_delete(r1);
7026                 goto fail;
7027             }
7028             res = bfdec_cmpu(r1, b);
7029             bfdec_delete(r1);
7030             if (res > 0 ||
7031                 (res == 0 &&
7032                  (rnd_mode == BF_RNDNA ||
7033                   (get_digit(q->tab, q->len, q->len * LIMB_DIGITS - q->expn) & 1) != 0))) {
7034                 goto do_sub_r;
7035             }
7036         } else if (is_ceil) {
7037         do_sub_r:
7038             res = bfdec_add_si(q, q, 1, BF_PREC_INF, BF_RNDZ);
7039             res |= bfdec_sub(r, r, b1, BF_PREC_INF, BF_RNDZ);
7040             if (res & BF_ST_MEM_ERROR)
7041                 goto fail;
7042         }
7043     }
7044 
7045     r->sign ^= a->sign;
7046     q->sign = q_sign;
7047     return bfdec_round(r, prec, flags);
7048  fail:
7049     bfdec_set_nan(q);
7050     bfdec_set_nan(r);
7051     return BF_ST_MEM_ERROR;
7052 }
7053 
bfdec_rem(bfdec_t * r,const bfdec_t * a,const bfdec_t * b,limb_t prec,bf_flags_t flags,int rnd_mode)7054 int bfdec_rem(bfdec_t *r, const bfdec_t *a, const bfdec_t *b, limb_t prec,
7055               bf_flags_t flags, int rnd_mode)
7056 {
7057     bfdec_t q_s, *q = &q_s;
7058     int ret;
7059 
7060     bfdec_init(r->ctx, q);
7061     ret = bfdec_divrem(q, r, a, b, prec, flags, rnd_mode);
7062     bfdec_delete(q);
7063     return ret;
7064 }
7065 
7066 /* convert to integer (infinite precision) */
bfdec_rint(bfdec_t * r,int rnd_mode)7067 int bfdec_rint(bfdec_t *r, int rnd_mode)
7068 {
7069     return bfdec_round(r, 0, rnd_mode | BF_FLAG_RADPNT_PREC);
7070 }
7071 
bfdec_sqrt(bfdec_t * r,const bfdec_t * a,limb_t prec,bf_flags_t flags)7072 int bfdec_sqrt(bfdec_t *r, const bfdec_t *a, limb_t prec, bf_flags_t flags)
7073 {
7074     bf_context_t *s = a->ctx;
7075     int ret, k;
7076     limb_t *a1, v;
7077     slimb_t n, n1, prec1;
7078     limb_t res;
7079 
7080     assert(r != a);
7081 
7082     if (a->len == 0) {
7083         if (a->expn == BF_EXP_NAN) {
7084             bfdec_set_nan(r);
7085         } else if (a->expn == BF_EXP_INF && a->sign) {
7086             goto invalid_op;
7087         } else {
7088             bfdec_set(r, a);
7089         }
7090         ret = 0;
7091     } else if (a->sign || prec == BF_PREC_INF) {
7092  invalid_op:
7093         bfdec_set_nan(r);
7094         ret = BF_ST_INVALID_OP;
7095     } else {
7096         if (flags & BF_FLAG_RADPNT_PREC) {
7097             prec1 = bf_max(floor_div(a->expn + 1, 2) + prec, 1);
7098         } else {
7099             prec1 = prec;
7100         }
7101         /* convert the mantissa to an integer with at least 2 *
7102            prec + 4 digits */
7103         n = (2 * (prec1 + 2) + 2 * LIMB_DIGITS - 1) / (2 * LIMB_DIGITS);
7104         if (bfdec_resize(r, n))
7105             goto fail;
7106         a1 = bf_malloc(s, sizeof(limb_t) * 2 * n);
7107         if (!a1)
7108             goto fail;
7109         n1 = bf_min(2 * n, a->len);
7110         memset(a1, 0, (2 * n - n1) * sizeof(limb_t));
7111         memcpy(a1 + 2 * n - n1, a->tab + a->len - n1, n1 * sizeof(limb_t));
7112         if (a->expn & 1) {
7113             res = mp_shr_dec(a1, a1, 2 * n, 1, 0);
7114         } else {
7115             res = 0;
7116         }
7117         /* normalize so that a1 >= B^(2*n)/4. Not need for n = 1
7118            because mp_sqrtrem2_dec already does it */
7119         k = 0;
7120         if (n > 1) {
7121             v = a1[2 * n - 1];
7122             while (v < BF_DEC_BASE / 4) {
7123                 k++;
7124                 v *= 4;
7125             }
7126             if (k != 0)
7127                 mp_mul1_dec(a1, a1, 2 * n, 1 << (2 * k), 0);
7128         }
7129         if (mp_sqrtrem_dec(s, r->tab, a1, n)) {
7130             bf_free(s, a1);
7131             goto fail;
7132         }
7133         if (k != 0)
7134             mp_div1_dec(r->tab, r->tab, n, 1 << k, 0);
7135         if (!res) {
7136             res = mp_scan_nz(a1, n + 1);
7137         }
7138         bf_free(s, a1);
7139         if (!res) {
7140             res = mp_scan_nz(a->tab, a->len - n1);
7141         }
7142         if (res != 0)
7143             r->tab[0] |= 1;
7144         r->sign = 0;
7145         r->expn = (a->expn + 1) >> 1;
7146         ret = bfdec_round(r, prec, flags);
7147     }
7148     return ret;
7149  fail:
7150     bfdec_set_nan(r);
7151     return BF_ST_MEM_ERROR;
7152 }
7153 
7154 /* The rounding mode is always BF_RNDZ. Return BF_ST_OVERFLOW if there
7155    is an overflow and 0 otherwise. No memory error is possible. */
bfdec_get_int32(int * pres,const bfdec_t * a)7156 int bfdec_get_int32(int *pres, const bfdec_t *a)
7157 {
7158     uint32_t v;
7159     int ret;
7160     if (a->expn >= BF_EXP_INF) {
7161         ret = 0;
7162         if (a->expn == BF_EXP_INF) {
7163             v = (uint32_t)INT32_MAX + a->sign;
7164              /* XXX: return overflow ? */
7165         } else {
7166             v = INT32_MAX;
7167         }
7168     } else if (a->expn <= 0) {
7169         v = 0;
7170         ret = 0;
7171     } else if (a->expn <= 9) {
7172         v = fast_shr_dec(a->tab[a->len - 1], LIMB_DIGITS - a->expn);
7173         if (a->sign)
7174             v = -v;
7175         ret = 0;
7176     } else if (a->expn == 10) {
7177         uint64_t v1;
7178         uint32_t v_max;
7179 #if LIMB_BITS == 64
7180         v1 = fast_shr_dec(a->tab[a->len - 1], LIMB_DIGITS - a->expn);
7181 #else
7182         v1 = (uint64_t)a->tab[a->len - 1] * 10 +
7183             get_digit(a->tab, a->len, (a->len - 1) * LIMB_DIGITS - 1);
7184 #endif
7185         v_max = (uint32_t)INT32_MAX + a->sign;
7186         if (v1 > v_max) {
7187             v = v_max;
7188             ret = BF_ST_OVERFLOW;
7189         } else {
7190             v = v1;
7191             if (a->sign)
7192                 v = -v;
7193             ret = 0;
7194         }
7195     } else {
7196         v = (uint32_t)INT32_MAX + a->sign;
7197         ret = BF_ST_OVERFLOW;
7198     }
7199     *pres = v;
7200     return ret;
7201 }
7202 
7203 /* power to an integer with infinite precision */
bfdec_pow_ui(bfdec_t * r,const bfdec_t * a,limb_t b)7204 int bfdec_pow_ui(bfdec_t *r, const bfdec_t *a, limb_t b)
7205 {
7206     int ret, n_bits, i;
7207 
7208     assert(r != a);
7209     if (b == 0)
7210         return bfdec_set_ui(r, 1);
7211     ret = bfdec_set(r, a);
7212     n_bits = LIMB_BITS - clz(b);
7213     for(i = n_bits - 2; i >= 0; i--) {
7214         ret |= bfdec_mul(r, r, r, BF_PREC_INF, BF_RNDZ);
7215         if ((b >> i) & 1)
7216             ret |= bfdec_mul(r, r, a, BF_PREC_INF, BF_RNDZ);
7217     }
7218     return ret;
7219 }
7220 
bfdec_ftoa(size_t * plen,const bfdec_t * a,limb_t prec,bf_flags_t flags)7221 char *bfdec_ftoa(size_t *plen, const bfdec_t *a, limb_t prec, bf_flags_t flags)
7222 {
7223     return bf_ftoa_internal(plen, (const bf_t *)a, 10, prec, flags, TRUE);
7224 }
7225 
bfdec_atof(bfdec_t * r,const char * str,const char ** pnext,limb_t prec,bf_flags_t flags)7226 int bfdec_atof(bfdec_t *r, const char *str, const char **pnext,
7227                limb_t prec, bf_flags_t flags)
7228 {
7229     slimb_t dummy_exp;
7230     return bf_atof_internal((bf_t *)r, &dummy_exp, str, pnext, 10, prec,
7231                             flags, TRUE);
7232 }
7233 
7234 #endif /* USE_BF_DEC */
7235 
7236 #ifdef USE_FFT_MUL
7237 /***************************************************************/
7238 /* Integer multiplication with FFT */
7239 
7240 /* or LIMB_BITS at bit position 'pos' in tab */
put_bits(limb_t * tab,limb_t len,slimb_t pos,limb_t val)7241 static inline void put_bits(limb_t *tab, limb_t len, slimb_t pos, limb_t val)
7242 {
7243     limb_t i;
7244     int p;
7245 
7246     i = pos >> LIMB_LOG2_BITS;
7247     p = pos & (LIMB_BITS - 1);
7248     if (i < len)
7249         tab[i] |= val << p;
7250     if (p != 0) {
7251         i++;
7252         if (i < len) {
7253             tab[i] |= val >> (LIMB_BITS - p);
7254         }
7255     }
7256 }
7257 
7258 #if defined(__AVX2__)
7259 
7260 typedef double NTTLimb;
7261 
7262 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
7263 #define NTT_MOD_LOG2_MIN 50
7264 #define NTT_MOD_LOG2_MAX 51
7265 #define NB_MODS 5
7266 #define NTT_PROOT_2EXP 39
7267 static const int ntt_int_bits[NB_MODS] = { 254, 203, 152, 101, 50, };
7268 
7269 static const limb_t ntt_mods[NB_MODS] = { 0x00073a8000000001, 0x0007858000000001, 0x0007a38000000001, 0x0007a68000000001, 0x0007fd8000000001,
7270 };
7271 
7272 static const limb_t ntt_proot[2][NB_MODS] = {
7273     { 0x00056198d44332c8, 0x0002eb5d640aad39, 0x00047e31eaa35fd0, 0x0005271ac118a150, 0x00075e0ce8442bd5, },
7274     { 0x000461169761bcc5, 0x0002dac3cb2da688, 0x0004abc97751e3bf, 0x000656778fc8c485, 0x0000dc6469c269fa, },
7275 };
7276 
7277 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7278  0x00020e4da740da8e, 0x0004c3dc09c09c1d, 0x000063bd097b4271, 0x000799d8f18f18fd,
7279  0x0005384222222264, 0x000572b07c1f07fe, 0x00035cd08888889a,
7280  0x00066015555557e3, 0x000725960b60b623,
7281  0x0002fc1fa1d6ce12,
7282 };
7283 
7284 #else
7285 
7286 typedef limb_t NTTLimb;
7287 
7288 #if LIMB_BITS == 64
7289 
7290 #define NTT_MOD_LOG2_MIN 61
7291 #define NTT_MOD_LOG2_MAX 62
7292 #define NB_MODS 5
7293 #define NTT_PROOT_2EXP 51
7294 static const int ntt_int_bits[NB_MODS] = { 307, 246, 185, 123, 61, };
7295 
7296 static const limb_t ntt_mods[NB_MODS] = { 0x28d8000000000001, 0x2a88000000000001, 0x2ed8000000000001, 0x3508000000000001, 0x3aa8000000000001,
7297 };
7298 
7299 static const limb_t ntt_proot[2][NB_MODS] = {
7300     { 0x1b8ea61034a2bea7, 0x21a9762de58206fb, 0x02ca782f0756a8ea, 0x278384537a3e50a1, 0x106e13fee74ce0ab, },
7301     { 0x233513af133e13b8, 0x1d13140d1c6f75f1, 0x12cde57f97e3eeda, 0x0d6149e23cbe654f, 0x36cd204f522a1379, },
7302 };
7303 
7304 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7305  0x08a9ed097b425eea, 0x18a44aaaaaaaaab3, 0x2493f57f57f57f5d, 0x126b8d0649a7f8d4,
7306  0x09d80ed7303b5ccc, 0x25b8bcf3cf3cf3d5, 0x2ce6ce63398ce638,
7307  0x0e31fad40a57eb59, 0x02a3529fd4a7f52f,
7308  0x3a5493e93e93e94a,
7309 };
7310 
7311 #elif LIMB_BITS == 32
7312 
7313 /* we must have: modulo >= 1 << NTT_MOD_LOG2_MIN */
7314 #define NTT_MOD_LOG2_MIN 29
7315 #define NTT_MOD_LOG2_MAX 30
7316 #define NB_MODS 5
7317 #define NTT_PROOT_2EXP 20
7318 static const int ntt_int_bits[NB_MODS] = { 148, 119, 89, 59, 29, };
7319 
7320 static const limb_t ntt_mods[NB_MODS] = { 0x0000000032b00001, 0x0000000033700001, 0x0000000036d00001, 0x0000000037300001, 0x000000003e500001,
7321 };
7322 
7323 static const limb_t ntt_proot[2][NB_MODS] = {
7324     { 0x0000000032525f31, 0x0000000005eb3b37, 0x00000000246eda9f, 0x0000000035f25901, 0x00000000022f5768, },
7325     { 0x00000000051eba1a, 0x00000000107be10e, 0x000000001cd574e0, 0x00000000053806e6, 0x000000002cd6bf98, },
7326 };
7327 
7328 static const limb_t ntt_mods_cr[NB_MODS * (NB_MODS - 1) / 2] = {
7329  0x000000000449559a, 0x000000001eba6ca9, 0x000000002ec18e46, 0x000000000860160b,
7330  0x000000000d321307, 0x000000000bf51120, 0x000000000f662938,
7331  0x000000000932ab3e, 0x000000002f40eef8,
7332  0x000000002e760905,
7333 };
7334 
7335 #endif /* LIMB_BITS */
7336 
7337 #endif /* !AVX2 */
7338 
7339 #if defined(__AVX2__)
7340 #define NTT_TRIG_K_MAX 18
7341 #else
7342 #define NTT_TRIG_K_MAX 19
7343 #endif
7344 
7345 typedef struct BFNTTState {
7346     bf_context_t *ctx;
7347 
7348     /* used for mul_mod_fast() */
7349     limb_t ntt_mods_div[NB_MODS];
7350 
7351     limb_t ntt_proot_pow[NB_MODS][2][NTT_PROOT_2EXP + 1];
7352     limb_t ntt_proot_pow_inv[NB_MODS][2][NTT_PROOT_2EXP + 1];
7353     NTTLimb *ntt_trig[NB_MODS][2][NTT_TRIG_K_MAX + 1];
7354     /* 1/2^n mod m */
7355     limb_t ntt_len_inv[NB_MODS][NTT_PROOT_2EXP + 1][2];
7356 #if defined(__AVX2__)
7357     __m256d ntt_mods_cr_vec[NB_MODS * (NB_MODS - 1) / 2];
7358     __m256d ntt_mods_vec[NB_MODS];
7359     __m256d ntt_mods_inv_vec[NB_MODS];
7360 #else
7361     limb_t ntt_mods_cr_inv[NB_MODS * (NB_MODS - 1) / 2];
7362 #endif
7363 } BFNTTState;
7364 
7365 static NTTLimb *get_trig(BFNTTState *s, int k, int inverse, int m_idx);
7366 
7367 /* add modulo with up to (LIMB_BITS-1) bit modulo */
add_mod(limb_t a,limb_t b,limb_t m)7368 static inline limb_t add_mod(limb_t a, limb_t b, limb_t m)
7369 {
7370     limb_t r;
7371     r = a + b;
7372     if (r >= m)
7373         r -= m;
7374     return r;
7375 }
7376 
7377 /* sub modulo with up to LIMB_BITS bit modulo */
sub_mod(limb_t a,limb_t b,limb_t m)7378 static inline limb_t sub_mod(limb_t a, limb_t b, limb_t m)
7379 {
7380     limb_t r;
7381     r = a - b;
7382     if (r > a)
7383         r += m;
7384     return r;
7385 }
7386 
7387 /* return (r0+r1*B) mod m
7388    precondition: 0 <= r0+r1*B < 2^(64+NTT_MOD_LOG2_MIN)
7389 */
mod_fast(dlimb_t r,limb_t m,limb_t m_inv)7390 static inline limb_t mod_fast(dlimb_t r,
7391                                 limb_t m, limb_t m_inv)
7392 {
7393     limb_t a1, q, t0, r1, r0;
7394 
7395     a1 = r >> NTT_MOD_LOG2_MIN;
7396 
7397     q = ((dlimb_t)a1 * m_inv) >> LIMB_BITS;
7398     r = r - (dlimb_t)q * m - m * 2;
7399     r1 = r >> LIMB_BITS;
7400     t0 = (slimb_t)r1 >> 1;
7401     r += m & t0;
7402     r0 = r;
7403     r1 = r >> LIMB_BITS;
7404     r0 += m & r1;
7405     return r0;
7406 }
7407 
7408 /* faster version using precomputed modulo inverse.
7409    precondition: 0 <= a * b < 2^(64+NTT_MOD_LOG2_MIN) */
mul_mod_fast(limb_t a,limb_t b,limb_t m,limb_t m_inv)7410 static inline limb_t mul_mod_fast(limb_t a, limb_t b,
7411                                     limb_t m, limb_t m_inv)
7412 {
7413     dlimb_t r;
7414     r = (dlimb_t)a * (dlimb_t)b;
7415     return mod_fast(r, m, m_inv);
7416 }
7417 
init_mul_mod_fast(limb_t m)7418 static inline limb_t init_mul_mod_fast(limb_t m)
7419 {
7420     dlimb_t t;
7421     assert(m < (limb_t)1 << NTT_MOD_LOG2_MAX);
7422     assert(m >= (limb_t)1 << NTT_MOD_LOG2_MIN);
7423     t = (dlimb_t)1 << (LIMB_BITS + NTT_MOD_LOG2_MIN);
7424     return t / m;
7425 }
7426 
7427 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7428    0 <= b < m. */
mul_mod_fast2(limb_t a,limb_t b,limb_t m,limb_t b_inv)7429 static inline limb_t mul_mod_fast2(limb_t a, limb_t b,
7430                                      limb_t m, limb_t b_inv)
7431 {
7432     limb_t r, q;
7433 
7434     q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7435     r = a * b - q * m;
7436     if (r >= m)
7437         r -= m;
7438     return r;
7439 }
7440 
7441 /* Faster version used when the multiplier is constant. 0 <= a < 2^64,
7442    0 <= b < m. Let r = a * b mod m. The return value is 'r' or 'r +
7443    m'. */
mul_mod_fast3(limb_t a,limb_t b,limb_t m,limb_t b_inv)7444 static inline limb_t mul_mod_fast3(limb_t a, limb_t b,
7445                                      limb_t m, limb_t b_inv)
7446 {
7447     limb_t r, q;
7448 
7449     q = ((dlimb_t)a * (dlimb_t)b_inv) >> LIMB_BITS;
7450     r = a * b - q * m;
7451     return r;
7452 }
7453 
init_mul_mod_fast2(limb_t b,limb_t m)7454 static inline limb_t init_mul_mod_fast2(limb_t b, limb_t m)
7455 {
7456     return ((dlimb_t)b << LIMB_BITS) / m;
7457 }
7458 
7459 #ifdef __AVX2__
7460 
ntt_limb_to_int(NTTLimb a,limb_t m)7461 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7462 {
7463     slimb_t v;
7464     v = a;
7465     if (v < 0)
7466         v += m;
7467     if (v >= m)
7468         v -= m;
7469     return v;
7470 }
7471 
int_to_ntt_limb(limb_t a,limb_t m)7472 static inline NTTLimb int_to_ntt_limb(limb_t a, limb_t m)
7473 {
7474     return (slimb_t)a;
7475 }
7476 
int_to_ntt_limb2(limb_t a,limb_t m)7477 static inline NTTLimb int_to_ntt_limb2(limb_t a, limb_t m)
7478 {
7479     if (a >= (m / 2))
7480         a -= m;
7481     return (slimb_t)a;
7482 }
7483 
7484 /* return r + m if r < 0 otherwise r. */
ntt_mod1(__m256d r,__m256d m)7485 static inline __m256d ntt_mod1(__m256d r, __m256d m)
7486 {
7487     return _mm256_blendv_pd(r, r + m, r);
7488 }
7489 
7490 /* input: abs(r) < 2 * m. Output: abs(r) < m */
ntt_mod(__m256d r,__m256d mf,__m256d m2f)7491 static inline __m256d ntt_mod(__m256d r, __m256d mf, __m256d m2f)
7492 {
7493     return _mm256_blendv_pd(r, r + m2f, r) - mf;
7494 }
7495 
7496 /* input: abs(a*b) < 2 * m^2, output: abs(r) < m */
ntt_mul_mod(__m256d a,__m256d b,__m256d mf,__m256d m_inv)7497 static inline __m256d ntt_mul_mod(__m256d a, __m256d b, __m256d mf,
7498                                   __m256d m_inv)
7499 {
7500     __m256d r, q, ab1, ab0, qm0, qm1;
7501     ab1 = a * b;
7502     q = _mm256_round_pd(ab1 * m_inv, 0); /* round to nearest */
7503     qm1 = q * mf;
7504     qm0 = _mm256_fmsub_pd(q, mf, qm1); /* low part */
7505     ab0 = _mm256_fmsub_pd(a, b, ab1); /* low part */
7506     r = (ab1 - qm1) + (ab0 - qm0);
7507     return r;
7508 }
7509 
bf_aligned_malloc(bf_context_t * s,size_t size,size_t align)7510 static void *bf_aligned_malloc(bf_context_t *s, size_t size, size_t align)
7511 {
7512     void *ptr;
7513     void **ptr1;
7514     ptr = bf_malloc(s, size + sizeof(void *) + align - 1);
7515     if (!ptr)
7516         return NULL;
7517     ptr1 = (void **)(((uintptr_t)ptr + sizeof(void *) + align - 1) &
7518                      ~(align - 1));
7519     ptr1[-1] = ptr;
7520     return ptr1;
7521 }
7522 
bf_aligned_free(bf_context_t * s,void * ptr)7523 static void bf_aligned_free(bf_context_t *s, void *ptr)
7524 {
7525     if (!ptr)
7526         return;
7527     bf_free(s, ((void **)ptr)[-1]);
7528 }
7529 
ntt_malloc(BFNTTState * s,size_t size)7530 static void *ntt_malloc(BFNTTState *s, size_t size)
7531 {
7532     return bf_aligned_malloc(s->ctx, size, 64);
7533 }
7534 
ntt_free(BFNTTState * s,void * ptr)7535 static void ntt_free(BFNTTState *s, void *ptr)
7536 {
7537     bf_aligned_free(s->ctx, ptr);
7538 }
7539 
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7540 static no_inline int ntt_fft(BFNTTState *s,
7541                              NTTLimb *out_buf, NTTLimb *in_buf,
7542                              NTTLimb *tmp_buf, int fft_len_log2,
7543                              int inverse, int m_idx)
7544 {
7545     limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j;
7546     NTTLimb *tab_in, *tab_out, *tmp, *trig;
7547     __m256d m_inv, mf, m2f, c, a0, a1, b0, b1;
7548     limb_t m;
7549     int l;
7550 
7551     m = ntt_mods[m_idx];
7552 
7553     m_inv = _mm256_set1_pd(1.0 / (double)m);
7554     mf = _mm256_set1_pd(m);
7555     m2f = _mm256_set1_pd(m * 2);
7556 
7557     n = (limb_t)1 << fft_len_log2;
7558     assert(n >= 8);
7559     stride_in = n / 2;
7560 
7561     tab_in = in_buf;
7562     tab_out = tmp_buf;
7563     trig = get_trig(s, fft_len_log2, inverse, m_idx);
7564     if (!trig)
7565         return -1;
7566     p = 0;
7567     for(k = 0; k < stride_in; k += 4) {
7568         a0 = _mm256_load_pd(&tab_in[k]);
7569         a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7570         c = _mm256_load_pd(trig);
7571         trig += 4;
7572         b0 = ntt_mod(a0 + a1, mf, m2f);
7573         b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7574         a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7575         a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7576         a0 = _mm256_permute4x64_pd(a0, 0xd8);
7577         a1 = _mm256_permute4x64_pd(a1, 0xd8);
7578         _mm256_store_pd(&tab_out[p], a0);
7579         _mm256_store_pd(&tab_out[p + 4], a1);
7580         p += 2 * 4;
7581     }
7582     tmp = tab_in;
7583     tab_in = tab_out;
7584     tab_out = tmp;
7585 
7586     trig = get_trig(s, fft_len_log2 - 1, inverse, m_idx);
7587     if (!trig)
7588         return -1;
7589     p = 0;
7590     for(k = 0; k < stride_in; k += 4) {
7591         a0 = _mm256_load_pd(&tab_in[k]);
7592         a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7593         c = _mm256_setr_pd(trig[0], trig[0], trig[1], trig[1]);
7594         trig += 2;
7595         b0 = ntt_mod(a0 + a1, mf, m2f);
7596         b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7597         a0 = _mm256_permute2f128_pd(b0, b1, 0x20);
7598         a1 = _mm256_permute2f128_pd(b0, b1, 0x31);
7599         _mm256_store_pd(&tab_out[p], a0);
7600         _mm256_store_pd(&tab_out[p + 4], a1);
7601         p += 2 * 4;
7602     }
7603     tmp = tab_in;
7604     tab_in = tab_out;
7605     tab_out = tmp;
7606 
7607     nb_blocks = n / 4;
7608     fft_per_block = 4;
7609 
7610     l = fft_len_log2 - 2;
7611     while (nb_blocks != 2) {
7612         nb_blocks >>= 1;
7613         p = 0;
7614         k = 0;
7615         trig = get_trig(s, l, inverse, m_idx);
7616         if (!trig)
7617             return -1;
7618         for(i = 0; i < nb_blocks; i++) {
7619             c = _mm256_set1_pd(trig[0]);
7620             trig++;
7621             for(j = 0; j < fft_per_block; j += 4) {
7622                 a0 = _mm256_load_pd(&tab_in[k + j]);
7623                 a1 = _mm256_load_pd(&tab_in[k + j + stride_in]);
7624                 b0 = ntt_mod(a0 + a1, mf, m2f);
7625                 b1 = ntt_mul_mod(a0 - a1, c, mf, m_inv);
7626                 _mm256_store_pd(&tab_out[p + j], b0);
7627                 _mm256_store_pd(&tab_out[p + j + fft_per_block], b1);
7628             }
7629             k += fft_per_block;
7630             p += 2 * fft_per_block;
7631         }
7632         fft_per_block <<= 1;
7633         l--;
7634         tmp = tab_in;
7635         tab_in = tab_out;
7636         tab_out = tmp;
7637     }
7638 
7639     tab_out = out_buf;
7640     for(k = 0; k < stride_in; k += 4) {
7641         a0 = _mm256_load_pd(&tab_in[k]);
7642         a1 = _mm256_load_pd(&tab_in[k + stride_in]);
7643         b0 = ntt_mod(a0 + a1, mf, m2f);
7644         b1 = ntt_mod(a0 - a1, mf, m2f);
7645         _mm256_store_pd(&tab_out[k], b0);
7646         _mm256_store_pd(&tab_out[k + stride_in], b1);
7647     }
7648     return 0;
7649 }
7650 
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,limb_t fft_len_log2,int k_tot,int m_idx)7651 static void ntt_vec_mul(BFNTTState *s,
7652                         NTTLimb *tab1, NTTLimb *tab2, limb_t fft_len_log2,
7653                         int k_tot, int m_idx)
7654 {
7655     limb_t i, c_inv, n, m;
7656     __m256d m_inv, mf, a, b, c;
7657 
7658     m = ntt_mods[m_idx];
7659     c_inv = s->ntt_len_inv[m_idx][k_tot][0];
7660     m_inv = _mm256_set1_pd(1.0 / (double)m);
7661     mf = _mm256_set1_pd(m);
7662     c = _mm256_set1_pd(int_to_ntt_limb(c_inv, m));
7663     n = (limb_t)1 << fft_len_log2;
7664     for(i = 0; i < n; i += 4) {
7665         a = _mm256_load_pd(&tab1[i]);
7666         b = _mm256_load_pd(&tab2[i]);
7667         a = ntt_mul_mod(a, b, mf, m_inv);
7668         a = ntt_mul_mod(a, c, mf, m_inv);
7669         _mm256_store_pd(&tab1[i], a);
7670     }
7671 }
7672 
mul_trig(NTTLimb * buf,limb_t n,limb_t c1,limb_t m,limb_t m_inv1)7673 static no_inline void mul_trig(NTTLimb *buf,
7674                                limb_t n, limb_t c1, limb_t m, limb_t m_inv1)
7675 {
7676     limb_t i, c2, c3, c4;
7677     __m256d c, c_mul, a0, mf, m_inv;
7678     assert(n >= 2);
7679 
7680     mf = _mm256_set1_pd(m);
7681     m_inv = _mm256_set1_pd(1.0 / (double)m);
7682 
7683     c2 = mul_mod_fast(c1, c1, m, m_inv1);
7684     c3 = mul_mod_fast(c2, c1, m, m_inv1);
7685     c4 = mul_mod_fast(c2, c2, m, m_inv1);
7686     c = _mm256_setr_pd(1, int_to_ntt_limb(c1, m),
7687                        int_to_ntt_limb(c2, m), int_to_ntt_limb(c3, m));
7688     c_mul = _mm256_set1_pd(int_to_ntt_limb(c4, m));
7689     for(i = 0; i < n; i += 4) {
7690         a0 = _mm256_load_pd(&buf[i]);
7691         a0 = ntt_mul_mod(a0, c, mf, m_inv);
7692         _mm256_store_pd(&buf[i], a0);
7693         c = ntt_mul_mod(c, c_mul, mf, m_inv);
7694     }
7695 }
7696 
7697 #else
7698 
ntt_malloc(BFNTTState * s,size_t size)7699 static void *ntt_malloc(BFNTTState *s, size_t size)
7700 {
7701     return bf_malloc(s->ctx, size);
7702 }
7703 
ntt_free(BFNTTState * s,void * ptr)7704 static void ntt_free(BFNTTState *s, void *ptr)
7705 {
7706     bf_free(s->ctx, ptr);
7707 }
7708 
ntt_limb_to_int(NTTLimb a,limb_t m)7709 static inline limb_t ntt_limb_to_int(NTTLimb a, limb_t m)
7710 {
7711     if (a >= m)
7712         a -= m;
7713     return a;
7714 }
7715 
int_to_ntt_limb(slimb_t a,limb_t m)7716 static inline NTTLimb int_to_ntt_limb(slimb_t a, limb_t m)
7717 {
7718     return a;
7719 }
7720 
ntt_fft(BFNTTState * s,NTTLimb * out_buf,NTTLimb * in_buf,NTTLimb * tmp_buf,int fft_len_log2,int inverse,int m_idx)7721 static no_inline int ntt_fft(BFNTTState *s, NTTLimb *out_buf, NTTLimb *in_buf,
7722                              NTTLimb *tmp_buf, int fft_len_log2,
7723                              int inverse, int m_idx)
7724 {
7725     limb_t nb_blocks, fft_per_block, p, k, n, stride_in, i, j, m, m2;
7726     NTTLimb *tab_in, *tab_out, *tmp, a0, a1, b0, b1, c, *trig, c_inv;
7727     int l;
7728 
7729     m = ntt_mods[m_idx];
7730     m2 = 2 * m;
7731     n = (limb_t)1 << fft_len_log2;
7732     nb_blocks = n;
7733     fft_per_block = 1;
7734     stride_in = n / 2;
7735     tab_in = in_buf;
7736     tab_out = tmp_buf;
7737     l = fft_len_log2;
7738     while (nb_blocks != 2) {
7739         nb_blocks >>= 1;
7740         p = 0;
7741         k = 0;
7742         trig = get_trig(s, l, inverse, m_idx);
7743         if (!trig)
7744             return -1;
7745         for(i = 0; i < nb_blocks; i++) {
7746             c = trig[0];
7747             c_inv = trig[1];
7748             trig += 2;
7749             for(j = 0; j < fft_per_block; j++) {
7750                 a0 = tab_in[k + j];
7751                 a1 = tab_in[k + j + stride_in];
7752                 b0 = add_mod(a0, a1, m2);
7753                 b1 = a0 - a1 + m2;
7754                 b1 = mul_mod_fast3(b1, c, m, c_inv);
7755                 tab_out[p + j] = b0;
7756                 tab_out[p + j + fft_per_block] = b1;
7757             }
7758             k += fft_per_block;
7759             p += 2 * fft_per_block;
7760         }
7761         fft_per_block <<= 1;
7762         l--;
7763         tmp = tab_in;
7764         tab_in = tab_out;
7765         tab_out = tmp;
7766     }
7767     /* no twiddle in last step */
7768     tab_out = out_buf;
7769     for(k = 0; k < stride_in; k++) {
7770         a0 = tab_in[k];
7771         a1 = tab_in[k + stride_in];
7772         b0 = add_mod(a0, a1, m2);
7773         b1 = sub_mod(a0, a1, m2);
7774         tab_out[k] = b0;
7775         tab_out[k + stride_in] = b1;
7776     }
7777     return 0;
7778 }
7779 
ntt_vec_mul(BFNTTState * s,NTTLimb * tab1,NTTLimb * tab2,int fft_len_log2,int k_tot,int m_idx)7780 static void ntt_vec_mul(BFNTTState *s,
7781                         NTTLimb *tab1, NTTLimb *tab2, int fft_len_log2,
7782                         int k_tot, int m_idx)
7783 {
7784     limb_t i, norm, norm_inv, a, n, m, m_inv;
7785 
7786     m = ntt_mods[m_idx];
7787     m_inv = s->ntt_mods_div[m_idx];
7788     norm = s->ntt_len_inv[m_idx][k_tot][0];
7789     norm_inv = s->ntt_len_inv[m_idx][k_tot][1];
7790     n = (limb_t)1 << fft_len_log2;
7791     for(i = 0; i < n; i++) {
7792         a = tab1[i];
7793         /* need to reduce the range so that the product is <
7794            2^(LIMB_BITS+NTT_MOD_LOG2_MIN) */
7795         if (a >= m)
7796             a -= m;
7797         a = mul_mod_fast(a, tab2[i], m, m_inv);
7798         a = mul_mod_fast3(a, norm, m, norm_inv);
7799         tab1[i] = a;
7800     }
7801 }
7802 
mul_trig(NTTLimb * buf,limb_t n,limb_t c_mul,limb_t m,limb_t m_inv)7803 static no_inline void mul_trig(NTTLimb *buf,
7804                                limb_t n, limb_t c_mul, limb_t m, limb_t m_inv)
7805 {
7806     limb_t i, c0, c_mul_inv;
7807 
7808     c0 = 1;
7809     c_mul_inv = init_mul_mod_fast2(c_mul, m);
7810     for(i = 0; i < n; i++) {
7811         buf[i] = mul_mod_fast(buf[i], c0, m, m_inv);
7812         c0 = mul_mod_fast2(c0, c_mul, m, c_mul_inv);
7813     }
7814 }
7815 
7816 #endif /* !AVX2 */
7817 
get_trig(BFNTTState * s,int k,int inverse,int m_idx)7818 static no_inline NTTLimb *get_trig(BFNTTState *s,
7819                                    int k, int inverse, int m_idx)
7820 {
7821     NTTLimb *tab;
7822     limb_t i, n2, c, c_mul, m, c_mul_inv;
7823 
7824     if (k > NTT_TRIG_K_MAX)
7825         return NULL;
7826 
7827     tab = s->ntt_trig[m_idx][inverse][k];
7828     if (tab)
7829         return tab;
7830     n2 = (limb_t)1 << (k - 1);
7831     m = ntt_mods[m_idx];
7832 #ifdef __AVX2__
7833     tab = ntt_malloc(s, sizeof(NTTLimb) * n2);
7834 #else
7835     tab = ntt_malloc(s, sizeof(NTTLimb) * n2 * 2);
7836 #endif
7837     if (!tab)
7838         return NULL;
7839     c = 1;
7840     c_mul = s->ntt_proot_pow[m_idx][inverse][k];
7841     c_mul_inv = s->ntt_proot_pow_inv[m_idx][inverse][k];
7842     for(i = 0; i < n2; i++) {
7843 #ifdef __AVX2__
7844         tab[i] = int_to_ntt_limb2(c, m);
7845 #else
7846         tab[2 * i] = int_to_ntt_limb(c, m);
7847         tab[2 * i + 1] = init_mul_mod_fast2(c, m);
7848 #endif
7849         c = mul_mod_fast2(c, c_mul, m, c_mul_inv);
7850     }
7851     s->ntt_trig[m_idx][inverse][k] = tab;
7852     return tab;
7853 }
7854 
fft_clear_cache(bf_context_t * s1)7855 void fft_clear_cache(bf_context_t *s1)
7856 {
7857     int m_idx, inverse, k;
7858     BFNTTState *s = s1->ntt_state;
7859     if (s) {
7860         for(m_idx = 0; m_idx < NB_MODS; m_idx++) {
7861             for(inverse = 0; inverse < 2; inverse++) {
7862                 for(k = 0; k < NTT_TRIG_K_MAX + 1; k++) {
7863                     if (s->ntt_trig[m_idx][inverse][k]) {
7864                         ntt_free(s, s->ntt_trig[m_idx][inverse][k]);
7865                         s->ntt_trig[m_idx][inverse][k] = NULL;
7866                     }
7867                 }
7868             }
7869         }
7870 #if defined(__AVX2__)
7871         bf_aligned_free(s1, s);
7872 #else
7873         bf_free(s1, s);
7874 #endif
7875         s1->ntt_state = NULL;
7876     }
7877 }
7878 
7879 #define STRIP_LEN 16
7880 
7881 /* dst = buf1, src = buf2 */
ntt_fft_partial(BFNTTState * s,NTTLimb * buf1,int k1,int k2,limb_t n1,limb_t n2,int inverse,limb_t m_idx)7882 static int ntt_fft_partial(BFNTTState *s, NTTLimb *buf1,
7883                            int k1, int k2, limb_t n1, limb_t n2, int inverse,
7884                            limb_t m_idx)
7885 {
7886     limb_t i, j, c_mul, c0, m, m_inv, strip_len, l;
7887     NTTLimb *buf2, *buf3;
7888 
7889     buf2 = NULL;
7890     buf3 = ntt_malloc(s, sizeof(NTTLimb) * n1);
7891     if (!buf3)
7892         goto fail;
7893     if (k2 == 0) {
7894         if (ntt_fft(s, buf1, buf1, buf3, k1, inverse, m_idx))
7895             goto fail;
7896     } else {
7897         strip_len = STRIP_LEN;
7898         buf2 = ntt_malloc(s, sizeof(NTTLimb) * n1 * strip_len);
7899         if (!buf2)
7900             goto fail;
7901         m = ntt_mods[m_idx];
7902         m_inv = s->ntt_mods_div[m_idx];
7903         c0 = s->ntt_proot_pow[m_idx][inverse][k1 + k2];
7904         c_mul = 1;
7905         assert((n2 % strip_len) == 0);
7906         for(j = 0; j < n2; j += strip_len) {
7907             for(i = 0; i < n1; i++) {
7908                 for(l = 0; l < strip_len; l++) {
7909                     buf2[i + l * n1] = buf1[i * n2 + (j + l)];
7910                 }
7911             }
7912             for(l = 0; l < strip_len; l++) {
7913                 if (inverse)
7914                     mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7915                 if (ntt_fft(s, buf2 + l * n1, buf2 + l * n1, buf3, k1, inverse, m_idx))
7916                     goto fail;
7917                 if (!inverse)
7918                     mul_trig(buf2 + l * n1, n1, c_mul, m, m_inv);
7919                 c_mul = mul_mod_fast(c_mul, c0, m, m_inv);
7920             }
7921 
7922             for(i = 0; i < n1; i++) {
7923                 for(l = 0; l < strip_len; l++) {
7924                     buf1[i * n2 + (j + l)] = buf2[i + l *n1];
7925                 }
7926             }
7927         }
7928         ntt_free(s, buf2);
7929     }
7930     ntt_free(s, buf3);
7931     return 0;
7932  fail:
7933     ntt_free(s, buf2);
7934     ntt_free(s, buf3);
7935     return -1;
7936 }
7937 
7938 
7939 /* dst = buf1, src = buf2, tmp = buf3 */
ntt_conv(BFNTTState * s,NTTLimb * buf1,NTTLimb * buf2,int k,int k_tot,limb_t m_idx)7940 static int ntt_conv(BFNTTState *s, NTTLimb *buf1, NTTLimb *buf2,
7941                     int k, int k_tot, limb_t m_idx)
7942 {
7943     limb_t n1, n2, i;
7944     int k1, k2;
7945 
7946     if (k <= NTT_TRIG_K_MAX) {
7947         k1 = k;
7948     } else {
7949         /* recursive split of the FFT */
7950         k1 = bf_min(k / 2, NTT_TRIG_K_MAX);
7951     }
7952     k2 = k - k1;
7953     n1 = (limb_t)1 << k1;
7954     n2 = (limb_t)1 << k2;
7955 
7956     if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 0, m_idx))
7957         return -1;
7958     if (ntt_fft_partial(s, buf2, k1, k2, n1, n2, 0, m_idx))
7959         return -1;
7960     if (k2 == 0) {
7961         ntt_vec_mul(s, buf1, buf2, k, k_tot, m_idx);
7962     } else {
7963         for(i = 0; i < n1; i++) {
7964             ntt_conv(s, buf1 + i * n2, buf2 + i * n2, k2, k_tot, m_idx);
7965         }
7966     }
7967     if (ntt_fft_partial(s, buf1, k1, k2, n1, n2, 1, m_idx))
7968         return -1;
7969     return 0;
7970 }
7971 
7972 
limb_to_ntt(BFNTTState * s,NTTLimb * tabr,limb_t fft_len,const limb_t * taba,limb_t a_len,int dpl,int first_m_idx,int nb_mods)7973 static no_inline void limb_to_ntt(BFNTTState *s,
7974                                   NTTLimb *tabr, limb_t fft_len,
7975                                   const limb_t *taba, limb_t a_len, int dpl,
7976                                   int first_m_idx, int nb_mods)
7977 {
7978     slimb_t i, n;
7979     dlimb_t a, b;
7980     int j, shift;
7981     limb_t base_mask1, a0, a1, a2, r, m, m_inv;
7982 
7983 #if 0
7984     for(i = 0; i < a_len; i++) {
7985         printf("%" PRId64 ": " FMT_LIMB "\n",
7986                (int64_t)i, taba[i]);
7987     }
7988 #endif
7989     memset(tabr, 0, sizeof(NTTLimb) * fft_len * nb_mods);
7990     shift = dpl & (LIMB_BITS - 1);
7991     if (shift == 0)
7992         base_mask1 = -1;
7993     else
7994         base_mask1 = ((limb_t)1 << shift) - 1;
7995     n = bf_min(fft_len, (a_len * LIMB_BITS + dpl - 1) / dpl);
7996     for(i = 0; i < n; i++) {
7997         a0 = get_bits(taba, a_len, i * dpl);
7998         if (dpl <= LIMB_BITS) {
7999             a0 &= base_mask1;
8000             a = a0;
8001         } else {
8002             a1 = get_bits(taba, a_len, i * dpl + LIMB_BITS);
8003             if (dpl <= (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
8004                 a = a0 | ((dlimb_t)(a1 & base_mask1) << LIMB_BITS);
8005             } else {
8006                 if (dpl > 2 * LIMB_BITS) {
8007                     a2 = get_bits(taba, a_len, i * dpl + LIMB_BITS * 2) &
8008                         base_mask1;
8009                 } else {
8010                     a1 &= base_mask1;
8011                     a2 = 0;
8012                 }
8013                 //            printf("a=0x%016lx%016lx%016lx\n", a2, a1, a0);
8014                 a = (a0 >> (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) |
8015                     ((dlimb_t)a1 << (NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN)) |
8016                     ((dlimb_t)a2 << (LIMB_BITS + NTT_MOD_LOG2_MAX - NTT_MOD_LOG2_MIN));
8017                 a0 &= ((limb_t)1 << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) - 1;
8018             }
8019         }
8020         for(j = 0; j < nb_mods; j++) {
8021             m = ntt_mods[first_m_idx + j];
8022             m_inv = s->ntt_mods_div[first_m_idx + j];
8023             r = mod_fast(a, m, m_inv);
8024             if (dpl > (LIMB_BITS + NTT_MOD_LOG2_MIN)) {
8025                 b = ((dlimb_t)r << (LIMB_BITS - NTT_MOD_LOG2_MAX + NTT_MOD_LOG2_MIN)) | a0;
8026                 r = mod_fast(b, m, m_inv);
8027             }
8028             tabr[i + j * fft_len] = int_to_ntt_limb(r, m);
8029         }
8030     }
8031 }
8032 
8033 #if defined(__AVX2__)
8034 
8035 #define VEC_LEN 4
8036 
8037 typedef union {
8038     __m256d v;
8039     double d[4];
8040 } VecUnion;
8041 
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)8042 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
8043                                   const NTTLimb *buf, int fft_len_log2, int dpl,
8044                                   int nb_mods)
8045 {
8046     const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
8047     const __m256d *mods_cr_vec, *mf, *m_inv;
8048     VecUnion y[NB_MODS];
8049     limb_t u[NB_MODS], carry[NB_MODS], fft_len, base_mask1, r;
8050     slimb_t i, len, pos;
8051     int j, k, l, shift, n_limb1, p;
8052     dlimb_t t;
8053 
8054     j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
8055     mods_cr_vec = s->ntt_mods_cr_vec + j;
8056     mf = s->ntt_mods_vec + NB_MODS - nb_mods;
8057     m_inv = s->ntt_mods_inv_vec + NB_MODS - nb_mods;
8058 
8059     shift = dpl & (LIMB_BITS - 1);
8060     if (shift == 0)
8061         base_mask1 = -1;
8062     else
8063         base_mask1 = ((limb_t)1 << shift) - 1;
8064     n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
8065     for(j = 0; j < NB_MODS; j++)
8066         carry[j] = 0;
8067     for(j = 0; j < NB_MODS; j++)
8068         u[j] = 0; /* avoid warnings */
8069     memset(tabr, 0, sizeof(limb_t) * r_len);
8070     fft_len = (limb_t)1 << fft_len_log2;
8071     len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
8072     len = (len + VEC_LEN - 1) & ~(VEC_LEN - 1);
8073     i = 0;
8074     while (i < len) {
8075         for(j = 0; j < nb_mods; j++)
8076             y[j].v = *(__m256d *)&buf[i + fft_len * j];
8077 
8078         /* Chinese remainder to get mixed radix representation */
8079         l = 0;
8080         for(j = 0; j < nb_mods - 1; j++) {
8081             y[j].v = ntt_mod1(y[j].v, mf[j]);
8082             for(k = j + 1; k < nb_mods; k++) {
8083                 y[k].v = ntt_mul_mod(y[k].v - y[j].v,
8084                                      mods_cr_vec[l], mf[k], m_inv[k]);
8085                 l++;
8086             }
8087         }
8088         y[j].v = ntt_mod1(y[j].v, mf[j]);
8089 
8090         for(p = 0; p < VEC_LEN; p++) {
8091             /* back to normal representation */
8092             u[0] = (int64_t)y[nb_mods - 1].d[p];
8093             l = 1;
8094             for(j = nb_mods - 2; j >= 1; j--) {
8095                 r = (int64_t)y[j].d[p];
8096                 for(k = 0; k < l; k++) {
8097                     t = (dlimb_t)u[k] * mods[j] + r;
8098                     r = t >> LIMB_BITS;
8099                     u[k] = t;
8100                 }
8101                 u[l] = r;
8102                 l++;
8103             }
8104             /* XXX: for nb_mods = 5, l should be 4 */
8105 
8106             /* last step adds the carry */
8107             r = (int64_t)y[0].d[p];
8108             for(k = 0; k < l; k++) {
8109                 t = (dlimb_t)u[k] * mods[j] + r + carry[k];
8110                 r = t >> LIMB_BITS;
8111                 u[k] = t;
8112             }
8113             u[l] = r + carry[l];
8114 
8115 #if 0
8116             printf("%" PRId64 ": ", i);
8117             for(j = nb_mods - 1; j >= 0; j--) {
8118                 printf(" %019" PRIu64, u[j]);
8119             }
8120             printf("\n");
8121 #endif
8122 
8123             /* write the digits */
8124             pos = i * dpl;
8125             for(j = 0; j < n_limb1; j++) {
8126                 put_bits(tabr, r_len, pos, u[j]);
8127                 pos += LIMB_BITS;
8128             }
8129             put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
8130             /* shift by dpl digits and set the carry */
8131             if (shift == 0) {
8132                 for(j = n_limb1 + 1; j < nb_mods; j++)
8133                     carry[j - (n_limb1 + 1)] = u[j];
8134             } else {
8135                 for(j = n_limb1; j < nb_mods - 1; j++) {
8136                     carry[j - n_limb1] = (u[j] >> shift) |
8137                         (u[j + 1] << (LIMB_BITS - shift));
8138                 }
8139                 carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
8140             }
8141             i++;
8142         }
8143     }
8144 }
8145 #else
ntt_to_limb(BFNTTState * s,limb_t * tabr,limb_t r_len,const NTTLimb * buf,int fft_len_log2,int dpl,int nb_mods)8146 static no_inline void ntt_to_limb(BFNTTState *s, limb_t *tabr, limb_t r_len,
8147                                   const NTTLimb *buf, int fft_len_log2, int dpl,
8148                                   int nb_mods)
8149 {
8150     const limb_t *mods = ntt_mods + NB_MODS - nb_mods;
8151     const limb_t *mods_cr, *mods_cr_inv;
8152     limb_t y[NB_MODS], u[NB_MODS], carry[NB_MODS], fft_len, base_mask1, r;
8153     slimb_t i, len, pos;
8154     int j, k, l, shift, n_limb1;
8155     dlimb_t t;
8156 
8157     j = NB_MODS * (NB_MODS - 1) / 2 - nb_mods * (nb_mods - 1) / 2;
8158     mods_cr = ntt_mods_cr + j;
8159     mods_cr_inv = s->ntt_mods_cr_inv + j;
8160 
8161     shift = dpl & (LIMB_BITS - 1);
8162     if (shift == 0)
8163         base_mask1 = -1;
8164     else
8165         base_mask1 = ((limb_t)1 << shift) - 1;
8166     n_limb1 = ((unsigned)dpl - 1) / LIMB_BITS;
8167     for(j = 0; j < NB_MODS; j++)
8168         carry[j] = 0;
8169     for(j = 0; j < NB_MODS; j++)
8170         u[j] = 0; /* avoid warnings */
8171     memset(tabr, 0, sizeof(limb_t) * r_len);
8172     fft_len = (limb_t)1 << fft_len_log2;
8173     len = bf_min(fft_len, (r_len * LIMB_BITS + dpl - 1) / dpl);
8174     for(i = 0; i < len; i++) {
8175         for(j = 0; j < nb_mods; j++)  {
8176             y[j] = ntt_limb_to_int(buf[i + fft_len * j], mods[j]);
8177         }
8178 
8179         /* Chinese remainder to get mixed radix representation */
8180         l = 0;
8181         for(j = 0; j < nb_mods - 1; j++) {
8182             for(k = j + 1; k < nb_mods; k++) {
8183                 limb_t m;
8184                 m = mods[k];
8185                 /* Note: there is no overflow in the sub_mod() because
8186                    the modulos are sorted by increasing order */
8187                 y[k] = mul_mod_fast2(y[k] - y[j] + m,
8188                                      mods_cr[l], m, mods_cr_inv[l]);
8189                 l++;
8190             }
8191         }
8192 
8193         /* back to normal representation */
8194         u[0] = y[nb_mods - 1];
8195         l = 1;
8196         for(j = nb_mods - 2; j >= 1; j--) {
8197             r = y[j];
8198             for(k = 0; k < l; k++) {
8199                 t = (dlimb_t)u[k] * mods[j] + r;
8200                 r = t >> LIMB_BITS;
8201                 u[k] = t;
8202             }
8203             u[l] = r;
8204             l++;
8205         }
8206 
8207         /* last step adds the carry */
8208         r = y[0];
8209         for(k = 0; k < l; k++) {
8210             t = (dlimb_t)u[k] * mods[j] + r + carry[k];
8211             r = t >> LIMB_BITS;
8212             u[k] = t;
8213         }
8214         u[l] = r + carry[l];
8215 
8216 #if 0
8217         printf("%" PRId64 ": ", (int64_t)i);
8218         for(j = nb_mods - 1; j >= 0; j--) {
8219             printf(" " FMT_LIMB, u[j]);
8220         }
8221         printf("\n");
8222 #endif
8223 
8224         /* write the digits */
8225         pos = i * dpl;
8226         for(j = 0; j < n_limb1; j++) {
8227             put_bits(tabr, r_len, pos, u[j]);
8228             pos += LIMB_BITS;
8229         }
8230         put_bits(tabr, r_len, pos, u[n_limb1] & base_mask1);
8231         /* shift by dpl digits and set the carry */
8232         if (shift == 0) {
8233             for(j = n_limb1 + 1; j < nb_mods; j++)
8234                 carry[j - (n_limb1 + 1)] = u[j];
8235         } else {
8236             for(j = n_limb1; j < nb_mods - 1; j++) {
8237                 carry[j - n_limb1] = (u[j] >> shift) |
8238                     (u[j + 1] << (LIMB_BITS - shift));
8239             }
8240             carry[nb_mods - 1 - n_limb1] = u[nb_mods - 1] >> shift;
8241         }
8242     }
8243 }
8244 #endif
8245 
ntt_static_init(bf_context_t * s1)8246 static int ntt_static_init(bf_context_t *s1)
8247 {
8248     BFNTTState *s;
8249     int inverse, i, j, k, l;
8250     limb_t c, c_inv, c_inv2, m, m_inv;
8251 
8252     if (s1->ntt_state)
8253         return 0;
8254 #if defined(__AVX2__)
8255     s = bf_aligned_malloc(s1, sizeof(*s), 64);
8256 #else
8257     s = bf_malloc(s1, sizeof(*s));
8258 #endif
8259     if (!s)
8260         return -1;
8261     memset(s, 0, sizeof(*s));
8262     s1->ntt_state = s;
8263     s->ctx = s1;
8264 
8265     for(j = 0; j < NB_MODS; j++) {
8266         m = ntt_mods[j];
8267         m_inv = init_mul_mod_fast(m);
8268         s->ntt_mods_div[j] = m_inv;
8269 #if defined(__AVX2__)
8270         s->ntt_mods_vec[j] = _mm256_set1_pd(m);
8271         s->ntt_mods_inv_vec[j] = _mm256_set1_pd(1.0 / (double)m);
8272 #endif
8273         c_inv2 = (m + 1) / 2; /* 1/2 */
8274         c_inv = 1;
8275         for(i = 0; i <= NTT_PROOT_2EXP; i++) {
8276             s->ntt_len_inv[j][i][0] = c_inv;
8277             s->ntt_len_inv[j][i][1] = init_mul_mod_fast2(c_inv, m);
8278             c_inv = mul_mod_fast(c_inv, c_inv2, m, m_inv);
8279         }
8280 
8281         for(inverse = 0; inverse < 2; inverse++) {
8282             c = ntt_proot[inverse][j];
8283             for(i = 0; i < NTT_PROOT_2EXP; i++) {
8284                 s->ntt_proot_pow[j][inverse][NTT_PROOT_2EXP - i] = c;
8285                 s->ntt_proot_pow_inv[j][inverse][NTT_PROOT_2EXP - i] =
8286                     init_mul_mod_fast2(c, m);
8287                 c = mul_mod_fast(c, c, m, m_inv);
8288             }
8289         }
8290     }
8291 
8292     l = 0;
8293     for(j = 0; j < NB_MODS - 1; j++) {
8294         for(k = j + 1; k < NB_MODS; k++) {
8295 #if defined(__AVX2__)
8296             s->ntt_mods_cr_vec[l] = _mm256_set1_pd(int_to_ntt_limb2(ntt_mods_cr[l],
8297                                                                     ntt_mods[k]));
8298 #else
8299             s->ntt_mods_cr_inv[l] = init_mul_mod_fast2(ntt_mods_cr[l],
8300                                                        ntt_mods[k]);
8301 #endif
8302             l++;
8303         }
8304     }
8305     return 0;
8306 }
8307 
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8308 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8309 {
8310     int dpl, fft_len_log2, n_bits, nb_mods, dpl_found, fft_len_log2_found;
8311     int int_bits, nb_mods_found;
8312     limb_t cost, min_cost;
8313 
8314     min_cost = -1;
8315     dpl_found = 0;
8316     nb_mods_found = 4;
8317     fft_len_log2_found = 0;
8318     for(nb_mods = 3; nb_mods <= NB_MODS; nb_mods++) {
8319         int_bits = ntt_int_bits[NB_MODS - nb_mods];
8320         dpl = bf_min((int_bits - 4) / 2,
8321                      2 * LIMB_BITS + 2 * NTT_MOD_LOG2_MIN - NTT_MOD_LOG2_MAX);
8322         for(;;) {
8323             fft_len_log2 = ceil_log2((len * LIMB_BITS + dpl - 1) / dpl);
8324             if (fft_len_log2 > NTT_PROOT_2EXP)
8325                 goto next;
8326             n_bits = fft_len_log2 + 2 * dpl;
8327             if (n_bits <= int_bits) {
8328                 cost = ((limb_t)(fft_len_log2 + 1) << fft_len_log2) * nb_mods;
8329                 //                printf("n=%d dpl=%d: cost=%" PRId64 "\n", nb_mods, dpl, (int64_t)cost);
8330                 if (cost < min_cost) {
8331                     min_cost = cost;
8332                     dpl_found = dpl;
8333                     nb_mods_found = nb_mods;
8334                     fft_len_log2_found = fft_len_log2;
8335                 }
8336                 break;
8337             }
8338             dpl--;
8339             if (dpl == 0)
8340                 break;
8341         }
8342     next: ;
8343     }
8344     if (!dpl_found)
8345         abort();
8346     /* limit dpl if possible to reduce fixed cost of limb/NTT conversion */
8347     if (dpl_found > (LIMB_BITS + NTT_MOD_LOG2_MIN) &&
8348         ((limb_t)(LIMB_BITS + NTT_MOD_LOG2_MIN) << fft_len_log2_found) >=
8349         len * LIMB_BITS) {
8350         dpl_found = LIMB_BITS + NTT_MOD_LOG2_MIN;
8351     }
8352     *pnb_mods = nb_mods_found;
8353     *pdpl = dpl_found;
8354     return fft_len_log2_found;
8355 }
8356 
8357 /* return 0 if OK, -1 if memory error */
fft_mul(bf_context_t * s1,bf_t * res,limb_t * a_tab,limb_t a_len,limb_t * b_tab,limb_t b_len,int mul_flags)8358 static no_inline int fft_mul(bf_context_t *s1,
8359                              bf_t *res, limb_t *a_tab, limb_t a_len,
8360                              limb_t *b_tab, limb_t b_len, int mul_flags)
8361 {
8362     BFNTTState *s;
8363     int dpl, fft_len_log2, j, nb_mods, reduced_mem;
8364     slimb_t len, fft_len;
8365     NTTLimb *buf1, *buf2, *ptr;
8366 #if defined(USE_MUL_CHECK)
8367     limb_t ha, hb, hr, h_ref;
8368 #endif
8369 
8370     if (ntt_static_init(s1))
8371         return -1;
8372     s = s1->ntt_state;
8373 
8374     /* find the optimal number of digits per limb (dpl) */
8375     len = a_len + b_len;
8376     fft_len_log2 = bf_get_fft_size(&dpl, &nb_mods, len);
8377     fft_len = (uint64_t)1 << fft_len_log2;
8378     //    printf("len=%" PRId64 " fft_len_log2=%d dpl=%d\n", len, fft_len_log2, dpl);
8379 #if defined(USE_MUL_CHECK)
8380     ha = mp_mod1(a_tab, a_len, BF_CHKSUM_MOD, 0);
8381     hb = mp_mod1(b_tab, b_len, BF_CHKSUM_MOD, 0);
8382 #endif
8383     if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) == 0) {
8384         if (!(mul_flags & FFT_MUL_R_NORESIZE))
8385             bf_resize(res, 0);
8386     } else if (mul_flags & FFT_MUL_R_OVERLAP_B) {
8387         limb_t *tmp_tab, tmp_len;
8388         /* it is better to free 'b' first */
8389         tmp_tab = a_tab;
8390         a_tab = b_tab;
8391         b_tab = tmp_tab;
8392         tmp_len = a_len;
8393         a_len = b_len;
8394         b_len = tmp_len;
8395     }
8396     buf1 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8397     if (!buf1)
8398         return -1;
8399     limb_to_ntt(s, buf1, fft_len, a_tab, a_len, dpl,
8400                 NB_MODS - nb_mods, nb_mods);
8401     if ((mul_flags & (FFT_MUL_R_OVERLAP_A | FFT_MUL_R_OVERLAP_B)) ==
8402         FFT_MUL_R_OVERLAP_A) {
8403         if (!(mul_flags & FFT_MUL_R_NORESIZE))
8404             bf_resize(res, 0);
8405     }
8406     reduced_mem = (fft_len_log2 >= 14);
8407     if (!reduced_mem) {
8408         buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len * nb_mods);
8409         if (!buf2)
8410             goto fail;
8411         limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8412                     NB_MODS - nb_mods, nb_mods);
8413         if (!(mul_flags & FFT_MUL_R_NORESIZE))
8414             bf_resize(res, 0); /* in case res == b */
8415     } else {
8416         buf2 = ntt_malloc(s, sizeof(NTTLimb) * fft_len);
8417         if (!buf2)
8418             goto fail;
8419     }
8420     for(j = 0; j < nb_mods; j++) {
8421         if (reduced_mem) {
8422             limb_to_ntt(s, buf2, fft_len, b_tab, b_len, dpl,
8423                         NB_MODS - nb_mods + j, 1);
8424             ptr = buf2;
8425         } else {
8426             ptr = buf2 + fft_len * j;
8427         }
8428         if (ntt_conv(s, buf1 + fft_len * j, ptr,
8429                      fft_len_log2, fft_len_log2, j + NB_MODS - nb_mods))
8430             goto fail;
8431     }
8432     if (!(mul_flags & FFT_MUL_R_NORESIZE))
8433         bf_resize(res, 0); /* in case res == b and reduced mem */
8434     ntt_free(s, buf2);
8435     buf2 = NULL;
8436     if (!(mul_flags & FFT_MUL_R_NORESIZE)) {
8437         if (bf_resize(res, len))
8438             goto fail;
8439     }
8440     ntt_to_limb(s, res->tab, len, buf1, fft_len_log2, dpl, nb_mods);
8441     ntt_free(s, buf1);
8442 #if defined(USE_MUL_CHECK)
8443     hr = mp_mod1(res->tab, len, BF_CHKSUM_MOD, 0);
8444     h_ref = mul_mod(ha, hb, BF_CHKSUM_MOD);
8445     if (hr != h_ref) {
8446         printf("ntt_mul_error: len=%" PRId_LIMB " fft_len_log2=%d dpl=%d nb_mods=%d\n",
8447                len, fft_len_log2, dpl, nb_mods);
8448         //        printf("ha=0x" FMT_LIMB" hb=0x" FMT_LIMB " hr=0x" FMT_LIMB " expected=0x" FMT_LIMB "\n", ha, hb, hr, h_ref);
8449         exit(1);
8450     }
8451 #endif
8452     return 0;
8453  fail:
8454     ntt_free(s, buf1);
8455     ntt_free(s, buf2);
8456     return -1;
8457 }
8458 
8459 #else /* USE_FFT_MUL */
8460 
bf_get_fft_size(int * pdpl,int * pnb_mods,limb_t len)8461 int bf_get_fft_size(int *pdpl, int *pnb_mods, limb_t len)
8462 {
8463     return 0;
8464 }
8465 
8466 #endif /* !USE_FFT_MUL */
8467