1 /*
2 * Copyright (c) 2022-2023 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15 #include <stdio.h>
16 #include <stdlib.h>
17 #include "hvb_hash_sha256.h"
18 #include "hvb_crypto.h"
19 #include "hvb_rsa.h"
20 #include "hvb_util.h"
21 #include "hvb_sysdeps.h"
22 #include "hvb_rsa_verify.h"
23
24
25 #define SHA256_DIGEST_LEN 32
26 #define PSS_EM_PADDING_LEN 2
27 #define PSS_MTMP_PADDING_LEN 8
28 #define PSS_DB_PADDING_LEN 1
29 #define PSS_END_PADDING_UNIT 0xBC
30 #define PSS_LEFTMOST_BIT_MASK 0xFFU
31
32 #define PADDING_UNIT_ZERO 0x00
33 #define PADDING_UNIT_ONE 0x01
34 #define RSA_WIDTH_MAX 8192
35
36 #define WORD_BYTE_SIZE sizeof(unsigned long)
37 #define WORD_BIT_SIZE (WORD_BYTE_SIZE * 8)
38 #define WORD_BIT_MASK (((1UL << WORD_BIT_SIZE) - 1))
39 #define bit2byte(bits) ((bits) >> 3)
40 #define byte2bit(byte) ((byte) << 3)
41 #define bit_val(x) (1U << (x))
42 #define bit_mask(x) (bit_val(x) - 1U)
43 #define bit_align(n, bit) (((n) + bit_mask(bit)) & (~(bit_mask(bit))))
44 #define bit2byte_align(bits) bit2byte(bit_align(bits, 3))
45 #define byte2dword(bytes) (((bytes) + (WORD_BYTE_SIZE) - 1) / WORD_BYTE_SIZE)
46 #define dword2byte(words) ((words) * WORD_BYTE_SIZE)
47
48 /* calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
emsa_pss_calc_m(const uint8_t * pdigest,uint32_t digestlen,uint8_t * salt,uint32_t saltlen,uint8_t ** m)49 static int emsa_pss_calc_m(const uint8_t *pdigest, uint32_t digestlen,
50 uint8_t *salt, uint32_t saltlen,
51 uint8_t **m)
52 {
53 uint8_t *m_tmp = NULL;
54 uint32_t m_tmp_len;
55 int ret = VERIFY_OK;
56
57 m_tmp_len = digestlen + saltlen + PSS_MTMP_PADDING_LEN;
58 m_tmp = (uint8_t *)hvb_malloc(m_tmp_len);
59 if (!m_tmp) {
60 return PARAM_EMPTY_ERROR;
61 }
62
63 if (hvb_memset_s(m_tmp, m_tmp_len, 0, PSS_MTMP_PADDING_LEN) != 0) {
64 ret = MEMORY_ERROR;
65 goto error;
66 }
67
68 if (hvb_memcpy_s(&m_tmp[PSS_MTMP_PADDING_LEN], m_tmp_len - PSS_MTMP_PADDING_LEN, pdigest, digestlen) != 0) {
69 ret = MEMORY_ERROR;
70 goto error;
71 }
72
73 if (saltlen != 0 && salt) {
74 if (hvb_memcpy_s(&m_tmp[PSS_MTMP_PADDING_LEN + digestlen], saltlen, salt, saltlen) != 0) {
75 ret = MEMORY_ERROR;
76 goto error;
77 }
78 }
79
80 *m = m_tmp;
81 return ret;
82 error:
83 hvb_free(m_tmp);
84 return ret;
85 }
86
87 /* rsa verify last step compare hash value */
emsa_pss_hash_cmp(uint8_t * m_tmp,uint32_t m_tmp_len,uint8_t * hash,uint32_t digestlen)88 static int emsa_pss_hash_cmp(uint8_t *m_tmp, uint32_t m_tmp_len,
89 uint8_t *hash, uint32_t digestlen)
90 {
91 int ret;
92 uint8_t *hash_tmp = NULL;
93
94 hash_tmp = (uint8_t *)hvb_malloc(digestlen);
95 if (!hash_tmp) {
96 return HASH_CMP_FAIL;
97 }
98 if (hash_sha256_single(m_tmp, m_tmp_len, hash_tmp, digestlen) != HASH_OK) {
99 ret = HASH_CMP_FAIL;
100 goto rsa_error;
101 }
102 /* compare twice */
103 ret = VERIFY_OK;
104 ret += hvb_memcmp(hash, hash_tmp, digestlen);
105 ret += hvb_memcmp(hash, hash_tmp, digestlen);
106 if (ret != VERIFY_OK)
107 ret = HASH_CMP_FAIL;
108 rsa_error:
109 hvb_free(hash_tmp);
110 return ret;
111 }
112
rsa_pss_get_emlen(uint32_t klen,struct long_int_num * pn,uint32_t * emlen,uint32_t * embits)113 static int rsa_pss_get_emlen(uint32_t klen, struct long_int_num *pn,
114 uint32_t *emlen, uint32_t *embits)
115 {
116 *embits = lin_get_bitlen(pn);
117 if (*embits == 0) {
118 return CALC_EMLEN_ERROR;
119 }
120 (*embits)--;
121
122 *emlen = bit2byte_align(*embits);
123 if (*emlen == 0) {
124 return CALC_EMLEN_ERROR;
125 }
126
127 if (*emlen > klen) {
128 return CALC_EMLEN_ERROR;
129 }
130
131 return VERIFY_OK;
132 }
133
134 /* make generate function V1 */
rsa_gen_mask_mgf_v1(uint8_t * seed,uint32_t seed_len,uint8_t * mask,uint32_t mask_len)135 static int rsa_gen_mask_mgf_v1(uint8_t *seed, uint32_t seed_len,
136 uint8_t *mask, uint32_t mask_len)
137 {
138 int ret = VERIFY_OK;
139 uint32_t cnt = 0;
140 uint32_t cnt_maxsize = 0;
141 uint8_t *p_tmp = NULL;
142 uint8_t *pt = NULL;
143 uint8_t *pc = NULL;
144 const uint32_t hash_len = SHA256_DIGEST_LEN;
145
146 /* Step 1: mask length is smaller than the maximum key length */
147 if (mask_len > bit2byte(RSA_WIDTH_MAX)) {
148 return CALC_MASK_ERROR;
149 }
150
151 /* Step 2: Let pt and pt_tmp be the empty octet string. */
152 pt = (uint8_t *)hvb_malloc(mask_len + hash_len);
153 if (!pt) {
154 return CALC_MASK_ERROR;
155 }
156
157 pc = (uint8_t *)hvb_malloc(seed_len + sizeof(uint32_t));
158 if (!pc) {
159 ret = CALC_MASK_ERROR;
160 goto rsa_error;
161 }
162
163 /*
164 * Step 3: For counter from 0 to (mask_len + hash_len - 1) / hash_len ,
165 * do the following:
166 * string T: T = T || Hash (pseed || counter)
167 */
168 p_tmp = pt;
169 if (hvb_memcpy_s(pc, seed_len, seed, seed_len) != 0) {
170 ret = MEMORY_ERROR;
171 goto rsa_error;
172 }
173
174 if (hvb_memset_s(pc + seed_len, sizeof(uint32_t), 0, sizeof(uint32_t)) != 0) {
175 ret = MEMORY_ERROR;
176 goto rsa_error;
177 }
178 /* step 3.1: count of Hash blocks needed for mask calculation */
179 cnt_maxsize = (uint32_t)((mask_len + hash_len - 1) / hash_len);
180
181 for (cnt = 0; cnt < cnt_maxsize; cnt++) {
182 /* step 3.2: pt_tmp = pseed ||Counter */
183 pc[seed_len + sizeof(uint32_t) - sizeof(uint8_t)] = cnt;
184
185 /* step 3.3: calc T, T = T || Hash (pt_tmp) */
186 if (hash_sha256_single(pc, seed_len + sizeof(uint32_t), p_tmp, hash_len) != HASH_OK) {
187 ret = CALC_MASK_ERROR;
188 goto rsa_error;
189 }
190 p_tmp += hash_len;
191 }
192 /* Step 4: Output the leading L octets of T as the octet string mask. */
193 if (hvb_memcpy_s(mask, mask_len, pt, mask_len) != 0) {
194 ret = MEMORY_ERROR;
195 goto rsa_error;
196 }
197
198 rsa_error:
199 if (pt != NULL)
200 hvb_free(pt);
201 if (pc != NULL)
202 hvb_free(pc);
203 return ret;
204 }
205
emsa_pss_verify_check_db(uint8_t * db,uint32_t db_len,uint32_t emlen,uint32_t digestlen,uint32_t saltlen)206 static int emsa_pss_verify_check_db(uint8_t *db, uint32_t db_len,
207 uint32_t emlen, uint32_t digestlen,
208 uint32_t saltlen)
209 {
210 uint32_t i;
211
212 for (i = 0; i < emlen - digestlen - saltlen - PSS_EM_PADDING_LEN; i++) {
213 if (db[i] != PADDING_UNIT_ZERO) {
214 return CHECK_DB_ERROR;
215 }
216 }
217
218 if (db[db_len - saltlen - PSS_DB_PADDING_LEN] != PADDING_UNIT_ONE) {
219 return CMP_DB_FAIL;
220 }
221
222 return VERIFY_OK;
223 }
224
emsa_pss_verify(uint32_t saltlen,const uint8_t * pdigest,uint32_t digestlen,uint32_t emlen,uint32_t embits,uint8_t * pem)225 static int emsa_pss_verify(uint32_t saltlen, const uint8_t *pdigest,
226 uint32_t digestlen, uint32_t emlen,
227 uint32_t embits, uint8_t *pem)
228 {
229 int ret;
230 uint32_t i;
231 uint32_t masklen;
232 uint32_t m_tmp_len;
233 uint32_t db_len = 0;
234 uint8_t *hash = NULL;
235 uint8_t *m_tmp = NULL;
236 uint8_t *maskedb = NULL;
237 uint8_t *salt = NULL;
238 uint8_t *db = NULL;
239
240 masklen = byte2bit(emlen) - embits;
241
242 /*
243 * Step 1: Skip digest calculate
244 * Step 2: Check sizes, emLen < hLen + sLen + 2
245 */
246 if (emlen < digestlen + PSS_EM_PADDING_LEN || saltlen > (emlen - digestlen - PSS_EM_PADDING_LEN)) {
247 return CALC_EMLEN_ERROR;
248 }
249 /* Step 3: if rightmost of EM is oxbc */
250 if (pem[emlen - PSS_DB_PADDING_LEN] != PSS_END_PADDING_UNIT) {
251 return CALC_0XBC_ERROR;
252 }
253
254 /* Step 4: set maskedDB and H */
255 maskedb = pem;
256 db_len = emlen - digestlen - PSS_DB_PADDING_LEN;
257 hash = &pem[db_len];
258
259 /* Step 5: Check that the leftmost bits in the leftmost octet of EM have the value 0 */
260 if ((maskedb[0] & (~(PSS_LEFTMOST_BIT_MASK >> masklen))) != 0) {
261 return CALC_EM_ERROR;
262 }
263
264 /* Step 6: calc dbMask, MGF(H) */
265 db = (uint8_t *)hvb_malloc(db_len); /* db is dbmask */
266 if (!db) {
267 return CALC_DB_ERROR;
268 }
269 ret = rsa_gen_mask_mgf_v1(hash, digestlen, db, db_len);
270 if (ret != VERIFY_OK) {
271 goto rsa_error;
272 }
273 /* Step 7: calc db, maskedDB ^ db_mask */
274 for (i = 0; i < db_len; i++) {
275 db[i] = maskedb[i] ^ db[i];
276 }
277
278 /* Step 8: Set the leftmost 8*emLen-emBits bits in DB to zero */
279 db[0] &= PSS_LEFTMOST_BIT_MASK >> masklen;
280
281 /* Step 9: check db padding data */
282 ret = emsa_pss_verify_check_db(db, db_len, emlen, digestlen, saltlen);
283 if (ret != VERIFY_OK) {
284 goto rsa_error;
285 }
286 /* Step 10: set salt be the last slen of DB */
287 if (saltlen != 0) {
288 salt = &db[db_len - saltlen];
289 }
290
291 /* Step 11: calc M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt */
292 ret = emsa_pss_calc_m(pdigest, digestlen, salt, saltlen, &m_tmp);
293 if (ret != VERIFY_OK) {
294 goto rsa_error;
295 }
296 /* Step 12: hash_tmp = H' = Hash(M') */
297 m_tmp_len = PSS_MTMP_PADDING_LEN + digestlen + saltlen;
298 ret = emsa_pss_hash_cmp(m_tmp, m_tmp_len, hash, digestlen);
299
300 rsa_error:
301 if (db != NULL)
302 hvb_free(db);
303 if (m_tmp != NULL)
304 hvb_free(m_tmp);
305 return ret;
306 }
307
invert_copy(uint8_t * dst,uint8_t * src,uint32_t len)308 static inline void invert_copy(uint8_t *dst, uint8_t *src, uint32_t len)
309 {
310 for (uint32_t i = 0; i < len; i++) {
311 dst[i] = src[len - i - 1];
312 }
313 }
314
hvb_rsa_verify_pss_param_check(const struct hvb_rsa_pubkey * pkey,const uint8_t * pdigest,uint32_t digestlen,uint8_t * psign,uint32_t signlen)315 static int hvb_rsa_verify_pss_param_check(const struct hvb_rsa_pubkey *pkey, const uint8_t *pdigest,
316 uint32_t digestlen, uint8_t *psign, uint32_t signlen)
317 {
318 uint32_t klen;
319 uint32_t n_validlen;
320
321 if (!pkey || !pdigest || !psign) {
322 return PARAM_EMPTY_ERROR;
323 }
324 if (!pkey->pn || !pkey->p_rr || pkey->n_n0_i == 0) {
325 return PUBKEY_EMPTY_ERROR;
326 }
327 klen = bit2byte(pkey->width);
328 n_validlen = bn_get_valid_len(pkey->pn, pkey->nlen);
329 if (digestlen != SHA256_DIGEST_LEN) {
330 return DIGEST_LEN_ERROR;
331 }
332 if (n_validlen != klen || pkey->rlen > pkey->nlen) {
333 return PUBKEY_LEN_ERROR;
334 }
335 if (signlen > klen) {
336 return SIGN_LEN_ERROR;
337 }
338
339 return VERIFY_OK;
340 }
341
hvb_rsa_verify_pss_param_convert(const struct hvb_rsa_pubkey * pkey,uint8_t * psign,uint32_t signlen,struct long_int_num * p_n,struct long_int_num * p_rr,struct long_int_num * p_m)342 static int hvb_rsa_verify_pss_param_convert(const struct hvb_rsa_pubkey *pkey, uint8_t *psign,
343 uint32_t signlen, struct long_int_num *p_n,
344 struct long_int_num *p_rr, struct long_int_num *p_m)
345 {
346 if (!p_n)
347 return PUBKEY_EMPTY_ERROR;
348 invert_copy((uint8_t *)p_n->p_uint, pkey->pn, pkey->nlen);
349 p_n->valid_word_len = byte2dword(pkey->nlen);
350 lin_update_valid_len(p_n);
351
352 if (!p_m)
353 return SIGN_EMPTY_ERROR;
354 invert_copy((uint8_t *)p_m->p_uint, psign, signlen);
355 p_m->valid_word_len = byte2dword(pkey->nlen);
356 lin_update_valid_len(p_m);
357
358 if (!p_rr)
359 return PUBKEY_EMPTY_ERROR;
360 invert_copy((uint8_t *)p_rr->p_uint, pkey->p_rr, pkey->rlen);
361 p_rr->valid_word_len = byte2dword(pkey->nlen);
362 lin_update_valid_len(p_rr);
363
364 return VERIFY_OK;
365 }
366
hvb_rsa_verify_pss(const struct hvb_rsa_pubkey * pkey,const uint8_t * pdigest,uint32_t digestlen,uint8_t * psign,uint32_t signlen,uint32_t saltlen)367 int hvb_rsa_verify_pss(const struct hvb_rsa_pubkey
368 *pkey, const uint8_t *pdigest,
369 uint32_t digestlen, uint8_t *psign,
370 uint32_t signlen, uint32_t saltlen)
371 {
372 int ret;
373 uint32_t klen;
374 uint32_t emlen;
375 uint32_t embits;
376 unsigned long n_n0_i;
377 struct long_int_num *p_n = NULL;
378 struct long_int_num *p_m = NULL;
379 struct long_int_num *p_rr = NULL;
380 struct long_int_num *em = NULL;
381 uint8_t *em_data = NULL;
382
383 ret = hvb_rsa_verify_pss_param_check(pkey, pdigest, digestlen, psign, signlen);
384 if (ret != VERIFY_OK) {
385 return ret;
386 }
387
388 n_n0_i = (unsigned long)pkey->n_n0_i;
389 klen = bit2byte(pkey->width);
390 p_n = lin_create(byte2dword(pkey->nlen));
391 if (!p_n) {
392 return MEMORY_ERROR;
393 }
394 p_m = lin_create(byte2dword(pkey->nlen));
395 if (!p_m) {
396 ret = MEMORY_ERROR;
397 goto rsa_error;
398 }
399 p_rr = lin_create(byte2dword(pkey->nlen));
400 if (!p_rr) {
401 ret = MEMORY_ERROR;
402 goto rsa_error;
403 }
404 ret = hvb_rsa_verify_pss_param_convert(pkey, psign, signlen, p_n, p_rr, p_m);
405 if (ret != VERIFY_OK) {
406 goto rsa_error;
407 }
408 /* Step 1: RSA prim decrypt */
409 em = montgomery_mod_exp(p_m, p_n, n_n0_i, p_rr, pkey->e);
410 if (!em) {
411 ret = MOD_EXP_CALC_FAIL;
412 goto rsa_error;
413 }
414 lin_update_valid_len(em);
415 em_data = hvb_malloc(klen);
416 if (!em_data) {
417 ret = MOD_EXP_CALC_FAIL;
418 goto rsa_error;
419 }
420
421 if (hvb_memset_s(em_data, klen, 0, klen) != 0) {
422 ret = MEMORY_ERROR;
423 goto rsa_error;
424 }
425 invert_copy(em_data, (uint8_t *)em->p_uint, klen);
426 /* Step 2: emsa pss verify */
427 ret = rsa_pss_get_emlen(klen, p_n, &emlen, &embits);
428 if (ret != VERIFY_OK) {
429 goto rsa_error;
430 }
431 if (klen - emlen == 1 && em_data[0] != 0) {
432 ret = MOD_EXP_CALC_FAIL;
433 goto rsa_error;
434 }
435 ret = emsa_pss_verify(saltlen, pdigest, digestlen, emlen, embits, em_data + klen - emlen);
436
437 rsa_error:
438 lin_free(em);
439 lin_free(p_n);
440 lin_free(p_m);
441 lin_free(p_rr);
442 if (em_data) {
443 hvb_free(em_data);
444 }
445
446 return ret;
447 }
448