1 /*
2 Name: imath.c
3 Purpose: Arbitrary precision integer arithmetic routines.
4 Author: M. J. Fromberger <http://spinning-yarns.org/michael/>
5
6 Copyright (C) 2002-2007 Michael J. Fromberger, All Rights Reserved.
7
8 Permission is hereby granted, free of charge, to any person obtaining a copy
9 of this software and associated documentation files (the "Software"), to deal
10 in the Software without restriction, including without limitation the rights
11 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 copies of the Software, and to permit persons to whom the Software is
13 furnished to do so, subject to the following conditions:
14
15 The above copyright notice and this permission notice shall be included in
16 all copies or substantial portions of the Software.
17
18 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 SOFTWARE.
25 */
26
27 #include "imath.h"
28
29 #if DEBUG
30 #include <stdio.h>
31 #endif
32
33 #include <stdlib.h>
34 #include <string.h>
35 #include <ctype.h>
36
37 #include <assert.h>
38
39 #if DEBUG
40 #define STATIC /* public */
41 #else
42 #define STATIC static
43 #endif
44
45 const mp_result MP_OK = 0; /* no error, all is well */
46 const mp_result MP_FALSE = 0; /* boolean false */
47 const mp_result MP_TRUE = -1; /* boolean true */
48 const mp_result MP_MEMORY = -2; /* out of memory */
49 const mp_result MP_RANGE = -3; /* argument out of range */
50 const mp_result MP_UNDEF = -4; /* result undefined */
51 const mp_result MP_TRUNC = -5; /* output truncated */
52 const mp_result MP_BADARG = -6; /* invalid null argument */
53 const mp_result MP_MINERR = -6;
54
55 const mp_sign MP_NEG = 1; /* value is strictly negative */
56 const mp_sign MP_ZPOS = 0; /* value is non-negative */
57
58 STATIC const char *s_unknown_err = "unknown result code";
59 STATIC const char *s_error_msg[] = {
60 "error code 0",
61 "boolean true",
62 "out of memory",
63 "argument out of range",
64 "result undefined",
65 "output truncated",
66 "invalid argument",
67 NULL
68 };
69
70 /* Argument checking macros
71 Use CHECK() where a return value is required; NRCHECK() elsewhere */
72 #define CHECK(TEST) assert(TEST)
73 #define NRCHECK(TEST) assert(TEST)
74
75 /* The ith entry of this table gives the value of log_i(2).
76
77 An integer value n requires ceil(log_i(n)) digits to be represented
78 in base i. Since it is easy to compute lg(n), by counting bits, we
79 can compute log_i(n) = lg(n) * log_i(2).
80
81 The use of this table eliminates a dependency upon linkage against
82 the standard math libraries.
83
84 If MP_MAX_RADIX is increased, this table should be expanded too.
85 */
86 STATIC const double s_log2[] = {
87 0.000000000, 0.000000000, 1.000000000, 0.630929754, /* (D)(D) 2 3 */
88 0.500000000, 0.430676558, 0.386852807, 0.356207187, /* 4 5 6 7 */
89 0.333333333, 0.315464877, 0.301029996, 0.289064826, /* 8 9 10 11 */
90 0.278942946, 0.270238154, 0.262649535, 0.255958025, /* 12 13 14 15 */
91 0.250000000, 0.244650542, 0.239812467, 0.235408913, /* 16 17 18 19 */
92 0.231378213, 0.227670249, 0.224243824, 0.221064729, /* 20 21 22 23 */
93 0.218104292, 0.215338279, 0.212746054, 0.210309918, /* 24 25 26 27 */
94 0.208014598, 0.205846832, 0.203795047, 0.201849087, /* 28 29 30 31 */
95 0.200000000, 0.198239863, 0.196561632, 0.194959022, /* 32 33 34 35 */
96 0.193426404, /* 36 */
97 };
98
99
100
101 /* Return the number of digits needed to represent a static value */
102 #define MP_VALUE_DIGITS(V) \
103 ((sizeof(V)+(sizeof(mp_digit)-1))/sizeof(mp_digit))
104
105 /* Round precision P to nearest word boundary */
106 #define ROUND_PREC(P) ((mp_size)(2*(((P)+1)/2)))
107
108 /* Set array P of S digits to zero */
109 #define ZERO(P, S) \
110 do{ \
111 mp_size i__ = (S) * sizeof(mp_digit); \
112 mp_digit *p__ = (P); \
113 memset(p__, 0, i__); \
114 } while(0)
115
116 /* Copy S digits from array P to array Q */
117 #define COPY(P, Q, S) \
118 do{ \
119 mp_size i__ = (S) * sizeof(mp_digit); \
120 mp_digit *p__ = (P), *q__ = (Q); \
121 memcpy(q__, p__, i__); \
122 } while(0)
123
124 /* Reverse N elements of type T in array A */
125 #define REV(T, A, N) \
126 do{ \
127 T *u_ = (A), *v_ = u_ + (N) - 1; \
128 while (u_ < v_) { \
129 T xch = *u_; \
130 *u_++ = *v_; \
131 *v_-- = xch; \
132 } \
133 } while(0)
134
135 #define CLAMP(Z) \
136 do{ \
137 mp_int z_ = (Z); \
138 mp_size uz_ = MP_USED(z_); \
139 mp_digit *dz_ = MP_DIGITS(z_) + uz_ -1; \
140 while (uz_ > 1 && (*dz_-- == 0)) \
141 --uz_; \
142 MP_USED(z_) = uz_; \
143 } while(0)
144
145 /* Select min/max. Do not provide expressions for which multiple
146 evaluation would be problematic, e.g. x++ */
147 #define MIN(A, B) ((B)<(A)?(B):(A))
148 #define MAX(A, B) ((B)>(A)?(B):(A))
149
150 /* Exchange lvalues A and B of type T, e.g.
151 SWAP(int, x, y) where x and y are variables of type int. */
152 #define SWAP(T, A, B) \
153 do{ \
154 T t_ = (A); \
155 A = (B); \
156 B = t_; \
157 } while(0)
158
159 /* Used to set up and access simple temp stacks within functions. */
160 #define DECLARE_TEMP(N) \
161 mpz_t temp[(N)]; \
162 int last__ = 0
163 #define CLEANUP_TEMP() \
164 CLEANUP: \
165 while (--last__ >= 0) \
166 mp_int_clear(TEMP(last__))
167 #define TEMP(K) (temp + (K))
168 #define LAST_TEMP() TEMP(last__)
169 #define SETUP(E) \
170 do{ \
171 if ((res = (E)) != MP_OK) \
172 goto CLEANUP; \
173 ++(last__); \
174 } while(0)
175
176 /* Compare value to zero. */
177 #define CMPZ(Z) \
178 (((Z)->used==1&&(Z)->digits[0]==0)?0:((Z)->sign==MP_NEG)?-1:1)
179
180 /* Multiply X by Y into Z, ignoring signs. Requires that Z have
181 enough storage preallocated to hold the result. */
182 #define UMUL(X, Y, Z) \
183 do{ \
184 mp_size ua_ = MP_USED(X), ub_ = MP_USED(Y); \
185 mp_size o_ = ua_ + ub_; \
186 ZERO(MP_DIGITS(Z), o_); \
187 (void) s_kmul(MP_DIGITS(X), MP_DIGITS(Y), MP_DIGITS(Z), ua_, ub_); \
188 MP_USED(Z) = o_; \
189 CLAMP(Z); \
190 } while(0)
191
192 /* Square X into Z. Requires that Z have enough storage to hold the
193 result. */
194 #define USQR(X, Z) \
195 do{ \
196 mp_size ua_ = MP_USED(X), o_ = ua_ + ua_; \
197 ZERO(MP_DIGITS(Z), o_); \
198 (void) s_ksqr(MP_DIGITS(X), MP_DIGITS(Z), ua_); \
199 MP_USED(Z) = o_; \
200 CLAMP(Z); \
201 } while(0)
202
203 #define UPPER_HALF(W) ((mp_word)((W) >> MP_DIGIT_BIT))
204 #define LOWER_HALF(W) ((mp_digit)(W))
205 #define HIGH_BIT_SET(W) ((W) >> (MP_WORD_BIT - 1))
206 #define ADD_WILL_OVERFLOW(W, V) ((MP_WORD_MAX - (V)) < (W))
207
208
209
210 /* Default number of digits allocated to a new mp_int */
211 #if IMATH_TEST
212 mp_size default_precision = MP_DEFAULT_PREC;
213 #else
214 STATIC const mp_size default_precision = MP_DEFAULT_PREC;
215 #endif
216
217 /* Minimum number of digits to invoke recursive multiply */
218 #if IMATH_TEST
219 mp_size multiply_threshold = MP_MULT_THRESH;
220 #else
221 STATIC const mp_size multiply_threshold = MP_MULT_THRESH;
222 #endif
223
224 /* Allocate a buffer of (at least) num digits, or return
225 NULL if that couldn't be done. */
226 STATIC mp_digit *s_alloc(mp_size num);
227
228 /* Release a buffer of digits allocated by s_alloc(). */
229 STATIC void s_free(void *ptr);
230
231 /* Insure that z has at least min digits allocated, resizing if
232 necessary. Returns true if successful, false if out of memory. */
233 STATIC int s_pad(mp_int z, mp_size min);
234
235 /* Fill in a "fake" mp_int on the stack with a given value */
236 STATIC void s_fake(mp_int z, mp_small value, mp_digit vbuf[]);
237 STATIC void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[]);
238
239 /* Compare two runs of digits of given length, returns <0, 0, >0 */
240 STATIC int s_cdig(mp_digit *da, mp_digit *db, mp_size len);
241
242 /* Pack the unsigned digits of v into array t */
243 STATIC int s_uvpack(mp_usmall v, mp_digit t[]);
244
245 /* Compare magnitudes of a and b, returns <0, 0, >0 */
246 STATIC int s_ucmp(mp_int a, mp_int b);
247
248 /* Compare magnitudes of a and v, returns <0, 0, >0 */
249 STATIC int s_vcmp(mp_int a, mp_small v);
250 STATIC int s_uvcmp(mp_int a, mp_usmall uv);
251
252 /* Unsigned magnitude addition; assumes dc is big enough.
253 Carry out is returned (no memory allocated). */
254 STATIC mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc,
255 mp_size size_a, mp_size size_b);
256
257 /* Unsigned magnitude subtraction. Assumes dc is big enough. */
258 STATIC void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc,
259 mp_size size_a, mp_size size_b);
260
261 /* Unsigned recursive multiplication. Assumes dc is big enough. */
262 STATIC int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc,
263 mp_size size_a, mp_size size_b);
264
265 /* Unsigned magnitude multiplication. Assumes dc is big enough. */
266 STATIC void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc,
267 mp_size size_a, mp_size size_b);
268
269 /* Unsigned recursive squaring. Assumes dc is big enough. */
270 STATIC int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a);
271
272 /* Unsigned magnitude squaring. Assumes dc is big enough. */
273 STATIC void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a);
274
275 /* Single digit addition. Assumes a is big enough. */
276 STATIC void s_dadd(mp_int a, mp_digit b);
277
278 /* Single digit multiplication. Assumes a is big enough. */
279 STATIC void s_dmul(mp_int a, mp_digit b);
280
281 /* Single digit multiplication on buffers; assumes dc is big enough. */
282 STATIC void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc,
283 mp_size size_a);
284
285 /* Single digit division. Replaces a with the quotient,
286 returns the remainder. */
287 STATIC mp_digit s_ddiv(mp_int a, mp_digit b);
288
289 /* Quick division by a power of 2, replaces z (no allocation) */
290 STATIC void s_qdiv(mp_int z, mp_size p2);
291
292 /* Quick remainder by a power of 2, replaces z (no allocation) */
293 STATIC void s_qmod(mp_int z, mp_size p2);
294
295 /* Quick multiplication by a power of 2, replaces z.
296 Allocates if necessary; returns false in case this fails. */
297 STATIC int s_qmul(mp_int z, mp_size p2);
298
299 /* Quick subtraction from a power of 2, replaces z.
300 Allocates if necessary; returns false in case this fails. */
301 STATIC int s_qsub(mp_int z, mp_size p2);
302
303 /* Return maximum k such that 2^k divides z. */
304 STATIC int s_dp2k(mp_int z);
305
306 /* Return k >= 0 such that z = 2^k, or -1 if there is no such k. */
307 STATIC int s_isp2(mp_int z);
308
309 /* Set z to 2^k. May allocate; returns false in case this fails. */
310 STATIC int s_2expt(mp_int z, mp_small k);
311
312 /* Normalize a and b for division, returns normalization constant */
313 STATIC int s_norm(mp_int a, mp_int b);
314
315 /* Compute constant mu for Barrett reduction, given modulus m, result
316 replaces z, m is untouched. */
317 STATIC mp_result s_brmu(mp_int z, mp_int m);
318
319 /* Reduce a modulo m, using Barrett's algorithm. */
320 STATIC int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2);
321
322 /* Modular exponentiation, using Barrett reduction */
323 STATIC mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c);
324
325 /* Unsigned magnitude division. Assumes |a| > |b|. Allocates temporaries;
326 overwrites a with quotient, b with remainder. */
327 STATIC mp_result s_udiv_knuth(mp_int a, mp_int b);
328
329 /* Compute the number of digits in radix r required to represent the given
330 value. Does not account for sign flags, terminators, etc. */
331 STATIC int s_outlen(mp_int z, mp_size r);
332
333 /* Guess how many digits of precision will be needed to represent a radix r
334 value of the specified number of digits. Returns a value guaranteed to be
335 no smaller than the actual number required. */
336 STATIC mp_size s_inlen(int len, mp_size r);
337
338 /* Convert a character to a digit value in radix r, or
339 -1 if out of range */
340 STATIC int s_ch2val(char c, int r);
341
342 /* Convert a digit value to a character */
343 STATIC char s_val2ch(int v, int caps);
344
345 /* Take 2's complement of a buffer in place */
346 STATIC void s_2comp(unsigned char *buf, int len);
347
348 /* Convert a value to binary, ignoring sign. On input, *limpos is the bound on
349 how many bytes should be written to buf; on output, *limpos is set to the
350 number of bytes actually written. */
351 STATIC mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad);
352
353 #if DEBUG
354 /* Dump a representation of the mp_int to standard output */
355 void s_print(char *tag, mp_int z);
356 void s_print_buf(char *tag, mp_digit *buf, mp_size num);
357 #endif
358
mp_int_init(mp_int z)359 mp_result mp_int_init(mp_int z)
360 {
361 if (z == NULL)
362 return MP_BADARG;
363
364 z->single = 0;
365 z->digits = &(z->single);
366 z->alloc = 1;
367 z->used = 1;
368 z->sign = MP_ZPOS;
369
370 return MP_OK;
371 }
372
mp_int_alloc(void)373 mp_int mp_int_alloc(void)
374 {
375 mp_int out = malloc(sizeof(mpz_t));
376
377 if (out != NULL)
378 mp_int_init(out);
379
380 return out;
381 }
382
mp_int_init_size(mp_int z,mp_size prec)383 mp_result mp_int_init_size(mp_int z, mp_size prec)
384 {
385 CHECK(z != NULL);
386
387 if (prec == 0)
388 prec = default_precision;
389 else if (prec == 1)
390 return mp_int_init(z);
391 else
392 prec = (mp_size) ROUND_PREC(prec);
393
394 if ((MP_DIGITS(z) = s_alloc(prec)) == NULL)
395 return MP_MEMORY;
396
397 z->digits[0] = 0;
398 MP_USED(z) = 1;
399 MP_ALLOC(z) = prec;
400 MP_SIGN(z) = MP_ZPOS;
401
402 return MP_OK;
403 }
404
mp_int_init_copy(mp_int z,mp_int old)405 mp_result mp_int_init_copy(mp_int z, mp_int old)
406 {
407 mp_result res;
408 mp_size uold;
409
410 CHECK(z != NULL && old != NULL);
411
412 uold = MP_USED(old);
413 if (uold == 1) {
414 mp_int_init(z);
415 }
416 else {
417 mp_size target = MAX(uold, default_precision);
418
419 if ((res = mp_int_init_size(z, target)) != MP_OK)
420 return res;
421 }
422
423 MP_USED(z) = uold;
424 MP_SIGN(z) = MP_SIGN(old);
425 COPY(MP_DIGITS(old), MP_DIGITS(z), uold);
426
427 return MP_OK;
428 }
429
mp_int_init_value(mp_int z,mp_small value)430 mp_result mp_int_init_value(mp_int z, mp_small value)
431 {
432 mpz_t vtmp;
433 mp_digit vbuf[MP_VALUE_DIGITS(value)];
434
435 s_fake(&vtmp, value, vbuf);
436 return mp_int_init_copy(z, &vtmp);
437 }
438
mp_int_init_uvalue(mp_int z,mp_usmall uvalue)439 mp_result mp_int_init_uvalue(mp_int z, mp_usmall uvalue)
440 {
441 mpz_t vtmp;
442 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
443
444 s_ufake(&vtmp, uvalue, vbuf);
445 return mp_int_init_copy(z, &vtmp);
446 }
447
mp_int_set_value(mp_int z,mp_small value)448 mp_result mp_int_set_value(mp_int z, mp_small value)
449 {
450 mpz_t vtmp;
451 mp_digit vbuf[MP_VALUE_DIGITS(value)];
452
453 s_fake(&vtmp, value, vbuf);
454 return mp_int_copy(&vtmp, z);
455 }
456
mp_int_set_uvalue(mp_int z,mp_usmall uvalue)457 mp_result mp_int_set_uvalue(mp_int z, mp_usmall uvalue)
458 {
459 mpz_t vtmp;
460 mp_digit vbuf[MP_VALUE_DIGITS(uvalue)];
461
462 s_ufake(&vtmp, uvalue, vbuf);
463 return mp_int_copy(&vtmp, z);
464 }
465
mp_int_clear(mp_int z)466 void mp_int_clear(mp_int z)
467 {
468 if (z == NULL)
469 return;
470
471 if (MP_DIGITS(z) != NULL) {
472 if (MP_DIGITS(z) != &(z->single))
473 s_free(MP_DIGITS(z));
474
475 MP_DIGITS(z) = NULL;
476 }
477 }
478
mp_int_free(mp_int z)479 void mp_int_free(mp_int z)
480 {
481 NRCHECK(z != NULL);
482
483 mp_int_clear(z);
484 free(z); /* note: NOT s_free() */
485 }
486
mp_int_copy(mp_int a,mp_int c)487 mp_result mp_int_copy(mp_int a, mp_int c)
488 {
489 CHECK(a != NULL && c != NULL);
490
491 if (a != c) {
492 mp_size ua = MP_USED(a);
493 mp_digit *da, *dc;
494
495 if (!s_pad(c, ua))
496 return MP_MEMORY;
497
498 da = MP_DIGITS(a); dc = MP_DIGITS(c);
499 COPY(da, dc, ua);
500
501 MP_USED(c) = ua;
502 MP_SIGN(c) = MP_SIGN(a);
503 }
504
505 return MP_OK;
506 }
507
mp_int_swap(mp_int a,mp_int c)508 void mp_int_swap(mp_int a, mp_int c)
509 {
510 if (a != c) {
511 mpz_t tmp = *a;
512
513 *a = *c;
514 *c = tmp;
515
516 if (MP_DIGITS(a) == &(c->single))
517 MP_DIGITS(a) = &(a->single);
518 if (MP_DIGITS(c) == &(a->single))
519 MP_DIGITS(c) = &(c->single);
520 }
521 }
522
mp_int_zero(mp_int z)523 void mp_int_zero(mp_int z)
524 {
525 NRCHECK(z != NULL);
526
527 z->digits[0] = 0;
528 MP_USED(z) = 1;
529 MP_SIGN(z) = MP_ZPOS;
530 }
531
mp_int_abs(mp_int a,mp_int c)532 mp_result mp_int_abs(mp_int a, mp_int c)
533 {
534 mp_result res;
535
536 CHECK(a != NULL && c != NULL);
537
538 if ((res = mp_int_copy(a, c)) != MP_OK)
539 return res;
540
541 MP_SIGN(c) = MP_ZPOS;
542 return MP_OK;
543 }
544
mp_int_neg(mp_int a,mp_int c)545 mp_result mp_int_neg(mp_int a, mp_int c)
546 {
547 mp_result res;
548
549 CHECK(a != NULL && c != NULL);
550
551 if ((res = mp_int_copy(a, c)) != MP_OK)
552 return res;
553
554 if (CMPZ(c) != 0)
555 MP_SIGN(c) = 1 - MP_SIGN(a);
556
557 return MP_OK;
558 }
559
mp_int_add(mp_int a,mp_int b,mp_int c)560 mp_result mp_int_add(mp_int a, mp_int b, mp_int c)
561 {
562 mp_size ua, ub, uc, max;
563
564 CHECK(a != NULL && b != NULL && c != NULL);
565
566 ua = MP_USED(a); ub = MP_USED(b); uc = MP_USED(c);
567 max = MAX(ua, ub);
568
569 if (MP_SIGN(a) == MP_SIGN(b)) {
570 /* Same sign -- add magnitudes, preserve sign of addends */
571 mp_digit carry;
572
573 if (!s_pad(c, max))
574 return MP_MEMORY;
575
576 carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
577 uc = max;
578
579 if (carry) {
580 if (!s_pad(c, max + 1))
581 return MP_MEMORY;
582
583 c->digits[max] = carry;
584 ++uc;
585 }
586
587 MP_USED(c) = uc;
588 MP_SIGN(c) = MP_SIGN(a);
589
590 }
591 else {
592 /* Different signs -- subtract magnitudes, preserve sign of greater */
593 mp_int x, y;
594 int cmp = s_ucmp(a, b); /* magnitude comparision, sign ignored */
595
596 /* Set x to max(a, b), y to min(a, b) to simplify later code.
597 A special case yields zero for equal magnitudes.
598 */
599 if (cmp == 0) {
600 mp_int_zero(c);
601 return MP_OK;
602 }
603 else if (cmp < 0) {
604 x = b; y = a;
605 }
606 else {
607 x = a; y = b;
608 }
609
610 if (!s_pad(c, MP_USED(x)))
611 return MP_MEMORY;
612
613 /* Subtract smaller from larger */
614 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
615 MP_USED(c) = MP_USED(x);
616 CLAMP(c);
617
618 /* Give result the sign of the larger */
619 MP_SIGN(c) = MP_SIGN(x);
620 }
621
622 return MP_OK;
623 }
624
mp_int_add_value(mp_int a,mp_small value,mp_int c)625 mp_result mp_int_add_value(mp_int a, mp_small value, mp_int c)
626 {
627 mpz_t vtmp;
628 mp_digit vbuf[MP_VALUE_DIGITS(value)];
629
630 s_fake(&vtmp, value, vbuf);
631
632 return mp_int_add(a, &vtmp, c);
633 }
634
mp_int_sub(mp_int a,mp_int b,mp_int c)635 mp_result mp_int_sub(mp_int a, mp_int b, mp_int c)
636 {
637 mp_size ua, ub, uc, max;
638
639 CHECK(a != NULL && b != NULL && c != NULL);
640
641 ua = MP_USED(a); ub = MP_USED(b); uc = MP_USED(c);
642 max = MAX(ua, ub);
643
644 if (MP_SIGN(a) != MP_SIGN(b)) {
645 /* Different signs -- add magnitudes and keep sign of a */
646 mp_digit carry;
647
648 if (!s_pad(c, max))
649 return MP_MEMORY;
650
651 carry = s_uadd(MP_DIGITS(a), MP_DIGITS(b), MP_DIGITS(c), ua, ub);
652 uc = max;
653
654 if (carry) {
655 if (!s_pad(c, max + 1))
656 return MP_MEMORY;
657
658 c->digits[max] = carry;
659 ++uc;
660 }
661
662 MP_USED(c) = uc;
663 MP_SIGN(c) = MP_SIGN(a);
664
665 }
666 else {
667 /* Same signs -- subtract magnitudes */
668 mp_int x, y;
669 mp_sign osign;
670 int cmp = s_ucmp(a, b);
671
672 if (!s_pad(c, max))
673 return MP_MEMORY;
674
675 if (cmp >= 0) {
676 x = a; y = b; osign = MP_ZPOS;
677 }
678 else {
679 x = b; y = a; osign = MP_NEG;
680 }
681
682 if (MP_SIGN(a) == MP_NEG && cmp != 0)
683 osign = 1 - osign;
684
685 s_usub(MP_DIGITS(x), MP_DIGITS(y), MP_DIGITS(c), MP_USED(x), MP_USED(y));
686 MP_USED(c) = MP_USED(x);
687 CLAMP(c);
688
689 MP_SIGN(c) = osign;
690 }
691
692 return MP_OK;
693 }
694
mp_int_sub_value(mp_int a,mp_small value,mp_int c)695 mp_result mp_int_sub_value(mp_int a, mp_small value, mp_int c)
696 {
697 mpz_t vtmp;
698 mp_digit vbuf[MP_VALUE_DIGITS(value)];
699
700 s_fake(&vtmp, value, vbuf);
701
702 return mp_int_sub(a, &vtmp, c);
703 }
704
mp_int_mul(mp_int a,mp_int b,mp_int c)705 mp_result mp_int_mul(mp_int a, mp_int b, mp_int c)
706 {
707 mp_digit *out;
708 mp_size osize, ua, ub, p = 0;
709 mp_sign osign;
710
711 CHECK(a != NULL && b != NULL && c != NULL);
712
713 /* If either input is zero, we can shortcut multiplication */
714 if (mp_int_compare_zero(a) == 0 || mp_int_compare_zero(b) == 0) {
715 mp_int_zero(c);
716 return MP_OK;
717 }
718
719 /* Output is positive if inputs have same sign, otherwise negative */
720 osign = (MP_SIGN(a) == MP_SIGN(b)) ? MP_ZPOS : MP_NEG;
721
722 /* If the output is not identical to any of the inputs, we'll write the
723 results directly; otherwise, allocate a temporary space. */
724 ua = MP_USED(a); ub = MP_USED(b);
725 osize = MAX(ua, ub);
726 osize = 4 * ((osize + 1) / 2);
727
728 if (c == a || c == b) {
729 p = ROUND_PREC(osize);
730 p = MAX(p, default_precision);
731
732 if ((out = s_alloc(p)) == NULL)
733 return MP_MEMORY;
734 }
735 else {
736 if (!s_pad(c, osize))
737 return MP_MEMORY;
738
739 out = MP_DIGITS(c);
740 }
741 ZERO(out, osize);
742
743 if (!s_kmul(MP_DIGITS(a), MP_DIGITS(b), out, ua, ub))
744 return MP_MEMORY;
745
746 /* If we allocated a new buffer, get rid of whatever memory c was already
747 using, and fix up its fields to reflect that.
748 */
749 if (out != MP_DIGITS(c)) {
750 if ((void *) MP_DIGITS(c) != (void *) c)
751 s_free(MP_DIGITS(c));
752 MP_DIGITS(c) = out;
753 MP_ALLOC(c) = p;
754 }
755
756 MP_USED(c) = osize; /* might not be true, but we'll fix it ... */
757 CLAMP(c); /* ... right here */
758 MP_SIGN(c) = osign;
759
760 return MP_OK;
761 }
762
mp_int_mul_value(mp_int a,mp_small value,mp_int c)763 mp_result mp_int_mul_value(mp_int a, mp_small value, mp_int c)
764 {
765 mpz_t vtmp;
766 mp_digit vbuf[MP_VALUE_DIGITS(value)];
767
768 s_fake(&vtmp, value, vbuf);
769
770 return mp_int_mul(a, &vtmp, c);
771 }
772
mp_int_mul_pow2(mp_int a,mp_small p2,mp_int c)773 mp_result mp_int_mul_pow2(mp_int a, mp_small p2, mp_int c)
774 {
775 mp_result res;
776 CHECK(a != NULL && c != NULL && p2 >= 0);
777
778 if ((res = mp_int_copy(a, c)) != MP_OK)
779 return res;
780
781 if (s_qmul(c, (mp_size) p2))
782 return MP_OK;
783 else
784 return MP_MEMORY;
785 }
786
mp_int_sqr(mp_int a,mp_int c)787 mp_result mp_int_sqr(mp_int a, mp_int c)
788 {
789 mp_digit *out;
790 mp_size osize, p = 0;
791
792 CHECK(a != NULL && c != NULL);
793
794 /* Get a temporary buffer big enough to hold the result */
795 osize = (mp_size) 4 * ((MP_USED(a) + 1) / 2);
796 if (a == c) {
797 p = ROUND_PREC(osize);
798 p = MAX(p, default_precision);
799
800 if ((out = s_alloc(p)) == NULL)
801 return MP_MEMORY;
802 }
803 else {
804 if (!s_pad(c, osize))
805 return MP_MEMORY;
806
807 out = MP_DIGITS(c);
808 }
809 ZERO(out, osize);
810
811 s_ksqr(MP_DIGITS(a), out, MP_USED(a));
812
813 /* Get rid of whatever memory c was already using, and fix up its fields to
814 reflect the new digit array it's using
815 */
816 if (out != MP_DIGITS(c)) {
817 if ((void *) MP_DIGITS(c) != (void *) c)
818 s_free(MP_DIGITS(c));
819 MP_DIGITS(c) = out;
820 MP_ALLOC(c) = p;
821 }
822
823 MP_USED(c) = osize; /* might not be true, but we'll fix it ... */
824 CLAMP(c); /* ... right here */
825 MP_SIGN(c) = MP_ZPOS;
826
827 return MP_OK;
828 }
829
mp_int_div(mp_int a,mp_int b,mp_int q,mp_int r)830 mp_result mp_int_div(mp_int a, mp_int b, mp_int q, mp_int r)
831 {
832 int cmp, lg;
833 mp_result res = MP_OK;
834 mp_int qout, rout;
835 mp_sign sa = MP_SIGN(a), sb = MP_SIGN(b);
836 DECLARE_TEMP(2);
837
838 CHECK(a != NULL && b != NULL && q != r);
839
840 if (CMPZ(b) == 0)
841 return MP_UNDEF;
842 else if ((cmp = s_ucmp(a, b)) < 0) {
843 /* If |a| < |b|, no division is required:
844 q = 0, r = a
845 */
846 if (r && (res = mp_int_copy(a, r)) != MP_OK)
847 return res;
848
849 if (q)
850 mp_int_zero(q);
851
852 return MP_OK;
853 }
854 else if (cmp == 0) {
855 /* If |a| = |b|, no division is required:
856 q = 1 or -1, r = 0
857 */
858 if (r)
859 mp_int_zero(r);
860
861 if (q) {
862 mp_int_zero(q);
863 q->digits[0] = 1;
864
865 if (sa != sb)
866 MP_SIGN(q) = MP_NEG;
867 }
868
869 return MP_OK;
870 }
871
872 /* When |a| > |b|, real division is required. We need someplace to store
873 quotient and remainder, but q and r are allowed to be NULL or to overlap
874 with the inputs.
875 */
876 if ((lg = s_isp2(b)) < 0) {
877 if (q && b != q) {
878 if ((res = mp_int_copy(a, q)) != MP_OK)
879 goto CLEANUP;
880 else
881 qout = q;
882 }
883 else {
884 qout = LAST_TEMP();
885 SETUP(mp_int_init_copy(LAST_TEMP(), a));
886 }
887
888 if (r && a != r) {
889 if ((res = mp_int_copy(b, r)) != MP_OK)
890 goto CLEANUP;
891 else
892 rout = r;
893 }
894 else {
895 rout = LAST_TEMP();
896 SETUP(mp_int_init_copy(LAST_TEMP(), b));
897 }
898
899 if ((res = s_udiv_knuth(qout, rout)) != MP_OK) goto CLEANUP;
900 }
901 else {
902 if (q && (res = mp_int_copy(a, q)) != MP_OK) goto CLEANUP;
903 if (r && (res = mp_int_copy(a, r)) != MP_OK) goto CLEANUP;
904
905 if (q) s_qdiv(q, (mp_size) lg); qout = q;
906 if (r) s_qmod(r, (mp_size) lg); rout = r;
907 }
908
909 /* Recompute signs for output */
910 if (rout) {
911 MP_SIGN(rout) = sa;
912 if (CMPZ(rout) == 0)
913 MP_SIGN(rout) = MP_ZPOS;
914 }
915 if (qout) {
916 MP_SIGN(qout) = (sa == sb) ? MP_ZPOS : MP_NEG;
917 if (CMPZ(qout) == 0)
918 MP_SIGN(qout) = MP_ZPOS;
919 }
920
921 if (q && (res = mp_int_copy(qout, q)) != MP_OK) goto CLEANUP;
922 if (r && (res = mp_int_copy(rout, r)) != MP_OK) goto CLEANUP;
923
924 CLEANUP_TEMP();
925 return res;
926 }
927
mp_int_mod(mp_int a,mp_int m,mp_int c)928 mp_result mp_int_mod(mp_int a, mp_int m, mp_int c)
929 {
930 mp_result res;
931 mpz_t tmp;
932 mp_int out;
933
934 if (m == c) {
935 mp_int_init(&tmp);
936 out = &tmp;
937 }
938 else {
939 out = c;
940 }
941
942 if ((res = mp_int_div(a, m, NULL, out)) != MP_OK)
943 goto CLEANUP;
944
945 if (CMPZ(out) < 0)
946 res = mp_int_add(out, m, c);
947 else
948 res = mp_int_copy(out, c);
949
950 CLEANUP:
951 if (out != c)
952 mp_int_clear(&tmp);
953
954 return res;
955 }
956
mp_int_div_value(mp_int a,mp_small value,mp_int q,mp_small * r)957 mp_result mp_int_div_value(mp_int a, mp_small value, mp_int q, mp_small *r)
958 {
959 mpz_t vtmp, rtmp;
960 mp_digit vbuf[MP_VALUE_DIGITS(value)];
961 mp_result res;
962
963 mp_int_init(&rtmp);
964 s_fake(&vtmp, value, vbuf);
965
966 if ((res = mp_int_div(a, &vtmp, q, &rtmp)) != MP_OK)
967 goto CLEANUP;
968
969 if (r)
970 (void) mp_int_to_int(&rtmp, r); /* can't fail */
971
972 CLEANUP:
973 mp_int_clear(&rtmp);
974 return res;
975 }
976
mp_int_div_pow2(mp_int a,mp_small p2,mp_int q,mp_int r)977 mp_result mp_int_div_pow2(mp_int a, mp_small p2, mp_int q, mp_int r)
978 {
979 mp_result res = MP_OK;
980
981 CHECK(a != NULL && p2 >= 0 && q != r);
982
983 if (q != NULL && (res = mp_int_copy(a, q)) == MP_OK)
984 s_qdiv(q, (mp_size) p2);
985
986 if (res == MP_OK && r != NULL && (res = mp_int_copy(a, r)) == MP_OK)
987 s_qmod(r, (mp_size) p2);
988
989 return res;
990 }
991
mp_int_expt(mp_int a,mp_small b,mp_int c)992 mp_result mp_int_expt(mp_int a, mp_small b, mp_int c)
993 {
994 mpz_t t;
995 mp_result res;
996 unsigned int v = labs(b);
997
998 CHECK(c != NULL);
999 if (b < 0)
1000 return MP_RANGE;
1001
1002 if ((res = mp_int_init_copy(&t, a)) != MP_OK)
1003 return res;
1004
1005 (void) mp_int_set_value(c, 1);
1006 while (v != 0) {
1007 if (v & 1) {
1008 if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1009 goto CLEANUP;
1010 }
1011
1012 v >>= 1;
1013 if (v == 0) break;
1014
1015 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1016 goto CLEANUP;
1017 }
1018
1019 CLEANUP:
1020 mp_int_clear(&t);
1021 return res;
1022 }
1023
mp_int_expt_value(mp_small a,mp_small b,mp_int c)1024 mp_result mp_int_expt_value(mp_small a, mp_small b, mp_int c)
1025 {
1026 mpz_t t;
1027 mp_result res;
1028 unsigned int v = labs(b);
1029
1030 CHECK(c != NULL);
1031 if (b < 0)
1032 return MP_RANGE;
1033
1034 if ((res = mp_int_init_value(&t, a)) != MP_OK)
1035 return res;
1036
1037 (void) mp_int_set_value(c, 1);
1038 while (v != 0) {
1039 if (v & 1) {
1040 if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1041 goto CLEANUP;
1042 }
1043
1044 v >>= 1;
1045 if (v == 0) break;
1046
1047 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1048 goto CLEANUP;
1049 }
1050
1051 CLEANUP:
1052 mp_int_clear(&t);
1053 return res;
1054 }
1055
mp_int_expt_full(mp_int a,mp_int b,mp_int c)1056 mp_result mp_int_expt_full(mp_int a, mp_int b, mp_int c)
1057 {
1058 mpz_t t;
1059 mp_result res;
1060 unsigned ix, jx;
1061
1062 CHECK(a != NULL && b != NULL && c != NULL);
1063 if (MP_SIGN(b) == MP_NEG)
1064 return MP_RANGE;
1065
1066 if ((res = mp_int_init_copy(&t, a)) != MP_OK)
1067 return res;
1068
1069 (void) mp_int_set_value(c, 1);
1070 for (ix = 0; ix < MP_USED(b); ++ix) {
1071 mp_digit d = b->digits[ix];
1072
1073 for (jx = 0; jx < MP_DIGIT_BIT; ++jx) {
1074 if (d & 1) {
1075 if ((res = mp_int_mul(c, &t, c)) != MP_OK)
1076 goto CLEANUP;
1077 }
1078
1079 d >>= 1;
1080 if (d == 0 && ix + 1 == MP_USED(b))
1081 break;
1082 if ((res = mp_int_sqr(&t, &t)) != MP_OK)
1083 goto CLEANUP;
1084 }
1085 }
1086
1087 CLEANUP:
1088 mp_int_clear(&t);
1089 return res;
1090 }
1091
mp_int_compare(mp_int a,mp_int b)1092 int mp_int_compare(mp_int a, mp_int b)
1093 {
1094 mp_sign sa;
1095
1096 CHECK(a != NULL && b != NULL);
1097
1098 sa = MP_SIGN(a);
1099 if (sa == MP_SIGN(b)) {
1100 int cmp = s_ucmp(a, b);
1101
1102 /* If they're both zero or positive, the normal comparison applies; if both
1103 negative, the sense is reversed. */
1104 if (sa == MP_ZPOS)
1105 return cmp;
1106 else
1107 return -cmp;
1108
1109 }
1110 else {
1111 if (sa == MP_ZPOS)
1112 return 1;
1113 else
1114 return -1;
1115 }
1116 }
1117
mp_int_compare_unsigned(mp_int a,mp_int b)1118 int mp_int_compare_unsigned(mp_int a, mp_int b)
1119 {
1120 NRCHECK(a != NULL && b != NULL);
1121
1122 return s_ucmp(a, b);
1123 }
1124
mp_int_compare_zero(mp_int z)1125 int mp_int_compare_zero(mp_int z)
1126 {
1127 NRCHECK(z != NULL);
1128
1129 if (MP_USED(z) == 1 && z->digits[0] == 0)
1130 return 0;
1131 else if (MP_SIGN(z) == MP_ZPOS)
1132 return 1;
1133 else
1134 return -1;
1135 }
1136
mp_int_compare_value(mp_int z,mp_small value)1137 int mp_int_compare_value(mp_int z, mp_small value)
1138 {
1139 mp_sign vsign = (value < 0) ? MP_NEG : MP_ZPOS;
1140 int cmp;
1141
1142 CHECK(z != NULL);
1143
1144 if (vsign == MP_SIGN(z)) {
1145 cmp = s_vcmp(z, value);
1146
1147 return (vsign == MP_ZPOS) ? cmp : -cmp;
1148 }
1149 else {
1150 return (value < 0) ? 1 : -1;
1151 }
1152 }
1153
mp_int_compare_uvalue(mp_int z,mp_usmall uv)1154 int mp_int_compare_uvalue(mp_int z, mp_usmall uv)
1155 {
1156 CHECK(z != NULL);
1157
1158 if (MP_SIGN(z) == MP_NEG)
1159 return -1;
1160 else
1161 return s_uvcmp(z, uv);
1162 }
1163
mp_int_exptmod(mp_int a,mp_int b,mp_int m,mp_int c)1164 mp_result mp_int_exptmod(mp_int a, mp_int b, mp_int m, mp_int c)
1165 {
1166 mp_result res;
1167 mp_size um;
1168 mp_int s;
1169 DECLARE_TEMP(3);
1170
1171 CHECK(a != NULL && b != NULL && c != NULL && m != NULL);
1172
1173 /* Zero moduli and negative exponents are not considered. */
1174 if (CMPZ(m) == 0)
1175 return MP_UNDEF;
1176 if (CMPZ(b) < 0)
1177 return MP_RANGE;
1178
1179 um = MP_USED(m);
1180 SETUP(mp_int_init_size(TEMP(0), 2 * um));
1181 SETUP(mp_int_init_size(TEMP(1), 2 * um));
1182
1183 if (c == b || c == m) {
1184 SETUP(mp_int_init_size(TEMP(2), 2 * um));
1185 s = TEMP(2);
1186 }
1187 else {
1188 s = c;
1189 }
1190
1191 if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK) goto CLEANUP;
1192
1193 if ((res = s_brmu(TEMP(1), m)) != MP_OK) goto CLEANUP;
1194
1195 if ((res = s_embar(TEMP(0), b, m, TEMP(1), s)) != MP_OK)
1196 goto CLEANUP;
1197
1198 res = mp_int_copy(s, c);
1199
1200 CLEANUP_TEMP();
1201 return res;
1202 }
1203
mp_int_exptmod_evalue(mp_int a,mp_small value,mp_int m,mp_int c)1204 mp_result mp_int_exptmod_evalue(mp_int a, mp_small value, mp_int m, mp_int c)
1205 {
1206 mpz_t vtmp;
1207 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1208
1209 s_fake(&vtmp, value, vbuf);
1210
1211 return mp_int_exptmod(a, &vtmp, m, c);
1212 }
1213
mp_int_exptmod_bvalue(mp_small value,mp_int b,mp_int m,mp_int c)1214 mp_result mp_int_exptmod_bvalue(mp_small value, mp_int b,
1215 mp_int m, mp_int c)
1216 {
1217 mpz_t vtmp;
1218 mp_digit vbuf[MP_VALUE_DIGITS(value)];
1219
1220 s_fake(&vtmp, value, vbuf);
1221
1222 return mp_int_exptmod(&vtmp, b, m, c);
1223 }
1224
mp_int_exptmod_known(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)1225 mp_result mp_int_exptmod_known(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
1226 {
1227 mp_result res;
1228 mp_size um;
1229 mp_int s;
1230 DECLARE_TEMP(2);
1231
1232 CHECK(a && b && m && c);
1233
1234 /* Zero moduli and negative exponents are not considered. */
1235 if (CMPZ(m) == 0)
1236 return MP_UNDEF;
1237 if (CMPZ(b) < 0)
1238 return MP_RANGE;
1239
1240 um = MP_USED(m);
1241 SETUP(mp_int_init_size(TEMP(0), 2 * um));
1242
1243 if (c == b || c == m) {
1244 SETUP(mp_int_init_size(TEMP(1), 2 * um));
1245 s = TEMP(1);
1246 }
1247 else {
1248 s = c;
1249 }
1250
1251 if ((res = mp_int_mod(a, m, TEMP(0))) != MP_OK) goto CLEANUP;
1252
1253 if ((res = s_embar(TEMP(0), b, m, mu, s)) != MP_OK)
1254 goto CLEANUP;
1255
1256 res = mp_int_copy(s, c);
1257
1258 CLEANUP_TEMP();
1259 return res;
1260 }
1261
mp_int_redux_const(mp_int m,mp_int c)1262 mp_result mp_int_redux_const(mp_int m, mp_int c)
1263 {
1264 CHECK(m != NULL && c != NULL && m != c);
1265
1266 return s_brmu(c, m);
1267 }
1268
mp_int_invmod(mp_int a,mp_int m,mp_int c)1269 mp_result mp_int_invmod(mp_int a, mp_int m, mp_int c)
1270 {
1271 mp_result res;
1272 mp_sign sa;
1273 DECLARE_TEMP(2);
1274
1275 CHECK(a != NULL && m != NULL && c != NULL);
1276
1277 if (CMPZ(a) == 0 || CMPZ(m) <= 0)
1278 return MP_RANGE;
1279
1280 sa = MP_SIGN(a); /* need this for the result later */
1281
1282 for (last__ = 0; last__ < 2; ++last__)
1283 mp_int_init(LAST_TEMP());
1284
1285 if ((res = mp_int_egcd(a, m, TEMP(0), TEMP(1), NULL)) != MP_OK)
1286 goto CLEANUP;
1287
1288 if (mp_int_compare_value(TEMP(0), 1) != 0) {
1289 res = MP_UNDEF;
1290 goto CLEANUP;
1291 }
1292
1293 /* It is first necessary to constrain the value to the proper range */
1294 if ((res = mp_int_mod(TEMP(1), m, TEMP(1))) != MP_OK)
1295 goto CLEANUP;
1296
1297 /* Now, if 'a' was originally negative, the value we have is actually the
1298 magnitude of the negative representative; to get the positive value we
1299 have to subtract from the modulus. Otherwise, the value is okay as it
1300 stands.
1301 */
1302 if (sa == MP_NEG)
1303 res = mp_int_sub(m, TEMP(1), c);
1304 else
1305 res = mp_int_copy(TEMP(1), c);
1306
1307 CLEANUP_TEMP();
1308 return res;
1309 }
1310
1311 /* Binary GCD algorithm due to Josef Stein, 1961 */
mp_int_gcd(mp_int a,mp_int b,mp_int c)1312 mp_result mp_int_gcd(mp_int a, mp_int b, mp_int c)
1313 {
1314 int ca, cb, k = 0;
1315 mpz_t u, v, t;
1316 mp_result res;
1317
1318 CHECK(a != NULL && b != NULL && c != NULL);
1319
1320 ca = CMPZ(a);
1321 cb = CMPZ(b);
1322 if (ca == 0 && cb == 0)
1323 return MP_UNDEF;
1324 else if (ca == 0)
1325 return mp_int_abs(b, c);
1326 else if (cb == 0)
1327 return mp_int_abs(a, c);
1328
1329 mp_int_init(&t);
1330 if ((res = mp_int_init_copy(&u, a)) != MP_OK)
1331 goto U;
1332 if ((res = mp_int_init_copy(&v, b)) != MP_OK)
1333 goto V;
1334
1335 MP_SIGN(&u) = MP_ZPOS; MP_SIGN(&v) = MP_ZPOS;
1336
1337 { /* Divide out common factors of 2 from u and v */
1338 int div2_u = s_dp2k(&u), div2_v = s_dp2k(&v);
1339
1340 k = MIN(div2_u, div2_v);
1341 s_qdiv(&u, (mp_size) k);
1342 s_qdiv(&v, (mp_size) k);
1343 }
1344
1345 if (mp_int_is_odd(&u)) {
1346 if ((res = mp_int_neg(&v, &t)) != MP_OK)
1347 goto CLEANUP;
1348 }
1349 else {
1350 if ((res = mp_int_copy(&u, &t)) != MP_OK)
1351 goto CLEANUP;
1352 }
1353
1354 for (;;) {
1355 s_qdiv(&t, s_dp2k(&t));
1356
1357 if (CMPZ(&t) > 0) {
1358 if ((res = mp_int_copy(&t, &u)) != MP_OK)
1359 goto CLEANUP;
1360 }
1361 else {
1362 if ((res = mp_int_neg(&t, &v)) != MP_OK)
1363 goto CLEANUP;
1364 }
1365
1366 if ((res = mp_int_sub(&u, &v, &t)) != MP_OK)
1367 goto CLEANUP;
1368
1369 if (CMPZ(&t) == 0)
1370 break;
1371 }
1372
1373 if ((res = mp_int_abs(&u, c)) != MP_OK)
1374 goto CLEANUP;
1375 if (!s_qmul(c, (mp_size) k))
1376 res = MP_MEMORY;
1377
1378 CLEANUP:
1379 mp_int_clear(&v);
1380 V: mp_int_clear(&u);
1381 U: mp_int_clear(&t);
1382
1383 return res;
1384 }
1385
1386 /* This is the binary GCD algorithm again, but this time we keep track of the
1387 elementary matrix operations as we go, so we can get values x and y
1388 satisfying c = ax + by.
1389 */
mp_int_egcd(mp_int a,mp_int b,mp_int c,mp_int x,mp_int y)1390 mp_result mp_int_egcd(mp_int a, mp_int b, mp_int c,
1391 mp_int x, mp_int y)
1392 {
1393 int k, ca, cb;
1394 mp_result res;
1395 DECLARE_TEMP(8);
1396
1397 CHECK(a != NULL && b != NULL && c != NULL &&
1398 (x != NULL || y != NULL));
1399
1400 ca = CMPZ(a);
1401 cb = CMPZ(b);
1402 if (ca == 0 && cb == 0)
1403 return MP_UNDEF;
1404 else if (ca == 0) {
1405 if ((res = mp_int_abs(b, c)) != MP_OK) return res;
1406 mp_int_zero(x); (void) mp_int_set_value(y, 1); return MP_OK;
1407 }
1408 else if (cb == 0) {
1409 if ((res = mp_int_abs(a, c)) != MP_OK) return res;
1410 (void) mp_int_set_value(x, 1); mp_int_zero(y); return MP_OK;
1411 }
1412
1413 /* Initialize temporaries:
1414 A:0, B:1, C:2, D:3, u:4, v:5, ou:6, ov:7 */
1415 for (last__ = 0; last__ < 4; ++last__)
1416 mp_int_init(LAST_TEMP());
1417 TEMP(0)->digits[0] = 1;
1418 TEMP(3)->digits[0] = 1;
1419
1420 SETUP(mp_int_init_copy(TEMP(4), a));
1421 SETUP(mp_int_init_copy(TEMP(5), b));
1422
1423 /* We will work with absolute values here */
1424 MP_SIGN(TEMP(4)) = MP_ZPOS;
1425 MP_SIGN(TEMP(5)) = MP_ZPOS;
1426
1427 { /* Divide out common factors of 2 from u and v */
1428 int div2_u = s_dp2k(TEMP(4)), div2_v = s_dp2k(TEMP(5));
1429
1430 k = MIN(div2_u, div2_v);
1431 s_qdiv(TEMP(4), k);
1432 s_qdiv(TEMP(5), k);
1433 }
1434
1435 SETUP(mp_int_init_copy(TEMP(6), TEMP(4)));
1436 SETUP(mp_int_init_copy(TEMP(7), TEMP(5)));
1437
1438 for (;;) {
1439 while (mp_int_is_even(TEMP(4))) {
1440 s_qdiv(TEMP(4), 1);
1441
1442 if (mp_int_is_odd(TEMP(0)) || mp_int_is_odd(TEMP(1))) {
1443 if ((res = mp_int_add(TEMP(0), TEMP(7), TEMP(0))) != MP_OK)
1444 goto CLEANUP;
1445 if ((res = mp_int_sub(TEMP(1), TEMP(6), TEMP(1))) != MP_OK)
1446 goto CLEANUP;
1447 }
1448
1449 s_qdiv(TEMP(0), 1);
1450 s_qdiv(TEMP(1), 1);
1451 }
1452
1453 while (mp_int_is_even(TEMP(5))) {
1454 s_qdiv(TEMP(5), 1);
1455
1456 if (mp_int_is_odd(TEMP(2)) || mp_int_is_odd(TEMP(3))) {
1457 if ((res = mp_int_add(TEMP(2), TEMP(7), TEMP(2))) != MP_OK)
1458 goto CLEANUP;
1459 if ((res = mp_int_sub(TEMP(3), TEMP(6), TEMP(3))) != MP_OK)
1460 goto CLEANUP;
1461 }
1462
1463 s_qdiv(TEMP(2), 1);
1464 s_qdiv(TEMP(3), 1);
1465 }
1466
1467 if (mp_int_compare(TEMP(4), TEMP(5)) >= 0) {
1468 if ((res = mp_int_sub(TEMP(4), TEMP(5), TEMP(4))) != MP_OK) goto CLEANUP;
1469 if ((res = mp_int_sub(TEMP(0), TEMP(2), TEMP(0))) != MP_OK) goto CLEANUP;
1470 if ((res = mp_int_sub(TEMP(1), TEMP(3), TEMP(1))) != MP_OK) goto CLEANUP;
1471 }
1472 else {
1473 if ((res = mp_int_sub(TEMP(5), TEMP(4), TEMP(5))) != MP_OK) goto CLEANUP;
1474 if ((res = mp_int_sub(TEMP(2), TEMP(0), TEMP(2))) != MP_OK) goto CLEANUP;
1475 if ((res = mp_int_sub(TEMP(3), TEMP(1), TEMP(3))) != MP_OK) goto CLEANUP;
1476 }
1477
1478 if (CMPZ(TEMP(4)) == 0) {
1479 if (x && (res = mp_int_copy(TEMP(2), x)) != MP_OK) goto CLEANUP;
1480 if (y && (res = mp_int_copy(TEMP(3), y)) != MP_OK) goto CLEANUP;
1481 if (c) {
1482 if (!s_qmul(TEMP(5), k)) {
1483 res = MP_MEMORY;
1484 goto CLEANUP;
1485 }
1486
1487 res = mp_int_copy(TEMP(5), c);
1488 }
1489
1490 break;
1491 }
1492 }
1493
1494 CLEANUP_TEMP();
1495 return res;
1496 }
1497
mp_int_lcm(mp_int a,mp_int b,mp_int c)1498 mp_result mp_int_lcm(mp_int a, mp_int b, mp_int c)
1499 {
1500 mpz_t lcm;
1501 mp_result res;
1502
1503 CHECK(a != NULL && b != NULL && c != NULL);
1504
1505 /* Since a * b = gcd(a, b) * lcm(a, b), we can compute
1506 lcm(a, b) = (a / gcd(a, b)) * b.
1507
1508 This formulation insures everything works even if the input
1509 variables share space.
1510 */
1511 if ((res = mp_int_init(&lcm)) != MP_OK)
1512 return res;
1513 if ((res = mp_int_gcd(a, b, &lcm)) != MP_OK)
1514 goto CLEANUP;
1515 if ((res = mp_int_div(a, &lcm, &lcm, NULL)) != MP_OK)
1516 goto CLEANUP;
1517 if ((res = mp_int_mul(&lcm, b, &lcm)) != MP_OK)
1518 goto CLEANUP;
1519
1520 res = mp_int_copy(&lcm, c);
1521
1522 CLEANUP:
1523 mp_int_clear(&lcm);
1524
1525 return res;
1526 }
1527
mp_int_divisible_value(mp_int a,mp_small v)1528 int mp_int_divisible_value(mp_int a, mp_small v)
1529 {
1530 mp_small rem = 0;
1531
1532 if (mp_int_div_value(a, v, NULL, &rem) != MP_OK)
1533 return 0;
1534
1535 return rem == 0;
1536 }
1537
mp_int_is_pow2(mp_int z)1538 int mp_int_is_pow2(mp_int z)
1539 {
1540 CHECK(z != NULL);
1541
1542 return s_isp2(z);
1543 }
1544
1545 /* Implementation of Newton's root finding method, based loosely on a patch
1546 contributed by Hal Finkel <half@halssoftware.com>
1547 modified by M. J. Fromberger.
1548 */
mp_int_root(mp_int a,mp_small b,mp_int c)1549 mp_result mp_int_root(mp_int a, mp_small b, mp_int c)
1550 {
1551 mp_result res = MP_OK;
1552 int flips = 0;
1553 DECLARE_TEMP(5);
1554
1555 CHECK(a != NULL && c != NULL && b > 0);
1556
1557 if (b == 1) {
1558 return mp_int_copy(a, c);
1559 }
1560 if (MP_SIGN(a) == MP_NEG) {
1561 if (b % 2 == 0)
1562 return MP_UNDEF; /* root does not exist for negative a with even b */
1563 else
1564 flips = 1;
1565 }
1566
1567 SETUP(mp_int_init_copy(LAST_TEMP(), a));
1568 SETUP(mp_int_init_copy(LAST_TEMP(), a));
1569 SETUP(mp_int_init(LAST_TEMP()));
1570 SETUP(mp_int_init(LAST_TEMP()));
1571 SETUP(mp_int_init(LAST_TEMP()));
1572
1573 (void) mp_int_abs(TEMP(0), TEMP(0));
1574 (void) mp_int_abs(TEMP(1), TEMP(1));
1575
1576 for (;;) {
1577 if ((res = mp_int_expt(TEMP(1), b, TEMP(2))) != MP_OK)
1578 goto CLEANUP;
1579
1580 if (mp_int_compare_unsigned(TEMP(2), TEMP(0)) <= 0)
1581 break;
1582
1583 if ((res = mp_int_sub(TEMP(2), TEMP(0), TEMP(2))) != MP_OK)
1584 goto CLEANUP;
1585 if ((res = mp_int_expt(TEMP(1), b - 1, TEMP(3))) != MP_OK)
1586 goto CLEANUP;
1587 if ((res = mp_int_mul_value(TEMP(3), b, TEMP(3))) != MP_OK)
1588 goto CLEANUP;
1589 if ((res = mp_int_div(TEMP(2), TEMP(3), TEMP(4), NULL)) != MP_OK)
1590 goto CLEANUP;
1591 if ((res = mp_int_sub(TEMP(1), TEMP(4), TEMP(4))) != MP_OK)
1592 goto CLEANUP;
1593
1594 if (mp_int_compare_unsigned(TEMP(1), TEMP(4)) == 0) {
1595 if ((res = mp_int_sub_value(TEMP(4), 1, TEMP(4))) != MP_OK)
1596 goto CLEANUP;
1597 }
1598 if ((res = mp_int_copy(TEMP(4), TEMP(1))) != MP_OK)
1599 goto CLEANUP;
1600 }
1601
1602 if ((res = mp_int_copy(TEMP(1), c)) != MP_OK)
1603 goto CLEANUP;
1604
1605 /* If the original value of a was negative, flip the output sign. */
1606 if (flips)
1607 (void) mp_int_neg(c, c); /* cannot fail */
1608
1609 CLEANUP_TEMP();
1610 return res;
1611 }
1612
mp_int_to_int(mp_int z,mp_small * out)1613 mp_result mp_int_to_int(mp_int z, mp_small *out)
1614 {
1615 mp_usmall uv = 0;
1616 mp_size uz;
1617 mp_digit *dz;
1618 mp_sign sz;
1619
1620 CHECK(z != NULL);
1621
1622 /* Make sure the value is representable as a small integer */
1623 sz = MP_SIGN(z);
1624 if ((sz == MP_ZPOS && mp_int_compare_value(z, MP_SMALL_MAX) > 0) ||
1625 mp_int_compare_value(z, MP_SMALL_MIN) < 0)
1626 return MP_RANGE;
1627
1628 uz = MP_USED(z);
1629 dz = MP_DIGITS(z) + uz - 1;
1630
1631 while (uz > 0) {
1632 uv <<= MP_DIGIT_BIT/2;
1633 uv = (uv << (MP_DIGIT_BIT/2)) | *dz--;
1634 --uz;
1635 }
1636
1637 if (out)
1638 *out = (mp_small)((sz == MP_NEG) ? -uv : uv);
1639
1640 return MP_OK;
1641 }
1642
mp_int_to_uint(mp_int z,mp_usmall * out)1643 mp_result mp_int_to_uint(mp_int z, mp_usmall *out)
1644 {
1645 mp_usmall uv = 0;
1646 mp_size uz;
1647 mp_digit *dz;
1648 mp_sign sz;
1649
1650 CHECK(z != NULL);
1651
1652 /* Make sure the value is representable as an unsigned small integer */
1653 sz = MP_SIGN(z);
1654 if (sz == MP_NEG || mp_int_compare_uvalue(z, MP_USMALL_MAX) > 0)
1655 return MP_RANGE;
1656
1657 uz = MP_USED(z);
1658 dz = MP_DIGITS(z) + uz - 1;
1659
1660 while (uz > 0) {
1661 uv <<= MP_DIGIT_BIT/2;
1662 uv = (uv << (MP_DIGIT_BIT/2)) | *dz--;
1663 --uz;
1664 }
1665
1666 if (out)
1667 *out = uv;
1668
1669 return MP_OK;
1670 }
1671
mp_int_to_string(mp_int z,mp_size radix,char * str,int limit)1672 mp_result mp_int_to_string(mp_int z, mp_size radix,
1673 char *str, int limit)
1674 {
1675 mp_result res;
1676 int cmp = 0;
1677
1678 CHECK(z != NULL && str != NULL && limit >= 2);
1679
1680 if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1681 return MP_RANGE;
1682
1683 if (CMPZ(z) == 0) {
1684 *str++ = s_val2ch(0, 1);
1685 }
1686 else {
1687 mpz_t tmp;
1688 char *h, *t;
1689
1690 if ((res = mp_int_init_copy(&tmp, z)) != MP_OK)
1691 return res;
1692
1693 if (MP_SIGN(z) == MP_NEG) {
1694 *str++ = '-';
1695 --limit;
1696 }
1697 h = str;
1698
1699 /* Generate digits in reverse order until finished or limit reached */
1700 for (/* */; limit > 0; --limit) {
1701 mp_digit d;
1702
1703 if ((cmp = CMPZ(&tmp)) == 0)
1704 break;
1705
1706 d = s_ddiv(&tmp, (mp_digit)radix);
1707 *str++ = s_val2ch(d, 1);
1708 }
1709 t = str - 1;
1710
1711 /* Put digits back in correct output order */
1712 while (h < t) {
1713 char tc = *h;
1714 *h++ = *t;
1715 *t-- = tc;
1716 }
1717
1718 mp_int_clear(&tmp);
1719 }
1720
1721 *str = '\0';
1722 if (cmp == 0)
1723 return MP_OK;
1724 else
1725 return MP_TRUNC;
1726 }
1727
mp_int_string_len(mp_int z,mp_size radix)1728 mp_result mp_int_string_len(mp_int z, mp_size radix)
1729 {
1730 int len;
1731
1732 CHECK(z != NULL);
1733
1734 if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1735 return MP_RANGE;
1736
1737 len = s_outlen(z, radix) + 1; /* for terminator */
1738
1739 /* Allow for sign marker on negatives */
1740 if (MP_SIGN(z) == MP_NEG)
1741 len += 1;
1742
1743 return len;
1744 }
1745
1746 /* Read zero-terminated string into z */
mp_int_read_string(mp_int z,mp_size radix,const char * str)1747 mp_result mp_int_read_string(mp_int z, mp_size radix, const char *str)
1748 {
1749 return mp_int_read_cstring(z, radix, str, NULL);
1750 }
1751
mp_int_read_cstring(mp_int z,mp_size radix,const char * str,char ** end)1752 mp_result mp_int_read_cstring(mp_int z, mp_size radix, const char *str, char **end)
1753 {
1754 int ch;
1755
1756 CHECK(z != NULL && str != NULL);
1757
1758 if (radix < MP_MIN_RADIX || radix > MP_MAX_RADIX)
1759 return MP_RANGE;
1760
1761 /* Skip leading whitespace */
1762 while (isspace((int)*str))
1763 ++str;
1764
1765 /* Handle leading sign tag (+/-, positive default) */
1766 switch (*str) {
1767 case '-':
1768 MP_SIGN(z) = MP_NEG;
1769 ++str;
1770 break;
1771 case '+':
1772 ++str; /* fallthrough */
1773 default:
1774 MP_SIGN(z) = MP_ZPOS;
1775 break;
1776 }
1777
1778 /* Skip leading zeroes */
1779 while ((ch = s_ch2val(*str, radix)) == 0)
1780 ++str;
1781
1782 /* Make sure there is enough space for the value */
1783 if (!s_pad(z, s_inlen(strlen(str), radix)))
1784 return MP_MEMORY;
1785
1786 MP_USED(z) = 1; z->digits[0] = 0;
1787
1788 while (*str != '\0' && ((ch = s_ch2val(*str, radix)) >= 0)) {
1789 s_dmul(z, (mp_digit)radix);
1790 s_dadd(z, (mp_digit)ch);
1791 ++str;
1792 }
1793
1794 CLAMP(z);
1795
1796 /* Override sign for zero, even if negative specified. */
1797 if (CMPZ(z) == 0)
1798 MP_SIGN(z) = MP_ZPOS;
1799
1800 if (end != NULL)
1801 *end = (char *)str;
1802
1803 /* Return a truncation error if the string has unprocessed characters
1804 remaining, so the caller can tell if the whole string was done */
1805 if (*str != '\0')
1806 return MP_TRUNC;
1807 else
1808 return MP_OK;
1809 }
1810
mp_int_count_bits(mp_int z)1811 mp_result mp_int_count_bits(mp_int z)
1812 {
1813 mp_size nbits = 0, uz;
1814 mp_digit d;
1815
1816 CHECK(z != NULL);
1817
1818 uz = MP_USED(z);
1819 if (uz == 1 && z->digits[0] == 0)
1820 return 1;
1821
1822 --uz;
1823 nbits = uz * MP_DIGIT_BIT;
1824 d = z->digits[uz];
1825
1826 while (d != 0) {
1827 d >>= 1;
1828 ++nbits;
1829 }
1830
1831 return nbits;
1832 }
1833
mp_int_to_binary(mp_int z,unsigned char * buf,int limit)1834 mp_result mp_int_to_binary(mp_int z, unsigned char *buf, int limit)
1835 {
1836 static const int PAD_FOR_2C = 1;
1837
1838 mp_result res;
1839 int limpos = limit;
1840
1841 CHECK(z != NULL && buf != NULL);
1842
1843 res = s_tobin(z, buf, &limpos, PAD_FOR_2C);
1844
1845 if (MP_SIGN(z) == MP_NEG)
1846 s_2comp(buf, limpos);
1847
1848 return res;
1849 }
1850
mp_int_read_binary(mp_int z,unsigned char * buf,int len)1851 mp_result mp_int_read_binary(mp_int z, unsigned char *buf, int len)
1852 {
1853 mp_size need, i;
1854 unsigned char *tmp;
1855 mp_digit *dz;
1856
1857 CHECK(z != NULL && buf != NULL && len > 0);
1858
1859 /* Figure out how many digits are needed to represent this value */
1860 need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1861 if (!s_pad(z, need))
1862 return MP_MEMORY;
1863
1864 mp_int_zero(z);
1865
1866 /* If the high-order bit is set, take the 2's complement before reading the
1867 value (it will be restored afterward) */
1868 if (buf[0] >> (CHAR_BIT - 1)) {
1869 MP_SIGN(z) = MP_NEG;
1870 s_2comp(buf, len);
1871 }
1872
1873 dz = MP_DIGITS(z);
1874 for (tmp = buf, i = len; i > 0; --i, ++tmp) {
1875 s_qmul(z, (mp_size) CHAR_BIT);
1876 *dz |= *tmp;
1877 }
1878
1879 /* Restore 2's complement if we took it before */
1880 if (MP_SIGN(z) == MP_NEG)
1881 s_2comp(buf, len);
1882
1883 return MP_OK;
1884 }
1885
mp_int_binary_len(mp_int z)1886 mp_result mp_int_binary_len(mp_int z)
1887 {
1888 mp_result res = mp_int_count_bits(z);
1889 int bytes = mp_int_unsigned_len(z);
1890
1891 if (res <= 0)
1892 return res;
1893
1894 bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
1895
1896 /* If the highest-order bit falls exactly on a byte boundary, we need to pad
1897 with an extra byte so that the sign will be read correctly when reading it
1898 back in. */
1899 if (bytes * CHAR_BIT == res)
1900 ++bytes;
1901
1902 return bytes;
1903 }
1904
mp_int_to_unsigned(mp_int z,unsigned char * buf,int limit)1905 mp_result mp_int_to_unsigned(mp_int z, unsigned char *buf, int limit)
1906 {
1907 static const int NO_PADDING = 0;
1908
1909 CHECK(z != NULL && buf != NULL);
1910
1911 return s_tobin(z, buf, &limit, NO_PADDING);
1912 }
1913
mp_int_read_unsigned(mp_int z,unsigned char * buf,int len)1914 mp_result mp_int_read_unsigned(mp_int z, unsigned char *buf, int len)
1915 {
1916 mp_size need, i;
1917 unsigned char *tmp;
1918
1919 CHECK(z != NULL && buf != NULL && len > 0);
1920
1921 /* Figure out how many digits are needed to represent this value */
1922 need = ((len * CHAR_BIT) + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT;
1923 if (!s_pad(z, need))
1924 return MP_MEMORY;
1925
1926 mp_int_zero(z);
1927
1928 for (tmp = buf, i = len; i > 0; --i, ++tmp) {
1929 (void) s_qmul(z, CHAR_BIT);
1930 *MP_DIGITS(z) |= *tmp;
1931 }
1932
1933 return MP_OK;
1934 }
1935
mp_int_unsigned_len(mp_int z)1936 mp_result mp_int_unsigned_len(mp_int z)
1937 {
1938 mp_result res = mp_int_count_bits(z);
1939 int bytes;
1940
1941 if (res <= 0)
1942 return res;
1943
1944 bytes = (res + (CHAR_BIT - 1)) / CHAR_BIT;
1945
1946 return bytes;
1947 }
1948
mp_error_string(mp_result res)1949 const char *mp_error_string(mp_result res)
1950 {
1951 int ix;
1952 if (res > 0)
1953 return s_unknown_err;
1954
1955 res = -res;
1956 for (ix = 0; ix < res && s_error_msg[ix] != NULL; ++ix)
1957 ;
1958
1959 if (s_error_msg[ix] != NULL)
1960 return s_error_msg[ix];
1961 else
1962 return s_unknown_err;
1963 }
1964
1965 /*------------------------------------------------------------------------*/
1966 /* Private functions for internal use. These make assumptions. */
1967
s_alloc(mp_size num)1968 STATIC mp_digit *s_alloc(mp_size num)
1969 {
1970 mp_digit *out = malloc(num * sizeof(mp_digit));
1971
1972 assert(out != NULL); /* for debugging */
1973 #if DEBUG > 1
1974 {
1975 mp_digit v = (mp_digit) 0xdeadbeef;
1976 int ix;
1977
1978 for (ix = 0; ix < num; ++ix)
1979 out[ix] = v;
1980 }
1981 #endif
1982
1983 return out;
1984 }
1985
s_realloc(mp_digit * old,mp_size osize,mp_size nsize)1986 STATIC mp_digit *s_realloc(mp_digit *old, mp_size osize, mp_size nsize)
1987 {
1988 #if DEBUG > 1
1989 mp_digit *new = s_alloc(nsize);
1990 int ix;
1991
1992 for (ix = 0; ix < nsize; ++ix)
1993 new[ix] = (mp_digit) 0xdeadbeef;
1994
1995 memcpy(new, old, osize * sizeof(mp_digit));
1996 #else
1997 mp_digit *new = realloc(old, nsize * sizeof(mp_digit));
1998
1999 assert(new != NULL); /* for debugging */
2000 #endif
2001 return new;
2002 }
2003
s_free(void * ptr)2004 STATIC void s_free(void *ptr)
2005 {
2006 free(ptr);
2007 }
2008
s_pad(mp_int z,mp_size min)2009 STATIC int s_pad(mp_int z, mp_size min)
2010 {
2011 if (MP_ALLOC(z) < min) {
2012 mp_size nsize = ROUND_PREC(min);
2013 mp_digit *tmp;
2014
2015 if ((void *)z->digits == (void *)z) {
2016 if ((tmp = s_alloc(nsize)) == NULL)
2017 return 0;
2018
2019 COPY(MP_DIGITS(z), tmp, MP_USED(z));
2020 }
2021 else if ((tmp = s_realloc(MP_DIGITS(z), MP_ALLOC(z), nsize)) == NULL)
2022 return 0;
2023
2024 MP_DIGITS(z) = tmp;
2025 MP_ALLOC(z) = nsize;
2026 }
2027
2028 return 1;
2029 }
2030
2031 /* Note: This will not work correctly when value == MP_SMALL_MIN */
s_fake(mp_int z,mp_small value,mp_digit vbuf[])2032 STATIC void s_fake(mp_int z, mp_small value, mp_digit vbuf[])
2033 {
2034 mp_usmall uv = (mp_usmall) (value < 0) ? -value : value;
2035 s_ufake(z, uv, vbuf);
2036 if (value < 0)
2037 z->sign = MP_NEG;
2038 }
2039
s_ufake(mp_int z,mp_usmall value,mp_digit vbuf[])2040 STATIC void s_ufake(mp_int z, mp_usmall value, mp_digit vbuf[])
2041 {
2042 mp_size ndig = (mp_size) s_uvpack(value, vbuf);
2043
2044 z->used = ndig;
2045 z->alloc = MP_VALUE_DIGITS(value);
2046 z->sign = MP_ZPOS;
2047 z->digits = vbuf;
2048 }
2049
s_cdig(mp_digit * da,mp_digit * db,mp_size len)2050 STATIC int s_cdig(mp_digit *da, mp_digit *db, mp_size len)
2051 {
2052 mp_digit *dat = da + len - 1, *dbt = db + len - 1;
2053
2054 for (/* */; len != 0; --len, --dat, --dbt) {
2055 if (*dat > *dbt)
2056 return 1;
2057 else if (*dat < *dbt)
2058 return -1;
2059 }
2060
2061 return 0;
2062 }
2063
s_uvpack(mp_usmall uv,mp_digit t[])2064 STATIC int s_uvpack(mp_usmall uv, mp_digit t[])
2065 {
2066 int ndig = 0;
2067
2068 if (uv == 0)
2069 t[ndig++] = 0;
2070 else {
2071 while (uv != 0) {
2072 t[ndig++] = (mp_digit) uv;
2073 uv >>= MP_DIGIT_BIT/2;
2074 uv >>= MP_DIGIT_BIT/2;
2075 }
2076 }
2077
2078 return ndig;
2079 }
2080
s_ucmp(mp_int a,mp_int b)2081 STATIC int s_ucmp(mp_int a, mp_int b)
2082 {
2083 mp_size ua = MP_USED(a), ub = MP_USED(b);
2084
2085 if (ua > ub)
2086 return 1;
2087 else if (ub > ua)
2088 return -1;
2089 else
2090 return s_cdig(MP_DIGITS(a), MP_DIGITS(b), ua);
2091 }
2092
s_vcmp(mp_int a,mp_small v)2093 STATIC int s_vcmp(mp_int a, mp_small v)
2094 {
2095 mp_usmall uv = (v < 0) ? -(mp_usmall) v : (mp_usmall) v;
2096 return s_uvcmp(a, uv);
2097 }
2098
s_uvcmp(mp_int a,mp_usmall uv)2099 STATIC int s_uvcmp(mp_int a, mp_usmall uv)
2100 {
2101 mpz_t vtmp;
2102 mp_digit vdig[MP_VALUE_DIGITS(uv)];
2103
2104 s_ufake(&vtmp, uv, vdig);
2105 return s_ucmp(a, &vtmp);
2106 }
2107
s_uadd(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2108 STATIC mp_digit s_uadd(mp_digit *da, mp_digit *db, mp_digit *dc,
2109 mp_size size_a, mp_size size_b)
2110 {
2111 mp_size pos;
2112 mp_word w = 0;
2113
2114 /* Insure that da is the longer of the two to simplify later code */
2115 if (size_b > size_a) {
2116 SWAP(mp_digit *, da, db);
2117 SWAP(mp_size, size_a, size_b);
2118 }
2119
2120 /* Add corresponding digits until the shorter number runs out */
2121 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
2122 w = w + (mp_word) *da + (mp_word) *db;
2123 *dc = LOWER_HALF(w);
2124 w = UPPER_HALF(w);
2125 }
2126
2127 /* Propagate carries as far as necessary */
2128 for (/* */; pos < size_a; ++pos, ++da, ++dc) {
2129 w = w + *da;
2130
2131 *dc = LOWER_HALF(w);
2132 w = UPPER_HALF(w);
2133 }
2134
2135 /* Return carry out */
2136 return (mp_digit)w;
2137 }
2138
s_usub(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2139 STATIC void s_usub(mp_digit *da, mp_digit *db, mp_digit *dc,
2140 mp_size size_a, mp_size size_b)
2141 {
2142 mp_size pos;
2143 mp_word w = 0;
2144
2145 /* We assume that |a| >= |b| so this should definitely hold */
2146 assert(size_a >= size_b);
2147
2148 /* Subtract corresponding digits and propagate borrow */
2149 for (pos = 0; pos < size_b; ++pos, ++da, ++db, ++dc) {
2150 w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */
2151 (mp_word)*da) - w - (mp_word)*db;
2152
2153 *dc = LOWER_HALF(w);
2154 w = (UPPER_HALF(w) == 0);
2155 }
2156
2157 /* Finish the subtraction for remaining upper digits of da */
2158 for (/* */; pos < size_a; ++pos, ++da, ++dc) {
2159 w = ((mp_word)MP_DIGIT_MAX + 1 + /* MP_RADIX */
2160 (mp_word)*da) - w;
2161
2162 *dc = LOWER_HALF(w);
2163 w = (UPPER_HALF(w) == 0);
2164 }
2165
2166 /* If there is a borrow out at the end, it violates the precondition */
2167 assert(w == 0);
2168 }
2169
s_kmul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2170 STATIC int s_kmul(mp_digit *da, mp_digit *db, mp_digit *dc,
2171 mp_size size_a, mp_size size_b)
2172 {
2173 mp_size bot_size;
2174
2175 /* Make sure b is the smaller of the two input values */
2176 if (size_b > size_a) {
2177 SWAP(mp_digit *, da, db);
2178 SWAP(mp_size, size_a, size_b);
2179 }
2180
2181 /* Insure that the bottom is the larger half in an odd-length split; the code
2182 below relies on this being true.
2183 */
2184 bot_size = (size_a + 1) / 2;
2185
2186 /* If the values are big enough to bother with recursion, use the Karatsuba
2187 algorithm to compute the product; otherwise use the normal multiplication
2188 algorithm
2189 */
2190 if (multiply_threshold &&
2191 size_a >= multiply_threshold &&
2192 size_b > bot_size) {
2193
2194 mp_digit *t1, *t2, *t3, carry;
2195
2196 mp_digit *a_top = da + bot_size;
2197 mp_digit *b_top = db + bot_size;
2198
2199 mp_size at_size = size_a - bot_size;
2200 mp_size bt_size = size_b - bot_size;
2201 mp_size buf_size = 2 * bot_size;
2202
2203 /* Do a single allocation for all three temporary buffers needed; each
2204 buffer must be big enough to hold the product of two bottom halves, and
2205 one buffer needs space for the completed product; twice the space is
2206 plenty.
2207 */
2208 if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
2209 t2 = t1 + buf_size;
2210 t3 = t2 + buf_size;
2211 ZERO(t1, 4 * buf_size);
2212
2213 /* t1 and t2 are initially used as temporaries to compute the inner product
2214 (a1 + a0)(b1 + b0) = a1b1 + a1b0 + a0b1 + a0b0
2215 */
2216 carry = s_uadd(da, a_top, t1, bot_size, at_size); /* t1 = a1 + a0 */
2217 t1[bot_size] = carry;
2218
2219 carry = s_uadd(db, b_top, t2, bot_size, bt_size); /* t2 = b1 + b0 */
2220 t2[bot_size] = carry;
2221
2222 (void) s_kmul(t1, t2, t3, bot_size + 1, bot_size + 1); /* t3 = t1 * t2 */
2223
2224 /* Now we'll get t1 = a0b0 and t2 = a1b1, and subtract them out so that
2225 we're left with only the pieces we want: t3 = a1b0 + a0b1
2226 */
2227 ZERO(t1, buf_size);
2228 ZERO(t2, buf_size);
2229 (void) s_kmul(da, db, t1, bot_size, bot_size); /* t1 = a0 * b0 */
2230 (void) s_kmul(a_top, b_top, t2, at_size, bt_size); /* t2 = a1 * b1 */
2231
2232 /* Subtract out t1 and t2 to get the inner product */
2233 s_usub(t3, t1, t3, buf_size + 2, buf_size);
2234 s_usub(t3, t2, t3, buf_size + 2, buf_size);
2235
2236 /* Assemble the output value */
2237 COPY(t1, dc, buf_size);
2238 carry = s_uadd(t3, dc + bot_size, dc + bot_size,
2239 buf_size + 1, buf_size);
2240 assert(carry == 0);
2241
2242 carry = s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size,
2243 buf_size, buf_size);
2244 assert(carry == 0);
2245
2246 s_free(t1); /* note t2 and t3 are just internal pointers to t1 */
2247 }
2248 else {
2249 s_umul(da, db, dc, size_a, size_b);
2250 }
2251
2252 return 1;
2253 }
2254
s_umul(mp_digit * da,mp_digit * db,mp_digit * dc,mp_size size_a,mp_size size_b)2255 STATIC void s_umul(mp_digit *da, mp_digit *db, mp_digit *dc,
2256 mp_size size_a, mp_size size_b)
2257 {
2258 mp_size a, b;
2259 mp_word w;
2260
2261 for (a = 0; a < size_a; ++a, ++dc, ++da) {
2262 mp_digit *dct = dc;
2263 mp_digit *dbt = db;
2264
2265 if (*da == 0)
2266 continue;
2267
2268 w = 0;
2269 for (b = 0; b < size_b; ++b, ++dbt, ++dct) {
2270 w = (mp_word)*da * (mp_word)*dbt + w + (mp_word)*dct;
2271
2272 *dct = LOWER_HALF(w);
2273 w = UPPER_HALF(w);
2274 }
2275
2276 *dct = (mp_digit)w;
2277 }
2278 }
2279
s_ksqr(mp_digit * da,mp_digit * dc,mp_size size_a)2280 STATIC int s_ksqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2281 {
2282 if (multiply_threshold && size_a > multiply_threshold) {
2283 mp_size bot_size = (size_a + 1) / 2;
2284 mp_digit *a_top = da + bot_size;
2285 mp_digit *t1, *t2, *t3, carry;
2286 mp_size at_size = size_a - bot_size;
2287 mp_size buf_size = 2 * bot_size;
2288
2289 if ((t1 = s_alloc(4 * buf_size)) == NULL) return 0;
2290 t2 = t1 + buf_size;
2291 t3 = t2 + buf_size;
2292 ZERO(t1, 4 * buf_size);
2293
2294 (void) s_ksqr(da, t1, bot_size); /* t1 = a0 ^ 2 */
2295 (void) s_ksqr(a_top, t2, at_size); /* t2 = a1 ^ 2 */
2296
2297 (void) s_kmul(da, a_top, t3, bot_size, at_size); /* t3 = a0 * a1 */
2298
2299 /* Quick multiply t3 by 2, shifting left (can't overflow) */
2300 {
2301 int i, top = bot_size + at_size;
2302 mp_word w, save = 0;
2303
2304 for (i = 0; i < top; ++i) {
2305 w = t3[i];
2306 w = (w << 1) | save;
2307 t3[i] = LOWER_HALF(w);
2308 save = UPPER_HALF(w);
2309 }
2310 t3[i] = LOWER_HALF(save);
2311 }
2312
2313 /* Assemble the output value */
2314 COPY(t1, dc, 2 * bot_size);
2315 carry = s_uadd(t3, dc + bot_size, dc + bot_size,
2316 buf_size + 1, buf_size);
2317 assert(carry == 0);
2318
2319 carry = s_uadd(t2, dc + 2*bot_size, dc + 2*bot_size,
2320 buf_size, buf_size);
2321 assert(carry == 0);
2322
2323 s_free(t1); /* note that t2 and t2 are internal pointers only */
2324
2325 }
2326 else {
2327 s_usqr(da, dc, size_a);
2328 }
2329
2330 return 1;
2331 }
2332
s_usqr(mp_digit * da,mp_digit * dc,mp_size size_a)2333 STATIC void s_usqr(mp_digit *da, mp_digit *dc, mp_size size_a)
2334 {
2335 mp_size i, j;
2336 mp_word w;
2337
2338 for (i = 0; i < size_a; ++i, dc += 2, ++da) {
2339 mp_digit *dct = dc, *dat = da;
2340
2341 if (*da == 0)
2342 continue;
2343
2344 /* Take care of the first digit, no rollover */
2345 w = (mp_word)*dat * (mp_word)*dat + (mp_word)*dct;
2346 *dct = LOWER_HALF(w);
2347 w = UPPER_HALF(w);
2348 ++dat; ++dct;
2349
2350 for (j = i + 1; j < size_a; ++j, ++dat, ++dct) {
2351 mp_word t = (mp_word)*da * (mp_word)*dat;
2352 mp_word u = w + (mp_word)*dct, ov = 0;
2353
2354 /* Check if doubling t will overflow a word */
2355 if (HIGH_BIT_SET(t))
2356 ov = 1;
2357
2358 w = t + t;
2359
2360 /* Check if adding u to w will overflow a word */
2361 if (ADD_WILL_OVERFLOW(w, u))
2362 ov = 1;
2363
2364 w += u;
2365
2366 *dct = LOWER_HALF(w);
2367 w = UPPER_HALF(w);
2368 if (ov) {
2369 w += MP_DIGIT_MAX; /* MP_RADIX */
2370 ++w;
2371 }
2372 }
2373
2374 w = w + *dct;
2375 *dct = (mp_digit)w;
2376 while ((w = UPPER_HALF(w)) != 0) {
2377 ++dct; w = w + *dct;
2378 *dct = LOWER_HALF(w);
2379 }
2380
2381 assert(w == 0);
2382 }
2383 }
2384
s_dadd(mp_int a,mp_digit b)2385 STATIC void s_dadd(mp_int a, mp_digit b)
2386 {
2387 mp_word w = 0;
2388 mp_digit *da = MP_DIGITS(a);
2389 mp_size ua = MP_USED(a);
2390
2391 w = (mp_word)*da + b;
2392 *da++ = LOWER_HALF(w);
2393 w = UPPER_HALF(w);
2394
2395 for (ua -= 1; ua > 0; --ua, ++da) {
2396 w = (mp_word)*da + w;
2397
2398 *da = LOWER_HALF(w);
2399 w = UPPER_HALF(w);
2400 }
2401
2402 if (w) {
2403 *da = (mp_digit)w;
2404 MP_USED(a) += 1;
2405 }
2406 }
2407
s_dmul(mp_int a,mp_digit b)2408 STATIC void s_dmul(mp_int a, mp_digit b)
2409 {
2410 mp_word w = 0;
2411 mp_digit *da = MP_DIGITS(a);
2412 mp_size ua = MP_USED(a);
2413
2414 while (ua > 0) {
2415 w = (mp_word)*da * b + w;
2416 *da++ = LOWER_HALF(w);
2417 w = UPPER_HALF(w);
2418 --ua;
2419 }
2420
2421 if (w) {
2422 *da = (mp_digit)w;
2423 MP_USED(a) += 1;
2424 }
2425 }
2426
s_dbmul(mp_digit * da,mp_digit b,mp_digit * dc,mp_size size_a)2427 STATIC void s_dbmul(mp_digit *da, mp_digit b, mp_digit *dc, mp_size size_a)
2428 {
2429 mp_word w = 0;
2430
2431 while (size_a > 0) {
2432 w = (mp_word)*da++ * (mp_word)b + w;
2433
2434 *dc++ = LOWER_HALF(w);
2435 w = UPPER_HALF(w);
2436 --size_a;
2437 }
2438
2439 if (w)
2440 *dc = LOWER_HALF(w);
2441 }
2442
s_ddiv(mp_int a,mp_digit b)2443 STATIC mp_digit s_ddiv(mp_int a, mp_digit b)
2444 {
2445 mp_word w = 0, qdigit;
2446 mp_size ua = MP_USED(a);
2447 mp_digit *da = MP_DIGITS(a) + ua - 1;
2448
2449 for (/* */; ua > 0; --ua, --da) {
2450 w = (w << MP_DIGIT_BIT) | *da;
2451
2452 if (w >= b) {
2453 qdigit = w / b;
2454 w = w % b;
2455 }
2456 else {
2457 qdigit = 0;
2458 }
2459
2460 *da = (mp_digit)qdigit;
2461 }
2462
2463 CLAMP(a);
2464 return (mp_digit)w;
2465 }
2466
s_qdiv(mp_int z,mp_size p2)2467 STATIC void s_qdiv(mp_int z, mp_size p2)
2468 {
2469 mp_size ndig = p2 / MP_DIGIT_BIT, nbits = p2 % MP_DIGIT_BIT;
2470 mp_size uz = MP_USED(z);
2471
2472 if (ndig) {
2473 mp_size mark;
2474 mp_digit *to, *from;
2475
2476 if (ndig >= uz) {
2477 mp_int_zero(z);
2478 return;
2479 }
2480
2481 to = MP_DIGITS(z); from = to + ndig;
2482
2483 for (mark = ndig; mark < uz; ++mark)
2484 *to++ = *from++;
2485
2486 MP_USED(z) = uz - ndig;
2487 }
2488
2489 if (nbits) {
2490 mp_digit d = 0, *dz, save;
2491 mp_size up = MP_DIGIT_BIT - nbits;
2492
2493 uz = MP_USED(z);
2494 dz = MP_DIGITS(z) + uz - 1;
2495
2496 for (/* */; uz > 0; --uz, --dz) {
2497 save = *dz;
2498
2499 *dz = (*dz >> nbits) | (d << up);
2500 d = save;
2501 }
2502
2503 CLAMP(z);
2504 }
2505
2506 if (MP_USED(z) == 1 && z->digits[0] == 0)
2507 MP_SIGN(z) = MP_ZPOS;
2508 }
2509
s_qmod(mp_int z,mp_size p2)2510 STATIC void s_qmod(mp_int z, mp_size p2)
2511 {
2512 mp_size start = p2 / MP_DIGIT_BIT + 1, rest = p2 % MP_DIGIT_BIT;
2513 mp_size uz = MP_USED(z);
2514 mp_digit mask = (1u << rest) - 1;
2515
2516 if (start <= uz) {
2517 MP_USED(z) = start;
2518 z->digits[start - 1] &= mask;
2519 CLAMP(z);
2520 }
2521 }
2522
s_qmul(mp_int z,mp_size p2)2523 STATIC int s_qmul(mp_int z, mp_size p2)
2524 {
2525 mp_size uz, need, rest, extra, i;
2526 mp_digit *from, *to, d;
2527
2528 if (p2 == 0)
2529 return 1;
2530
2531 uz = MP_USED(z);
2532 need = p2 / MP_DIGIT_BIT; rest = p2 % MP_DIGIT_BIT;
2533
2534 /* Figure out if we need an extra digit at the top end; this occurs if the
2535 topmost `rest' bits of the high-order digit of z are not zero, meaning
2536 they will be shifted off the end if not preserved */
2537 extra = 0;
2538 if (rest != 0) {
2539 mp_digit *dz = MP_DIGITS(z) + uz - 1;
2540
2541 if ((*dz >> (MP_DIGIT_BIT - rest)) != 0)
2542 extra = 1;
2543 }
2544
2545 if (!s_pad(z, uz + need + extra))
2546 return 0;
2547
2548 /* If we need to shift by whole digits, do that in one pass, then
2549 to back and shift by partial digits.
2550 */
2551 if (need > 0) {
2552 from = MP_DIGITS(z) + uz - 1;
2553 to = from + need;
2554
2555 for (i = 0; i < uz; ++i)
2556 *to-- = *from--;
2557
2558 ZERO(MP_DIGITS(z), need);
2559 uz += need;
2560 }
2561
2562 if (rest) {
2563 d = 0;
2564 for (i = need, from = MP_DIGITS(z) + need; i < uz; ++i, ++from) {
2565 mp_digit save = *from;
2566
2567 *from = (*from << rest) | (d >> (MP_DIGIT_BIT - rest));
2568 d = save;
2569 }
2570
2571 d >>= (MP_DIGIT_BIT - rest);
2572 if (d != 0) {
2573 *from = d;
2574 uz += extra;
2575 }
2576 }
2577
2578 MP_USED(z) = uz;
2579 CLAMP(z);
2580
2581 return 1;
2582 }
2583
2584 /* Compute z = 2^p2 - |z|; requires that 2^p2 >= |z|
2585 The sign of the result is always zero/positive.
2586 */
s_qsub(mp_int z,mp_size p2)2587 STATIC int s_qsub(mp_int z, mp_size p2)
2588 {
2589 mp_digit hi = (1 << (p2 % MP_DIGIT_BIT)), *zp;
2590 mp_size tdig = (p2 / MP_DIGIT_BIT), pos;
2591 mp_word w = 0;
2592
2593 if (!s_pad(z, tdig + 1))
2594 return 0;
2595
2596 for (pos = 0, zp = MP_DIGITS(z); pos < tdig; ++pos, ++zp) {
2597 w = ((mp_word) MP_DIGIT_MAX + 1) - w - (mp_word)*zp;
2598
2599 *zp = LOWER_HALF(w);
2600 w = UPPER_HALF(w) ? 0 : 1;
2601 }
2602
2603 w = ((mp_word) MP_DIGIT_MAX + 1 + hi) - w - (mp_word)*zp;
2604 *zp = LOWER_HALF(w);
2605
2606 assert(UPPER_HALF(w) != 0); /* no borrow out should be possible */
2607
2608 MP_SIGN(z) = MP_ZPOS;
2609 CLAMP(z);
2610
2611 return 1;
2612 }
2613
s_dp2k(mp_int z)2614 STATIC int s_dp2k(mp_int z)
2615 {
2616 int k = 0;
2617 mp_digit *dp = MP_DIGITS(z), d;
2618
2619 if (MP_USED(z) == 1 && *dp == 0)
2620 return 1;
2621
2622 while (*dp == 0) {
2623 k += MP_DIGIT_BIT;
2624 ++dp;
2625 }
2626
2627 d = *dp;
2628 while ((d & 1) == 0) {
2629 d >>= 1;
2630 ++k;
2631 }
2632
2633 return k;
2634 }
2635
s_isp2(mp_int z)2636 STATIC int s_isp2(mp_int z)
2637 {
2638 mp_size uz = MP_USED(z), k = 0;
2639 mp_digit *dz = MP_DIGITS(z), d;
2640
2641 while (uz > 1) {
2642 if (*dz++ != 0)
2643 return -1;
2644 k += MP_DIGIT_BIT;
2645 --uz;
2646 }
2647
2648 d = *dz;
2649 while (d > 1) {
2650 if (d & 1)
2651 return -1;
2652 ++k; d >>= 1;
2653 }
2654
2655 return (int) k;
2656 }
2657
s_2expt(mp_int z,mp_small k)2658 STATIC int s_2expt(mp_int z, mp_small k)
2659 {
2660 mp_size ndig, rest;
2661 mp_digit *dz;
2662
2663 ndig = (k + MP_DIGIT_BIT) / MP_DIGIT_BIT;
2664 rest = k % MP_DIGIT_BIT;
2665
2666 if (!s_pad(z, ndig))
2667 return 0;
2668
2669 dz = MP_DIGITS(z);
2670 ZERO(dz, ndig);
2671 *(dz + ndig - 1) = (1 << rest);
2672 MP_USED(z) = ndig;
2673
2674 return 1;
2675 }
2676
s_norm(mp_int a,mp_int b)2677 STATIC int s_norm(mp_int a, mp_int b)
2678 {
2679 mp_digit d = b->digits[MP_USED(b) - 1];
2680 int k = 0;
2681
2682 while (d < (1u << (mp_digit)(MP_DIGIT_BIT - 1))) { /* d < (MP_RADIX / 2) */
2683 d <<= 1;
2684 ++k;
2685 }
2686
2687 /* These multiplications can't fail */
2688 if (k != 0) {
2689 (void) s_qmul(a, (mp_size) k);
2690 (void) s_qmul(b, (mp_size) k);
2691 }
2692
2693 return k;
2694 }
2695
s_brmu(mp_int z,mp_int m)2696 STATIC mp_result s_brmu(mp_int z, mp_int m)
2697 {
2698 mp_size um = MP_USED(m) * 2;
2699
2700 if (!s_pad(z, um))
2701 return MP_MEMORY;
2702
2703 s_2expt(z, MP_DIGIT_BIT * um);
2704 return mp_int_div(z, m, z, NULL);
2705 }
2706
s_reduce(mp_int x,mp_int m,mp_int mu,mp_int q1,mp_int q2)2707 STATIC int s_reduce(mp_int x, mp_int m, mp_int mu, mp_int q1, mp_int q2)
2708 {
2709 mp_size um = MP_USED(m), umb_p1, umb_m1;
2710
2711 umb_p1 = (um + 1) * MP_DIGIT_BIT;
2712 umb_m1 = (um - 1) * MP_DIGIT_BIT;
2713
2714 if (mp_int_copy(x, q1) != MP_OK)
2715 return 0;
2716
2717 /* Compute q2 = floor((floor(x / b^(k-1)) * mu) / b^(k+1)) */
2718 s_qdiv(q1, umb_m1);
2719 UMUL(q1, mu, q2);
2720 s_qdiv(q2, umb_p1);
2721
2722 /* Set x = x mod b^(k+1) */
2723 s_qmod(x, umb_p1);
2724
2725 /* Now, q is a guess for the quotient a / m.
2726 Compute x - q * m mod b^(k+1), replacing x. This may be off
2727 by a factor of 2m, but no more than that.
2728 */
2729 UMUL(q2, m, q1);
2730 s_qmod(q1, umb_p1);
2731 (void) mp_int_sub(x, q1, x); /* can't fail */
2732
2733 /* The result may be < 0; if it is, add b^(k+1) to pin it in the proper
2734 range. */
2735 if ((CMPZ(x) < 0) && !s_qsub(x, umb_p1))
2736 return 0;
2737
2738 /* If x > m, we need to back it off until it is in range. This will be
2739 required at most twice. */
2740 if (mp_int_compare(x, m) >= 0) {
2741 (void) mp_int_sub(x, m, x);
2742 if (mp_int_compare(x, m) >= 0)
2743 (void) mp_int_sub(x, m, x);
2744 }
2745
2746 /* At this point, x has been properly reduced. */
2747 return 1;
2748 }
2749
2750 /* Perform modular exponentiation using Barrett's method, where mu is the
2751 reduction constant for m. Assumes a < m, b > 0. */
s_embar(mp_int a,mp_int b,mp_int m,mp_int mu,mp_int c)2752 STATIC mp_result s_embar(mp_int a, mp_int b, mp_int m, mp_int mu, mp_int c)
2753 {
2754 mp_digit *db, *dbt, umu, d;
2755 mp_result res;
2756 DECLARE_TEMP(3);
2757
2758 umu = MP_USED(mu); db = MP_DIGITS(b); dbt = db + MP_USED(b) - 1;
2759
2760 while (last__ < 3) {
2761 SETUP(mp_int_init_size(LAST_TEMP(), 4 * umu));
2762 ZERO(MP_DIGITS(TEMP(last__ - 1)), MP_ALLOC(TEMP(last__ - 1)));
2763 }
2764
2765 (void) mp_int_set_value(c, 1);
2766
2767 /* Take care of low-order digits */
2768 while (db < dbt) {
2769 int i;
2770
2771 for (d = *db, i = MP_DIGIT_BIT; i > 0; --i, d >>= 1) {
2772 if (d & 1) {
2773 /* The use of a second temporary avoids allocation */
2774 UMUL(c, a, TEMP(0));
2775 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2776 res = MP_MEMORY; goto CLEANUP;
2777 }
2778 mp_int_copy(TEMP(0), c);
2779 }
2780
2781
2782 USQR(a, TEMP(0));
2783 assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2784 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2785 res = MP_MEMORY; goto CLEANUP;
2786 }
2787 assert(MP_SIGN(TEMP(0)) == MP_ZPOS);
2788 mp_int_copy(TEMP(0), a);
2789 }
2790
2791 ++db;
2792 }
2793
2794 /* Take care of highest-order digit */
2795 d = *dbt;
2796 for (;;) {
2797 if (d & 1) {
2798 UMUL(c, a, TEMP(0));
2799 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2800 res = MP_MEMORY; goto CLEANUP;
2801 }
2802 mp_int_copy(TEMP(0), c);
2803 }
2804
2805 d >>= 1;
2806 if (!d) break;
2807
2808 USQR(a, TEMP(0));
2809 if (!s_reduce(TEMP(0), m, mu, TEMP(1), TEMP(2))) {
2810 res = MP_MEMORY; goto CLEANUP;
2811 }
2812 (void) mp_int_copy(TEMP(0), a);
2813 }
2814
2815 CLEANUP_TEMP();
2816 return res;
2817 }
2818
2819 /* Division of nonnegative integers
2820
2821 This function implements division algorithm for unsigned multi-precision
2822 integers. The algorithm is based on Algorithm D from Knuth's "The Art of
2823 Computer Programming", 3rd ed. 1998, pg 272-273.
2824
2825 We diverge from Knuth's algorithm in that we do not perform the subtraction
2826 from the remainder until we have determined that we have the correct
2827 quotient digit. This makes our algorithm less efficient that Knuth because
2828 we might have to perform multiple multiplication and comparison steps before
2829 the subtraction. The advantage is that it is easy to implement and ensure
2830 correctness without worrying about underflow from the subtraction.
2831
2832 inputs: u a n+m digit integer in base b (b is 2^MP_DIGIT_BIT)
2833 v a n digit integer in base b (b is 2^MP_DIGIT_BIT)
2834 n >= 1
2835 m >= 0
2836 outputs: u / v stored in u
2837 u % v stored in v
2838 */
s_udiv_knuth(mp_int u,mp_int v)2839 STATIC mp_result s_udiv_knuth(mp_int u, mp_int v) {
2840 mpz_t q, r, t;
2841 mp_result
2842 res = MP_OK;
2843 int k,j;
2844 mp_size m,n;
2845
2846 /* Force signs to positive */
2847 MP_SIGN(u) = MP_ZPOS;
2848 MP_SIGN(v) = MP_ZPOS;
2849
2850 /* Use simple division algorithm when v is only one digit long */
2851 if (MP_USED(v) == 1) {
2852 mp_digit d, rem;
2853 d = v->digits[0];
2854 rem = s_ddiv(u, d);
2855 mp_int_set_value(v, rem);
2856 return MP_OK;
2857 }
2858
2859 /* Algorithm D
2860
2861 The n and m variables are defined as used by Knuth.
2862 u is an n digit number with digits u_{n-1}..u_0.
2863 v is an n+m digit number with digits from v_{m+n-1}..v_0.
2864 We require that n > 1 and m >= 0
2865 */
2866 n = MP_USED(v);
2867 m = MP_USED(u) - n;
2868 assert(n > 1);
2869 assert(m >= 0);
2870
2871 /* D1: Normalize.
2872 The normalization step provides the necessary condition for Theorem B,
2873 which states that the quotient estimate for q_j, call it qhat
2874
2875 qhat = u_{j+n}u_{j+n-1} / v_{n-1}
2876
2877 is bounded by
2878
2879 qhat - 2 <= q_j <= qhat.
2880
2881 That is, qhat is always greater than the actual quotient digit q,
2882 and it is never more than two larger than the actual quotient digit.
2883 */
2884 k = s_norm(u, v);
2885
2886 /* Extend size of u by one if needed.
2887
2888 The algorithm begins with a value of u that has one more digit of input.
2889 The normalization step sets u_{m+n}..u_0 = 2^k * u_{m+n-1}..u_0. If the
2890 multiplication did not increase the number of digits of u, we need to add
2891 a leading zero here.
2892 */
2893 if (k == 0 || MP_USED(u) != m + n + 1) {
2894 if (!s_pad(u, m+n+1))
2895 return MP_MEMORY;
2896 u->digits[m+n] = 0;
2897 u->used = m+n+1;
2898 }
2899
2900 /* Add a leading 0 to v.
2901
2902 The multiplication in step D4 multiplies qhat * 0v_{n-1}..v_0. We need to
2903 add the leading zero to v here to ensure that the multiplication will
2904 produce the full n+1 digit result.
2905 */
2906 if (!s_pad(v, n+1)) return MP_MEMORY; v->digits[n] = 0;
2907
2908 /* Initialize temporary variables q and t.
2909 q allocates space for m+1 digits to store the quotient digits
2910 t allocates space for n+1 digits to hold the result of q_j*v
2911 */
2912 if ((res = mp_int_init_size(&q, m + 1)) != MP_OK) return res;
2913 if ((res = mp_int_init_size(&t, n + 1)) != MP_OK) goto CLEANUP;
2914
2915 /* D2: Initialize j */
2916 j = m;
2917 r.digits = MP_DIGITS(u) + j; /* The contents of r are shared with u */
2918 r.used = n + 1;
2919 r.sign = MP_ZPOS;
2920 r.alloc = MP_ALLOC(u);
2921 ZERO(t.digits, t.alloc);
2922
2923 /* Calculate the m+1 digits of the quotient result */
2924 for (; j >= 0; j--) {
2925 /* D3: Calculate q' */
2926 /* r->digits is aligned to position j of the number u */
2927 mp_word pfx, qhat;
2928 pfx = r.digits[n];
2929 pfx <<= MP_DIGIT_BIT / 2;
2930 pfx <<= MP_DIGIT_BIT / 2;
2931 pfx |= r.digits[n-1]; /* pfx = u_{j+n}{j+n-1} */
2932
2933 qhat = pfx / v->digits[n-1];
2934 /* Check to see if qhat > b, and decrease qhat if so.
2935 Theorem B guarantess that qhat is at most 2 larger than the
2936 actual value, so it is possible that qhat is greater than
2937 the maximum value that will fit in a digit */
2938 if (qhat > MP_DIGIT_MAX)
2939 qhat = MP_DIGIT_MAX;
2940
2941 /* D4,D5,D6: Multiply qhat * v and test for a correct value of q
2942
2943 We proceed a bit different than the way described by Knuth. This way is
2944 simpler but less efficent. Instead of doing the multiply and subtract
2945 then checking for underflow, we first do the multiply of qhat * v and
2946 see if it is larger than the current remainder r. If it is larger, we
2947 decrease qhat by one and try again. We may need to decrease qhat one
2948 more time before we get a value that is smaller than r.
2949
2950 This way is less efficent than Knuth becuase we do more multiplies, but
2951 we do not need to worry about underflow this way.
2952 */
2953 /* t = qhat * v */
2954 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, t.digits, n+1); t.used = n + 1;
2955 CLAMP(&t);
2956
2957 /* Clamp r for the comparison. Comparisons do not like leading zeros. */
2958 CLAMP(&r);
2959 if (s_ucmp(&t, &r) > 0) { /* would the remainder be negative? */
2960 qhat -= 1; /* try a smaller q */
2961 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, t.digits, n+1);
2962 t.used = n + 1; CLAMP(&t);
2963 if (s_ucmp(&t, &r) > 0) { /* would the remainder be negative? */
2964 assert(qhat > 0);
2965 qhat -= 1; /* try a smaller q */
2966 s_dbmul(MP_DIGITS(v), (mp_digit) qhat, t.digits, n+1);
2967 t.used = n + 1; CLAMP(&t);
2968 }
2969 assert(s_ucmp(&t, &r) <= 0 && "The mathematics failed us.");
2970 }
2971 /* Unclamp r. The D algorithm expects r = u_{j+n}..u_j to always be n+1
2972 digits long. */
2973 r.used = n + 1;
2974
2975 /* D4: Multiply and subtract
2976
2977 Note: The multiply was completed above so we only need to subtract here.
2978 */
2979 s_usub(r.digits, t.digits, r.digits, r.used, t.used);
2980
2981 /* D5: Test remainder
2982
2983 Note: Not needed because we always check that qhat is the correct value
2984 before performing the subtract. Value cast to mp_digit to prevent
2985 warning, qhat has been clamped to MP_DIGIT_MAX
2986 */
2987 q.digits[j] = (mp_digit)qhat;
2988
2989 /* D6: Add back
2990 Note: Not needed because we always check that qhat is the correct value
2991 before performing the subtract.
2992 */
2993
2994 /* D7: Loop on j */
2995 r.digits--;
2996 ZERO(t.digits, t.alloc);
2997 }
2998
2999 /* Get rid of leading zeros in q */
3000 q.used = m + 1;
3001 CLAMP(&q);
3002
3003 /* Denormalize the remainder */
3004 CLAMP(u); /* use u here because the r.digits pointer is off-by-one */
3005 if (k != 0)
3006 s_qdiv(u, k);
3007
3008 mp_int_copy(u, v); /* ok: 0 <= r < v */
3009 mp_int_copy(&q, u); /* ok: q <= u */
3010
3011 mp_int_clear(&t);
3012 CLEANUP:
3013 mp_int_clear(&q);
3014 return res;
3015 }
3016
s_outlen(mp_int z,mp_size r)3017 STATIC int s_outlen(mp_int z, mp_size r)
3018 {
3019 mp_result bits;
3020 double raw;
3021
3022 assert(r >= MP_MIN_RADIX && r <= MP_MAX_RADIX);
3023
3024 bits = mp_int_count_bits(z);
3025 raw = (double)bits * s_log2[r];
3026
3027 return (int)(raw + 0.999999);
3028 }
3029
s_inlen(int len,mp_size r)3030 STATIC mp_size s_inlen(int len, mp_size r)
3031 {
3032 double raw = (double)len / s_log2[r];
3033 mp_size bits = (mp_size)(raw + 0.5);
3034
3035 return (mp_size)((bits + (MP_DIGIT_BIT - 1)) / MP_DIGIT_BIT) + 1;
3036 }
3037
s_ch2val(char c,int r)3038 STATIC int s_ch2val(char c, int r)
3039 {
3040 int out;
3041
3042 if (isdigit((unsigned char) c))
3043 out = c - '0';
3044 else if (r > 10 && isalpha((unsigned char) c))
3045 out = toupper(c) - 'A' + 10;
3046 else
3047 return -1;
3048
3049 return (out >= r) ? -1 : out;
3050 }
3051
s_val2ch(int v,int caps)3052 STATIC char s_val2ch(int v, int caps)
3053 {
3054 assert(v >= 0);
3055
3056 if (v < 10)
3057 return v + '0';
3058 else {
3059 char out = (v - 10) + 'a';
3060
3061 if (caps)
3062 return toupper(out);
3063 else
3064 return out;
3065 }
3066 }
3067
s_2comp(unsigned char * buf,int len)3068 STATIC void s_2comp(unsigned char *buf, int len)
3069 {
3070 int i;
3071 unsigned short s = 1;
3072
3073 for (i = len - 1; i >= 0; --i) {
3074 unsigned char c = ~buf[i];
3075
3076 s = c + s;
3077 c = s & UCHAR_MAX;
3078 s >>= CHAR_BIT;
3079
3080 buf[i] = c;
3081 }
3082
3083 /* last carry out is ignored */
3084 }
3085
s_tobin(mp_int z,unsigned char * buf,int * limpos,int pad)3086 STATIC mp_result s_tobin(mp_int z, unsigned char *buf, int *limpos, int pad)
3087 {
3088 mp_size uz;
3089 mp_digit *dz;
3090 int pos = 0, limit = *limpos;
3091
3092 uz = MP_USED(z); dz = MP_DIGITS(z);
3093 while (uz > 0 && pos < limit) {
3094 mp_digit d = *dz++;
3095 int i;
3096
3097 for (i = sizeof(mp_digit); i > 0 && pos < limit; --i) {
3098 buf[pos++] = (unsigned char)d;
3099 d >>= CHAR_BIT;
3100
3101 /* Don't write leading zeroes */
3102 if (d == 0 && uz == 1)
3103 i = 0; /* exit loop without signaling truncation */
3104 }
3105
3106 /* Detect truncation (loop exited with pos >= limit) */
3107 if (i > 0) break;
3108
3109 --uz;
3110 }
3111
3112 if (pad != 0 && (buf[pos - 1] >> (CHAR_BIT - 1))) {
3113 if (pos < limit)
3114 buf[pos++] = 0;
3115 else
3116 uz = 1;
3117 }
3118
3119 /* Digits are in reverse order, fix that */
3120 REV(unsigned char, buf, pos);
3121
3122 /* Return the number of bytes actually written */
3123 *limpos = pos;
3124
3125 return (uz == 0) ? MP_OK : MP_TRUNC;
3126 }
3127
3128 #if DEBUG
s_print(char * tag,mp_int z)3129 void s_print(char *tag, mp_int z)
3130 {
3131 int i;
3132
3133 fprintf(stderr, "%s: %c ", tag,
3134 (MP_SIGN(z) == MP_NEG) ? '-' : '+');
3135
3136 for (i = MP_USED(z) - 1; i >= 0; --i)
3137 fprintf(stderr, "%0*X", (int)(MP_DIGIT_BIT / 4), z->digits[i]);
3138
3139 fputc('\n', stderr);
3140
3141 }
3142
s_print_buf(char * tag,mp_digit * buf,mp_size num)3143 void s_print_buf(char *tag, mp_digit *buf, mp_size num)
3144 {
3145 int i;
3146
3147 fprintf(stderr, "%s: ", tag);
3148
3149 for (i = num - 1; i >= 0; --i)
3150 fprintf(stderr, "%0*X", (int)(MP_DIGIT_BIT / 4), buf[i]);
3151
3152 fputc('\n', stderr);
3153 }
3154 #endif
3155
3156 /* Here there be dragons */
3157