• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020 HiSilicon (Shanghai) Technologies CO., LIMITED.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  * Description: Mbedtls Adaptation Layer of MQTT
15  * Author: LiteOS Team
16  * Create: 2020-08-10
17  */
18 #if defined(MBEDTLS)
19 
20 #include "SocketBuffer.h"
21 #include "MQTTClient.h"
22 #include "MQTTProtocolOut.h"
23 #include "SSLSocket.h"
24 #include "Log.h"
25 #include "StackTrace.h"
26 #include "Socket.h"
27 
28 #include "Heap.h"
29 
30 #include "securec.h"
31 
32 #if defined(IOT_CONNECT)
33 #include "atiny_mqtt_commu.h"
34 #include "soc_socket.h"
35 #endif
36 
37 #include <string.h>
38 #include <mbedtls/ctr_drbg.h>
39 #include <mbedtls/entropy.h>
40 #include <mbedtls/ssl.h>
41 #include <mbedtls/x509.h>
42 #include <mbedtls/net_sockets.h>
43 #include <mbedtls/platform.h>
44 #if !defined(IOT_LITEOS_ADAPT)
45 #include <ssl_misc.h>
46 #endif
47 
48 extern Sockets mod_s;
49 static ssl_mutex_type sslCoreMutex;
50 
51 static int SSL_create_mutex(ssl_mutex_type* mutex);
52 static int SSL_lock_mutex(ssl_mutex_type* mutex);
53 static int SSL_unlock_mutex(ssl_mutex_type* mutex);
54 static void SSL_destroy_mutex(ssl_mutex_type* mutex);
55 static int SSLSocket_createContext(networkHandles* net, MQTTClient_SSLOptions* opts);
56 static void SSLSocket_destroyContext(networkHandles* net);
57 static void SSLSocket_addPendingRead(int sock);
58 
59 #if defined(WIN32) || defined(WIN64)
60 #define iov_len len
61 #define iov_base buf
62 #endif
63 
64 
65 #if !defined(ARRAY_SIZE)
66 /**
67  * Macro to calculate the number of entries in an array
68  */
69 #define ARRAY_SIZE(a) (sizeof(a) / sizeof(a[0]))
70 #endif
71 
72 #if defined(SEND_MAX_LEN)
73 #define LOG_LEN 150
74 #define SSL_HEADER_LEN 64
75 #endif
76 
SSL_create_mutex(ssl_mutex_type * mutex)77 static int SSL_create_mutex(ssl_mutex_type* mutex)
78 {
79 	int rc = 0;
80 	if (mutex == NULL)
81 		return -1;
82 
83 	FUNC_ENTRY;
84 #if defined(WIN32) || defined(WIN64)
85 	*mutex = CreateMutex(NULL, 0, NULL);
86 #elif defined(COMPAT_CMSIS)
87 	*mutex = osMutexNew(NULL);
88 	rc = (*mutex == NULL) ? -1 : 0;
89 #else
90 	rc = pthread_mutex_init(mutex, NULL);
91 #endif
92 	FUNC_EXIT_RC(rc);
93 	return rc;
94 }
95 
SSL_lock_mutex(ssl_mutex_type * mutex)96 static int SSL_lock_mutex(ssl_mutex_type* mutex)
97 {
98 	int rc = -1;
99 	if (mutex == NULL)
100 		return rc;
101 
102 	/* don't add entry/exit trace points, as trace gets lock too, and it might happen quite frequently  */
103 #if defined(WIN32) || defined(WIN64)
104 	if (WaitForSingleObject(*mutex, INFINITE) != WAIT_FAILED)
105 #elif defined(COMPAT_CMSIS)
106 	if ((rc = osMutexAcquire(*mutex, osWaitForever)) == osOK)
107 #else
108 	if ((rc = pthread_mutex_lock(mutex)) == 0)
109 #endif
110 		rc = 0;
111 
112 	return rc;
113 }
114 
SSL_unlock_mutex(ssl_mutex_type * mutex)115 static int SSL_unlock_mutex(ssl_mutex_type* mutex)
116 {
117 	int rc = -1;
118 	if (mutex == NULL)
119 		return rc;
120 
121 	/* don't add entry/exit trace points, as trace gets lock too, and it might happen quite frequently  */
122 #if defined(WIN32) || defined(WIN64)
123 	if (ReleaseMutex(*mutex) != 0)
124 #elif defined (COMPAT_CMSIS)
125 	if ((rc = osMutexRelease(*mutex)) == osOK)
126 #else
127 	if ((rc = pthread_mutex_unlock(mutex)) == 0)
128 #endif
129 		rc = 0;
130 
131 	return rc;
132 }
133 
SSL_destroy_mutex(ssl_mutex_type * mutex)134 static void SSL_destroy_mutex(ssl_mutex_type* mutex)
135 {
136 	int rc = 0;
137 	if (mutex == NULL)
138 		return;
139 
140 	FUNC_ENTRY;
141 #if defined(WIN32) || defined(WIN64)
142 	if (*mutex != NULL)
143 		rc = CloseHandle(*mutex);
144 #elif defined(COMPAT_CMSIS)
145 	if (*mutex != NULL)
146 		rc = osMutexDelete(*mutex);
147 #else
148 	rc = pthread_mutex_destroy(mutex);
149 #endif
150 	FUNC_EXIT_RC(rc);
151 }
152 
153 #ifdef MBEDTLS_USE_CRT
154 /*
155  * custom verify on mbedtls, when ssl option verify is not set.
156  */
SSL_verify_discard(void * data,mbedtls_x509_crt * crt,int depth,uint32_t * flags)157 static int SSL_verify_discard(void* data, mbedtls_x509_crt* crt, int depth, uint32_t* flags)
158 {
159 	char buf[512];
160 	((void) data);
161 	(void)crt;
162 	(void)depth;
163 	if (flags != NULL && (*flags) != 0)
164 	{
165 		mbedtls_x509_crt_verify_info(buf, sizeof(buf), NULL, *flags);
166 		Log(TRACE_PROTOCOL, 1,  "Warnning! flags:%u %s", *flags, buf);
167 
168 		/* Discard CN mismatch when ssl options verify unset */
169 		if (*flags == MBEDTLS_X509_BADCERT_CN_MISMATCH)
170 			*flags = 0;
171 	}
172 
173 	return 0;
174 }
175 #endif
176 
SSLSocket_handleOpensslInit(int bool_value)177 void SSLSocket_handleOpensslInit(int bool_value)
178 {
179 	(void)bool_value;
180 	return;
181 }
182 
183 
SSLSocket_initialize(void)184 int SSLSocket_initialize(void)
185 {
186 	int rc = 0;
187 	FUNC_ENTRY;
188 
189 	SSL_create_mutex(&sslCoreMutex);
190 
191 	FUNC_EXIT_RC(rc);
192 	return rc;
193 }
194 
SSLSocket_terminate(void)195 void SSLSocket_terminate(void)
196 {
197 	FUNC_ENTRY;
198 
199 	SSL_destroy_mutex(&sslCoreMutex);
200 
201 	FUNC_EXIT;
202 }
203 
SSL_loadClientCrt(networkHandles * net,const MQTTClient_SSLOptions * opts)204 static int SSL_loadClientCrt(networkHandles* net, const MQTTClient_SSLOptions* opts)
205 {
206 	int rc;
207 #if !defined (IOT_CONNECT) && !defined(IOT_LITEOS_ADAPT)
208 	if (opts->keyStore != NULL && opts->privateKey != NULL)
209 #else
210 	if (opts->los_keyStore != NULL && opts->los_privateKey != NULL)
211 #endif
212 	{
213 		/* parse client cert */
214 #if !defined (IOT_CONNECT) && !defined(IOT_LITEOS_ADAPT)
215 		rc = mbedtls_x509_crt_parse_file(&net->ctx->clicert, opts->keyStore);
216 #else
217 		rc = mbedtls_x509_crt_parse( &net->ctx->clicert, opts->los_keyStore->body, opts->los_keyStore->size );
218 #endif
219 		if (rc != 0)
220 		{
221 			Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_x509_crt_parse_file");
222 			return rc;
223 		}
224 
225 		/* parse client key */
226 #if !defined (IOT_CONNECT) && !defined(IOT_LITEOS_ADAPT)
227 		rc = mbedtls_pk_parse_keyfile(&net->ctx->pkey, opts->privateKey, opts->privateKeyPassword);
228 #else
229 		if (opts->privateKeyPassword == NULL)
230 			rc = mbedtls_pk_parse_key( &net->ctx->pkey, opts->los_privateKey->body, opts->los_privateKey->size, NULL, 0,
231 				NULL, NULL);
232 		else
233 			rc = mbedtls_pk_parse_key( &net->ctx->pkey, opts->los_privateKey->body, opts->los_privateKey->size,
234 				(const unsigned char *) opts->privateKeyPassword, strlen( opts->privateKeyPassword ), NULL, NULL);
235 #endif
236 		if (rc != 0)
237 		{
238 			Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_pk_parse_keyfile");
239 			return rc;
240 		}
241 
242 		/* config own cert */
243 		rc = mbedtls_ssl_conf_own_cert(&net->ctx->conf, &net->ctx->clicert, &net->ctx->pkey);
244 		if (rc != 0)
245 		{
246 			Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_ssl_conf_own_cert");
247 			return rc;
248 		}
249 	}
250 	return 0;
251 }
252 
SSL_loadKey(networkHandles * net,const MQTTClient_SSLOptions * opts)253 static int SSL_loadKey(networkHandles* net, const MQTTClient_SSLOptions* opts)
254 {
255 	int rc;
256 #ifdef MBEDTLS_USE_PSK
257 	rc = mbedtls_ssl_conf_psk(&net->ctx->conf, opts->psk, opts->psk_len, opts->psk_id, opts->psk_id_len);
258 	if (rc != 0)
259 	{
260 		Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_ssl_conf_psk");
261 		return rc;
262 	}
263 #endif /* MBEDTLS_USE_PSK */
264 
265 #ifdef MBEDTLS_USE_CRT
266 	rc = SSL_loadClientCrt(net, opts);
267 	if (rc != 0) {
268 		return rc;
269 	}
270 
271 #if !defined (IOT_CONNECT) && !defined(IOT_LITEOS_ADAPT)
272 	if (opts->trustStore != NULL)
273 #else
274 	if (opts->los_trustStore != NULL)
275 #endif
276 	{
277 		/* parse CA file */
278 #if !defined (IOT_CONNECT) && !defined(IOT_LITEOS_ADAPT)
279 		rc = mbedtls_x509_crt_parse_file(&net->ctx->cacert, opts->trustStore);
280 #else
281 		rc = mbedtls_x509_crt_parse( &net->ctx->cacert, opts->los_trustStore->body, opts->los_trustStore->size );
282 #endif
283 		if (rc != 0)
284 		{
285 			Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_x509_crt_parse_file");
286 			return rc;
287 		}
288 		/* set the ca certificate chain */
289 		mbedtls_ssl_conf_ca_chain(&net->ctx->conf, &net->ctx->cacert, NULL);
290 	}
291 
292 	if (opts->enableServerCertAuth != 0)
293 		mbedtls_ssl_conf_authmode(&net->ctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
294 	else
295 		mbedtls_ssl_conf_authmode(&net->ctx->conf, MBEDTLS_SSL_VERIFY_NONE);
296 
297 	/* custom mbedtls verify */
298 	if (opts->verify == 0)
299 		mbedtls_ssl_conf_verify(&net->ctx->conf, SSL_verify_discard, NULL);
300 #endif /* MBEDTLS_USE_CRT */
301 	return 0;
302 }
303 
SSL_setVersion(networkHandles * net,const MQTTClient_SSLOptions * opts)304 static void SSL_setVersion(networkHandles* net, const MQTTClient_SSLOptions* opts)
305 {
306 	int sslVersion = MQTT_SSL_VERSION_DEFAULT;
307 	if (opts->struct_version >= 1)
308 		sslVersion = opts->sslVersion;
309 	switch (sslVersion)
310 	{
311 	case MQTT_SSL_VERSION_DEFAULT:
312 		break;
313 #if !defined(IOT_LITEOS_ADAPT)
314 	case MQTT_SSL_VERSION_TLS_1_0:
315 		mbedtls_ssl_conf_min_version(&net->ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_1);
316 		mbedtls_ssl_conf_max_version(&net->ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_1);
317 		break;
318 	case MQTT_SSL_VERSION_TLS_1_1:
319 		mbedtls_ssl_conf_min_version(&net->ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_2);
320 		mbedtls_ssl_conf_max_version(&net->ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_2);
321 		break;
322 #endif
323 	case MQTT_SSL_VERSION_TLS_1_2:
324 		mbedtls_ssl_conf_min_version(&net->ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
325 		mbedtls_ssl_conf_max_version(&net->ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, MBEDTLS_SSL_MINOR_VERSION_3);
326 		break;
327 	default:
328 		break;
329 	}
330 }
331 
SSL_tlsInit(networkHandles * net,const MQTTClient_SSLOptions * opts)332 static int SSL_tlsInit(networkHandles* net, const MQTTClient_SSLOptions* opts)
333 {
334 	int rc = 0;
335 	/* RNG related string */
336 	static const char personalization[] = "paho_mbedtls_entropy";
337 
338 	net->ctx = (SSL_CTX*)malloc(sizeof(SSL_CTX));
339 	if (net->ctx == NULL)
340 	{
341 		Log(TRACE_PROTOCOL, -1, "allocate context failed.");
342 		return PAHO_MEMORY_ERROR;
343 	}
344 
345 	/* initialise the mbedtls context */
346 	mbedtls_ssl_config_init(&net->ctx->conf);
347 	/* initialise RNG */
348 	mbedtls_entropy_init(&net->ctx->entropy);
349 	mbedtls_ctr_drbg_init(&net->ctx->ctr_drbg);
350 #ifdef MBEDTLS_USE_CRT
351 	/* init certificates */
352 	mbedtls_x509_crt_init(&net->ctx->cacert);
353 	mbedtls_x509_crt_init(&net->ctx->clicert);
354 	mbedtls_pk_init(&net->ctx->pkey);
355 #endif /* MBEDTLS_USE_CRT */
356 	if ((rc = mbedtls_ctr_drbg_seed(
357 			&net->ctx->ctr_drbg, mbedtls_entropy_func,
358 			&net->ctx->entropy, (const unsigned char*)personalization,
359 				sizeof(personalization))) != 0)
360 	{
361 		Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_ctr_drbg_seed returned %d", rc);
362 		SSLSocket_destroyContext(net);
363 		return rc;
364 	}
365 
366 	mbedtls_ssl_conf_rng(&net->ctx->conf, mbedtls_ctr_drbg_random,
367 							&net->ctx->ctr_drbg);
368 
369 	/* load default config */
370 	if ((rc = mbedtls_ssl_config_defaults(&net->ctx->conf,
371 			MBEDTLS_SSL_IS_CLIENT,
372 			MBEDTLS_SSL_TRANSPORT_STREAM,
373 			MBEDTLS_SSL_PRESET_DEFAULT)) != 0)
374 	{
375 		Log(TRACE_PROTOCOL, -1, "failed ! mbedtls_ssl_config_defaults returned %d", rc);
376 		SSLSocket_destroyContext(net);
377 		return rc;
378 	}
379 
380 	SSL_setVersion(net, opts);
381 
382 	return rc;
383 }
384 
SSLSocket_createContext(networkHandles * net,MQTTClient_SSLOptions * opts)385 int SSLSocket_createContext(networkHandles* net, MQTTClient_SSLOptions* opts)
386 {
387 	int rc = 0;
388 
389 	FUNC_ENTRY;
390 	if (net->ctx == NULL)
391 	{
392 		rc = SSL_tlsInit(net, opts);
393 		if (rc != 0)
394 			goto exit;
395 
396 	}
397 
398 	rc = SSL_loadKey(net, opts);
399 	if (rc != 0)
400 		SSLSocket_destroyContext(net);
401 
402 exit:
403 	FUNC_EXIT_RC(rc);
404 	return rc;
405 }
406 
407 #define SSL_MAX_COUNT 65535
408 /**
409  * SSLSocket_setSocketForSSL
410  * @return 0 is failure, 1 is success
411  */
SSLSocket_setSocketForSSL(networkHandles * net,MQTTClient_SSLOptions * opts,const char * hostname,size_t hostname_len)412 int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
413 	const char* hostname, size_t hostname_len)
414 {
415 	int rc = 0;
416 	int ret_state = 0;
417 
418 	if (net == NULL || opts == NULL || hostname == NULL)
419 		goto exit;
420 	if (hostname_len > SSL_MAX_COUNT)
421 		goto exit;
422 
423 	FUNC_ENTRY;
424 	if (net->ctx != NULL || (ret_state = SSLSocket_createContext(net, opts)) == 0)
425 	{
426 		if (net->ssl == NULL)
427 		{
428 			net->ssl = malloc(sizeof(mbedtls_ssl_context)); // free in SSLSocket_close
429 			if (net->ssl == NULL)
430 			{
431 				Log(TRACE_PROTOCOL, -1, "allocate ssl context failed.\n");
432 				goto exit;
433 			}
434 			mbedtls_ssl_init(net->ssl);
435 		}
436 		if ((ret_state = mbedtls_ssl_setup(net->ssl, &net->ctx->conf)) != 0)
437 		{
438 			Log(TRACE_PROTOCOL, 1, "failed! mbedtls_ssl_setup returned %d \n", ret_state);
439 			goto exit;
440 		}
441 #ifdef MBEDTLS_USE_CRT
442 		char *hostname_plus_null;
443 		hostname_plus_null = malloc(hostname_len + 1u );
444 		if (hostname_plus_null == NULL)
445 		{
446 			Log(TRACE_PROTOCOL, -1, "allocate hostname_plus_null failed.\n");
447 			goto exit;
448 		}
449 		MQTTStrncpy(hostname_plus_null, hostname, hostname_len + 1u);
450 		if ((ret_state = mbedtls_ssl_set_hostname(net->ssl, hostname_plus_null)) != 0)
451 		{
452 			Log(TRACE_PROTOCOL, 1, "failed! mbedtls_ssl_set_hostname returned %d \n", ret_state);
453 			free(hostname_plus_null);
454 			goto exit;
455 		}
456 		free(hostname_plus_null);
457 #endif
458 		mbedtls_ssl_set_bio(net->ssl, &net->socket, mbedtls_net_send, mbedtls_net_recv, NULL);
459 	}
460 
461 	if (ret_state == 0)
462 		rc = 1;
463 
464 exit:
465 	FUNC_EXIT_RC(rc);
466 	return rc;
467 }
468 
469 /*
470  * Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure
471  */
SSLSocket_connect(SSL * ssl,int sock,const char * hostname,int verify,int (* cb)(const char * str,size_t len,void * u),void * u)472 int SSLSocket_connect(SSL* ssl, int sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u)
473 {
474 	int rc = 1;
475 	int ret_state = 0;
476 	(void)sock;
477 	(void)hostname;
478 	(void)verify;
479 	(void)cb;
480 	(void)u;
481 	FUNC_ENTRY;
482 
483 	ret_state = mbedtls_ssl_handshake(ssl);
484 
485 	if (ret_state == MBEDTLS_ERR_SSL_WANT_READ ||
486 			ret_state == MBEDTLS_ERR_SSL_WANT_WRITE ||
487 			ret_state == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
488 			ret_state == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS)
489 	{
490 		rc = TCPSOCKET_INTERRUPTED;
491 	}
492 	else if (ret_state == 0)
493 	{
494 		/* handshake complete check server certificate */
495 		Log(TRACE_MIN, -1, "ssl handshake complete.\n");
496 #if !defined(IOT_LITEOS_ADAPT)
497 		rc = 1;
498 #else
499         rc = 0;
500 #endif
501 	} else {
502 		rc = SSL_FATAL;
503 		Log(TRACE_PROTOCOL, -1, "failed! mbedtls_ssl_handshake returned -0x%x\n", ret_state);
504 	}
505 	FUNC_EXIT_RC(rc);
506 	return rc;
507 }
508 
509 
510 /**
511  *  Reads one byte from a socket
512  *  @param socket the socket to read from
513  *  @param c the character read, returned
514  *  @return completion code
515  */
SSLSocket_getch(SSL * ssl,int socket,char * c)516 int SSLSocket_getch(SSL* ssl, int socket, char* c)
517 {
518 	int rc = SOCKET_ERROR;
519 	if (ssl == NULL || c == NULL)
520 		goto exit;
521 
522 	FUNC_ENTRY;
523 	if ((rc = SocketBuffer_getQueuedChar(socket, c)) != SOCKETBUFFER_INTERRUPTED)
524 		goto exit;
525 
526 	if ((rc = mbedtls_ssl_read(ssl, (unsigned char *)c, (size_t)1)) < 0)
527 	{
528 		Log(TRACE_MIN, -1, "[%s,%d]rc = %d", __func__, __LINE__, rc);
529 		if (rc == MBEDTLS_ERR_SSL_WANT_READ || rc == MBEDTLS_ERR_SSL_WANT_WRITE)
530 		{
531 			rc = TCPSOCKET_INTERRUPTED;
532 			SocketBuffer_interrupted(socket, 0);
533 		}
534 		else
535 		{
536 			rc = SOCKET_ERROR;
537 		}
538 	}
539 	else if (rc == 0)
540 	{
541 		rc = SOCKET_ERROR;  /* The return value from recv is 0 when the peer has performed an orderly shutdown. */
542 	}
543 	else if (rc == 1)
544 	{
545 		SocketBuffer_queueChar(socket, *c);
546 		rc = TCPSOCKET_COMPLETE;
547 	}
548 exit:
549 	FUNC_EXIT_RC(rc);
550 	return rc;
551 }
552 
553 
554 /**
555  *  Attempts to read a number of bytes from a socket, non-blocking. If a previous read did not
556  *  finish, then retrieve that data.
557  *  @param socket the socket to read from
558  *  @param bytes the number of bytes to read
559  *  @param actual_len the actual number of bytes read
560  *  @return completion code
561  */
SSLSocket_getdata(SSL * ssl,int socket,size_t bytes,size_t * actual_len,int * rc)562 char *SSLSocket_getdata(SSL* ssl, int socket, size_t bytes, size_t* actual_len, int* rc)
563 {
564 	char* buf = NULL;
565 	if (ssl == NULL || actual_len == NULL || rc == NULL)
566 		goto exit;
567 
568 	FUNC_ENTRY;
569 	if (bytes == 0)
570 	{
571 		buf = SocketBuffer_complete(socket);
572 		goto exit;
573 	}
574 
575 	buf = SocketBuffer_getQueuedData(socket, bytes, actual_len);
576 
577 	if ((*rc = mbedtls_ssl_read(ssl, (unsigned char *)(buf + (*actual_len)), (int)(bytes - (*actual_len)))) < 0)
578 	{
579 		Log(TRACE_MIN, -1, "[%s,%d]rc = %d", __func__, __LINE__, *rc);
580 		if (*rc != MBEDTLS_ERR_SSL_WANT_READ && *rc != MBEDTLS_ERR_SSL_WANT_WRITE)
581 		{
582 			buf = NULL;
583 			goto exit;
584 		}
585 	}
586 	else if (*rc == 0) /* rc 0 means the other end closed the socket */
587 	{
588 		buf = NULL;
589 		goto exit;
590 	}
591 	else
592 	{
593 		*actual_len += *rc;
594 	}
595 
596 	if (*actual_len == bytes)
597 	{
598 		SocketBuffer_complete(socket);
599 		/* if we read the whole packet, there might still be data waiting in the SSL buffer, which
600 		isn't picked up by select.  So here we should check for any data remaining in the SSL buffer, and
601 		if so, add this socket to a new "pending SSL reads" list.
602 		*/
603 		if (mbedtls_ssl_get_bytes_avail(ssl) > 0) /* return no of bytes pending */
604 			SSLSocket_addPendingRead(socket);
605 	}
606 	else /* we didn't read the whole packet */
607 	{
608 		SocketBuffer_interrupted(socket, *actual_len);
609 		Log(TRACE_MAX, -1, "SSL_read: %u bytes expected but %u bytes now received", bytes, *actual_len);
610 	}
611 exit:
612 	FUNC_EXIT;
613 	return buf;
614 }
615 
SSLSocket_destroyContext(networkHandles * net)616 void SSLSocket_destroyContext(networkHandles* net)
617 {
618 	FUNC_ENTRY;
619 	if (net != NULL && net->ctx != NULL)
620 	{
621 		mbedtls_ssl_config_free(&net->ctx->conf);
622 		mbedtls_ctr_drbg_free(&net->ctx->ctr_drbg);
623 		mbedtls_entropy_free(&net->ctx->entropy);
624 #ifdef MBEDTLS_USE_CRT
625 		mbedtls_x509_crt_free(&net->ctx->cacert);
626 		mbedtls_x509_crt_free(&net->ctx->clicert);
627 		mbedtls_pk_free(&net->ctx->pkey);
628 #endif
629 		free(net->ctx);
630 		net->ctx = NULL;
631 	}
632 	FUNC_EXIT;
633 }
634 
635 static List pending_reads = {NULL, NULL, NULL, 0, 0};
636 
SSLSocket_close(networkHandles * net)637 int SSLSocket_close(networkHandles* net)
638 {
639 	int rc = 1;
640 	if (net == NULL)
641 		return rc;
642 
643 	FUNC_ENTRY;
644 	/* clean up any pending reads for this socket */
645 	if (pending_reads.count > 0 && ListFindItem(&pending_reads, &net->socket, intcompare) != NULL)
646 		ListRemoveItem(&pending_reads, &net->socket, intcompare);
647 
648 	if (net->ssl != NULL)
649 	{
650 		rc = mbedtls_ssl_close_notify(net->ssl);
651 		mbedtls_ssl_free(net->ssl);
652 		free(net->ssl);
653 		net->ssl = NULL;
654 	}
655 	SSLSocket_destroyContext(net);
656 	FUNC_EXIT_RC(rc);
657 	return rc;
658 }
659 
SSLSocket_buflenCheck(const iobuf * iovec)660 static int SSLSocket_buflenCheck(const iobuf* iovec)
661 {
662 
663 	if (iovec->iov_len > SSL_MAX_COUNT)
664 	{
665 		return SOCKET_ERROR;
666 	}
667 
668 #if defined(SEND_MAX_LEN)
669 	if (iovec->iov_len + SSL_HEADER_LEN > SEND_MAX_LEN)
670 	{
671 		char array[LOG_LEN];
672 		int rc = sprintf_s(array, LOG_LEN, "[Error]:Please don't send a message longer than %d bytes."
673 				" Message length which contains header and payload is %u bytes.\r\n", SEND_MAX_LEN,
674 				iovec->iov_len + SSL_HEADER_LEN);
675 		if (rc != -1)
676 			app_at_send_at_rsp_string_lines_with_claim_and_log_restricted(array);
677 		return  EXT_SOCKET_RET_MESSAGE_TOO_LONG;
678 	}
679 #endif
680 	return 0;
681 }
682 
SSLSocket_putdatasub(iobuf * iovec,const char * buf0,size_t buf0len,PacketBuffers * bufs)683 static int SSLSocket_putdatasub(iobuf* iovec, const char* buf0, size_t buf0len, PacketBuffers* bufs)
684 {
685 	int i;
686 	char *ptr;
687 	int len = 0;
688 	int mem_ret = -1;
689 
690 	ptr = iovec->iov_base = (char *)malloc(iovec->iov_len);
691 	len = iovec->iov_len;
692 	if (ptr == NULL)
693 	{
694 		return PAHO_MEMORY_ERROR;
695 	}
696 	mem_ret = memcpy_s(ptr, len, buf0, buf0len);
697 	if (mem_ret != 0)
698 	{
699 		free(iovec->iov_base);
700 		iovec->iov_base = NULL;
701 		return PAHO_MEMORY_ERROR;
702 	}
703 	ptr += buf0len;
704 	len -= buf0len;
705 	for (i = 0; i < bufs->count; i++)
706 	{
707 		if (bufs->buffers[i] != NULL && bufs->buflens[i] > 0)
708 		{
709 			mem_ret = memcpy_s(ptr, len, bufs->buffers[i], bufs->buflens[i]);
710 			if (mem_ret != 0)
711 			{
712 				free(iovec->iov_base);
713 				iovec->iov_base = NULL;
714 				return PAHO_MEMORY_ERROR;
715 			}
716 			ptr += bufs->buflens[i];
717 			len -= bufs->buflens[i];
718 		}
719 	}
720 	return 0;
721 }
722 
SSL_write(SSL * ssl,int socket,iobuf * iovec)723 static int SSL_write(SSL* ssl, int socket, iobuf* iovec)
724 {
725 	int rc;
726 	if ((rc = mbedtls_ssl_write(ssl, iovec->iov_base, iovec->iov_len)) == (int)iovec->iov_len)
727 	{
728 		return TCPSOCKET_COMPLETE;
729 	}
730 	else
731 	{
732 		Log(TRACE_MIN, -1, "[%s,%d]rc = %d", __func__, __LINE__, rc);
733 		if (rc == MBEDTLS_ERR_SSL_WANT_READ ||
734 			rc == MBEDTLS_ERR_SSL_WANT_WRITE ||
735 			rc == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
736 			rc == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
737 			rc == MBEDTLS_ERR_SSL_CLIENT_RECONNECT)
738 		{
739 			int* sockmem = (int*)malloc(sizeof(int));
740 			int free = 1;
741 
742 			if (sockmem != NULL)
743 			{
744 				Log(TRACE_MIN, -1, "Partial write: incomplete write of %d bytes on SSL socket %d",
745 					iovec->iov_len, socket);
746 				SocketBuffer_pendingWrite(socket, ssl, 1, iovec, &free, iovec->iov_len, 0);
747 				*sockmem = socket;
748 				ListAppend(mod_s.write_pending, sockmem, sizeof(int));
749 #if defined(USE_SELECT)
750 				FD_SET(socket, &(mod_s.pending_wset));
751 #endif
752 				return TCPSOCKET_INTERRUPTED;
753 			}
754 			else
755 			{
756 				return PAHO_MEMORY_ERROR;
757 			}
758 		}
759 #if defined(IOT_CONNECT)
760 		else if (rc == EXT_SOCKET_RET_SLIDING_WINDOW_FULL)
761 		{
762 			return EXT_SOCKET_RET_SLIDING_WINDOW_FULL;
763 		}
764 #endif
765 		else
766 		{
767 			return SOCKET_ERROR;
768 		}
769 	}
770 	return 0;
771 }
772 
773 /* No SSL_writev() provided by OpenSSL. Boo. */
SSLSocket_putdatas(SSL * ssl,int socket,char * buf0,size_t buf0len,PacketBuffers bufs)774 int SSLSocket_putdatas(SSL* ssl, int socket, char* buf0, size_t buf0len, PacketBuffers bufs)
775 {
776 	int rc = 0;
777 	int i;
778 	iobuf iovec;
779 
780 	if (ssl == NULL || buf0 == NULL || bufs.buffers == NULL || bufs.buflens == NULL || bufs.frees == NULL)
781 	{
782 		rc = SOCKET_ERROR;
783 		goto exit;
784 	}
785 	if (bufs.count > SSL_MAX_COUNT)
786 	{
787 		rc = SOCKET_ERROR;
788 		goto exit;
789 	}
790 
791 	FUNC_ENTRY;
792 	iovec.iov_len = (ULONG)buf0len;
793 	for (i = 0; i < bufs.count; i++)
794 		iovec.iov_len += (ULONG)bufs.buflens[i];
795 
796 	rc = SSLSocket_buflenCheck(&iovec);
797 	if (rc != 0)
798 		goto exit;
799 
800 	rc = SSLSocket_putdatasub(&iovec, buf0, buf0len, &bufs);
801 	if (rc != 0)
802 		goto exit;
803 
804 	SSL_lock_mutex(&sslCoreMutex);
805 	rc = SSL_write(ssl, socket, &iovec);
806 	SSL_unlock_mutex(&sslCoreMutex);
807 
808 	if (rc != TCPSOCKET_INTERRUPTED)
809 	{
810 		free(iovec.iov_base);
811 		iovec.iov_base = NULL;
812 	}
813 	else
814 	{
815 		free(buf0);
816 		for (i = 0; i < bufs.count; ++i)
817 		{
818 			if (bufs.frees[i] != 0)
819 			{
820 				free(bufs.buffers[i]);
821 				bufs.buffers[i] = NULL;
822 			}
823 		}
824 	}
825 
826 exit:
827 	FUNC_EXIT_RC(rc);
828 	return rc;
829 }
830 
831 
SSLSocket_addPendingRead(int sock)832 void SSLSocket_addPendingRead(int sock)
833 {
834 	FUNC_ENTRY;
835 	if (ListFindItem(&pending_reads, &sock, intcompare) == NULL) /* make sure we don't add the same socket twice */
836 	{
837 		int* psock = (int*)malloc(sizeof(sock));
838 		if (psock != NULL)
839 		{
840 			*psock = sock;
841 			ListAppend(&pending_reads, psock, sizeof(sock));
842 		}
843 	}
844 	else
845 	{
846 		Log(TRACE_MIN, -1, "SSLSocket_addPendingRead: socket %d already in the list", sock);
847 	}
848 
849 	FUNC_EXIT;
850 }
851 
852 
SSLSocket_getPendingRead(void)853 int SSLSocket_getPendingRead(void)
854 {
855 	int sock = -1;
856 
857 	if (pending_reads.count > 0)
858 	{
859 		sock = *(int*)(pending_reads.first->content);
860 #if defined(IOT_CONNECT) || defined(IOT_LITEOS_ADAPT)
861 		_ListRemoveHead(&pending_reads); // conflict with libbt_host.a
862 #else
863 		ListRemoveHead(&pending_reads);
864 #endif
865 	}
866 	return sock;
867 }
868 
869 
SSLSocket_continueWrite(pending_writes * pw)870 int SSLSocket_continueWrite(pending_writes* pw)
871 {
872 	int rc = 0;
873 	if (pw == NULL)
874 		return rc;
875 
876 	#if defined(SEND_MAX_LEN)
877 		if (pw->iovecs[0].iov_len + SSL_HEADER_LEN > SEND_MAX_LEN)
878 		{
879 			char array[LOG_LEN];
880 			int rc = sprintf_s(array, LOG_LEN, "[Error]:Please don't send a message longer than %d bytes."
881 					" Message length which contains header and payload is %u bytes.\r\n", SEND_MAX_LEN,
882 					pw->iovecs[0].iov_len + SSL_HEADER_LEN);
883 			if (rc != -1)
884 				app_at_send_at_rsp_string_lines_with_claim_and_log_restricted(array);
885 			return EXT_SOCKET_RET_MESSAGE_TOO_LONG;
886 		}
887 	#endif
888 
889 	FUNC_ENTRY;
890 	if ((rc = mbedtls_ssl_write(pw->ssl, pw->iovecs[0].iov_base, pw->iovecs[0].iov_len)) == (int)pw->iovecs[0].iov_len)
891 	{
892 		/* topic and payload buffers are freed elsewhere, when all references to them have been removed */
893 		free(pw->iovecs[0].iov_base);
894 		pw->iovecs[0].iov_base = NULL;
895 		Log(TRACE_MIN, -1, "SSL continueWrite: partial write now complete for socket %d", pw->socket);
896 		rc = 1;
897 	}
898 	else
899 	{
900 		Log(TRACE_MIN, -1, "[%s,%d]rc = %d", __func__, __LINE__, rc);
901 		if (rc == MBEDTLS_ERR_SSL_WANT_READ ||
902 			rc == MBEDTLS_ERR_SSL_WANT_WRITE ||
903 			rc == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS ||
904 			rc == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
905 			rc == MBEDTLS_ERR_SSL_CLIENT_RECONNECT)
906 			rc = 0; /* indicate we haven't finished writing the payload yet */
907 	}
908 	FUNC_EXIT_RC(rc);
909 	return rc;
910 }
911 
912 #endif
913