• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * This file is part of the openHiTLS project.
3  *
4  * openHiTLS is licensed under the Mulan PSL v2.
5  * You can use this software according to the terms and conditions of the Mulan PSL v2.
6  * You may obtain a copy of Mulan PSL v2 at:
7  *
8  *     http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  * See the Mulan PSL v2 for more details.
14  */
15 
16 #include "hitls_build.h"
17 #ifdef HITLS_CRYPTO_RSA
18 
19 #include "crypt_rsa.h"
20 #include "rsa_local.h"
21 #include "crypt_errno.h"
22 #include "crypt_types.h"
23 #include "crypt_utils.h"
24 #include "crypt_util_rand.h"
25 #include "securec.h"
26 #include "bsl_sal.h"
27 #include "bsl_bytes.h"
28 #include "bsl_err_internal.h"
29 
30 #define UINT32_SIZE 4
31 
32 #ifdef HITLS_CRYPTO_RSA_EMSA_PSS
33 // maskedDB: [in] maskDB from MGF
34 //           [out] maskedDB = DB xor maskDB
35 // DB: PS || 0x01 || salt;
36 // msBit: indicates the number of valid bits in the most significant bytes of the EM,
37 // value 0 indicates that all bits are valid.
MaskDB(uint8_t * maskedDB,uint32_t len,const uint8_t * salt,uint32_t saltLen,uint32_t msBit)38 static void MaskDB(uint8_t *maskedDB, uint32_t len, const uint8_t *salt, uint32_t saltLen, uint32_t msBit)
39 {
40     uint8_t *tmp = maskedDB + (len - saltLen) - 1; // init point to pos of 0x01
41     *tmp ^= 0x01;
42     tmp++;
43     uint32_t i;
44     for (i = 0; i < saltLen; i++) {
45         tmp[i] ^= salt[i];
46     }
47     if (msBit != 0) {
48         // Set the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB to zero
49         maskedDB[0] &= ((uint8_t)(0xFF >> (8 - msBit)));
50     }
51 }
52 
PssEncodeLengthCheck(uint32_t modBits,uint32_t hLen,uint32_t saltLen,uint32_t dataLen,uint32_t padLen)53 static int32_t PssEncodeLengthCheck(uint32_t modBits, uint32_t hLen,
54     uint32_t saltLen, uint32_t dataLen, uint32_t padLen)
55 {
56     if (modBits < RSA_MIN_MODULUS_BITS || modBits > RSA_MAX_MODULUS_BITS) {
57         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_KEY_BITS);
58         return CRYPT_RSA_ERR_KEY_BITS;
59     }
60     if (hLen > RSA_MAX_MODULUS_LEN || dataLen != hLen) {
61         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
62         return CRYPT_RSA_ERR_INPUT_VALUE;
63     }
64     uint32_t keyBytes = BN_BITS_TO_BYTES(modBits);
65     if (keyBytes != padLen) { // The length required for padding does not match the key module length (API convention).
66         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
67         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
68     }
69     if (saltLen == (uint32_t)CRYPT_RSA_SALTLEN_TYPE_AUTOLEN) {
70         return CRYPT_SUCCESS;
71     }
72     if (saltLen > RSA_MAX_MODULUS_LEN) {
73         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_PSS_SALT_LEN);
74         return CRYPT_RSA_ERR_PSS_SALT_LEN;
75     }
76     uint32_t emLen = keyBytes;
77     // the octet length of EM will be one less than k if modBits - 1 is divisible by 8 and equal to k otherwise
78     if (((modBits - 1) & 0x7) == 0) {
79         emLen--;
80     }
81     if (emLen < hLen + saltLen + 2) { // RFC: If emLen < hLen + sLen + 2, output "encoding error" and stop.
82         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_PSS_SALT_LEN);
83         return CRYPT_RSA_ERR_PSS_SALT_LEN;
84     }
85     return CRYPT_SUCCESS;
86 }
87 
88 #if defined(HITLS_CRYPTO_RSA_SIGN) || defined(HITLS_CRYPTO_RSA_BSSA)
GenPssSalt(void * libCtx,CRYPT_Data * salt,const EAL_MdMethod * mdMethod,int32_t saltLen,uint32_t padBuffLen)89 int32_t GenPssSalt(void *libCtx, CRYPT_Data *salt, const EAL_MdMethod *mdMethod, int32_t saltLen, uint32_t padBuffLen)
90 {
91     uint32_t hashLen = mdMethod->mdSize;
92     if (saltLen == CRYPT_RSA_SALTLEN_TYPE_HASHLEN) { // saltLen is -1
93         salt->len = hashLen;
94     } else if (saltLen == CRYPT_RSA_SALTLEN_TYPE_MAXLEN ||
95         saltLen == CRYPT_RSA_SALTLEN_TYPE_AUTOLEN) { // saltLen is -2 or -3
96         salt->len = padBuffLen - hashLen - 2; // salt, obtains from the DRBG
97     } else {
98         salt->len = (uint32_t)saltLen;
99     }
100 
101     salt->data = BSL_SAL_Malloc(salt->len);
102     if (salt->data == NULL) {
103         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
104         return CRYPT_MEM_ALLOC_FAIL;
105     }
106     // Obtain the salt through the public random number.
107     int32_t ret = CRYPT_RandEx(libCtx, salt->data, salt->len);
108     if (ret != CRYPT_SUCCESS) {
109         BSL_SAL_FREE(salt->data);
110         BSL_ERR_PUSH_ERROR(ret);
111     }
112     return ret;
113 }
114 
115 /**
116  * EMSA-PSS Encoding Operation
117  *                                    +-----------+
118  *                                    |     M     |
119  *                                    +-----------+
120  *                                          |
121  *                                          V
122  *                                        Hash
123  *                                          |
124  *                                          V
125  *                            +--------+----------+----------+
126  *                       M' = |Padding1|  mHash   |   salt   |
127  *                            +--------+----------+----------+
128  *                                           |
129  *                 +--------+----------+     V
130  *           DB =  |Padding2|   salt   |   Hash
131  *                 +--------+----------+     |
132  *                           |               |
133  *                           V               |
134  *                          xor <--- MGF <---|  maskDB = MGF(H, emLen - hLen - 1).
135  *                           |               |
136  *                           |               |
137  *                           V               V
138  *                 +-------------------+----------+--+
139  *           EM =  |    maskedDB       |     H    |bc|
140  *                 +-------------------+----------+--+
141  * Output EM data with a fixed length (keyBytes) to the pad buffer.
142  * Add 0s to the first byte, if the EM length + 1 = keyBytes.
143  * Of which:
144  * The data is the mHash in the preceding figure.
145  * M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt
146  * DB = PS || 0x01 || salt; DB is an octet string of length emLen - hLen - 1
147  * PS consisting of emLen - sLen - hLen - 2 zero octets, The length of PS may be 0.
148  */
CRYPT_RSA_SetPss(const EAL_MdMethod * hashMethod,const EAL_MdMethod * mgfMethod,uint32_t keyBits,const uint8_t * salt,uint32_t saltLen,const uint8_t * data,uint32_t dataLen,uint8_t * pad,uint32_t padLen)149 int32_t CRYPT_RSA_SetPss(const EAL_MdMethod *hashMethod, const EAL_MdMethod *mgfMethod, uint32_t keyBits,
150     const uint8_t *salt, uint32_t saltLen, const uint8_t *data, uint32_t dataLen, uint8_t *pad, uint32_t padLen)
151 {
152     int32_t ret;
153     if (hashMethod == NULL || mgfMethod == NULL || pad == NULL || data == NULL) {
154         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
155         return CRYPT_NULL_INPUT;
156     }
157     if (salt == NULL && saltLen != 0) {
158         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_PSS_SALT_DATA);
159         return CRYPT_RSA_ERR_PSS_SALT_DATA;
160     }
161     uint32_t hLen = hashMethod->mdSize;
162     ret = PssEncodeLengthCheck(keyBits, hLen, saltLen, dataLen, padLen);
163     if (ret != CRYPT_SUCCESS) {
164         BSL_ERR_PUSH_ERROR(ret);
165         return ret;
166     }
167     uint32_t keyBytes = BN_BITS_TO_BYTES(keyBits);
168     uint8_t *em = pad;
169     uint32_t emLen = keyBytes;
170     // the octet length of EM will be one less than k if modBits - 1 is divisible by 8 and equal to k otherwise
171     uint32_t msBit = ((keyBits - 1) & 0x7);
172     if (msBit == 0) {
173         emLen--;
174         *em = 0;
175         em++;
176     }
177     em[emLen - 1] = 0xbc; // EM = maskedDB || H || 0xbc.
178 
179     // set H
180     static const uint8_t zeros8[8] = {0};
181     const CRYPT_ConstData hashData[] = {
182         {zeros8, sizeof(zeros8)},
183         {data, dataLen}, // mHash
184         {salt, saltLen}  // salt
185     };
186 
187     const uint32_t maskedDBLen = emLen - hLen - 1;
188     uint8_t *h = em + maskedDBLen;
189     ret = CRYPT_CalcHash(hashMethod, hashData, sizeof(hashData) / sizeof(hashData[0]), h, &hLen);
190     if (ret != CRYPT_SUCCESS) {
191         BSL_ERR_PUSH_ERROR(ret);
192         return ret;
193     }
194 
195     // set maskedDB
196     ret = CRYPT_Mgf1(mgfMethod, h, hLen, em, maskedDBLen);
197     if (ret != CRYPT_SUCCESS) {
198         BSL_ERR_PUSH_ERROR(ret);
199         return ret;
200     }
201     MaskDB(em, maskedDBLen, salt, saltLen, msBit);
202     return CRYPT_SUCCESS;
203 }
204 #endif // HITLS_CRYPTO_RSA_SIGN || HITLS_CRYPTO_RSA_BSSA
205 
206 #ifdef HITLS_CRYPTO_RSA_VERIFY
GetVerifySaltLen(const uint8_t * emData,const uint8_t * dbBuff,uint32_t maskedDBLen,uint32_t msBit,uint32_t * saltLen)207 static int32_t GetVerifySaltLen(const uint8_t *emData, const uint8_t *dbBuff, uint32_t maskedDBLen, uint32_t msBit,
208     uint32_t *saltLen)
209 {
210     uint32_t i = 0;
211     uint8_t *tmpBuff = (uint8_t *)BSL_SAL_Malloc(maskedDBLen);
212     if (tmpBuff == NULL) {
213         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
214         return CRYPT_MEM_ALLOC_FAIL;
215     }
216     (void)memcpy_s(tmpBuff, maskedDBLen, dbBuff, maskedDBLen);
217     if (msBit != 0) {
218         tmpBuff[0] &= ((uint8_t)(0xFF >> (8 - msBit)));  // Set the leftmost 8emLen - emBits bits to zero
219     }
220 
221     for (i = 0; i < maskedDBLen; i++) {
222         tmpBuff[i] ^= emData[i];
223         if (tmpBuff[i] != 0) {
224             break;
225         }
226     }
227     if (i == maskedDBLen || tmpBuff[i] != 0x01) {
228         BSL_SAL_FREE(tmpBuff);
229         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_PSS_SALT_LEN);
230         return CRYPT_RSA_ERR_PSS_SALT_LEN;
231     }
232     i++;
233     BSL_SAL_FREE(tmpBuff);
234     *saltLen = maskedDBLen - i;
235     return CRYPT_SUCCESS;
236 }
237 
GetAndVerifyDB(const EAL_MdMethod * mgfMethod,const CRYPT_Data * emData,const CRYPT_Data * dbBuff,uint32_t * saltLen,uint32_t msBit)238 static int32_t GetAndVerifyDB(const EAL_MdMethod *mgfMethod, const CRYPT_Data *emData,
239     const CRYPT_Data *dbBuff, uint32_t *saltLen, uint32_t msBit)
240 {
241     uint32_t maskedDBLen = dbBuff->len;
242     uint32_t hLen = emData->len - maskedDBLen - 1;
243     uint32_t tmpSaltLen = *saltLen;
244     const uint8_t *h = emData->data + maskedDBLen;
245     int32_t ret = CRYPT_Mgf1(mgfMethod, h, hLen, dbBuff->data, dbBuff->len);
246     if (ret != CRYPT_SUCCESS) {
247         BSL_ERR_PUSH_ERROR(ret);
248         return ret;
249     }
250     if (tmpSaltLen == (uint32_t)CRYPT_RSA_SALTLEN_TYPE_AUTOLEN) {
251         ret = GetVerifySaltLen(emData->data, dbBuff->data, maskedDBLen, msBit, &tmpSaltLen);
252         if (ret != CRYPT_SUCCESS) {
253             return ret;
254         }
255     }
256     // A ^ B == C => A ^ C == B
257     MaskDB(dbBuff->data, dbBuff->len, h - tmpSaltLen, tmpSaltLen, msBit);
258     if (memcmp(dbBuff->data, emData->data, maskedDBLen - tmpSaltLen) != 0) {
259         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
260         return CRYPT_RSA_NOR_VERIFY_FAIL;
261     }
262     *saltLen = tmpSaltLen;
263     return CRYPT_SUCCESS;
264 }
265 
VerifyH(const EAL_MdMethod * hashMethod,const CRYPT_Data * mHash,const CRYPT_Data * salt,const CRYPT_Data * h,const CRYPT_Data * hBuff)266 static int32_t VerifyH(const EAL_MdMethod *hashMethod, const CRYPT_Data *mHash, const CRYPT_Data *salt,
267     const CRYPT_Data *h, const CRYPT_Data *hBuff)
268 {
269     static const uint8_t zeros8[8] = {0};
270     const CRYPT_ConstData hashData[] = {
271         {zeros8, sizeof(zeros8)},
272         {mHash->data, mHash->len},
273         {salt->data, salt->len}
274     };
275 
276     uint32_t hLen = hBuff->len;
277     int32_t ret = CRYPT_CalcHash(hashMethod, hashData, sizeof(hashData) / sizeof(hashData[0]), hBuff->data, &hLen);
278     if (ret != CRYPT_SUCCESS) {
279         BSL_ERR_PUSH_ERROR(ret);
280         return ret;
281     }
282     if (memcmp(h->data, hBuff->data, hLen) != 0) {
283         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
284         return CRYPT_RSA_NOR_VERIFY_FAIL;
285     }
286     return CRYPT_SUCCESS;
287 }
288 
289 // Reverse verification process of EMSA-PSS Encoding Operation:
290 // MGF(H,maskedDBLen) ^ MaskedDB => DB' (PS||0x01||salt'),  H' = Hash(padding1 || mHash || salt') == H ?
CRYPT_RSA_VerifyPss(const EAL_MdMethod * hashMethod,const EAL_MdMethod * mgfMethod,uint32_t keyBits,uint32_t saltLen,const uint8_t * data,uint32_t dataLen,const uint8_t * pad,uint32_t padLen)291 int32_t CRYPT_RSA_VerifyPss(const EAL_MdMethod *hashMethod, const EAL_MdMethod *mgfMethod, uint32_t keyBits,
292     uint32_t saltLen, const uint8_t *data, uint32_t dataLen, const uint8_t *pad, uint32_t padLen)
293 {
294     if (hashMethod == NULL || mgfMethod == NULL || pad == NULL || data == NULL) {
295         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
296         return CRYPT_NULL_INPUT;
297     }
298     uint32_t hLen = hashMethod->mdSize;
299     int32_t ret = PssEncodeLengthCheck(keyBits, hLen, saltLen, dataLen, padLen);
300     if (ret != CRYPT_SUCCESS) {
301         BSL_ERR_PUSH_ERROR(ret);
302         return ret;
303     }
304 
305     //  EM = maskedDB || H || 0xbc
306     if (pad[padLen - 1] != 0xbc) {
307         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
308         return CRYPT_RSA_NOR_VERIFY_FAIL;
309     }
310 
311     const uint8_t *em = pad;
312     uint32_t emLen = BN_BITS_TO_BYTES(keyBits);
313     // the octet length of EM will be one less than k if modBits - 1 is divisible by 8 and equal to k otherwise
314     uint32_t msBit = ((keyBits - 1) & 0x7);
315     if (msBit == 0) {
316         emLen--;
317         em++;
318     }
319     if ((pad[0] >> msBit) != 0) {
320         // if msBit == 0, 8emLen == emBits, pad[0] should be 0
321         // the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB should be 0
322         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
323         return CRYPT_RSA_NOR_VERIFY_FAIL;
324     }
325     uint8_t *tmpBuff = BSL_SAL_Malloc(emLen); // for maskDB' / DB' and H'
326     if (tmpBuff == NULL) {
327         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
328         return CRYPT_MEM_ALLOC_FAIL;
329     }
330 
331     const uint32_t maskedDBLen = emLen - hLen - 1;
332     const CRYPT_Data dbBuff = {tmpBuff, maskedDBLen};
333     const CRYPT_Data emData = {(uint8_t *)(uintptr_t)em, emLen};
334     const CRYPT_Data mHash = {(uint8_t *)(uintptr_t)data, dataLen};
335     const CRYPT_Data h     = {(uint8_t *)(uintptr_t)&em[maskedDBLen], hLen};
336     const CRYPT_Data hBuff = {&tmpBuff[maskedDBLen], hLen};
337     ret = GetAndVerifyDB(mgfMethod, &emData, &dbBuff, &saltLen, msBit);
338     if (ret != CRYPT_SUCCESS) {
339         (void)memset_s(tmpBuff, emLen, 0, emLen);
340         BSL_SAL_FREE(tmpBuff);
341         BSL_ERR_PUSH_ERROR(ret);
342         return ret;
343     }
344     const CRYPT_Data salt  = {&tmpBuff[maskedDBLen - saltLen], saltLen};
345     ret = VerifyH(hashMethod, &mHash, &salt, &h, &hBuff);
346     (void)memset_s(tmpBuff, emLen, 0, emLen);
347     BSL_SAL_FREE(tmpBuff);
348     return ret;
349 }
350 #endif // HITLS_CRYPTO_RSA_VERIFY
351 #endif // HITLS_CRYPTO_RSA_EMSA_PSS
352 
353 #ifdef HITLS_CRYPTO_RSA_EMSA_PKCSV15
PkcsSetLengthCheck(uint32_t emLen,uint32_t hashLen,uint32_t algIdentLen)354 static int32_t PkcsSetLengthCheck(uint32_t emLen, uint32_t hashLen, uint32_t algIdentLen)
355 {
356     if (emLen > RSA_MAX_MODULUS_LEN || hashLen > RSA_MAX_MODULUS_LEN) {
357         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
358         return CRYPT_RSA_ERR_INPUT_VALUE;
359     }
360     if (hashLen == 0) {
361         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
362         return CRYPT_RSA_ERR_INPUT_VALUE;
363     }
364     /* The length of the pad must exceed 11 bytes at least. tLen = hashLen + algIdentLen */
365     if (emLen < hashLen + algIdentLen + 11) {
366         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
367         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
368     }
369     return CRYPT_SUCCESS;
370 }
371 
PkcsGetIdentifier(CRYPT_MD_AlgId hashId,CRYPT_Data * algIdentifier)372 static int32_t PkcsGetIdentifier(CRYPT_MD_AlgId hashId, CRYPT_Data *algIdentifier)
373 {
374     static uint8_t sha1TInfo[] = {0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05,
375         0x00, 0x04, 0x14};
376     static uint8_t sha224TInfo[] = {0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
377         0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c};
378     static uint8_t sha256TInfo[] = {0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
379         0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20};
380     static uint8_t sha384TInfo[] = {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
381         0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30};
382     static uint8_t sha512TInfo[] = {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
383         0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40};
384     static uint8_t md5TInfo[] = {0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d,
385         0x02, 0x05, 0x05, 0x00, 0x04, 0x10};
386     static uint8_t sm3TInfo[] = {0x30, 0x30, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x81, 0x1c, 0xcf, 0x55, 0x01,
387         0x83, 0x11, 0x05, 0x00, 0x04, 0x20};
388     algIdentifier->data = NULL;
389     algIdentifier->len = 0;
390 
391     if (hashId == CRYPT_MD_SHA1) {
392         algIdentifier->data = (uint8_t *)sha1TInfo;
393         algIdentifier->len = sizeof(sha1TInfo);
394     } else if (hashId == CRYPT_MD_SHA224) {
395         algIdentifier->data = (uint8_t *)sha224TInfo;
396         algIdentifier->len = sizeof(sha224TInfo);
397     } else if (hashId == CRYPT_MD_SHA256) {
398         algIdentifier->data = (uint8_t *)sha256TInfo;
399         algIdentifier->len = sizeof(sha256TInfo);
400     } else if (hashId == CRYPT_MD_SHA384) {
401         algIdentifier->data = (uint8_t *)sha384TInfo;
402         algIdentifier->len = sizeof(sha384TInfo);
403     } else if (hashId == CRYPT_MD_SHA512) {
404         algIdentifier->data = (uint8_t *)sha512TInfo;
405         algIdentifier->len = sizeof(sha512TInfo);
406     } else if (hashId == CRYPT_MD_MD5) {
407         algIdentifier->data = (uint8_t *)md5TInfo;
408         algIdentifier->len = sizeof(md5TInfo);
409     } else if (hashId == CRYPT_MD_SM3) {
410         algIdentifier->data = (uint8_t *)sm3TInfo;
411         algIdentifier->len = sizeof(sm3TInfo);
412     } else {
413         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_MD_ALGID);
414         return CRYPT_RSA_ERR_MD_ALGID;
415     }
416     return CRYPT_SUCCESS;
417 }
418 
419 // Pad output format:EM = 00 || 01 || PS || 00 || T; where T = algIdentifier || hash(M);
420 // hash(M) is the input parameter data of this function.
CRYPT_RSA_SetPkcsV15Type1(CRYPT_MD_AlgId hashId,const uint8_t * data,uint32_t dataLen,uint8_t * pad,uint32_t padLen)421 int32_t CRYPT_RSA_SetPkcsV15Type1(CRYPT_MD_AlgId hashId, const uint8_t *data, uint32_t dataLen,
422     uint8_t *pad, uint32_t padLen)
423 {
424     int32_t ret;
425     uint32_t padSize;
426     uint8_t *tmp = pad;
427     uint32_t tmpLen = padLen;
428     if (pad == NULL || data == NULL) {
429         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
430         return CRYPT_NULL_INPUT;
431     }
432     CRYPT_Data algIdentifier = {NULL, 0};
433     ret = PkcsGetIdentifier(hashId, &algIdentifier);
434     if (ret != CRYPT_SUCCESS) {
435         BSL_ERR_PUSH_ERROR(ret);
436         return ret;
437     }
438     ret = PkcsSetLengthCheck(padLen, dataLen, algIdentifier.len);
439     if (ret != CRYPT_SUCCESS) {
440         return ret;
441     }
442 
443     // Considering that the data space and pad space may overlap,
444     // move the data to the specified position(the end of the pad).
445     if (memmove_s(pad + (padLen - dataLen), dataLen, data, dataLen) != EOK) {
446         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
447         return CRYPT_SECUREC_FAIL;
448     }
449     *tmp = 0x0;
450     tmp++;
451     *tmp = 0x1;
452     tmp++;
453     tmpLen -= 2; // Skip the first 2 bytes.
454 
455     // PS length: padSize = padLen - dataLen - algIdentifier.len - 3
456     padSize = padLen - dataLen - algIdentifier.len - 3;
457     if (memset_s(tmp, tmpLen, 0xff, padSize) != EOK) { // 0xff padded in PS
458         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
459         return CRYPT_SECUREC_FAIL;
460     }
461     tmp += padSize;
462     tmpLen -= padSize;
463 
464     *tmp = 0x0;
465     tmp++;
466     tmpLen--;
467 
468     if ((algIdentifier.len > 0) && memcpy_s(tmp, tmpLen, algIdentifier.data, algIdentifier.len) != EOK) {
469         // padding when identifier exit
470         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
471         return CRYPT_SECUREC_FAIL;
472     }
473     return CRYPT_SUCCESS;
474 }
475 
476 #ifdef HITLS_CRYPTO_RSA_VERIFY
CRYPT_RSA_VerifyPkcsV15Type1(CRYPT_MD_AlgId hashId,const uint8_t * pad,uint32_t padLen,const uint8_t * data,uint32_t dataLen)477 int32_t CRYPT_RSA_VerifyPkcsV15Type1(CRYPT_MD_AlgId hashId, const uint8_t *pad, uint32_t padLen,
478     const uint8_t *data, uint32_t dataLen)
479 {
480     if (pad == NULL || data == NULL) {
481         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
482         return CRYPT_NULL_INPUT;
483     }
484 
485     if (padLen == 0 || dataLen == 0) {
486         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
487         return CRYPT_RSA_ERR_INPUT_VALUE;
488     }
489 
490     uint8_t *padBuff = BSL_SAL_Malloc(padLen);
491     if (padBuff == NULL) {
492         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
493         return CRYPT_MEM_ALLOC_FAIL;
494     }
495 
496     int32_t ret = CRYPT_RSA_SetPkcsV15Type1(hashId, data, dataLen, padBuff, padLen);
497     if (ret != CRYPT_SUCCESS) {
498         BSL_ERR_PUSH_ERROR(ret);
499         BSL_SAL_FREE(padBuff);
500         return ret;
501     }
502 
503     if (memcmp(pad, padBuff, padLen) != 0) {
504         BSL_SAL_FREE(padBuff);
505         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
506         return CRYPT_RSA_NOR_VERIFY_FAIL;
507     }
508     BSL_SAL_FREE(padBuff);
509     return CRYPT_SUCCESS;
510 }
511 #endif // HITLS_CRYPTO_RSA_VERIFY
512 
CRYPT_RSA_UnPackPkcsV15Type1(uint8_t * data,uint32_t dataLen,uint8_t * out,uint32_t * outLen)513 int32_t CRYPT_RSA_UnPackPkcsV15Type1(uint8_t *data, uint32_t dataLen, uint8_t *out, uint32_t *outLen)
514 {
515     uint8_t *index = data;
516     uint32_t tmpLen = dataLen;
517     // Format of the data to be decrypted is EB = 00 || 01 || PS || 00 || T.
518     // The PS padding is at least 8. Therefore, the length of the data to be decrypted is at least 11.
519     if (dataLen < 11) {
520         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
521         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
522     }
523     if (*index != 0x0 || *(index + 1) != 0x01) {
524         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
525         return CRYPT_RSA_ERR_INPUT_VALUE;
526     }
527 
528     index += 2; // Skip first 2 bytes.
529     tmpLen -= 2; // Skip first 2 bytes.
530     uint32_t padNum = 0;
531     while (*index == 0xff) {
532         index++;
533         tmpLen--;
534         padNum++;
535     }
536     if (padNum < 8) { // The PS padding is at least 8.
537         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_PAD_NUM);
538         return CRYPT_RSA_ERR_PAD_NUM;
539     }
540     if (tmpLen == 0 || *index != 0x0) {
541         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
542         return CRYPT_RSA_ERR_INPUT_VALUE;
543     }
544     index++;
545     tmpLen--;
546 
547     if (memcpy_s(out, *outLen, index, tmpLen) != EOK) {
548         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
549         return CRYPT_SECUREC_FAIL;
550     }
551     *outLen = tmpLen;
552     return CRYPT_SUCCESS;
553 }
554 #endif // HITLS_CRYPTO_RSA_EMSA_PKCSV15
555 
556 #ifdef HITLS_CRYPTO_RSAES_OAEP
557 #ifdef HITLS_CRYPTO_RSA_ENCRYPT
OaepSetLengthCheck(uint32_t outLen,uint32_t inLen,uint32_t hashLen)558 static int32_t OaepSetLengthCheck(uint32_t outLen, uint32_t inLen, uint32_t hashLen)
559 {
560     if (outLen > RSA_MAX_MODULUS_LEN || inLen > RSA_MAX_MODULUS_LEN || hashLen > HASH_MAX_MDSIZE) {
561         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
562         return CRYPT_RSA_ERR_INPUT_VALUE;
563     }
564     if (outLen == 0 || hashLen == 0) {
565         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
566         return CRYPT_RSA_ERR_INPUT_VALUE;
567     }
568     // If mLen > k - 2hLen - 2, output "message too long" and stop.
569     if (inLen + 2 * hashLen + 2 > outLen) {
570         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_ENC_BITS);
571         return CRYPT_RSA_ERR_ENC_BITS;
572     }
573     return CRYPT_SUCCESS;
574 }
575 
OaepSetPs(const uint8_t * in,uint32_t inLen,uint8_t * db,uint32_t padLen,uint32_t hashLen)576 static int32_t OaepSetPs(const uint8_t *in, uint32_t inLen, uint8_t *db, uint32_t padLen, uint32_t hashLen)
577 {
578     uint8_t *ps = db + hashLen;
579     // Generate a padding string PS consisting of k - mLen - 2hLen - 2 zero octets.  The length of PS may be zero
580     // This operation cannot be reversed because the OaepSetLengthCheck has checked the validity of the data.
581     uint32_t psLen = padLen - inLen - 2 * hashLen - 2;
582     // padding 0x00
583     (void)memset_s(ps, psLen, 0, psLen);
584     ps += psLen;
585     *ps = 0x01;
586     ps++;
587     /**
588      * padLen minus twice hashLen, then subtract 2 bytes of fixed data, and subtract the padding length.
589      * The remaining length is the plaintext length.
590      */
591     if (inLen != 0 && memcpy_s(ps, padLen - 2 * hashLen - 2 - psLen, in, inLen) != EOK) {
592         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
593         return CRYPT_SECUREC_FAIL;
594     }
595     return CRYPT_SUCCESS;
596 }
597 
OaepSetMaskedDB(const EAL_MdMethod * mgfMethod,uint8_t * db,uint8_t * seed,uint32_t padLen,uint32_t hashLen)598 static int32_t OaepSetMaskedDB(const EAL_MdMethod *mgfMethod, uint8_t *db, uint8_t *seed, uint32_t padLen,
599     uint32_t hashLen)
600 {
601     int32_t ret;
602     uint32_t i;
603     uint32_t maskedDBLen = padLen - hashLen - 1;
604     uint8_t *maskedDB = (uint8_t *)BSL_SAL_Malloc(maskedDBLen);
605     if (maskedDB == NULL) {
606         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
607         return CRYPT_MEM_ALLOC_FAIL;
608     }
609 
610     ret = CRYPT_Mgf1(mgfMethod, seed, hashLen, maskedDB, maskedDBLen);
611     if (ret != CRYPT_SUCCESS) {
612         BSL_ERR_PUSH_ERROR(ret);
613         goto EXIT;
614     }
615     for (i = 0; i < maskedDBLen; i++) {
616         db[i] ^= maskedDB[i];
617     }
618 EXIT:
619     BSL_SAL_CleanseData(maskedDB, maskedDBLen);
620     BSL_SAL_FREE(maskedDB);
621     return ret;
622 }
623 
OaepSetSeedMask(const EAL_MdMethod * mgfMethod,uint8_t * db,uint8_t * seed,uint32_t padLen,uint32_t hashLen)624 static int32_t OaepSetSeedMask(const EAL_MdMethod *mgfMethod, uint8_t *db, uint8_t *seed, uint32_t padLen,
625     uint32_t hashLen)
626 {
627     uint32_t i;
628     int32_t ret;
629     uint8_t seedmask[HASH_MAX_MDSIZE];
630     uint32_t maskedDBLen = padLen - hashLen - 1;
631 
632     ret = CRYPT_Mgf1(mgfMethod, db, maskedDBLen, seedmask, hashLen);
633     if (ret != CRYPT_SUCCESS) {
634         BSL_ERR_PUSH_ERROR(ret);
635         goto EXIT;
636     }
637     for (i = 0; i < hashLen; i++) {
638         seed[i] ^= seedmask[i];
639     }
640 EXIT:
641     BSL_SAL_CleanseData(seedmask, hashLen);
642     return ret;
643 }
644 
645 /**
646 *    _________________________________________________________________
647 *
648 *                                +----------+------+--+-------+
649 *                           DB = |  lHash   |  PS  |01|   M   |
650 *                                +----------+------+--+-------+
651 *                                               |
652 *                     +----------+              |
653 *                     |   seed   |              |
654 *                     +----------+              |
655 *                           |                   |
656 *                           |-------> MGF ---> xor
657 *                           |                   |
658 *                  +--+     V                   |
659 *                  |00|    xor <----- MGF <-----|
660 *                  +--+     |                   |
661 *                    |      |                   |
662 *                    V      V                   V
663 *                  +--+----------+----------------------------+
664 *            EM =  |00|maskedSeed|          maskedDB          |
665 *                  +--+----------+----------------------------+
666 *    _________________________________________________________________
667 *
668 *                   Figure 1: EME-OAEP Encoding Operation <rfc8017>
669 */
CRYPT_RSA_SetPkcs1Oaep(CRYPT_RSA_Ctx * ctx,const uint8_t * in,uint32_t inLen,uint8_t * pad,uint32_t padLen)670 int32_t CRYPT_RSA_SetPkcs1Oaep(CRYPT_RSA_Ctx *ctx, const uint8_t *in, uint32_t inLen, uint8_t *pad, uint32_t padLen)
671 {
672     int32_t ret;
673     const EAL_MdMethod *hashMethod = ctx->pad.para.oaep.mdMeth;
674     const EAL_MdMethod *mgfMethod = ctx->pad.para.oaep.mgfMeth;
675 
676     if (hashMethod == NULL || mgfMethod == NULL || (in == NULL && inLen != 0) || pad == NULL) {
677         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
678         return CRYPT_NULL_INPUT;
679     }
680 
681     uint32_t hashLen = hashMethod->mdSize;
682 
683     /* If mLen > k - 2hLen - 2, output "message too long" and stop<rfc8017>
684         k is output len, hLen is hashLen, mLen is inLen
685     */
686     ret = OaepSetLengthCheck(padLen, inLen, hashLen);
687     if (ret != CRYPT_SUCCESS) {
688         BSL_ERR_PUSH_ERROR(ret);
689         return ret;
690     }
691 
692     *pad = 0x00;
693     uint8_t *seed = pad + 1;
694     // Generate a random octet string seed of length hLen<rfc8017>
695     ret = CRYPT_RandEx(ctx->libCtx, seed, hashLen);
696     if (ret != CRYPT_SUCCESS) {
697         BSL_ERR_PUSH_ERROR(ret);
698         return ret;
699     }
700     uint8_t *db = seed + hashLen;
701 
702     // Calculate hash
703     const CRYPT_ConstData data = {ctx->label.data, ctx->label.len};
704     ret = CRYPT_CalcHash(hashMethod, &data, 1, db, &hashLen);
705     if (ret != CRYPT_SUCCESS) {
706         BSL_ERR_PUSH_ERROR(ret);
707         return ret;
708     }
709 
710     ret = OaepSetPs(in, inLen, db, padLen, hashLen);
711     if (ret != CRYPT_SUCCESS) {
712         BSL_ERR_PUSH_ERROR(ret);
713         return ret;
714     }
715 
716     // set maskedDB
717     ret = OaepSetMaskedDB(mgfMethod, db, seed, padLen, hashLen);
718     if (ret != CRYPT_SUCCESS) {
719         BSL_ERR_PUSH_ERROR(ret);
720         return ret;
721     }
722 
723     // set seedmask
724     ret = OaepSetSeedMask(mgfMethod, db, seed, padLen, hashLen);
725     if (ret != CRYPT_SUCCESS) {
726         BSL_ERR_PUSH_ERROR(ret);
727     }
728 
729     return ret;
730 }
731 #endif // HITLS_CRYPTO_RSA_ENCRYPT
732 
733 #ifdef HITLS_CRYPTO_RSA_DECRYPT
OaepVerifyLengthCheck(uint32_t outLen,uint32_t inLen,uint32_t hashLen)734 static int32_t OaepVerifyLengthCheck(uint32_t outLen, uint32_t inLen, uint32_t hashLen)
735 {
736     if (outLen > RSA_MAX_MODULUS_LEN || inLen > RSA_MAX_MODULUS_LEN || hashLen > HASH_MAX_MDSIZE) {
737         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
738         return CRYPT_RSA_ERR_INPUT_VALUE;
739     }
740     if (outLen == 0 || hashLen == 0 || inLen == 0) {
741         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
742         return CRYPT_RSA_ERR_INPUT_VALUE;
743     }
744     // If k < 2hLen + 2, output "decryption error" and stop
745     if (inLen < 2 * hashLen + 2) {
746         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
747         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
748     }
749     return CRYPT_SUCCESS;
750 }
751 
OaepDecodeSeedMask(const EAL_MdMethod * mgfMethod,const uint8_t * in,uint32_t inLen,CRYPT_Data * seedMask,uint32_t hashLen)752 static int32_t OaepDecodeSeedMask(const EAL_MdMethod *mgfMethod, const uint8_t *in, uint32_t inLen,
753     CRYPT_Data *seedMask, uint32_t hashLen)
754 {
755     uint32_t i;
756     int32_t ret;
757 
758     const uint8_t *maskedSeed = in + 1;
759     uint32_t maskedDBLen = inLen - hashLen - 1;
760     const uint8_t *maskedDB = maskedSeed + hashLen;
761 
762     ret = CRYPT_Mgf1(mgfMethod, maskedDB, maskedDBLen, seedMask->data, hashLen);
763     if (ret != CRYPT_SUCCESS) {
764         return ret;
765     }
766     for (i = 0; i < hashLen; i++) {
767         seedMask->data[i] ^= maskedSeed[i];
768     }
769     return CRYPT_SUCCESS;
770 }
771 
OaepDecodeMaskedDB(const EAL_MdMethod * mgfMethod,const CRYPT_Data * in,const uint8_t * seedMask,uint32_t hashLen,const CRYPT_Data * dbMaskData)772 static int32_t OaepDecodeMaskedDB(const EAL_MdMethod *mgfMethod, const CRYPT_Data *in, const uint8_t *seedMask,
773     uint32_t hashLen, const CRYPT_Data *dbMaskData)
774 {
775     int32_t ret;
776     uint32_t i;
777     const uint8_t *maskedDB = in->data + 1 + hashLen;
778     uint32_t maskedDBLen = in->len - hashLen - 1;
779 
780     ret = CRYPT_Mgf1(mgfMethod, seedMask, hashLen, dbMaskData->data, maskedDBLen);
781     if (ret != CRYPT_SUCCESS) {
782         BSL_ERR_PUSH_ERROR(ret);
783         return ret;
784     }
785     for (i = 0; i < maskedDBLen; i++) {
786         dbMaskData->data[i] ^= maskedDB[i];
787     }
788 
789     return ret;
790 }
791 
OaepVerifyHashMaskDB(const EAL_MdMethod * hashMethod,CRYPT_Data * paramData,CRYPT_Data * dbMaskData,uint32_t hashLen,uint32_t * offset)792 static int32_t OaepVerifyHashMaskDB(const EAL_MdMethod *hashMethod, CRYPT_Data *paramData, CRYPT_Data *dbMaskData,
793     uint32_t hashLen, uint32_t *offset)
794 {
795     int32_t ret;
796     uint8_t hashVal[HASH_MAX_MDSIZE];
797     CRYPT_ConstData data = {paramData->data, paramData->len};
798     ret = CRYPT_CalcHash(hashMethod, &data, 1, hashVal, &hashLen);
799     if (ret != CRYPT_SUCCESS) {
800         BSL_ERR_PUSH_ERROR(ret);
801         return ret;
802     }
803 
804     if (memcmp(dbMaskData->data, hashVal, hashLen) != 0) {
805         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
806         return CRYPT_RSA_NOR_VERIFY_FAIL;
807     }
808 
809     *offset = hashLen;
810     while ((*offset) < dbMaskData->len && dbMaskData->data[(*offset)] == 0) {
811         (*offset)++;
812     }
813     if ((*offset) >= dbMaskData->len) {
814         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
815         return CRYPT_RSA_NOR_VERIFY_FAIL;
816     }
817 
818     if (dbMaskData->data[(*offset)] != 0x01) {
819         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
820         return CRYPT_RSA_NOR_VERIFY_FAIL;
821     }
822 
823     (*offset)++;
824     return ret;
825 }
826 
CRYPT_RSA_VerifyPkcs1Oaep(const EAL_MdMethod * hashMethod,const EAL_MdMethod * mgfMethod,const uint8_t * in,uint32_t inLen,const uint8_t * param,uint32_t paramLen,uint8_t * msg,uint32_t * msgLen)827 int32_t CRYPT_RSA_VerifyPkcs1Oaep(const EAL_MdMethod *hashMethod, const EAL_MdMethod *mgfMethod, const uint8_t *in,
828     uint32_t inLen, const uint8_t *param, uint32_t paramLen, uint8_t *msg, uint32_t *msgLen)
829 {
830     if (hashMethod == NULL || mgfMethod == NULL || in == NULL || msg == NULL) {
831         BSL_ERR_PUSH_ERROR(CRYPT_NULL_INPUT);
832         return CRYPT_NULL_INPUT;
833     }
834     uint32_t hashLen = hashMethod->mdSize;
835     if (inLen <= (hashLen + 1)) {
836         BSL_ERR_PUSH_ERROR(CRYPT_RSA_ERR_INPUT_VALUE);
837         return CRYPT_RSA_ERR_INPUT_VALUE;
838     }
839     uint32_t maskedDBLen = inLen - hashLen - 1;
840     int32_t ret;
841     uint32_t offset;
842     uint8_t seedMask[HASH_MAX_MDSIZE];
843     CRYPT_Data seedData = { (uint8_t *)(uintptr_t)seedMask, HASH_MAX_MDSIZE };
844     CRYPT_Data paramData = { (uint8_t *)(uintptr_t)param, paramLen };
845     CRYPT_Data inData = { (uint8_t *)(uintptr_t)in, inLen };
846     uint8_t *maskDB = (uint8_t *)BSL_SAL_Malloc(maskedDBLen);
847     if (maskDB == NULL) {
848         BSL_ERR_PUSH_ERROR(CRYPT_MEM_ALLOC_FAIL);
849         return CRYPT_MEM_ALLOC_FAIL;
850     }
851     CRYPT_Data dbMaskData = { maskDB, maskedDBLen };
852 
853     /* If k < 2hLen + 2, output "decryption error" and stop.<rfc8017>
854         k is intLen , hLen is hashLen
855     */
856     GOTO_ERR_IF_EX(OaepVerifyLengthCheck(*msgLen, inLen, hashLen), ret);
857 
858     GOTO_ERR_IF_EX(OaepDecodeSeedMask(mgfMethod, in, inLen, &seedData, hashLen), ret);
859 
860     GOTO_ERR_IF_EX(OaepDecodeMaskedDB(mgfMethod, &inData, seedMask, hashLen, &dbMaskData), ret);
861 
862     GOTO_ERR_IF_EX(OaepVerifyHashMaskDB(hashMethod, &paramData, &dbMaskData, hashLen, &offset), ret);
863 
864     if (memcpy_s(msg, *msgLen, maskDB + offset, maskedDBLen - offset) != EOK) {
865         ret = CRYPT_RSA_NOR_VERIFY_FAIL;
866         BSL_ERR_PUSH_ERROR(ret);
867         goto ERR;
868     }
869     *msgLen = maskedDBLen - offset;
870 ERR:
871     BSL_SAL_CleanseData(maskDB, maskedDBLen);
872     BSL_SAL_FREE(maskDB);
873     return ret;
874 }
875 #endif // HITLS_CRYPTO_RSA_DECRYPT
876 #endif // HITLS_CRYPTO_RSAES_OAEP
877 
878 #if defined(HITLS_CRYPTO_RSA_ENCRYPT) && \
879     (defined(HITLS_CRYPTO_RSAES_PKCSV15_TLS) || defined(HITLS_CRYPTO_RSAES_PKCSV15))
880 // Pad output format: EM = 00 || 02 || PS || 00 || M; where M indicates message.
CRYPT_RSA_SetPkcsV15Type2(void * libCtx,const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t outLen)881 int32_t CRYPT_RSA_SetPkcsV15Type2(void *libCtx, const uint8_t *in, uint32_t inLen,
882     uint8_t *out, uint32_t outLen)
883 {
884     // If mLen > k - 11, output "message too long" and stop.<rfc8017>
885     if (inLen + 11 > outLen) {
886         BSL_ERR_PUSH_ERROR(CRYPT_RSA_BUFF_LEN_NOT_ENOUGH);
887         return CRYPT_RSA_BUFF_LEN_NOT_ENOUGH;
888     }
889 
890     int32_t ret;
891     uint32_t i;
892     uint8_t *ps = out + 2;
893     uint32_t psLen = outLen - inLen - 3;
894     uint8_t *msg = out + psLen + 3;
895 
896     *out = 0x00;
897     *(out + 1) = 0x02;
898     *(out + outLen - inLen - 2) = 0x00;
899     // msg padding, outLen minus the 3-byte constant, ps length, and start padding.
900     if (inLen != 0 && memcpy_s(msg, outLen - (psLen + 3), in, inLen) != EOK) {
901         BSL_ERR_PUSH_ERROR(CRYPT_SECUREC_FAIL);
902         return CRYPT_SECUREC_FAIL;
903     }
904 
905     // cal ps
906     ret = CRYPT_RandEx(libCtx, ps, psLen);
907     if (ret != CRYPT_SUCCESS) {
908         return ret;
909     }
910     ps[psLen] = 0;
911     for (i = 0; i < psLen; i++) {
912         if (*(ps + i) != 0) {
913             continue;
914         }
915         do {
916             // no zero
917             ret = CRYPT_RandEx(libCtx, ps + i, 1);
918             if (ret != CRYPT_SUCCESS) {
919                 return ret;
920             }
921         } while (*(ps + i) == 0);
922     }
923 
924     return CRYPT_SUCCESS;
925 }
926 #endif // HITLS_CRYPTO_RSA_ENCRYPT && (EC_PKCSV15_TLS || EC_PKCSV15)
927 
928 #ifdef HITLS_CRYPTO_RSA_DECRYPT
929 #ifdef HITLS_CRYPTO_RSAES_PKCSV15
CRYPT_RSA_VerifyPkcsV15Type2(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t * outLen)930 int32_t CRYPT_RSA_VerifyPkcsV15Type2(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t *outLen)
931 {
932     uint32_t zeroIndex = 0;
933     uint32_t index = ~(0);
934     uint32_t firstZero = Uint32ConstTimeEqual(in[0], 0x00);
935     uint32_t firstTwo = Uint32ConstTimeEqual(in[1], 0x02);
936     // Check the ps starting from subscript 2.
937     for (uint32_t i = 2; i < inLen; i++) {
938         uint32_t equals0 = Uint32ConstTimeIsZero(in[i]);
939         zeroIndex = Uint32ConstTimeSelect(index & equals0, i, zeroIndex);
940         index = Uint32ConstTimeSelect(equals0, 0, index);
941     }
942 
943     uint32_t valid = firstZero & firstTwo & (~index);
944     // Pad output format: EM = 00 || 02 || PS || 00 || M; where M is a message, and PS must be >= 8.
945     // Therefore, the subscript of the second 0 must be greater than or equal to 10.
946     valid &= Uint32ConstTimeGe(zeroIndex, 10);
947 
948     zeroIndex++;
949     if (valid == 0) {
950         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
951         return CRYPT_RSA_NOR_VERIFY_FAIL;
952     }
953 
954     if (inLen - zeroIndex > *outLen) {
955         BSL_ERR_PUSH_ERROR(CRYPT_RSA_NOR_VERIFY_FAIL);
956         return CRYPT_RSA_NOR_VERIFY_FAIL;
957     }
958 
959     (void)memcpy_s(out, *outLen, in + zeroIndex, inLen - zeroIndex);
960     *outLen = inLen - zeroIndex;
961 
962     return CRYPT_SUCCESS;
963 }
964 #endif // HITLS_CRYPTO_RSAES_PKCSV15
965 
966 #ifdef HITLS_CRYPTO_RSAES_PKCSV15_TLS
CRYPT_RSA_VerifyPkcsV15Type2TLS(const uint8_t * in,uint32_t inLen,uint8_t * out,uint32_t * outLen)967 int32_t CRYPT_RSA_VerifyPkcsV15Type2TLS(const uint8_t *in, uint32_t inLen, uint8_t *out, uint32_t *outLen)
968 {
969     uint32_t masterSecretLen = *outLen;
970     uint32_t zeroIndex = 0;
971     uint32_t index = ~(0);
972     uint32_t fist = Uint32ConstTimeEqual(in[0], 0x00);
973     uint32_t second = Uint32ConstTimeEqual(in[1], 0x02);
974     for (uint32_t i = 2; i < inLen; i++) {
975         uint32_t equals0 = Uint32ConstTimeIsZero(in[i]);
976         zeroIndex = Uint32ConstTimeSelect(index & equals0, i, zeroIndex);
977         index = Uint32ConstTimeSelect(equals0, 0, index);
978     }
979 
980     uint32_t valid = fist & second & (~index);
981     // Pad output format: EM = 00 || 02 || PS || 00 || M; where M is a message, and PS must be >= 8.
982     // Therefore, the subscript of the second 0 must be greater than or equal to 10.
983     valid &= Uint32ConstTimeGe(zeroIndex, 10);
984     zeroIndex++;
985     uint32_t secretLen = inLen - zeroIndex;
986     valid &= ~(Uint32ConstTimeGt(secretLen, *outLen));
987     for (uint32_t i = 0; i < masterSecretLen; i++) {
988         uint32_t mask = valid & Uint32ConstTimeLt(i, secretLen);
989         uint32_t inIndex = mask & zeroIndex;
990         out[i] = Uint8ConstTimeSelect(mask, *(in + inIndex + i), 0);
991     }
992     *outLen = secretLen;
993 
994     // if the 'plaintext' is PKCS15 , the valid should be 0xffffffff, else should be 0
995     // so return 0 for success, else return 0xffffffff
996     return ~valid;
997 }
998 #endif // HITLS_CRYPTO_RSAES_PKCSV15_TLS
999 #endif // HITLS_CRYPTO_RSA_DECRYPT
1000 
1001 #endif /* HITLS_CRYPTO_RSA */