• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  *  Copyright 2014-2022 The GmSSL Project. All Rights Reserved.
3  *
4  *  Licensed under the Apache License, Version 2.0 (the License); you may
5  *  not use this file except in compliance with the License.
6  *
7  *  http://www.apache.org/licenses/LICENSE-2.0
8  */
9 
10 
11 
12 #include <time.h>
13 #include <stdio.h>
14 #include <assert.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <unistd.h>
18 #include <fcntl.h>
19 #include <sys/types.h>
20 #include <arpa/inet.h>
21 #include <sys/socket.h>
22 #include <netinet/in.h>
23 #include <gmssl/rand.h>
24 #include <gmssl/x509.h>
25 #include <gmssl/error.h>
26 #include <gmssl/mem.h>
27 #include <gmssl/sm2.h>
28 #include <gmssl/sm3.h>
29 #include <gmssl/sm4.h>
30 #include <gmssl/pem.h>
31 #include <gmssl/tls.h>
32 
33 
tls_uint8_to_bytes(uint8_t a,uint8_t ** out,size_t * outlen)34 void tls_uint8_to_bytes(uint8_t a, uint8_t **out, size_t *outlen)
35 {
36 	if (out && *out) {
37 		*(*out)++ = a;
38 	}
39 	(*outlen)++;
40 }
41 
tls_uint16_to_bytes(uint16_t a,uint8_t ** out,size_t * outlen)42 void tls_uint16_to_bytes(uint16_t a, uint8_t **out, size_t *outlen)
43 {
44 	if (out && *out) {
45 		*(*out)++ = (uint8_t)(a >> 8);
46 		*(*out)++ = (uint8_t)a;
47 	}
48 	*outlen += 2;
49 }
50 
tls_uint24_to_bytes(uint24_t a,uint8_t ** out,size_t * outlen)51 void tls_uint24_to_bytes(uint24_t a, uint8_t **out, size_t *outlen)
52 {
53 	if (out && *out) {
54 		*(*out)++ = (uint8_t)(a >> 16);
55 		*(*out)++ = (uint8_t)(a >> 8);
56 		*(*out)++ = (uint8_t)(a);
57 	}
58 	(*outlen) += 3;
59 }
60 
tls_uint32_to_bytes(uint32_t a,uint8_t ** out,size_t * outlen)61 void tls_uint32_to_bytes(uint32_t a, uint8_t **out, size_t *outlen)
62 {
63 	if (out && *out) {
64 		*(*out)++ = (uint8_t)(a >> 24);
65 		*(*out)++ = (uint8_t)(a >> 16);
66 		*(*out)++ = (uint8_t)(a >>  8);
67 		*(*out)++ = (uint8_t)(a      );
68 	}
69 	(*outlen) += 4;
70 }
71 
tls_array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)72 void tls_array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
73 {
74 	if (out && *out) {
75 		if (data) {
76 			memcpy(*out, data, datalen);
77 		}
78 		*out += datalen;
79 	}
80 	*outlen += datalen;
81 }
82 
83 /*
84 这几个函数要区分data = NULL, datalen = 0 和 data = NULL, datalen != 0的情况
85 前者意味着数据为空,因此输出的就是一个长度
86 后者意味着数据不为空,只是我们不想输出数据,只输出头部的长度,并且更新整个的输出长度。 这种情况应该避免!
87 
88 */
89 
tls_uint8array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)90 void tls_uint8array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
91 {
92 	tls_uint8_to_bytes((uint8_t)datalen, out, outlen);
93 	tls_array_to_bytes(data, datalen, out, outlen);
94 }
95 
tls_uint16array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)96 void tls_uint16array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
97 {
98 	tls_uint16_to_bytes((uint16_t)datalen, out, outlen);
99 	tls_array_to_bytes(data, datalen, out, outlen);
100 }
101 
tls_uint24array_to_bytes(const uint8_t * data,size_t datalen,uint8_t ** out,size_t * outlen)102 void tls_uint24array_to_bytes(const uint8_t *data, size_t datalen, uint8_t **out, size_t *outlen)
103 {
104 	tls_uint24_to_bytes((uint24_t)datalen, out, outlen);
105 	tls_array_to_bytes(data, datalen, out, outlen);
106 }
107 
tls_uint8_from_bytes(uint8_t * a,const uint8_t ** in,size_t * inlen)108 int tls_uint8_from_bytes(uint8_t *a, const uint8_t **in, size_t *inlen)
109 {
110 	if (*inlen < 1) {
111 		error_print();
112 		return -1;
113 	}
114 	*a = *(*in)++;
115 	(*inlen)--;
116 	return 1;
117 }
118 
tls_uint16_from_bytes(uint16_t * a,const uint8_t ** in,size_t * inlen)119 int tls_uint16_from_bytes(uint16_t *a, const uint8_t **in, size_t *inlen)
120 {
121 	if (*inlen < 2) {
122 		error_print();
123 		return -1;
124 	}
125 	*a = *(*in)++;
126 	*a <<= 8;
127 	*a |= *(*in)++;
128 	*inlen -= 2;
129 	return 1;
130 }
131 
tls_uint24_from_bytes(uint24_t * a,const uint8_t ** in,size_t * inlen)132 int tls_uint24_from_bytes(uint24_t *a, const uint8_t **in, size_t *inlen)
133 {
134 	if (*inlen < 3) {
135 		error_print();
136 		return -1;
137 	}
138 	*a = *(*in)++;
139 	*a <<= 8;
140 	*a |= *(*in)++;
141 	*a <<= 8;
142 	*a |= *(*in)++;
143 	*inlen -= 3;
144 	return 1;
145 }
146 
tls_uint32_from_bytes(uint32_t * a,const uint8_t ** in,size_t * inlen)147 int tls_uint32_from_bytes(uint32_t *a, const uint8_t **in, size_t *inlen)
148 {
149 	if (*inlen < 4) {
150 		error_print();
151 		return -1;
152 	}
153 	*a = *(*in)++;
154 	*a <<= 8;
155 	*a |= *(*in)++;
156 	*a <<= 8;
157 	*a |= *(*in)++;
158 	*a <<= 8;
159 	*a |= *(*in)++;
160 	*inlen -= 4;
161 	return 1;
162 }
163 
tls_array_from_bytes(const uint8_t ** data,size_t datalen,const uint8_t ** in,size_t * inlen)164 int tls_array_from_bytes(const uint8_t **data, size_t datalen, const uint8_t **in, size_t *inlen)
165 {
166 	if (*inlen < datalen) {
167 		error_print();
168 		return -1;
169 	}
170 	*data = *in;
171 	*in += datalen;
172 	*inlen -= datalen;
173 	return 1;
174 }
175 
tls_uint8array_from_bytes(const uint8_t ** data,size_t * datalen,const uint8_t ** in,size_t * inlen)176 int tls_uint8array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
177 {
178 	uint8_t len;
179 	if (tls_uint8_from_bytes(&len, in, inlen) != 1
180 		|| tls_array_from_bytes(data, len, in, inlen) != 1) {
181 		error_print();
182 		return -1;
183 	}
184 	if (!len) {
185 		*data = NULL;
186 	}
187 	*datalen = len;
188 	return 1;
189 }
190 
tls_uint16array_from_bytes(const uint8_t ** data,size_t * datalen,const uint8_t ** in,size_t * inlen)191 int tls_uint16array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
192 {
193 	uint16_t len;
194 	if (tls_uint16_from_bytes(&len, in, inlen) != 1
195 		|| tls_array_from_bytes(data, len, in, inlen) != 1) {
196 		error_print();
197 		return -1;
198 	}
199 	if (!len) {
200 		*data = NULL;
201 	}
202 	*datalen = len;
203 	return 1;
204 }
205 
tls_uint24array_from_bytes(const uint8_t ** data,size_t * datalen,const uint8_t ** in,size_t * inlen)206 int tls_uint24array_from_bytes(const uint8_t **data, size_t *datalen, const uint8_t **in, size_t *inlen)
207 {
208 	uint24_t len;
209 	if (tls_uint24_from_bytes(&len, in, inlen) != 1
210 		|| tls_array_from_bytes(data, len, in, inlen) != 1) {
211 		error_print();
212 		return -1;
213 	}
214 	if (!len) {
215 		*data = NULL;
216 	}
217 	*datalen = len;
218 	return 1;
219 }
220 
tls_length_is_zero(size_t len)221 int tls_length_is_zero(size_t len)
222 {
223 	if (len) {
224 		error_print();
225 		return -1;
226 	}
227 	return 1;
228 }
229 
tls_record_set_type(uint8_t * record,int type)230 int tls_record_set_type(uint8_t *record, int type)
231 {
232 	if (!tls_record_type_name(type)) {
233 		error_print();
234 		return -1;
235 	}
236 	record[0] = type;
237 	return 1;
238 }
239 
tls_record_set_protocol(uint8_t * record,int protocol)240 int tls_record_set_protocol(uint8_t *record, int protocol)
241 {
242 	if (!tls_protocol_name(protocol)) {
243 		error_print();
244 		return -1;
245 	}
246 	record[1] = protocol >> 8;
247 	record[2] = protocol;
248 	return 1;
249 }
250 
tls_record_set_length(uint8_t * record,size_t length)251 int tls_record_set_length(uint8_t *record, size_t length)
252 {
253 	uint8_t *p = record + 3;
254 	size_t len;
255 	if (length > TLS_MAX_CIPHERTEXT_SIZE) {
256 		error_print();
257 		return -1;
258 	}
259 	tls_uint16_to_bytes(length, &p, &len);
260 	return 1;
261 }
262 
tls_record_set_data(uint8_t * record,const uint8_t * data,size_t datalen)263 int tls_record_set_data(uint8_t *record, const uint8_t *data, size_t datalen)
264 {
265 	if (tls_record_set_length(record, datalen) != 1) {
266 		error_print();
267 		return -1;
268 	}
269 	memcpy(tls_record_data(record), data, datalen);
270 	return 1;
271 }
272 
tls_cbc_encrypt(const SM3_HMAC_CTX * inited_hmac_ctx,const SM4_KEY * enc_key,const uint8_t seq_num[8],const uint8_t header[5],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)273 int tls_cbc_encrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *enc_key,
274 	const uint8_t seq_num[8], const uint8_t header[5],
275 	const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
276 {
277 	SM3_HMAC_CTX hmac_ctx;
278 	uint8_t last_blocks[32 + 16] = {0};
279 	uint8_t *mac, *padding, *iv;
280 	int rem, padding_len;
281 	int i;
282 
283 	if (!inited_hmac_ctx || !enc_key || !seq_num || !header || (!in && inlen) || !out || !outlen) {
284 		error_print();
285 		return -1;
286 	}
287 	if (inlen > (1 << 14)) {
288 		error_print_msg("invalid tls record data length %zu\n", inlen);
289 		return -1;
290 	}
291 	if ((((size_t)header[3]) << 8) + header[4] != inlen) {
292 		error_print();
293 		return -1;
294 	}
295 
296 	rem = (inlen + 32) % 16;
297 	memcpy(last_blocks, in + inlen - rem, rem);
298 	mac = last_blocks + rem;
299 
300 	memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
301 	sm3_hmac_update(&hmac_ctx, seq_num, 8);
302 	sm3_hmac_update(&hmac_ctx, header, 5);
303 	sm3_hmac_update(&hmac_ctx, in, inlen);
304 	sm3_hmac_finish(&hmac_ctx, mac);
305 
306 	padding = mac + 32;
307 	padding_len = 16 - rem - 1;
308 	for (i = 0; i <= padding_len; i++) {
309 		padding[i] = padding_len;
310 	}
311 
312 	iv = out;
313 	if (rand_bytes(iv, 16) != 1) {
314 		error_print();
315 		return -1;
316 	}
317 	out += 16;
318 
319 	if (inlen >= 16) {
320 		sm4_cbc_encrypt(enc_key, iv, in, inlen/16, out);
321 		out += inlen - rem;
322 		iv = out - 16;
323 	}
324 	sm4_cbc_encrypt(enc_key, iv, last_blocks, sizeof(last_blocks)/16, out);
325 	*outlen = 16 + inlen - rem + sizeof(last_blocks);
326 	return 1;
327 }
328 
tls_cbc_decrypt(const SM3_HMAC_CTX * inited_hmac_ctx,const SM4_KEY * dec_key,const uint8_t seq_num[8],const uint8_t enced_header[5],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)329 int tls_cbc_decrypt(const SM3_HMAC_CTX *inited_hmac_ctx, const SM4_KEY *dec_key,
330 	const uint8_t seq_num[8], const uint8_t enced_header[5],
331 	const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen)
332 {
333 	SM3_HMAC_CTX hmac_ctx;
334 	const uint8_t *iv;
335 	const uint8_t *padding;
336 	const uint8_t *mac;
337 	uint8_t header[5];
338 	int padding_len;
339 	uint8_t hmac[32];
340 	int i;
341 
342 	if (!inited_hmac_ctx || !dec_key || !seq_num || !enced_header || !in || !inlen || !out || !outlen) {
343 		error_print();
344 		return -1;
345 	}
346 	if (inlen % 16
347 		|| inlen < (16 + 0 + 32 + 16) // iv + data +  mac + padding
348 		|| inlen > (16 + (1<<14) + 32 + 256)) {
349 		error_print_msg("invalid tls cbc ciphertext length %zu\n", inlen);
350 		return -1;
351 	}
352 
353 	iv = in;
354 	in += 16;
355 	inlen -= 16;
356 
357 	sm4_cbc_decrypt(dec_key, iv, in, inlen/16, out);
358 
359 	padding_len = out[inlen - 1];
360 	padding = out + inlen - padding_len - 1;
361 	if (padding < out + 32) {
362 		error_print();
363 		return -1;
364 	}
365 	for (i = 0; i < padding_len; i++) {
366 		if (padding[i] != padding_len) {
367 			error_puts("tls ciphertext cbc-padding check failure");
368 			return -1;
369 		}
370 	}
371 
372 	*outlen = inlen - 32 - padding_len - 1;
373 
374 	header[0] = enced_header[0];
375 	header[1] = enced_header[1];
376 	header[2] = enced_header[2];
377 	header[3] = (*outlen) >> 8;
378 	header[4] = (*outlen);
379 	mac = padding - 32;
380 
381 	memcpy(&hmac_ctx, inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
382 	sm3_hmac_update(&hmac_ctx, seq_num, 8);
383 	sm3_hmac_update(&hmac_ctx, header, 5);
384 	sm3_hmac_update(&hmac_ctx, out, *outlen);
385 	sm3_hmac_finish(&hmac_ctx, hmac);
386 	if (gmssl_secure_memcmp(mac, hmac, sizeof(hmac)) != 0) {
387 		error_puts("tls ciphertext mac check failure\n");
388 		return -1;
389 	}
390 	return 1;
391 }
392 
tls_record_encrypt(const SM3_HMAC_CTX * hmac_ctx,const SM4_KEY * cbc_key,const uint8_t seq_num[8],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)393 int tls_record_encrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
394 	const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
395 	uint8_t *out, size_t *outlen)
396 {
397 	if (tls_cbc_encrypt(hmac_ctx, cbc_key, seq_num, in,
398 		in + 5, inlen - 5,
399 		out + 5, outlen) != 1) {
400 		error_print();
401 		return -1;
402 	}
403 
404 	out[0] = in[0];
405 	out[1] = in[1];
406 	out[2] = in[2];
407 	out[3] = (*outlen) >> 8;
408 	out[4] = (*outlen);
409 	(*outlen) += 5;
410 	return 1;
411 }
412 
tls_record_decrypt(const SM3_HMAC_CTX * hmac_ctx,const SM4_KEY * cbc_key,const uint8_t seq_num[8],const uint8_t * in,size_t inlen,uint8_t * out,size_t * outlen)413 int tls_record_decrypt(const SM3_HMAC_CTX *hmac_ctx, const SM4_KEY *cbc_key,
414 	const uint8_t seq_num[8], const uint8_t *in, size_t inlen,
415 	uint8_t *out, size_t *outlen)
416 {
417 	if (tls_cbc_decrypt(hmac_ctx, cbc_key, seq_num, in,
418 		in + 5, inlen - 5,
419 		out + 5, outlen) != 1) {
420 		error_print();
421 		return -1;
422 	}
423 
424 	out[0] = in[0];
425 	out[1] = in[1];
426 	out[2] = in[2];
427 	out[3] = (*outlen) >> 8;
428 	out[4] = (*outlen);
429 	(*outlen) += 5;
430 
431 	return 1;
432 }
433 
tls_random_generate(uint8_t random[32])434 int tls_random_generate(uint8_t random[32])
435 {
436 	uint32_t gmt_unix_time = (uint32_t)time(NULL);
437 	uint8_t *p = random;
438 	size_t len = 0;
439 	tls_uint32_to_bytes(gmt_unix_time, &p, &len);
440 	if (rand_bytes(random + 4, 28) != 1) {
441 		error_print();
442 		return -1;
443 	}
444 	return 1;
445 }
446 
tls_prf(const uint8_t * secret,size_t secretlen,const char * label,const uint8_t * seed,size_t seedlen,const uint8_t * more,size_t morelen,size_t outlen,uint8_t * out)447 int tls_prf(const uint8_t *secret, size_t secretlen, const char *label,
448 	const uint8_t *seed, size_t seedlen,
449 	const uint8_t *more, size_t morelen,
450 	size_t outlen, uint8_t *out)
451 {
452 	SM3_HMAC_CTX inited_hmac_ctx;
453 	SM3_HMAC_CTX hmac_ctx;
454 	uint8_t A[32];
455 	uint8_t hmac[32];
456 	size_t len;
457 
458 	if (!secret || !secretlen || !label || !seed || !seedlen
459 		|| (!more && morelen) || !outlen || !out) {
460 		error_print();
461 		return -1;
462 	}
463 
464 	sm3_hmac_init(&inited_hmac_ctx, secret, secretlen);
465 
466 	memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
467 	sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
468 	sm3_hmac_update(&hmac_ctx, seed, seedlen);
469 	sm3_hmac_update(&hmac_ctx, more, morelen);
470 	sm3_hmac_finish(&hmac_ctx, A);
471 
472 	memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
473 	sm3_hmac_update(&hmac_ctx, A, sizeof(A));
474 	sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
475 	sm3_hmac_update(&hmac_ctx, seed, seedlen);
476 	sm3_hmac_update(&hmac_ctx, more, morelen);
477 	sm3_hmac_finish(&hmac_ctx, hmac);
478 
479 	len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
480 	memcpy(out, hmac, len);
481 	out += len;
482 	outlen -= len;
483 
484 	while (outlen) {
485 		memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
486 		sm3_hmac_update(&hmac_ctx, A, sizeof(A));
487 		sm3_hmac_finish(&hmac_ctx, A);
488 
489 		memcpy(&hmac_ctx, &inited_hmac_ctx, sizeof(SM3_HMAC_CTX));
490 		sm3_hmac_update(&hmac_ctx, A, sizeof(A));
491 		sm3_hmac_update(&hmac_ctx, (uint8_t *)label, strlen(label));
492 		sm3_hmac_update(&hmac_ctx, seed, seedlen);
493 		sm3_hmac_update(&hmac_ctx, more, morelen);
494 		sm3_hmac_finish(&hmac_ctx, hmac);
495 
496 		len = outlen < sizeof(hmac) ? outlen : sizeof(hmac);
497 		memcpy(out, hmac, len);
498 		out += len;
499 		outlen -= len;
500 	}
501 	return 1;
502 }
503 
tls_pre_master_secret_generate(uint8_t pre_master_secret[48],int protocol)504 int tls_pre_master_secret_generate(uint8_t pre_master_secret[48], int protocol)
505 {
506 	if (!tls_protocol_name(protocol)) {
507 		error_print();
508 		return -1;
509 	}
510 	pre_master_secret[0] = protocol >> 8;
511 	pre_master_secret[1] = protocol;
512 	if (rand_bytes(pre_master_secret + 2, 46) != 1) {
513 		error_print();
514 		return -1;
515 	}
516 	return 1;
517 }
518 
519 // 用于设置CertificateRequest
tls_cert_type_from_oid(int oid)520 int tls_cert_type_from_oid(int oid)
521 {
522 	switch (oid) {
523 	case OID_sm2sign_with_sm3:
524 	case OID_ecdsa_with_sha1:
525 	case OID_ecdsa_with_sha224:
526 	case OID_ecdsa_with_sha256:
527 	case OID_ecdsa_with_sha512:
528 		return TLS_cert_type_ecdsa_sign;
529 	case OID_rsasign_with_sm3:
530 	case OID_rsasign_with_md5:
531 	case OID_rsasign_with_sha1:
532 	case OID_rsasign_with_sha224:
533 	case OID_rsasign_with_sha256:
534 	case OID_rsasign_with_sha384:
535 	case OID_rsasign_with_sha512:
536 		return TLS_cert_type_rsa_sign;
537 	}
538 	// TLS_cert_type_xxx 中没有为0的值
539 	return 0;
540 }
541 
542 // 这两个函数没有对应的TLCP版本
tls_sign_server_ecdh_params(const SM2_KEY * server_sign_key,const uint8_t client_random[32],const uint8_t server_random[32],int curve,const SM2_POINT * point,uint8_t * sig,size_t * siglen)543 int tls_sign_server_ecdh_params(const SM2_KEY *server_sign_key,
544 	const uint8_t client_random[32], const uint8_t server_random[32],
545 	int curve, const SM2_POINT *point, uint8_t *sig, size_t *siglen)
546 {
547 	uint8_t server_ecdh_params[69];
548 	SM2_SIGN_CTX sign_ctx;
549 
550 	if (!server_sign_key || !client_random || !server_random
551 		|| curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen) {
552 		error_print();
553 		return -1;
554 	}
555 	server_ecdh_params[0] = TLS_curve_type_named_curve;
556 	server_ecdh_params[1] = curve >> 8;
557 	server_ecdh_params[2] = curve;
558 	server_ecdh_params[3] = 65;
559 	sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
560 
561 	sm2_sign_init(&sign_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
562 	sm2_sign_update(&sign_ctx, client_random, 32);
563 	sm2_sign_update(&sign_ctx, server_random, 32);
564 	sm2_sign_update(&sign_ctx, server_ecdh_params, 69);
565 	sm2_sign_finish(&sign_ctx, sig, siglen);
566 
567 	return 1;
568 }
569 
tls_verify_server_ecdh_params(const SM2_KEY * server_sign_key,const uint8_t client_random[32],const uint8_t server_random[32],int curve,const SM2_POINT * point,const uint8_t * sig,size_t siglen)570 int tls_verify_server_ecdh_params(const SM2_KEY *server_sign_key,
571 	const uint8_t client_random[32], const uint8_t server_random[32],
572 	int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen)
573 {
574 	int ret;
575 	uint8_t server_ecdh_params[69];
576 	SM2_SIGN_CTX verify_ctx;
577 
578 	if (!server_sign_key || !client_random || !server_random
579 		|| curve != TLS_curve_sm2p256v1 || !point || !sig || !siglen
580 		|| siglen > SM2_MAX_SIGNATURE_SIZE) {
581 		error_print();
582 		return -1;
583 	}
584 	server_ecdh_params[0] = TLS_curve_type_named_curve;
585 	server_ecdh_params[1] = curve >> 8;
586 	server_ecdh_params[2] = curve;
587 	server_ecdh_params[3] = 65;
588 	sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
589 
590 	sm2_verify_init(&verify_ctx, server_sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
591 	sm2_verify_update(&verify_ctx, client_random, 32);
592 	sm2_verify_update(&verify_ctx, server_random, 32);
593 	sm2_verify_update(&verify_ctx, server_ecdh_params, 69);
594 	ret = sm2_verify_finish(&verify_ctx, sig, siglen);
595 	if (ret != 1) error_print();
596 	return ret;
597 }
598 
tls_record_set_handshake(uint8_t * record,size_t * recordlen,int type,const uint8_t * data,size_t datalen)599 int tls_record_set_handshake(uint8_t *record, size_t *recordlen,
600 	int type, const uint8_t *data, size_t datalen)
601 {
602 	size_t handshakelen;
603 
604 	if (!record || !recordlen) {
605 		error_print();
606 		return -1;
607 	}
608 	// 由于ServerHelloDone没有负载数据,因此允许 data,datalen = NULL,0
609 	if (datalen > TLS_MAX_PLAINTEXT_SIZE - TLS_HANDSHAKE_HEADER_SIZE) {
610 		error_print();
611 		return -1;
612 	}
613 	if (!tls_protocol_name(tls_record_protocol(record))) {
614 		error_print();
615 		return -1;
616 	}
617 	if (!tls_handshake_type_name(type)) {
618 		error_print();
619 		return -1;
620 	}
621 	handshakelen = TLS_HANDSHAKE_HEADER_SIZE + datalen;
622 	record[0] = TLS_record_handshake;
623 	record[3] = handshakelen >> 8;
624 	record[4] = handshakelen;
625 	record[5] = type;
626 	record[6] = datalen >> 16;
627 	record[7] = datalen >> 8;
628 	record[8] = datalen;
629 	if (data) {
630 		memcpy(tls_handshake_data(tls_record_data(record)), data, datalen);
631 	}
632 	*recordlen = TLS_RECORD_HEADER_SIZE + handshakelen;
633 	return 1;
634 }
635 
tls_record_get_handshake(const uint8_t * record,int * type,const uint8_t ** data,size_t * datalen)636 int tls_record_get_handshake(const uint8_t *record,
637 	int *type, const uint8_t **data, size_t *datalen)
638 {
639 	const uint8_t *handshake;
640 	size_t handshake_len;
641 	uint24_t handshake_datalen;
642 
643 	if (!record || !type || !data || !datalen) {
644 		error_print();
645 		return -1;
646 	}
647 	if (!tls_protocol_name(tls_record_protocol(record))) {
648 		error_print();
649 		return -1;
650 	}
651 	if (tls_record_type(record) != TLS_record_handshake) {
652 		error_print();
653 		return -1;
654 	}
655 	handshake = tls_record_data(record);
656 	handshake_len = tls_record_data_length(record);
657 
658 	if (handshake_len < TLS_HANDSHAKE_HEADER_SIZE) {
659 		error_print();
660 		return -1;
661 	}
662 	if (handshake_len > TLS_MAX_PLAINTEXT_SIZE) {
663 		// 不支持证书长度超过记录长度的特殊情况
664 		error_print();
665 		return -1;
666 	}
667 
668 	if (!tls_handshake_type_name(handshake[0])) {
669 		error_print();
670 		return -1;
671 	}
672 	*type = handshake[0];
673 
674 	handshake++;
675 	handshake_len--;
676 	if (tls_uint24_from_bytes(&handshake_datalen, &handshake, &handshake_len) != 1) {
677 		error_print();
678 		return -1;
679 	}
680 	if (handshake_len != handshake_datalen) {
681 		error_print();
682 		return -1;
683 	}
684 	*data = handshake;
685 	*datalen = handshake_datalen;
686 
687 	if (*datalen == 0) {
688 		*data = NULL;
689 	}
690 	return 1;
691 }
692 
tls_record_set_handshake_client_hello(uint8_t * record,size_t * recordlen,int protocol,const uint8_t random[32],const uint8_t * session_id,size_t session_id_len,const int * cipher_suites,size_t cipher_suites_count,const uint8_t * exts,size_t exts_len)693 int tls_record_set_handshake_client_hello(uint8_t *record, size_t *recordlen,
694 	int protocol, const uint8_t random[32],
695 	const uint8_t *session_id, size_t session_id_len,
696 	const int *cipher_suites, size_t cipher_suites_count,
697 	const uint8_t *exts, size_t exts_len)
698 {
699 	uint8_t type = TLS_handshake_client_hello;
700 	uint8_t *p;
701 	size_t len;
702 
703 	if (!record || !recordlen || !random || !cipher_suites || !cipher_suites_count) {
704 		error_print();
705 		return -1;
706 	}
707 	if (session_id) {
708 		if (!session_id_len
709 			|| session_id_len < TLS_MAX_SESSION_ID_SIZE
710 			|| session_id_len > TLS_MAX_SESSION_ID_SIZE) {
711 			error_print();
712 			return -1;
713 		}
714 	}
715 	if (cipher_suites_count > TLS_MAX_CIPHER_SUITES_COUNT) {
716 		error_print();
717 		return -1;
718 	}
719 	if (exts && !exts_len) {
720 		error_print();
721 		return -1;
722 	}
723 
724 
725 	p = tls_handshake_data(tls_record_data(record));
726 	len = 0;
727 
728 	if (!tls_protocol_name(protocol)) {
729 		error_print();
730 		return -1;
731 	}
732 	tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
733 	tls_array_to_bytes(random, 32, &p, &len);
734 	tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
735 	tls_uint16_to_bytes(cipher_suites_count * 2, &p, &len);
736 	while (cipher_suites_count--) {
737 		if (!tls_cipher_suite_name(*cipher_suites)) {
738 			error_print();
739 			return -1;
740 		}
741 		tls_uint16_to_bytes((uint16_t)*cipher_suites, &p, &len);
742 		cipher_suites++;
743 	}
744 	tls_uint8_to_bytes(1, &p, &len);
745 	tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
746 	if (exts) {
747 		size_t tmp_len = len;
748 		if (protocol < TLS_protocol_tls12) {
749 			error_print();
750 			return -1;
751 		}
752 		tls_uint16array_to_bytes(exts, exts_len, NULL, &tmp_len);
753 		if (tmp_len > TLS_MAX_HANDSHAKE_DATA_SIZE) {
754 			error_print();
755 			return -1;
756 		}
757 		tls_uint16array_to_bytes(exts, exts_len, &p, &len);
758 	}
759 	if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
760 		error_print();
761 		return -1;
762 	}
763 	return 1;
764 }
765 
tls_record_get_handshake_client_hello(const uint8_t * record,int * protocol,const uint8_t ** random,const uint8_t ** session_id,size_t * session_id_len,const uint8_t ** cipher_suites,size_t * cipher_suites_len,const uint8_t ** exts,size_t * exts_len)766 int tls_record_get_handshake_client_hello(const uint8_t *record,
767 	int *protocol, const uint8_t **random,
768 	const uint8_t **session_id, size_t *session_id_len,
769 	const uint8_t **cipher_suites, size_t *cipher_suites_len,
770 	const uint8_t **exts, size_t *exts_len)
771 {
772 	int type;
773 	const uint8_t *p;
774 	size_t len;
775 	uint16_t ver;
776 	const uint8_t *comp_meths;
777 	size_t comp_meths_len;
778 
779 	if (!record || !protocol || !random
780 		|| !session_id || !session_id_len
781 		|| !cipher_suites || !cipher_suites_len
782 		|| !exts || !exts_len) {
783 		error_print();
784 		return -1;
785 	}
786 	if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
787 		error_print();
788 		return -1;
789 	}
790 	if (type != TLS_handshake_client_hello) {
791 		error_print();
792 		return -1;
793 	}
794 	if (tls_uint16_from_bytes(&ver, &p, &len) != 1
795 		|| tls_array_from_bytes(random, 32, &p, &len) != 1
796 		|| tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1
797 		|| tls_uint16array_from_bytes(cipher_suites, cipher_suites_len, &p, &len) != 1
798 		|| tls_uint8array_from_bytes(&comp_meths, &comp_meths_len, &p, &len) != 1) {
799 		error_print();
800 		return -1;
801 	}
802 
803 	if (!tls_protocol_name(ver)) {
804 		error_print();
805 		return -1;
806 	}
807 	*protocol = ver;
808 
809 	if (*session_id) {
810 		if (*session_id_len == 0
811 			|| *session_id_len < TLS_MIN_SESSION_ID_SIZE
812 			|| *session_id_len > TLS_MAX_SESSION_ID_SIZE) {
813 			error_print();
814 			return -1;
815 		}
816 	}
817 
818 	if (!cipher_suites) {
819 		error_print();
820 		return -1;
821 	}
822 	if (*cipher_suites_len % 2) {
823 		error_print();
824 		return -1;
825 	}
826 
827 	if (len) {
828 		if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
829 			error_print();
830 			return -1;
831 		}
832 		if (*exts == NULL) {
833 			error_print();
834 			return -1;
835 		}
836 	} else {
837 		*exts = NULL;
838 		*exts_len = 0;
839 	}
840 	if (len) {
841 		error_print();
842 		return -1;
843 	}
844 	return 1;
845 }
846 
tls_record_set_handshake_server_hello(uint8_t * record,size_t * recordlen,int protocol,const uint8_t random[32],const uint8_t * session_id,size_t session_id_len,int cipher_suite,const uint8_t * exts,size_t exts_len)847 int tls_record_set_handshake_server_hello(uint8_t *record, size_t *recordlen,
848 	int protocol, const uint8_t random[32],
849 	const uint8_t *session_id, size_t session_id_len, int cipher_suite,
850 	const uint8_t *exts, size_t exts_len)
851 {
852 	uint8_t type = TLS_handshake_server_hello;
853 	uint8_t *p;
854 	size_t len;
855 
856 	if (!record || !recordlen || !random) {
857 		error_print();
858 		return -1;
859 	}
860 	if (session_id) {
861 		if (session_id_len == 0
862 			|| session_id_len < TLS_MIN_SESSION_ID_SIZE
863 			|| session_id_len > TLS_MAX_SESSION_ID_SIZE) {
864 			error_print();
865 			return -1;
866 		}
867 	}
868 	if (!tls_protocol_name(protocol)) {
869 		error_print();
870 		return -1;
871 	}
872 	if (!tls_cipher_suite_name(cipher_suite)) {
873 		error_print();
874 		return -1;
875 	}
876 
877 	p = tls_handshake_data(tls_record_data(record));
878 	len = 0;
879 
880 	tls_uint16_to_bytes((uint16_t)protocol, &p, &len);
881 	tls_array_to_bytes(random, 32, &p, &len);
882 	tls_uint8array_to_bytes(session_id, session_id_len, &p, &len);
883 	tls_uint16_to_bytes((uint16_t)cipher_suite, &p, &len);
884 	tls_uint8_to_bytes((uint8_t)TLS_compression_null, &p, &len);
885 	if (exts) {
886 		if (protocol < TLS_protocol_tls12) {
887 			error_print();
888 			return -1;
889 		}
890 		tls_uint16array_to_bytes(exts, exts_len, &p, &len);
891 	}
892 	if (tls_record_set_handshake(record, recordlen, type, NULL, len) != 1) {
893 		error_print();
894 		return -1;
895 	}
896 	return 1;
897 }
898 
tls_record_get_handshake_server_hello(const uint8_t * record,int * protocol,const uint8_t ** random,const uint8_t ** session_id,size_t * session_id_len,int * cipher_suite,const uint8_t ** exts,size_t * exts_len)899 int tls_record_get_handshake_server_hello(const uint8_t *record,
900 	int *protocol, const uint8_t **random, const uint8_t **session_id, size_t *session_id_len,
901 	int *cipher_suite, const uint8_t **exts, size_t *exts_len)
902 {
903 	int type;
904 	const uint8_t *p;
905 	size_t len;
906 	uint16_t ver;
907 	uint16_t cipher;
908 	uint8_t comp_meth;
909 
910 	if (!record || !protocol || !random || !session_id || !session_id_len
911 		|| !cipher_suite || !exts || !exts_len) {
912 		error_print();
913 		return -1;
914 	}
915 	if (tls_record_get_handshake(record, &type, &p, &len) != 1) {
916 		error_print();
917 		return -1;
918 	}
919 	if (type != TLS_handshake_server_hello) {
920 		error_print();
921 		return -1;
922 	}
923 	if (tls_uint16_from_bytes(&ver, &p, &len) != 1
924 		|| tls_array_from_bytes(random, 32, &p, &len) != 1
925 		|| tls_uint8array_from_bytes(session_id, session_id_len, &p, &len) != 1
926 		|| tls_uint16_from_bytes(&cipher, &p, &len) != 1
927 		|| tls_uint8_from_bytes(&comp_meth, &p, &len) != 1) {
928 		error_print();
929 		return -1;
930 	}
931 
932 	if (!tls_protocol_name(ver)) {
933 		error_print();
934 		return -1;
935 	}
936 	if (ver < tls_record_protocol(record)) {
937 		error_print();
938 		return -1;
939 	}
940 	*protocol = ver;
941 
942 	if (*session_id) {
943 		if (*session_id == 0
944 			|| *session_id_len < TLS_MIN_SESSION_ID_SIZE
945 			|| *session_id_len > TLS_MAX_SESSION_ID_SIZE) {
946 			error_print();
947 			return -1;
948 		}
949 	}
950 
951 	if (!tls_cipher_suite_name(cipher)) {
952 		error_print();
953 		return -1;
954 	}
955 	*cipher_suite = cipher;
956 
957 	if (comp_meth != TLS_compression_null) {
958 		error_print();
959 		return -1;
960 	}
961 
962 	if (len) {
963 		if (tls_uint16array_from_bytes(exts, exts_len, &p, &len) != 1) {
964 			error_print();
965 			return -1;
966 		}
967 		if (*exts == NULL) {
968 			error_print();
969 			return -1;
970 		}
971 	} else {
972 		*exts = NULL;
973 		*exts_len = 0;
974 	}
975 	if (len) {
976 		error_print();
977 		return -1;
978 	}
979 	return 1;
980 }
981 
tls_record_set_handshake_certificate(uint8_t * record,size_t * recordlen,const uint8_t * certs,size_t certslen)982 int tls_record_set_handshake_certificate(uint8_t *record, size_t *recordlen,
983 	const uint8_t *certs, size_t certslen)
984 {
985 	int type = TLS_handshake_certificate;
986 	uint8_t *data;
987 	size_t datalen;
988 	uint8_t *p;
989 	size_t len;
990 
991 	if (!record || !recordlen || !certs || !certslen) {
992 		error_print();
993 		return -1;
994 	}
995 	data = tls_handshake_data(tls_record_data(record));
996 	p = data + tls_uint24_size();
997 	datalen = tls_uint24_size();
998 	len = 0;
999 
1000 	while (certslen) {
1001 		const uint8_t *cert;
1002 		size_t certlen;
1003 
1004 		if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1) {
1005 			error_print();
1006 			return -1;
1007 		}
1008 		tls_uint24array_to_bytes(cert, certlen, NULL, &datalen);
1009 		if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
1010 			error_print();
1011 			return -1;
1012 		}
1013 		tls_uint24array_to_bytes(cert, certlen, &p, &len);
1014 	}
1015 	tls_uint24_to_bytes(len, &data, &len);
1016 	tls_record_set_handshake(record, recordlen, type, NULL, datalen);
1017 	return 1;
1018 }
1019 
tls_record_get_handshake_certificate(const uint8_t * record,uint8_t * certs,size_t * certslen)1020 int tls_record_get_handshake_certificate(const uint8_t *record, uint8_t *certs, size_t *certslen)
1021 {
1022 	int type;
1023 	const uint8_t *data;
1024 	size_t datalen;
1025 	const uint8_t *cp;
1026 	size_t len;
1027 
1028 	if (tls_record_get_handshake(record, &type, &data, &datalen) != 1) {
1029 		error_print();
1030 		return -1;
1031 	}
1032 	if (type != TLS_handshake_certificate) {
1033 		error_print();
1034 		return -1;
1035 	}
1036 	if (tls_uint24array_from_bytes(&cp, &len, &data, &datalen) != 1) {
1037 		error_print();
1038 		return -1;
1039 	}
1040 
1041 	*certslen = 0;
1042 	while (len) {
1043 		const uint8_t *a;
1044 		size_t alen;
1045 		const uint8_t *cert;
1046 		size_t certlen;
1047 
1048 		if (tls_uint24array_from_bytes(&a, &alen, &cp, &len) != 1) {
1049 			error_print();
1050 			return -1;
1051 		}
1052 		if (x509_cert_from_der(&cert, &certlen, &a, &alen) != 1
1053 			|| asn1_length_is_zero(alen) != 1
1054 			|| x509_cert_to_der(cert, certlen, &certs, certslen) != 1) {
1055 			error_print();
1056 			return -1;
1057 		}
1058 	}
1059 	return 1;
1060 }
1061 
tls_record_set_handshake_certificate_request(uint8_t * record,size_t * recordlen,const uint8_t * cert_types,size_t cert_types_len,const uint8_t * ca_names,size_t ca_names_len)1062 int tls_record_set_handshake_certificate_request(uint8_t *record, size_t *recordlen,
1063 	const uint8_t *cert_types, size_t cert_types_len,
1064 	const uint8_t *ca_names, size_t ca_names_len)
1065 {
1066 	int type = TLS_handshake_certificate_request;
1067 	uint8_t *p;
1068 	size_t len =0;
1069 	size_t datalen = 0;
1070 
1071 	if (!record || !recordlen) {
1072 		error_print();
1073 		return -1;
1074 	}
1075 	if (cert_types) {
1076 		if (cert_types_len == 0 || cert_types_len > TLS_MAX_CERTIFICATE_TYPES) {
1077 			error_print();
1078 			return -1;
1079 		}
1080 	}
1081 	if (ca_names) {
1082 		if (ca_names_len == 0 || ca_names_len > TLS_MAX_CA_NAMES_SIZE) {
1083 			error_print();
1084 			return -1;
1085 		}
1086 	}
1087 	tls_uint8array_to_bytes(cert_types, cert_types_len, NULL, &datalen);
1088 	tls_uint16array_to_bytes(ca_names, ca_names_len, NULL, &datalen);
1089 	if (datalen > TLS_MAX_HANDSHAKE_DATA_SIZE) {
1090 		error_print();
1091 		return -1;
1092 	}
1093 	p = tls_handshake_data(tls_record_data(record));
1094 	tls_uint8array_to_bytes(cert_types, cert_types_len, &p, &len);
1095 	tls_uint16array_to_bytes(ca_names, ca_names_len, &p, &len);
1096 	tls_record_set_handshake(record, recordlen, type, NULL, datalen);
1097 	return 1;
1098 }
1099 
tls_record_get_handshake_certificate_request(const uint8_t * record,const uint8_t ** cert_types,size_t * cert_types_len,const uint8_t ** ca_names,size_t * ca_names_len)1100 int tls_record_get_handshake_certificate_request(const uint8_t *record,
1101 	const uint8_t **cert_types, size_t *cert_types_len,
1102 	const uint8_t **ca_names, size_t *ca_names_len)
1103 {
1104 	int type;
1105 	const uint8_t *cp;
1106 	size_t len;
1107 	size_t i;
1108 
1109 	if (!record || !cert_types || !cert_types_len || !ca_names || !ca_names_len) {
1110 		error_print();
1111 		return -1;
1112 	}
1113 	if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
1114 		error_print();
1115 		return -1;
1116 	}
1117 	if (type != TLS_handshake_certificate_request) {
1118 		error_print();
1119 		return -1;
1120 	}
1121 	if (tls_uint8array_from_bytes(cert_types, cert_types_len, &cp, &len) != 1
1122 		|| tls_uint16array_from_bytes(ca_names, ca_names_len, &cp, &len) != 1
1123 		|| tls_length_is_zero(len) != 1) {
1124 		error_print();
1125 		return -1;
1126 	}
1127 
1128 	if (*cert_types == NULL) {
1129 		error_print();
1130 		return -1;
1131 	}
1132 	for (i = 0; i < *cert_types_len; i++) {
1133 		if (!tls_cert_type_name((*cert_types)[i])) {
1134 			error_print();
1135 			return -1;
1136 		}
1137 	}
1138 	if (*ca_names) {
1139 		const uint8_t *names = *ca_names;
1140 		size_t nameslen = *ca_names_len;
1141 		while (nameslen) {
1142 			if (tls_uint16array_from_bytes(&cp, &len, &names, &nameslen) != 1) {
1143 				error_print();
1144 				return -1;
1145 			}
1146 		}
1147 	}
1148 	return 1;
1149 }
1150 
tls_record_set_handshake_server_hello_done(uint8_t * record,size_t * recordlen)1151 int tls_record_set_handshake_server_hello_done(uint8_t *record, size_t *recordlen)
1152 {
1153 	int type = TLS_handshake_server_hello_done;
1154 	if (!record || !recordlen) {
1155 		error_print();
1156 		return -1;
1157 	}
1158 	tls_record_set_handshake(record, recordlen, type, NULL, 0);
1159 	return 1;
1160 }
1161 
tls_record_get_handshake_server_hello_done(const uint8_t * record)1162 int tls_record_get_handshake_server_hello_done(const uint8_t *record)
1163 {
1164 	int type;
1165 	const uint8_t *p;
1166 	size_t len;
1167 
1168 	if (!record) {
1169 		error_print();
1170 		return -1;
1171 	}
1172 	if (tls_record_get_handshake(record, &type, &p, &len) != 1
1173 		|| type != TLS_handshake_server_hello_done) {
1174 		error_print();
1175 		return -1;
1176 	}
1177 	if (p != NULL || len != 0) {
1178 		error_print();
1179 		return -1;
1180 	}
1181 	return 1;
1182 }
1183 
tls_record_set_handshake_client_key_exchange_pke(uint8_t * record,size_t * recordlen,const uint8_t * enced_pms,size_t enced_pms_len)1184 int tls_record_set_handshake_client_key_exchange_pke(uint8_t *record, size_t *recordlen,
1185 	const uint8_t *enced_pms, size_t enced_pms_len)
1186 {
1187 	int type = TLS_handshake_client_key_exchange;
1188 	uint8_t *p;
1189 	size_t len = 0;
1190 
1191 	if (!record || !recordlen || !enced_pms || !enced_pms_len) {
1192 		error_print();
1193 		return -1;
1194 	}
1195 	if (enced_pms_len > TLS_MAX_HANDSHAKE_DATA_SIZE - tls_uint16_size()) {
1196 		error_print();
1197 		return -1;
1198 	}
1199 	p = tls_handshake_data(tls_record_data(record));
1200 	tls_uint16array_to_bytes(enced_pms, enced_pms_len, &p, &len);
1201 	tls_record_set_handshake(record, recordlen, type, NULL, len);
1202 	return 1;
1203 }
1204 
tls_record_get_handshake_client_key_exchange_pke(const uint8_t * record,const uint8_t ** enced_pms,size_t * enced_pms_len)1205 int tls_record_get_handshake_client_key_exchange_pke(const uint8_t *record,
1206 	const uint8_t **enced_pms, size_t *enced_pms_len)
1207 {
1208 	int type;
1209 	const uint8_t *cp;
1210 	size_t len;
1211 
1212 	if (!record || !enced_pms || !enced_pms_len) {
1213 		error_print();
1214 		return -1;
1215 	}
1216 	if (tls_record_get_handshake(record, &type, &cp, &len) != 1) {
1217 		error_print();
1218 		return -1;
1219 	}
1220 	if (type != TLS_handshake_client_key_exchange) {
1221 		error_print();
1222 		return -1;
1223 	}
1224 	if (tls_uint16array_from_bytes(enced_pms, enced_pms_len, &cp, &len) != 1
1225 		|| tls_length_is_zero(len) != 1) {
1226 		error_print();
1227 		return -1;
1228 	}
1229 	return 1;
1230 }
1231 
tls_record_set_handshake_certificate_verify(uint8_t * record,size_t * recordlen,const uint8_t * sig,size_t siglen)1232 int tls_record_set_handshake_certificate_verify(uint8_t *record, size_t *recordlen,
1233 	const uint8_t *sig, size_t siglen)
1234 {
1235 	int type = TLS_handshake_certificate_verify;
1236 
1237 	if (!record || !recordlen || !sig || !siglen) {
1238 		error_print();
1239 		return -1;
1240 	}
1241 	if (siglen > TLS_MAX_SIGNATURE_SIZE) {
1242 		error_print();
1243 		return -1;
1244 	}
1245 	tls_record_set_handshake(record, recordlen, type, sig, siglen);
1246 	return 1;
1247 }
1248 
tls_record_get_handshake_certificate_verify(const uint8_t * record,const uint8_t ** sig,size_t * siglen)1249 int tls_record_get_handshake_certificate_verify(const uint8_t *record,
1250 	const uint8_t **sig, size_t *siglen)
1251 {
1252 	int type;
1253 
1254 	if (!record || !sig || !siglen) {
1255 		error_print();
1256 		return -1;
1257 	}
1258 	if (tls_record_get_handshake(record, &type, sig, siglen) != 1) {
1259 		error_print();
1260 		return -1;
1261 	}
1262 	if (type != TLS_handshake_certificate_verify) {
1263 		error_print();
1264 		return -1;
1265 	}
1266 	if (*sig == NULL || *siglen == 0) {
1267 		error_print();
1268 		return -1;
1269 	}
1270 	if (*siglen > TLS_MAX_SIGNATURE_SIZE) {
1271 		error_print();
1272 		return -1;
1273 	}
1274 	return 1;
1275 }
1276 
tls_record_set_handshake_finished(uint8_t * record,size_t * recordlen,const uint8_t * verify_data,size_t verify_data_len)1277 int tls_record_set_handshake_finished(uint8_t *record, size_t *recordlen,
1278 	const uint8_t *verify_data, size_t verify_data_len)
1279 {
1280 	int type = TLS_handshake_finished;
1281 
1282 	if (!record || !recordlen || !verify_data || !verify_data_len) {
1283 		error_print();
1284 		return -1;
1285 	}
1286 	if (verify_data_len != 12 && verify_data_len != 32) {
1287 		error_print();
1288 		return -1;
1289 	}
1290 	tls_record_set_handshake(record, recordlen, type, verify_data, verify_data_len);
1291 	return 1;
1292 }
1293 
tls_record_get_handshake_finished(const uint8_t * record,const uint8_t ** verify_data,size_t * verify_data_len)1294 int tls_record_get_handshake_finished(const uint8_t *record, const uint8_t **verify_data, size_t *verify_data_len)
1295 {
1296 	int type;
1297 
1298 	if (!record || !verify_data || !verify_data_len) {
1299 		error_print();
1300 		return -1;
1301 	}
1302 	if (tls_record_get_handshake(record, &type, verify_data, verify_data_len) != 1) {
1303 		error_print();
1304 		return -1;
1305 	}
1306 	if (type != TLS_handshake_finished) {
1307 		error_print();
1308 		return -1;
1309 	}
1310 	if (*verify_data == NULL || *verify_data_len == 0) {
1311 		error_print();
1312 		return -1;
1313 	}
1314 	if (*verify_data_len != 12 && *verify_data_len != 32) {
1315 		error_print();
1316 		return -1;
1317 	}
1318 	return 1;
1319 }
1320 
tls_record_set_alert(uint8_t * record,size_t * recordlen,int alert_level,int alert_description)1321 int tls_record_set_alert(uint8_t *record, size_t *recordlen,
1322 	int alert_level,
1323 	int alert_description)
1324 {
1325 	if (!record || !recordlen) {
1326 		error_print();
1327 		return -1;
1328 	}
1329 	if (!tls_alert_level_name(alert_level)) {
1330 		error_print();
1331 		return -1;
1332 	}
1333 	if (!tls_alert_description_text(alert_description)) {
1334 		error_print();
1335 		return -1;
1336 	}
1337 	record[0] = TLS_record_alert;
1338 	record[3] = 0; // length
1339 	record[4] = 2; // length
1340 	record[5] = (uint8_t)alert_level;
1341 	record[6] = (uint8_t)alert_description;
1342 	*recordlen = TLS_RECORD_HEADER_SIZE + 2;
1343 	return 1;
1344 }
1345 
tls_record_get_alert(const uint8_t * record,int * alert_level,int * alert_description)1346 int tls_record_get_alert(const uint8_t *record,
1347 	int *alert_level,
1348 	int *alert_description)
1349 {
1350 	if (!record || !alert_level || !alert_description) {
1351 		error_print();
1352 		return -1;
1353 	}
1354 	if (tls_record_type(record) != TLS_record_alert) {
1355 		error_print();
1356 		return -1;
1357 	}
1358 	if (record[3] != 0 || record[4] != 2) {
1359 		error_print();
1360 		return -1;
1361 	}
1362 	*alert_level = record[5];
1363 	*alert_description = record[6];
1364 	if (!tls_alert_level_name(*alert_level)) {
1365 		error_print();
1366 		return -1;
1367 	}
1368 	if (!tls_alert_description_text(*alert_description)) {
1369 		error_puts("warning");
1370 		return -1;
1371 	}
1372 	return 1;
1373 }
1374 
tls_record_set_change_cipher_spec(uint8_t * record,size_t * recordlen)1375 int tls_record_set_change_cipher_spec(uint8_t *record, size_t *recordlen)
1376 {
1377 	if (!record || !recordlen) {
1378 		error_print();
1379 		return -1;
1380 	}
1381 	record[0] = TLS_record_change_cipher_spec;
1382 	record[3] = 0;
1383 	record[4] = 1;
1384 	record[5] = TLS_change_cipher_spec;
1385 	*recordlen = TLS_RECORD_HEADER_SIZE + 1;
1386 	return 1;
1387 }
1388 
tls_record_get_change_cipher_spec(const uint8_t * record)1389 int tls_record_get_change_cipher_spec(const uint8_t *record)
1390 {
1391 	if (!record) {
1392 		error_print();
1393 		return -1;
1394 	}
1395 	if (tls_record_type(record) != TLS_record_change_cipher_spec) {
1396 		error_print();
1397 		return -1;
1398 	}
1399 	if (record[3] != 0 || record[4] != 1) {
1400 		error_print();
1401 		return -1;
1402 	}
1403 	if (record[5] != TLS_change_cipher_spec) {
1404 		error_print();
1405 		return -1;
1406 	}
1407 	return 1;
1408 }
1409 
tls_record_set_application_data(uint8_t * record,size_t * recordlen,const uint8_t * data,size_t datalen)1410 int tls_record_set_application_data(uint8_t *record, size_t *recordlen,
1411 	const uint8_t *data, size_t datalen)
1412 {
1413 	if (!record || !recordlen || !data || !datalen) {
1414 		error_print();
1415 		return -1;
1416 	}
1417 	record[0] = TLS_record_application_data;
1418 	record[3] = (datalen >> 8) & 0xff;
1419 	record[4] = datalen & 0xff;
1420 	memcpy(tls_record_data(record), data, datalen);
1421 	*recordlen = TLS_RECORD_HEADER_SIZE + datalen;
1422 	return 1;
1423 }
1424 
tls_record_get_application_data(uint8_t * record,const uint8_t ** data,size_t * datalen)1425 int tls_record_get_application_data(uint8_t *record,
1426 	const uint8_t **data, size_t *datalen)
1427 {
1428 	if (!record || !data || !datalen) {
1429 		error_print();
1430 		return -1;
1431 	}
1432 	if (tls_record_type(record) != TLS_record_application_data) {
1433 		error_print();
1434 		return -1;
1435 	}
1436 	*datalen = ((size_t)record[3] << 8) | record[4];
1437 	*data = *datalen ? record + TLS_RECORD_HEADER_SIZE : 0;
1438 	return 1;
1439 }
1440 
tls_cipher_suite_in_list(int cipher,const int * list,size_t list_count)1441 int tls_cipher_suite_in_list(int cipher, const int *list, size_t list_count)
1442 {
1443 	size_t i;
1444 	if (!list || !list_count) {
1445 		error_print();
1446 		return -1;
1447 	}
1448 	for (i = 0; i < list_count; i++) {
1449 		if (cipher == list[i]) {
1450 			return 1;
1451 		}
1452 	}
1453 	return 0;
1454 }
1455 
tls_record_send(const uint8_t * record,size_t recordlen,int sock)1456 int tls_record_send(const uint8_t *record, size_t recordlen, int sock)
1457 {
1458 	ssize_t r;
1459 	if (!record) {
1460 		error_print();
1461 		return -1;
1462 	}
1463 	if (recordlen < TLS_RECORD_HEADER_SIZE) {
1464 		error_print();
1465 		return -1;
1466 	}
1467 	if (tls_record_length(record) != recordlen) {
1468 		error_print();
1469 		return -1;
1470 	}
1471 	if ((r = send(sock, record, recordlen, 0)) < 0) {
1472 		perror("tls_record_send");
1473 		error_print();
1474 		return -1;
1475 	} else if (r != recordlen) {
1476 		error_print();
1477 		return -1;
1478 	}
1479 	return 1;
1480 }
1481 
tls_record_do_recv(uint8_t * record,size_t * recordlen,int sock)1482 int tls_record_do_recv(uint8_t *record, size_t *recordlen, int sock)
1483 {
1484 	ssize_t r;
1485 	int type;
1486 	size_t len;
1487 
1488 	len = 5;
1489 	while (len) {
1490 		if ((r = recv(sock, record + 5 - len, len, 0)) < 0) {
1491 			perror("tls_record_do_recv");
1492 			error_print();
1493 			return -1;
1494 		}
1495 		if (r == 0) {
1496 			perror("tls_record_do_recv");
1497 			error_print();
1498 			return 0;
1499 		}
1500 
1501 		len -= r;
1502 	}
1503 	if (!tls_record_type_name(tls_record_type(record))) {
1504 		error_print();
1505 		return -1;
1506 	}
1507 	if (!tls_protocol_name(tls_record_protocol(record))) {
1508 		error_print();
1509 		return -1;
1510 	}
1511 	len = (size_t)record[3] << 8 | record[4];
1512 	*recordlen = 5 + len;
1513 	if (*recordlen > TLS_MAX_RECORD_SIZE) {
1514 		// 这里只检查是否超过最大长度,握手协议的长度检查由上层协议完成
1515 		error_print();
1516 		return -1;
1517 	}
1518 	while (len) {
1519 		if ((r = recv(sock, record + *recordlen - len, len, 0)) < 0) {
1520 			perror("tls_record_do_recv");
1521 			error_print();
1522 			return -1;
1523 		}
1524 		len -= r;
1525 	}
1526 	return 1;
1527 }
1528 
tls_record_recv(uint8_t * record,size_t * recordlen,int sock)1529 int tls_record_recv(uint8_t *record, size_t *recordlen, int sock)
1530 {
1531 retry:
1532 	if (tls_record_do_recv(record, recordlen, sock) != 1) {
1533 		error_print();
1534 		return -1;
1535 	}
1536 
1537 	if (tls_record_type(record) == TLS_record_alert) {
1538 		int level;
1539 		int alert;
1540 		if (tls_record_get_alert(record, &level, &alert) != 1) {
1541 			error_print();
1542 			return -1;
1543 		}
1544 		tls_record_trace(stderr, record, *recordlen, 0, 0);
1545 		if (level == TLS_alert_level_warning) {
1546 			// 忽略Warning,读取下一个记录
1547 			error_puts("Warning record received!\n");
1548 			goto retry;
1549 		}
1550 		if (alert == TLS_alert_close_notify) {
1551 			// close_notify是唯一需要提供反馈的Fatal Alert,其他直接中止连接
1552 			uint8_t alert_record[TLS_ALERT_RECORD_SIZE];
1553 			size_t alert_record_len;
1554 			tls_record_set_type(alert_record, TLS_record_alert);
1555 			tls_record_set_protocol(alert_record, tls_record_protocol(record));
1556 			tls_record_set_alert(alert_record, &alert_record_len, TLS_alert_level_fatal, TLS_alert_close_notify);
1557 
1558 			tls_trace("send Alert close_notifiy\n");
1559 			tls_record_trace(stderr, alert_record, alert_record_len, 0, 0);
1560 			tls_record_send(alert_record, alert_record_len, sock);
1561 		}
1562 		// 返回错误0通知调用方不再做任何处理(无需再发送Alert)
1563 		return 0;
1564 	}
1565 	return 1;
1566 }
1567 
tls_seq_num_incr(uint8_t seq_num[8])1568 int tls_seq_num_incr(uint8_t seq_num[8])
1569 {
1570 	int i;
1571 	for (i = 7; i > 0; i--) {
1572 		seq_num[i]++;
1573 		if (seq_num[i]) break;
1574 	}
1575 	// FIXME: 检查溢出
1576 	return 1;
1577 }
1578 
tls_compression_methods_has_null_compression(const uint8_t * meths,size_t methslen)1579 int tls_compression_methods_has_null_compression(const uint8_t *meths, size_t methslen)
1580 {
1581 	if (!meths || !methslen) {
1582 		error_print();
1583 		return -1;
1584 	}
1585 	while (methslen--) {
1586 		if (*meths++ == TLS_compression_null) {
1587 			return 1;
1588 		}
1589 	}
1590 	error_print();
1591 	return -1;
1592 }
1593 
tls_send_alert(TLS_CONNECT * conn,int alert)1594 int tls_send_alert(TLS_CONNECT *conn, int alert)
1595 {
1596 	uint8_t record[5 + 2];
1597 	size_t recordlen;
1598 
1599 	if (!conn) {
1600 		error_print();
1601 		return -1;
1602 	}
1603 	tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
1604 	tls_record_set_alert(record, &recordlen, TLS_alert_level_fatal, alert);
1605 
1606 	if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
1607 		error_print();
1608 		return -1;
1609 	}
1610 	tls_record_trace(stderr, record, sizeof(record), 0, 0);
1611 	return 1;
1612 }
1613 
tls_alert_level(int alert)1614 int tls_alert_level(int alert)
1615 {
1616 	switch (alert) {
1617 	case TLS_alert_bad_certificate:
1618 	case TLS_alert_unsupported_certificate:
1619 	case TLS_alert_certificate_revoked:
1620 	case TLS_alert_certificate_expired:
1621 	case TLS_alert_certificate_unknown:
1622 		return 0;
1623 	case TLS_alert_user_canceled:
1624 	case TLS_alert_no_renegotiation:
1625 		return TLS_alert_level_warning;
1626 	default:
1627 		return TLS_alert_level_fatal;
1628 	}
1629 	return -1;
1630 }
1631 
tls_send_warning(TLS_CONNECT * conn,int alert)1632 int tls_send_warning(TLS_CONNECT *conn, int alert)
1633 {
1634 	uint8_t record[5 + 2];
1635 	size_t recordlen;
1636 
1637 	if (!conn) {
1638 		error_print();
1639 		return -1;
1640 	}
1641 	if (tls_alert_level(alert) == TLS_alert_level_fatal) {
1642 		error_print();
1643 		return -1;
1644 	}
1645 	tls_record_set_protocol(record, conn->protocol == TLS_protocol_tls13 ? TLS_protocol_tls12 : conn->protocol);
1646 	tls_record_set_alert(record, &recordlen, TLS_alert_level_warning, alert);
1647 
1648 	if (tls_record_send(record, sizeof(record), conn->sock) != 1) {
1649 		error_print();
1650 		return -1;
1651 	}
1652 	tls_record_trace(stderr, record, sizeof(record), 0, 0);
1653 	return 1;
1654 }
1655 
tls_send(TLS_CONNECT * conn,const uint8_t * in,size_t inlen,size_t * sentlen)1656 int tls_send(TLS_CONNECT *conn, const uint8_t *in, size_t inlen, size_t *sentlen)
1657 {
1658 	const SM3_HMAC_CTX *hmac_ctx;
1659 	const SM4_KEY *enc_key;
1660 	uint8_t *seq_num;
1661 	uint8_t *record;
1662 	size_t recordlen;
1663 	uint8_t *data;
1664 	size_t datalen;
1665 
1666 	if (!conn) {
1667 		error_print();
1668 		return -1;
1669 	}
1670 	if (!in || !inlen || !sentlen) {
1671 		error_print();
1672 		return -1;
1673 	}
1674 
1675 	if (inlen > TLS_MAX_PLAINTEXT_SIZE) {
1676 		inlen = TLS_MAX_PLAINTEXT_SIZE;
1677 	}
1678 
1679 	if (conn->is_client) {
1680 		hmac_ctx = &conn->client_write_mac_ctx;
1681 		enc_key = &conn->client_write_enc_key;
1682 		seq_num = conn->client_seq_num;
1683 	} else {
1684 		hmac_ctx = &conn->server_write_mac_ctx;
1685 		enc_key = &conn->server_write_enc_key;
1686 		seq_num = conn->server_seq_num;
1687 	}
1688 	record = conn->record;
1689 
1690 	tls_trace("send ApplicationData\n");
1691 
1692 	if (tls_record_set_type(record, TLS_record_application_data) != 1
1693 		|| tls_record_set_protocol(record, conn->protocol) != 1
1694 		|| tls_record_set_length(record, inlen) != 1) {
1695 		error_print();
1696 		return -1;
1697 	}
1698 
1699 	if (tls_cbc_encrypt(hmac_ctx, enc_key, seq_num, tls_record_header(record),
1700 		in, inlen, tls_record_data(record), &datalen) != 1) {
1701 		error_print();
1702 		return -1;
1703 	}
1704 	if (tls_record_set_length(record, datalen) != 1) {
1705 		error_print();
1706 		return -1;
1707 	}
1708 	tls_seq_num_incr(seq_num);
1709 	if (tls_record_send(record, tls_record_length(record), conn->sock) != 1) {
1710 		error_print();
1711 		return -1;
1712 	}
1713 	*sentlen = inlen;
1714 	tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
1715 	return 1;
1716 }
1717 
tls_do_recv(TLS_CONNECT * conn)1718 int tls_do_recv(TLS_CONNECT *conn)
1719 {
1720 	int ret;
1721 	const SM3_HMAC_CTX *hmac_ctx;
1722 	const SM4_KEY *dec_key;
1723 	uint8_t *seq_num;
1724 
1725 	uint8_t *record = conn->record;
1726 	size_t recordlen;
1727 
1728 	if (conn->is_client) {
1729 		hmac_ctx = &conn->server_write_mac_ctx;
1730 		dec_key = &conn->server_write_enc_key;
1731 		seq_num = conn->server_seq_num;
1732 	} else {
1733 		hmac_ctx = &conn->client_write_mac_ctx;
1734 		dec_key = &conn->client_write_enc_key;
1735 		seq_num = conn->client_seq_num;
1736 	}
1737 
1738 	tls_trace("recv ApplicationData\n");
1739 	if ((ret = tls_record_recv(record, &recordlen, conn->sock)) != 1) {
1740 		if (ret < 0) error_print();
1741 		return ret;
1742 	}
1743 
1744 	tls_record_trace(stderr, record, recordlen, 0, 0);
1745 	if (tls_cbc_decrypt(hmac_ctx, dec_key, seq_num, record,
1746 		tls_record_data(record), tls_record_data_length(record),
1747 		conn->databuf, &conn->datalen) != 1) {
1748 		error_print();
1749 		return -1;
1750 	}
1751 	conn->data = conn->databuf;
1752 	tls_seq_num_incr(seq_num);
1753 
1754 	tls_record_set_data(record, conn->data, conn->datalen);
1755 	tls_trace("decrypt ApplicationData\n");
1756 	tls_record_trace(stderr, record, tls_record_length(record), 0, 0);
1757 	return 1;
1758 }
1759 
tls_recv(TLS_CONNECT * conn,uint8_t * out,size_t outlen,size_t * recvlen)1760 int tls_recv(TLS_CONNECT *conn, uint8_t *out, size_t outlen, size_t *recvlen)
1761 {
1762 	if (!conn || !out || !outlen || !recvlen) {
1763 		error_print();
1764 		return -1;
1765 	}
1766 	if (conn->datalen == 0) {
1767 		int ret;
1768 		if ((ret = tls_do_recv(conn)) != 1) {
1769 			if (ret) error_print();
1770 			return ret;
1771 		}
1772 	}
1773 	*recvlen = outlen <= conn->datalen ? outlen : conn->datalen;
1774 	memcpy(out, conn->data, *recvlen);
1775 	conn->data += *recvlen;
1776 	conn->datalen -= *recvlen;
1777 	return 1;
1778 }
1779 
tls_shutdown(TLS_CONNECT * conn)1780 int tls_shutdown(TLS_CONNECT *conn)
1781 {
1782 	size_t recordlen;
1783 	if (!conn) {
1784 		error_print();
1785 		return -1;
1786 	}
1787 	tls_trace("send Alert close_notify\n");
1788 	if (tls_send_alert(conn, TLS_alert_close_notify) != 1) {
1789 		error_print();
1790 		return -1;
1791 	}
1792 	tls_trace("recv Alert close_notify\n");
1793 
1794 	if (tls_record_do_recv(conn->record, &recordlen, conn->sock) != 1) {
1795 		error_print();
1796 		return -1;
1797 	}
1798 	tls_record_trace(stderr, conn->record, recordlen, 0, 0);
1799 
1800 	return 1;
1801 }
1802 
tls_authorities_from_certs(uint8_t * names,size_t * nameslen,size_t maxlen,const uint8_t * certs,size_t certslen)1803 int tls_authorities_from_certs(uint8_t *names, size_t *nameslen, size_t maxlen, const uint8_t *certs, size_t certslen)
1804 {
1805 	const uint8_t *cert;
1806 	size_t certlen;
1807 	const uint8_t *name;
1808 	size_t namelen;
1809 
1810 	*nameslen = 0;
1811 	while (certslen) {
1812 		size_t alen = 0;
1813 		if (x509_cert_from_der(&cert, &certlen, &certs, &certslen) != 1
1814 			|| x509_cert_get_subject(cert, certlen, &name, &namelen) != 1
1815 			|| asn1_sequence_to_der(name, namelen, NULL, &alen) != 1) {
1816 			error_print();
1817 			return -1;
1818 		}
1819 		if (tls_uint16_size() + alen > maxlen) {
1820 			error_print();
1821 			return -1;
1822 		}
1823 		tls_uint16_to_bytes(alen, &names, nameslen);
1824 		if (asn1_sequence_to_der(name, namelen, &names, nameslen) != 1) {
1825 			error_print();
1826 			return -1;
1827 		}
1828 		maxlen -= alen;
1829 	}
1830 	return 1;
1831 }
1832 
tls_authorities_issued_certificate(const uint8_t * ca_names,size_t ca_names_len,const uint8_t * certs,size_t certslen)1833 int tls_authorities_issued_certificate(const uint8_t *ca_names, size_t ca_names_len, const uint8_t *certs, size_t certslen)
1834 {
1835 	const uint8_t *cert;
1836 	size_t certlen;
1837 	const uint8_t *issuer;
1838 	size_t issuer_len;
1839 
1840 	if (x509_certs_get_last(certs, certslen, &cert, &certlen) != 1
1841 		|| x509_cert_get_issuer(cert, certlen, &issuer, &issuer_len) != 1) {
1842 		error_print();
1843 		return -1;
1844 	}
1845 	while (ca_names_len) {
1846 		const uint8_t *p;
1847 		size_t len;
1848 		const uint8_t *name;
1849 		size_t namelen;
1850 
1851 		if (tls_uint16array_from_bytes(&p, &len, &ca_names, &ca_names_len) != 1) {
1852 			error_print();
1853 			return -1;
1854 		}
1855 		if (asn1_sequence_from_der(&name, &namelen, &p, &len) != 1
1856 			|| asn1_length_is_zero(len) != 1) {
1857 			error_print();
1858 			return -1;
1859 		}
1860 		if (x509_name_equ(name, namelen, issuer, issuer_len) == 1) {
1861 			return 1;
1862 		}
1863 	}
1864 	error_print();
1865 	return 0;
1866 }
1867 
tls_cert_types_accepted(const uint8_t * types,size_t types_len,const uint8_t * client_certs,size_t client_certs_len)1868 int tls_cert_types_accepted(const uint8_t *types, size_t types_len, const uint8_t *client_certs, size_t client_certs_len)
1869 {
1870 	const uint8_t *cert;
1871 	size_t certlen;
1872 	int sig_alg;
1873 	size_t i;
1874 
1875 	if (x509_certs_get_cert_by_index(client_certs, client_certs_len, 0, &cert, &certlen) != 1) {
1876 		error_print();
1877 		return -1;
1878 	}
1879 	if ((sig_alg = tls_cert_type_from_oid(OID_sm2sign_with_sm3)) < 0) {
1880 		error_print();
1881 		return -1;
1882 	}
1883 	for (i = 0; i < types_len; i++) {
1884 		if (sig_alg == types[i]) {
1885 			return 1;
1886 		}
1887 	}
1888 	return 0;
1889 }
1890 
tls_client_verify_init(TLS_CLIENT_VERIFY_CTX * ctx)1891 int tls_client_verify_init(TLS_CLIENT_VERIFY_CTX *ctx)
1892 {
1893 	if (!ctx) {
1894 		error_print();
1895 		return -1;
1896 	}
1897 	memset(ctx, 0, sizeof(TLS_CLIENT_VERIFY_CTX));
1898 	return 1;
1899 }
1900 
tls_client_verify_update(TLS_CLIENT_VERIFY_CTX * ctx,const uint8_t * handshake,size_t handshake_len)1901 int tls_client_verify_update(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *handshake, size_t handshake_len)
1902 {
1903 	uint8_t *buf;
1904 	if (!ctx || !handshake || !handshake_len) {
1905 		error_print();
1906 		return -1;
1907 	}
1908 	if (ctx->index < 0 || ctx->index > 7) {
1909 		error_print();
1910 		return -1;
1911 	}
1912 	if (!(buf = malloc(handshake_len))) {
1913 		error_print();
1914 		return -1;
1915 	}
1916 	memcpy(buf, handshake, handshake_len);
1917 	ctx->handshake[ctx->index] = buf;
1918 	ctx->handshake_len[ctx->index] = handshake_len;
1919 	ctx->index++;
1920 	return 1;
1921 }
1922 
tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX * ctx,const uint8_t * sig,size_t siglen,const SM2_KEY * public_key)1923 int tls_client_verify_finish(TLS_CLIENT_VERIFY_CTX *ctx, const uint8_t *sig, size_t siglen, const SM2_KEY *public_key)
1924 {
1925 	int ret;
1926 	SM2_SIGN_CTX sm2_ctx;
1927 	int i;
1928 
1929 	if (!ctx || !sig || !siglen || !public_key) {
1930 		error_print();
1931 		return -1;
1932 	}
1933 
1934 	if (ctx->index != 8) {
1935 		error_print();
1936 		return -1;
1937 	}
1938 	if (sm2_verify_init(&sm2_ctx, public_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH) != 1) {
1939 		error_print();
1940 		return -1;
1941 	}
1942 	for (i = 0; i < 8; i++) {
1943 		if (sm2_verify_update(&sm2_ctx, ctx->handshake[i], ctx->handshake_len[i]) != 1) {
1944 			error_print();
1945 			return -1;
1946 		}
1947 	}
1948 	if ((ret = sm2_verify_finish(&sm2_ctx, sig, siglen)) < 0) {
1949 		error_print();
1950 		return -1;
1951 	}
1952 	return ret;
1953 }
1954 
tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX * ctx)1955 void tls_client_verify_cleanup(TLS_CLIENT_VERIFY_CTX *ctx)
1956 {
1957 	if (ctx) {
1958 		int i;
1959 		for (i = 0; i< ctx->index; i++) {
1960 			if (ctx->handshake[i]) {
1961 				free(ctx->handshake[i]);
1962 				ctx->handshake[i] = NULL;
1963 				ctx->handshake_len[i] = 0;
1964 			}
1965 		}
1966 	}
1967 }
1968 
tls_cipher_suites_select(const uint8_t * client_ciphers,size_t client_ciphers_len,const int * server_ciphers,size_t server_ciphers_cnt,int * selected_cipher)1969 int tls_cipher_suites_select(const uint8_t *client_ciphers, size_t client_ciphers_len,
1970 	const int *server_ciphers, size_t server_ciphers_cnt,
1971 	int *selected_cipher)
1972 {
1973 	if (!client_ciphers || !client_ciphers_len
1974 		|| !server_ciphers || !server_ciphers_cnt || !selected_cipher) {
1975 		error_print();
1976 		return -1;
1977 	}
1978 	while (server_ciphers_cnt--) {
1979 		const uint8_t *p = client_ciphers;
1980 		size_t len = client_ciphers_len;
1981 		while (len) {
1982 			uint16_t cipher;
1983 			if (tls_uint16_from_bytes(&cipher, &p, &len) != 1) {
1984 				error_print();
1985 				return -1;
1986 			}
1987 			if (cipher == *server_ciphers) {
1988 				*selected_cipher = *server_ciphers;
1989 				return 1;
1990 			}
1991 		}
1992 		server_ciphers++;
1993 	}
1994 	return 0;
1995 }
1996 
tls_ctx_cleanup(TLS_CTX * ctx)1997 void tls_ctx_cleanup(TLS_CTX *ctx)
1998 {
1999 	if (ctx) {
2000 		gmssl_secure_clear(&ctx->signkey, sizeof(SM2_KEY));
2001 		gmssl_secure_clear(&ctx->kenckey, sizeof(SM2_KEY));
2002 		if (ctx->certs) free(ctx->certs);
2003 		if (ctx->cacerts) free(ctx->cacerts);
2004 		memset(ctx, 0, sizeof(TLS_CTX));
2005 	}
2006 }
2007 
tls_ctx_init(TLS_CTX * ctx,int protocol,int is_client)2008 int tls_ctx_init(TLS_CTX *ctx, int protocol, int is_client)
2009 {
2010 	if (!ctx) {
2011 		error_print();
2012 		return -1;
2013 	}
2014 	memset(ctx, 0, sizeof(*ctx));
2015 
2016 	switch (protocol) {
2017 	case TLS_protocol_tlcp:
2018 	case TLS_protocol_tls12:
2019 	case TLS_protocol_tls13:
2020 		ctx->protocol = protocol;
2021 		break;
2022 	default:
2023 		error_print();
2024 		return -1;
2025 	}
2026 	ctx->is_client = is_client ? 1 : 0;
2027 	return 1;
2028 }
2029 
tls_ctx_set_cipher_suites(TLS_CTX * ctx,const int * cipher_suites,size_t cipher_suites_cnt)2030 int tls_ctx_set_cipher_suites(TLS_CTX *ctx, const int *cipher_suites, size_t cipher_suites_cnt)
2031 {
2032 	size_t i;
2033 
2034 	if (!ctx || !cipher_suites || !cipher_suites_cnt) {
2035 		error_print();
2036 		return -1;
2037 	}
2038 	if (cipher_suites_cnt < 1 || cipher_suites_cnt > TLS_MAX_CIPHER_SUITES_COUNT) {
2039 		error_print();
2040 		return -1;
2041 	}
2042 	for (i = 0; i < cipher_suites_cnt; i++) {
2043 		if (!tls_cipher_suite_name(cipher_suites[i])) {
2044 			error_print();
2045 			return -1;
2046 		}
2047 	}
2048 	for (i = 0; i < cipher_suites_cnt; i++) {
2049 		ctx->cipher_suites[i] = cipher_suites[i];
2050 	}
2051 	ctx->cipher_suites_cnt = cipher_suites_cnt;
2052 	return 1;
2053 }
2054 
tls_ctx_set_ca_certificates(TLS_CTX * ctx,const char * cacertsfile,int depth)2055 int tls_ctx_set_ca_certificates(TLS_CTX *ctx, const char *cacertsfile, int depth)
2056 {
2057 	if (!ctx || !cacertsfile) {
2058 		error_print();
2059 		return -1;
2060 	}
2061 	if (depth < 0 || depth > TLS_MAX_VERIFY_DEPTH) {
2062 		error_print();
2063 		return -1;
2064 	}
2065 	if (!tls_protocol_name(ctx->protocol)) {
2066 		error_print();
2067 		return -1;
2068 	}
2069 	if (ctx->cacerts) {
2070 		error_print();
2071 		return -1;
2072 	}
2073 	if (x509_certs_new_from_file(&ctx->cacerts, &ctx->cacertslen, cacertsfile) != 1) {
2074 		error_print();
2075 		return -1;
2076 	}
2077 	if (ctx->cacertslen == 0) {
2078 		error_print();
2079 		return -1;
2080 	}
2081 
2082 	ctx->verify_depth = depth;
2083 	return 1;
2084 }
2085 
tls_ctx_set_certificate_and_key(TLS_CTX * ctx,const char * chainfile,const char * keyfile,const char * keypass)2086 int tls_ctx_set_certificate_and_key(TLS_CTX *ctx, const char *chainfile,
2087 	const char *keyfile, const char *keypass)
2088 {
2089 	int ret = -1;
2090 	uint8_t *certs = NULL;
2091 	size_t certslen;
2092 	FILE *keyfp = NULL;
2093 	SM2_KEY key;
2094 	const uint8_t *cert;
2095 	size_t certlen;
2096 	SM2_KEY public_key;
2097 
2098 	if (!ctx || !chainfile || !keyfile || !keypass) {
2099 		error_print();
2100 		return -1;
2101 	}
2102 	if (!tls_protocol_name(ctx->protocol)) {
2103 		error_print();
2104 		return -1;
2105 	}
2106 	if (ctx->certs) {
2107 		error_print();
2108 		return -1;
2109 	}
2110 
2111 	if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
2112 		error_print();
2113 		goto end;
2114 	}
2115 	if (!(keyfp = fopen(keyfile, "r"))) {
2116 		error_print();
2117 		goto end;
2118 	}
2119 	if (sm2_private_key_info_decrypt_from_pem(&key, keypass, keyfp) != 1) {
2120 		error_print();
2121 		goto end;
2122 	}
2123 	if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
2124 		|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1) {
2125 		error_print();
2126 		return -1;
2127 	}
2128 	if (sm2_public_key_equ(&key, &public_key) != 1) {
2129 		error_print();
2130 		return -1;
2131 	}
2132 	ctx->certs = certs;
2133 	ctx->certslen = certslen;
2134 	ctx->signkey = key;
2135 	certs = NULL;
2136 	ret = 1;
2137 
2138 end:
2139 	gmssl_secure_clear(&key, sizeof(key));
2140 	if (certs) free(certs);
2141 	if (keyfp) fclose(keyfp);
2142 	return ret;
2143 }
2144 
tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX * ctx,const char * chainfile,const char * signkeyfile,const char * signkeypass,const char * kenckeyfile,const char * kenckeypass)2145 int tls_ctx_set_tlcp_server_certificate_and_keys(TLS_CTX *ctx, const char *chainfile,
2146 	const char *signkeyfile, const char *signkeypass,
2147 	const char *kenckeyfile, const char *kenckeypass)
2148 {
2149 	int ret = -1;
2150 	uint8_t *certs = NULL;
2151 	size_t certslen;
2152 	FILE *signkeyfp = NULL;
2153 	FILE *kenckeyfp = NULL;
2154 	SM2_KEY signkey;
2155 	SM2_KEY kenckey;
2156 
2157 	const uint8_t *cert;
2158 	size_t certlen;
2159 	SM2_KEY public_key;
2160 
2161 	if (!ctx || !chainfile || !signkeyfile || !signkeypass || !kenckeyfile || !kenckeypass) {
2162 		error_print();
2163 		return -1;
2164 	}
2165 	if (!tls_protocol_name(ctx->protocol)) {
2166 		error_print();
2167 		return -1;
2168 	}
2169 	if (ctx->certs) {
2170 		error_print();
2171 		return -1;
2172 	}
2173 
2174 	if (x509_certs_new_from_file(&certs, &certslen, chainfile) != 1) {
2175 		error_print();
2176 		return -1;
2177 	}
2178 
2179 	if (!(signkeyfp = fopen(signkeyfile, "r"))) {
2180 		error_print();
2181 		goto end;
2182 	}
2183 	if (sm2_private_key_info_decrypt_from_pem(&signkey, signkeypass, signkeyfp) != 1) {
2184 		error_print();
2185 		goto end;
2186 	}
2187 	if (x509_certs_get_cert_by_index(certs, certslen, 0, &cert, &certlen) != 1
2188 		|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1
2189 		|| sm2_public_key_equ(&signkey, &public_key) != 1) {
2190 		error_print();
2191 		goto end;
2192 	}
2193 
2194 	if (!(kenckeyfp = fopen(kenckeyfile, "r"))) {
2195 		error_print();
2196 		goto end;
2197 	}
2198 	if (sm2_private_key_info_decrypt_from_pem(&kenckey, kenckeypass, kenckeyfp) != 1) {
2199 		error_print();
2200 		goto end;
2201 	}
2202 	if (x509_certs_get_cert_by_index(certs, certslen, 1, &cert, &certlen) != 1
2203 		|| x509_cert_get_subject_public_key(cert, certlen, &public_key) != 1
2204 		|| sm2_public_key_equ(&kenckey, &public_key) != 1) {
2205 		error_print();
2206 		goto end;
2207 	}
2208 
2209 	ctx->certs = certs;
2210 	ctx->certslen = certslen;
2211 	ctx->signkey = signkey;
2212 	ctx->kenckey = kenckey;
2213 	certs = NULL;
2214 	ret = 1;
2215 
2216 end:
2217 	gmssl_secure_clear(&signkey, sizeof(signkey));
2218 	gmssl_secure_clear(&kenckey, sizeof(kenckey));
2219 	if (certs) free(certs);
2220 	if (signkeyfp) fclose(signkeyfp);
2221 	if (kenckeyfp) fclose(kenckeyfp);
2222 	return ret;
2223 }
2224 
tls_init(TLS_CONNECT * conn,const TLS_CTX * ctx)2225 int tls_init(TLS_CONNECT *conn, const TLS_CTX *ctx)
2226 {
2227 	size_t i;
2228 	memset(conn, 0, sizeof(*conn));
2229 
2230 	conn->protocol = ctx->protocol;
2231 	conn->is_client = ctx->is_client;
2232 	for (i = 0; i < ctx->cipher_suites_cnt; i++) {
2233 		conn->cipher_suites[i] = ctx->cipher_suites[i];
2234 	}
2235 	conn->cipher_suites_cnt = ctx->cipher_suites_cnt;
2236 
2237 
2238 	if (ctx->certslen > TLS_MAX_CERTIFICATES_SIZE) {
2239 		error_print();
2240 		return -1;
2241 	}
2242 	if (conn->is_client) {
2243 		memcpy(conn->client_certs, ctx->certs, ctx->certslen);
2244 		conn->client_certs_len = ctx->certslen;
2245 	} else {
2246 		memcpy(conn->server_certs, ctx->certs, ctx->certslen);
2247 		conn->server_certs_len = ctx->certslen;
2248 	}
2249 
2250 	if (ctx->cacertslen > TLS_MAX_CERTIFICATES_SIZE) {
2251 		error_print();
2252 		return -1;
2253 	}
2254 	memcpy(conn->ca_certs, ctx->cacerts, ctx->cacertslen);
2255 	conn->ca_certs_len = ctx->cacertslen;
2256 
2257 	conn->sign_key = ctx->signkey;
2258 	conn->kenc_key = ctx->kenckey;
2259 
2260 	return 1;
2261 }
2262 
tls_cleanup(TLS_CONNECT * conn)2263 void tls_cleanup(TLS_CONNECT *conn)
2264 {
2265 	gmssl_secure_clear(conn, sizeof(TLS_CONNECT));
2266 }
2267 
2268 
tls_set_socket(TLS_CONNECT * conn,int sock)2269 int tls_set_socket(TLS_CONNECT *conn, int sock)
2270 {
2271 	int opts;
2272 
2273 	if ((opts = fcntl(sock, F_GETFL)) < 0) {
2274 		error_print();
2275 		perror("tls_set_socket");
2276 		return -1;
2277 	}
2278 	opts &= ~O_NONBLOCK;
2279 	if (fcntl(sock, F_SETFL, opts) < 0) {
2280 		error_print();
2281 		return -1;
2282 	}
2283 	conn->sock = sock;
2284 	return 1;
2285 }
2286 
tls_do_handshake(TLS_CONNECT * conn)2287 int tls_do_handshake(TLS_CONNECT *conn)
2288 {
2289 	switch (conn->protocol) {
2290 	case TLS_protocol_tlcp:
2291 		if (conn->is_client) return tlcp_do_connect(conn);
2292 		else return tlcp_do_accept(conn);
2293 	case TLS_protocol_tls12:
2294 		if (conn->is_client) return tls12_do_connect(conn);
2295 		else return tls12_do_accept(conn);
2296 	case TLS_protocol_tls13:
2297 		if (conn->is_client) return tls13_do_connect(conn);
2298 		else return tls13_do_accept(conn);
2299 	}
2300 	error_print();
2301 	return -1;
2302 }
2303 
tls_get_verify_result(TLS_CONNECT * conn,int * result)2304 int tls_get_verify_result(TLS_CONNECT *conn, int *result)
2305 {
2306 	*result = conn->verify_result;
2307 	return 1;
2308 }
2309