• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <openssl/e_os2.h>
2 #include <stddef.h>
3 #include <sys/types.h>
4 #include <string.h>
5 #include <openssl/bn.h>
6 #include <openssl/err.h>
7 #include <openssl/rsaerr.h>
8 #include "internal/numbers.h"
9 #include "internal/constant_time.h"
10 #include "bn_local.h"
11 
12 # if BN_BYTES == 8
13 typedef uint64_t limb_t;
14 #  if defined(__SIZEOF_INT128__) && __SIZEOF_INT128__ == 16
15 /* nonstandard; implemented by gcc on 64-bit platforms */
16 typedef __uint128_t limb2_t;
17 #   define HAVE_LIMB2_T
18 #  endif
19 #  define LIMB_BIT_SIZE 64
20 #  define LIMB_BYTE_SIZE 8
21 # elif BN_BYTES == 4
22 typedef uint32_t limb_t;
23 typedef uint64_t limb2_t;
24 #  define LIMB_BIT_SIZE 32
25 #  define LIMB_BYTE_SIZE 4
26 #  define HAVE_LIMB2_T
27 # else
28 #  error "Not supported"
29 # endif
30 
31 /*
32  * For multiplication we're using schoolbook multiplication,
33  * so if we have two numbers, each with 6 "digits" (words)
34  * the multiplication is calculated as follows:
35  *                        A B C D E F
36  *                     x  I J K L M N
37  *                     --------------
38  *                                N*F
39  *                              N*E
40  *                            N*D
41  *                          N*C
42  *                        N*B
43  *                      N*A
44  *                              M*F
45  *                            M*E
46  *                          M*D
47  *                        M*C
48  *                      M*B
49  *                    M*A
50  *                            L*F
51  *                          L*E
52  *                        L*D
53  *                      L*C
54  *                    L*B
55  *                  L*A
56  *                          K*F
57  *                        K*E
58  *                      K*D
59  *                    K*C
60  *                  K*B
61  *                K*A
62  *                        J*F
63  *                      J*E
64  *                    J*D
65  *                  J*C
66  *                J*B
67  *              J*A
68  *                      I*F
69  *                    I*E
70  *                  I*D
71  *                I*C
72  *              I*B
73  *         +  I*A
74  *         ==========================
75  *                        N*B N*D N*F
76  *                    + N*A N*C N*E
77  *                    + M*B M*D M*F
78  *                  + M*A M*C M*E
79  *                  + L*B L*D L*F
80  *                + L*A L*C L*E
81  *                + K*B K*D K*F
82  *              + K*A K*C K*E
83  *              + J*B J*D J*F
84  *            + J*A J*C J*E
85  *            + I*B I*D I*F
86  *          + I*A I*C I*E
87  *
88  *                1+1 1+3 1+5
89  *              1+0 1+2 1+4
90  *              0+1 0+3 0+5
91  *            0+0 0+2 0+4
92  *
93  *            0 1 2 3 4 5 6
94  * which requires n^2 multiplications and 2n full length additions
95  * as we can keep every other result of limb multiplication in two separate
96  * limbs
97  */
98 
99 #if defined HAVE_LIMB2_T
_mul_limb(limb_t * hi,limb_t * lo,limb_t a,limb_t b)100 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
101 {
102     limb2_t t;
103     /*
104      * this is idiomatic code to tell compiler to use the native mul
105      * those three lines will actually compile to single instruction
106      */
107 
108     t = (limb2_t)a * b;
109     *hi = t >> LIMB_BIT_SIZE;
110     *lo = (limb_t)t;
111 }
112 #elif (BN_BYTES == 8) && (defined _MSC_VER)
113 /* https://learn.microsoft.com/en-us/cpp/intrinsics/umul128?view=msvc-170 */
114 #pragma intrinsic(_umul128)
_mul_limb(limb_t * hi,limb_t * lo,limb_t a,limb_t b)115 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
116 {
117     *lo = _umul128(a, b, hi);
118 }
119 #else
120 /*
121  * if the compiler doesn't have either a 128bit data type nor a "return
122  * high 64 bits of multiplication"
123  */
_mul_limb(limb_t * hi,limb_t * lo,limb_t a,limb_t b)124 static ossl_inline void _mul_limb(limb_t *hi, limb_t *lo, limb_t a, limb_t b)
125 {
126     limb_t a_low = (limb_t)(uint32_t)a;
127     limb_t a_hi = a >> 32;
128     limb_t b_low = (limb_t)(uint32_t)b;
129     limb_t b_hi = b >> 32;
130 
131     limb_t p0 = a_low * b_low;
132     limb_t p1 = a_low * b_hi;
133     limb_t p2 = a_hi * b_low;
134     limb_t p3 = a_hi * b_hi;
135 
136     uint32_t cy = (uint32_t)(((p0 >> 32) + (uint32_t)p1 + (uint32_t)p2) >> 32);
137 
138     *lo = p0 + (p1 << 32) + (p2 << 32);
139     *hi = p3 + (p1 >> 32) + (p2 >> 32) + cy;
140 }
141 #endif
142 
143 /* add two limbs with carry in, return carry out */
_add_limb(limb_t * ret,limb_t a,limb_t b,limb_t carry)144 static ossl_inline limb_t _add_limb(limb_t *ret, limb_t a, limb_t b, limb_t carry)
145 {
146     limb_t carry1, carry2, t;
147     /*
148      * `c = a + b; if (c < a)` is idiomatic code that makes compilers
149      * use add with carry on assembly level
150      */
151 
152     *ret = a + carry;
153     if (*ret < a)
154         carry1 = 1;
155     else
156         carry1 = 0;
157 
158     t = *ret;
159     *ret = t + b;
160     if (*ret < t)
161         carry2 = 1;
162     else
163         carry2 = 0;
164 
165     return carry1 + carry2;
166 }
167 
168 /*
169  * add two numbers of the same size, return overflow
170  *
171  * add a to b, place result in ret; all arrays need to be n limbs long
172  * return overflow from addition (0 or 1)
173  */
add(limb_t * ret,limb_t * a,limb_t * b,size_t n)174 static ossl_inline limb_t add(limb_t *ret, limb_t *a, limb_t *b, size_t n)
175 {
176     limb_t c = 0;
177     ossl_ssize_t i;
178 
179     for(i = n - 1; i > -1; i--)
180         c = _add_limb(&ret[i], a[i], b[i], c);
181 
182     return c;
183 }
184 
185 /*
186  * return number of limbs necessary for temporary values
187  * when multiplying numbers n limbs large
188  */
mul_limb_numb(size_t n)189 static ossl_inline size_t mul_limb_numb(size_t n)
190 {
191     return  2 * n * 2;
192 }
193 
194 /*
195  * multiply two numbers of the same size
196  *
197  * multiply a by b, place result in ret; a and b need to be n limbs long
198  * ret needs to be 2*n limbs long, tmp needs to be mul_limb_numb(n) limbs
199  * long
200  */
limb_mul(limb_t * ret,limb_t * a,limb_t * b,size_t n,limb_t * tmp)201 static void limb_mul(limb_t *ret, limb_t *a, limb_t *b, size_t n, limb_t *tmp)
202 {
203     limb_t *r_odd, *r_even;
204     size_t i, j, k;
205 
206     r_odd = tmp;
207     r_even = &tmp[2 * n];
208 
209     memset(ret, 0, 2 * n * sizeof(limb_t));
210 
211     for (i = 0; i < n; i++) {
212         for (k = 0; k < i + n + 1; k++) {
213             r_even[k] = 0;
214             r_odd[k] = 0;
215         }
216         for (j = 0; j < n; j++) {
217             /*
218              * place results from even and odd limbs in separate arrays so that
219              * we don't have to calculate overflow every time we get individual
220              * limb multiplication result
221              */
222             if (j % 2 == 0)
223                 _mul_limb(&r_even[i + j], &r_even[i + j + 1], a[i], b[j]);
224             else
225                 _mul_limb(&r_odd[i + j], &r_odd[i + j + 1], a[i], b[j]);
226         }
227         /*
228          * skip the least significant limbs when adding multiples of
229          * more significant limbs (they're zero anyway)
230          */
231         add(ret, ret, r_even, n + i + 1);
232         add(ret, ret, r_odd, n + i + 1);
233     }
234 }
235 
236 /* modifies the value in place by performing a right shift by one bit */
rshift1(limb_t * val,size_t n)237 static ossl_inline void rshift1(limb_t *val, size_t n)
238 {
239     limb_t shift_in = 0, shift_out = 0;
240     size_t i;
241 
242     for (i = 0; i < n; i++) {
243         shift_out = val[i] & 1;
244         val[i] = shift_in << (LIMB_BIT_SIZE - 1) | (val[i] >> 1);
245         shift_in = shift_out;
246     }
247 }
248 
249 /* extend the LSB of flag to all bits of limb */
mk_mask(limb_t flag)250 static ossl_inline limb_t mk_mask(limb_t flag)
251 {
252     flag |= flag << 1;
253     flag |= flag << 2;
254     flag |= flag << 4;
255     flag |= flag << 8;
256     flag |= flag << 16;
257 #if (LIMB_BYTE_SIZE == 8)
258     flag |= flag << 32;
259 #endif
260     return flag;
261 }
262 
263 /*
264  * copy from either a or b to ret based on flag
265  * when flag == 0, then copies from b
266  * when flag == 1, then copies from a
267  */
cselect(limb_t flag,limb_t * ret,limb_t * a,limb_t * b,size_t n)268 static ossl_inline void cselect(limb_t flag, limb_t *ret, limb_t *a, limb_t *b, size_t n)
269 {
270     /*
271      * would be more efficient with non volatile mask, but then gcc
272      * generates code with jumps
273      */
274     volatile limb_t mask;
275     size_t i;
276 
277     mask = mk_mask(flag);
278     for (i = 0; i < n; i++) {
279 #if (LIMB_BYTE_SIZE == 8)
280         ret[i] = constant_time_select_64(mask, a[i], b[i]);
281 #else
282         ret[i] = constant_time_select_32(mask, a[i], b[i]);
283 #endif
284     }
285 }
286 
_sub_limb(limb_t * ret,limb_t a,limb_t b,limb_t borrow)287 static limb_t _sub_limb(limb_t *ret, limb_t a, limb_t b, limb_t borrow)
288 {
289     limb_t borrow1, borrow2, t;
290     /*
291      * while it doesn't look constant-time, this is idiomatic code
292      * to tell compilers to use the carry bit from subtraction
293      */
294 
295     *ret = a - borrow;
296     if (*ret > a)
297         borrow1 = 1;
298     else
299         borrow1 = 0;
300 
301     t = *ret;
302     *ret = t - b;
303     if (*ret > t)
304         borrow2 = 1;
305     else
306         borrow2 = 0;
307 
308     return borrow1 + borrow2;
309 }
310 
311 /*
312  * place the result of a - b into ret, return the borrow bit.
313  * All arrays need to be n limbs long
314  */
sub(limb_t * ret,limb_t * a,limb_t * b,size_t n)315 static limb_t sub(limb_t *ret, limb_t *a, limb_t *b, size_t n)
316 {
317     limb_t borrow = 0;
318     ossl_ssize_t i;
319 
320     for (i = n - 1; i > -1; i--)
321         borrow = _sub_limb(&ret[i], a[i], b[i], borrow);
322 
323     return borrow;
324 }
325 
326 /* return the number of limbs necessary to allocate for the mod() tmp operand */
mod_limb_numb(size_t anum,size_t modnum)327 static ossl_inline size_t mod_limb_numb(size_t anum, size_t modnum)
328 {
329     return (anum + modnum) * 3;
330 }
331 
332 /*
333  * calculate a % mod, place the result in ret
334  * size of a is defined by anum, size of ret and mod is modnum,
335  * size of tmp is returned by mod_limb_numb()
336  */
mod(limb_t * ret,limb_t * a,size_t anum,limb_t * mod,size_t modnum,limb_t * tmp)337 static void mod(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
338                size_t modnum, limb_t *tmp)
339 {
340     limb_t *atmp, *modtmp, *rettmp;
341     limb_t res;
342     size_t i;
343 
344     memset(tmp, 0, mod_limb_numb(anum, modnum) * LIMB_BYTE_SIZE);
345 
346     atmp = tmp;
347     modtmp = &tmp[anum + modnum];
348     rettmp = &tmp[(anum + modnum) * 2];
349 
350     for (i = modnum; i <modnum + anum; i++)
351         atmp[i] = a[i-modnum];
352 
353     for (i = 0; i < modnum; i++)
354         modtmp[i] = mod[i];
355 
356     for (i = 0; i < anum * LIMB_BIT_SIZE; i++) {
357         rshift1(modtmp, anum + modnum);
358         res = sub(rettmp, atmp, modtmp, anum+modnum);
359         cselect(res, atmp, atmp, rettmp, anum+modnum);
360     }
361 
362     memcpy(ret, &atmp[anum], sizeof(limb_t) * modnum);
363 }
364 
365 /* necessary size of tmp for a _mul_add_limb() call with provided anum */
_mul_add_limb_numb(size_t anum)366 static ossl_inline size_t _mul_add_limb_numb(size_t anum)
367 {
368     return 2 * (anum + 1);
369 }
370 
371 /* multiply a by m, add to ret, return carry */
_mul_add_limb(limb_t * ret,limb_t * a,size_t anum,limb_t m,limb_t * tmp)372 static limb_t _mul_add_limb(limb_t *ret, limb_t *a, size_t anum,
373                            limb_t m, limb_t *tmp)
374 {
375     limb_t carry = 0;
376     limb_t *r_odd, *r_even;
377     size_t i;
378 
379     memset(tmp, 0, sizeof(limb_t) * (anum + 1) * 2);
380 
381     r_odd = tmp;
382     r_even = &tmp[anum + 1];
383 
384     for (i = 0; i < anum; i++) {
385         /*
386          * place the results from even and odd limbs in separate arrays
387          * so that we have to worry about carry just once
388          */
389         if (i % 2 == 0)
390             _mul_limb(&r_even[i], &r_even[i + 1], a[i], m);
391         else
392             _mul_limb(&r_odd[i], &r_odd[i + 1], a[i], m);
393     }
394     /* assert: add() carry here will be equal zero */
395     add(r_even, r_even, r_odd, anum + 1);
396     /*
397      * while here it will not overflow as the max value from multiplication
398      * is -2 while max overflow from addition is 1, so the max value of
399      * carry is -1 (i.e. max int)
400      */
401     carry = add(ret, ret, &r_even[1], anum) + r_even[0];
402 
403     return carry;
404 }
405 
mod_montgomery_limb_numb(size_t modnum)406 static ossl_inline size_t mod_montgomery_limb_numb(size_t modnum)
407 {
408     return modnum * 2 + _mul_add_limb_numb(modnum);
409 }
410 
411 /*
412  * calculate a % mod, place result in ret
413  * assumes that a is in Montgomery form with the R (Montgomery modulus) being
414  * smallest power of two big enough to fit mod and that's also a power
415  * of the count of number of bits in limb_t (B).
416  * For calculation, we also need n', such that mod * n' == -1 mod B.
417  * anum must be <= 2 * modnum
418  * ret needs to be modnum words long
419  * tmp needs to be mod_montgomery_limb_numb(modnum) limbs long
420  */
mod_montgomery(limb_t * ret,limb_t * a,size_t anum,limb_t * mod,size_t modnum,limb_t ni0,limb_t * tmp)421 static void mod_montgomery(limb_t *ret, limb_t *a, size_t anum, limb_t *mod,
422                           size_t modnum, limb_t ni0, limb_t *tmp)
423 {
424     limb_t carry, v;
425     limb_t *res, *rp, *tmp2;
426     ossl_ssize_t i;
427 
428     res = tmp;
429     /*
430      * for intermediate result we need an integer twice as long as modulus
431      * but keep the input in the least significant limbs
432      */
433     memset(res, 0, sizeof(limb_t) * (modnum * 2));
434     memcpy(&res[modnum * 2 - anum], a, sizeof(limb_t) * anum);
435     rp = &res[modnum];
436     tmp2 = &res[modnum * 2];
437 
438     carry = 0;
439 
440     /* add multiples of the modulus to the value until R divides it cleanly */
441     for (i = modnum; i > 0; i--, rp--) {
442         v = _mul_add_limb(rp, mod, modnum, rp[modnum - 1] * ni0, tmp2);
443         v = v + carry + rp[-1];
444         carry |= (v != rp[-1]);
445         carry &= (v <= rp[-1]);
446         rp[-1] = v;
447     }
448 
449     /* perform the final reduction by mod... */
450     carry -= sub(ret, rp, mod, modnum);
451 
452     /* ...conditionally */
453     cselect(carry, ret, rp, ret, modnum);
454 }
455 
456 /* allocated buffer should be freed afterwards */
BN_to_limb(const BIGNUM * bn,limb_t * buf,size_t limbs)457 static void BN_to_limb(const BIGNUM *bn, limb_t *buf, size_t limbs)
458 {
459     int i;
460     int real_limbs = (BN_num_bytes(bn) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
461     limb_t *ptr = buf + (limbs - real_limbs);
462 
463     for (i = 0; i < real_limbs; i++)
464          ptr[i] = bn->d[real_limbs - i - 1];
465 }
466 
467 #if LIMB_BYTE_SIZE == 8
be64(uint64_t host)468 static ossl_inline uint64_t be64(uint64_t host)
469 {
470     const union {
471         long one;
472         char little;
473     } is_endian = { 1 };
474 
475     if (is_endian.little) {
476         uint64_t big = 0;
477 
478         big |= (host & 0xff00000000000000) >> 56;
479         big |= (host & 0x00ff000000000000) >> 40;
480         big |= (host & 0x0000ff0000000000) >> 24;
481         big |= (host & 0x000000ff00000000) >>  8;
482         big |= (host & 0x00000000ff000000) <<  8;
483         big |= (host & 0x0000000000ff0000) << 24;
484         big |= (host & 0x000000000000ff00) << 40;
485         big |= (host & 0x00000000000000ff) << 56;
486         return big;
487     } else {
488         return host;
489     }
490 }
491 
492 #else
493 /* Not all platforms have htobe32(). */
be32(uint32_t host)494 static ossl_inline uint32_t be32(uint32_t host)
495 {
496     const union {
497         long one;
498         char little;
499     } is_endian = { 1 };
500 
501     if (is_endian.little) {
502         uint32_t big = 0;
503 
504         big |= (host & 0xff000000) >> 24;
505         big |= (host & 0x00ff0000) >> 8;
506         big |= (host & 0x0000ff00) << 8;
507         big |= (host & 0x000000ff) << 24;
508         return big;
509     } else {
510         return host;
511     }
512 }
513 #endif
514 
515 /*
516  * We assume that intermediate, possible_arg2, blinding, and ctx are used
517  * similar to BN_BLINDING_invert_ex() arguments.
518  * to_mod is RSA modulus.
519  * buf and num is the serialization buffer and its length.
520  *
521  * Here we use classic/Montgomery multiplication and modulo. After the calculation finished
522  * we serialize the new structure instead of BIGNUMs taking endianness into account.
523  */
ossl_bn_rsa_do_unblind(const BIGNUM * intermediate,const BN_BLINDING * blinding,const BIGNUM * possible_arg2,const BIGNUM * to_mod,BN_CTX * ctx,unsigned char * buf,int num)524 int ossl_bn_rsa_do_unblind(const BIGNUM *intermediate,
525                            const BN_BLINDING *blinding,
526                            const BIGNUM *possible_arg2,
527                            const BIGNUM *to_mod, BN_CTX *ctx,
528                            unsigned char *buf, int num)
529 {
530     limb_t *l_im = NULL, *l_mul = NULL, *l_mod = NULL;
531     limb_t *l_ret = NULL, *l_tmp = NULL, l_buf;
532     size_t l_im_count = 0, l_mul_count = 0, l_size = 0, l_mod_count = 0;
533     size_t l_tmp_count = 0;
534     int ret = 0;
535     size_t i;
536     unsigned char *tmp;
537     const BIGNUM *arg1 = intermediate;
538     const BIGNUM *arg2 = (possible_arg2 == NULL) ? blinding->Ai : possible_arg2;
539 
540     l_im_count  = (BN_num_bytes(arg1)   + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
541     l_mul_count = (BN_num_bytes(arg2)   + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
542     l_mod_count = (BN_num_bytes(to_mod) + LIMB_BYTE_SIZE - 1) / LIMB_BYTE_SIZE;
543 
544     l_size = l_im_count > l_mul_count ? l_im_count : l_mul_count;
545     l_im  = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
546     l_mul = OPENSSL_zalloc(l_size * LIMB_BYTE_SIZE);
547     l_mod = OPENSSL_zalloc(l_mod_count * LIMB_BYTE_SIZE);
548 
549     if ((l_im == NULL) || (l_mul == NULL) || (l_mod == NULL))
550         goto err;
551 
552     BN_to_limb(arg1,   l_im,  l_size);
553     BN_to_limb(arg2,   l_mul, l_size);
554     BN_to_limb(to_mod, l_mod, l_mod_count);
555 
556     l_ret = OPENSSL_malloc(2 * l_size * LIMB_BYTE_SIZE);
557 
558     if (blinding->m_ctx != NULL) {
559         l_tmp_count = mul_limb_numb(l_size) > mod_montgomery_limb_numb(l_mod_count) ?
560                       mul_limb_numb(l_size) : mod_montgomery_limb_numb(l_mod_count);
561         l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
562     } else {
563         l_tmp_count = mul_limb_numb(l_size) > mod_limb_numb(2 * l_size, l_mod_count) ?
564                       mul_limb_numb(l_size) : mod_limb_numb(2 * l_size, l_mod_count);
565         l_tmp = OPENSSL_malloc(l_tmp_count * LIMB_BYTE_SIZE);
566     }
567 
568     if ((l_ret == NULL) || (l_tmp == NULL))
569         goto err;
570 
571     if (blinding->m_ctx != NULL) {
572         limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
573         mod_montgomery(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count,
574                        blinding->m_ctx->n0[0], l_tmp);
575     } else {
576         limb_mul(l_ret, l_im, l_mul, l_size, l_tmp);
577         mod(l_ret, l_ret, 2 * l_size, l_mod, l_mod_count, l_tmp);
578     }
579 
580     /* modulus size in bytes can be equal to num but after limbs conversion it becomes bigger */
581     if (num < BN_num_bytes(to_mod)) {
582         BNerr(BN_F_OSSL_BN_RSA_DO_UNBLIND, ERR_R_PASSED_INVALID_ARGUMENT);
583         goto err;
584     }
585 
586     memset(buf, 0, num);
587     tmp = buf + num - BN_num_bytes(to_mod);
588     for (i = 0; i < l_mod_count; i++) {
589 #if LIMB_BYTE_SIZE == 8
590         l_buf = be64(l_ret[i]);
591 #else
592         l_buf = be32(l_ret[i]);
593 #endif
594         if (i == 0) {
595             int delta = LIMB_BYTE_SIZE - ((l_mod_count * LIMB_BYTE_SIZE) - num);
596 
597             memcpy(tmp, ((char *)&l_buf) + LIMB_BYTE_SIZE - delta, delta);
598             tmp += delta;
599         } else {
600             memcpy(tmp, &l_buf, LIMB_BYTE_SIZE);
601             tmp += LIMB_BYTE_SIZE;
602         }
603     }
604     ret = num;
605 
606  err:
607     OPENSSL_free(l_im);
608     OPENSSL_free(l_mul);
609     OPENSSL_free(l_mod);
610     OPENSSL_free(l_tmp);
611     OPENSSL_free(l_ret);
612 
613     return ret;
614 }
615