• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 Huawei Technologies Co., Ltd.
3  * Licensed under the Mulan PSL v2.
4  * You can use this software according to the terms and conditions of the Mulan PSL v2.
5  * You may obtain a copy of Mulan PSL v2 at:
6  *     http://license.coscl.org.cn/MulanPSL2
7  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
8  * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
9  * PURPOSE.
10  * See the Mulan PSL v2 for more details.
11  */
12 #include "tee_elf_verify_openssl.h"
13 #include "tee_log.h"
14 #include "securec.h"
15 #include "ta_load_key.h"
16 #include "tee_v3_elf_verify.h"
17 #include "wb_aes_decrypt.h"
18 #include <openssl/obj_mac.h>
19 #include <openssl/evp.h>
20 #include <evp/evp_local.h>
21 #include <openssl/sha.h>
22 #include <openssl/bn.h>
23 #include <openssl/rsa.h>
24 #include <openssl/ecdh.h>
25 #include <openssl/ec.h>
26 #include <openssl/hmac.h>
27 #include <openssl/rsa.h>
28 #include "tee_load_key_ops.h"
29 #include <tee_crypto_signature_verify.h>
30 #include "wb_tool_root_key.h"
31 
32 static const char *g_ecies_hmac_salt = "salt for ecies kdf";
33 
tee_secure_img_decrypt_cipher_layer(const uint8_t * cipher_layer,uint32_t cipher_size,uint8_t * plaintext_layer,uint32_t * plaintext_size)34 TEE_Result tee_secure_img_decrypt_cipher_layer(const uint8_t *cipher_layer, uint32_t cipher_size,
35     uint8_t *plaintext_layer, uint32_t *plaintext_size)
36 {
37     bool check = (cipher_layer == NULL || cipher_size == 0 || plaintext_layer == NULL ||
38         plaintext_size == NULL || *plaintext_size < cipher_size);
39     if (check)
40         return TEE_ERROR_BAD_PARAMETERS;
41 
42     enum ta_type type = V3_TYPE_2048;
43     if (judge_rsa_key_type(cipher_size, &type) != TEE_SUCCESS)
44         return TEE_ERROR_BAD_PARAMETERS;
45 
46     RSA *ta_load_priv_key = NULL;
47 
48     ta_load_priv_key = get_private_key(CIPHER_LAYER_VERSION, type);
49     if (ta_load_priv_key == NULL)
50         return TEE_ERROR_GENERIC;
51     /* key size 2048, RSA OAEP mode */
52     int32_t out_len = RSA_private_decrypt(cipher_size, (const uint8_t *)cipher_layer, (uint8_t *)plaintext_layer,
53                                           ta_load_priv_key, RSA_PKCS1_OAEP_PADDING);
54 
55     free_private_key(ta_load_priv_key);
56     ta_load_priv_key = NULL;
57 
58     if (out_len < 0) {
59         tloge("Failed to decrypt cipher layer of TA image: cipher len=%u\n", get_v3_cipher_layer_len());
60         return TEE_ERROR_GENERIC;
61     }
62 
63     *plaintext_size = (uint32_t)out_len;
64     return TEE_SUCCESS;
65 }
66 
free_rsa_bn_n(BIGNUM * bn_n,BIGNUM * bn_e,BIGNUM * bn_d,BIGNUM * bn_p)67 static void free_rsa_bn_n(BIGNUM *bn_n, BIGNUM *bn_e, BIGNUM *bn_d, BIGNUM *bn_p)
68 {
69     BN_free(bn_n);
70     BN_free(bn_e);
71     BN_free(bn_d);
72     BN_free(bn_p);
73 }
74 
free_rsa_bn_q(BIGNUM * bn_q,BIGNUM * bn_dp,BIGNUM * bn_dq,BIGNUM * bn_qinv)75 static void free_rsa_bn_q(BIGNUM *bn_q, BIGNUM *bn_dp, BIGNUM *bn_dq, BIGNUM *bn_qinv)
76 {
77     BN_free(bn_q);
78     BN_free(bn_dp);
79     BN_free(bn_dq);
80     BN_free(bn_qinv);
81 }
82 
compute_rsa_big_num_ed(const BIGNUM * bn_p,const BIGNUM * bn_q,BN_CTX * ctx,BIGNUM ** bn_e,BIGNUM ** bn_d)83 static int32_t compute_rsa_big_num_ed(const BIGNUM *bn_p, const BIGNUM *bn_q, BN_CTX *ctx, BIGNUM **bn_e,
84                                       BIGNUM **bn_d)
85 {
86     BIGNUM *tmp1 = BN_dup(bn_p);
87     BIGNUM *tmp2 = BN_dup(bn_q);
88     BIGNUM *tmp3 = BN_new();
89     BIGNUM *gcd = BN_new();
90     bool is_abnormal = (tmp1 == NULL) || (tmp2 == NULL) || (tmp3 == NULL) || (gcd == NULL);
91     if (is_abnormal) {
92         tloge("Duplicate or new big num failed\n");
93         goto error;
94     }
95 
96     int32_t ret1 = BN_sub_word(tmp1, 1);
97     int32_t ret2 = BN_sub_word(tmp2, 1);
98     is_abnormal = (ret1 != 1) || (ret2 != 1);
99     if (is_abnormal) {
100         tloge("Big num sub 1 failed, ret1=%d, ret2=%d\n", ret1, ret2);
101         goto error;
102     }
103 
104     ret1 = BN_gcd(gcd, tmp1, tmp2, ctx);
105     ret2 = BN_div(tmp1, tmp3, tmp1, gcd, ctx);
106     is_abnormal = (ret1 != 1) || (ret2 != 1);
107     if (is_abnormal) {
108         tloge("Big num gcd div failed, ret1=%d, ret2=%d\n", ret1, ret2);
109         goto error;
110     }
111     ret1 = BN_set_word(*bn_e, RSA_F4);
112     ret2 = BN_mul(tmp2, tmp2, tmp1, ctx);
113     is_abnormal = (ret1 != 1) || (ret2 != 1);
114     if (is_abnormal) {
115         tloge("compute e and d failed, ret1=%d, ret2=%d\n", ret1, ret2);
116         goto error;
117     }
118     /* BN_mod_inverse return value is not new allocated, can not be free */
119     if (BN_mod_inverse(*bn_d, *bn_e, tmp2, ctx) == NULL) {
120         tloge("Get big num d by mod inverse failed\n");
121         goto error;
122     }
123 
124     free_rsa_bn_n(tmp1, tmp2, tmp3, gcd);
125     return 1;
126 error:
127     free_rsa_bn_n(tmp1, tmp2, tmp3, gcd);
128     return 0;
129 }
130 
get_rsa_big_num_ned(const BIGNUM * bn_p,const BIGNUM * bn_q,BIGNUM ** bn_n,BIGNUM ** bn_e,BIGNUM ** bn_d)131 static int32_t get_rsa_big_num_ned(const BIGNUM *bn_p, const BIGNUM *bn_q, BIGNUM **bn_n, BIGNUM **bn_e,
132                                    BIGNUM **bn_d)
133 {
134     BN_CTX *ctx = BN_CTX_new();
135     if (ctx == NULL) {
136         tloge("New bn ctx fail\n");
137         return 0;
138     }
139 
140     int32_t ret = BN_mul(*bn_n, bn_p, bn_q, ctx);
141     if (ret != 1) {
142         tloge("Mul big num fail\n");
143         BN_CTX_free(ctx);
144         return 0;
145     }
146 
147     ret = compute_rsa_big_num_ed(bn_p, bn_q, ctx, bn_e, bn_d);
148     BN_CTX_free(ctx);
149     if (ret != 1) {
150         tloge("Compute big num e and d failed\n");
151         return 0;
152     }
153 
154     return 1;
155 }
156 
get_rsa_big_num_n(BIGNUM * bn_p,BIGNUM * bn_q,BIGNUM ** bn_n)157 static int32_t get_rsa_big_num_n(BIGNUM *bn_p, BIGNUM *bn_q, BIGNUM **bn_n)
158 {
159     BN_CTX *ctx = BN_CTX_new();
160     if (ctx == NULL) {
161         tloge("New bn ctx failed\n");
162         return 0;
163     }
164 
165     int32_t ret = BN_mul(*bn_n, bn_p, bn_q, ctx);
166     if (ret != 1) {
167         tloge("Mul big num failed\n");
168         BN_CTX_free(ctx);
169         return 0;
170     }
171 
172     BN_CTX_free(ctx);
173 
174     return 1;
175 }
176 
177 struct boringssl_priv_key_st {
178     BIGNUM *bn_n;
179     BIGNUM *bn_e;
180     BIGNUM *bn_d;
181     BIGNUM *bn_p;
182     BIGNUM *bn_q;
183     BIGNUM *bn_dp;
184     BIGNUM *bn_dq;
185     BIGNUM *bn_qinv;
186 };
187 
rsa_priv_key_transform(const struct rsa_priv_key * priv_key,struct boringssl_priv_key_st * key)188 static TEE_Result rsa_priv_key_transform(const struct rsa_priv_key *priv_key, struct boringssl_priv_key_st *key)
189 {
190     key->bn_n = BN_new();
191     key->bn_e = BN_new();
192     key->bn_d = BN_new();
193     bool is_abnormal = (key->bn_n == NULL) || (key->bn_e == NULL) || (key->bn_d == NULL);
194     if (is_abnormal) {
195         tloge("New big num n or e or d failed\n");
196         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, NULL);
197         return TEE_ERROR_OUT_OF_MEMORY;
198     }
199 
200     key->bn_p = BN_bin2bn(priv_key->p, get_effective_size(priv_key->p, priv_key->p_size), NULL);
201     key->bn_q = BN_bin2bn(priv_key->q, get_effective_size(priv_key->q, priv_key->q_size), NULL);
202     key->bn_dq = BN_bin2bn(priv_key->dq, get_effective_size(priv_key->dq, priv_key->dq_size), NULL);
203     key->bn_dp = BN_bin2bn(priv_key->dp, get_effective_size(priv_key->dp, priv_key->dp_size), NULL);
204     key->bn_qinv = BN_bin2bn(priv_key->qinv, get_effective_size(priv_key->qinv, priv_key->qinv_size), NULL);
205     is_abnormal =
206         (key->bn_p == NULL || key->bn_q == NULL || key->bn_dp == NULL || key->bn_dq == NULL || key->bn_qinv == NULL);
207     if (is_abnormal) {
208         tloge("change buffer to BIGNUM is error!");
209         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, key->bn_p);
210         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
211         return TEE_ERROR_GENERIC;
212     }
213 
214     int32_t result = get_rsa_big_num_ned(key->bn_p, key->bn_q, &key->bn_n, &key->bn_e, &key->bn_d);
215     if (result != 1) {
216         tloge("get rsa key error!");
217         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, key->bn_p);
218         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
219         return TEE_ERROR_GENERIC;
220     }
221 
222     return TEE_SUCCESS;
223 }
224 
proc_build_rsa_key(struct boringssl_priv_key_st * key)225 static RSA *proc_build_rsa_key(struct boringssl_priv_key_st *key)
226 {
227     RSA *rsa_key = RSA_new();
228     if (rsa_key == NULL) {
229         tloge("Malloc memory for rsa key failed\n");
230         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, key->bn_p);
231         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
232         return NULL;
233     }
234 
235     int32_t result = RSA_set0_key(rsa_key, key->bn_n, key->bn_e, key->bn_d);
236     if (result != 1) {
237         tloge("RSA_set0_key failed\n");
238         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, key->bn_p);
239         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
240         RSA_free(rsa_key);
241         return NULL;
242     }
243 
244     result = RSA_set0_factors(rsa_key, key->bn_p, key->bn_q);
245     if (result != 1) {
246         tloge("RSA_set0_factors failed\n");
247         free_rsa_bn_n(NULL, NULL, NULL, key->bn_p);
248         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
249         RSA_free(rsa_key);
250         return NULL;
251     }
252     result = RSA_set0_crt_params(rsa_key, key->bn_dp, key->bn_dq, key->bn_qinv);
253     if (result != 1) {
254         tloge("RSA_set0_crt_params\n");
255         free_rsa_bn_q(NULL, key->bn_dp, key->bn_dq, key->bn_qinv);
256         RSA_free(rsa_key);
257         return NULL;
258     }
259     return rsa_key;
260 }
261 
rsa_build_key(const struct rsa_priv_key * priv_key)262 static RSA *rsa_build_key(const struct rsa_priv_key *priv_key)
263 {
264     struct boringssl_priv_key_st key = {0};
265 
266     TEE_Result ret = rsa_priv_key_transform(priv_key, &key);
267     if (ret != TEE_SUCCESS) {
268         tloge("Failed to transform private key to boringssl format!");
269         return NULL;
270     }
271 
272     return proc_build_rsa_key(&key);
273 }
274 
rsa_priv_key_transform_with_ed(const struct rsa_priv_key * priv_key,struct boringssl_priv_key_st * key)275 static TEE_Result rsa_priv_key_transform_with_ed(const struct rsa_priv_key *priv_key,
276                                                  struct boringssl_priv_key_st *key)
277 {
278     key->bn_n = BN_new();
279     bool is_abnormal = (key->bn_n == NULL);
280     if (is_abnormal) {
281         tloge("New big num n or e or d failed\n");
282         free_rsa_bn_n(key->bn_n, NULL, NULL, NULL);
283         return TEE_ERROR_OUT_OF_MEMORY;
284     }
285     key->bn_p = BN_bin2bn(priv_key->p, get_effective_size(priv_key->p, priv_key->p_size), NULL);
286     key->bn_q = BN_bin2bn(priv_key->q, get_effective_size(priv_key->q, priv_key->q_size), NULL);
287     key->bn_dq = BN_bin2bn(priv_key->dq, get_effective_size(priv_key->dq, priv_key->dq_size), NULL);
288     key->bn_dp = BN_bin2bn(priv_key->dp, get_effective_size(priv_key->dp, priv_key->dp_size), NULL);
289     key->bn_qinv = BN_bin2bn(priv_key->qinv, get_effective_size(priv_key->qinv, priv_key->qinv_size), NULL);
290     key->bn_d = BN_bin2bn(priv_key->d, get_effective_size(priv_key->d, priv_key->d_size), NULL);
291     key->bn_e = BN_bin2bn(priv_key->e, get_effective_size(priv_key->e, priv_key->e_size), NULL);
292     is_abnormal  = (key->bn_p == NULL || key->bn_q == NULL || key->bn_dp == NULL || key->bn_dq == NULL ||
293                    key->bn_qinv == NULL || key->bn_d == NULL || key->bn_e == NULL);
294     if (is_abnormal) {
295         tloge("change buffer to BIGNUM is error!");
296         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, key->bn_p);
297         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
298         return TEE_ERROR_GENERIC;
299     }
300 
301     int32_t result = get_rsa_big_num_n(key->bn_p, key->bn_q, &key->bn_n);
302     if (result != 1) {
303         tloge("get rsa key error!");
304         free_rsa_bn_n(key->bn_n, key->bn_e, key->bn_d, key->bn_p);
305         free_rsa_bn_q(key->bn_q, key->bn_dp, key->bn_dq, key->bn_qinv);
306         return TEE_ERROR_GENERIC;
307     }
308 
309     return TEE_SUCCESS;
310 }
311 
rsa_build_key_with_ed(const struct rsa_priv_key * priv_key)312 static RSA *rsa_build_key_with_ed(const struct rsa_priv_key *priv_key)
313 {
314     struct boringssl_priv_key_st key = {0};
315 
316     TEE_Result ret = rsa_priv_key_transform_with_ed(priv_key, &key);
317     if (ret != TEE_SUCCESS) {
318         tloge("Failed to transform private key to boringssl format!");
319         return NULL;
320     }
321 
322     return proc_build_rsa_key(&key);
323 }
324 
325 struct ecies_kem_data_st {
326     BIGNUM *d;
327     EC_KEY *ec1;
328     EC_KEY *ec2;
329     EC_POINT *ecp;
330     EC_GROUP *group;
331     uint8_t secret[AES_KEY_LEN];
332 };
333 
ecies_kem_cleanup(struct ecies_kem_data_st * ctx)334 static void ecies_kem_cleanup(struct ecies_kem_data_st *ctx)
335 {
336     BN_clear_free(ctx->d);
337     EC_POINT_free(ctx->ecp);
338     EC_GROUP_free(ctx->group);
339     EC_KEY_free(ctx->ec1);
340     EC_KEY_free(ctx->ec2);
341 }
342 
ecies_kem_init(const struct ecc_derive_data_st * ecc_data,struct ecies_kem_data_st * ctx)343 static TEE_Result ecies_kem_init(const struct ecc_derive_data_st *ecc_data, struct ecies_kem_data_st *ctx)
344 {
345     ctx->ec1 = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1);
346     ctx->ec2 = EC_KEY_new();
347     ctx->group = EC_GROUP_new_by_curve_name(NID_X9_62_prime256v1);
348 
349     bool check = (ctx->ec1 == NULL || ctx->ec2 == NULL || ctx->group == NULL);
350     if (check) {
351         ecies_kem_cleanup(ctx);
352         tloge("new ec key failed\n");
353         return TEE_ERROR_GENERIC;
354     }
355 
356     ctx->d = BN_bin2bn(ecc_data->ec1_priv, ECIES_PRIV_LEN, ctx->d);
357     if (ctx->d == NULL) {
358         tloge("bin2bn failed\n");
359         ecies_kem_cleanup(ctx);
360         return TEE_ERROR_GENERIC;
361     }
362 
363     int32_t ret = EC_KEY_set_private_key(ctx->ec1, ctx->d);
364     if (ret == 0) {
365         tloge("set private key failed\n");
366         ecies_kem_cleanup(ctx);
367         return TEE_ERROR_GENERIC;
368     }
369 
370     ret = EC_KEY_set_group(ctx->ec2, ctx->group);
371     if (ret == 0) {
372         tloge("set ec group failed\n");
373         ecies_kem_cleanup(ctx);
374         return TEE_ERROR_GENERIC;
375     }
376 
377     ctx->ecp = EC_POINT_new(ctx->group);
378     if (ctx->ecp == NULL) {
379         tloge("new ec point failed\n");
380         ecies_kem_cleanup(ctx);
381         return TEE_ERROR_GENERIC;
382     }
383 
384     ret = EC_POINT_oct2point(ctx->group, ctx->ecp, ecc_data->ec2_pub, ecc_data->ec2_len, NULL);
385     if (ret == 0) {
386         tloge("ec oct2point failed\n");
387         ecies_kem_cleanup(ctx);
388         return TEE_ERROR_GENERIC;
389     }
390 
391     return TEE_SUCCESS;
392 }
393 
394 /* generate a 256 bytes AES key with ECDH + HMAC */
ecies_kem_decrypt(const struct ecc_derive_data_st * ecc_data,uint8_t * key,uint32_t key_len)395 int32_t ecies_kem_decrypt(const struct ecc_derive_data_st *ecc_data, uint8_t *key, uint32_t key_len)
396 {
397     struct ecies_kem_data_st ctx = {0};
398     uint8_t *hmac = NULL;
399 
400     bool check = (ecc_data == NULL || key == NULL || ecc_data->ec1_len != ECIES_PRIV_LEN
401         || ecc_data->ec2_len != ECIES_PUB_LEN);
402     if (check) {
403         tloge("key len invalid, %u/%u\n", ecc_data->ec1_len, ecc_data->ec2_len);
404         return -1;
405     }
406 
407     TEE_Result ret = ecies_kem_init(ecc_data, &ctx);
408     if (ret != TEE_SUCCESS) {
409         tloge("Failed to initialize ctx\n");
410         return -1;
411     }
412 
413     int32_t result = ECDH_compute_key(ctx.secret, sizeof(ctx.secret), ctx.ecp, ctx.ec1, NULL);
414     if (result <= 0) {
415         tloge("ecdh compute failed\n");
416         ecies_kem_cleanup(&ctx);
417         return -1;
418     }
419 
420     hmac = HMAC(EVP_sha256(), ctx.secret, sizeof(ctx.secret), (uint8_t *)g_ecies_hmac_salt,
421                 strlen(g_ecies_hmac_salt) + 1, key, &key_len);
422     ecies_kem_cleanup(&ctx);
423     if (hmac == NULL) {
424         tloge("hkdf failed\n");
425         return -1;
426     }
427 
428     return 0;
429 }
430 
aes_cbc_256_decrypt(const uint8_t * key,const uint8_t * iv,const uint8_t * in,uint32_t in_len,uint8_t * out)431 int32_t aes_cbc_256_decrypt(const uint8_t *key, const uint8_t *iv,
432     const uint8_t *in, uint32_t in_len, uint8_t *out)
433 {
434     EVP_CIPHER_CTX ctx = {0};
435     int32_t len, len2;
436 
437     if (key == NULL || iv == NULL || in == NULL || out == NULL)
438         goto clean;
439 
440     int32_t ret = EVP_DecryptInit(&ctx, EVP_aes_256_cbc(), key, iv);
441     if (ret == 0) {
442         tloge("decrypt init failed\n");
443         goto clean;
444     }
445 
446     ret = EVP_DecryptUpdate(&ctx, out, &len, in, in_len);
447     if (ret == 0) {
448         tloge("decrypt update failed\n");
449         goto clean;
450     }
451 
452     ret = EVP_DecryptFinal_ex(&ctx, out + len, &len2);
453     if (ret == 0) {
454         tloge("decrypt final failed\n");
455         goto clean;
456     }
457 
458     bool check = (len < 0 || len2 < 0);
459     if (check) {
460         tloge("error decrypt len,update:%d, final:%d\n", len, len2);
461         goto clean;
462     }
463 
464     if (len + len2 < len) {
465         tloge("len and len2's addition may overflow\n");
466         goto clean;
467     }
468     EVP_CIPHER_CTX_reset(&ctx);
469     return len + len2;
470 clean:
471     EVP_CIPHER_CTX_reset(&ctx);
472     return -1;
473 }
474 
get_private_key_ecies(int32_t img_version,enum ta_type type)475 static RSA *get_private_key_ecies(int32_t img_version, enum ta_type type)
476 {
477     uint8_t aes_key[AES_KEY_LEN];
478     const struct ecies_key_struct *ecies_key_data = NULL;
479     struct rsa_priv_key *priv_key = NULL;
480     RSA *ret_key = NULL;
481 
482     ecies_key_data = get_ecies_key_data(img_version, type);
483     if (ecies_key_data == NULL) {
484         tloge("Failed to get ecies key data\n");
485         return NULL;
486     }
487 
488     TEE_Result ret = get_rsa_priv_aes_key(ecies_key_data, aes_key, sizeof(aes_key));
489     if (ret != TEE_SUCCESS) {
490         tloge("Failed to get AES key to decrypt RSA private components\n");
491         return NULL;
492     }
493 
494     priv_key = TEE_Malloc(sizeof(struct rsa_priv_key), 0);
495     if (priv_key == NULL)
496         return NULL;
497 
498     ret = aes_decrypt_rsa_private(ecies_key_data, aes_key, sizeof(aes_key), priv_key);
499     (void)memset_s(aes_key, sizeof(aes_key), 0, sizeof(aes_key));
500     if (ret != TEE_SUCCESS) {
501         tloge("Failed to decrypt RSA private components\n");
502         TEE_Free(priv_key);
503         return NULL;
504     }
505 
506     ret_key = rsa_build_key_with_ed(priv_key);
507     (void)memset_s(priv_key, sizeof(*priv_key), 0, sizeof(*priv_key));
508     TEE_Free(priv_key);
509     return ret_key;
510 }
511 
get_wb_key_data(int32_t img_version,enum ta_type type)512 static struct wb_key_struct *get_wb_key_data(int32_t img_version, enum ta_type type)
513 {
514     TEE_Result ret;
515     struct key_data key_data = {
516         .pro_type = WB_KEY,
517         .ta_type = type,
518         .key = NULL,
519         .key_len = 0,
520     };
521 
522     ret = get_key_data(img_version, &key_data);
523     if (ret != TEE_SUCCESS) {
524         tloge("get wb key failed for version:%d\n", img_version);
525         return NULL;
526     }
527 
528     if (key_data.key_len != sizeof(struct wb_key_struct)) {
529         tloge("get wb key len error\n");
530         return NULL;
531     }
532 
533     return (struct wb_key_struct *)key_data.key;
534 }
535 
get_wb_tool_internal_key(int32_t img_version,struct wb_tool_inter_key * inter_key)536 static TEE_Result get_wb_tool_internal_key(int32_t img_version, struct wb_tool_inter_key *inter_key)
537 {
538     struct wb_tool_key tool_key = {0};
539 
540     tool_key.tool_ver = WB_TOOL_KEY_128;
541 
542     /* Only v3 use new white box table2 key. */
543     if (img_version == CIPHER_LAYER_VERSION)
544         tool_key.tool_ver = WB_TOOL_KEY_256;
545 
546     if (get_wb_tool_key(&tool_key) != TEE_SUCCESS)
547         return TEE_ERROR_GENERIC;
548 
549     inter_key->iv = tool_key.iv;
550     inter_key->table2 = tool_key.table2;
551     inter_key->round_num = tool_key.round_num;
552 
553     return TEE_SUCCESS;
554 }
555 
get_white_box_private_key(int32_t img_version,enum ta_type type)556 static RSA *get_white_box_private_key(int32_t img_version, enum ta_type type)
557 {
558     RSA *ret_key = NULL;
559     struct rsa_priv_key *priv_key = NULL;
560     struct wb_key_struct *wb_key = NULL;
561     struct wb_tool_inter_key inter_key = {0};
562 
563     wb_key = get_wb_key_data(img_version, type);
564     if (wb_key == NULL) {
565         tloge("get wb key data failed\n");
566         return NULL;
567     }
568 
569     if (get_wb_tool_internal_key(img_version, &inter_key) != TEE_SUCCESS)
570         return NULL;
571 
572     priv_key = TEE_Malloc(sizeof(struct rsa_priv_key), 0);
573     if (priv_key == NULL)
574         return NULL;
575 
576     bool temp_check =
577         (wb_aes_decrypt_cbc(&inter_key, wb_key->wb_rsa_priv_p,
578             wb_key->wb_rsa_priv_p_len, priv_key->p, &priv_key->p_size) != 0) ||
579         (wb_aes_decrypt_cbc(&inter_key, wb_key->wb_rsa_priv_q,
580             wb_key->wb_rsa_priv_q_len, priv_key->q, &priv_key->q_size) != 0) ||
581         (wb_aes_decrypt_cbc(&inter_key, wb_key->wb_rsa_priv_dp,
582             wb_key->wb_rsa_priv_dp_len, priv_key->dp, &priv_key->dp_size) != 0) ||
583         (wb_aes_decrypt_cbc(&inter_key, wb_key->wb_rsa_priv_dq,
584             wb_key->wb_rsa_priv_dq_len, priv_key->dq, &priv_key->dq_size) != 0) ||
585         (wb_aes_decrypt_cbc(&inter_key, wb_key->wb_rsa_priv_qinv,
586             wb_key->wb_rsa_priv_qinv_len, priv_key->qinv, &priv_key->qinv_size) != 0);
587     if (temp_check) {
588         tloge("whitebox generate private key failed\n");
589         TEE_Free(priv_key);
590         return NULL;
591     }
592 
593     temp_check = (priv_key->p_size > sizeof(wb_key->wb_rsa_priv_p) ||
594                   priv_key->q_size > sizeof(wb_key->wb_rsa_priv_q) ||
595                   priv_key->dp_size > sizeof(wb_key->wb_rsa_priv_dp) ||
596                   priv_key->dq_size > sizeof(wb_key->wb_rsa_priv_dq) ||
597                   priv_key->qinv_size > sizeof(wb_key->wb_rsa_priv_qinv));
598     if (temp_check) {
599         tloge("generate private key len failed\n");
600         TEE_Free(priv_key);
601         return NULL;
602     }
603 
604     ret_key = rsa_build_key(priv_key);
605     (void)memset_s(priv_key, sizeof(*priv_key), 0, sizeof(*priv_key));
606     TEE_Free(priv_key);
607     return ret_key;
608 }
609 
get_private_key_v2(int32_t img_version,enum ta_type type)610 static RSA *get_private_key_v2(int32_t img_version, enum ta_type type)
611 {
612     bool is_wb_key = is_wb_protecd_ta_key();
613     if (is_wb_key)
614         return get_white_box_private_key(img_version, type);
615     else
616         return get_private_key_ecies(img_version, type);
617 }
618 
fill_priv_key_size(struct rsa_priv_key * priv_key)619 static void fill_priv_key_size(struct rsa_priv_key *priv_key)
620 {
621     priv_key->p_size = WITHOUT_ZERO;
622     priv_key->q_size = WITHOUT_ZERO;
623     priv_key->dp_size = WITHOUT_ZERO;
624     priv_key->dq_size = WITHOUT_ZERO;
625     priv_key->qinv_size = WITHOUT_ZERO;
626 }
627 
convert_v1_key(const uint8_t * key_buffer,uint32_t key_len,struct rsa_priv_key * priv_key)628 static TEE_Result convert_v1_key(const uint8_t *key_buffer, uint32_t key_len, struct rsa_priv_key *priv_key)
629 {
630     errno_t eret;
631     uint32_t off_set = 0;
632 
633     if (key_len < RESULT1)
634         return TEE_ERROR_BAD_PARAMETERS;
635 
636     /* WITH_ZERO = WITHOUT_ZERO + 1, makesure cpy never overflow */
637     eret = memcpy_s(priv_key->p, sizeof(priv_key->p), key_buffer + off_set, WITHOUT_ZERO);
638     if (eret != EOK)
639         return TEE_ERROR_SECURITY;
640     off_set += WITHOUT_ZERO;
641 
642     eret = memcpy_s(priv_key->q, sizeof(priv_key->q), key_buffer + off_set, WITHOUT_ZERO);
643     if (eret != EOK)
644         return TEE_ERROR_SECURITY;
645     off_set += WITHOUT_ZERO;
646 
647     eret = memcpy_s(priv_key->dp, sizeof(priv_key->dp), key_buffer + off_set, WITHOUT_ZERO);
648     if (eret != EOK)
649         return TEE_ERROR_SECURITY;
650     off_set += WITHOUT_ZERO;
651 
652     eret = memcpy_s(priv_key->dq, sizeof(priv_key->dq), key_buffer + off_set, WITHOUT_ZERO);
653     if (eret != EOK)
654         return TEE_ERROR_SECURITY;
655     off_set += WITHOUT_ZERO;
656 
657     eret = memcpy_s(priv_key->qinv, sizeof(priv_key->qinv), key_buffer + off_set, WITHOUT_ZERO);
658     if (eret != EOK)
659         return TEE_ERROR_SECURITY;
660 
661     fill_priv_key_size(priv_key);
662     return TEE_SUCCESS;
663 }
664 
665 #define V1_WB_KEY_LEN 336U
get_private_key_v1(void)666 static RSA *get_private_key_v1(void)
667 {
668     RSA *ret_key = NULL;
669     struct rsa_priv_key *priv_key = NULL;
670     uint8_t key_buffer[RESULT1] = {0};
671     uint32_t key_len = 0;
672     struct wb_tool_inter_key inter_key = {0};
673     struct key_data key_data = {
674         .pro_type = WB_KEY,
675         .ta_type = V1_TYPE,
676         .key = NULL,
677         .key_len = 0,
678     };
679 
680     if (get_ta_load_key(&key_data) != TEE_SUCCESS || key_data.key_len != V1_WB_KEY_LEN || key_data.key == NULL) {
681         tloge("get v1 key failed, wb key len is %zu\n", key_data.key_len);
682         return NULL;
683     }
684 
685     if (get_wb_tool_internal_key(TA_SIGN_VERSION, &inter_key) != TEE_SUCCESS)
686         return NULL;
687 
688     int32_t iret = wb_aes_decrypt_cbc(&inter_key, key_data.key, key_data.key_len, key_buffer, &key_len);
689     if (iret != 0 || key_len > key_data.key_len) {
690         tloge("Whitebox Generate PrivateKey failed:%d, or decrypt len error:%u", iret, key_len);
691         return NULL;
692     }
693 
694     priv_key = TEE_Malloc(sizeof(struct rsa_priv_key), 0);
695     if (priv_key == NULL)
696         return NULL;
697 
698     if (convert_v1_key(key_buffer, key_len, priv_key) != TEE_SUCCESS) {
699         (void)memset_s(key_buffer, sizeof(key_buffer), 0, sizeof(key_buffer));
700         TEE_Free(priv_key);
701         return NULL;
702     }
703     (void)memset_s(key_buffer, sizeof(key_buffer), 0, sizeof(key_buffer));
704 
705     /* get the RSA private key  */
706     ret_key = rsa_build_key(priv_key);
707     (void)memset_s(priv_key, sizeof(*priv_key), 0, sizeof(*priv_key));
708     TEE_Free(priv_key);
709     return ret_key;
710 }
711 
get_private_key(int32_t img_version,enum ta_type type)712 RSA *get_private_key(int32_t img_version, enum ta_type type)
713 {
714     switch (img_version) {
715     case TA_SIGN_VERSION:
716         return get_private_key_v1();
717     case TA_RSA2048_VERSION:
718     case CIPHER_LAYER_VERSION:
719         return get_private_key_v2(img_version, type);
720     default:
721         tloge("Unsupported secure image version!\n");
722         return NULL;
723     }
724 }
725 
free_private_key(RSA * priv_key)726 void free_private_key(RSA *priv_key)
727 {
728     if (priv_key != NULL)
729         RSA_free(priv_key);
730 }
731