• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 1995-2022 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15 
16 #include "internal/cryptlib.h"
17 #include "crypto/bn.h"
18 #include "rsa_local.h"
19 #include "internal/constant_time.h"
20 
21 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
22                                   unsigned char *to, RSA *rsa, int padding);
23 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
24                                    unsigned char *to, RSA *rsa, int padding);
25 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
26                                   unsigned char *to, RSA *rsa, int padding);
27 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
28                                    unsigned char *to, RSA *rsa, int padding);
29 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
30                            BN_CTX *ctx);
31 static int rsa_ossl_init(RSA *rsa);
32 static int rsa_ossl_finish(RSA *rsa);
33 static RSA_METHOD rsa_pkcs1_ossl_meth = {
34     "OpenSSL PKCS#1 RSA",
35     rsa_ossl_public_encrypt,
36     rsa_ossl_public_decrypt,     /* signature verification */
37     rsa_ossl_private_encrypt,    /* signing */
38     rsa_ossl_private_decrypt,
39     rsa_ossl_mod_exp,
40     BN_mod_exp_mont,            /* XXX probably we should not use Montgomery
41                                  * if e == 3 */
42     rsa_ossl_init,
43     rsa_ossl_finish,
44     RSA_FLAG_FIPS_METHOD,       /* flags */
45     NULL,
46     0,                          /* rsa_sign */
47     0,                          /* rsa_verify */
48     NULL,                       /* rsa_keygen */
49     NULL                        /* rsa_multi_prime_keygen */
50 };
51 
52 static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
53 
RSA_set_default_method(const RSA_METHOD * meth)54 void RSA_set_default_method(const RSA_METHOD *meth)
55 {
56     default_RSA_meth = meth;
57 }
58 
RSA_get_default_method(void)59 const RSA_METHOD *RSA_get_default_method(void)
60 {
61     return default_RSA_meth;
62 }
63 
RSA_PKCS1_OpenSSL(void)64 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
65 {
66     return &rsa_pkcs1_ossl_meth;
67 }
68 
RSA_null_method(void)69 const RSA_METHOD *RSA_null_method(void)
70 {
71     return NULL;
72 }
73 
rsa_ossl_public_encrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)74 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
75                                   unsigned char *to, RSA *rsa, int padding)
76 {
77     BIGNUM *f, *ret;
78     int i, num = 0, r = -1;
79     unsigned char *buf = NULL;
80     BN_CTX *ctx = NULL;
81 
82     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
83         ERR_raise(ERR_LIB_RSA, RSA_R_MODULUS_TOO_LARGE);
84         return -1;
85     }
86 
87     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
88         ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
89         return -1;
90     }
91 
92     /* for large moduli, enforce exponent limit */
93     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
94         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
95             ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
96             return -1;
97         }
98     }
99 
100     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
101         goto err;
102     BN_CTX_start(ctx);
103     f = BN_CTX_get(ctx);
104     ret = BN_CTX_get(ctx);
105     num = BN_num_bytes(rsa->n);
106     buf = OPENSSL_malloc(num);
107     if (ret == NULL || buf == NULL) {
108         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
109         goto err;
110     }
111 
112     switch (padding) {
113     case RSA_PKCS1_PADDING:
114         i = ossl_rsa_padding_add_PKCS1_type_2_ex(rsa->libctx, buf, num,
115                                                  from, flen);
116         break;
117     case RSA_PKCS1_OAEP_PADDING:
118         i = ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(rsa->libctx, buf, num,
119                                                     from, flen, NULL, 0,
120                                                     NULL, NULL);
121         break;
122     case RSA_NO_PADDING:
123         i = RSA_padding_add_none(buf, num, from, flen);
124         break;
125     default:
126         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
127         goto err;
128     }
129     if (i <= 0)
130         goto err;
131 
132     if (BN_bin2bn(buf, num, f) == NULL)
133         goto err;
134 
135     if (BN_ucmp(f, rsa->n) >= 0) {
136         /* usually the padding functions would catch this */
137         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
138         goto err;
139     }
140 
141     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
142         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
143                                     rsa->n, ctx))
144             goto err;
145 
146     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
147                                rsa->_method_mod_n))
148         goto err;
149 
150     /*
151      * BN_bn2binpad puts in leading 0 bytes if the number is less than
152      * the length of the modulus.
153      */
154     r = BN_bn2binpad(ret, to, num);
155  err:
156     BN_CTX_end(ctx);
157     BN_CTX_free(ctx);
158     OPENSSL_clear_free(buf, num);
159     return r;
160 }
161 
rsa_get_blinding(RSA * rsa,int * local,BN_CTX * ctx)162 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
163 {
164     BN_BLINDING *ret;
165 
166     if (!CRYPTO_THREAD_write_lock(rsa->lock))
167         return NULL;
168 
169     if (rsa->blinding == NULL) {
170         rsa->blinding = RSA_setup_blinding(rsa, ctx);
171     }
172 
173     ret = rsa->blinding;
174     if (ret == NULL)
175         goto err;
176 
177     if (BN_BLINDING_is_current_thread(ret)) {
178         /* rsa->blinding is ours! */
179 
180         *local = 1;
181     } else {
182         /* resort to rsa->mt_blinding instead */
183 
184         /*
185          * instructs rsa_blinding_convert(), rsa_blinding_invert() that the
186          * BN_BLINDING is shared, meaning that accesses require locks, and
187          * that the blinding factor must be stored outside the BN_BLINDING
188          */
189         *local = 0;
190 
191         if (rsa->mt_blinding == NULL) {
192             rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
193         }
194         ret = rsa->mt_blinding;
195     }
196 
197  err:
198     CRYPTO_THREAD_unlock(rsa->lock);
199     return ret;
200 }
201 
rsa_blinding_convert(BN_BLINDING * b,BIGNUM * f,BIGNUM * unblind,BN_CTX * ctx)202 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
203                                 BN_CTX *ctx)
204 {
205     if (unblind == NULL) {
206         /*
207          * Local blinding: store the unblinding factor in BN_BLINDING.
208          */
209         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
210     } else {
211         /*
212          * Shared blinding: store the unblinding factor outside BN_BLINDING.
213          */
214         int ret;
215 
216         if (!BN_BLINDING_lock(b))
217             return 0;
218 
219         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
220         BN_BLINDING_unlock(b);
221 
222         return ret;
223     }
224 }
225 
rsa_blinding_invert(BN_BLINDING * b,BIGNUM * f,BIGNUM * unblind,BN_CTX * ctx)226 static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
227                                BN_CTX *ctx)
228 {
229     /*
230      * For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex
231      * will use the unblinding factor stored in BN_BLINDING. If BN_BLINDING
232      * is shared between threads, unblind must be non-null:
233      * BN_BLINDING_invert_ex will then use the local unblinding factor, and
234      * will only read the modulus from BN_BLINDING. In both cases it's safe
235      * to access the blinding without a lock.
236      */
237     return BN_BLINDING_invert_ex(f, unblind, b, ctx);
238 }
239 
240 /* signing */
rsa_ossl_private_encrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)241 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
242                                    unsigned char *to, RSA *rsa, int padding)
243 {
244     BIGNUM *f, *ret, *res;
245     int i, num = 0, r = -1;
246     unsigned char *buf = NULL;
247     BN_CTX *ctx = NULL;
248     int local_blinding = 0;
249     /*
250      * Used only if the blinding structure is shared. A non-NULL unblind
251      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
252      * the unblinding factor outside the blinding structure.
253      */
254     BIGNUM *unblind = NULL;
255     BN_BLINDING *blinding = NULL;
256 
257     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
258         goto err;
259     BN_CTX_start(ctx);
260     f = BN_CTX_get(ctx);
261     ret = BN_CTX_get(ctx);
262     num = BN_num_bytes(rsa->n);
263     buf = OPENSSL_malloc(num);
264     if (ret == NULL || buf == NULL) {
265         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
266         goto err;
267     }
268 
269     switch (padding) {
270     case RSA_PKCS1_PADDING:
271         i = RSA_padding_add_PKCS1_type_1(buf, num, from, flen);
272         break;
273     case RSA_X931_PADDING:
274         i = RSA_padding_add_X931(buf, num, from, flen);
275         break;
276     case RSA_NO_PADDING:
277         i = RSA_padding_add_none(buf, num, from, flen);
278         break;
279     default:
280         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
281         goto err;
282     }
283     if (i <= 0)
284         goto err;
285 
286     if (BN_bin2bn(buf, num, f) == NULL)
287         goto err;
288 
289     if (BN_ucmp(f, rsa->n) >= 0) {
290         /* usually the padding functions would catch this */
291         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
292         goto err;
293     }
294 
295     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
296         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
297                                     rsa->n, ctx))
298             goto err;
299 
300     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
301         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
302         if (blinding == NULL) {
303             ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
304             goto err;
305         }
306     }
307 
308     if (blinding != NULL) {
309         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
310             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
311             goto err;
312         }
313         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
314             goto err;
315     }
316 
317     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
318         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
319         ((rsa->p != NULL) &&
320          (rsa->q != NULL) &&
321          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
322         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
323             goto err;
324     } else {
325         BIGNUM *d = BN_new();
326         if (d == NULL) {
327             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
328             goto err;
329         }
330         if (rsa->d == NULL) {
331             ERR_raise(ERR_LIB_RSA, RSA_R_MISSING_PRIVATE_KEY);
332             BN_free(d);
333             goto err;
334         }
335         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
336 
337         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
338                                    rsa->_method_mod_n)) {
339             BN_free(d);
340             goto err;
341         }
342         /* We MUST free d before any further use of rsa->d */
343         BN_free(d);
344     }
345 
346     if (blinding)
347         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
348             goto err;
349 
350     if (padding == RSA_X931_PADDING) {
351         if (!BN_sub(f, rsa->n, ret))
352             goto err;
353         if (BN_cmp(ret, f) > 0)
354             res = f;
355         else
356             res = ret;
357     } else {
358         res = ret;
359     }
360 
361     /*
362      * BN_bn2binpad puts in leading 0 bytes if the number is less than
363      * the length of the modulus.
364      */
365     r = BN_bn2binpad(res, to, num);
366  err:
367     BN_CTX_end(ctx);
368     BN_CTX_free(ctx);
369     OPENSSL_clear_free(buf, num);
370     return r;
371 }
372 
rsa_ossl_private_decrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)373 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
374                                    unsigned char *to, RSA *rsa, int padding)
375 {
376     BIGNUM *f, *ret;
377     int j, num = 0, r = -1;
378     unsigned char *buf = NULL;
379     BN_CTX *ctx = NULL;
380     int local_blinding = 0;
381     /*
382      * Used only if the blinding structure is shared. A non-NULL unblind
383      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
384      * the unblinding factor outside the blinding structure.
385      */
386     BIGNUM *unblind = NULL;
387     BN_BLINDING *blinding = NULL;
388 
389     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
390         goto err;
391     BN_CTX_start(ctx);
392     f = BN_CTX_get(ctx);
393     ret = BN_CTX_get(ctx);
394     num = BN_num_bytes(rsa->n);
395     buf = OPENSSL_malloc(num);
396     if (ret == NULL || buf == NULL) {
397         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
398         goto err;
399     }
400 
401     /*
402      * This check was for equality but PGP does evil things and chops off the
403      * top '0' bytes
404      */
405     if (flen > num) {
406         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_GREATER_THAN_MOD_LEN);
407         goto err;
408     }
409 
410     /* make data into a big number */
411     if (BN_bin2bn(from, (int)flen, f) == NULL)
412         goto err;
413 
414     if (BN_ucmp(f, rsa->n) >= 0) {
415         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
416         goto err;
417     }
418 
419     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
420         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
421         if (blinding == NULL) {
422             ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
423             goto err;
424         }
425     }
426 
427     if (blinding != NULL) {
428         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
429             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
430             goto err;
431         }
432         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
433             goto err;
434     }
435 
436     /* do the decrypt */
437     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
438         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
439         ((rsa->p != NULL) &&
440          (rsa->q != NULL) &&
441          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
442         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
443             goto err;
444     } else {
445         BIGNUM *d = BN_new();
446         if (d == NULL) {
447             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
448             goto err;
449         }
450         if (rsa->d == NULL) {
451             ERR_raise(ERR_LIB_RSA, RSA_R_MISSING_PRIVATE_KEY);
452             BN_free(d);
453             goto err;
454         }
455         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
456 
457         if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
458             if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
459                                         rsa->n, ctx)) {
460                 BN_free(d);
461                 goto err;
462             }
463         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
464                                    rsa->_method_mod_n)) {
465             BN_free(d);
466             goto err;
467         }
468         /* We MUST free d before any further use of rsa->d */
469         BN_free(d);
470     }
471 
472     if (blinding) {
473         /*
474          * ossl_bn_rsa_do_unblind() combines blinding inversion and
475          * 0-padded BN BE serialization
476          */
477         j = ossl_bn_rsa_do_unblind(ret, blinding, unblind, rsa->n, ctx,
478                                    buf, num);
479         if (j == 0)
480             goto err;
481     } else {
482         j = BN_bn2binpad(ret, buf, num);
483         if (j < 0)
484             goto err;
485     }
486 
487     switch (padding) {
488     case RSA_PKCS1_PADDING:
489         r = RSA_padding_check_PKCS1_type_2(to, num, buf, j, num);
490         break;
491     case RSA_PKCS1_OAEP_PADDING:
492         r = RSA_padding_check_PKCS1_OAEP(to, num, buf, j, num, NULL, 0);
493         break;
494     case RSA_NO_PADDING:
495         memcpy(to, buf, (r = j));
496         break;
497     default:
498         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
499         goto err;
500     }
501 #ifndef FIPS_MODULE
502     /*
503      * This trick doesn't work in the FIPS provider because libcrypto manages
504      * the error stack. Instead we opt not to put an error on the stack at all
505      * in case of padding failure in the FIPS provider.
506      */
507     ERR_raise(ERR_LIB_RSA, RSA_R_PADDING_CHECK_FAILED);
508     err_clear_last_constant_time(1 & ~constant_time_msb(r));
509 #endif
510 
511  err:
512     BN_CTX_end(ctx);
513     BN_CTX_free(ctx);
514     OPENSSL_clear_free(buf, num);
515     return r;
516 }
517 
518 /* signature verification */
rsa_ossl_public_decrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)519 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
520                                   unsigned char *to, RSA *rsa, int padding)
521 {
522     BIGNUM *f, *ret;
523     int i, num = 0, r = -1;
524     unsigned char *buf = NULL;
525     BN_CTX *ctx = NULL;
526 
527     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
528         ERR_raise(ERR_LIB_RSA, RSA_R_MODULUS_TOO_LARGE);
529         return -1;
530     }
531 
532     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
533         ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
534         return -1;
535     }
536 
537     /* for large moduli, enforce exponent limit */
538     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
539         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
540             ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
541             return -1;
542         }
543     }
544 
545     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
546         goto err;
547     BN_CTX_start(ctx);
548     f = BN_CTX_get(ctx);
549     ret = BN_CTX_get(ctx);
550     num = BN_num_bytes(rsa->n);
551     buf = OPENSSL_malloc(num);
552     if (ret == NULL || buf == NULL) {
553         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
554         goto err;
555     }
556 
557     /*
558      * This check was for equality but PGP does evil things and chops off the
559      * top '0' bytes
560      */
561     if (flen > num) {
562         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_GREATER_THAN_MOD_LEN);
563         goto err;
564     }
565 
566     if (BN_bin2bn(from, flen, f) == NULL)
567         goto err;
568 
569     if (BN_ucmp(f, rsa->n) >= 0) {
570         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
571         goto err;
572     }
573 
574     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
575         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
576                                     rsa->n, ctx))
577             goto err;
578 
579     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
580                                rsa->_method_mod_n))
581         goto err;
582 
583     if ((padding == RSA_X931_PADDING) && ((bn_get_words(ret)[0] & 0xf) != 12))
584         if (!BN_sub(ret, rsa->n, ret))
585             goto err;
586 
587     i = BN_bn2binpad(ret, buf, num);
588     if (i < 0)
589         goto err;
590 
591     switch (padding) {
592     case RSA_PKCS1_PADDING:
593         r = RSA_padding_check_PKCS1_type_1(to, num, buf, i, num);
594         break;
595     case RSA_X931_PADDING:
596         r = RSA_padding_check_X931(to, num, buf, i, num);
597         break;
598     case RSA_NO_PADDING:
599         memcpy(to, buf, (r = i));
600         break;
601     default:
602         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
603         goto err;
604     }
605     if (r < 0)
606         ERR_raise(ERR_LIB_RSA, RSA_R_PADDING_CHECK_FAILED);
607 
608  err:
609     BN_CTX_end(ctx);
610     BN_CTX_free(ctx);
611     OPENSSL_clear_free(buf, num);
612     return r;
613 }
614 
rsa_ossl_mod_exp(BIGNUM * r0,const BIGNUM * I,RSA * rsa,BN_CTX * ctx)615 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
616 {
617     BIGNUM *r1, *m1, *vrfy;
618     int ret = 0, smooth = 0;
619 #ifndef FIPS_MODULE
620     BIGNUM *r2, *m[RSA_MAX_PRIME_NUM - 2];
621     int i, ex_primes = 0;
622     RSA_PRIME_INFO *pinfo;
623 #endif
624 
625     BN_CTX_start(ctx);
626 
627     r1 = BN_CTX_get(ctx);
628 #ifndef FIPS_MODULE
629     r2 = BN_CTX_get(ctx);
630 #endif
631     m1 = BN_CTX_get(ctx);
632     vrfy = BN_CTX_get(ctx);
633     if (vrfy == NULL)
634         goto err;
635 
636 #ifndef FIPS_MODULE
637     if (rsa->version == RSA_ASN1_VERSION_MULTI
638         && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
639              || ex_primes > RSA_MAX_PRIME_NUM - 2))
640         goto err;
641 #endif
642 
643     if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
644         BIGNUM *factor = BN_new();
645 
646         if (factor == NULL)
647             goto err;
648 
649         /*
650          * Make sure BN_mod_inverse in Montgomery initialization uses the
651          * BN_FLG_CONSTTIME flag
652          */
653         if (!(BN_with_flags(factor, rsa->p, BN_FLG_CONSTTIME),
654               BN_MONT_CTX_set_locked(&rsa->_method_mod_p, rsa->lock,
655                                      factor, ctx))
656             || !(BN_with_flags(factor, rsa->q, BN_FLG_CONSTTIME),
657                  BN_MONT_CTX_set_locked(&rsa->_method_mod_q, rsa->lock,
658                                         factor, ctx))) {
659             BN_free(factor);
660             goto err;
661         }
662 #ifndef FIPS_MODULE
663         for (i = 0; i < ex_primes; i++) {
664             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
665             BN_with_flags(factor, pinfo->r, BN_FLG_CONSTTIME);
666             if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, factor, ctx)) {
667                 BN_free(factor);
668                 goto err;
669             }
670         }
671 #endif
672         /*
673          * We MUST free |factor| before any further use of the prime factors
674          */
675         BN_free(factor);
676 
677         smooth = (rsa->meth->bn_mod_exp == BN_mod_exp_mont)
678 #ifndef FIPS_MODULE
679                  && (ex_primes == 0)
680 #endif
681                  && (BN_num_bits(rsa->q) == BN_num_bits(rsa->p));
682     }
683 
684     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
685         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
686                                     rsa->n, ctx))
687             goto err;
688 
689     if (smooth) {
690         /*
691          * Conversion from Montgomery domain, a.k.a. Montgomery reduction,
692          * accepts values in [0-m*2^w) range. w is m's bit width rounded up
693          * to limb width. So that at the very least if |I| is fully reduced,
694          * i.e. less than p*q, we can count on from-to round to perform
695          * below modulo operations on |I|. Unlike BN_mod it's constant time.
696          */
697         if (/* m1 = I moq q */
698             !bn_from_mont_fixed_top(m1, I, rsa->_method_mod_q, ctx)
699             || !bn_to_mont_fixed_top(m1, m1, rsa->_method_mod_q, ctx)
700             /* r1 = I mod p */
701             || !bn_from_mont_fixed_top(r1, I, rsa->_method_mod_p, ctx)
702             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
703             /*
704              * Use parallel exponentiations optimization if possible,
705              * otherwise fallback to two sequential exponentiations:
706              *    m1 = m1^dmq1 mod q
707              *    r1 = r1^dmp1 mod p
708              */
709             || !BN_mod_exp_mont_consttime_x2(m1, m1, rsa->dmq1, rsa->q,
710                                              rsa->_method_mod_q,
711                                              r1, r1, rsa->dmp1, rsa->p,
712                                              rsa->_method_mod_p,
713                                              ctx)
714             /* r1 = (r1 - m1) mod p */
715             /*
716              * bn_mod_sub_fixed_top is not regular modular subtraction,
717              * it can tolerate subtrahend to be larger than modulus, but
718              * not bit-wise wider. This makes up for uncommon q>p case,
719              * when |m1| can be larger than |rsa->p|.
720              */
721             || !bn_mod_sub_fixed_top(r1, r1, m1, rsa->p)
722 
723             /* r1 = r1 * iqmp mod p */
724             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
725             || !bn_mul_mont_fixed_top(r1, r1, rsa->iqmp, rsa->_method_mod_p,
726                                       ctx)
727             /* r0 = r1 * q + m1 */
728             || !bn_mul_fixed_top(r0, r1, rsa->q, ctx)
729             || !bn_mod_add_fixed_top(r0, r0, m1, rsa->n))
730             goto err;
731 
732         goto tail;
733     }
734 
735     /* compute I mod q */
736     {
737         BIGNUM *c = BN_new();
738         if (c == NULL)
739             goto err;
740         BN_with_flags(c, I, BN_FLG_CONSTTIME);
741 
742         if (!BN_mod(r1, c, rsa->q, ctx)) {
743             BN_free(c);
744             goto err;
745         }
746 
747         {
748             BIGNUM *dmq1 = BN_new();
749             if (dmq1 == NULL) {
750                 BN_free(c);
751                 goto err;
752             }
753             BN_with_flags(dmq1, rsa->dmq1, BN_FLG_CONSTTIME);
754 
755             /* compute r1^dmq1 mod q */
756             if (!rsa->meth->bn_mod_exp(m1, r1, dmq1, rsa->q, ctx,
757                                        rsa->_method_mod_q)) {
758                 BN_free(c);
759                 BN_free(dmq1);
760                 goto err;
761             }
762             /* We MUST free dmq1 before any further use of rsa->dmq1 */
763             BN_free(dmq1);
764         }
765 
766         /* compute I mod p */
767         if (!BN_mod(r1, c, rsa->p, ctx)) {
768             BN_free(c);
769             goto err;
770         }
771         /* We MUST free c before any further use of I */
772         BN_free(c);
773     }
774 
775     {
776         BIGNUM *dmp1 = BN_new();
777         if (dmp1 == NULL)
778             goto err;
779         BN_with_flags(dmp1, rsa->dmp1, BN_FLG_CONSTTIME);
780 
781         /* compute r1^dmp1 mod p */
782         if (!rsa->meth->bn_mod_exp(r0, r1, dmp1, rsa->p, ctx,
783                                    rsa->_method_mod_p)) {
784             BN_free(dmp1);
785             goto err;
786         }
787         /* We MUST free dmp1 before any further use of rsa->dmp1 */
788         BN_free(dmp1);
789     }
790 
791 #ifndef FIPS_MODULE
792     if (ex_primes > 0) {
793         BIGNUM *di = BN_new(), *cc = BN_new();
794 
795         if (cc == NULL || di == NULL) {
796             BN_free(cc);
797             BN_free(di);
798             goto err;
799         }
800 
801         for (i = 0; i < ex_primes; i++) {
802             /* prepare m_i */
803             if ((m[i] = BN_CTX_get(ctx)) == NULL) {
804                 BN_free(cc);
805                 BN_free(di);
806                 goto err;
807             }
808 
809             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
810 
811             /* prepare c and d_i */
812             BN_with_flags(cc, I, BN_FLG_CONSTTIME);
813             BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
814 
815             if (!BN_mod(r1, cc, pinfo->r, ctx)) {
816                 BN_free(cc);
817                 BN_free(di);
818                 goto err;
819             }
820             /* compute r1 ^ d_i mod r_i */
821             if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
822                 BN_free(cc);
823                 BN_free(di);
824                 goto err;
825             }
826         }
827 
828         BN_free(cc);
829         BN_free(di);
830     }
831 #endif
832 
833     if (!BN_sub(r0, r0, m1))
834         goto err;
835     /*
836      * This will help stop the size of r0 increasing, which does affect the
837      * multiply if it optimised for a power of 2 size
838      */
839     if (BN_is_negative(r0))
840         if (!BN_add(r0, r0, rsa->p))
841             goto err;
842 
843     if (!BN_mul(r1, r0, rsa->iqmp, ctx))
844         goto err;
845 
846     {
847         BIGNUM *pr1 = BN_new();
848         if (pr1 == NULL)
849             goto err;
850         BN_with_flags(pr1, r1, BN_FLG_CONSTTIME);
851 
852         if (!BN_mod(r0, pr1, rsa->p, ctx)) {
853             BN_free(pr1);
854             goto err;
855         }
856         /* We MUST free pr1 before any further use of r1 */
857         BN_free(pr1);
858     }
859 
860     /*
861      * If p < q it is occasionally possible for the correction of adding 'p'
862      * if r0 is negative above to leave the result still negative. This can
863      * break the private key operations: the following second correction
864      * should *always* correct this rare occurrence. This will *never* happen
865      * with OpenSSL generated keys because they ensure p > q [steve]
866      */
867     if (BN_is_negative(r0))
868         if (!BN_add(r0, r0, rsa->p))
869             goto err;
870     if (!BN_mul(r1, r0, rsa->q, ctx))
871         goto err;
872     if (!BN_add(r0, r1, m1))
873         goto err;
874 
875 #ifndef FIPS_MODULE
876     /* add m_i to m in multi-prime case */
877     if (ex_primes > 0) {
878         BIGNUM *pr2 = BN_new();
879 
880         if (pr2 == NULL)
881             goto err;
882 
883         for (i = 0; i < ex_primes; i++) {
884             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
885             if (!BN_sub(r1, m[i], r0)) {
886                 BN_free(pr2);
887                 goto err;
888             }
889 
890             if (!BN_mul(r2, r1, pinfo->t, ctx)) {
891                 BN_free(pr2);
892                 goto err;
893             }
894 
895             BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
896 
897             if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
898                 BN_free(pr2);
899                 goto err;
900             }
901 
902             if (BN_is_negative(r1))
903                 if (!BN_add(r1, r1, pinfo->r)) {
904                     BN_free(pr2);
905                     goto err;
906                 }
907             if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
908                 BN_free(pr2);
909                 goto err;
910             }
911             if (!BN_add(r0, r0, r1)) {
912                 BN_free(pr2);
913                 goto err;
914             }
915         }
916         BN_free(pr2);
917     }
918 #endif
919 
920  tail:
921     if (rsa->e && rsa->n) {
922         if (rsa->meth->bn_mod_exp == BN_mod_exp_mont) {
923             if (!BN_mod_exp_mont(vrfy, r0, rsa->e, rsa->n, ctx,
924                                  rsa->_method_mod_n))
925                 goto err;
926         } else {
927             bn_correct_top(r0);
928             if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
929                                        rsa->_method_mod_n))
930                 goto err;
931         }
932         /*
933          * If 'I' was greater than (or equal to) rsa->n, the operation will
934          * be equivalent to using 'I mod n'. However, the result of the
935          * verify will *always* be less than 'n' so we don't check for
936          * absolute equality, just congruency.
937          */
938         if (!BN_sub(vrfy, vrfy, I))
939             goto err;
940         if (BN_is_zero(vrfy)) {
941             bn_correct_top(r0);
942             ret = 1;
943             goto err;   /* not actually error */
944         }
945         if (!BN_mod(vrfy, vrfy, rsa->n, ctx))
946             goto err;
947         if (BN_is_negative(vrfy))
948             if (!BN_add(vrfy, vrfy, rsa->n))
949                 goto err;
950         if (!BN_is_zero(vrfy)) {
951             /*
952              * 'I' and 'vrfy' aren't congruent mod n. Don't leak
953              * miscalculated CRT output, just do a raw (slower) mod_exp and
954              * return that instead.
955              */
956 
957             BIGNUM *d = BN_new();
958             if (d == NULL)
959                 goto err;
960             BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
961 
962             if (!rsa->meth->bn_mod_exp(r0, I, d, rsa->n, ctx,
963                                        rsa->_method_mod_n)) {
964                 BN_free(d);
965                 goto err;
966             }
967             /* We MUST free d before any further use of rsa->d */
968             BN_free(d);
969         }
970     }
971     /*
972      * It's unfortunate that we have to bn_correct_top(r0). What hopefully
973      * saves the day is that correction is highly unlike, and private key
974      * operations are customarily performed on blinded message. Which means
975      * that attacker won't observe correlation with chosen plaintext.
976      * Secondly, remaining code would still handle it in same computational
977      * time and even conceal memory access pattern around corrected top.
978      */
979     bn_correct_top(r0);
980     ret = 1;
981  err:
982     BN_CTX_end(ctx);
983     return ret;
984 }
985 
rsa_ossl_init(RSA * rsa)986 static int rsa_ossl_init(RSA *rsa)
987 {
988     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
989     return 1;
990 }
991 
rsa_ossl_finish(RSA * rsa)992 static int rsa_ossl_finish(RSA *rsa)
993 {
994 #ifndef FIPS_MODULE
995     int i;
996     RSA_PRIME_INFO *pinfo;
997 
998     for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
999         pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
1000         BN_MONT_CTX_free(pinfo->m);
1001     }
1002 #endif
1003 
1004     BN_MONT_CTX_free(rsa->_method_mod_n);
1005     BN_MONT_CTX_free(rsa->_method_mod_p);
1006     BN_MONT_CTX_free(rsa->_method_mod_q);
1007     return 1;
1008 }
1009