1 #include <openssl/e_os2.h>
2 #include <stddef.h>
3 #include <sys/types.h>
4 #include <string.h>
5 #include <openssl/bn.h>
6 #include <openssl/err.h>
7 #include <openssl/rsaerr.h>
8 #include "internal/numbers.h"
9 #include "internal/constant_time.h"
10 #include "bn_local.h"
11
12 # if BN_BYTES == 8
13 typedef uint64_t limb_t;
14 # if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16
15 /* nonstandard; implemented by gcc on 64-bit platforms */
16 typedef __uint128_t limb2_t;
17 # define HAVE_LIMB2_T
18 # endif
19 # define LIMB_BIT_SIZE 64
20 # define LIMB_BYTE_SIZE 8
21 # elif BN_BYTES == 4
22 typedef uint32_t limb_t;
23 typedef uint64_t limb2_t;
24 # define LIMB_BIT_SIZE 32
25 # define LIMB_BYTE_SIZE 4
26 # define HAVE_LIMB2_T
27 # else
28 # error "Not supported"
29 # endif
30
31 /*
32 * For multiplication we're using schoolbook multiplication,
33 * so if we have two numbers, each with 6 "digits" (words)
34 * the multiplication is calculated as follows:
35 * A B C D E F
36 * x I J K L M N
37 * --------------
38 * N*F
39 * N*E
40 * N*D
41 * N*C
42 * N*B
43 * N*A
44 * M*F
45 * M*E
46 * M*D
47 * M*C
48 * M*B
49 * M*A
50 * L*F
51 * L*E
52 * L*D
53 * L*C
54 * L*B
55 * L*A
56 * K*F
57 * K*E
58 * K*D
59 * K*C
60 * K*B
61 * K*A
62 * J*F
63 * J*E
64 * J*D
65 * J*C
66 * J*B
67 * J*A
68 * I*F
69 * I*E
70 * I*D
71 * I*C
72 * I*B
73 * + I*A
74 * ==========================
75 * N*B N*D N*F
76 * + N*A N*C N*E
77 * + M*B M*D M*F
78 * + M*A M*C M*E
79 * + L*B L*D L*F
80 * + L*A L*C L*E
81 * + K*B K*D K*F
82 * + K*A K*C K*E
83 * + J*B J*D J*F
84 * + J*A J*C J*E
85 * + I*B I*D I*F
86 * + I*A I*C I*E
87 *
88 * 1+1 1+3 1+5
89 * 1+0 1+2 1+4
90 * 0+1 0+3 0+5
91 * 0+0 0+2 0+4
92 *
93 * 0 1 2 3 4 5 6
94 * which requires n^2 multiplications and 2n full length additions
95 * as we can keep every other result of limb multiplication in two separate
96 * limbs
97 */
98
99 #if defined HAVE_LIMB2_T
_mul_limb(limb_t * hi,limb_t * lo,limb_t a,limb_t b)100 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
101 {
102 limb2_t t;
103 /*
104 * this is idiomatic code to tell compiler to use the native mul
105 * those three lines will actually compile to single instruction
106 */
107
108 t = (limb2_t)a * b;
109 *hi = t >> LIMB_BIT_SIZE;
110 *lo = (limb_t)t;
111 }
112 #elif (BN_BYTES == 8) && (defined _MSC_VER)
113 /* https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-170 */
114 #pragma intrinsic(_umul128)
_mul_limb(limb_t * hi,limb_t * lo,limb_t a,limb_t b)115 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
116 {
117 *lo = _umul128(a, b, hi);
118 }
119 #else
120 /*
121 * if the compiler doesn't have either a 128bit data type nor a "return
122 * high 64 bits of multiplication"
123 */
_mul_limb(limb_t * hi,limb_t * lo,limb_t a,limb_t b)124 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
125 {
126 limb_t a_low = (limb_t)(uint32_t)a;
127 limb_t a_hi = a >> 32;
128 limb_t b_low = (limb_t)(uint32_t)b;
129 limb_t b_hi = b >> 32;
130
131 limb_t p0 = a_low * b_low;
132 limb_t p1 = a_low * b_hi;
133 limb_t p2 = a_hi * b_low;
134 limb_t p3 = a_hi * b_hi;
135
136 uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32);
137
138 *lo = p0 + (p1 << 32) + (p2 << 32);
139 *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy;
140 }
141 #endif
142
143 /* add two limbs with carry in, return carry out */
_add_limb(limb_t * ret,limb_t a,limb_t b,limb_t carry)144 static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry)
145 {
146 limb_t carry1, carry2, t;
147 /*
148 * `c = a + b; if (c < a)` is idiomatic code that makes compilers
149 * use add with carry on assembly level
150 */
151
152 *ret = a + carry;
153 if (*ret < a)
154 carry1 = 1;
155 else
156 carry1 = 0;
157
158 t = *ret;
159 *ret = t + b;
160 if (*ret < t)
161 carry2 = 1;
162 else
163 carry2 = 0;
164
165 return carry1 + carry2;
166 }
167
168 /*
169 * add two numbers of the same size, return overflow
170 *
171 * add a to b, place result in ret; all arrays need to be n limbs long
172 * return overflow from addition (0 or 1)
173 */
add(limb_t * ret,limb_t * a,limb_t * b,size_t n)174 static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n)
175 {
176 limb_t c = 0;
177 ossl_ssize_t i;
178
179 for(i = n - 1; i > -1; i--)
180 c = _add_limb(&ret[i], a[i], b[i], c);
181
182 return c;
183 }
184
185 /*
186 * return number of limbs necessary for temporary values
187 * when multiplying numbers n limbs large
188 */
mul_limb_numb(size_t n)189 static ossl_inline size_t mul_limb_numb(size_t n)
190 {
191 return 2 * n * 2;
192 }
193
194 /*
195 * multiply two numbers of the same size
196 *
197 * multiply a by b, place result in ret; a and b need to be n limbs long
198 * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs
199 * long
200 */
limb_mul(limb_t * ret,limb_t * a,limb_t * b,size_t n,limb_t * tmp)201 static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp)
202 {
203 limb_t *r_odd, *r_even;
204 size_t i, j, k;
205
206 r_odd = tmp;
207 r_even = &tmp[2 * n];
208
209 memset(ret, 0, 2 * n * sizeof(limb_t));
210
211 for (i = 0; i < n; i++) {
212 for (k = 0; k < i + n + 1; k++) {
213 r_even[k] = 0;
214 r_odd[k] = 0;
215 }
216 for (j = 0; j < n; j++) {
217 /*
218 * place results from even and odd limbs in separate arrays so that
219 * we don't have to calculate overflow every time we get individual
220 * limb multiplication result
221 */
222 if (j % 2 == 0)
223 _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]);
224 else
225 _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]);
226 }
227 /*
228 * skip the least significant limbs when adding multiples of
229 * more significant limbs (they're zero anyway)
230 */
231 add(ret, ret, r_even, n + i + 1);
232 add(ret, ret, r_odd, n + i + 1);
233 }
234 }
235
236 /* modifies the value in place by performing a right shift by one bit */
rshift1(limb_t * val,size_t n)237 static ossl_inline void rshift1(limb_t *val, size_t n)
238 {
239 limb_t shift_in = 0, shift_out = 0;
240 size_t i;
241
242 for (i = 0; i < n; i++) {
243 shift_out = val[i] & 1;
244 val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1);
245 shift_in = shift_out;
246 }
247 }
248
249 /* extend the LSB of flag to all bits of limb */
mk_mask(limb_t flag)250 static ossl_inline limb_t mk_mask(limb_t flag)
251 {
252 flag |= flag << 1;
253 flag |= flag << 2;
254 flag |= flag << 4;
255 flag |= flag << 8;
256 flag |= flag << 16;
257 #if (LIMB_BYTE_SIZE == 8)
258 flag |= flag << 32;
259 #endif
260 return flag;
261 }
262
263 /*
264 * copy from either a or b to ret based on flag
265 * when flag == 0, then copies from b
266 * when flag == 1, then copies from a
267 */
cselect(limb_t flag,limb_t * ret,limb_t * a,limb_t * b,size_t n)268 static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n)
269 {
270 /*
271 * would be more efficient with non volatile mask, but then gcc
272 * generates code with jumps
273 */
274 volatile limb_t mask;
275 size_t i;
276
277 mask = mk_mask(flag);
278 for (i = 0; i < n; i++) {
279 #if (LIMB_BYTE_SIZE == 8)
280 ret[i] = constant_time_select_64(mask, a[i], b[i]);
281 #else
282 ret[i] = constant_time_select_32(mask, a[i], b[i]);
283 #endif
284 }
285 }
286
_sub_limb(limb_t * ret,limb_t a,limb_t b,limb_t borrow)287 static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow)
288 {
289 limb_t borrow1, borrow2, t;
290 /*
291 * while it doesn't look constant-time, this is idiomatic code
292 * to tell compilers to use the carry bit from subtraction
293 */
294
295 *ret = a - borrow;
296 if (*ret > a)
297 borrow1 = 1;
298 else
299 borrow1 = 0;
300
301 t = *ret;
302 *ret = t - b;
303 if (*ret > t)
304 borrow2 = 1;
305 else
306 borrow2 = 0;
307
308 return borrow1 + borrow2;
309 }
310
311 /*
312 * place the result of a - b into ret, return the borrow bit.
313 * All arrays need to be n limbs long
314 */
sub(limb_t * ret,limb_t * a,limb_t * b,size_t n)315 static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n)
316 {
317 limb_t borrow = 0;
318 ossl_ssize_t i;
319
320 for (i = n - 1; i > -1; i--)
321 borrow = _sub_limb(&ret[i], a[i], b[i], borrow);
322
323 return borrow;
324 }
325
326 /* return the number of limbs necessary to allocate for the mod() tmp operand */
mod_limb_numb(size_t anum,size_t modnum)327 static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum)
328 {
329 return (anum + modnum) * 3;
330 }
331
332 /*
333 * calculate a % mod, place the result in ret
334 * size of a is defined by anum, size of ret and mod is modnum,
335 * size of tmp is returned by mod_limb_numb()
336 */
mod(limb_t * ret,limb_t * a,size_t anum,limb_t * mod,size_t modnum,limb_t * tmp)337 static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
338 size_t modnum, limb_t *tmp)
339 {
340 limb_t *atmp, *modtmp, *rettmp;
341 limb_t res;
342 size_t i;
343
344 memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE);
345
346 atmp = tmp;
347 modtmp = &tmp[anum + modnum];
348 rettmp = &tmp[(anum + modnum) * 2];
349
350 for (i = modnum; i <modnum + anum; i++)
351 atmp[i] = a[i-modnum];
352
353 for (i = 0; i < modnum; i++)
354 modtmp[i] = mod[i];
355
356 for (i = 0; i < anum * LIMB_BIT_SIZE; i++) {
357 rshift1(modtmp, anum + modnum);
358 res = sub(rettmp, atmp, modtmp, anum+modnum);
359 cselect(res, atmp, atmp, rettmp, anum+modnum);
360 }
361
362 memcpy(ret, &atmp[anum], sizeof(limb_t) * modnum);
363 }
364
365 /* necessary size of tmp for a _mul_add_limb() call with provided anum */
_mul_add_limb_numb(size_t anum)366 static ossl_inline size_t _mul_add_limb_numb(size_t anum)
367 {
368 return 2 * (anum + 1);
369 }
370
371 /* multiply a by m, add to ret, return carry */
_mul_add_limb(limb_t * ret,limb_t * a,size_t anum,limb_t m,limb_t * tmp)372 static limb_t _mul_add_limb(limb_t *ret, limb_t *a, size_t anum,
373 limb_t m, limb_t *tmp)
374 {
375 limb_t carry = 0;
376 limb_t *r_odd, *r_even;
377 size_t i;
378
379 memset(tmp, 0, sizeof(limb_t) * (anum + 1) * 2);
380
381 r_odd = tmp;
382 r_even = &tmp[anum + 1];
383
384 for (i = 0; i < anum; i++) {
385 /*
386 * place the results from even and odd limbs in separate arrays
387 * so that we have to worry about carry just once
388 */
389 if (i % 2 == 0)
390 _mul_limb(&r_even[i], &r_even[i + 1], a[i], m);
391 else
392 _mul_limb(&r_odd[i], &r_odd[i + 1], a[i], m);
393 }
394 /* assert: add() carry here will be equal zero */
395 add(r_even, r_even, r_odd, anum + 1);
396 /*
397 * while here it will not overflow as the max value from multiplication
398 * is -2 while max overflow from addition is 1, so the max value of
399 * carry is -1 (i.e. max int)
400 */
401 carry = add(ret, ret, &r_even[1], anum) + r_even[0];
402
403 return carry;
404 }
405
mod_montgomery_limb_numb(size_t modnum)406 static ossl_inline size_t mod_montgomery_limb_numb(size_t modnum)
407 {
408 return modnum * 2 + _mul_add_limb_numb(modnum);
409 }
410
411 /*
412 * calculate a % mod, place result in ret
413 * assumes that a is in Montgomery form with the R (Montgomery modulus) being
414 * smallest power of two big enough to fit mod and that's also a power
415 * of the count of number of bits in limb_t (B).
416 * For calculation, we also need n', such that mod * n' == -1 mod B.
417 * anum must be <= 2 * modnum
418 * ret needs to be modnum words long
419 * tmp needs to be mod_montgomery_limb_numb(modnum) limbs long
420 */
mod_montgomery(limb_t * ret,limb_t * a,size_t anum,limb_t * mod,size_t modnum,limb_t ni0,limb_t * tmp)421 static void mod_montgomery(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
422 size_t modnum, limb_t ni0, limb_t *tmp)
423 {
424 limb_t carry, v;
425 limb_t *res, *rp, *tmp2;
426 ossl_ssize_t i;
427
428 res = tmp;
429 /*
430 * for intermediate result we need an integer twice as long as modulus
431 * but keep the input in the least significant limbs
432 */
433 memset(res, 0, sizeof(limb_t) * (modnum * 2));
434 memcpy(&res[modnum * 2 - anum], a, sizeof(limb_t) * anum);
435 rp = &res[modnum];
436 tmp2 = &res[modnum * 2];
437
438 carry = 0;
439
440 /* add multiples of the modulus to the value until R divides it cleanly */
441 for (i = modnum; i > 0; i--, rp--) {
442 v = _mul_add_limb(rp, mod, modnum, rp[modnum - 1] * ni0, tmp2);
443 v = v + carry + rp[-1];
444 carry |= (v != rp[-1]);
445 carry &= (v <= rp[-1]);
446 rp[-1] = v;
447 }
448
449 /* perform the final reduction by mod... */
450 carry -= sub(ret, rp, mod, modnum);
451
452 /* ...conditionally */
453 cselect(carry, ret, rp, ret, modnum);
454 }
455
456 /* allocated buffer should be freed afterwards */
BN_to_limb(const BIGNUM * bn,limb_t * buf,size_t limbs)457 static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs)
458 {
459 int i;
460 int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
461 limb_t *ptr = buf + (limbs - real_limbs);
462
463 for (i = 0; i < real_limbs; i++)
464 ptr[i] = bn->d[real_limbs - i - 1];
465 }
466
467 #if LIMB_BYTE_SIZE == 8
be64(uint64_t host)468 static ossl_inline uint64_t be64(uint64_t host)
469 {
470 const union {
471 long one;
472 char little;
473 } is_endian = { 1 };
474
475 if (is_endian.little) {
476 uint64_t big = 0;
477
478 big |= (host & 0xff00000000000000) >> 56;
479 big |= (host & 0x00ff000000000000) >> 40;
480 big |= (host & 0x0000ff0000000000) >> 24;
481 big |= (host & 0x000000ff00000000) >> 8;
482 big |= (host & 0x00000000ff000000) << 8;
483 big |= (host & 0x0000000000ff0000) << 24;
484 big |= (host & 0x000000000000ff00) << 40;
485 big |= (host & 0x00000000000000ff) << 56;
486 return big;
487 } else {
488 return host;
489 }
490 }
491
492 #else
493 /* Not all platforms have htobe32(). */
be32(uint32_t host)494 static ossl_inline uint32_t be32(uint32_t host)
495 {
496 const union {
497 long one;
498 char little;
499 } is_endian = { 1 };
500
501 if (is_endian.little) {
502 uint32_t big = 0;
503
504 big |= (host & 0xff000000) >> 24;
505 big |= (host & 0x00ff0000) >> 8;
506 big |= (host & 0x0000ff00) << 8;
507 big |= (host & 0x000000ff) << 24;
508 return big;
509 } else {
510 return host;
511 }
512 }
513 #endif
514
515 /*
516 * We assume that intermediate, possible_arg2, blinding, and ctx are used
517 * similar to BN_BLINDING_invert_ex() arguments.
518 * to_mod is RSA modulus.
519 * buf and num is the serialization buffer and its length.
520 *
521 * Here we use classic/Montgomery multiplication and modulo. After the calculation finished
522 * we serialize the new structure instead of BIGNUMs taking endianness into account.
523 */
ossl_bn_rsa_do_unblind(const BIGNUM * intermediate,const BN_BLINDING * blinding,const BIGNUM * possible_arg2,const BIGNUM * to_mod,BN_CTX * ctx,unsigned char * buf,int num)524 int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate,
525 const BN_BLINDING *blinding,
526 const BIGNUM *possible_arg2,
527 const BIGNUM *to_mod, BN_CTX *ctx,
528 unsigned char *buf, int num)
529 {
530 limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL;
531 limb_t *l_ret = NULL, *l_tmp = NULL, l_buf;
532 size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0;
533 size_t l_tmp_count = 0;
534 int ret = 0;
535 size_t i;
536 unsigned char *tmp;
537 const BIGNUM *arg1 = intermediate;
538 const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2;
539
540 l_im_count = (BN_num_bytes(arg1) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
541 l_mul_count = (BN_num_bytes(arg2) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
542 l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
543
544 l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count;
545 l_im = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
546 l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
547 l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE);
548
549 if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL))
550 goto err;
551
552 BN_to_limb(arg1, l_im, l_size);
553 BN_to_limb(arg2, l_mul, l_size);
554 BN_to_limb(to_mod, l_mod, l_mod_count);
555
556 l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE);
557
558 if (blinding->m_ctx != NULL) {
559 l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ?
560 mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count);
561 l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
562 } else {
563 l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ?
564 mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count);
565 l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
566 }
567
568 if ((l_ret == NULL) || (l_tmp == NULL))
569 goto err;
570
571 if (blinding->m_ctx != NULL) {
572 limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
573 mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count,
574 blinding->m_ctx->n0[0], l_tmp);
575 } else {
576 limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
577 mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp);
578 }
579
580 /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */
581 if (num < BN_num_bytes(to_mod)) {
582 BNerr(BN_F_OSSL_BN_RSA_DO_UNBLIND, ERR_R_PASSED_INVALID_ARGUMENT);
583 goto err;
584 }
585
586 memset(buf, 0, num);
587 tmp = buf + num - BN_num_bytes(to_mod);
588 for (i = 0; i < l_mod_count; i++) {
589 #if LIMB_BYTE_SIZE == 8
590 l_buf = be64(l_ret[i]);
591 #else
592 l_buf = be32(l_ret[i]);
593 #endif
594 if (i == 0) {
595 int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num);
596
597 memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta);
598 tmp += delta;
599 } else {
600 memcpy(tmp, &l_buf, LIMB_BYTE_SIZE);
601 tmp += LIMB_BYTE_SIZE;
602 }
603 }
604 ret = num;
605
606 err:
607 OPENSSL_free(l_im);
608 OPENSSL_free(l_mul);
609 OPENSSL_free(l_mod);
610 OPENSSL_free(l_tmp);
611 OPENSSL_free(l_ret);
612
613 return ret;
614 }
615