• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_RSA
18 
19 #include "crypt_utils.h"
20 #include "crypt_rsa.h"
21 #include "rsa_local.h"
22 #include "crypt_errno.h"
23 #include "crypt_eal_md.h"
24 #include "bsl_sal.h"
25 #include "securec.h"
26 #include "bsl_err_internal.h"
27 #include "eal_pkey_local.h"
28 #include "eal_md_local.h"
29 #include "bsl_params.h"
30 #include "crypt_params_key.h"
31 
32 // rsa-decrypt Calculation used by Chinese Remainder Theorem(CRT). intermediate variables:
33 typedef struct {
34     BN_BigNum *cP;
35     BN_BigNum *cQ;
36     BN_BigNum *mP;
37     BN_BigNum *mQ;
38     BN_Mont *montP;
39     BN_Mont *montQ;
40 } RsaDecProcedurePara;
41 
InputRangeCheck(const BN_BigNum * input,const BN_BigNum * n,uint32_t bits)42 static int32_t InputRangeCheck(const BN_BigNum *input, const BN_BigNum *n, uint32_t bits)
43 {
44     // The value range defined in RFC is [0, n - 1]. Because the operation result of 0, 1, n - 1 is relatively fixed,
45     // it is considered invalid here. The actual valid value range is [2, n - 2].
46     int32_t ret;
47     BN_BigNum *nMinusOne = NULL;
48     if (BN_IsZero(input) == true || BN_IsOne(input) == true) {
49         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
50         return CRYPT_RSA_ERR_INPUT_VALUE;
51     }
52     /* Allocate 8 extra bits to prevent calculation errors due to the feature of BigNum calculation. */
53     nMinusOne = BN_Create(bits);
54     if (nMinusOne == NULL) {
55         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
56         return CRYPT_MEM_ALLOC_FAIL;
57     }
58     ret = BN_SubLimb(nMinusOne, n, 1);
59     if (ret != CRYPT_SUCCESS) {
60         BSL_ERR_PUSH_ERROR(ret);
61         BN_Destroy(nMinusOne);
62         return ret;
63     }
64     if (BN_Cmp(input, nMinusOne) >= 0) {
65         ret = CRYPT_RSA_ERR_INPUT_VALUE;
66         BSL_ERR_PUSH_ERROR(ret);
67     }
68     BN_Destroy(nMinusOne);
69     return ret;
70 }
71 
AddZero(uint32_t bits,uint8_t * out,uint32_t * outLen)72 static int32_t AddZero(uint32_t bits, uint8_t *out, uint32_t *outLen)
73 {
74     uint32_t i;
75     uint32_t zeros = 0;
76     uint32_t needBytes = BN_BITS_TO_BYTES(bits);
77     /* Divide bits by 8 to obtain the byte length. If it is smaller than the key length, pad it with 0. */
78     if ((*outLen) < needBytes) {
79         /* Divide bits by 8 to obtain the byte length. If it is smaller than the key length, pad it with 0. */
80         zeros = needBytes - (*outLen);
81         if (memmove_s(out + zeros, needBytes - zeros, out, (*outLen)) != EOK) {
82             BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
83             return CRYPT_SECUREC_FAIL;
84         }
85         for (i = 0; i < zeros; i++) {
86             out[i] = 0x0;
87         }
88     }
89     *outLen = needBytes;
90     return CRYPT_SUCCESS;
91 }
92 
ResultToOut(uint32_t bits,const BN_BigNum * result,uint8_t * out,uint32_t * outLen)93 static int32_t ResultToOut(uint32_t bits, const BN_BigNum *result, uint8_t *out, uint32_t *outLen)
94 {
95     int32_t ret = BN_Bn2Bin(result, out, outLen);
96     if (ret != CRYPT_SUCCESS) {
97         BSL_ERR_PUSH_ERROR(ret);
98         return ret;
99     }
100     return AddZero(bits, out, outLen);
101 }
102 
AllocResultAndInputBN(uint32_t bits,BN_BigNum ** result,BN_BigNum ** inputBN,const uint8_t * input,uint32_t inputLen)103 static int32_t AllocResultAndInputBN(uint32_t bits, BN_BigNum **result, BN_BigNum **inputBN,
104     const uint8_t *input, uint32_t inputLen)
105 {
106     if (inputLen > BN_BITS_TO_BYTES(bits)) {
107         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
108         return CRYPT_RSA_ERR_INPUT_VALUE;
109     }
110     *result = BN_Create(bits + 1);
111     *inputBN = BN_Create(bits);
112     if (*result == NULL || *inputBN == NULL) {
113         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
114         return CRYPT_MEM_ALLOC_FAIL;
115     }
116     return BN_Bin2Bn(*inputBN, input, inputLen);
117 }
118 
CalcMontExp(CRYPT_RSA_PrvKey * prvKey,BN_BigNum * result,const BN_BigNum * input,BN_Optimizer * opt,bool consttime)119 static int32_t CalcMontExp(CRYPT_RSA_PrvKey *prvKey,
120     BN_BigNum *result, const BN_BigNum *input, BN_Optimizer *opt, bool consttime)
121 {
122     int32_t ret;
123     BN_Mont *mont = NULL;
124     if (BN_IsZero(prvKey->n) || BN_IsZero(prvKey->d)) {
125         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
126         return CRYPT_RSA_NO_KEY_INFO;
127     }
128     mont = BN_MontCreate(prvKey->n);
129     if (mont == NULL) {
130         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
131         return CRYPT_MEM_ALLOC_FAIL;
132     }
133     if (consttime) {
134         ret = BN_MontExpConsttime(result, input, prvKey->d, mont, opt);
135     } else {
136         ret = BN_MontExp(result, input, prvKey->d, mont, opt);
137     }
138 
139     BN_MontDestroy(mont);
140     return ret;
141 }
142 
143 #if defined(HITLS_CRYPTO_RSA_ENCRYPT) || defined(HITLS_CRYPTO_RSA_VERIFY) || defined(HITLS_CRYPTO_RSA_SIGN)
CRYPT_RSA_PubEnc(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t * outLen)144 int32_t  CRYPT_RSA_PubEnc(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
145     uint8_t *out, uint32_t *outLen)
146 {
147     int32_t ret;
148     BN_BigNum *inputBN = NULL;
149     BN_BigNum *result = NULL;
150     if (ctx == NULL || input == NULL || out == NULL || outLen == NULL) {
151         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
152         return CRYPT_NULL_INPUT;
153     }
154     CRYPT_RSA_PubKey *pubKey = ctx->pubKey;
155     if (pubKey == NULL) {
156         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
157         return CRYPT_RSA_NO_KEY_INFO;
158     }
159     uint32_t bits = CRYPT_RSA_GetBits(ctx);
160     if ((*outLen) < BN_BITS_TO_BYTES(bits)) {
161         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
162         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
163     }
164     BN_Optimizer *optimizer = BN_OptimizerCreate();
165     if (optimizer == NULL) {
166         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
167         return CRYPT_MEM_ALLOC_FAIL;
168     }
169     GOTO_ERR_IF_EX(AllocResultAndInputBN(bits, &result, &inputBN, input, inputLen), ret);
170     GOTO_ERR_IF_EX(InputRangeCheck(inputBN, pubKey->n, bits), ret);
171 
172     // pubKey->mont: Ensure that this value is not empty when the public key is set or generated.
173     GOTO_ERR_IF(BN_MontExp(result, inputBN, pubKey->e, pubKey->mont, optimizer), ret);
174     ret = ResultToOut(bits, result, out, outLen);
175 ERR:
176     BN_Destroy(result);
177     BN_Destroy(inputBN);
178     BN_OptimizerDestroy(optimizer);
179     return ret;
180 }
181 #endif
182 
183 /* Release intermediate variables. */
RsaDecProcedureFree(RsaDecProcedurePara * para)184 static void RsaDecProcedureFree(RsaDecProcedurePara *para)
185 {
186     if (para == NULL) {
187         return;
188     }
189     BN_Destroy(para->cP);
190     BN_Destroy(para->cQ);
191     BN_Destroy(para->mP);
192     BN_Destroy(para->mQ);
193     BN_MontDestroy(para->montP);
194     BN_MontDestroy(para->montQ);
195 }
196 
197 /* Apply for intermediate variables. */
RsaDecProcedureAlloc(RsaDecProcedurePara * para,uint32_t bits,const CRYPT_RSA_PrvKey * priKey)198 static int32_t RsaDecProcedureAlloc(RsaDecProcedurePara *para, uint32_t bits, const CRYPT_RSA_PrvKey *priKey)
199 {
200     para->cP = BN_Create(bits);
201     para->cQ = BN_Create(bits);
202     para->mP = BN_Create(bits);
203     para->mQ = BN_Create(bits);
204     para->montP = BN_MontCreate(priKey->p);
205     para->montQ = BN_MontCreate(priKey->q);
206     if (para->cP == NULL || para->cQ == NULL ||
207         para->mP == NULL || para->mQ == NULL || para->montP == NULL || para->montQ == NULL) {
208         RsaDecProcedureFree(para);
209         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
210         return CRYPT_MEM_ALLOC_FAIL;
211     }
212     return CRYPT_SUCCESS;
213 }
214 
215 /* rsa decryption calculation by CRT. Message is the BigNum converted from the original input ciphertext. */
NormalDecProcedure(const CRYPT_RSA_Ctx * ctx,const BN_BigNum * message,BN_BigNum * result,BN_Optimizer * opt)216 static int32_t NormalDecProcedure(const CRYPT_RSA_Ctx *ctx, const BN_BigNum *message, BN_BigNum *result,
217     BN_Optimizer *opt)
218 {
219     CRYPT_RSA_PrvKey *priKey = ctx->prvKey;
220     uint32_t bits = CRYPT_RSA_GetBits(ctx);
221     RsaDecProcedurePara procedure = {0}; // Temporary variable
222     /* Apply for temporary variable */
223     int32_t ret = RsaDecProcedureAlloc(&procedure, bits, priKey);
224     if (ret != CRYPT_SUCCESS) {
225         return ret;
226     }
227     /* cP = M mod P where inp = M = Message */
228     ret = BN_Mod(procedure.cP, message, priKey->p, opt);
229     if (ret != CRYPT_SUCCESS) {
230         goto EXIT;
231     }
232     /* cQ = M mod Q where inp = M = Message */
233     ret = BN_Mod(procedure.cQ, message, priKey->q, opt);
234     if (ret != CRYPT_SUCCESS) {
235         goto EXIT;
236     }
237     /* mP = cP^dP mod p */
238     ret = BN_MontExpConsttime(procedure.mP, procedure.cP, priKey->dP, procedure.montP, opt);
239     if (ret != CRYPT_SUCCESS) {
240         goto EXIT;
241     }
242     /* mQ = cQ^dQ mod q */
243     ret = BN_MontExpConsttime(procedure.mQ, procedure.cQ, priKey->dQ, procedure.montQ, opt);
244     if (ret != CRYPT_SUCCESS) {
245         goto EXIT;
246     }
247     /* result = (mP - mQ) mod p */
248     ret = BN_ModSub(result, procedure.mP, procedure.mQ, priKey->p, opt);
249     if (ret != CRYPT_SUCCESS) {
250         goto EXIT;
251     }
252     /* result = result * qInv mod p */
253     ret = MontMulCore(result, result, priKey->qInv, procedure.montP, opt);
254     if (ret != CRYPT_SUCCESS) {
255         goto EXIT;
256     }
257     /* result = result * q */
258     ret = BN_Mul(result, result, priKey->q, opt);
259     if (ret != CRYPT_SUCCESS) {
260         goto EXIT;
261     }
262     /* result = result + mQ */
263     ret = BN_Add(result, result, procedure.mQ);
264 EXIT:
265     RsaDecProcedureFree(&procedure);
266     return ret;
267 }
268 
269 #ifdef HITLS_CRYPTO_RSA_BLINDING
RSA_GetSub(const BN_BigNum * p,const BN_BigNum * q,BN_BigNum * r1,BN_BigNum * r2)270 static int32_t RSA_GetSub(const BN_BigNum *p, const BN_BigNum *q, BN_BigNum *r1, BN_BigNum *r2)
271 {
272     int32_t ret = BN_SubLimb(r1, p, 1);
273     if (ret != CRYPT_SUCCESS) {
274         BSL_ERR_PUSH_ERROR(ret);
275         return ret;
276     }
277     ret = BN_SubLimb(r2, q, 1);
278     if (ret != CRYPT_SUCCESS) {
279         BSL_ERR_PUSH_ERROR(ret);
280     }
281     return ret;
282 }
283 
RSA_GetL(BN_BigNum * l,BN_BigNum * u,BN_BigNum * r1,BN_BigNum * r2,BN_Optimizer * opt)284 static int32_t RSA_GetL(BN_BigNum *l, BN_BigNum *u, BN_BigNum *r1, BN_BigNum *r2, BN_Optimizer *opt)
285 {
286     int32_t ret = BN_Mul(l, r1, r2, opt);
287     if (ret != CRYPT_SUCCESS) {
288         BSL_ERR_PUSH_ERROR(ret);
289         return ret;
290     }
291 
292     ret = BN_Gcd(u, r1, r2, opt);
293     if (ret != CRYPT_SUCCESS) {
294         BSL_ERR_PUSH_ERROR(ret);
295         return ret;
296     }
297 
298     ret = BN_Div(l, NULL, l, u, opt);
299     if (ret != CRYPT_SUCCESS) {
300         BSL_ERR_PUSH_ERROR(ret);
301     }
302     return ret;
303 }
304 
RSA_GetPublicExp(const BN_BigNum * d,const BN_BigNum * p,const BN_BigNum * q,uint32_t bits,BN_Optimizer * opt)305 static BN_BigNum *RSA_GetPublicExp(const BN_BigNum *d, const BN_BigNum *p,
306     const BN_BigNum *q, uint32_t bits, BN_Optimizer *opt)
307 {
308     int32_t ret;
309     /* Apply for the temporary space of the BN object */
310     BN_BigNum *l = BN_Create(bits);
311     BN_BigNum *r1 = BN_Create(bits >> 1);
312     BN_BigNum *r2 = BN_Create(bits >> 1);
313     BN_BigNum *u = BN_Create(bits + 1);
314     BN_BigNum *e = BN_Create(bits);
315 
316     if (l == NULL || r1 == NULL || r2 == NULL || u == NULL || e == NULL) {
317         ret = CRYPT_NULL_INPUT;
318         BSL_ERR_PUSH_ERROR(ret);
319         goto EXIT;
320     }
321 
322     ret = RSA_GetSub(p, q, r1, r2);
323     // The push error in GetSub can be used to locate the fault. Therefore, it is not added here.
324     if (ret != CRYPT_SUCCESS) {
325         goto EXIT;
326     }
327 
328     ret = RSA_GetL(l, u, r1, r2, opt);
329     // The push error in GetL can be used to locate the fault. Therefore, it is not added here.
330     if (ret != CRYPT_SUCCESS) {
331         goto EXIT;
332     }
333 
334     ret = BN_ModInv(e, d, l, opt);
335     if (ret != CRYPT_SUCCESS) {
336         BSL_ERR_PUSH_ERROR(ret);
337     }
338 EXIT:
339     BN_Destroy(r1);
340     BN_Destroy(r2);
341     BN_Destroy(l);
342     BN_Destroy(u);
343     if (ret != CRYPT_SUCCESS) {
344         BN_Destroy(e);
345         e = NULL;
346     }
347     return e;
348 }
349 
RSA_InitBlind(CRYPT_RSA_Ctx * ctx,BN_Optimizer * opt)350 static int32_t RSA_InitBlind(CRYPT_RSA_Ctx *ctx, BN_Optimizer *opt)
351 {
352     uint32_t bits = BN_Bits(ctx->prvKey->n);
353     bool needDestoryE = false;
354     BN_BigNum *e = ctx->prvKey->e;
355     if (e == NULL || BN_IsZero(e)) {
356         e = RSA_GetPublicExp(ctx->prvKey->d, ctx->prvKey->p, ctx->prvKey->q, bits, opt);
357         if (e == NULL) {
358             BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_E_VALUE);
359             return CRYPT_RSA_ERR_E_VALUE;
360         }
361         needDestoryE = true;
362     }
363 
364     ctx->scBlind = RSA_BlindNewCtx();
365 
366     int32_t ret = RSA_BlindCreateParam(ctx->libCtx, ctx->scBlind, e, ctx->prvKey->n, bits, opt);
367     if (needDestoryE) {
368         BN_Destroy(e);
369     }
370     return ret;
371 }
372 
RSA_BlindProcess(CRYPT_RSA_Ctx * ctx,BN_BigNum * message,BN_Optimizer * opt)373 static int32_t RSA_BlindProcess(CRYPT_RSA_Ctx *ctx, BN_BigNum *message, BN_Optimizer *opt)
374 {
375     int32_t ret;
376     if (ctx->scBlind == NULL) {
377         ret = RSA_InitBlind(ctx, opt);
378         if (ret != CRYPT_SUCCESS) {
379             return ret;
380         }
381     }
382 
383     return RSA_BlindCovert(ctx->scBlind, message, ctx->prvKey->n, opt);
384 }
385 #endif
386 
RSA_AllocAndCheck(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,BN_BigNum ** result,BN_BigNum ** message)387 static int32_t RSA_AllocAndCheck(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
388     BN_BigNum **result, BN_BigNum **message)
389 {
390     int32_t ret;
391     if (ctx->prvKey == NULL) {
392         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
393         return CRYPT_RSA_NO_KEY_INFO;
394     }
395 
396     uint32_t bits = CRYPT_RSA_GetBits(ctx);
397 
398     ret = AllocResultAndInputBN(bits, result, message, input, inputLen);
399     if (ret != CRYPT_SUCCESS) {
400         BSL_ERR_PUSH_ERROR(ret);
401         goto ERR;
402     }
403     ret = InputRangeCheck(*message, ctx->prvKey->n, bits);
404     if (ret != CRYPT_SUCCESS) {
405         goto ERR;
406     }
407     return ret;
408 ERR:
409     BN_Destroy(*result);
410     BN_Destroy(*message);
411     return ret;
412 }
413 
RSA_PrvProcess(const CRYPT_RSA_Ctx * ctx,BN_BigNum * message,BN_BigNum * result,BN_Optimizer * opt)414 static int32_t RSA_PrvProcess(const CRYPT_RSA_Ctx *ctx, BN_BigNum *message, BN_BigNum *result, BN_Optimizer *opt)
415 {
416 #ifndef HITLS_CRYPTO_RSA_BLINDING
417     (void)opt;
418 #endif
419     int32_t ret;
420 #ifdef HITLS_CRYPTO_RSA_BLINDING
421     // blinding
422     if ((ctx->flags & CRYPT_RSA_BLINDING) != 0) {
423         ret = RSA_BlindProcess((CRYPT_RSA_Ctx *)(uintptr_t)ctx, message, opt);
424         if (ret != CRYPT_SUCCESS) {
425             BSL_ERR_PUSH_ERROR(ret);
426             return ret;
427         }
428     }
429 #endif
430     /* If ctx->prvKey->p is set to 0, the standard mode is used for RSA decryption.
431        Otherwise, the CRT mode is used for RSA decryption. */
432     if (BN_IsZero(ctx->prvKey->p)) {
433         ret = CalcMontExp(ctx->prvKey, result, message, opt, true);
434     } else {
435         ret = NormalDecProcedure(ctx, message, result, opt);
436     }
437     if (ret != CRYPT_SUCCESS) {
438         BSL_ERR_PUSH_ERROR(ret);
439         return ret;
440     }
441 
442 #ifdef HITLS_CRYPTO_RSA_BLINDING
443     // unblinding
444     if ((ctx->flags & CRYPT_RSA_BLINDING) != 0) {
445         ret = RSA_BlindInvert(ctx->scBlind, result, ctx->prvKey->n, opt);
446         if (ret != CRYPT_SUCCESS) {
447             BSL_ERR_PUSH_ERROR(ret);
448         }
449     }
450 #endif
451     return ret;
452 }
453 
CRYPT_RSA_PrvDec(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t * outLen)454 int32_t CRYPT_RSA_PrvDec(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
455     uint8_t *out, uint32_t *outLen)
456 {
457     int32_t ret;
458     uint32_t bits;
459     BN_BigNum *result = NULL;
460     BN_BigNum *message = NULL;
461     BN_Optimizer *opt = NULL;
462 
463     if (ctx == NULL || input == NULL || out == NULL || outLen == NULL) {
464         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
465         return CRYPT_NULL_INPUT;
466     }
467 
468     if (ctx->prvKey == NULL) {
469         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
470         return CRYPT_RSA_NO_KEY_INFO;
471     }
472 
473     bits = CRYPT_RSA_GetBits(ctx);
474     if ((*outLen) < BN_BITS_TO_BYTES(bits)) {
475         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
476         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
477     }
478     opt = BN_OptimizerCreate();
479     if (opt == NULL) {
480         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
481         return CRYPT_MEM_ALLOC_FAIL;
482     }
483     ret = RSA_AllocAndCheck(ctx, input, inputLen, &result, &message);
484     if (ret != CRYPT_SUCCESS) {
485         BN_OptimizerDestroy(opt);
486         return ret;
487     }
488 
489     (void)OptimizerStart(opt);
490     ret = RSA_PrvProcess(ctx, message, result, opt);
491     if (ret != CRYPT_SUCCESS) {
492         goto EXIT;
493     }
494 
495     ret = ResultToOut(bits, result, out, outLen);
496 EXIT:
497     OptimizerEnd(opt);
498     BN_OptimizerDestroy(opt);
499     BN_Destroy(result);
500     BN_Destroy(message);
501     return ret;
502 }
503 
504 #if defined(HITLS_CRYPTO_RSA_SIGN) || defined(HITLS_CRYPTO_RSA_VERIFY)
GetHashLen(const CRYPT_RSA_Ctx * ctx)505 static uint32_t GetHashLen(const CRYPT_RSA_Ctx *ctx)
506 {
507     if (ctx->pad.type == EMSA_PKCSV15) {
508         return CRYPT_GetMdSizeById(ctx->pad.para.pkcsv15.mdId);
509     }
510 
511     return (uint32_t)(ctx->pad.para.pss.mdMeth->mdSize);
512 }
513 
CheckHashLen(uint32_t inputLen)514 static int32_t CheckHashLen(uint32_t inputLen)
515 {
516     if (inputLen > 64) {  // 64 is the maximum of the hash length.
517         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_ALGID);
518         return CRYPT_RSA_ERR_ALGID;
519     }
520     // Inconsistent length
521     BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_ALGID);
522     return CRYPT_RSA_ERR_ALGID;
523 }
524 #endif
525 
526 #if defined(HITLS_CRYPTO_RSA_EMSA_PSS) && defined(HITLS_CRYPTO_RSA_SIGN)
PssPad(CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t outLen)527 static int32_t PssPad(CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen, uint8_t *out, uint32_t outLen)
528 {
529     CRYPT_Data salt = { 0 };
530     bool kat = false; // mark
531     if (ctx->pad.salt.data != NULL) {
532         // If the salt contains data, that is the kat test.
533         kat = true;
534     }
535     if (kat) {
536         salt.data = ctx->pad.salt.data;
537         salt.len = ctx->pad.salt.len;
538         ctx->pad.salt.data = NULL;
539         ctx->pad.salt.len = 0;
540     } else if (ctx->pad.para.pss.saltLen != 0) {
541         // Generate a salt information to the salt.
542         int32_t ret = GenPssSalt(ctx->libCtx, &salt, ctx->pad.para.pss.mdMeth, ctx->pad.para.pss.saltLen, outLen);
543         if (ret != CRYPT_SUCCESS) {
544             BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_GEN_SALT);
545             return CRYPT_RSA_ERR_GEN_SALT;
546         }
547     }
548     int32_t ret = CRYPT_RSA_SetPss(ctx->pad.para.pss.mdMeth, ctx->pad.para.pss.mgfMeth, CRYPT_RSA_GetBits(ctx),
549         salt.data, salt.len, input, inputLen, out, outLen);
550     if (ret != CRYPT_SUCCESS) {
551         BSL_ERR_PUSH_ERROR(ret);
552     }
553     if (!kat && (ctx->pad.para.pss.saltLen != 0)) {
554         // The generated salt needs to be released.
555         BSL_SAL_CleanseData(salt.data, salt.len);
556         BSL_SAL_FREE(salt.data);
557     }
558     return ret;
559 }
560 #endif
561 
562 #ifdef HITLS_CRYPTO_RSA_BSSA
563 
BlindInputCheck(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,const uint8_t * out,const uint32_t * outLen)564 static int32_t BlindInputCheck(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
565     const uint8_t *out, const uint32_t *outLen)
566 {
567     if (ctx == NULL || input == NULL || inputLen == 0 || out == NULL || outLen == NULL) {
568         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
569         return CRYPT_NULL_INPUT;
570     }
571     if (ctx->pubKey == NULL) {
572         // Check whether the private key information exists.
573         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_NO_PUBKEY_INFO);
574         return CRYPT_RSA_ERR_NO_PUBKEY_INFO;
575     }
576     uint32_t bits = CRYPT_RSA_GetBits(ctx);
577     if ((*outLen) < BN_BITS_TO_BYTES(bits)) {
578         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
579         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
580     }
581     if (ctx->pad.type != EMSA_PSS) {
582         BSL_ERR_PUSH_ERROR(CRYPT_RSA_PADDING_NOT_SUPPORTED);
583         return CRYPT_RSA_PADDING_NOT_SUPPORTED;
584     }
585     return CRYPT_SUCCESS;
586 }
587 
588 #ifdef HITLS_CRYPTO_RSA_SIGN
BssaParamNew(void)589 static RSA_BlindParam *BssaParamNew(void)
590 {
591     RSA_BlindParam *param = BSL_SAL_Calloc(1u, sizeof(RSA_BlindParam));
592     if (param == NULL) {
593         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
594         return NULL;
595     }
596     param->para.bssa = RSA_BlindNewCtx();
597     if (param->para.bssa == NULL) {
598         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
599         BSL_SAL_FREE(param);
600         return NULL;
601     }
602     param->type = RSABSSA;
603     return param;
604 }
605 
BssaBlind(CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t * outLen)606 static int32_t BssaBlind(CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
607     uint8_t *out, uint32_t *outLen)
608 {
609     int32_t ret;
610     uint32_t bits = CRYPT_RSA_GetBits(ctx);
611     uint32_t padLen = BN_BITS_TO_BYTES(bits);
612     RSA_BlindParam *param = NULL;
613     BN_BigNum *e = ctx->pubKey->e;
614     BN_BigNum *n = ctx->pubKey->n;
615     RSA_Blind *blind = NULL;
616     uint8_t *pad = BSL_SAL_Malloc(padLen);
617     BN_Optimizer *opt = BN_OptimizerCreate();
618     BN_BigNum *enMsg = BN_Create(bits);
619     BN_BigNum *gcd = BN_Create(bits);
620     if (pad == NULL || opt == NULL || enMsg == NULL || gcd == NULL) {
621         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
622         ret = CRYPT_MEM_ALLOC_FAIL;
623         goto ERR;
624     }
625    // encoded_msg = EMSA-PSS-ENCODE(msg, bit_len(n))
626     GOTO_ERR_IF(PssPad(ctx, input, inputLen, pad, padLen), ret);
627     GOTO_ERR_IF(BN_Bin2Bn(enMsg, pad, padLen), ret);
628 
629     // Check if bigNumOut and n are coprime using GCD
630     GOTO_ERR_IF(BN_Gcd(gcd, enMsg, n, opt), ret);
631 
632     // Check if gcd is 1
633     if (!BN_IsOne(gcd)) {
634         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
635         ret = CRYPT_INVALID_ARG;
636         goto ERR;
637     }
638 
639     param = ctx->blindParam;
640     if (param == NULL) {
641         param = BssaParamNew();
642         if (param == NULL) {
643             ret = CRYPT_MEM_ALLOC_FAIL;
644             goto ERR;
645         }
646         GOTO_ERR_IF(RSA_BlindCreateParam(ctx->libCtx, param->para.bssa, e, n, bits, opt), ret);
647     }
648     blind = param->para.bssa;
649     GOTO_ERR_IF(BN_ModMul(enMsg, enMsg, blind->r, n, opt), ret);
650     GOTO_ERR_IF(ResultToOut(bits, enMsg, out, outLen), ret);
651     ctx->blindParam = param;
652 ERR:
653     if (ret != CRYPT_SUCCESS && ctx->blindParam == NULL && param != NULL) {
654         RSA_BlindFreeCtx(param->para.bssa);
655         BSL_SAL_Free(param);
656     }
657     BN_Destroy(enMsg);
658     BN_Destroy(gcd);
659     BSL_SAL_FREE(pad);
660     BN_OptimizerDestroy(opt);
661     return ret;
662 }
663 
BlindSign(CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,uint8_t * sign,uint32_t * signLen)664 static int32_t BlindSign(CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen,
665     uint8_t *sign, uint32_t *signLen)
666 {
667     int32_t ret;
668     uint32_t bits = CRYPT_RSA_GetBits(ctx);
669     uint32_t sLen = BN_BITS_TO_BYTES(bits);
670     uint32_t mLen = BN_BITS_TO_BYTES(bits);
671     uint8_t *s = BSL_SAL_Malloc(sLen);
672     uint8_t *m = BSL_SAL_Malloc(mLen);
673     if (s == NULL || m == NULL) {
674         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
675         ret = CRYPT_MEM_ALLOC_FAIL;
676         goto EXIT;
677     }
678 
679     /* Step 1: Compute blind signature using RSA private key operation */
680     ret = CRYPT_RSA_PrvDec(ctx, data, dataLen, s, &sLen);
681     if (ret != CRYPT_SUCCESS) {
682         goto EXIT;
683     }
684 
685     /* Step 2: Verify the signature by applying the public key operation
686      * This step ensures the signature is valid under the RSA key pair */
687     ret = CRYPT_RSA_PubEnc(ctx, s, sLen, m, &mLen);
688     if (ret != CRYPT_SUCCESS) {
689         goto EXIT;
690     }
691 
692     /* Step 3: Verify that the result matches the input blinded message
693      * This ensures the signature operation was performed correctly */
694     if (dataLen != mLen || memcmp(data, m, mLen) != 0) {
695         ret = CRYPT_RSA_NOR_VERIFY_FAIL;
696         BSL_ERR_PUSH_ERROR(ret);
697         goto EXIT;
698     }
699 
700     /* Copy the blind signature to output buffer */
701     (void)memcpy_s(sign, *signLen, s, sLen);
702     *signLen = sLen;
703 EXIT:
704     BSL_SAL_FREE(m);
705     BSL_SAL_FREE(s);
706     return ret;
707 }
708 
CRYPT_RSA_Blind(CRYPT_RSA_Ctx * ctx,int32_t algId,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t * outLen)709 int32_t CRYPT_RSA_Blind(CRYPT_RSA_Ctx *ctx, int32_t algId, const uint8_t *input, uint32_t inputLen,
710     uint8_t *out, uint32_t *outLen)
711 {
712     int32_t ret = BlindInputCheck(ctx, input, inputLen, out, outLen);
713     if (ret != CRYPT_SUCCESS) {
714         return ret;
715     }
716     if ((int32_t)ctx->pad.para.pss.mdId != algId) {
717         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_MD_ALGID);
718         return CRYPT_RSA_ERR_MD_ALGID;
719     }
720     uint8_t hash[64]; // 64 is max hash len
721     uint32_t hashLen = sizeof(hash);
722     ret = CRYPT_EAL_Md(algId, input, inputLen, hash, &hashLen);
723     if (ret != CRYPT_SUCCESS) {
724         BSL_ERR_PUSH_ERROR(ret);
725         return ret;
726     }
727     if ((ctx->flags & CRYPT_RSA_BSSA) != 0) {
728         ret = BssaBlind(ctx, hash, hashLen, out, outLen);
729     } else {
730         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_BLIND_TYPE);
731         ret = CRYPT_RSA_ERR_BLIND_TYPE;
732     }
733     return ret;
734 }
735 #endif // HITLS_CRYPTO_RSA_SIGN
736 
737 #ifdef HITLS_CRYPTO_RSA_VERIFY
BssaUnBlind(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t * outLen)738 static int32_t BssaUnBlind(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
739     uint8_t *out, uint32_t *outLen)
740 {
741     uint32_t bits = CRYPT_RSA_GetBits(ctx);
742     uint32_t sigLen = BN_BITS_TO_BYTES(bits);
743     if (inputLen != sigLen) {
744         BSL_ERR_PUSH_ERROR(CRYPT_INVALID_ARG);
745         return CRYPT_INVALID_ARG;
746     }
747     int32_t ret;
748     RSA_Blind *blind = NULL;
749     BN_BigNum *n = ctx->pubKey->n;
750     BN_Optimizer *opt = BN_OptimizerCreate();
751     BN_BigNum *z = BN_Create(bits);
752     BN_BigNum *s = BN_Create(bits);
753     if (opt == NULL || z == NULL || s == NULL) {
754         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
755         ret = CRYPT_MEM_ALLOC_FAIL;
756         goto ERR;
757     }
758     if (ctx->blindParam == NULL || ctx->blindParam->para.bssa == NULL) {
759         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_NO_BLIND_INFO);
760         ret = CRYPT_RSA_ERR_NO_BLIND_INFO;
761         goto ERR;
762     }
763     if (ctx->blindParam->type != RSABSSA) {
764         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_BLIND_TYPE);
765         ret = CRYPT_RSA_ERR_BLIND_TYPE;
766         goto ERR;
767     }
768     blind = ctx->blindParam->para.bssa;
769     GOTO_ERR_IF(BN_Bin2Bn(z, input, inputLen), ret);
770     GOTO_ERR_IF(BN_ModMul(s, z, blind->rInv, n, opt), ret);
771     GOTO_ERR_IF(ResultToOut(bits, s, out, outLen), ret);
772 ERR:
773     BN_Destroy(z);
774     BN_Destroy(s);
775     BN_OptimizerDestroy(opt);
776     return ret;
777 }
778 
CRYPT_RSA_UnBlind(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,uint8_t * out,uint32_t * outLen)779 int32_t CRYPT_RSA_UnBlind(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
780     uint8_t *out, uint32_t *outLen)
781 {
782     int32_t ret;
783     ret = BlindInputCheck(ctx, input, inputLen, out, outLen);
784     if (ret != CRYPT_SUCCESS) {
785         return ret;
786     }
787     if ((ctx->flags & CRYPT_RSA_BSSA) != 0) {
788         ret = BssaUnBlind(ctx, input, inputLen, out, outLen);
789     } else {
790         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_BLIND_TYPE);
791         ret = CRYPT_RSA_ERR_BLIND_TYPE;
792     }
793     return ret;
794 }
795 #endif // HITLS_CRYPTO_RSA_VERIFY
796 #endif // HITLS_CRYPTO_RSA_BSSA
797 
798 #ifdef HITLS_CRYPTO_RSA_SIGN
SignInputCheck(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,const uint8_t * out,const uint32_t * outLen)799 static int32_t SignInputCheck(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
800     const uint8_t *out, const uint32_t *outLen)
801 {
802     if (ctx == NULL || input == NULL || out == NULL || outLen == NULL) {
803         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
804         return CRYPT_NULL_INPUT;
805     }
806     if (ctx->prvKey == NULL) {
807         // Check whether the private key information exists.
808         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
809         return CRYPT_RSA_NO_KEY_INFO;
810     }
811     // Check whether the length of the out is sufficient to place the signature information.
812     uint32_t bits = CRYPT_RSA_GetBits(ctx);
813     if ((*outLen) < BN_BITS_TO_BYTES(bits)) {
814         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
815         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
816     }
817     if (ctx->pad.type != EMSA_PKCSV15 && ctx->pad.type != EMSA_PSS) {
818         // No padding type is set.
819         BSL_ERR_PUSH_ERROR(CRYPT_RSA_PAD_NO_SET_ERROR);
820         return CRYPT_RSA_PAD_NO_SET_ERROR;
821     }
822 #ifdef HITLS_CRYPTO_RSA_BSSA
823     if ((ctx->flags & CRYPT_RSA_BSSA) != 0) {
824         if (BN_BITS_TO_BYTES(bits) != inputLen) {
825             BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
826             return CRYPT_RSA_ERR_INPUT_VALUE;
827         }
828         return CRYPT_SUCCESS;
829     }
830 #endif
831     if (GetHashLen(ctx) != inputLen) {
832         return CheckHashLen(inputLen);
833     }
834     return CRYPT_SUCCESS;
835 }
836 
CRYPT_RSA_SignData(CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,uint8_t * sign,uint32_t * signLen)837 int32_t CRYPT_RSA_SignData(CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen,
838     uint8_t *sign, uint32_t *signLen)
839 {
840     int32_t ret = SignInputCheck(ctx, data, dataLen, sign, signLen);
841     if (ret != CRYPT_SUCCESS) {
842         return ret;
843     }
844 #ifdef HITLS_CRYPTO_RSA_BSSA
845     if ((ctx->flags & CRYPT_RSA_BSSA) != 0) {
846         return BlindSign(ctx, data, dataLen, sign, signLen);
847     }
848 #endif
849     uint32_t bits = CRYPT_RSA_GetBits(ctx);
850     uint32_t padLen = BN_BITS_TO_BYTES(bits);
851     uint8_t *pad = BSL_SAL_Malloc(padLen);
852     if (pad == NULL) {
853         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
854         return CRYPT_MEM_ALLOC_FAIL;
855     }
856     switch (ctx->pad.type) {
857 #ifdef HITLS_CRYPTO_RSA_EMSA_PKCSV15
858         case EMSA_PKCSV15:
859             ret = CRYPT_RSA_SetPkcsV15Type1(ctx->pad.para.pkcsv15.mdId, data,
860                 dataLen, pad, padLen);
861             break;
862 #endif
863 #ifdef HITLS_CRYPTO_RSA_EMSA_PSS
864         case EMSA_PSS:
865             ret = PssPad(ctx, data, dataLen, pad, padLen);
866             break;
867 #endif
868         default: // This branch cannot be entered because it's been verified before.
869             ret = CRYPT_RSA_PAD_NO_SET_ERROR;
870             break;
871     }
872     if (ret != CRYPT_SUCCESS) {
873         BSL_ERR_PUSH_ERROR(ret);
874         goto EXIT;
875     }
876     ret = CRYPT_RSA_PrvDec(ctx, pad, padLen, sign, signLen);
877     if (ret != CRYPT_SUCCESS) {
878         BSL_ERR_PUSH_ERROR(ret);
879     }
880 EXIT:
881     (void)memset_s(pad, padLen, 0, padLen);
882     BSL_SAL_FREE(pad);
883     return ret;
884 }
885 
CRYPT_RSA_Sign(CRYPT_RSA_Ctx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,uint8_t * sign,uint32_t * signLen)886 int32_t CRYPT_RSA_Sign(CRYPT_RSA_Ctx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen,
887     uint8_t *sign, uint32_t *signLen)
888 {
889     uint8_t hash[64]; // 64 is max hash len
890     uint32_t hashLen = sizeof(hash) / sizeof(hash[0]);
891     int32_t ret = EAL_Md(algId, data, dataLen, hash, &hashLen);
892     if (ret != CRYPT_SUCCESS) {
893         BSL_ERR_PUSH_ERROR(ret);
894         return ret;
895     }
896     return CRYPT_RSA_SignData(ctx, hash, hashLen, sign, signLen);
897 }
898 #endif // HITLS_CRYPTO_RSA_SIGN
899 
900 #ifdef HITLS_CRYPTO_RSA_VERIFY
VerifyInputCheck(const CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,const uint8_t * sign)901 static int32_t VerifyInputCheck(const CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen,
902     const uint8_t *sign)
903 {
904     if (ctx == NULL || data == NULL || sign == NULL) {
905         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
906         return CRYPT_NULL_INPUT;
907     }
908     if (ctx->pubKey == NULL) {
909         // Check whether the private key information exists.
910         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
911         return CRYPT_RSA_NO_KEY_INFO;
912     }
913     if (ctx->pad.type != EMSA_PKCSV15 && ctx->pad.type != EMSA_PSS) {
914         // No padding type is set.
915         BSL_ERR_PUSH_ERROR(CRYPT_RSA_PAD_NO_SET_ERROR);
916         return CRYPT_RSA_PAD_NO_SET_ERROR;
917     }
918     if (GetHashLen(ctx) != dataLen) {
919         return CheckHashLen(dataLen);
920     }
921     return CRYPT_SUCCESS;
922 }
923 
CRYPT_RSA_VerifyData(CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,const uint8_t * sign,uint32_t signLen)924 int32_t CRYPT_RSA_VerifyData(CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen,
925     const uint8_t *sign, uint32_t signLen)
926 {
927     uint8_t *pad = NULL;
928 #ifdef HITLS_CRYPTO_RSA_EMSA_PSS
929     uint32_t saltLen = 0;
930 #endif
931     int32_t ret = VerifyInputCheck(ctx, data, dataLen, sign);
932     if (ret != CRYPT_SUCCESS) {
933         return ret;
934     }
935     uint32_t bits = CRYPT_RSA_GetBits(ctx);
936     uint32_t padLen = BN_BITS_TO_BYTES(bits);
937     pad = BSL_SAL_Malloc(padLen);
938     if (pad == NULL) {
939         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
940         return CRYPT_MEM_ALLOC_FAIL;
941     }
942 
943     ret = CRYPT_RSA_PubEnc(ctx, sign, signLen, pad, &padLen);
944     if (ret != CRYPT_SUCCESS) {
945         BSL_ERR_PUSH_ERROR(ret);
946         goto EXIT;
947     }
948     switch (ctx->pad.type) {
949 #ifdef HITLS_CRYPTO_RSA_EMSA_PKCSV15
950         case EMSA_PKCSV15:
951             ret = CRYPT_RSA_VerifyPkcsV15Type1(ctx->pad.para.pkcsv15.mdId, pad, padLen,
952                 data, dataLen);
953             break;
954 #endif
955 #ifdef HITLS_CRYPTO_RSA_EMSA_PSS
956         case EMSA_PSS:
957             saltLen = (uint32_t)ctx->pad.para.pss.saltLen;
958             if (ctx->pad.para.pss.mdMeth == NULL) {
959                 ret = CRYPT_NULL_INPUT;
960                 goto EXIT;
961             }
962             if (ctx->pad.para.pss.saltLen == CRYPT_RSA_SALTLEN_TYPE_HASHLEN) { // saltLen is -1
963                 saltLen = (uint32_t)ctx->pad.para.pss.mdMeth->mdSize;
964             } else if (ctx->pad.para.pss.saltLen == CRYPT_RSA_SALTLEN_TYPE_MAXLEN) { // saltLen is -2
965                 saltLen = (uint32_t)(padLen - ctx->pad.para.pss.mdMeth->mdSize - 2); // salt, obtains DRBG
966             }
967             ret = CRYPT_RSA_VerifyPss(ctx->pad.para.pss.mdMeth, ctx->pad.para.pss.mgfMeth,
968                 bits, saltLen, data, dataLen, pad, padLen);
969             break;
970 #endif
971         default: // This branch cannot be entered because it's been verified before.
972             ret = CRYPT_RSA_PAD_NO_SET_ERROR;
973             BSL_ERR_PUSH_ERROR(ret);
974     }
975 EXIT:
976     (void)memset_s(pad, padLen, 0, padLen);
977     BSL_SAL_FREE(pad);
978     return ret;
979 }
980 
CRYPT_RSA_Verify(CRYPT_RSA_Ctx * ctx,int32_t algId,const uint8_t * data,uint32_t dataLen,const uint8_t * sign,uint32_t signLen)981 int32_t CRYPT_RSA_Verify(CRYPT_RSA_Ctx *ctx, int32_t algId, const uint8_t *data, uint32_t dataLen,
982     const uint8_t *sign, uint32_t signLen)
983 {
984     uint8_t hash[64]; // 64 is max hash len
985     uint32_t hashLen = sizeof(hash) / sizeof(hash[0]);
986     int32_t ret = EAL_Md(algId, data, dataLen, hash, &hashLen);
987     if (ret != CRYPT_SUCCESS) {
988         BSL_ERR_PUSH_ERROR(ret);
989         return ret;
990     }
991     return CRYPT_RSA_VerifyData(ctx, hash, hashLen, sign, signLen);
992 }
993 #endif // HITLS_CRYPTO_RSA_VERIFY
994 
995 #if defined(HITLS_CRYPTO_RSA_ENCRYPT) || defined(HITLS_CRYPTO_RSA_VERIFY)
EncryptInputCheck(const CRYPT_RSA_Ctx * ctx,const uint8_t * input,uint32_t inputLen,const uint8_t * out,const uint32_t * outLen)996 static int32_t EncryptInputCheck(const CRYPT_RSA_Ctx *ctx, const uint8_t *input, uint32_t inputLen,
997     const uint8_t *out, const uint32_t *outLen)
998 {
999     if (ctx == NULL || (input == NULL && inputLen != 0) || out == NULL || outLen == NULL) {
1000         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
1001         return CRYPT_NULL_INPUT;
1002     }
1003     if (ctx->pubKey == NULL) {
1004         // Check whether the public key information exists.
1005         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
1006         return CRYPT_RSA_NO_KEY_INFO;
1007     }
1008     // Check whether the length of the out is sufficient to place the encryption information.
1009     uint32_t bits = CRYPT_RSA_GetBits(ctx);
1010     if ((*outLen) < BN_BITS_TO_BYTES(bits)) {
1011         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
1012         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
1013     }
1014     if (inputLen > BN_BITS_TO_BYTES(bits)) {
1015         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_ENC_BITS);
1016         return CRYPT_RSA_ERR_ENC_BITS;
1017     }
1018     return CRYPT_SUCCESS;
1019 }
1020 #endif
1021 
1022 #ifdef HITLS_CRYPTO_RSA_ENCRYPT
CRYPT_RSA_Encrypt(CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,uint8_t * out,uint32_t * outLen)1023 int32_t CRYPT_RSA_Encrypt(CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen,
1024     uint8_t *out, uint32_t *outLen)
1025 {
1026     uint32_t bits, padLen;
1027     uint8_t *pad = NULL;
1028     int32_t ret = EncryptInputCheck(ctx, data, dataLen, out, outLen);
1029     // The static function has pushed an error. The push error is not repeated here.
1030     if (ret != CRYPT_SUCCESS) {
1031         return ret;
1032     }
1033     bits = CRYPT_RSA_GetBits(ctx);
1034     padLen = BN_BITS_TO_BYTES(bits);
1035     pad = BSL_SAL_Malloc(padLen);
1036     if (pad == NULL) {
1037         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
1038         return CRYPT_MEM_ALLOC_FAIL;
1039     }
1040 
1041     switch (ctx->pad.type) {
1042 #if defined(HITLS_CRYPTO_RSAES_PKCSV15_TLS) || defined(HITLS_CRYPTO_RSAES_PKCSV15)
1043         case RSAES_PKCSV15_TLS:
1044         case RSAES_PKCSV15:
1045             ret = CRYPT_RSA_SetPkcsV15Type2(ctx->libCtx, data, dataLen, pad, padLen);
1046             if (ret != CRYPT_SUCCESS) {
1047                 BSL_ERR_PUSH_ERROR(ret);
1048                 goto EXIT;
1049             }
1050             break;
1051 #endif
1052 #ifdef HITLS_CRYPTO_RSAES_OAEP
1053         case RSAES_OAEP:
1054             ret = CRYPT_RSA_SetPkcs1Oaep(ctx, data, dataLen, pad, padLen);
1055             if (ret != CRYPT_SUCCESS) {
1056                 BSL_ERR_PUSH_ERROR(ret);
1057                 goto EXIT;
1058             }
1059             break;
1060 #endif
1061 #ifdef HITLS_CRYPTO_RSA_NO_PAD
1062         case RSA_NO_PAD:
1063             if (dataLen != padLen) {
1064                 ret = CRYPT_RSA_ERR_ENC_INPUT_NOT_ENOUGH;
1065                 BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_ENC_INPUT_NOT_ENOUGH);
1066                 goto EXIT;
1067             }
1068             (void)memcpy_s(pad, padLen, data, dataLen);
1069             break;
1070 #endif
1071         default:
1072             ret = CRYPT_RSA_PAD_NO_SET_ERROR;
1073             BSL_ERR_PUSH_ERROR(ret);
1074             goto EXIT;
1075     }
1076 
1077     ret = CRYPT_RSA_PubEnc(ctx, pad, padLen, out, outLen);
1078     if (ret != CRYPT_SUCCESS) {
1079         BSL_ERR_PUSH_ERROR(ret);
1080     }
1081 EXIT:
1082     (void)memset_s(pad, padLen, 0, padLen);
1083     BSL_SAL_FREE(pad);
1084     return ret;
1085 }
1086 #endif // HITLS_CRYPTO_RSA_ENCRYPT
1087 
1088 #ifdef HITLS_CRYPTO_RSA_DECRYPT
DecryptInputCheck(const CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,const uint8_t * out,const uint32_t * outLen)1089 static int32_t DecryptInputCheck(const CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen,
1090     const uint8_t *out, const uint32_t *outLen)
1091 {
1092     if (ctx == NULL || data == NULL || out == NULL || outLen == NULL) {
1093         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
1094         return CRYPT_NULL_INPUT;
1095     }
1096     if (ctx->prvKey == NULL) {
1097         // Check whether the private key information exists.
1098         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NO_KEY_INFO);
1099         return CRYPT_RSA_NO_KEY_INFO;
1100     }
1101 
1102     uint32_t bits = CRYPT_RSA_GetBits(ctx);
1103     if (dataLen != BN_BITS_TO_BYTES(bits)) {
1104         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_DEC_BITS);
1105         return CRYPT_RSA_ERR_DEC_BITS;
1106     }
1107     return CRYPT_SUCCESS;
1108 }
1109 
CRYPT_RSA_Decrypt(CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,uint8_t * out,uint32_t * outLen)1110 int32_t CRYPT_RSA_Decrypt(CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen, uint8_t *out, uint32_t *outLen)
1111 {
1112     uint8_t *pad = NULL;
1113     int32_t ret = DecryptInputCheck(ctx, data, dataLen, out, outLen);
1114     // The static function has pushed an error. The push error is not repeated here.
1115     if (ret != CRYPT_SUCCESS) {
1116         return ret;
1117     }
1118     uint32_t bits = CRYPT_RSA_GetBits(ctx);
1119     uint32_t padLen = BN_BITS_TO_BYTES(bits);
1120     pad = BSL_SAL_Malloc(padLen);
1121     if (pad == NULL) {
1122         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
1123         return CRYPT_MEM_ALLOC_FAIL;
1124     }
1125 
1126     ret = CRYPT_RSA_PrvDec(ctx, data, dataLen, pad, &padLen);
1127     if (ret != CRYPT_SUCCESS) {
1128         BSL_ERR_PUSH_ERROR(ret);
1129         goto EXIT;
1130     }
1131 
1132     switch (ctx->pad.type) {
1133 #ifdef HITLS_CRYPTO_RSAES_OAEP
1134         case RSAES_OAEP:
1135             ret = CRYPT_RSA_VerifyPkcs1Oaep(ctx->pad.para.oaep.mdMeth,
1136                 ctx->pad.para.oaep.mgfMeth, pad, padLen, ctx->label.data, ctx->label.len, out, outLen);
1137             break;
1138 #endif
1139 #ifdef HITLS_CRYPTO_RSAES_PKCSV15
1140         case RSAES_PKCSV15:
1141             ret = CRYPT_RSA_VerifyPkcsV15Type2(pad, padLen, out, outLen);
1142             break;
1143 #endif
1144 #ifdef HITLS_CRYPTO_RSAES_PKCSV15_TLS
1145         case RSAES_PKCSV15_TLS:
1146             ret = CRYPT_RSA_VerifyPkcsV15Type2TLS(pad, padLen, out, outLen);
1147             break;
1148 #endif
1149 #ifdef HITLS_CRYPTO_RSA_NO_PAD
1150         case RSA_NO_PAD:
1151             if (memcpy_s(out, *outLen, pad, padLen) != EOK) {
1152                 ret = CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
1153                 BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
1154                 goto EXIT;
1155             }
1156             *outLen = padLen;
1157             break;
1158 #endif
1159         default:
1160             ret = CRYPT_RSA_PAD_NO_SET_ERROR;
1161             BSL_ERR_PUSH_ERROR(ret);
1162             break;
1163     }
1164 EXIT:
1165     BSL_SAL_CleanseData(pad, padLen);
1166     BSL_SAL_FREE(pad);
1167     return ret;
1168 }
1169 #endif // HITLS_CRYPTO_RSA_DECRYPT
1170 
1171 #ifdef HITLS_CRYPTO_RSA_VERIFY
CRYPT_RSA_Recover(CRYPT_RSA_Ctx * ctx,const uint8_t * data,uint32_t dataLen,uint8_t * out,uint32_t * outLen)1172 int32_t CRYPT_RSA_Recover(CRYPT_RSA_Ctx *ctx, const uint8_t *data, uint32_t dataLen, uint8_t *out, uint32_t *outLen)
1173 {
1174     int32_t ret = EncryptInputCheck(ctx, data, dataLen, out, outLen);
1175     if (ret != CRYPT_SUCCESS) {
1176         BSL_ERR_PUSH_ERROR(ret);
1177         return ret;
1178     }
1179     if (data == NULL) {
1180         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
1181         return CRYPT_NULL_INPUT;
1182     }
1183     uint8_t *emMsg = NULL;
1184     uint32_t bits = CRYPT_RSA_GetBits(ctx);
1185     uint32_t emLen = BN_BITS_TO_BYTES(bits);
1186     emMsg = BSL_SAL_Malloc(emLen);
1187     if (emMsg == NULL) {
1188         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
1189         return CRYPT_MEM_ALLOC_FAIL;
1190     }
1191     ret = CRYPT_RSA_PubEnc(ctx, data, dataLen, emMsg, &emLen);
1192     if (ret != CRYPT_SUCCESS) {
1193         BSL_ERR_PUSH_ERROR(ret);
1194         goto ERR;
1195     }
1196 
1197     switch (ctx->pad.type) { // Remove padding based on the padding type to obtain the plaintext.
1198 #ifdef HITLS_CRYPTO_RSA_EMSA_PKCSV15
1199         case RSAES_PKCSV15:
1200             ret = CRYPT_RSA_UnPackPkcsV15Type1(emMsg, emLen, out, outLen);
1201             break;
1202 #endif
1203 #ifdef HITLS_CRYPTO_RSA_NO_PAD
1204         case RSA_NO_PAD:
1205             if (memcpy_s(out, *outLen, emMsg, emLen) != EOK) {
1206                 BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
1207                 ret = CRYPT_SECUREC_FAIL;
1208                 goto ERR;
1209             }
1210             *outLen = emLen;
1211             break;
1212 #endif
1213         default:
1214             ret = CRYPT_RSA_PAD_NO_SET_ERROR;
1215             BSL_ERR_PUSH_ERROR(ret);
1216             goto ERR;
1217     }
1218 ERR:
1219     (void)memset_s(emMsg, emLen, 0, emLen);
1220     BSL_SAL_FREE(emMsg);
1221     return ret;
1222 }
1223 #endif // HITLS_CRYPTO_RSA_VERIFY
1224 
1225 #endif /* HITLS_CRYPTO_RSA */