1 /*
2 * Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License); you may
5 * not use this file except in compliance with the License.
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 */
9
10
11
12 #include <time.h>
13 #include <stdio.h>
14 #include <assert.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <unistd.h>
18 #include <fcntl.h>
19 #include <sys/types.h>
20 #include <arpa/inet.h>
21 #include <sys/socket.h>
22 #include <netinet/in.h>
23 #include <gmssl/rand.h>
24 #include <gmssl/x509.h>
25 #include <gmssl/error.h>
26 #include <gmssl/mem.h>
27 #include <gmssl/sm2.h>
28 #include <gmssl/sm3.h>
29 #include <gmssl/sm4.h>
30 #include <gmssl/pem.h>
31 #include <gmssl/tls.h>
32
33
tls_uint8_to_bytes(uint8_t a,uint8_t ** out,size_t * outlen)34 void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen)
35 {
36 if (out && *out) {
37 *(*out)++ = a;
38 }
39 (*outlen)++;
40 }
41
tls_uint16_to_bytes(uint16_t a,uint8_t ** out,size_t * outlen)42 void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen)
43 {
44 if (out && *out) {
45 *(*out)++ = (uint8_t)(a >> 8);
46 *(*out)++ = (uint8_t)a;
47 }
48 *outlen += 2;
49 }
50
tls_uint24_to_bytes(uint24_t a,uint8_t ** out,size_t * outlen)51 void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen)
52 {
53 if (out && *out) {
54 *(*out)++ = (uint8_t)(a >> 16);
55 *(*out)++ = (uint8_t)(a >> 8);
56 *(*out)++ = (uint8_t)(a);
57 }
58 (*outlen) += 3;
59 }
60
tls_uint32_to_bytes(uint32_t a,uint8_t ** out,size_t * outlen)61 void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen)
62 {
63 if (out && *out) {
64 *(*out)++ = (uint8_t)(a >> 24);
65 *(*out)++ = (uint8_t)(a >> 16);
66 *(*out)++ = (uint8_t)(a >> 8);
67 *(*out)++ = (uint8_t)(a );
68 }
69 (*outlen) += 4;
70 }
71
tls_array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)72 void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
73 {
74 if (out && *out) {
75 if (data) {
76 memcpy(*out, data, datalen);
77 }
78 *out += datalen;
79 }
80 *outlen += datalen;
81 }
82
83 /*
84 这几个函数要区分data = NULL, datalen = 0 和 data = NULL, datalen != 0的情况
85 前者意味着数据为空,因此输出的就是一个长度
86 后者意味着数据不为空,只是我们不想输出数据,只输出头部的长度,并且更新整个的输出长度。 这种情况应该避免!
87
88 */
89
tls_uint8array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)90 void tls_uint8array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
91 {
92 tls_uint8_to_bytes((uint8_t)datalen, out, outlen);
93 tls_array_to_bytes(data, datalen, out, outlen);
94 }
95
tls_uint16array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)96 void tls_uint16array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
97 {
98 tls_uint16_to_bytes((uint16_t)datalen, out, outlen);
99 tls_array_to_bytes(data, datalen, out, outlen);
100 }
101
tls_uint24array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)102 void tls_uint24array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
103 {
104 tls_uint24_to_bytes((uint24_t)datalen, out, outlen);
105 tls_array_to_bytes(data, datalen, out, outlen);
106 }
107
tls_uint8_from_bytes(uint8_t * a,const uint8_t ** in,size_t * inlen)108 int tls_uint8_from_bytes(uint8_t *a, const uint8_t **in, size_t *inlen)
109 {
110 if (*inlen < 1) {
111 error_print();
112 return -1;
113 }
114 *a = *(*in)++;
115 (*inlen)--;
116 return 1;
117 }
118
tls_uint16_from_bytes(uint16_t * a,const uint8_t ** in,size_t * inlen)119 int tls_uint16_from_bytes(uint16_t *a, const uint8_t **in, size_t *inlen)
120 {
121 if (*inlen < 2) {
122 error_print();
123 return -1;
124 }
125 *a = *(*in)++;
126 *a <<= 8;
127 *a |= *(*in)++;
128 *inlen -= 2;
129 return 1;
130 }
131
tls_uint24_from_bytes(uint24_t * a,const uint8_t ** in,size_t * inlen)132 int tls_uint24_from_bytes(uint24_t *a, const uint8_t **in, size_t *inlen)
133 {
134 if (*inlen < 3) {
135 error_print();
136 return -1;
137 }
138 *a = *(*in)++;
139 *a <<= 8;
140 *a |= *(*in)++;
141 *a <<= 8;
142 *a |= *(*in)++;
143 *inlen -= 3;
144 return 1;
145 }
146
tls_uint32_from_bytes(uint32_t * a,const uint8_t ** in,size_t * inlen)147 int tls_uint32_from_bytes(uint32_t *a, const uint8_t **in, size_t *inlen)
148 {
149 if (*inlen < 4) {
150 error_print();
151 return -1;
152 }
153 *a = *(*in)++;
154 *a <<= 8;
155 *a |= *(*in)++;
156 *a <<= 8;
157 *a |= *(*in)++;
158 *a <<= 8;
159 *a |= *(*in)++;
160 *inlen -= 4;
161 return 1;
162 }
163
tls_array_from_bytes(const uint8_t ** data,size_t datalen,const uint8_t ** in,size_t * inlen)164 int tls_array_from_bytes(const uint8_t **data, size_t datalen, const uint8_t **in, size_t *inlen)
165 {
166 if (*inlen < datalen) {
167 error_print();
168 return -1;
169 }
170 *data = *in;
171 *in += datalen;
172 *inlen -= datalen;
173 return 1;
174 }
175
tls_uint8array_from_bytes(const uint8_t ** data,size_t * datalen,const uint8_t ** in,size_t * inlen)176 int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
177 {
178 uint8_t len;
179 if (tls_uint8_from_bytes(&len, in, inlen) != 1
180 || tls_array_from_bytes(data, len, in, inlen) != 1) {
181 error_print();
182 return -1;
183 }
184 if (!len) {
185 *data = NULL;
186 }
187 *datalen = len;
188 return 1;
189 }
190
tls_uint16array_from_bytes(const uint8_t ** data,size_t * datalen,const uint8_t ** in,size_t * inlen)191 int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
192 {
193 uint16_t len;
194 if (tls_uint16_from_bytes(&len, in, inlen) != 1
195 || tls_array_from_bytes(data, len, in, inlen) != 1) {
196 error_print();
197 return -1;
198 }
199 if (!len) {
200 *data = NULL;
201 }
202 *datalen = len;
203 return 1;
204 }
205
tls_uint24array_from_bytes(const uint8_t ** data,size_t * datalen,const uint8_t ** in,size_t * inlen)206 int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
207 {
208 uint24_t len;
209 if (tls_uint24_from_bytes(&len, in, inlen) != 1
210 || tls_array_from_bytes(data, len, in, inlen) != 1) {
211 error_print();
212 return -1;
213 }
214 if (!len) {
215 *data = NULL;
216 }
217 *datalen = len;
218 return 1;
219 }
220
tls_length_is_zero(size_t len)221 int tls_length_is_zero(size_t len)
222 {
223 if (len) {
224 error_print();
225 return -1;
226 }
227 return 1;
228 }
229
tls_record_set_type(uint8_t * record,int type)230 int tls_record_set_type(uint8_t *record, int type)
231 {
232 if (!tls_record_type_name(type)) {
233 error_print();
234 return -1;
235 }
236 record[0] = type;
237 return 1;
238 }
239
tls_record_set_protocol(uint8_t * record,int protocol)240 int tls_record_set_protocol(uint8_t *record, int protocol)
241 {
242 if (!tls_protocol_name(protocol)) {
243 error_print();
244 return -1;
245 }
246 record[1] = protocol >> 8;
247 record[2] = protocol;
248 return 1;
249 }
250
tls_record_set_length(uint8_t * record,size_t length)251 int tls_record_set_length(uint8_t *record, size_t length)
252 {
253 uint8_t *p = record + 3;
254 size_t len;
255 if (length > TLS_MAX_CIPHERTEXT_SIZE) {
256 error_print();
257 return -1;
258 }
259 tls_uint16_to_bytes(length, &p, &len);
260 return 1;
261 }
262
tls_record_set_data(uint8_t * record,const uint8_t * data,size_t datalen)263 int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen)
264 {
265 if (tls_record_set_length(record, datalen) != 1) {
266 error_print();
267 return -1;
268 }
269 memcpy(tls_record_data(record), data, datalen);
270 return 1;
271 }
272
tls_cbc_encrypt(const SM3_HMAC_CTX * inited_hmac_ctx,const SM4_KEY * enc_key,const uint8_t seq_num[8],const uint8_t header[5],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)273 int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
274 const uint8_t seq_num[8], const uint8_t header[5],
275 const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
276 {
277 SM3_HMAC_CTX hmac_ctx;
278 uint8_t last_blocks[32 + 16] = {0};
279 uint8_t *mac, *padding, *iv;
280 int rem, padding_len;
281 int i;
282
283 if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) {
284 error_print();
285 return -1;
286 }
287 if (inlen > (1 << 14)) {
288 error_print_msg("invalid tls record data length %zu\n", inlen);
289 return -1;
290 }
291 if ((((size_t)header[3]) << 8) + header[4] != inlen) {
292 error_print();
293 return -1;
294 }
295
296 rem = (inlen + 32) % 16;
297 memcpy(last_blocks, in + inlen - rem, rem);
298 mac = last_blocks + rem;
299
300 memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
301 sm3_hmac_update(&hmac_ctx, seq_num, 8);
302 sm3_hmac_update(&hmac_ctx, header, 5);
303 sm3_hmac_update(&hmac_ctx, in, inlen);
304 sm3_hmac_finish(&hmac_ctx, mac);
305
306 padding = mac + 32;
307 padding_len = 16 - rem - 1;
308 for (i = 0; i <= padding_len; i++) {
309 padding[i] = padding_len;
310 }
311
312 iv = out;
313 if (rand_bytes(iv, 16) != 1) {
314 error_print();
315 return -1;
316 }
317 out += 16;
318
319 if (inlen >= 16) {
320 sm4_cbc_encrypt(enc_key, iv, in, inlen/16, out);
321 out += inlen - rem;
322 iv = out - 16;
323 }
324 sm4_cbc_encrypt(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
325 *outlen = 16 + inlen - rem + sizeof(last_blocks);
326 return 1;
327 }
328
tls_cbc_decrypt(const SM3_HMAC_CTX * inited_hmac_ctx,const SM4_KEY * dec_key,const uint8_t seq_num[8],const uint8_t enced_header[5],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)329 int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
330 const uint8_t seq_num[8], const uint8_t enced_header[5],
331 const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
332 {
333 SM3_HMAC_CTX hmac_ctx;
334 const uint8_t *iv;
335 const uint8_t *padding;
336 const uint8_t *mac;
337 uint8_t header[5];
338 int padding_len;
339 uint8_t hmac[32];
340 int i;
341
342 if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) {
343 error_print();
344 return -1;
345 }
346 if (inlen % 16
347 || inlen < (16 + 0 + 32 + 16) // iv + data + mac + padding
348 || inlen > (16 + (1<<14) + 32 + 256)) {
349 error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen);
350 return -1;
351 }
352
353 iv = in;
354 in += 16;
355 inlen -= 16;
356
357 sm4_cbc_decrypt(dec_key, iv, in, inlen/16, out);
358
359 padding_len = out[inlen - 1];
360 padding = out + inlen - padding_len - 1;
361 if (padding < out + 32) {
362 error_print();
363 return -1;
364 }
365 for (i = 0; i < padding_len; i++) {
366 if (padding[i] != padding_len) {
367 error_puts("tls ciphertext cbc-padding check failure");
368 return -1;
369 }
370 }
371
372 *outlen = inlen - 32 - padding_len - 1;
373
374 header[0] = enced_header[0];
375 header[1] = enced_header[1];
376 header[2] = enced_header[2];
377 header[3] = (*outlen) >> 8;
378 header[4] = (*outlen);
379 mac = padding - 32;
380
381 memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
382 sm3_hmac_update(&hmac_ctx, seq_num, 8);
383 sm3_hmac_update(&hmac_ctx, header, 5);
384 sm3_hmac_update(&hmac_ctx, out, *outlen);
385 sm3_hmac_finish(&hmac_ctx, hmac);
386 if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
387 error_puts("tls ciphertext mac check failure\n");
388 return -1;
389 }
390 return 1;
391 }
392
tls_record_encrypt(const SM3_HMAC_CTX * hmac_ctx,const SM4_KEY * cbc_key,const uint8_t seq_num[8],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)393 int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
394 const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
395 uint8_t *out, size_t *outlen)
396 {
397 if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in,
398 in + 5, inlen - 5,
399 out + 5, outlen) != 1) {
400 error_print();
401 return -1;
402 }
403
404 out[0] = in[0];
405 out[1] = in[1];
406 out[2] = in[2];
407 out[3] = (*outlen) >> 8;
408 out[4] = (*outlen);
409 (*outlen) += 5;
410 return 1;
411 }
412
tls_record_decrypt(const SM3_HMAC_CTX * hmac_ctx,const SM4_KEY * cbc_key,const uint8_t seq_num[8],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)413 int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
414 const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
415 uint8_t *out, size_t *outlen)
416 {
417 if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in,
418 in + 5, inlen - 5,
419 out + 5, outlen) != 1) {
420 error_print();
421 return -1;
422 }
423
424 out[0] = in[0];
425 out[1] = in[1];
426 out[2] = in[2];
427 out[3] = (*outlen) >> 8;
428 out[4] = (*outlen);
429 (*outlen) += 5;
430
431 return 1;
432 }
433
tls_random_generate(uint8_t random[32])434 int tls_random_generate(uint8_t random[32])
435 {
436 uint32_t gmt_unix_time = (uint32_t)time(NULL);
437 uint8_t *p = random;
438 size_t len = 0;
439 tls_uint32_to_bytes(gmt_unix_time, &p, &len);
440 if (rand_bytes(random + 4, 28) != 1) {
441 error_print();
442 return -1;
443 }
444 return 1;
445 }
446
tls_prf(const uint8_t * secret,size_t secretlen,const char * label,const uint8_t * seed,size_t seedlen,const uint8_t * more,size_t morelen,size_t outlen,uint8_t * out)447 int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
448 const uint8_t *seed, size_t seedlen,
449 const uint8_t *more, size_t morelen,
450 size_t outlen, uint8_t *out)
451 {
452 SM3_HMAC_CTX inited_hmac_ctx;
453 SM3_HMAC_CTX hmac_ctx;
454 uint8_t A[32];
455 uint8_t hmac[32];
456 size_t len;
457
458 if (!secret || !secretlen || !label || !seed || !seedlen
459 || (!more && morelen) || !outlen || !out) {
460 error_print();
461 return -1;
462 }
463
464 sm3_hmac_init(&inited_hmac_ctx, secret, secretlen);
465
466 memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
467 sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
468 sm3_hmac_update(&hmac_ctx, seed, seedlen);
469 sm3_hmac_update(&hmac_ctx, more, morelen);
470 sm3_hmac_finish(&hmac_ctx, A);
471
472 memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
473 sm3_hmac_update(&hmac_ctx, A, sizeof(A));
474 sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
475 sm3_hmac_update(&hmac_ctx, seed, seedlen);
476 sm3_hmac_update(&hmac_ctx, more, morelen);
477 sm3_hmac_finish(&hmac_ctx, hmac);
478
479 len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
480 memcpy(out, hmac, len);
481 out += len;
482 outlen -= len;
483
484 while (outlen) {
485 memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
486 sm3_hmac_update(&hmac_ctx, A, sizeof(A));
487 sm3_hmac_finish(&hmac_ctx, A);
488
489 memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
490 sm3_hmac_update(&hmac_ctx, A, sizeof(A));
491 sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
492 sm3_hmac_update(&hmac_ctx, seed, seedlen);
493 sm3_hmac_update(&hmac_ctx, more, morelen);
494 sm3_hmac_finish(&hmac_ctx, hmac);
495
496 len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
497 memcpy(out, hmac, len);
498 out += len;
499 outlen -= len;
500 }
501 return 1;
502 }
503
tls_pre_master_secret_generate(uint8_t pre_master_secret[48],int protocol)504 int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol)
505 {
506 if (!tls_protocol_name(protocol)) {
507 error_print();
508 return -1;
509 }
510 pre_master_secret[0] = protocol >> 8;
511 pre_master_secret[1] = protocol;
512 if (rand_bytes(pre_master_secret + 2, 46) != 1) {
513 error_print();
514 return -1;
515 }
516 return 1;
517 }
518
519 // 用于设置CertificateRequest
tls_cert_type_from_oid(int oid)520 int tls_cert_type_from_oid(int oid)
521 {
522 switch (oid) {
523 case OID_sm2sign_with_sm3:
524 case OID_ecdsa_with_sha1:
525 case OID_ecdsa_with_sha224:
526 case OID_ecdsa_with_sha256:
527 case OID_ecdsa_with_sha512:
528 return TLS_cert_type_ecdsa_sign;
529 case OID_rsasign_with_sm3:
530 case OID_rsasign_with_md5:
531 case OID_rsasign_with_sha1:
532 case OID_rsasign_with_sha224:
533 case OID_rsasign_with_sha256:
534 case OID_rsasign_with_sha384:
535 case OID_rsasign_with_sha512:
536 return TLS_cert_type_rsa_sign;
537 }
538 // TLS_cert_type_xxx 中没有为0的值
539 return 0;
540 }
541
542 // 这两个函数没有对应的TLCP版本
tls_sign_server_ecdh_params(const SM2_KEY * server_sign_key,const uint8_t client_random[32],const uint8_t server_random[32],int curve,const SM2_POINT * point,uint8_t * sig,size_t * siglen)543 int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key,
544 const uint8_t client_random[32], const uint8_t server_random[32],
545 int curve, const SM2_POINT *point, uint8_t *sig, size_t *siglen)
546 {
547 uint8_t server_ecdh_params[69];
548 SM2_SIGN_CTX sign_ctx;
549
550 if (!server_sign_key || !client_random || !server_random
551 || curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen) {
552 error_print();
553 return -1;
554 }
555 server_ecdh_params[0] = TLS_curve_type_named_curve;
556 server_ecdh_params[1] = curve >> 8;
557 server_ecdh_params[2] = curve;
558 server_ecdh_params[3] = 65;
559 sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
560
561 sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
562 sm2_sign_update(&sign_ctx, client_random, 32);
563 sm2_sign_update(&sign_ctx, server_random, 32);
564 sm2_sign_update(&sign_ctx, server_ecdh_params, 69);
565 sm2_sign_finish(&sign_ctx, sig, siglen);
566
567 return 1;
568 }
569
tls_verify_server_ecdh_params(const SM2_KEY * server_sign_key,const uint8_t client_random[32],const uint8_t server_random[32],int curve,const SM2_POINT * point,const uint8_t * sig,size_t siglen)570 int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key,
571 const uint8_t client_random[32], const uint8_t server_random[32],
572 int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen)
573 {
574 int ret;
575 uint8_t server_ecdh_params[69];
576 SM2_SIGN_CTX verify_ctx;
577
578 if (!server_sign_key || !client_random || !server_random
579 || curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen
580 || siglen > SM2_MAX_SIGNATURE_SIZE) {
581 error_print();
582 return -1;
583 }
584 server_ecdh_params[0] = TLS_curve_type_named_curve;
585 server_ecdh_params[1] = curve >> 8;
586 server_ecdh_params[2] = curve;
587 server_ecdh_params[3] = 65;
588 sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
589
590 sm2_verify_init(&verify_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
591 sm2_verify_update(&verify_ctx, client_random, 32);
592 sm2_verify_update(&verify_ctx, server_random, 32);
593 sm2_verify_update(&verify_ctx, server_ecdh_params, 69);
594 ret = sm2_verify_finish(&verify_ctx, sig, siglen);
595 if (ret != 1) error_print();
596 return ret;
597 }
598
tls_record_set_handshake(uint8_t * record,size_t * recordlen,int type,const uint8_t * data,size_t datalen)599 int tls_record_set_handshake(uint8_t *record, size_t *recordlen,
600 int type, const uint8_t *data, size_t datalen)
601 {
602 size_t handshakelen;
603
604 if (!record || !recordlen) {
605 error_print();
606 return -1;
607 }
608 // 由于ServerHelloDone没有负载数据,因此允许 data,datalen = NULL,0
609 if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) {
610 error_print();
611 return -1;
612 }
613 if (!tls_protocol_name(tls_record_protocol(record))) {
614 error_print();
615 return -1;
616 }
617 if (!tls_handshake_type_name(type)) {
618 error_print();
619 return -1;
620 }
621 handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen;
622 record[0] = TLS_record_handshake;
623 record[3] = handshakelen >> 8;
624 record[4] = handshakelen;
625 record[5] = type;
626 record[6] = datalen >> 16;
627 record[7] = datalen >> 8;
628 record[8] = datalen;
629 if (data) {
630 memcpy(tls_handshake_data(tls_record_data(record)), data, datalen);
631 }
632 *recordlen = TLS_RECORD_HEADER_SIZE + handshakelen;
633 return 1;
634 }
635
tls_record_get_handshake(const uint8_t * record,int * type,const uint8_t ** data,size_t * datalen)636 int tls_record_get_handshake(const uint8_t *record,
637 int *type, const uint8_t **data, size_t *datalen)
638 {
639 const uint8_t *handshake;
640 size_t handshake_len;
641 uint24_t handshake_datalen;
642
643 if (!record || !type || !data || !datalen) {
644 error_print();
645 return -1;
646 }
647 if (!tls_protocol_name(tls_record_protocol(record))) {
648 error_print();
649 return -1;
650 }
651 if (tls_record_type(record) != TLS_record_handshake) {
652 error_print();
653 return -1;
654 }
655 handshake = tls_record_data(record);
656 handshake_len = tls_record_data_length(record);
657
658 if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) {
659 error_print();
660 return -1;
661 }
662 if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) {
663 // 不支持证书长度超过记录长度的特殊情况
664 error_print();
665 return -1;
666 }
667
668 if (!tls_handshake_type_name(handshake[0])) {
669 error_print();
670 return -1;
671 }
672 *type = handshake[0];
673
674 handshake++;
675 handshake_len--;
676 if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) != 1) {
677 error_print();
678 return -1;
679 }
680 if (handshake_len != handshake_datalen) {
681 error_print();
682 return -1;
683 }
684 *data = handshake;
685 *datalen = handshake_datalen;
686
687 if (*datalen == 0) {
688 *data = NULL;
689 }
690 return 1;
691 }
692
tls_record_set_handshake_client_hello(uint8_t * record,size_t * recordlen,int protocol,const uint8_t random[32],const uint8_t * session_id,size_t session_id_len,const int * cipher_suites,size_t cipher_suites_count,const uint8_t * exts,size_t exts_len)693 int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen,
694 int protocol, const uint8_t random[32],
695 const uint8_t *session_id, size_t session_id_len,
696 const int *cipher_suites, size_t cipher_suites_count,
697 const uint8_t *exts, size_t exts_len)
698 {
699 uint8_t type = TLS_handshake_client_hello;
700 uint8_t *p;
701 size_t len;
702
703 if (!record || !recordlen || !random || !cipher_suites || !cipher_suites_count) {
704 error_print();
705 return -1;
706 }
707 if (session_id) {
708 if (!session_id_len
709 || session_id_len < TLS_MAX_SESSION_ID_SIZE
710 || session_id_len > TLS_MAX_SESSION_ID_SIZE) {
711 error_print();
712 return -1;
713 }
714 }
715 if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) {
716 error_print();
717 return -1;
718 }
719 if (exts && !exts_len) {
720 error_print();
721 return -1;
722 }
723
724
725 p = tls_handshake_data(tls_record_data(record));
726 len = 0;
727
728 if (!tls_protocol_name(protocol)) {
729 error_print();
730 return -1;
731 }
732 tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
733 tls_array_to_bytes(random, 32, &p, &len);
734 tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
735 tls_uint16_to_bytes(cipher_suites_count * 2, &p, &len);
736 while (cipher_suites_count--) {
737 if (!tls_cipher_suite_name(*cipher_suites)) {
738 error_print();
739 return -1;
740 }
741 tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len);
742 cipher_suites++;
743 }
744 tls_uint8_to_bytes(1, &p, &len);
745 tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
746 if (exts) {
747 size_t tmp_len = len;
748 if (protocol < TLS_protocol_tls12) {
749 error_print();
750 return -1;
751 }
752 tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len);
753 if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) {
754 error_print();
755 return -1;
756 }
757 tls_uint16array_to_bytes(exts, exts_len, &p, &len);
758 }
759 if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
760 error_print();
761 return -1;
762 }
763 return 1;
764 }
765
tls_record_get_handshake_client_hello(const uint8_t * record,int * protocol,const uint8_t ** random,const uint8_t ** session_id,size_t * session_id_len,const uint8_t ** cipher_suites,size_t * cipher_suites_len,const uint8_t ** exts,size_t * exts_len)766 int tls_record_get_handshake_client_hello(const uint8_t *record,
767 int *protocol, const uint8_t **random,
768 const uint8_t **session_id, size_t *session_id_len,
769 const uint8_t **cipher_suites, size_t *cipher_suites_len,
770 const uint8_t **exts, size_t *exts_len)
771 {
772 int type;
773 const uint8_t *p;
774 size_t len;
775 uint16_t ver;
776 const uint8_t *comp_meths;
777 size_t comp_meths_len;
778
779 if (!record || !protocol || !random
780 || !session_id || !session_id_len
781 || !cipher_suites || !cipher_suites_len
782 || !exts || !exts_len) {
783 error_print();
784 return -1;
785 }
786 if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
787 error_print();
788 return -1;
789 }
790 if (type != TLS_handshake_client_hello) {
791 error_print();
792 return -1;
793 }
794 if (tls_uint16_from_bytes(&ver, &p, &len) != 1
795 || tls_array_from_bytes(random, 32, &p, &len) != 1
796 || tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1
797 || tls_uint16array_from_bytes(cipher_suites, cipher_suites_len, &p, &len) != 1
798 || tls_uint8array_from_bytes(&comp_meths, &comp_meths_len, &p, &len) != 1) {
799 error_print();
800 return -1;
801 }
802
803 if (!tls_protocol_name(ver)) {
804 error_print();
805 return -1;
806 }
807 *protocol = ver;
808
809 if (*session_id) {
810 if (*session_id_len == 0
811 || *session_id_len < TLS_MIN_SESSION_ID_SIZE
812 || *session_id_len > TLS_MAX_SESSION_ID_SIZE) {
813 error_print();
814 return -1;
815 }
816 }
817
818 if (!cipher_suites) {
819 error_print();
820 return -1;
821 }
822 if (*cipher_suites_len % 2) {
823 error_print();
824 return -1;
825 }
826
827 if (len) {
828 if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
829 error_print();
830 return -1;
831 }
832 if (*exts == NULL) {
833 error_print();
834 return -1;
835 }
836 } else {
837 *exts = NULL;
838 *exts_len = 0;
839 }
840 if (len) {
841 error_print();
842 return -1;
843 }
844 return 1;
845 }
846
tls_record_set_handshake_server_hello(uint8_t * record,size_t * recordlen,int protocol,const uint8_t random[32],const uint8_t * session_id,size_t session_id_len,int cipher_suite,const uint8_t * exts,size_t exts_len)847 int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen,
848 int protocol, const uint8_t random[32],
849 const uint8_t *session_id, size_t session_id_len, int cipher_suite,
850 const uint8_t *exts, size_t exts_len)
851 {
852 uint8_t type = TLS_handshake_server_hello;
853 uint8_t *p;
854 size_t len;
855
856 if (!record || !recordlen || !random) {
857 error_print();
858 return -1;
859 }
860 if (session_id) {
861 if (session_id_len == 0
862 || session_id_len < TLS_MIN_SESSION_ID_SIZE
863 || session_id_len > TLS_MAX_SESSION_ID_SIZE) {
864 error_print();
865 return -1;
866 }
867 }
868 if (!tls_protocol_name(protocol)) {
869 error_print();
870 return -1;
871 }
872 if (!tls_cipher_suite_name(cipher_suite)) {
873 error_print();
874 return -1;
875 }
876
877 p = tls_handshake_data(tls_record_data(record));
878 len = 0;
879
880 tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
881 tls_array_to_bytes(random, 32, &p, &len);
882 tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
883 tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len);
884 tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
885 if (exts) {
886 if (protocol < TLS_protocol_tls12) {
887 error_print();
888 return -1;
889 }
890 tls_uint16array_to_bytes(exts, exts_len, &p, &len);
891 }
892 if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
893 error_print();
894 return -1;
895 }
896 return 1;
897 }
898
tls_record_get_handshake_server_hello(const uint8_t * record,int * protocol,const uint8_t ** random,const uint8_t ** session_id,size_t * session_id_len,int * cipher_suite,const uint8_t ** exts,size_t * exts_len)899 int tls_record_get_handshake_server_hello(const uint8_t *record,
900 int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len,
901 int *cipher_suite, const uint8_t **exts, size_t *exts_len)
902 {
903 int type;
904 const uint8_t *p;
905 size_t len;
906 uint16_t ver;
907 uint16_t cipher;
908 uint8_t comp_meth;
909
910 if (!record || !protocol || !random || !session_id || !session_id_len
911 || !cipher_suite || !exts || !exts_len) {
912 error_print();
913 return -1;
914 }
915 if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
916 error_print();
917 return -1;
918 }
919 if (type != TLS_handshake_server_hello) {
920 error_print();
921 return -1;
922 }
923 if (tls_uint16_from_bytes(&ver, &p, &len) != 1
924 || tls_array_from_bytes(random, 32, &p, &len) != 1
925 || tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1
926 || tls_uint16_from_bytes(&cipher, &p, &len) != 1
927 || tls_uint8_from_bytes(&comp_meth, &p, &len) != 1) {
928 error_print();
929 return -1;
930 }
931
932 if (!tls_protocol_name(ver)) {
933 error_print();
934 return -1;
935 }
936 if (ver < tls_record_protocol(record)) {
937 error_print();
938 return -1;
939 }
940 *protocol = ver;
941
942 if (*session_id) {
943 if (*session_id == 0
944 || *session_id_len < TLS_MIN_SESSION_ID_SIZE
945 || *session_id_len > TLS_MAX_SESSION_ID_SIZE) {
946 error_print();
947 return -1;
948 }
949 }
950
951 if (!tls_cipher_suite_name(cipher)) {
952 error_print();
953 return -1;
954 }
955 *cipher_suite = cipher;
956
957 if (comp_meth != TLS_compression_null) {
958 error_print();
959 return -1;
960 }
961
962 if (len) {
963 if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
964 error_print();
965 return -1;
966 }
967 if (*exts == NULL) {
968 error_print();
969 return -1;
970 }
971 } else {
972 *exts = NULL;
973 *exts_len = 0;
974 }
975 if (len) {
976 error_print();
977 return -1;
978 }
979 return 1;
980 }
981
tls_record_set_handshake_certificate(uint8_t * record,size_t * recordlen,const uint8_t * certs,size_t certslen)982 int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen,
983 const uint8_t *certs, size_t certslen)
984 {
985 int type = TLS_handshake_certificate;
986 uint8_t *data;
987 size_t datalen;
988 uint8_t *p;
989 size_t len;
990
991 if (!record || !recordlen || !certs || !certslen) {
992 error_print();
993 return -1;
994 }
995 data = tls_handshake_data(tls_record_data(record));
996 p = data + tls_uint24_size();
997 datalen = tls_uint24_size();
998 len = 0;
999
1000 while (certslen) {
1001 const uint8_t *cert;
1002 size_t certlen;
1003
1004 if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1) {
1005 error_print();
1006 return -1;
1007 }
1008 tls_uint24array_to_bytes(cert, certlen, NULL, &datalen);
1009 if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
1010 error_print();
1011 return -1;
1012 }
1013 tls_uint24array_to_bytes(cert, certlen, &p, &len);
1014 }
1015 tls_uint24_to_bytes(len, &data, &len);
1016 tls_record_set_handshake(record, recordlen, type, NULL, datalen);
1017 return 1;
1018 }
1019
tls_record_get_handshake_certificate(const uint8_t * record,uint8_t * certs,size_t * certslen)1020 int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, size_t *certslen)
1021 {
1022 int type;
1023 const uint8_t *data;
1024 size_t datalen;
1025 const uint8_t *cp;
1026 size_t len;
1027
1028 if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) {
1029 error_print();
1030 return -1;
1031 }
1032 if (type != TLS_handshake_certificate) {
1033 error_print();
1034 return -1;
1035 }
1036 if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) {
1037 error_print();
1038 return -1;
1039 }
1040
1041 *certslen = 0;
1042 while (len) {
1043 const uint8_t *a;
1044 size_t alen;
1045 const uint8_t *cert;
1046 size_t certlen;
1047
1048 if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) {
1049 error_print();
1050 return -1;
1051 }
1052 if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1
1053 || asn1_length_is_zero(alen) != 1
1054 || x509_cert_to_der(cert, certlen, &certs, certslen) != 1) {
1055 error_print();
1056 return -1;
1057 }
1058 }
1059 return 1;
1060 }
1061
tls_record_set_handshake_certificate_request(uint8_t * record,size_t * recordlen,const uint8_t * cert_types,size_t cert_types_len,const uint8_t * ca_names,size_t ca_names_len)1062 int tls_record_set_handshake_certificate_request(uint8_t *record, size_t *recordlen,
1063 const uint8_t *cert_types, size_t cert_types_len,
1064 const uint8_t *ca_names, size_t ca_names_len)
1065 {
1066 int type = TLS_handshake_certificate_request;
1067 uint8_t *p;
1068 size_t len =0;
1069 size_t datalen = 0;
1070
1071 if (!record || !recordlen) {
1072 error_print();
1073 return -1;
1074 }
1075 if (cert_types) {
1076 if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) {
1077 error_print();
1078 return -1;
1079 }
1080 }
1081 if (ca_names) {
1082 if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) {
1083 error_print();
1084 return -1;
1085 }
1086 }
1087 tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen);
1088 tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen);
1089 if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
1090 error_print();
1091 return -1;
1092 }
1093 p = tls_handshake_data(tls_record_data(record));
1094 tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len);
1095 tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len);
1096 tls_record_set_handshake(record, recordlen, type, NULL, datalen);
1097 return 1;
1098 }
1099
tls_record_get_handshake_certificate_request(const uint8_t * record,const uint8_t ** cert_types,size_t * cert_types_len,const uint8_t ** ca_names,size_t * ca_names_len)1100 int tls_record_get_handshake_certificate_request(const uint8_t *record,
1101 const uint8_t **cert_types, size_t *cert_types_len,
1102 const uint8_t **ca_names, size_t *ca_names_len)
1103 {
1104 int type;
1105 const uint8_t *cp;
1106 size_t len;
1107 size_t i;
1108
1109 if (!record || !cert_types || !cert_types_len || !ca_names || !ca_names_len) {
1110 error_print();
1111 return -1;
1112 }
1113 if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
1114 error_print();
1115 return -1;
1116 }
1117 if (type != TLS_handshake_certificate_request) {
1118 error_print();
1119 return -1;
1120 }
1121 if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1
1122 || tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1
1123 || tls_length_is_zero(len) != 1) {
1124 error_print();
1125 return -1;
1126 }
1127
1128 if (*cert_types == NULL) {
1129 error_print();
1130 return -1;
1131 }
1132 for (i = 0; i < *cert_types_len; i++) {
1133 if (!tls_cert_type_name((*cert_types)[i])) {
1134 error_print();
1135 return -1;
1136 }
1137 }
1138 if (*ca_names) {
1139 const uint8_t *names = *ca_names;
1140 size_t nameslen = *ca_names_len;
1141 while (nameslen) {
1142 if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) {
1143 error_print();
1144 return -1;
1145 }
1146 }
1147 }
1148 return 1;
1149 }
1150
tls_record_set_handshake_server_hello_done(uint8_t * record,size_t * recordlen)1151 int tls_record_set_handshake_server_hello_done(uint8_t *record, size_t *recordlen)
1152 {
1153 int type = TLS_handshake_server_hello_done;
1154 if (!record || !recordlen) {
1155 error_print();
1156 return -1;
1157 }
1158 tls_record_set_handshake(record, recordlen, type, NULL, 0);
1159 return 1;
1160 }
1161
tls_record_get_handshake_server_hello_done(const uint8_t * record)1162 int tls_record_get_handshake_server_hello_done(const uint8_t *record)
1163 {
1164 int type;
1165 const uint8_t *p;
1166 size_t len;
1167
1168 if (!record) {
1169 error_print();
1170 return -1;
1171 }
1172 if (tls_record_get_handshake(record, &type, &p, &len) != 1
1173 || type != TLS_handshake_server_hello_done) {
1174 error_print();
1175 return -1;
1176 }
1177 if (p != NULL || len != 0) {
1178 error_print();
1179 return -1;
1180 }
1181 return 1;
1182 }
1183
tls_record_set_handshake_client_key_exchange_pke(uint8_t * record,size_t * recordlen,const uint8_t * enced_pms,size_t enced_pms_len)1184 int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record, size_t *recordlen,
1185 const uint8_t *enced_pms, size_t enced_pms_len)
1186 {
1187 int type = TLS_handshake_client_key_exchange;
1188 uint8_t *p;
1189 size_t len = 0;
1190
1191 if (!record || !recordlen || !enced_pms || !enced_pms_len) {
1192 error_print();
1193 return -1;
1194 }
1195 if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) {
1196 error_print();
1197 return -1;
1198 }
1199 p = tls_handshake_data(tls_record_data(record));
1200 tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len);
1201 tls_record_set_handshake(record, recordlen, type, NULL, len);
1202 return 1;
1203 }
1204
tls_record_get_handshake_client_key_exchange_pke(const uint8_t * record,const uint8_t ** enced_pms,size_t * enced_pms_len)1205 int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record,
1206 const uint8_t **enced_pms, size_t *enced_pms_len)
1207 {
1208 int type;
1209 const uint8_t *cp;
1210 size_t len;
1211
1212 if (!record || !enced_pms || !enced_pms_len) {
1213 error_print();
1214 return -1;
1215 }
1216 if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
1217 error_print();
1218 return -1;
1219 }
1220 if (type != TLS_handshake_client_key_exchange) {
1221 error_print();
1222 return -1;
1223 }
1224 if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1
1225 || tls_length_is_zero(len) != 1) {
1226 error_print();
1227 return -1;
1228 }
1229 return 1;
1230 }
1231
tls_record_set_handshake_certificate_verify(uint8_t * record,size_t * recordlen,const uint8_t * sig,size_t siglen)1232 int tls_record_set_handshake_certificate_verify(uint8_t *record, size_t *recordlen,
1233 const uint8_t *sig, size_t siglen)
1234 {
1235 int type = TLS_handshake_certificate_verify;
1236
1237 if (!record || !recordlen || !sig || !siglen) {
1238 error_print();
1239 return -1;
1240 }
1241 if (siglen > TLS_MAX_SIGNATURE_SIZE) {
1242 error_print();
1243 return -1;
1244 }
1245 tls_record_set_handshake(record, recordlen, type, sig, siglen);
1246 return 1;
1247 }
1248
tls_record_get_handshake_certificate_verify(const uint8_t * record,const uint8_t ** sig,size_t * siglen)1249 int tls_record_get_handshake_certificate_verify(const uint8_t *record,
1250 const uint8_t **sig, size_t *siglen)
1251 {
1252 int type;
1253
1254 if (!record || !sig || !siglen) {
1255 error_print();
1256 return -1;
1257 }
1258 if (tls_record_get_handshake(record, &type, sig, siglen) != 1) {
1259 error_print();
1260 return -1;
1261 }
1262 if (type != TLS_handshake_certificate_verify) {
1263 error_print();
1264 return -1;
1265 }
1266 if (*sig == NULL || *siglen == 0) {
1267 error_print();
1268 return -1;
1269 }
1270 if (*siglen > TLS_MAX_SIGNATURE_SIZE) {
1271 error_print();
1272 return -1;
1273 }
1274 return 1;
1275 }
1276
tls_record_set_handshake_finished(uint8_t * record,size_t * recordlen,const uint8_t * verify_data,size_t verify_data_len)1277 int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen,
1278 const uint8_t *verify_data, size_t verify_data_len)
1279 {
1280 int type = TLS_handshake_finished;
1281
1282 if (!record || !recordlen || !verify_data || !verify_data_len) {
1283 error_print();
1284 return -1;
1285 }
1286 if (verify_data_len != 12 && verify_data_len != 32) {
1287 error_print();
1288 return -1;
1289 }
1290 tls_record_set_handshake(record, recordlen, type, verify_data, verify_data_len);
1291 return 1;
1292 }
1293
tls_record_get_handshake_finished(const uint8_t * record,const uint8_t ** verify_data,size_t * verify_data_len)1294 int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **verify_data, size_t *verify_data_len)
1295 {
1296 int type;
1297
1298 if (!record || !verify_data || !verify_data_len) {
1299 error_print();
1300 return -1;
1301 }
1302 if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) != 1) {
1303 error_print();
1304 return -1;
1305 }
1306 if (type != TLS_handshake_finished) {
1307 error_print();
1308 return -1;
1309 }
1310 if (*verify_data == NULL || *verify_data_len == 0) {
1311 error_print();
1312 return -1;
1313 }
1314 if (*verify_data_len != 12 && *verify_data_len != 32) {
1315 error_print();
1316 return -1;
1317 }
1318 return 1;
1319 }
1320
tls_record_set_alert(uint8_t * record,size_t * recordlen,int alert_level,int alert_description)1321 int tls_record_set_alert(uint8_t *record, size_t *recordlen,
1322 int alert_level,
1323 int alert_description)
1324 {
1325 if (!record || !recordlen) {
1326 error_print();
1327 return -1;
1328 }
1329 if (!tls_alert_level_name(alert_level)) {
1330 error_print();
1331 return -1;
1332 }
1333 if (!tls_alert_description_text(alert_description)) {
1334 error_print();
1335 return -1;
1336 }
1337 record[0] = TLS_record_alert;
1338 record[3] = 0; // length
1339 record[4] = 2; // length
1340 record[5] = (uint8_t)alert_level;
1341 record[6] = (uint8_t)alert_description;
1342 *recordlen = TLS_RECORD_HEADER_SIZE + 2;
1343 return 1;
1344 }
1345
tls_record_get_alert(const uint8_t * record,int * alert_level,int * alert_description)1346 int tls_record_get_alert(const uint8_t *record,
1347 int *alert_level,
1348 int *alert_description)
1349 {
1350 if (!record || !alert_level || !alert_description) {
1351 error_print();
1352 return -1;
1353 }
1354 if (tls_record_type(record) != TLS_record_alert) {
1355 error_print();
1356 return -1;
1357 }
1358 if (record[3] != 0 || record[4] != 2) {
1359 error_print();
1360 return -1;
1361 }
1362 *alert_level = record[5];
1363 *alert_description = record[6];
1364 if (!tls_alert_level_name(*alert_level)) {
1365 error_print();
1366 return -1;
1367 }
1368 if (!tls_alert_description_text(*alert_description)) {
1369 error_puts("warning");
1370 return -1;
1371 }
1372 return 1;
1373 }
1374
tls_record_set_change_cipher_spec(uint8_t * record,size_t * recordlen)1375 int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen)
1376 {
1377 if (!record || !recordlen) {
1378 error_print();
1379 return -1;
1380 }
1381 record[0] = TLS_record_change_cipher_spec;
1382 record[3] = 0;
1383 record[4] = 1;
1384 record[5] = TLS_change_cipher_spec;
1385 *recordlen = TLS_RECORD_HEADER_SIZE + 1;
1386 return 1;
1387 }
1388
tls_record_get_change_cipher_spec(const uint8_t * record)1389 int tls_record_get_change_cipher_spec(const uint8_t *record)
1390 {
1391 if (!record) {
1392 error_print();
1393 return -1;
1394 }
1395 if (tls_record_type(record) != TLS_record_change_cipher_spec) {
1396 error_print();
1397 return -1;
1398 }
1399 if (record[3] != 0 || record[4] != 1) {
1400 error_print();
1401 return -1;
1402 }
1403 if (record[5] != TLS_change_cipher_spec) {
1404 error_print();
1405 return -1;
1406 }
1407 return 1;
1408 }
1409
tls_record_set_application_data(uint8_t * record,size_t * recordlen,const uint8_t * data,size_t datalen)1410 int tls_record_set_application_data(uint8_t *record, size_t *recordlen,
1411 const uint8_t *data, size_t datalen)
1412 {
1413 if (!record || !recordlen || !data || !datalen) {
1414 error_print();
1415 return -1;
1416 }
1417 record[0] = TLS_record_application_data;
1418 record[3] = (datalen >> 8) & 0xff;
1419 record[4] = datalen & 0xff;
1420 memcpy(tls_record_data(record), data, datalen);
1421 *recordlen = TLS_RECORD_HEADER_SIZE + datalen;
1422 return 1;
1423 }
1424
tls_record_get_application_data(uint8_t * record,const uint8_t ** data,size_t * datalen)1425 int tls_record_get_application_data(uint8_t *record,
1426 const uint8_t **data, size_t *datalen)
1427 {
1428 if (!record || !data || !datalen) {
1429 error_print();
1430 return -1;
1431 }
1432 if (tls_record_type(record) != TLS_record_application_data) {
1433 error_print();
1434 return -1;
1435 }
1436 *datalen = ((size_t)record[3] << 8) | record[4];
1437 *data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0;
1438 return 1;
1439 }
1440
tls_cipher_suite_in_list(int cipher,const int * list,size_t list_count)1441 int tls_cipher_suite_in_list(int cipher, const int *list, size_t list_count)
1442 {
1443 size_t i;
1444 if (!list || !list_count) {
1445 error_print();
1446 return -1;
1447 }
1448 for (i = 0; i < list_count; i++) {
1449 if (cipher == list[i]) {
1450 return 1;
1451 }
1452 }
1453 return 0;
1454 }
1455
tls_record_send(const uint8_t * record,size_t recordlen,int sock)1456 int tls_record_send(const uint8_t *record, size_t recordlen, int sock)
1457 {
1458 ssize_t r;
1459 if (!record) {
1460 error_print();
1461 return -1;
1462 }
1463 if (recordlen < TLS_RECORD_HEADER_SIZE) {
1464 error_print();
1465 return -1;
1466 }
1467 if (tls_record_length(record) != recordlen) {
1468 error_print();
1469 return -1;
1470 }
1471 if ((r = send(sock, record, recordlen, 0)) < 0) {
1472 perror("tls_record_send");
1473 error_print();
1474 return -1;
1475 } else if (r != recordlen) {
1476 error_print();
1477 return -1;
1478 }
1479 return 1;
1480 }
1481
tls_record_do_recv(uint8_t * record,size_t * recordlen,int sock)1482 int tls_record_do_recv(uint8_t *record, size_t *recordlen, int sock)
1483 {
1484 ssize_t r;
1485 int type;
1486 size_t len;
1487
1488 len = 5;
1489 while (len) {
1490 if ((r = recv(sock, record + 5 - len, len, 0)) < 0) {
1491 perror("tls_record_do_recv");
1492 error_print();
1493 return -1;
1494 }
1495 if (r == 0) {
1496 perror("tls_record_do_recv");
1497 error_print();
1498 return 0;
1499 }
1500
1501 len -= r;
1502 }
1503 if (!tls_record_type_name(tls_record_type(record))) {
1504 error_print();
1505 return -1;
1506 }
1507 if (!tls_protocol_name(tls_record_protocol(record))) {
1508 error_print();
1509 return -1;
1510 }
1511 len = (size_t)record[3] << 8 | record[4];
1512 *recordlen = 5 + len;
1513 if (*recordlen > TLS_MAX_RECORD_SIZE) {
1514 // 这里只检查是否超过最大长度,握手协议的长度检查由上层协议完成
1515 error_print();
1516 return -1;
1517 }
1518 while (len) {
1519 if ((r = recv(sock, record + *recordlen - len, len, 0)) < 0) {
1520 perror("tls_record_do_recv");
1521 error_print();
1522 return -1;
1523 }
1524 len -= r;
1525 }
1526 return 1;
1527 }
1528
tls_record_recv(uint8_t * record,size_t * recordlen,int sock)1529 int tls_record_recv(uint8_t *record, size_t *recordlen, int sock)
1530 {
1531 retry:
1532 if (tls_record_do_recv(record, recordlen, sock) != 1) {
1533 error_print();
1534 return -1;
1535 }
1536
1537 if (tls_record_type(record) == TLS_record_alert) {
1538 int level;
1539 int alert;
1540 if (tls_record_get_alert(record, &level, &alert) != 1) {
1541 error_print();
1542 return -1;
1543 }
1544 tls_record_trace(stderr, record, *recordlen, 0, 0);
1545 if (level == TLS_alert_level_warning) {
1546 // 忽略Warning,读取下一个记录
1547 error_puts("Warning record received!\n");
1548 goto retry;
1549 }
1550 if (alert == TLS_alert_close_notify) {
1551 // close_notify是唯一需要提供反馈的Fatal Alert,其他直接中止连接
1552 uint8_t alert_record[TLS_ALERT_RECORD_SIZE];
1553 size_t alert_record_len;
1554 tls_record_set_type(alert_record, TLS_record_alert);
1555 tls_record_set_protocol(alert_record, tls_record_protocol(record));
1556 tls_record_set_alert(alert_record, &alert_record_len, TLS_alert_level_fatal, TLS_alert_close_notify);
1557
1558 tls_trace("send Alert close_notifiy\n");
1559 tls_record_trace(stderr, alert_record, alert_record_len, 0, 0);
1560 tls_record_send(alert_record, alert_record_len, sock);
1561 }
1562 // 返回错误0通知调用方不再做任何处理(无需再发送Alert)
1563 return 0;
1564 }
1565 return 1;
1566 }
1567
tls_seq_num_incr(uint8_t seq_num[8])1568 int tls_seq_num_incr(uint8_t seq_num[8])
1569 {
1570 int i;
1571 for (i = 7; i > 0; i--) {
1572 seq_num[i]++;
1573 if (seq_num[i]) break;
1574 }
1575 // FIXME: 检查溢出
1576 return 1;
1577 }
1578
tls_compression_methods_has_null_compression(const uint8_t * meths,size_t methslen)1579 int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen)
1580 {
1581 if (!meths || !methslen) {
1582 error_print();
1583 return -1;
1584 }
1585 while (methslen--) {
1586 if (*meths++ == TLS_compression_null) {
1587 return 1;
1588 }
1589 }
1590 error_print();
1591 return -1;
1592 }
1593
tls_send_alert(TLS_CONNECT * conn,int alert)1594 int tls_send_alert(TLS_CONNECT *conn, int alert)
1595 {
1596 uint8_t record[5 + 2];
1597 size_t recordlen;
1598
1599 if (!conn) {
1600 error_print();
1601 return -1;
1602 }
1603 tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
1604 tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert);
1605
1606 if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
1607 error_print();
1608 return -1;
1609 }
1610 tls_record_trace(stderr, record, sizeof(record), 0, 0);
1611 return 1;
1612 }
1613
tls_alert_level(int alert)1614 int tls_alert_level(int alert)
1615 {
1616 switch (alert) {
1617 case TLS_alert_bad_certificate:
1618 case TLS_alert_unsupported_certificate:
1619 case TLS_alert_certificate_revoked:
1620 case TLS_alert_certificate_expired:
1621 case TLS_alert_certificate_unknown:
1622 return 0;
1623 case TLS_alert_user_canceled:
1624 case TLS_alert_no_renegotiation:
1625 return TLS_alert_level_warning;
1626 default:
1627 return TLS_alert_level_fatal;
1628 }
1629 return -1;
1630 }
1631
tls_send_warning(TLS_CONNECT * conn,int alert)1632 int tls_send_warning(TLS_CONNECT *conn, int alert)
1633 {
1634 uint8_t record[5 + 2];
1635 size_t recordlen;
1636
1637 if (!conn) {
1638 error_print();
1639 return -1;
1640 }
1641 if (tls_alert_level(alert) == TLS_alert_level_fatal) {
1642 error_print();
1643 return -1;
1644 }
1645 tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
1646 tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert);
1647
1648 if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
1649 error_print();
1650 return -1;
1651 }
1652 tls_record_trace(stderr, record, sizeof(record), 0, 0);
1653 return 1;
1654 }
1655
tls_send(TLS_CONNECT * conn,const uint8_t * in,size_t inlen,size_t * sentlen)1656 int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
1657 {
1658 const SM3_HMAC_CTX *hmac_ctx;
1659 const SM4_KEY *enc_key;
1660 uint8_t *seq_num;
1661 uint8_t *record;
1662 size_t recordlen;
1663 uint8_t *data;
1664 size_t datalen;
1665
1666 if (!conn) {
1667 error_print();
1668 return -1;
1669 }
1670 if (!in || !inlen || !sentlen) {
1671 error_print();
1672 return -1;
1673 }
1674
1675 if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
1676 inlen = TLS_MAX_PLAINTEXT_SIZE;
1677 }
1678
1679 if (conn->is_client) {
1680 hmac_ctx = &conn->client_write_mac_ctx;
1681 enc_key = &conn->client_write_enc_key;
1682 seq_num = conn->client_seq_num;
1683 } else {
1684 hmac_ctx = &conn->server_write_mac_ctx;
1685 enc_key = &conn->server_write_enc_key;
1686 seq_num = conn->server_seq_num;
1687 }
1688 record = conn->record;
1689
1690 tls_trace("send ApplicationData\n");
1691
1692 if (tls_record_set_type(record, TLS_record_application_data) != 1
1693 || tls_record_set_protocol(record, conn->protocol) != 1
1694 || tls_record_set_length(record, inlen) != 1) {
1695 error_print();
1696 return -1;
1697 }
1698
1699 if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record),
1700 in, inlen, tls_record_data(record), &datalen) != 1) {
1701 error_print();
1702 return -1;
1703 }
1704 if (tls_record_set_length(record, datalen) != 1) {
1705 error_print();
1706 return -1;
1707 }
1708 tls_seq_num_incr(seq_num);
1709 if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) {
1710 error_print();
1711 return -1;
1712 }
1713 *sentlen = inlen;
1714 tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
1715 return 1;
1716 }
1717
tls_do_recv(TLS_CONNECT * conn)1718 int tls_do_recv(TLS_CONNECT *conn)
1719 {
1720 int ret;
1721 const SM3_HMAC_CTX *hmac_ctx;
1722 const SM4_KEY *dec_key;
1723 uint8_t *seq_num;
1724
1725 uint8_t *record = conn->record;
1726 size_t recordlen;
1727
1728 if (conn->is_client) {
1729 hmac_ctx = &conn->server_write_mac_ctx;
1730 dec_key = &conn->server_write_enc_key;
1731 seq_num = conn->server_seq_num;
1732 } else {
1733 hmac_ctx = &conn->client_write_mac_ctx;
1734 dec_key = &conn->client_write_enc_key;
1735 seq_num = conn->client_seq_num;
1736 }
1737
1738 tls_trace("recv ApplicationData\n");
1739 if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) {
1740 if (ret < 0) error_print();
1741 return ret;
1742 }
1743
1744 tls_record_trace(stderr, record, recordlen, 0, 0);
1745 if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record,
1746 tls_record_data(record), tls_record_data_length(record),
1747 conn->databuf, &conn->datalen) != 1) {
1748 error_print();
1749 return -1;
1750 }
1751 conn->data = conn->databuf;
1752 tls_seq_num_incr(seq_num);
1753
1754 tls_record_set_data(record, conn->data, conn->datalen);
1755 tls_trace("decrypt ApplicationData\n");
1756 tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
1757 return 1;
1758 }
1759
tls_recv(TLS_CONNECT * conn,uint8_t * out,size_t outlen,size_t * recvlen)1760 int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
1761 {
1762 if (!conn || !out || !outlen || !recvlen) {
1763 error_print();
1764 return -1;
1765 }
1766 if (conn->datalen == 0) {
1767 int ret;
1768 if ((ret = tls_do_recv(conn)) != 1) {
1769 if (ret) error_print();
1770 return ret;
1771 }
1772 }
1773 *recvlen = outlen <= conn->datalen ? outlen : conn->datalen;
1774 memcpy(out, conn->data, *recvlen);
1775 conn->data += *recvlen;
1776 conn->datalen -= *recvlen;
1777 return 1;
1778 }
1779
tls_shutdown(TLS_CONNECT * conn)1780 int tls_shutdown(TLS_CONNECT *conn)
1781 {
1782 size_t recordlen;
1783 if (!conn) {
1784 error_print();
1785 return -1;
1786 }
1787 tls_trace("send Alert close_notify\n");
1788 if (tls_send_alert(conn, TLS_alert_close_notify) != 1) {
1789 error_print();
1790 return -1;
1791 }
1792 tls_trace("recv Alert close_notify\n");
1793
1794 if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) {
1795 error_print();
1796 return -1;
1797 }
1798 tls_record_trace(stderr, conn->record, recordlen, 0, 0);
1799
1800 return 1;
1801 }
1802
tls_authorities_from_certs(uint8_t * names,size_t * nameslen,size_t maxlen,const uint8_t * certs,size_t certslen)1803 int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen)
1804 {
1805 const uint8_t *cert;
1806 size_t certlen;
1807 const uint8_t *name;
1808 size_t namelen;
1809
1810 *nameslen = 0;
1811 while (certslen) {
1812 size_t alen = 0;
1813 if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1
1814 || x509_cert_get_subject(cert, certlen, &name, &namelen) != 1
1815 || asn1_sequence_to_der(name, namelen, NULL, &alen) != 1) {
1816 error_print();
1817 return -1;
1818 }
1819 if (tls_uint16_size() + alen > maxlen) {
1820 error_print();
1821 return -1;
1822 }
1823 tls_uint16_to_bytes(alen, &names, nameslen);
1824 if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) {
1825 error_print();
1826 return -1;
1827 }
1828 maxlen -= alen;
1829 }
1830 return 1;
1831 }
1832
tls_authorities_issued_certificate(const uint8_t * ca_names,size_t ca_names_len,const uint8_t * certs,size_t certslen)1833 int tls_authorities_issued_certificate(const uint8_t *ca_names, size_t ca_names_len, const uint8_t *certs, size_t certslen)
1834 {
1835 const uint8_t *cert;
1836 size_t certlen;
1837 const uint8_t *issuer;
1838 size_t issuer_len;
1839
1840 if (x509_certs_get_last(certs, certslen, &cert, &certlen) != 1
1841 || x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) {
1842 error_print();
1843 return -1;
1844 }
1845 while (ca_names_len) {
1846 const uint8_t *p;
1847 size_t len;
1848 const uint8_t *name;
1849 size_t namelen;
1850
1851 if (tls_uint16array_from_bytes(&p, &len, &ca_names, &ca_names_len) != 1) {
1852 error_print();
1853 return -1;
1854 }
1855 if (asn1_sequence_from_der(&name, &namelen, &p, &len) != 1
1856 || asn1_length_is_zero(len) != 1) {
1857 error_print();
1858 return -1;
1859 }
1860 if (x509_name_equ(name, namelen, issuer, issuer_len) == 1) {
1861 return 1;
1862 }
1863 }
1864 error_print();
1865 return 0;
1866 }
1867
tls_cert_types_accepted(const uint8_t * types,size_t types_len,const uint8_t * client_certs,size_t client_certs_len)1868 int tls_cert_types_accepted(const uint8_t *types, size_t types_len, const uint8_t *client_certs, size_t client_certs_len)
1869 {
1870 const uint8_t *cert;
1871 size_t certlen;
1872 int sig_alg;
1873 size_t i;
1874
1875 if (x509_certs_get_cert_by_index(client_certs, client_certs_len, 0, &cert, &certlen) != 1) {
1876 error_print();
1877 return -1;
1878 }
1879 if ((sig_alg = tls_cert_type_from_oid(OID_sm2sign_with_sm3)) < 0) {
1880 error_print();
1881 return -1;
1882 }
1883 for (i = 0; i < types_len; i++) {
1884 if (sig_alg == types[i]) {
1885 return 1;
1886 }
1887 }
1888 return 0;
1889 }
1890
tls_client_verify_init(TLS_CLIENT_VERIFY_CTX * ctx)1891 int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx)
1892 {
1893 if (!ctx) {
1894 error_print();
1895 return -1;
1896 }
1897 memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX));
1898 return 1;
1899 }
1900
tls_client_verify_update(TLS_CLIENT_VERIFY_CTX * ctx,const uint8_t * handshake,size_t handshake_len)1901 int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *handshake, size_t handshake_len)
1902 {
1903 uint8_t *buf;
1904 if (!ctx || !handshake || !handshake_len) {
1905 error_print();
1906 return -1;
1907 }
1908 if (ctx->index < 0 || ctx->index > 7) {
1909 error_print();
1910 return -1;
1911 }
1912 if (!(buf = malloc(handshake_len))) {
1913 error_print();
1914 return -1;
1915 }
1916 memcpy(buf, handshake, handshake_len);
1917 ctx->handshake[ctx->index] = buf;
1918 ctx->handshake_len[ctx->index] = handshake_len;
1919 ctx->index++;
1920 return 1;
1921 }
1922
tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX * ctx,const uint8_t * sig,size_t siglen,const SM2_KEY * public_key)1923 int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, size_t siglen, const SM2_KEY *public_key)
1924 {
1925 int ret;
1926 SM2_SIGN_CTX sm2_ctx;
1927 int i;
1928
1929 if (!ctx || !sig || !siglen || !public_key) {
1930 error_print();
1931 return -1;
1932 }
1933
1934 if (ctx->index != 8) {
1935 error_print();
1936 return -1;
1937 }
1938 if (sm2_verify_init(&sm2_ctx, public_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) {
1939 error_print();
1940 return -1;
1941 }
1942 for (i = 0; i < 8; i++) {
1943 if (sm2_verify_update(&sm2_ctx, ctx->handshake[i], ctx->handshake_len[i]) != 1) {
1944 error_print();
1945 return -1;
1946 }
1947 }
1948 if ((ret = sm2_verify_finish(&sm2_ctx, sig, siglen)) < 0) {
1949 error_print();
1950 return -1;
1951 }
1952 return ret;
1953 }
1954
tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX * ctx)1955 void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx)
1956 {
1957 if (ctx) {
1958 int i;
1959 for (i = 0; i< ctx->index; i++) {
1960 if (ctx->handshake[i]) {
1961 free(ctx->handshake[i]);
1962 ctx->handshake[i] = NULL;
1963 ctx->handshake_len[i] = 0;
1964 }
1965 }
1966 }
1967 }
1968
tls_cipher_suites_select(const uint8_t * client_ciphers,size_t client_ciphers_len,const int * server_ciphers,size_t server_ciphers_cnt,int * selected_cipher)1969 int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_ciphers_len,
1970 const int *server_ciphers, size_t server_ciphers_cnt,
1971 int *selected_cipher)
1972 {
1973 if (!client_ciphers || !client_ciphers_len
1974 || !server_ciphers || !server_ciphers_cnt || !selected_cipher) {
1975 error_print();
1976 return -1;
1977 }
1978 while (server_ciphers_cnt--) {
1979 const uint8_t *p = client_ciphers;
1980 size_t len = client_ciphers_len;
1981 while (len) {
1982 uint16_t cipher;
1983 if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
1984 error_print();
1985 return -1;
1986 }
1987 if (cipher == *server_ciphers) {
1988 *selected_cipher = *server_ciphers;
1989 return 1;
1990 }
1991 }
1992 server_ciphers++;
1993 }
1994 return 0;
1995 }
1996
tls_ctx_cleanup(TLS_CTX * ctx)1997 void tls_ctx_cleanup(TLS_CTX *ctx)
1998 {
1999 if (ctx) {
2000 gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY));
2001 gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY));
2002 if (ctx->certs) free(ctx->certs);
2003 if (ctx->cacerts) free(ctx->cacerts);
2004 memset(ctx, 0, sizeof(TLS_CTX));
2005 }
2006 }
2007
tls_ctx_init(TLS_CTX * ctx,int protocol,int is_client)2008 int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client)
2009 {
2010 if (!ctx) {
2011 error_print();
2012 return -1;
2013 }
2014 memset(ctx, 0, sizeof(*ctx));
2015
2016 switch (protocol) {
2017 case TLS_protocol_tlcp:
2018 case TLS_protocol_tls12:
2019 case TLS_protocol_tls13:
2020 ctx->protocol = protocol;
2021 break;
2022 default:
2023 error_print();
2024 return -1;
2025 }
2026 ctx->is_client = is_client ? 1 : 0;
2027 return 1;
2028 }
2029
tls_ctx_set_cipher_suites(TLS_CTX * ctx,const int * cipher_suites,size_t cipher_suites_cnt)2030 int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt)
2031 {
2032 size_t i;
2033
2034 if (!ctx || !cipher_suites || !cipher_suites_cnt) {
2035 error_print();
2036 return -1;
2037 }
2038 if (cipher_suites_cnt < 1 || cipher_suites_cnt > TLS_MAX_CIPHER_SUITES_COUNT) {
2039 error_print();
2040 return -1;
2041 }
2042 for (i = 0; i < cipher_suites_cnt; i++) {
2043 if (!tls_cipher_suite_name(cipher_suites[i])) {
2044 error_print();
2045 return -1;
2046 }
2047 }
2048 for (i = 0; i < cipher_suites_cnt; i++) {
2049 ctx->cipher_suites[i] = cipher_suites[i];
2050 }
2051 ctx->cipher_suites_cnt = cipher_suites_cnt;
2052 return 1;
2053 }
2054
tls_ctx_set_ca_certificates(TLS_CTX * ctx,const char * cacertsfile,int depth)2055 int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth)
2056 {
2057 if (!ctx || !cacertsfile) {
2058 error_print();
2059 return -1;
2060 }
2061 if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) {
2062 error_print();
2063 return -1;
2064 }
2065 if (!tls_protocol_name(ctx->protocol)) {
2066 error_print();
2067 return -1;
2068 }
2069 if (ctx->cacerts) {
2070 error_print();
2071 return -1;
2072 }
2073 if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen, cacertsfile) != 1) {
2074 error_print();
2075 return -1;
2076 }
2077 if (ctx->cacertslen == 0) {
2078 error_print();
2079 return -1;
2080 }
2081
2082 ctx->verify_depth = depth;
2083 return 1;
2084 }
2085
tls_ctx_set_certificate_and_key(TLS_CTX * ctx,const char * chainfile,const char * keyfile,const char * keypass)2086 int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile,
2087 const char *keyfile, const char *keypass)
2088 {
2089 int ret = -1;
2090 uint8_t *certs = NULL;
2091 size_t certslen;
2092 FILE *keyfp = NULL;
2093 SM2_KEY key;
2094 const uint8_t *cert;
2095 size_t certlen;
2096 SM2_KEY public_key;
2097
2098 if (!ctx || !chainfile || !keyfile || !keypass) {
2099 error_print();
2100 return -1;
2101 }
2102 if (!tls_protocol_name(ctx->protocol)) {
2103 error_print();
2104 return -1;
2105 }
2106 if (ctx->certs) {
2107 error_print();
2108 return -1;
2109 }
2110
2111 if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
2112 error_print();
2113 goto end;
2114 }
2115 if (!(keyfp = fopen(keyfile, "r"))) {
2116 error_print();
2117 goto end;
2118 }
2119 if (sm2_private_key_info_decrypt_from_pem(&key, keypass, keyfp) != 1) {
2120 error_print();
2121 goto end;
2122 }
2123 if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
2124 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
2125 error_print();
2126 return -1;
2127 }
2128 if (sm2_public_key_equ(&key, &public_key) != 1) {
2129 error_print();
2130 return -1;
2131 }
2132 ctx->certs = certs;
2133 ctx->certslen = certslen;
2134 ctx->signkey = key;
2135 certs = NULL;
2136 ret = 1;
2137
2138 end:
2139 gmssl_secure_clear(&key, sizeof(key));
2140 if (certs) free(certs);
2141 if (keyfp) fclose(keyfp);
2142 return ret;
2143 }
2144
tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX * ctx,const char * chainfile,const char * signkeyfile,const char * signkeypass,const char * kenckeyfile,const char * kenckeypass)2145 int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile,
2146 const char *signkeyfile, const char *signkeypass,
2147 const char *kenckeyfile, const char *kenckeypass)
2148 {
2149 int ret = -1;
2150 uint8_t *certs = NULL;
2151 size_t certslen;
2152 FILE *signkeyfp = NULL;
2153 FILE *kenckeyfp = NULL;
2154 SM2_KEY signkey;
2155 SM2_KEY kenckey;
2156
2157 const uint8_t *cert;
2158 size_t certlen;
2159 SM2_KEY public_key;
2160
2161 if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile || !kenckeypass) {
2162 error_print();
2163 return -1;
2164 }
2165 if (!tls_protocol_name(ctx->protocol)) {
2166 error_print();
2167 return -1;
2168 }
2169 if (ctx->certs) {
2170 error_print();
2171 return -1;
2172 }
2173
2174 if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
2175 error_print();
2176 return -1;
2177 }
2178
2179 if (!(signkeyfp = fopen(signkeyfile, "r"))) {
2180 error_print();
2181 goto end;
2182 }
2183 if (sm2_private_key_info_decrypt_from_pem(&signkey, signkeypass, signkeyfp) != 1) {
2184 error_print();
2185 goto end;
2186 }
2187 if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
2188 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1
2189 || sm2_public_key_equ(&signkey, &public_key) != 1) {
2190 error_print();
2191 goto end;
2192 }
2193
2194 if (!(kenckeyfp = fopen(kenckeyfile, "r"))) {
2195 error_print();
2196 goto end;
2197 }
2198 if (sm2_private_key_info_decrypt_from_pem(&kenckey, kenckeypass, kenckeyfp) != 1) {
2199 error_print();
2200 goto end;
2201 }
2202 if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1
2203 || x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1
2204 || sm2_public_key_equ(&kenckey, &public_key) != 1) {
2205 error_print();
2206 goto end;
2207 }
2208
2209 ctx->certs = certs;
2210 ctx->certslen = certslen;
2211 ctx->signkey = signkey;
2212 ctx->kenckey = kenckey;
2213 certs = NULL;
2214 ret = 1;
2215
2216 end:
2217 gmssl_secure_clear(&signkey, sizeof(signkey));
2218 gmssl_secure_clear(&kenckey, sizeof(kenckey));
2219 if (certs) free(certs);
2220 if (signkeyfp) fclose(signkeyfp);
2221 if (kenckeyfp) fclose(kenckeyfp);
2222 return ret;
2223 }
2224
tls_init(TLS_CONNECT * conn,const TLS_CTX * ctx)2225 int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx)
2226 {
2227 size_t i;
2228 memset(conn, 0, sizeof(*conn));
2229
2230 conn->protocol = ctx->protocol;
2231 conn->is_client = ctx->is_client;
2232 for (i = 0; i < ctx->cipher_suites_cnt; i++) {
2233 conn->cipher_suites[i] = ctx->cipher_suites[i];
2234 }
2235 conn->cipher_suites_cnt = ctx->cipher_suites_cnt;
2236
2237
2238 if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) {
2239 error_print();
2240 return -1;
2241 }
2242 if (conn->is_client) {
2243 memcpy(conn->client_certs, ctx->certs, ctx->certslen);
2244 conn->client_certs_len = ctx->certslen;
2245 } else {
2246 memcpy(conn->server_certs, ctx->certs, ctx->certslen);
2247 conn->server_certs_len = ctx->certslen;
2248 }
2249
2250 if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) {
2251 error_print();
2252 return -1;
2253 }
2254 memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen);
2255 conn->ca_certs_len = ctx->cacertslen;
2256
2257 conn->sign_key = ctx->signkey;
2258 conn->kenc_key = ctx->kenckey;
2259
2260 return 1;
2261 }
2262
tls_cleanup(TLS_CONNECT * conn)2263 void tls_cleanup(TLS_CONNECT *conn)
2264 {
2265 gmssl_secure_clear(conn, sizeof(TLS_CONNECT));
2266 }
2267
2268
tls_set_socket(TLS_CONNECT * conn,int sock)2269 int tls_set_socket(TLS_CONNECT *conn, int sock)
2270 {
2271 int opts;
2272
2273 if ((opts = fcntl(sock, F_GETFL)) < 0) {
2274 error_print();
2275 perror("tls_set_socket");
2276 return -1;
2277 }
2278 opts &= ~O_NONBLOCK;
2279 if (fcntl(sock, F_SETFL, opts) < 0) {
2280 error_print();
2281 return -1;
2282 }
2283 conn->sock = sock;
2284 return 1;
2285 }
2286
tls_do_handshake(TLS_CONNECT * conn)2287 int tls_do_handshake(TLS_CONNECT *conn)
2288 {
2289 switch (conn->protocol) {
2290 case TLS_protocol_tlcp:
2291 if (conn->is_client) return tlcp_do_connect(conn);
2292 else return tlcp_do_accept(conn);
2293 case TLS_protocol_tls12:
2294 if (conn->is_client) return tls12_do_connect(conn);
2295 else return tls12_do_accept(conn);
2296 case TLS_protocol_tls13:
2297 if (conn->is_client) return tls13_do_connect(conn);
2298 else return tls13_do_accept(conn);
2299 }
2300 error_print();
2301 return -1;
2302 }
2303
tls_get_verify_result(TLS_CONNECT * conn,int * result)2304 int tls_get_verify_result(TLS_CONNECT *conn, int *result)
2305 {
2306 *result = conn->verify_result;
2307 return 1;
2308 }
2309