1 /*
2 * This file is part of the openHiTLS project.
3 *
4 * openHiTLS is licensed under the Mulan PSL v2.
5 * You can use this software according to the terms and conditions of the Mulan PSL v2.
6 * You may obtain a copy of Mulan PSL v2 at:
7 *
8 * http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 * See the Mulan PSL v2 for more details.
14 */
15
16 #include <stdlib.h>
17 #include <malloc.h>
18 #include <stdatomic.h>
19 #include "securec.h"
20 #include "hitls_crypt_type.h"
21 #include "hitls_session.h"
22 #include "logger.h"
23 #include "bsl_sal.h"
24 #include "hitls_error.h"
25 #include "hitls_sni.h"
26 #include "sni.h"
27 #include "hitls_alpn.h"
28 #include "hitls_type.h"
29 #include "common_func.h"
30
31 #define SUCCESS 0
32 #define ERROR (-1)
33 #define MAX_CERT_PATH_LENGTH (128)
34 #define SINGLE_CERT_LEN (120)
35
36 #define KEY_NAME_SIZE 16
37 #define IV_SIZE 16
38 #define KEY_SIZE 32
39 #define RENEGOTIATE_FAIL 1
40
41 static uint8_t g_keyName[KEY_NAME_SIZE] = {
42 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A
43 };
44
45 static uint8_t g_key[KEY_SIZE] = {
46 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A,
47 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A
48 };
49
50 static uint8_t g_iv[IV_SIZE] = {
51 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A, 0x1A
52 };
53
54 typedef struct {
55 char *name;
56 void *cb;
57 } ExampleCb;
58
59 typedef struct {
60 char *name;
61 void *(*data)(void);
62 } ExampleData;
63
64 #define ASSERT_RETURN(condition, log) \
65 do { \
66 if (!(condition)) { \
67 LOG_ERROR(log); \
68 return ERROR; \
69 } \
70 } while (0)
71
72 static char g_localIdentity[PSK_MAX_LEN] = "Client_identity";
73 static char g_localPsk[PSK_MAX_LEN] = "1A1A1A1A1A";
74
ExampleSetPsk(char * psk)75 int32_t ExampleSetPsk(char *psk)
76 {
77 if (psk == NULL) {
78 LOG_DEBUG("input error.");
79 return -1;
80 }
81 (void)memset_s(g_localPsk, PSK_MAX_LEN, 0, PSK_MAX_LEN);
82 if (strcpy_s(g_localPsk, PSK_MAX_LEN, psk) != EOK) {
83 LOG_DEBUG("ExampleSetPsk failed.");
84 return -1;
85 }
86 return 0;
87 }
88
ExampleHexStr2BufHelper(const uint8_t * input,uint32_t inLen,uint8_t * out,uint32_t outLen,uint32_t * usedLen)89 int32_t ExampleHexStr2BufHelper(const uint8_t *input, uint32_t inLen, uint8_t *out, uint32_t outLen, uint32_t *usedLen)
90 {
91 (void)inLen;
92 (void)outLen;
93 char indexH[2] = {0};
94 char indexL[2] = {0};
95 const uint8_t *curr = NULL;
96 uint8_t *outIndex = NULL;
97 int32_t high, low;
98
99 if ((input == NULL) || (out == NULL) || (usedLen == NULL)) {
100 return -1;
101 }
102
103 for (curr = input, outIndex = out; *curr;) {
104 indexH[0] = *curr++;
105 indexL[0] = *curr++;
106 if (indexL[0] == '\0') {
107 return -1;
108 }
109
110 high = (int32_t)strtol(indexH, NULL, 16); // Converting char to Hexadecimal numbers
111 low = (int32_t)strtol(indexL, NULL, 16); // Converting char to Hexadecimal numbers
112
113 if (high < 0 || low < 0) {
114 return -1;
115 }
116 *outIndex++ = (uint8_t)((high << 4) | low); // The upper four bits of the are shifted to the left
117 }
118
119 *usedLen = outIndex - out;
120
121 return 0;
122 }
123
ExampleClientCb(HITLS_Ctx * ctx,const uint8_t * hint,uint8_t * identity,uint32_t maxIdentityLen,uint8_t * psk,uint32_t maxPskLen)124 uint32_t ExampleClientCb(HITLS_Ctx *ctx, const uint8_t *hint, uint8_t *identity, uint32_t maxIdentityLen, uint8_t *psk,
125 uint32_t maxPskLen)
126 {
127 (void)ctx;
128 (void)hint;
129 int32_t ret;
130 uint8_t pskTrans[PSK_MAX_LEN] = {0};
131 uint32_t pskTransUsedLen = 0u;
132
133 ret = ExampleHexStr2BufHelper((uint8_t *)g_localPsk, sizeof(g_localPsk), pskTrans, PSK_MAX_LEN, &pskTransUsedLen);
134 if (ret != 0) {
135 return 0;
136 }
137
138 /* strlen(g_localIdentity) + 1 copy terminator */
139 if (memcpy_s(identity, maxIdentityLen, g_localIdentity, strlen(g_localIdentity) + 1) != EOK) {
140 return 0;
141 }
142 if (memcpy_s(psk, maxPskLen, pskTrans, pskTransUsedLen) != EOK) {
143 return 0;
144 }
145 return pskTransUsedLen;
146 }
147
ExampleServerCb(HITLS_Ctx * ctx,const uint8_t * identity,uint8_t * psk,uint32_t maxPskLen)148 uint32_t ExampleServerCb(HITLS_Ctx *ctx, const uint8_t *identity, uint8_t *psk, uint32_t maxPskLen)
149 {
150 (void)ctx;
151
152 if (identity == NULL || strcmp((const char *)identity, g_localIdentity) != 0) {
153 return 0;
154 }
155
156 int32_t ret;
157 uint8_t pskTrans[PSK_MAX_LEN] = {0};
158 uint32_t pskTransUsedLen = 0u;
159
160 ret = ExampleHexStr2BufHelper((uint8_t *)g_localPsk, sizeof(g_localPsk), pskTrans, PSK_MAX_LEN, &pskTransUsedLen);
161 if (ret != 0) {
162 return 0;
163 }
164
165 if (memcpy_s(psk, maxPskLen, pskTrans, pskTransUsedLen) != EOK) {
166 return 0;
167 }
168
169 return pskTransUsedLen;
170 }
171
SetCipherInfo(void * cipher)172 static void SetCipherInfo(void *cipher)
173 {
174 HITLS_CipherParameters *cipherPara = cipher;
175 cipherPara->type = HITLS_CBC_CIPHER;
176 cipherPara->algo = HITLS_CIPHER_AES_256_CBC;
177 cipherPara->key = g_key;
178 cipherPara->keyLen = sizeof(g_key);
179 cipherPara->hmacKey = g_key;
180 cipherPara->hmacKeyLen = sizeof(g_key);
181 cipherPara->iv = g_iv;
182 cipherPara->ivLen = sizeof(g_iv);
183 return;
184 }
185
ExampleTicketKeySuccessCb(uint8_t * keyName,uint32_t keyNameSize,void * cipher,uint8_t isEncrypt)186 int32_t ExampleTicketKeySuccessCb(uint8_t *keyName, uint32_t keyNameSize, void *cipher, uint8_t isEncrypt)
187 {
188 if (isEncrypt) {
189 if (memcpy_s(keyName, keyNameSize, g_keyName, KEY_NAME_SIZE) != EOK) {
190 return HITLS_TICKET_KEY_RET_FAIL;
191 }
192 SetCipherInfo(cipher);
193 return HITLS_TICKET_KEY_RET_SUCCESS;
194 }
195
196 if (memcmp(keyName, g_keyName, KEY_NAME_SIZE) != 0) {
197 return HITLS_TICKET_KEY_RET_FAIL;
198 }
199 SetCipherInfo(cipher);
200 return HITLS_TICKET_KEY_RET_SUCCESS;
201 }
202
ExampleTicketKeyRenewCb(uint8_t * keyName,uint32_t keyNameSize,void * cipher,uint8_t isEncrypt)203 int32_t ExampleTicketKeyRenewCb(uint8_t *keyName, uint32_t keyNameSize, void *cipher, uint8_t isEncrypt)
204 {
205 if (isEncrypt) {
206 if (memcpy_s(keyName, keyNameSize, g_keyName, KEY_NAME_SIZE) != EOK) {
207 return HITLS_TICKET_KEY_RET_FAIL;
208 }
209 SetCipherInfo(cipher);
210 return HITLS_TICKET_KEY_RET_SUCCESS_RENEW;
211 }
212
213 if (memcmp(keyName, g_keyName, KEY_NAME_SIZE) != 0) {
214 return HITLS_TICKET_KEY_RET_FAIL;
215 }
216 SetCipherInfo(cipher);
217 return HITLS_TICKET_KEY_RET_SUCCESS_RENEW;
218 }
219
ExampleTicketKeyAlertCb(uint8_t * keyName,uint32_t keyNameSize,HITLS_CipherParameters * cipher,uint8_t isEncrypt)220 int32_t ExampleTicketKeyAlertCb(uint8_t *keyName, uint32_t keyNameSize, HITLS_CipherParameters *cipher,
221 uint8_t isEncrypt)
222 {
223 if (isEncrypt) {
224 (void)memcpy_s(keyName, keyNameSize, g_keyName, KEY_NAME_SIZE);
225 SetCipherInfo(cipher);
226 return HITLS_TICKET_KEY_RET_SUCCESS_RENEW;
227 } else {
228 return HITLS_TICKET_KEY_RET_NEED_ALERT;
229 }
230 }
231
ExampleTicketKeyFailCb(uint8_t * keyName,uint32_t keyNameSize,HITLS_CipherParameters * cipher,uint8_t isEncrypt)232 int32_t ExampleTicketKeyFailCb(uint8_t *keyName, uint32_t keyNameSize, HITLS_CipherParameters *cipher,
233 uint8_t isEncrypt)
234 {
235 if (isEncrypt) {
236 (void)memcpy_s(keyName, keyNameSize, g_keyName, KEY_NAME_SIZE);
237 SetCipherInfo(cipher);
238 return HITLS_TICKET_KEY_RET_SUCCESS_RENEW;
239 }
240
241 SetCipherInfo(cipher);
242 return HITLS_TICKET_KEY_RET_FAIL;
243 }
244
ExampleServerNameCb(HITLS_Ctx * ctx,int * alert,void * arg)245 int32_t ExampleServerNameCb(HITLS_Ctx *ctx, int *alert, void *arg)
246 {
247 (void)ctx;
248 (void)arg;
249 *alert = HITLS_ACCEPT_SNI_ERR_OK;
250 return HITLS_ACCEPT_SNI_ERR_OK;
251 }
252
ExampleServerNameCbNOACK(HITLS_Ctx * ctx,int * alert,void * arg)253 int32_t ExampleServerNameCbNOACK(HITLS_Ctx *ctx, int *alert, void *arg)
254 {
255 (void)ctx;
256 (void)alert;
257 (void)arg;
258 return HITLS_ACCEPT_SNI_ERR_NOACK;
259 }
260
ExampleServerNameCbALERT(HITLS_Ctx * ctx,int * alert,void * arg)261 int32_t ExampleServerNameCbALERT(HITLS_Ctx *ctx, int *alert, void *arg)
262 {
263 (void)ctx;
264 (void)alert;
265 (void)arg;
266 return HITLS_ACCEPT_SNI_ERR_ALERT_FATAL;
267 }
268
269 SNI_Arg *g_sniArg;
ExampleServerNameArg(void)270 void *ExampleServerNameArg(void)
271 {
272 return g_sniArg;
273 }
274
275 static char *g_alpnhttp = "http";
276
ExampleAlpnParseProtocolList1(uint8_t * out,uint8_t * outLen,uint8_t * in,uint8_t inLen)277 int32_t ExampleAlpnParseProtocolList1(uint8_t *out, uint8_t *outLen, uint8_t *in, uint8_t inLen)
278 {
279 if (out == NULL || outLen == NULL || in == NULL) {
280 return HITLS_NULL_INPUT;
281 }
282
283 if (inLen == 0) {
284 return HITLS_CONFIG_INVALID_LENGTH;
285 }
286
287 uint8_t i = 0u;
288 uint8_t commaNum = 0u;
289 uint8_t startPos = 0u;
290
291 for (i = 0u; i <= inLen; ++i) {
292 if (i == inLen || in[i] == ',') {
293 if (i == startPos) {
294 ++startPos;
295 ++commaNum;
296 continue;
297 }
298 out[startPos - commaNum] = (uint8_t)(i - startPos);
299 startPos = i + 1;
300 } else {
301 out[i + 1 - commaNum] = in[i];
302 }
303 }
304
305 *outLen = inLen + 1 - commaNum;
306
307 return HITLS_SUCCESS;
308 }
309
ExampleAlpnCb(HITLS_Ctx * ctx,char ** selectedProto,uint8_t * selectedProtoSize,char * clientAlpnList,uint32_t clientAlpnListSize,void * userData)310 int32_t ExampleAlpnCb(HITLS_Ctx *ctx, char **selectedProto, uint8_t *selectedProtoSize, char *clientAlpnList,
311 uint32_t clientAlpnListSize, void *userData)
312 {
313 (void)ctx;
314 (void)userData;
315 if (clientAlpnListSize >= 5 && memcmp(clientAlpnList + 1, "http", 4) == 0) {
316 *selectedProto = clientAlpnList + 1;
317 *selectedProtoSize = 4;
318 return HITLS_ALPN_ERR_OK;
319 } else if (clientAlpnListSize >= 4 && memcmp(clientAlpnList + 1, "ftp", 3) == 0) {
320 *selectedProto = g_alpnhttp;
321 *selectedProtoSize = 4;
322 return HITLS_ALPN_ERR_OK;
323 } else if (clientAlpnListSize >= 4 && memcmp(clientAlpnList + 1, "mml", 3) == 0) {
324 *selectedProto = g_alpnhttp;
325 *selectedProtoSize = 4;
326 return HITLS_ALPN_ERR_ALERT_FATAL;
327 } else if (clientAlpnListSize >= 4 && memcmp(clientAlpnList + 1, "www", 3) == 0) {
328 *selectedProto = g_alpnhttp;
329 *selectedProtoSize = 4;
330 return HITLS_ALPN_ERR_OK;
331 } else {
332 return HITLS_ALPN_ERR_NOACK;
333 }
334 }
335
AlpnCbWARN1(HITLS_Ctx * ctx,uint8_t ** selectedProto,uint8_t * selectedProtoSize,uint8_t * clientAlpnList,uint32_t clientAlpnListSize,void * userData)336 int32_t AlpnCbWARN1(HITLS_Ctx *ctx, uint8_t **selectedProto, uint8_t *selectedProtoSize, uint8_t *clientAlpnList,
337 uint32_t clientAlpnListSize, void *userData)
338 {
339 (void)ctx;
340 (void)selectedProto;
341 (void)selectedProtoSize;
342 (void)clientAlpnList;
343 (void)clientAlpnListSize;
344 (void)userData;
345
346 return HITLS_ALPN_ERR_ALERT_WARNING;
347 }
348
AlpnCbALERT1(HITLS_Ctx * ctx,uint8_t ** selectedProto,uint8_t * selectedProtoSize,uint8_t * clientAlpnList,uint32_t clientAlpnListSize,void * userData)349 int32_t AlpnCbALERT1(HITLS_Ctx *ctx, uint8_t **selectedProto, uint8_t *selectedProtoSize, uint8_t *clientAlpnList,
350 uint32_t clientAlpnListSize, void *userData)
351 {
352 (void)ctx;
353 (void)selectedProto;
354 (void)selectedProtoSize;
355 (void)clientAlpnList;
356 (void)clientAlpnListSize;
357 (void)userData;
358
359 return HITLS_ALPN_ERR_ALERT_FATAL;
360 }
361
ExampleAlpnData(void)362 void *ExampleAlpnData(void)
363 {
364 // Return the alpnData address.
365 return "audata";
366 }
367
GetTicketKeyCb(char * str)368 void *GetTicketKeyCb(char *str)
369 {
370 const ExampleCb cbList[] = {
371 {"ExampleTicketKeySuccessCb", ExampleTicketKeySuccessCb},
372 {"ExampleTicketKeyRenewCb", ExampleTicketKeyRenewCb},
373 {"ExampleTicketKeyAlertCb", ExampleTicketKeyAlertCb},
374 {"ExampleTicketKeyFailCb", ExampleTicketKeyFailCb},
375 };
376
377 int len = sizeof(cbList) / sizeof(cbList[0]);
378 for (int i = 0; i < len; i++) {
379 if (strcmp(str, cbList[i].name) == 0) {
380 return cbList[i].cb;
381 }
382 }
383 return NULL;
384 }
385
GetExtensionCb(const char * str)386 void *GetExtensionCb(const char *str)
387 {
388 const ExampleCb cbList[] = {
389 {"ExampleSNICb", ExampleServerNameCb},
390 {"ExampleAlpnCb", ExampleAlpnCb},
391 {"ExampleAlpnWarnCb", AlpnCbWARN1},
392 {"ExampleAlpAlertCb", AlpnCbALERT1},
393 {"ExampleSNICbnoack", ExampleServerNameCbNOACK},
394 {"ExampleSNICbAlert", ExampleServerNameCbALERT},
395 };
396
397 int len = sizeof(cbList) / sizeof(cbList[0]);
398 for (int i = 0; i < len; i++) {
399 if (strcmp(str, cbList[i].name) == 0) {
400 return cbList[i].cb;
401 }
402 }
403 return NULL;
404 }
405
GetExampleData(const char * str)406 void *GetExampleData(const char *str)
407 {
408 const ExampleData cbList[] = {
409 {"ExampleSNIArg", ExampleServerNameArg},
410 {"ExampleAlpnData", ExampleAlpnData},
411 };
412
413 int len = sizeof(cbList) / sizeof(cbList[0]);
414 for (int i = 0; i < len; i++) {
415 if (strcmp(str, cbList[i].name) == 0) {
416 return cbList[i].data();
417 }
418 }
419 return NULL;
420 }
421