• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include <stdio.h>
16 #include <stdlib.h>
17 #include "hvb_hash_sha256.h"
18 #include "hvb_crypto.h"
19 #include "hvb_rsa.h"
20 #include "hvb_util.h"
21 #include "hvb_sysdeps.h"
22 #include "hvb_rsa_verify.h"
23 
24 
25 #define SHA256_DIGEST_LEN 32
26 #define PSS_EM_PADDING_LEN 2
27 #define PSS_MTMP_PADDING_LEN 8
28 #define PSS_DB_PADDING_LEN 1
29 #define PSS_END_PADDING_UNIT 0xBC
30 #define PSS_LEFTMOST_BIT_MASK 0xFFU
31 
32 #define PADDING_UNIT_ZERO 0x00
33 #define PADDING_UNIT_ONE 0x01
34 #define RSA_WIDTH_MAX 8192
35 
36 #define WORD_BYTE_SIZE sizeof(unsigned long)
37 #define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
38 #define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
39 #define bit2byte(bits) ((bits) >> 3)
40 #define byte2bit(byte) ((byte) << 3)
41 #define bit_val(x) (1U << (x))
42 #define bit_mask(x) (bit_val(x) - 1U)
43 #define bit_align(n, bit) (((n) + bit_mask(bit)) & (~(bit_mask(bit))))
44 #define bit2byte_align(bits) bit2byte(bit_align(bits, 3))
45 #define byte2dword(bytes) (((bytes) + (WORD_BYTE_SIZE) - 1) / WORD_BYTE_SIZE)
46 #define dword2byte(words) ((words) * WORD_BYTE_SIZE)
47 
48 /* calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
emsa_pss_calc_m(const uint8_t * pdigest,uint32_t digestlen,uint8_t * salt,uint32_t saltlen,uint8_t ** m)49 static int emsa_pss_calc_m(const uint8_t *pdigest, uint32_t digestlen,
50                            uint8_t *salt, uint32_t saltlen,
51                            uint8_t **m)
52 {
53     uint8_t *m_tmp = NULL;
54     uint32_t m_tmp_len;
55 
56     m_tmp_len = digestlen + saltlen + PSS_MTMP_PADDING_LEN;
57     m_tmp = (uint8_t *)hvb_malloc(m_tmp_len);
58     if (!m_tmp) {
59         return PARAM_EMPTY_ERROR;
60     }
61 
62     hvb_memset(m_tmp, 0, PSS_MTMP_PADDING_LEN);
63     hvb_memcpy(&m_tmp[PSS_MTMP_PADDING_LEN], pdigest, digestlen);
64 
65     if (saltlen != 0 && salt) {
66         hvb_memcpy(&m_tmp[PSS_MTMP_PADDING_LEN + digestlen], salt, saltlen);
67     }
68 
69     *m = m_tmp;
70     return VERIFY_OK;
71 }
72 
73 /* rsa verify last step compare hash value */
emsa_pss_hash_cmp(uint8_t * m_tmp,uint32_t m_tmp_len,uint8_t * hash,uint32_t digestlen)74 static int emsa_pss_hash_cmp(uint8_t *m_tmp, uint32_t m_tmp_len,
75                              uint8_t *hash, uint32_t digestlen)
76 {
77     int ret;
78     uint8_t *hash_tmp = NULL;
79 
80     hash_tmp = (uint8_t *)hvb_malloc(digestlen);
81     if (!hash_tmp) {
82         return HASH_CMP_FAIL;
83     }
84     if (hash_sha256_single(m_tmp, m_tmp_len, hash_tmp, digestlen) != HASH_OK) {
85         ret = HASH_CMP_FAIL;
86         goto rsa_error;
87     }
88     /* compare twice */
89     ret = VERIFY_OK;
90     ret += hvb_memcmp(hash, hash_tmp, digestlen);
91     ret += hvb_memcmp(hash, hash_tmp, digestlen);
92     if (ret != VERIFY_OK)
93         ret = HASH_CMP_FAIL;
94 rsa_error:
95     hvb_free(hash_tmp);
96     return ret;
97 }
98 
rsa_pss_get_emlen(uint32_t klen,struct long_int_num * pn,uint32_t * emlen,uint32_t * embits)99 static int rsa_pss_get_emlen(uint32_t klen, struct long_int_num *pn,
100                              uint32_t *emlen, uint32_t *embits)
101 {
102     *embits = lin_get_bitlen(pn);
103     if (*embits == 0) {
104         return CALC_EMLEN_ERROR;
105     }
106     (*embits)--;
107 
108     *emlen = bit2byte_align(*embits);
109     if (*emlen == 0) {
110         return CALC_EMLEN_ERROR;
111     }
112 
113     if (*emlen > klen) {
114         return CALC_EMLEN_ERROR;
115     }
116 
117     return VERIFY_OK;
118 }
119 
120 /* make generate function V1 */
rsa_gen_mask_mgf_v1(uint8_t * seed,uint32_t seed_len,uint8_t * mask,uint32_t mask_len)121 static int rsa_gen_mask_mgf_v1(uint8_t *seed, uint32_t seed_len,
122                                uint8_t *mask, uint32_t mask_len)
123 {
124     int ret = VERIFY_OK;
125     uint32_t cnt = 0;
126     uint32_t cnt_maxsize = 0;
127     uint8_t *p_tmp = NULL;
128     uint8_t *pt = NULL;
129     uint8_t *pc = NULL;
130     const uint32_t hash_len = SHA256_DIGEST_LEN;
131 
132     /* Step 1: mask length is smaller than the maximum key length */
133     if (mask_len > bit2byte(RSA_WIDTH_MAX)) {
134         return CALC_MASK_ERROR;
135     }
136 
137     /* Step 2:  Let pt and pt_tmp be the empty octet string. */
138     pt = (uint8_t *)hvb_malloc(mask_len + hash_len);
139     if (!pt) {
140         return CALC_MASK_ERROR;
141     }
142 
143     pc = (uint8_t *)hvb_malloc(seed_len + sizeof(uint32_t));
144     if (!pc) {
145         ret = CALC_MASK_ERROR;
146         goto rsa_error;
147     }
148 
149     /*
150      * Step 3:  For counter from 0 to (mask_len + hash_len - 1) / hash_len ,
151      * do the following:
152      * string T:   T = T || Hash (pseed || counter)
153      */
154     p_tmp = pt;
155     hvb_memcpy(pc, seed, seed_len);
156 
157     hvb_memset(pc + seed_len, 0, sizeof(uint32_t));
158     /* step 3.1: count of Hash blocks needed for mask calculation */
159     cnt_maxsize = (uint32_t)((mask_len + hash_len - 1) / hash_len);
160 
161     for (cnt = 0; cnt < cnt_maxsize; cnt++) {
162         /* step 3.2: pt_tmp = pseed ||Counter */
163         pc[seed_len + sizeof(uint32_t) - sizeof(uint8_t)] = cnt;
164 
165         /* step 3.3: calc T, T = T || Hash (pt_tmp) */
166         if (hash_sha256_single(pc, seed_len + sizeof(uint32_t), p_tmp, hash_len) != HASH_OK) {
167         ret = CALC_MASK_ERROR;
168         goto rsa_error;
169         }
170         p_tmp += hash_len;
171     }
172     /* Step 4:  Output the leading L octets of T as the octet string mask. */
173     hvb_memcpy(mask, pt, mask_len);
174 
175 rsa_error:
176     hvb_free(pt);
177     hvb_free(pc);
178     return ret;
179 }
180 
emsa_pss_verify_check_db(uint8_t * db,uint32_t db_len,uint32_t emlen,uint32_t digestlen,uint32_t saltlen)181 static int emsa_pss_verify_check_db(uint8_t *db, uint32_t db_len,
182                                     uint32_t emlen, uint32_t digestlen,
183                                     uint32_t saltlen)
184 {
185     int i;
186 
187     for (i = 0; i < emlen - digestlen - saltlen - PSS_EM_PADDING_LEN; i++) {
188         if (db[i] != PADDING_UNIT_ZERO) {
189             return CHECK_DB_ERROR;
190         }
191     }
192 
193     if (db[db_len - saltlen - PSS_DB_PADDING_LEN] != PADDING_UNIT_ONE) {
194         return CMP_DB_FAIL;
195     }
196 
197     return VERIFY_OK;
198 }
199 
emsa_pss_verify(uint32_t saltlen,const uint8_t * pdigest,uint32_t digestlen,uint32_t emlen,uint32_t embits,uint8_t * pem)200 static int emsa_pss_verify(uint32_t saltlen, const uint8_t *pdigest,
201                            uint32_t digestlen, uint32_t emlen,
202                            uint32_t embits, uint8_t *pem)
203 {
204     int ret;
205     uint32_t i;
206     uint32_t masklen;
207     uint32_t m_tmp_len;
208     uint32_t db_len = 0;
209     uint8_t *hash = NULL;
210     uint8_t *m_tmp = NULL;
211     uint8_t *maskedb = NULL;
212     uint8_t *salt = NULL;
213     uint8_t *db = NULL;
214 
215     masklen = byte2bit(emlen) - embits;
216 
217     /*
218      * Step 1: Skip digest calculate
219      * Step 2: Check sizes, emLen < hLen + sLen + 2
220      */
221     if (emlen < digestlen + PSS_EM_PADDING_LEN || saltlen > (emlen - digestlen - PSS_EM_PADDING_LEN)) {
222         return CALC_EMLEN_ERROR;
223     }
224     /* Step 3: if rightmost of EM is oxbc */
225     if (pem[emlen - PSS_DB_PADDING_LEN] != PSS_END_PADDING_UNIT) {
226         return CALC_0XBC_ERROR;
227     }
228 
229     /* Step 4: set maskedDB and H */
230     maskedb = pem;
231     db_len = emlen - digestlen - PSS_DB_PADDING_LEN;
232     hash = &pem[db_len];
233 
234     /* Step 5: Check that the leftmost bits in the leftmost octet of EM have the value 0 */
235     if ((maskedb[0] & (~(PSS_LEFTMOST_BIT_MASK >> masklen))) != 0) {
236         return CALC_EM_ERROR;
237     }
238 
239     /* Step 6: calc dbMask, MGF(H) */
240     db = (uint8_t *)hvb_malloc(db_len); /* db is dbmask */
241     if (!db) {
242         return CALC_DB_ERROR;
243     }
244     ret = rsa_gen_mask_mgf_v1(hash, digestlen, db, db_len);
245     if (ret != VERIFY_OK) {
246         goto rsa_error;
247     }
248     /* Step 7: calc db, maskedDB ^ db_mask */
249     for (i = 0; i < db_len; i++) {
250         db[i] = maskedb[i] ^ db[i];
251     }
252 
253     /* Step 8: Set the leftmost 8*emLen-emBits bits in DB to zero */
254     db[0] &= PSS_LEFTMOST_BIT_MASK >> masklen;
255 
256     /* Step 9: check db padding data */
257     ret = emsa_pss_verify_check_db(db, db_len, emlen, digestlen, saltlen);
258     if (ret != VERIFY_OK) {
259         goto rsa_error;
260     }
261     /* Step 10: set salt be the last slen of DB */
262     if (saltlen != 0) {
263         salt = &db[db_len - saltlen];
264     }
265 
266     /* Step 11: calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
267     ret = emsa_pss_calc_m(pdigest, digestlen, salt, saltlen, &m_tmp);
268     if (ret != VERIFY_OK) {
269         goto rsa_error;
270     }
271     /* Step 12: hash_tmp = H' = Hash(M') */
272     m_tmp_len = PSS_MTMP_PADDING_LEN + digestlen + saltlen;
273     ret = emsa_pss_hash_cmp(m_tmp, m_tmp_len, hash, digestlen);
274 
275 rsa_error:
276     hvb_free(db);
277     hvb_free(m_tmp);
278     return ret;
279 }
280 
invert_copy(uint8_t * dst,uint8_t * src,uint32_t len)281 static inline void invert_copy(uint8_t *dst, uint8_t *src, uint32_t len)
282 {
283     for (uint32_t i = 0; i < len; i++) {
284         dst[i] = src[len - i - 1];
285     }
286 }
287 
hvb_rsa_verify_pss_param_check(const struct hvb_rsa_pubkey * pkey,const uint8_t * pdigest,uint32_t digestlen,uint8_t * psign,uint32_t signlen)288 static int hvb_rsa_verify_pss_param_check(const struct hvb_rsa_pubkey *pkey, const uint8_t *pdigest,
289                                           uint32_t digestlen, uint8_t *psign, uint32_t signlen)
290 {
291     uint32_t klen;
292 
293     if (!pkey || !pdigest || !psign) {
294         return PARAM_EMPTY_ERROR;
295     }
296     if (!pkey->pn || !pkey->p_rr || pkey->n_n0_i == 0) {
297         return PUBKEY_EMPTY_ERROR;
298     }
299     klen = bit2byte(pkey->width);
300     if (digestlen != SHA256_DIGEST_LEN) {
301         return DIGEST_LEN_ERROR;
302     }
303     if (pkey->nlen != klen || pkey->rlen > pkey->nlen) {
304         return PUBKEY_LEN_ERROR;
305     }
306     if (signlen > klen) {
307         return SIGN_LEN_ERROR;
308     }
309 
310     return VERIFY_OK;
311 }
312 
hvb_rsa_verify_pss_param_convert(const struct hvb_rsa_pubkey * pkey,uint8_t * psign,uint32_t signlen,struct long_int_num * p_n,struct long_int_num * p_rr,struct long_int_num * p_m)313 static int hvb_rsa_verify_pss_param_convert(const struct hvb_rsa_pubkey *pkey, uint8_t *psign,
314                                             uint32_t signlen, struct long_int_num *p_n,
315                                             struct long_int_num *p_rr, struct long_int_num *p_m)
316 {
317     invert_copy((uint8_t *)p_n->p_uint, pkey->pn, pkey->nlen);
318     p_n->valid_word_len = byte2dword(pkey->nlen);
319     lin_update_valid_len(p_n);
320     if (!p_n) {
321         return PUBKEY_EMPTY_ERROR;
322     }
323 
324     invert_copy((uint8_t *)p_m->p_uint, psign, signlen);
325     p_m->valid_word_len = byte2dword(pkey->nlen);
326     lin_update_valid_len(p_m);
327     if (!p_m) {
328         return SIGN_EMPTY_ERROR;
329     }
330 
331     invert_copy((uint8_t *)p_rr->p_uint, pkey->p_rr, pkey->rlen);
332     p_rr->valid_word_len = byte2dword(pkey->nlen);
333     lin_update_valid_len(p_rr);
334     if (!p_rr) {
335         return PUBKEY_EMPTY_ERROR;
336     }
337 
338     return VERIFY_OK;
339 }
340 
hvb_rsa_verify_pss(const struct hvb_rsa_pubkey * pkey,const uint8_t * pdigest,uint32_t digestlen,uint8_t * psign,uint32_t signlen,uint32_t saltlen)341 int hvb_rsa_verify_pss(const struct hvb_rsa_pubkey
342                        *pkey, const uint8_t *pdigest,
343                        uint32_t digestlen, uint8_t *psign,
344                        uint32_t signlen, uint32_t saltlen)
345 {
346     int ret;
347     uint32_t klen;
348     uint32_t emlen;
349     uint32_t embits;
350     unsigned long n_n0_i;
351     struct long_int_num *p_n = NULL;
352     struct long_int_num *p_m = NULL;
353     struct long_int_num *p_rr = NULL;
354     struct long_int_num *em = NULL;
355     uint8_t *em_data = NULL;
356 
357     ret = hvb_rsa_verify_pss_param_check(pkey, pdigest, digestlen, psign, signlen);
358     if (ret != VERIFY_OK) {
359         return ret;
360     }
361 
362     n_n0_i = (unsigned long)pkey->n_n0_i;
363     klen = bit2byte(pkey->width);
364     p_n = lin_create(byte2dword(pkey->nlen));
365     if (!p_n) {
366         return MEMORY_ERROR;
367     }
368     p_m = lin_create(byte2dword(pkey->nlen));
369     if (!p_m) {
370         ret = MEMORY_ERROR;
371         goto rsa_error;
372     }
373     p_rr = lin_create(byte2dword(pkey->nlen));
374     if (!p_rr) {
375         ret = MEMORY_ERROR;
376         goto rsa_error;
377     }
378     ret = hvb_rsa_verify_pss_param_convert(pkey, psign, signlen, p_n, p_rr, p_m);
379     if (ret != VERIFY_OK) {
380         goto rsa_error;
381     }
382     /* Step 1: RSA prim decrypt */
383     em = montgomery_mod_exp(p_m, p_n, n_n0_i, p_rr, pkey->e);
384     if (!em) {
385         ret = MOD_EXP_CALC_FAIL;
386         goto rsa_error;
387     }
388     em->valid_word_len = byte2dword(pkey->nlen);
389     lin_update_valid_len(em);
390     em_data = hvb_malloc(dword2byte(em->valid_word_len));
391     if (!em_data) {
392         ret = MOD_EXP_CALC_FAIL;
393         goto rsa_error;
394     }
395     invert_copy(em_data, (uint8_t *)em->p_uint, dword2byte(em->valid_word_len));
396     /* Step 2: emsa pss verify */
397     ret = rsa_pss_get_emlen(klen, p_n, &emlen, &embits);
398     if (ret != VERIFY_OK) {
399         goto rsa_error;
400     }
401     if (klen - emlen == 1 && em_data[0] != 0) {
402         ret = MOD_EXP_CALC_FAIL;
403         goto rsa_error;
404     }
405     ret = emsa_pss_verify(saltlen, pdigest, digestlen, emlen, embits, em_data + klen - emlen);
406 
407 rsa_error:
408     lin_free(em);
409     lin_free(p_n);
410     lin_free(p_m);
411     lin_free(p_rr);
412     if (em_data) {
413         hvb_free(em_data);
414     }
415 
416     return ret;
417 }
418