1 /*
2 * Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the License); you may
5 * not use this file except in compliance with the License.
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 */
9
10
11
12 #include <time.h>
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <string.h>
16 #include <unistd.h>
17 #include <fcntl.h>
18 #include <sys/types.h>
19 #include <arpa/inet.h>
20 #include <sys/socket.h>
21 #include <netinet/in.h>
22 #include <gmssl/rand.h>
23 #include <gmssl/x509.h>
24 #include <gmssl/error.h>
25 #include <gmssl/sm2.h>
26 #include <gmssl/sm3.h>
27 #include <gmssl/sm4.h>
28 #include <gmssl/pem.h>
29 #include <gmssl/mem.h>
30 #include <gmssl/tls.h>
31
32
33
34 static const int tls12_ciphers[] = {
35 TLS_cipher_ecdhe_sm4_cbc_sm3,
36 };
37
38 static const size_t tls12_ciphers_count = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]);
39
40 static const uint8_t tls12_exts[] = {
41 /* supported_groups */ 0x00,0x0A, 0x00,0x04, 0x00,0x02, 0x00,30,//0x29, // curveSM2
42 /* ec_point_formats */ 0x00,0x0B, 0x00,0x02, 0x01, 0x00, // uncompressed
43 /* signature_algors */ 0x00,0x0D, 0x00,0x04, 0x00,0x02, 0x07,0x07,//0x08, // sm2sig_sm3
44 };
45
46
tls12_record_print(FILE * fp,const uint8_t * record,size_t recordlen,int format,int indent)47 int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent)
48 {
49 // 目前只支持TLCP的ECC公钥加密套件,因此不论用哪个套件解析都是一样的
50 // 如果未来支持ECDHE套件,可以将函数改为宏,直接传入 (conn->cipher_suite << 8)
51 format |= tls12_ciphers[0] << 8;
52 return tls_record_print(fp, record, recordlen, format, indent);
53 }
54
55
tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t * record,size_t * recordlen,int curve,const SM2_POINT * point,const uint8_t * sig,size_t siglen)56 int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t *recordlen,
57 int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen)
58 {
59 int type = TLS_handshake_server_key_exchange;
60 uint8_t *server_ecdh_params = record + 9;
61 uint8_t *p = server_ecdh_params + 69;
62 size_t len = 69;
63
64 if (!record || !recordlen || !tls_named_curve_name(curve) || !point
65 || !sig || !siglen || siglen > TLS_MAX_SIGNATURE_SIZE) {
66 error_print();
67 return -1;
68 }
69 server_ecdh_params[0] = TLS_curve_type_named_curve;
70 server_ecdh_params[1] = curve >> 8;
71 server_ecdh_params[2] = curve;
72 server_ecdh_params[3] = 65;
73 sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
74 tls_uint16_to_bytes(TLS_sig_sm2sig_sm3, &p, &len);
75 tls_uint16array_to_bytes(sig, siglen, &p, &len);
76 tls_record_set_handshake(record, recordlen, type, NULL, len);
77 return 1;
78 }
79
80 // 这里返回的应该是一个SM2_POINT吗?
tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t * record,int * curve,SM2_POINT * point,const uint8_t ** sig,size_t * siglen)81 int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record,
82 int *curve, SM2_POINT *point, const uint8_t **sig, size_t *siglen)
83 {
84 int type;
85 const uint8_t *p;
86 size_t len;
87 uint8_t curve_type;
88 uint16_t named_curve;
89 const uint8_t *octets;
90 size_t octetslen;
91 uint16_t sig_alg;
92
93 if (!record || !curve || !point || !sig || !siglen) {
94 error_print();
95 return -1;
96 }
97 if (tls_record_get_handshake(record, &type, &p, &len) != 1
98 || type != TLS_handshake_server_key_exchange) {
99 error_print();
100 return -1;
101 }
102 if (tls_uint8_from_bytes(&curve_type, &p, &len) != 1
103 || tls_uint16_from_bytes(&named_curve, &p, &len) != 1
104 || tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1
105 || tls_uint16_from_bytes(&sig_alg, &p, &len) != 1
106 || tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1
107 || tls_length_is_zero(len) != 1) {
108 error_print();
109 return -1;
110 }
111 if (curve_type != TLS_curve_type_named_curve) {
112 error_print();
113 return -1;
114 }
115 if (named_curve != TLS_curve_sm2p256v1) {
116 error_print();
117 return -1;
118 }
119 *curve = named_curve;
120 if (octetslen != 65
121 || sm2_point_from_octets(point, octets, octetslen) != 1) {
122 error_print();
123 return -1;
124 }
125 if (sig_alg != TLS_sig_sm2sig_sm3) {
126 error_print();
127 return -1;
128 }
129 return 1;
130 }
131
tls_record_set_handshake_client_key_exchange_ecdhe(uint8_t * record,size_t * recordlen,const SM2_POINT * point)132 int tls_record_set_handshake_client_key_exchange_ecdhe(uint8_t *record, size_t *recordlen,
133 const SM2_POINT *point)
134 {
135 int type = TLS_handshake_client_key_exchange;
136 record[9] = 65;
137 sm2_point_to_uncompressed_octets(point, record + 9 + 1);
138 tls_record_set_handshake(record, recordlen, type, NULL, 1 + 65);
139 return 1;
140 }
141
tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t * record,SM2_POINT * point)142 int tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t *record, SM2_POINT *point)
143 {
144 int type;
145 const uint8_t *p;
146 size_t len;
147 const uint8_t *octets;
148 size_t octetslen;
149
150 if (tls_record_get_handshake(record, &type, &p, &len) != 1
151 || type != TLS_handshake_client_key_exchange) {
152 error_print();
153 return -1;
154 }
155 if (tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1
156 || len > 0) {
157 error_print();
158 return -1;
159 }
160 if (octetslen != 65
161 || sm2_point_from_octets(point, octets, octetslen) != 1) {
162 error_print();
163 return -1;
164 }
165 return 1;
166 }
167
168 /*
169 Client Server
170
171 ClientHello -------->
172 ServerHello
173 Certificate
174 ServerKeyExchange
175 CertificateRequest*
176 <-------- ServerHelloDone
177 Certificate*
178 ClientKeyExchange
179 CertificateVerify*
180 [ChangeCipherSpec]
181 Finished -------->
182 [ChangeCipherSpec]
183 <-------- Finished
184 Application Data <-------> Application Data
185
186
187 */
188
tls12_do_connect(TLS_CONNECT * conn)189 int tls12_do_connect(TLS_CONNECT *conn)
190 {
191 int ret = -1;
192 uint8_t *record = conn->record;
193 uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
194 size_t recordlen, finished_record_len;
195
196 uint8_t client_random[32];
197 uint8_t server_random[32];
198 int protocol;
199 int cipher_suite;
200 const uint8_t *random;
201 const uint8_t *session_id;
202 size_t session_id_len;
203
204 uint8_t client_exts[TLS_MAX_EXTENSIONS_SIZE];
205 size_t client_exts_len = 0;
206 const uint8_t *server_exts;
207 size_t server_exts_len;
208
209 // 扩展的协商结果,-1 表示服务器不支持该扩展(未给出响应)
210 int ec_point_format = -1;
211 int supported_group = -1;
212 int signature_algor = -1;
213
214
215 SM2_KEY server_sign_key;
216 SM2_SIGN_CTX verify_ctx;
217 SM2_SIGN_CTX sign_ctx;
218 const uint8_t *sig;
219 size_t siglen;
220 uint8_t pre_master_secret[48];
221 SM3_CTX sm3_ctx;
222 SM3_CTX tmp_sm3_ctx;
223 uint8_t sm3_hash[32];
224 const uint8_t *verify_data;
225 size_t verify_data_len;
226 uint8_t local_verify_data[12];
227
228 int handshake_type;
229 const uint8_t *server_enc_cert; // 这几个值也是不需要的
230 size_t server_enc_cert_len;
231 uint8_t server_enc_cert_lenbuf[3];
232 const uint8_t *cp;
233 uint8_t *p;
234 size_t len;
235
236 int depth = 5;
237 int alert = 0;
238 int verify_result;
239
240
241 // 初始化记录缓冲
242 tls_record_set_protocol(record, TLS_protocol_tls1); // ClientHello的记录层协议版本是TLSv1.0
243 tls_record_set_protocol(finished_record, conn->protocol);
244
245 // 准备Finished Context(和ClientVerify)
246 sm3_init(&sm3_ctx);
247 if (conn->client_certs_len)
248 sm2_sign_init(&sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
249
250
251 // send ClientHello
252 tls_random_generate(client_random);
253 int ec_point_formats[] = { TLS_point_uncompressed };
254 size_t ec_point_formats_cnt = 1;
255 int supported_groups[] = { TLS_curve_sm2p256v1 };
256 size_t supported_groups_cnt = 1;
257 int signature_algors[] = { TLS_sig_sm2sig_sm3 };
258 size_t signature_algors_cnt = 1;
259
260
261 p = client_exts;
262 client_exts_len = 0;
263
264 tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &p, &client_exts_len);
265 tls_supported_groups_ext_to_bytes(supported_groups, supported_groups_cnt, &p, &client_exts_len);
266 tls_signature_algorithms_ext_to_bytes(signature_algors, signature_algors_cnt, &p, &client_exts_len);
267
268 if (tls_record_set_handshake_client_hello(record, &recordlen,
269 conn->protocol, client_random, NULL, 0,
270 tls12_ciphers, tls12_ciphers_count,
271 client_exts, client_exts_len) != 1) {
272 error_print();
273 goto end;
274 }
275 tls_trace("send ClientHello\n");
276 tls12_record_trace(stderr, record, recordlen, 0, 0);
277 if (tls_record_send(record, recordlen, conn->sock) != 1) {
278 error_print();
279 goto end;
280 }
281 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
282 if (conn->client_certs_len)
283 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
284
285 // recv ServerHello
286 tls_trace("recv ServerHello\n");
287 if (tls_record_recv(record, &recordlen, conn->sock) != 1) {
288 error_print();
289 tls_send_alert(conn, TLS_alert_unexpected_message);
290 goto end;
291 }
292 tls12_record_trace(stderr, record, recordlen, 0, 0);
293 if (tls_record_protocol(record) != conn->protocol) {
294 error_print();
295 tls_send_alert(conn, TLS_alert_protocol_version);
296 goto end;
297 }
298 if (tls_record_get_handshake_server_hello(record,
299 &protocol, &random, &session_id, &session_id_len, &cipher_suite,
300 &server_exts, &server_exts_len) != 1) {
301 error_print();
302 tls_send_alert(conn, TLS_alert_unexpected_message);
303 goto end;
304 }
305 if (protocol != conn->protocol) {
306 error_print();
307 tls_send_alert(conn, TLS_alert_protocol_version);
308 goto end;
309 }
310 // tls12_ciphers 应该改为conn的内部变量
311 if (tls_cipher_suite_in_list(cipher_suite, tls12_ciphers, tls12_ciphers_count) != 1) {
312 error_print();
313 tls_send_alert(conn, TLS_alert_handshake_failure);
314 goto end;
315 }
316 if (!server_exts) {
317 error_print();
318 tls_send_alert(conn, TLS_alert_unexpected_message);
319 goto end;
320 }
321 if (tls_process_server_hello_exts(server_exts, server_exts_len, &ec_point_format, &supported_group, &signature_algor) != 1
322 || ec_point_format < 0
323 || supported_group < 0
324 || signature_algor < 0) {
325 error_print();
326 tls_send_alert(conn, TLS_alert_unexpected_message);
327 goto end;
328 }
329 memcpy(server_random, random, 32);
330 memcpy(conn->session_id, session_id, session_id_len);
331 conn->cipher_suite = cipher_suite;
332 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
333 if (conn->client_certs_len)
334 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
335
336 // recv ServerCertificate
337 tls_trace("recv ServerCertificate\n");
338 if (tls_record_recv(record, &recordlen, conn->sock) != 1
339 || tls_record_protocol(record) != conn->protocol) {
340 error_print();
341 tls_send_alert(conn, TLS_alert_unexpected_message);
342 goto end;
343 }
344 tls12_record_trace(stderr, record, recordlen, 0, 0);
345
346 if (tls_record_get_handshake_certificate(record,
347 conn->server_certs, &conn->server_certs_len) != 1) {
348 error_print();
349 tls_send_alert(conn, TLS_alert_unexpected_message);
350 goto end;
351 }
352 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
353 if (conn->client_certs_len)
354 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
355
356 // verify ServerCertificate
357 if (x509_certs_verify(conn->server_certs, conn->server_certs_len,
358 conn->ca_certs, conn->ca_certs_len, depth, &verify_result) != 1) {
359 error_print();
360 tls_send_alert(conn, alert);
361 goto end;
362 }
363
364 // recv ServerKeyExchange
365 tls_trace("recv ServerKeyExchange\n");
366 if (tls_record_recv(record, &recordlen, conn->sock) != 1
367 || tls_record_protocol(record) != conn->protocol) {
368 error_print();
369 tls_send_alert(conn, TLS_alert_unexpected_message);
370 goto end;
371 }
372 tls12_record_trace(stderr, record, recordlen, 0, 0);
373
374 int curve;
375 SM2_POINT server_ecdhe_public;
376 if (tls_record_get_handshake_server_key_exchange_ecdhe(record, &curve, &server_ecdhe_public, &sig, &siglen) != 1) {
377 error_print();
378 tls_send_alert(conn, TLS_alert_unexpected_message);
379 goto end;
380 }
381 if (curve != TLS_curve_sm2p256v1) {
382 error_print();
383 tls_send_alert(conn, TLS_alert_unexpected_message);
384 goto end;
385 }
386 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
387 if (conn->client_certs_len)
388 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
389
390 // verify ServerKeyExchange
391 if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, &cp, &len) != 1
392 || x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1) {
393 error_print();
394 tls_send_alert(conn, TLS_alert_bad_certificate);
395 goto end;
396 }
397 if (tls_verify_server_ecdh_params(&server_sign_key, // 这应该是签名公钥
398 client_random, server_random, curve, &server_ecdhe_public, sig, siglen) != 1) {
399 error_print();
400 tls_send_alert(conn, TLS_alert_internal_error);
401 goto end;
402 }
403
404 // recv CertificateRequest or ServerHelloDone
405 if (tls_record_recv(record, &recordlen, conn->sock) != 1
406 || tls_record_protocol(record) != conn->protocol
407 || tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) {
408 error_print();
409 tls_send_alert(conn, TLS_alert_unexpected_message);
410 goto end;
411 }
412 if (handshake_type == TLS_handshake_certificate_request) {
413 const uint8_t *cert_types;
414 size_t cert_types_len;
415 const uint8_t *ca_names;
416 size_t ca_names_len;
417
418 // recv CertificateRequest
419 tls_trace("recv CertificateRequest\n");
420 tls12_record_trace(stderr, record, recordlen, 0, 0);
421 if (tls_record_get_handshake_certificate_request(record,
422 &cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) {
423 error_print();
424 tls_send_alert(conn, TLS_alert_unexpected_message);
425 goto end;
426 }
427 if(!conn->client_certs_len) {
428 error_print();
429 tls_send_alert(conn, TLS_alert_internal_error);
430 goto end;
431 }
432 if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1
433 || tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) {
434 error_print();
435 tls_send_alert(conn, TLS_alert_unsupported_certificate);
436 goto end;
437 }
438 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
439 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
440
441 // recv ServerHelloDone
442 if (tls_record_recv(record, &recordlen, conn->sock) != 1
443 || tls_record_protocol(record) != conn->protocol) {
444 error_print();
445 tls_send_alert(conn, TLS_alert_unexpected_message);
446 goto end;
447 }
448 } else {
449 // 这个得处理一下
450 conn->client_certs_len = 0;
451 gmssl_secure_clear(&conn->sign_key, sizeof(SM2_KEY));
452 }
453 tls_trace("recv ServerHelloDone\n");
454 tls12_record_trace(stderr, record, recordlen, 0, 0);
455 if (tls_record_get_handshake_server_hello_done(record) != 1) {
456 error_print();
457 tls_send_alert(conn, TLS_alert_unexpected_message);
458 goto end;
459 }
460 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
461 if (conn->client_certs_len)
462 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
463
464 // send ClientCertificate
465 if (conn->client_certs_len) {
466 tls_trace("send ClientCertificate\n");
467 if (tls_record_set_handshake_certificate(record, &recordlen, conn->client_certs, conn->client_certs_len) != 1) {
468 error_print();
469 tls_send_alert(conn, TLS_alert_internal_error);
470 goto end;
471 }
472 tls12_record_trace(stderr, record, recordlen, 0, 0);
473 if (tls_record_send(record, recordlen, conn->sock) != 1) {
474 error_print();
475 goto end;
476 }
477 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
478 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
479 }
480
481 // generate MASTER_SECRET
482 tls_trace("generate secrets\n");
483 SM2_KEY client_ecdh;
484 sm2_key_generate(&client_ecdh);
485 sm2_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public);
486 memcpy(pre_master_secret, &server_ecdhe_public, 32); // 这个做法很不优雅
487 // ECDHE和ECC的PMS结构是不一样的吗?
488
489 if (tls_prf(pre_master_secret, 32, "master secret",
490 client_random, 32, server_random, 32,
491 48, conn->master_secret) != 1
492 || tls_prf(conn->master_secret, 48, "key expansion",
493 server_random, 32, client_random, 32,
494 96, conn->key_block) != 1) {
495 error_print();
496 tls_send_alert(conn, TLS_alert_internal_error);
497 goto end;
498 }
499 sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
500 sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
501 sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
502 sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
503 /*
504 tls_secrets_print(stderr,
505 pre_master_secret, 48,
506 client_random, server_random,
507 conn->master_secret,
508 conn->key_block, 96,
509 0, 4);
510 */
511
512 // send ClientKeyExchange
513 tls_trace("send ClientKeyExchange\n");
514 if (tls_record_set_handshake_client_key_exchange_ecdhe(record, &recordlen, &client_ecdh.public_key) != 1) {
515 error_print();
516 tls_send_alert(conn, TLS_alert_internal_error);
517 goto end;
518 }
519 tls12_record_trace(stderr, record, recordlen, 0, 0);
520 if (tls_record_send(record, recordlen, conn->sock) != 1) {
521 error_print();
522 goto end;
523 }
524 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
525 if (conn->client_certs_len)
526 sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
527
528 // send CertificateVerify
529 if (conn->client_certs_len) {
530 tls_trace("send CertificateVerify\n");
531 uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
532 if (sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1
533 || tls_record_set_handshake_certificate_verify(record, &recordlen, sigbuf, siglen) != 1) {
534 error_print();
535 tls_send_alert(conn, TLS_alert_internal_error);
536 goto end;
537 }
538 tls12_record_trace(stderr, record, recordlen, 0, 0);
539 if (tls_record_send(record, recordlen, conn->sock) != 1) {
540 error_print();
541 goto end;
542 }
543 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
544 }
545
546 // send [ChangeCipherSpec]
547 tls_trace("send [ChangeCipherSpec]\n");
548 if (tls_record_set_change_cipher_spec(record, &recordlen) !=1) {
549 error_print();
550 tls_send_alert(conn, TLS_alert_internal_error);
551 goto end;
552 }
553 tls12_record_trace(stderr, record, recordlen, 0, 0);
554 if (tls_record_send(record, recordlen, conn->sock) != 1) {
555 error_print();
556 goto end;
557 }
558
559 // send Client Finished
560 tls_trace("send Finished\n");
561 memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(sm3_ctx));
562 sm3_finish(&tmp_sm3_ctx, sm3_hash);
563 if (tls_prf(conn->master_secret, 48, "client finished",
564 sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1
565 || tls_record_set_handshake_finished(finished_record, &finished_record_len,
566 local_verify_data, sizeof(local_verify_data)) != 1) {
567 error_print();
568 tls_send_alert(conn, TLS_alert_internal_error);
569 goto end;
570 }
571 tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
572 sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
573
574 // encrypt Client Finished
575 tls_trace("encrypt Finished\n");
576 if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
577 conn->client_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
578 error_print();
579 tls_send_alert(conn, TLS_alert_internal_error);
580 goto end;
581 }
582 tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
583 tls_seq_num_incr(conn->client_seq_num);
584 if (tls_record_send(record, recordlen, conn->sock) != 1) {
585 error_print();
586 goto end;
587 }
588
589 // [ChangeCipherSpec]
590 tls_trace("recv [ChangeCipherSpec]\n");
591 if (tls_record_recv(record, &recordlen, conn->sock) != 1
592 || tls_record_protocol(record) != conn->protocol) {
593 error_print();
594 tls_send_alert(conn, TLS_alert_unexpected_message);
595 goto end;
596 }
597 tls12_record_trace(stderr, record, recordlen, 0, 0);
598 if (tls_record_get_change_cipher_spec(record) != 1) {
599 error_print();
600 tls_send_alert(conn, TLS_alert_unexpected_message);
601 goto end;
602 }
603
604 // Finished
605 tls_trace("recv Finished\n");
606 if (tls_record_recv(record, &recordlen, conn->sock) != 1
607 || tls_record_protocol(record) != conn->protocol) {
608 error_print();
609 tls_send_alert(conn, TLS_alert_unexpected_message);
610 goto end;
611 }
612 if (recordlen > sizeof(finished_record)) {
613 error_print(); // 解密可能导致 finished_record 溢出
614 tls_send_alert(conn, TLS_alert_bad_record_mac);
615 goto end;
616 }
617 tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
618 tls_trace("decrypt Finished\n");
619 if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
620 conn->server_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
621 error_print();
622 tls_send_alert(conn, TLS_alert_bad_record_mac);
623 goto end;
624 }
625 tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
626 tls_seq_num_incr(conn->server_seq_num);
627 if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
628 error_print();
629 tls_send_alert(conn, TLS_alert_unexpected_message);
630 goto end;
631 }
632 if (verify_data_len != sizeof(local_verify_data)) {
633 error_print();
634 tls_send_alert(conn, TLS_alert_unexpected_message);
635 goto end;
636 }
637 sm3_finish(&sm3_ctx, sm3_hash);
638 if (tls_prf(conn->master_secret, 48, "server finished",
639 sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) {
640 error_print();
641 tls_send_alert(conn, TLS_alert_internal_error);
642 goto end;
643 }
644 if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
645 error_print();
646 tls_send_alert(conn, TLS_alert_decrypt_error);
647 goto end;
648 }
649 fprintf(stderr, "Connection established!\n");
650
651
652 conn->protocol = conn->protocol;
653 conn->cipher_suite = cipher_suite;
654
655 ret = 1;
656
657 end:
658 gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
659 gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret));
660 return 1;
661 }
662
tls12_do_accept(TLS_CONNECT * conn)663 int tls12_do_accept(TLS_CONNECT *conn)
664 {
665 int ret = -1;
666
667 int client_verify = 0;
668
669 uint8_t *record = conn->record;
670 uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; // 解密可能导致前面的record被覆盖
671 size_t recordlen, finished_record_len;
672
673 // 这个ciphers不是应该在CTX中设置的吗
674 const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件
675
676 // ClientHello, ServerHello
677 uint8_t client_random[32];
678 uint8_t server_random[32];
679 int protocol;
680 const uint8_t *random;
681 const uint8_t *session_id; // TLCP服务器忽略客户端SessionID,也不主动设置SessionID
682 size_t session_id_len;
683 const uint8_t *client_ciphers;
684 size_t client_ciphers_len;
685 const uint8_t *client_exts;
686 size_t client_exts_len;
687 uint8_t server_exts[TLS_MAX_EXTENSIONS_SIZE];
688 size_t server_exts_len;
689 int curve = TLS_curve_sm2p256v1; // 这个是否应该在conn中设置?
690
691 // ServerKeyExchange
692 SM2_KEY server_ecdhe_key;
693 SM2_SIGN_CTX sign_ctx;
694 uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
695 size_t siglen;
696
697 // ClientCertificate, CertificateVerify
698 TLS_CLIENT_VERIFY_CTX client_verify_ctx;
699 SM2_KEY client_sign_key;
700 const uint8_t *sig;
701 const int verify_depth = 5;
702 int verify_result;
703
704 // ClientKeyExchange
705 SM2_POINT client_ecdhe_point;
706 uint8_t pre_master_secret[SM2_MAX_PLAINTEXT_SIZE]; // sm2_decrypt 保证输出不会溢出
707 size_t pre_master_secret_len;
708
709 // Finished
710 SM3_CTX sm3_ctx;
711 SM3_CTX tmp_sm3_ctx;
712 uint8_t sm3_hash[32];
713 uint8_t local_verify_data[12];
714 const uint8_t *verify_data;
715 size_t verify_data_len;
716
717 uint8_t *p;
718 const uint8_t *cp;
719 size_t len;
720
721
722 // 服务器端如果设置了CA
723 if (conn->ca_certs_len)
724 client_verify = 1;
725
726 // 初始化Finished和客户端验证环境
727 sm3_init(&sm3_ctx);
728 if (client_verify)
729 tls_client_verify_init(&client_verify_ctx);
730
731
732 // recv ClientHello
733 tls_trace("recv ClientHello\n");
734 if (tls_record_recv(record, &recordlen, conn->sock) != 1) {
735 error_print();
736 tls_send_alert(conn, TLS_alert_unexpected_message);
737 goto end;
738 }
739 tls12_record_trace(stderr, record, recordlen, 0, 0);
740 if (tls_record_protocol(record) != conn->protocol
741 && tls_record_protocol(record) != TLS_protocol_tls1) {
742 error_print();
743 tls_send_alert(conn, TLS_alert_protocol_version);
744 goto end;
745 }
746 if (tls_record_get_handshake_client_hello(record,
747 &protocol, &random, &session_id, &session_id_len,
748 &client_ciphers, &client_ciphers_len,
749 &client_exts, &client_exts_len) != 1) {
750 error_print();
751 tls_send_alert(conn, TLS_alert_unexpected_message);
752 goto end;
753 }
754 if (protocol != conn->protocol) {
755 error_print();
756 tls_send_alert(conn, TLS_alert_protocol_version);
757 goto end;
758 }
759 memcpy(client_random, random, 32);
760 if (tls_cipher_suites_select(client_ciphers, client_ciphers_len,
761 server_ciphers, sizeof(server_ciphers)/sizeof(server_ciphers[0]),
762 &conn->cipher_suite) != 1) {
763 error_print();
764 tls_send_alert(conn, TLS_alert_insufficient_security);
765 goto end;
766 }
767 if (client_exts) {
768 server_exts_len = 0;
769 curve = TLS_curve_sm2p256v1;
770
771 tls_process_client_hello_exts(client_exts, client_exts_len, server_exts, &server_exts_len, sizeof(server_exts));
772
773
774
775 }
776 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
777 if (client_verify)
778 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
779
780
781 // send ServerHello
782 tls_trace("send ServerHello\n");
783 tls_random_generate(server_random);
784 tls_record_set_protocol(record, conn->protocol);
785 if (tls_record_set_handshake_server_hello(record, &recordlen,
786 conn->protocol, server_random, NULL, 0,
787 conn->cipher_suite, server_exts, server_exts_len) != 1) {
788 error_print();
789 tls_send_alert(conn, TLS_alert_internal_error);
790 goto end;
791 }
792 tls12_record_trace(stderr, record, recordlen, 0, 0);
793 if (tls_record_send(record, recordlen, conn->sock) != 1) {
794 error_print();
795 goto end;
796 }
797 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
798 if (client_verify)
799 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
800
801 // send ServerCertificate
802 tls_trace("send ServerCertificate\n");
803 if (tls_record_set_handshake_certificate(record, &recordlen,
804 conn->server_certs, conn->server_certs_len) != 1) {
805 error_print();
806 tls_send_alert(conn, TLS_alert_internal_error);
807 goto end;
808 }
809 tls12_record_trace(stderr, record, recordlen, 0, 0);
810 if (tls_record_send(record, recordlen, conn->sock) != 1) {
811 error_print();
812 goto end;
813 }
814 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
815 if (client_verify)
816 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
817
818 // send ServerKeyExchange
819 tls_trace("send ServerKeyExchange\n");
820 sm2_key_generate(&server_ecdhe_key);
821 if (tls_sign_server_ecdh_params(&conn->sign_key,
822 client_random, server_random, TLS_curve_sm2p256v1, &server_ecdhe_key.public_key,
823 sigbuf, &siglen) != 1) {
824 error_print();
825 tls_send_alert(conn, TLS_alert_internal_error);
826 return -1;
827 }
828 if (tls_record_set_handshake_server_key_exchange_ecdhe(record, &recordlen,
829 curve, &server_ecdhe_key.public_key, sigbuf, siglen) != 1) {
830 error_print();
831 tls_send_alert(conn, TLS_alert_internal_error);
832 goto end;
833 }
834 tls12_record_trace(stderr, record, recordlen, 0, 0);
835 if (tls_record_send(record, recordlen, conn->sock) != 1) {
836 error_print();
837 goto end;
838 }
839 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
840 if (client_verify)
841 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
842
843 // send CertificateRequest
844 if (client_verify) {
845 const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign };
846 uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小,或直接输出到record缓冲
847 size_t ca_names_len = 0;
848
849 tls_trace("send CertificateRequest\n");
850 if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names),
851 conn->ca_certs, conn->ca_certs_len) != 1) {
852 error_print();
853 goto end;
854 }
855 if (tls_record_set_handshake_certificate_request(record, &recordlen,
856 cert_types, sizeof(cert_types),
857 ca_names, ca_names_len) != 1) {
858 error_print();
859 goto end;
860 }
861 tls12_record_trace(stderr, record, recordlen, 0, 0);
862 if (tls_record_send(record, recordlen, conn->sock) != 1) {
863 error_print();
864 goto end;
865 }
866 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
867 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
868 }
869
870 // send ServerHelloDone
871 tls_trace("send ServerHelloDone\n");
872 tls_record_set_handshake_server_hello_done(record, &recordlen);
873 tls12_record_trace(stderr, record, recordlen, 0, 0);
874 if (tls_record_send(record, recordlen, conn->sock) != 1) {
875 error_print();
876 goto end;
877 }
878 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
879 if (client_verify)
880 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
881
882 // recv ClientCertificate
883 if (conn->ca_certs_len) {
884 tls_trace("recv ClientCertificate\n");
885 if (tls_record_recv(record, &recordlen, conn->sock) != 1
886 || tls_record_protocol(record) != conn->protocol) { // protocol检查应该在trace之后
887 error_print();
888 tls_send_alert(conn, TLS_alert_unexpected_message);
889 goto end;
890 }
891 tls12_record_trace(stderr, record, recordlen, 0, 0);
892 if (tls_record_get_handshake_certificate(record, conn->client_certs, &conn->client_certs_len) != 1) {
893 error_print();
894 tls_send_alert(conn, TLS_alert_unexpected_message);
895 goto end;
896 }
897 if (x509_certs_verify(conn->client_certs, conn->client_certs_len,
898 conn->ca_certs, conn->ca_certs_len, verify_depth, &verify_result) != 1) {
899 error_print();
900 tls_send_alert(conn, TLS_alert_bad_certificate);
901 goto end;
902 }
903 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
904 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
905 }
906
907 // recv ClientKeyExchange
908 tls_trace("recv ClientKeyExchange\n");
909 if (tls_record_recv(record, &recordlen, conn->sock) != 1
910 || tls_record_protocol(record) != conn->protocol) {
911 error_print();
912 tls_send_alert(conn, TLS_alert_unexpected_message);
913 goto end;
914 }
915 tls12_record_trace(stderr, record, recordlen, 0, 0); // 应该给tls12一个独立的trace
916 if (tls_record_get_handshake_client_key_exchange_ecdhe(record, &client_ecdhe_point) != 1) {
917 error_print();
918 tls_send_alert(conn, TLS_alert_unexpected_message);
919 goto end;
920 }
921
922 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
923 if (client_verify)
924 tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
925
926 // recv CertificateVerify
927 if (client_verify) {
928 tls_trace("recv CertificateVerify\n");
929 if (tls_record_recv(record, &recordlen, conn->sock) != 1
930 || tls_record_protocol(record) != conn->protocol) {
931 tls_send_alert(conn, TLS_alert_unexpected_message);
932 error_print();
933 goto end;
934 }
935 tls12_record_trace(stderr, record, recordlen, 0, 0);
936 if (tls_record_get_handshake_certificate_verify(record, &sig, &siglen) != 1) {
937 tls_send_alert(conn, TLS_alert_unexpected_message);
938 error_print();
939 goto end;
940 }
941 if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0, &cp, &len) != 1
942 || x509_cert_get_subject_public_key(cp, len, &client_sign_key) != 1) {
943 error_print();
944 tls_send_alert(conn, TLS_alert_bad_certificate);
945 goto end;
946 }
947 if (tls_client_verify_finish(&client_verify_ctx, sig, siglen, &client_sign_key) != 1) {
948 error_print();
949 tls_send_alert(conn, TLS_alert_decrypt_error);
950 goto end;
951 }
952 sm3_update(&sm3_ctx, record + 5, recordlen - 5);
953 }
954
955 // generate secrets
956 tls_trace("generate secrets\n");
957 sm2_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point);
958 memcpy(pre_master_secret, (uint8_t *)&client_ecdhe_point, 32); // 这里应该修改一下表示方式,比如get_xy()
959 tls_prf(pre_master_secret, 32, "master secret",
960 client_random, 32, server_random, 32,
961 48, conn->master_secret);
962 tls_prf(conn->master_secret, 48, "key expansion",
963 server_random, 32, client_random, 32,
964 96, conn->key_block);
965 sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
966 sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
967 sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
968 sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
969 /*
970 tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random,
971 conn->master_secret, conn->key_block, 96, 0, 4);
972 */
973
974 // recv [ChangeCipherSpec]
975 tls_trace("recv [ChangeCipherSpec]\n");
976 if (tls_record_recv(record, &recordlen, conn->sock) != 1
977 || tls_record_protocol(record) != conn->protocol) {
978 error_print();
979 tls_send_alert(conn, TLS_alert_unexpected_message);
980 goto end;
981 }
982 tls12_record_trace(stderr, record, recordlen, 0, 0);
983 if (tls_record_get_change_cipher_spec(record) != 1) {
984 error_print();
985 tls_send_alert(conn, TLS_alert_unexpected_message);
986 goto end;
987 }
988
989 // recv ClientFinished
990 tls_trace("recv Finished\n");
991 if (tls_record_recv(record, &recordlen, conn->sock) != 1
992 || tls_record_protocol(record) != conn->protocol) {
993 error_print();
994 tls_send_alert(conn, TLS_alert_unexpected_message);
995 goto end;
996 }
997 if (recordlen > sizeof(finished_record)) {
998 error_print();
999 tls_send_alert(conn, TLS_alert_unexpected_message);
1000 goto end;
1001 }
1002 tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
1003
1004 // decrypt ClientFinished
1005 tls_trace("decrypt Finished\n");
1006 if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
1007 conn->client_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
1008 error_print();
1009 tls_send_alert(conn, TLS_alert_bad_record_mac);
1010 goto end;
1011 }
1012 tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
1013 tls_seq_num_incr(conn->client_seq_num);
1014 if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
1015 error_print();
1016 tls_send_alert(conn, TLS_alert_bad_record_mac);
1017 goto end;
1018 }
1019 if (verify_data_len != sizeof(local_verify_data)) {
1020 error_print();
1021 tls_send_alert(conn, TLS_alert_bad_record_mac);
1022 goto end;
1023 }
1024
1025 // verify ClientFinished
1026 memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(SM3_CTX));
1027 sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
1028 sm3_finish(&tmp_sm3_ctx, sm3_hash);
1029 if (tls_prf(conn->master_secret, 48, "client finished", sm3_hash, 32, NULL, 0,
1030 sizeof(local_verify_data), local_verify_data) != 1) {
1031 error_print();
1032 tls_send_alert(conn, TLS_alert_internal_error);
1033 goto end;
1034 }
1035 if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
1036 error_puts("client_finished.verify_data verification failure");
1037 tls_send_alert(conn, TLS_alert_decrypt_error);
1038 goto end;
1039 }
1040
1041 // send [ChangeCipherSpec]
1042 tls_trace("send [ChangeCipherSpec]\n");
1043 if (tls_record_set_change_cipher_spec(record, &recordlen) != 1) {
1044 error_print();
1045 tls_send_alert(conn, TLS_alert_internal_error);
1046 goto end;
1047 }
1048 tls12_record_trace(stderr, record, recordlen, 0, 0);
1049 if (tls_record_send(record, recordlen, conn->sock) != 1) {
1050 error_print();
1051 goto end;
1052 }
1053
1054 // send ServerFinished
1055 tls_trace("send Finished\n");
1056 sm3_finish(&sm3_ctx, sm3_hash);
1057 if (tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0,
1058 sizeof(local_verify_data), local_verify_data) != 1
1059 || tls_record_set_handshake_finished(finished_record, &finished_record_len,
1060 local_verify_data, sizeof(local_verify_data)) != 1) {
1061 error_print();
1062 tls_send_alert(conn, TLS_alert_internal_error);
1063 goto end;
1064 }
1065 tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
1066 if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
1067 conn->server_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
1068 error_print();
1069 tls_send_alert(conn, TLS_alert_internal_error);
1070 goto end;
1071 }
1072 tls_trace("encrypt Finished\n");
1073 tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
1074 tls_seq_num_incr(conn->server_seq_num);
1075 if (tls_record_send(record, recordlen, conn->sock) != 1) {
1076 error_print();
1077 goto end;
1078 }
1079
1080 conn->protocol = conn->protocol;
1081
1082 fprintf(stderr, "Connection Established!\n\n");
1083 ret = 1;
1084
1085 end:
1086 gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
1087 gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret));
1088 if (client_verify) tls_client_verify_cleanup(&client_verify_ctx);
1089 return ret;
1090 }
1091