• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3  *
4  *  Licensed under the Apache License, Version 2.0 (the License); you may
5  *  not use this file except in compliance with the License.
6  *
7  *  http://www.apache.org/licenses/LICENSE-2.0
8  */
9 
10 
11 
12 #include <time.h>
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <string.h>
16 #include <unistd.h>
17 #include <fcntl.h>
18 #include <sys/types.h>
19 #include <arpa/inet.h>
20 #include <sys/socket.h>
21 #include <netinet/in.h>
22 #include <gmssl/rand.h>
23 #include <gmssl/x509.h>
24 #include <gmssl/error.h>
25 #include <gmssl/sm2.h>
26 #include <gmssl/sm3.h>
27 #include <gmssl/sm4.h>
28 #include <gmssl/pem.h>
29 #include <gmssl/mem.h>
30 #include <gmssl/tls.h>
31 
32 
33 
34 static const int tls12_ciphers[] = {
35 	TLS_cipher_ecdhe_sm4_cbc_sm3,
36 };
37 
38 static const size_t tls12_ciphers_count = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]);
39 
40 static const uint8_t tls12_exts[] = {
41 	/* supported_groups */ 0x00,0x0A, 0x00,0x04, 0x00,0x02, 0x00,30,//0x29, // curveSM2
42 	/* ec_point_formats */ 0x00,0x0B, 0x00,0x02, 0x01,      0x00, // uncompressed
43 	/* signature_algors */ 0x00,0x0D, 0x00,0x04, 0x00,0x02, 0x07,0x07,//0x08, // sm2sig_sm3
44 };
45 
46 
tls12_record_print(FILE * fp,const uint8_t * record,size_t recordlen,int format,int indent)47 int tls12_record_print(FILE *fp, const uint8_t *record,  size_t recordlen, int format, int indent)
48 {
49 	// 目前只支持TLCP的ECC公钥加密套件,因此不论用哪个套件解析都是一样的
50 	// 如果未来支持ECDHE套件,可以将函数改为宏,直接传入 (conn->cipher_suite << 8)
51 	format |= tls12_ciphers[0] << 8;
52 	return tls_record_print(fp, record, recordlen, format, indent);
53 }
54 
55 
tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t * record,size_t * recordlen,int curve,const SM2_POINT * point,const uint8_t * sig,size_t siglen)56 int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t *recordlen,
57 	int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen)
58 {
59 	int type = TLS_handshake_server_key_exchange;
60 	uint8_t *server_ecdh_params = record + 9;
61 	uint8_t *p = server_ecdh_params + 69;
62 	size_t len = 69;
63 
64 	if (!record || !recordlen || !tls_named_curve_name(curve) || !point
65 		|| !sig || !siglen || siglen > TLS_MAX_SIGNATURE_SIZE) {
66 		error_print();
67 		return -1;
68 	}
69 	server_ecdh_params[0] = TLS_curve_type_named_curve;
70 	server_ecdh_params[1] = curve >> 8;
71 	server_ecdh_params[2] = curve;
72 	server_ecdh_params[3] = 65;
73 	sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
74 	tls_uint16_to_bytes(TLS_sig_sm2sig_sm3, &p, &len);
75 	tls_uint16array_to_bytes(sig, siglen, &p, &len);
76 	tls_record_set_handshake(record, recordlen, type, NULL, len);
77 	return 1;
78 }
79 
80 // 这里返回的应该是一个SM2_POINT吗?
tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t * record,int * curve,SM2_POINT * point,const uint8_t ** sig,size_t * siglen)81 int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record,
82 	int *curve, SM2_POINT *point, const uint8_t **sig, size_t *siglen)
83 {
84 	int type;
85 	const uint8_t *p;
86 	size_t len;
87 	uint8_t curve_type;
88 	uint16_t named_curve;
89 	const uint8_t *octets;
90 	size_t octetslen;
91 	uint16_t sig_alg;
92 
93 	if (!record || !curve || !point || !sig || !siglen) {
94 		error_print();
95 		return -1;
96 	}
97 	if (tls_record_get_handshake(record, &type, &p, &len) != 1
98 		|| type != TLS_handshake_server_key_exchange) {
99 		error_print();
100 		return -1;
101 	}
102 	if (tls_uint8_from_bytes(&curve_type, &p, &len) != 1
103 		|| tls_uint16_from_bytes(&named_curve, &p, &len) != 1
104 		|| tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1
105 		|| tls_uint16_from_bytes(&sig_alg, &p, &len) != 1
106 		|| tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1
107 		|| tls_length_is_zero(len) != 1) {
108 		error_print();
109 		return -1;
110 	}
111 	if (curve_type != TLS_curve_type_named_curve) {
112 		error_print();
113 		return -1;
114 	}
115 	if (named_curve != TLS_curve_sm2p256v1) {
116 		error_print();
117 		return -1;
118 	}
119 	*curve = named_curve;
120 	if (octetslen != 65
121 		|| sm2_point_from_octets(point, octets, octetslen) != 1) {
122 		error_print();
123 		return -1;
124 	}
125 	if (sig_alg != TLS_sig_sm2sig_sm3) {
126 		error_print();
127 		return -1;
128 	}
129 	return 1;
130 }
131 
tls_record_set_handshake_client_key_exchange_ecdhe(uint8_t * record,size_t * recordlen,const SM2_POINT * point)132 int tls_record_set_handshake_client_key_exchange_ecdhe(uint8_t *record, size_t *recordlen,
133 	const SM2_POINT *point)
134 {
135 	int type = TLS_handshake_client_key_exchange;
136 	record[9] = 65;
137 	sm2_point_to_uncompressed_octets(point, record + 9 + 1);
138 	tls_record_set_handshake(record, recordlen, type, NULL, 1 + 65);
139 	return 1;
140 }
141 
tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t * record,SM2_POINT * point)142 int tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t *record, SM2_POINT *point)
143 {
144 	int type;
145 	const uint8_t *p;
146 	size_t len;
147 	const uint8_t *octets;
148 	size_t octetslen;
149 
150 	if (tls_record_get_handshake(record, &type, &p, &len) != 1
151 		|| type != TLS_handshake_client_key_exchange) {
152 		error_print();
153 		return -1;
154 	}
155 	if (tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1
156 		|| len > 0) {
157 		error_print();
158 		return -1;
159 	}
160 	if (octetslen != 65
161 		|| sm2_point_from_octets(point, octets, octetslen) != 1) {
162 		error_print();
163 		return -1;
164 	}
165 	return 1;
166 }
167 
168 /*
169       Client                                               Server
170 
171       ClientHello                  -------->
172                                                       ServerHello
173                                                       Certificate
174                                                 ServerKeyExchange
175                                               CertificateRequest*
176                                    <--------      ServerHelloDone
177       Certificate*
178       ClientKeyExchange
179       CertificateVerify*
180       [ChangeCipherSpec]
181       Finished                     -------->
182                                                [ChangeCipherSpec]
183                                    <--------             Finished
184       Application Data             <------->     Application Data
185 
186 
187 */
188 
tls12_do_connect(TLS_CONNECT * conn)189 int tls12_do_connect(TLS_CONNECT *conn)
190 {
191 	int ret = -1;
192 	uint8_t *record = conn->record;
193 	uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
194 	size_t recordlen, finished_record_len;
195 
196 	uint8_t client_random[32];
197 	uint8_t server_random[32];
198 	int protocol;
199 	int cipher_suite;
200 	const uint8_t *random;
201 	const uint8_t *session_id;
202 	size_t session_id_len;
203 
204 	uint8_t client_exts[TLS_MAX_EXTENSIONS_SIZE];
205 	size_t client_exts_len = 0;
206 	const uint8_t *server_exts;
207 	size_t server_exts_len;
208 
209 	// 扩展的协商结果,-1 表示服务器不支持该扩展(未给出响应)
210 	int ec_point_format = -1;
211 	int supported_group = -1;
212 	int signature_algor = -1;
213 
214 
215 	SM2_KEY server_sign_key;
216 	SM2_SIGN_CTX verify_ctx;
217 	SM2_SIGN_CTX sign_ctx;
218 	const uint8_t *sig;
219 	size_t siglen;
220 	uint8_t pre_master_secret[48];
221 	SM3_CTX sm3_ctx;
222 	SM3_CTX tmp_sm3_ctx;
223 	uint8_t sm3_hash[32];
224 	const uint8_t *verify_data;
225 	size_t verify_data_len;
226 	uint8_t local_verify_data[12];
227 
228 	int handshake_type;
229 	const uint8_t *server_enc_cert; // 这几个值也是不需要的
230 	size_t server_enc_cert_len;
231 	uint8_t server_enc_cert_lenbuf[3];
232 	const uint8_t *cp;
233 	uint8_t *p;
234 	size_t len;
235 
236 	int depth = 5;
237 	int alert = 0;
238 	int verify_result;
239 
240 
241 	// 初始化记录缓冲
242 	tls_record_set_protocol(record, TLS_protocol_tls1); // ClientHello的记录层协议版本是TLSv1.0
243 	tls_record_set_protocol(finished_record, conn->protocol);
244 
245 	// 准备Finished Context(和ClientVerify)
246 	sm3_init(&sm3_ctx);
247 	if (conn->client_certs_len)
248 		sm2_sign_init(&sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
249 
250 
251 	// send ClientHello
252 	tls_random_generate(client_random);
253 	int ec_point_formats[] = { TLS_point_uncompressed };
254 	size_t ec_point_formats_cnt = 1;
255 	int supported_groups[] = { TLS_curve_sm2p256v1 };
256 	size_t supported_groups_cnt = 1;
257 	int signature_algors[] = { TLS_sig_sm2sig_sm3 };
258 	size_t signature_algors_cnt = 1;
259 
260 
261 	p = client_exts;
262 	client_exts_len = 0;
263 
264 	tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &p, &client_exts_len);
265 	tls_supported_groups_ext_to_bytes(supported_groups, supported_groups_cnt, &p, &client_exts_len);
266 	tls_signature_algorithms_ext_to_bytes(signature_algors, signature_algors_cnt, &p, &client_exts_len);
267 
268 	if (tls_record_set_handshake_client_hello(record, &recordlen,
269 		conn->protocol, client_random, NULL, 0,
270 		tls12_ciphers, tls12_ciphers_count,
271 		client_exts, client_exts_len) != 1) {
272 		error_print();
273 		goto end;
274 	}
275 	tls_trace("send ClientHello\n");
276 	tls12_record_trace(stderr, record, recordlen, 0, 0);
277 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
278 		error_print();
279 		goto end;
280 	}
281 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
282 	if (conn->client_certs_len)
283 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
284 
285 	// recv ServerHello
286 	tls_trace("recv ServerHello\n");
287 	if (tls_record_recv(record, &recordlen, conn->sock) != 1) {
288 		error_print();
289 		tls_send_alert(conn, TLS_alert_unexpected_message);
290 		goto end;
291 	}
292 	tls12_record_trace(stderr, record, recordlen, 0, 0);
293 	if (tls_record_protocol(record) != conn->protocol) {
294 		error_print();
295 		tls_send_alert(conn, TLS_alert_protocol_version);
296 		goto end;
297 	}
298 	if (tls_record_get_handshake_server_hello(record,
299 		&protocol, &random, &session_id, &session_id_len, &cipher_suite,
300 		&server_exts, &server_exts_len) != 1) {
301 		error_print();
302 		tls_send_alert(conn, TLS_alert_unexpected_message);
303 		goto end;
304 	}
305 	if (protocol != conn->protocol) {
306 		error_print();
307 		tls_send_alert(conn, TLS_alert_protocol_version);
308 		goto end;
309 	}
310 	// tls12_ciphers 应该改为conn的内部变量
311 	if (tls_cipher_suite_in_list(cipher_suite, tls12_ciphers, tls12_ciphers_count) != 1) {
312 		error_print();
313 		tls_send_alert(conn, TLS_alert_handshake_failure);
314 		goto end;
315 	}
316 	if (!server_exts) {
317 		error_print();
318 		tls_send_alert(conn, TLS_alert_unexpected_message);
319 		goto end;
320 	}
321 	if (tls_process_server_hello_exts(server_exts, server_exts_len, &ec_point_format, &supported_group, &signature_algor) != 1
322 		|| ec_point_format < 0
323 		|| supported_group < 0
324 		|| signature_algor < 0) {
325 		error_print();
326 		tls_send_alert(conn, TLS_alert_unexpected_message);
327 		goto end;
328 	}
329 	memcpy(server_random, random, 32);
330 	memcpy(conn->session_id, session_id, session_id_len);
331 	conn->cipher_suite = cipher_suite;
332 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
333 	if (conn->client_certs_len)
334 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
335 
336 	// recv ServerCertificate
337 	tls_trace("recv ServerCertificate\n");
338 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
339 		|| tls_record_protocol(record) != conn->protocol) {
340 		error_print();
341 		tls_send_alert(conn, TLS_alert_unexpected_message);
342 		goto end;
343 	}
344 	tls12_record_trace(stderr, record, recordlen, 0, 0);
345 
346 	if (tls_record_get_handshake_certificate(record,
347 		conn->server_certs, &conn->server_certs_len) != 1) {
348 		error_print();
349 		tls_send_alert(conn, TLS_alert_unexpected_message);
350 		goto end;
351 	}
352 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
353 	if (conn->client_certs_len)
354 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
355 
356 	// verify ServerCertificate
357 	if (x509_certs_verify(conn->server_certs, conn->server_certs_len,
358 		conn->ca_certs, conn->ca_certs_len, depth, &verify_result) != 1) {
359 		error_print();
360 		tls_send_alert(conn, alert);
361 		goto end;
362 	}
363 
364 	// recv ServerKeyExchange
365 	tls_trace("recv ServerKeyExchange\n");
366 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
367 		|| tls_record_protocol(record) != conn->protocol) {
368 		error_print();
369 		tls_send_alert(conn, TLS_alert_unexpected_message);
370 		goto end;
371 	}
372 	tls12_record_trace(stderr, record, recordlen, 0, 0);
373 
374 	int curve;
375 	SM2_POINT server_ecdhe_public;
376 	if (tls_record_get_handshake_server_key_exchange_ecdhe(record, &curve, &server_ecdhe_public, &sig, &siglen) != 1) {
377 		error_print();
378 		tls_send_alert(conn, TLS_alert_unexpected_message);
379 		goto end;
380 	}
381 	if (curve != TLS_curve_sm2p256v1) {
382 		error_print();
383 		tls_send_alert(conn, TLS_alert_unexpected_message);
384 		goto end;
385 	}
386 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
387 	if (conn->client_certs_len)
388 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
389 
390 	// verify ServerKeyExchange
391 	if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, &cp, &len) != 1
392 		|| x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1) {
393 		error_print();
394 		tls_send_alert(conn, TLS_alert_bad_certificate);
395 		goto end;
396 	}
397 	if (tls_verify_server_ecdh_params(&server_sign_key, // 这应该是签名公钥
398 		client_random, server_random, curve, &server_ecdhe_public, sig, siglen) != 1) {
399 		error_print();
400 		tls_send_alert(conn, TLS_alert_internal_error);
401 		goto end;
402 	}
403 
404 	// recv CertificateRequest or ServerHelloDone
405 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
406 		|| tls_record_protocol(record) != conn->protocol
407 		|| tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) {
408 		error_print();
409 		tls_send_alert(conn, TLS_alert_unexpected_message);
410 		goto end;
411 	}
412 	if (handshake_type == TLS_handshake_certificate_request) {
413 		const uint8_t *cert_types;
414 		size_t cert_types_len;
415 		const uint8_t *ca_names;
416 		size_t ca_names_len;
417 
418 		// recv CertificateRequest
419 		tls_trace("recv CertificateRequest\n");
420 		tls12_record_trace(stderr, record, recordlen, 0, 0);
421 		if (tls_record_get_handshake_certificate_request(record,
422 			&cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) {
423 			error_print();
424 			tls_send_alert(conn, TLS_alert_unexpected_message);
425 			goto end;
426 		}
427 		if(!conn->client_certs_len) {
428 			error_print();
429 			tls_send_alert(conn, TLS_alert_internal_error);
430 			goto end;
431 		}
432 		if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1
433 			|| tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) {
434 			error_print();
435 			tls_send_alert(conn, TLS_alert_unsupported_certificate);
436 			goto end;
437 		}
438 		sm3_update(&sm3_ctx, record + 5, recordlen - 5);
439 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
440 
441 		// recv ServerHelloDone
442 		if (tls_record_recv(record, &recordlen, conn->sock) != 1
443 			|| tls_record_protocol(record) != conn->protocol) {
444 			error_print();
445 			tls_send_alert(conn, TLS_alert_unexpected_message);
446 			goto end;
447 		}
448 	} else {
449 		// 这个得处理一下
450 		conn->client_certs_len = 0;
451 		gmssl_secure_clear(&conn->sign_key, sizeof(SM2_KEY));
452 	}
453 	tls_trace("recv ServerHelloDone\n");
454 	tls12_record_trace(stderr, record, recordlen, 0, 0);
455 	if (tls_record_get_handshake_server_hello_done(record) != 1) {
456 		error_print();
457 		tls_send_alert(conn, TLS_alert_unexpected_message);
458 		goto end;
459 	}
460 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
461 	if (conn->client_certs_len)
462 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
463 
464 	// send ClientCertificate
465 	if (conn->client_certs_len) {
466 		tls_trace("send ClientCertificate\n");
467 		if (tls_record_set_handshake_certificate(record, &recordlen, conn->client_certs, conn->client_certs_len) != 1) {
468 			error_print();
469 			tls_send_alert(conn, TLS_alert_internal_error);
470 			goto end;
471 		}
472 		tls12_record_trace(stderr, record, recordlen, 0, 0);
473 		if (tls_record_send(record, recordlen, conn->sock) != 1) {
474 			error_print();
475 			goto end;
476 		}
477 		sm3_update(&sm3_ctx, record + 5, recordlen - 5);
478 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
479 	}
480 
481 	// generate MASTER_SECRET
482 	tls_trace("generate secrets\n");
483 	SM2_KEY client_ecdh;
484 	sm2_key_generate(&client_ecdh);
485 	sm2_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public);
486 	memcpy(pre_master_secret, &server_ecdhe_public, 32); // 这个做法很不优雅
487 	// ECDHE和ECC的PMS结构是不一样的吗?
488 
489 	if (tls_prf(pre_master_secret, 32, "master secret",
490 			client_random, 32, server_random, 32,
491 			48, conn->master_secret) != 1
492 		|| tls_prf(conn->master_secret, 48, "key expansion",
493 			server_random, 32, client_random, 32,
494 			96, conn->key_block) != 1) {
495 		error_print();
496 		tls_send_alert(conn, TLS_alert_internal_error);
497 		goto end;
498 	}
499 	sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
500 	sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
501 	sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
502 	sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
503 	/*
504 	tls_secrets_print(stderr,
505 		pre_master_secret, 48,
506 		client_random, server_random,
507 		conn->master_secret,
508 		conn->key_block, 96,
509 		0, 4);
510 	*/
511 
512 	// send ClientKeyExchange
513 	tls_trace("send ClientKeyExchange\n");
514 	if (tls_record_set_handshake_client_key_exchange_ecdhe(record, &recordlen, &client_ecdh.public_key) != 1) {
515 		error_print();
516 		tls_send_alert(conn, TLS_alert_internal_error);
517 		goto end;
518 	}
519 	tls12_record_trace(stderr, record, recordlen, 0, 0);
520 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
521 		error_print();
522 		goto end;
523 	}
524 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
525 	if (conn->client_certs_len)
526 		sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
527 
528 	// send CertificateVerify
529 	if (conn->client_certs_len) {
530 		tls_trace("send CertificateVerify\n");
531 		uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
532 		if (sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1
533 			|| tls_record_set_handshake_certificate_verify(record, &recordlen, sigbuf, siglen) != 1) {
534 			error_print();
535 			tls_send_alert(conn, TLS_alert_internal_error);
536 			goto end;
537 		}
538 		tls12_record_trace(stderr, record, recordlen, 0, 0);
539 		if (tls_record_send(record, recordlen, conn->sock) != 1) {
540 			error_print();
541 			goto end;
542 		}
543 		sm3_update(&sm3_ctx, record + 5, recordlen - 5);
544 	}
545 
546 	// send [ChangeCipherSpec]
547 	tls_trace("send [ChangeCipherSpec]\n");
548 	if (tls_record_set_change_cipher_spec(record, &recordlen) !=1) {
549 		error_print();
550 		tls_send_alert(conn, TLS_alert_internal_error);
551 		goto end;
552 	}
553 	tls12_record_trace(stderr, record, recordlen, 0, 0);
554 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
555 		error_print();
556 		goto end;
557 	}
558 
559 	// send Client Finished
560 	tls_trace("send Finished\n");
561 	memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(sm3_ctx));
562 	sm3_finish(&tmp_sm3_ctx, sm3_hash);
563 	if (tls_prf(conn->master_secret, 48, "client finished",
564 			sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1
565 		|| tls_record_set_handshake_finished(finished_record, &finished_record_len,
566 			local_verify_data, sizeof(local_verify_data)) != 1) {
567 		error_print();
568 		tls_send_alert(conn, TLS_alert_internal_error);
569 		goto end;
570 	}
571 	tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
572 	sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
573 
574 	// encrypt Client Finished
575 	tls_trace("encrypt Finished\n");
576 	if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
577 		conn->client_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
578 		error_print();
579 		tls_send_alert(conn, TLS_alert_internal_error);
580 		goto end;
581 	}
582 	tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
583 	tls_seq_num_incr(conn->client_seq_num);
584 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
585 		error_print();
586 		goto end;
587 	}
588 
589 	// [ChangeCipherSpec]
590 	tls_trace("recv [ChangeCipherSpec]\n");
591 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
592 		|| tls_record_protocol(record) != conn->protocol) {
593 		error_print();
594 		tls_send_alert(conn, TLS_alert_unexpected_message);
595 		goto end;
596 	}
597 	tls12_record_trace(stderr, record, recordlen, 0, 0);
598 	if (tls_record_get_change_cipher_spec(record) != 1) {
599 		error_print();
600 		tls_send_alert(conn, TLS_alert_unexpected_message);
601 		goto end;
602 	}
603 
604 	// Finished
605 	tls_trace("recv Finished\n");
606 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
607 		|| tls_record_protocol(record) != conn->protocol) {
608 		error_print();
609 		tls_send_alert(conn, TLS_alert_unexpected_message);
610 		goto end;
611 	}
612 	if (recordlen > sizeof(finished_record)) {
613 		error_print(); // 解密可能导致 finished_record 溢出
614 		tls_send_alert(conn, TLS_alert_bad_record_mac);
615 		goto end;
616 	}
617 	tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
618 	tls_trace("decrypt Finished\n");
619 	if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
620 		conn->server_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
621 		error_print();
622 		tls_send_alert(conn, TLS_alert_bad_record_mac);
623 		goto end;
624 	}
625 	tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
626 	tls_seq_num_incr(conn->server_seq_num);
627 	if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
628 		error_print();
629 		tls_send_alert(conn, TLS_alert_unexpected_message);
630 		goto end;
631 	}
632 	if (verify_data_len != sizeof(local_verify_data)) {
633 		error_print();
634 		tls_send_alert(conn, TLS_alert_unexpected_message);
635 		goto end;
636 	}
637 	sm3_finish(&sm3_ctx, sm3_hash);
638 	if (tls_prf(conn->master_secret, 48, "server finished",
639 		sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) {
640 		error_print();
641 		tls_send_alert(conn, TLS_alert_internal_error);
642 		goto end;
643 	}
644 	if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
645 		error_print();
646 		tls_send_alert(conn, TLS_alert_decrypt_error);
647 		goto end;
648 	}
649 	fprintf(stderr, "Connection established!\n");
650 
651 
652 	conn->protocol = conn->protocol;
653 	conn->cipher_suite = cipher_suite;
654 
655 	ret = 1;
656 
657 end:
658 	gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
659 	gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret));
660 	return 1;
661 }
662 
tls12_do_accept(TLS_CONNECT * conn)663 int tls12_do_accept(TLS_CONNECT *conn)
664 {
665 	int ret = -1;
666 
667 	int client_verify = 0;
668 
669 	uint8_t *record = conn->record;
670 	uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; // 解密可能导致前面的record被覆盖
671 	size_t recordlen, finished_record_len;
672 
673 	// 这个ciphers不是应该在CTX中设置的吗
674 	const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件
675 
676 	// ClientHello, ServerHello
677 	uint8_t client_random[32];
678 	uint8_t server_random[32];
679 	int protocol;
680 	const uint8_t *random;
681 	const uint8_t *session_id; // TLCP服务器忽略客户端SessionID,也不主动设置SessionID
682 	size_t session_id_len;
683 	const uint8_t *client_ciphers;
684 	size_t client_ciphers_len;
685 	const uint8_t *client_exts;
686 	size_t client_exts_len;
687 	uint8_t server_exts[TLS_MAX_EXTENSIONS_SIZE];
688 	size_t server_exts_len;
689 	int curve = TLS_curve_sm2p256v1; // 这个是否应该在conn中设置?
690 
691 	// ServerKeyExchange
692 	SM2_KEY server_ecdhe_key;
693 	SM2_SIGN_CTX sign_ctx;
694 	uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
695 	size_t siglen;
696 
697 	// ClientCertificate, CertificateVerify
698 	TLS_CLIENT_VERIFY_CTX client_verify_ctx;
699 	SM2_KEY client_sign_key;
700 	const uint8_t *sig;
701 	const int verify_depth = 5;
702 	int verify_result;
703 
704 	// ClientKeyExchange
705 	SM2_POINT client_ecdhe_point;
706 	uint8_t pre_master_secret[SM2_MAX_PLAINTEXT_SIZE]; // sm2_decrypt 保证输出不会溢出
707 	size_t pre_master_secret_len;
708 
709 	// Finished
710 	SM3_CTX sm3_ctx;
711 	SM3_CTX tmp_sm3_ctx;
712 	uint8_t sm3_hash[32];
713 	uint8_t local_verify_data[12];
714 	const uint8_t *verify_data;
715 	size_t verify_data_len;
716 
717 	uint8_t *p;
718 	const uint8_t *cp;
719 	size_t len;
720 
721 
722 	// 服务器端如果设置了CA
723 	if (conn->ca_certs_len)
724 		client_verify = 1;
725 
726 	// 初始化Finished和客户端验证环境
727 	sm3_init(&sm3_ctx);
728 	if (client_verify)
729 		tls_client_verify_init(&client_verify_ctx);
730 
731 
732 	// recv ClientHello
733 	tls_trace("recv ClientHello\n");
734 	if (tls_record_recv(record, &recordlen, conn->sock) != 1) {
735 		error_print();
736 		tls_send_alert(conn, TLS_alert_unexpected_message);
737 		goto end;
738 	}
739 	tls12_record_trace(stderr, record, recordlen, 0, 0);
740 	if (tls_record_protocol(record) != conn->protocol
741 		&& tls_record_protocol(record) != TLS_protocol_tls1) {
742 		error_print();
743 		tls_send_alert(conn, TLS_alert_protocol_version);
744 		goto end;
745 	}
746 	if (tls_record_get_handshake_client_hello(record,
747 		&protocol, &random, &session_id, &session_id_len,
748 		&client_ciphers, &client_ciphers_len,
749 		&client_exts, &client_exts_len) != 1) {
750 		error_print();
751 		tls_send_alert(conn, TLS_alert_unexpected_message);
752 		goto end;
753 	}
754 	if (protocol != conn->protocol) {
755 		error_print();
756 		tls_send_alert(conn, TLS_alert_protocol_version);
757 		goto end;
758 	}
759 	memcpy(client_random, random, 32);
760 	if (tls_cipher_suites_select(client_ciphers, client_ciphers_len,
761 		server_ciphers, sizeof(server_ciphers)/sizeof(server_ciphers[0]),
762 		&conn->cipher_suite) != 1) {
763 		error_print();
764 		tls_send_alert(conn, TLS_alert_insufficient_security);
765 		goto end;
766 	}
767 	if (client_exts) {
768 		server_exts_len = 0;
769 		curve = TLS_curve_sm2p256v1;
770 
771 		tls_process_client_hello_exts(client_exts, client_exts_len, server_exts, &server_exts_len, sizeof(server_exts));
772 
773 
774 
775 	}
776 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
777 	if (client_verify)
778 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
779 
780 
781 	// send ServerHello
782 	tls_trace("send ServerHello\n");
783 	tls_random_generate(server_random);
784 	tls_record_set_protocol(record, conn->protocol);
785 	if (tls_record_set_handshake_server_hello(record, &recordlen,
786 		conn->protocol, server_random, NULL, 0,
787 		conn->cipher_suite, server_exts, server_exts_len) != 1) {
788 		error_print();
789 		tls_send_alert(conn, TLS_alert_internal_error);
790 		goto end;
791 	}
792 	tls12_record_trace(stderr, record, recordlen, 0, 0);
793 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
794 		error_print();
795 		goto end;
796 	}
797 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
798 	if (client_verify)
799 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
800 
801 	// send ServerCertificate
802 	tls_trace("send ServerCertificate\n");
803 	if (tls_record_set_handshake_certificate(record, &recordlen,
804 		conn->server_certs, conn->server_certs_len) != 1) {
805 		error_print();
806 		tls_send_alert(conn, TLS_alert_internal_error);
807 		goto end;
808 	}
809 	tls12_record_trace(stderr, record, recordlen, 0, 0);
810 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
811 		error_print();
812 		goto end;
813 	}
814 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
815 	if (client_verify)
816 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
817 
818 	// send ServerKeyExchange
819 	tls_trace("send ServerKeyExchange\n");
820 	sm2_key_generate(&server_ecdhe_key);
821 	if (tls_sign_server_ecdh_params(&conn->sign_key,
822 		client_random, server_random, TLS_curve_sm2p256v1, &server_ecdhe_key.public_key,
823 		sigbuf, &siglen) != 1) {
824 		error_print();
825 		tls_send_alert(conn, TLS_alert_internal_error);
826 		return -1;
827 	}
828 	if (tls_record_set_handshake_server_key_exchange_ecdhe(record, &recordlen,
829 		curve, &server_ecdhe_key.public_key, sigbuf, siglen) != 1) {
830 		error_print();
831 		tls_send_alert(conn, TLS_alert_internal_error);
832 		goto end;
833 	}
834 	tls12_record_trace(stderr, record, recordlen, 0, 0);
835 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
836 		error_print();
837 		goto end;
838 	}
839 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
840 	if (client_verify)
841 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
842 
843 	// send CertificateRequest
844 	if (client_verify) {
845 		const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign };
846 		uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小,或直接输出到record缓冲
847 		size_t ca_names_len = 0;
848 
849 		tls_trace("send CertificateRequest\n");
850 		if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names),
851 			conn->ca_certs, conn->ca_certs_len) != 1) {
852 			error_print();
853 			goto end;
854 		}
855 		if (tls_record_set_handshake_certificate_request(record, &recordlen,
856 			cert_types, sizeof(cert_types),
857 			ca_names, ca_names_len) != 1) {
858 			error_print();
859 			goto end;
860 		}
861 		tls12_record_trace(stderr, record, recordlen, 0, 0);
862 		if (tls_record_send(record, recordlen, conn->sock) != 1) {
863 			error_print();
864 			goto end;
865 		}
866 		sm3_update(&sm3_ctx, record + 5, recordlen - 5);
867 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
868 	}
869 
870 	// send ServerHelloDone
871 	tls_trace("send ServerHelloDone\n");
872 	tls_record_set_handshake_server_hello_done(record, &recordlen);
873 	tls12_record_trace(stderr, record, recordlen, 0, 0);
874 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
875 		error_print();
876 		goto end;
877 	}
878 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
879 	if (client_verify)
880 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
881 
882 	// recv ClientCertificate
883 	if (conn->ca_certs_len) {
884 		tls_trace("recv ClientCertificate\n");
885 		if (tls_record_recv(record, &recordlen, conn->sock) != 1
886 			|| tls_record_protocol(record) != conn->protocol) { // protocol检查应该在trace之后
887 			error_print();
888 			tls_send_alert(conn, TLS_alert_unexpected_message);
889 			goto end;
890 		}
891 		tls12_record_trace(stderr, record, recordlen, 0, 0);
892 		if (tls_record_get_handshake_certificate(record, conn->client_certs, &conn->client_certs_len) != 1) {
893 			error_print();
894 			tls_send_alert(conn, TLS_alert_unexpected_message);
895 			goto end;
896 		}
897 		if (x509_certs_verify(conn->client_certs, conn->client_certs_len,
898 			conn->ca_certs, conn->ca_certs_len, verify_depth, &verify_result) != 1) {
899 			error_print();
900 			tls_send_alert(conn, TLS_alert_bad_certificate);
901 			goto end;
902 		}
903 		sm3_update(&sm3_ctx, record + 5, recordlen - 5);
904 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
905 	}
906 
907 	// recv ClientKeyExchange
908 	tls_trace("recv ClientKeyExchange\n");
909 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
910 		|| tls_record_protocol(record) != conn->protocol) {
911 		error_print();
912 		tls_send_alert(conn, TLS_alert_unexpected_message);
913 		goto end;
914 	}
915 	tls12_record_trace(stderr, record, recordlen, 0, 0); // 应该给tls12一个独立的trace
916 	if (tls_record_get_handshake_client_key_exchange_ecdhe(record, &client_ecdhe_point) != 1) {
917 		error_print();
918 		tls_send_alert(conn, TLS_alert_unexpected_message);
919 		goto end;
920 	}
921 
922 	sm3_update(&sm3_ctx, record + 5, recordlen - 5);
923 	if (client_verify)
924 		tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
925 
926 	// recv CertificateVerify
927 	if (client_verify) {
928 		tls_trace("recv CertificateVerify\n");
929 		if (tls_record_recv(record, &recordlen, conn->sock) != 1
930 			|| tls_record_protocol(record) != conn->protocol) {
931 			tls_send_alert(conn, TLS_alert_unexpected_message);
932 			error_print();
933 			goto end;
934 		}
935 		tls12_record_trace(stderr, record, recordlen, 0, 0);
936 		if (tls_record_get_handshake_certificate_verify(record, &sig, &siglen) != 1) {
937 			tls_send_alert(conn, TLS_alert_unexpected_message);
938 			error_print();
939 			goto end;
940 		}
941 		if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0, &cp, &len) != 1
942 			|| x509_cert_get_subject_public_key(cp, len, &client_sign_key) != 1) {
943 			error_print();
944 			tls_send_alert(conn, TLS_alert_bad_certificate);
945 			goto end;
946 		}
947 		if (tls_client_verify_finish(&client_verify_ctx, sig, siglen, &client_sign_key) != 1) {
948 			error_print();
949 			tls_send_alert(conn, TLS_alert_decrypt_error);
950 			goto end;
951 		}
952 		sm3_update(&sm3_ctx, record + 5, recordlen - 5);
953 	}
954 
955 	// generate secrets
956 	tls_trace("generate secrets\n");
957 	sm2_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point);
958 	memcpy(pre_master_secret, (uint8_t *)&client_ecdhe_point, 32); // 这里应该修改一下表示方式,比如get_xy()
959 	tls_prf(pre_master_secret, 32, "master secret",
960 		client_random, 32, server_random, 32,
961 		48, conn->master_secret);
962 	tls_prf(conn->master_secret, 48, "key expansion",
963 		server_random, 32, client_random, 32,
964 		96, conn->key_block);
965 	sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
966 	sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
967 	sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
968 	sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
969 	/*
970 	tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random,
971 		conn->master_secret, conn->key_block, 96, 0, 4);
972 	*/
973 
974 	// recv [ChangeCipherSpec]
975 	tls_trace("recv [ChangeCipherSpec]\n");
976 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
977 		|| tls_record_protocol(record) != conn->protocol) {
978 		error_print();
979 		tls_send_alert(conn, TLS_alert_unexpected_message);
980 		goto end;
981 	}
982 	tls12_record_trace(stderr, record, recordlen, 0, 0);
983 	if (tls_record_get_change_cipher_spec(record) != 1) {
984 		error_print();
985 		tls_send_alert(conn, TLS_alert_unexpected_message);
986 		goto end;
987 	}
988 
989 	// recv ClientFinished
990 	tls_trace("recv Finished\n");
991 	if (tls_record_recv(record, &recordlen, conn->sock) != 1
992 		|| tls_record_protocol(record) != conn->protocol) {
993 		error_print();
994 		tls_send_alert(conn, TLS_alert_unexpected_message);
995 		goto end;
996 	}
997 	if (recordlen > sizeof(finished_record)) {
998 		error_print();
999 		tls_send_alert(conn, TLS_alert_unexpected_message);
1000 		goto end;
1001 	}
1002 	tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
1003 
1004 	// decrypt ClientFinished
1005 	tls_trace("decrypt Finished\n");
1006 	if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
1007 		conn->client_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
1008 		error_print();
1009 		tls_send_alert(conn, TLS_alert_bad_record_mac);
1010 		goto end;
1011 	}
1012 	tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
1013 	tls_seq_num_incr(conn->client_seq_num);
1014 	if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
1015 		error_print();
1016 		tls_send_alert(conn, TLS_alert_bad_record_mac);
1017 		goto end;
1018 	}
1019 	if (verify_data_len != sizeof(local_verify_data)) {
1020 		error_print();
1021 		tls_send_alert(conn, TLS_alert_bad_record_mac);
1022 		goto end;
1023 	}
1024 
1025 	// verify ClientFinished
1026 	memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(SM3_CTX));
1027 	sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
1028 	sm3_finish(&tmp_sm3_ctx, sm3_hash);
1029 	if (tls_prf(conn->master_secret, 48, "client finished", sm3_hash, 32, NULL, 0,
1030 		sizeof(local_verify_data), local_verify_data) != 1) {
1031 		error_print();
1032 		tls_send_alert(conn, TLS_alert_internal_error);
1033 		goto end;
1034 	}
1035 	if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
1036 		error_puts("client_finished.verify_data verification failure");
1037 		tls_send_alert(conn, TLS_alert_decrypt_error);
1038 		goto end;
1039 	}
1040 
1041 	// send [ChangeCipherSpec]
1042 	tls_trace("send [ChangeCipherSpec]\n");
1043 	if (tls_record_set_change_cipher_spec(record, &recordlen) != 1) {
1044 		error_print();
1045 		tls_send_alert(conn, TLS_alert_internal_error);
1046 		goto end;
1047 	}
1048 	tls12_record_trace(stderr, record, recordlen, 0, 0);
1049 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
1050 		error_print();
1051 		goto end;
1052 	}
1053 
1054 	// send ServerFinished
1055 	tls_trace("send Finished\n");
1056 	sm3_finish(&sm3_ctx, sm3_hash);
1057 	if (tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0,
1058 			sizeof(local_verify_data), local_verify_data) != 1
1059 		|| tls_record_set_handshake_finished(finished_record, &finished_record_len,
1060 			local_verify_data, sizeof(local_verify_data)) != 1) {
1061 		error_print();
1062 		tls_send_alert(conn, TLS_alert_internal_error);
1063 		goto end;
1064 	}
1065 	tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
1066 	if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
1067 		conn->server_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
1068 		error_print();
1069 		tls_send_alert(conn, TLS_alert_internal_error);
1070 		goto end;
1071 	}
1072 	tls_trace("encrypt Finished\n");
1073 	tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
1074 	tls_seq_num_incr(conn->server_seq_num);
1075 	if (tls_record_send(record, recordlen, conn->sock) != 1) {
1076 		error_print();
1077 		goto end;
1078 	}
1079 
1080 	conn->protocol = conn->protocol;
1081 
1082 	fprintf(stderr, "Connection Established!\n\n");
1083 	ret = 1;
1084 
1085 end:
1086 	gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
1087 	gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret));
1088 	if (client_verify) tls_client_verify_cleanup(&client_verify_ctx);
1089 	return ret;
1090 }
1091