1 // SPDX-License-Identifier: GPL-2.0
2 /*
3 * fs/hmdfs/comm/transport.c
4 *
5 * Copyright (c) 2020-2021 Huawei Device Co., Ltd.
6 */
7
8 #include "transport.h"
9
10 #include <linux/freezer.h>
11 #include <linux/highmem.h>
12 #include <linux/kthread.h>
13 #include <linux/module.h>
14 #include <linux/net.h>
15 #include <linux/tcp.h>
16 #include <linux/time.h>
17 #include <linux/file.h>
18 #include <linux/sched/mm.h>
19
20 #include "device_node.h"
21 #include "hmdfs_trace.h"
22 #include "socket_adapter.h"
23 #include "authority/authentication.h"
24
25 #ifdef CONFIG_HMDFS_FS_ENCRYPTION
26 #include <net/tls.h>
27 #include "crypto.h"
28 #endif
29
30 typedef void (*connect_recv_handler)(struct connection *, void *, void *,
31 __u32);
32
33 static connect_recv_handler connect_recv_callback[CONNECT_STAT_COUNT] = {
34 [CONNECT_STAT_WAIT_REQUEST] = connection_handshake_recv_handler,
35 [CONNECT_STAT_WAIT_RESPONSE] = connection_handshake_recv_handler,
36 [CONNECT_STAT_WORKING] = connection_working_recv_handler,
37 [CONNECT_STAT_STOP] = NULL,
38 [CONNECT_STAT_WAIT_ACK] = connection_handshake_recv_handler,
39 [CONNECT_STAT_NEGO_FAIL] = NULL,
40 };
41
recvmsg_nofs(struct socket * sock,struct msghdr * msg,struct kvec * vec,size_t num,size_t size,int flags)42 static int recvmsg_nofs(struct socket *sock, struct msghdr *msg,
43 struct kvec *vec, size_t num, size_t size, int flags)
44 {
45 unsigned int nofs_flags;
46 int ret;
47
48 /* enable NOFS for memory allocation */
49 nofs_flags = memalloc_nofs_save();
50 ret = kernel_recvmsg(sock, msg, vec, num, size, flags);
51 memalloc_nofs_restore(nofs_flags);
52
53 return ret;
54 }
55
sendmsg_nofs(struct socket * sock,struct msghdr * msg,struct kvec * vec,size_t num,size_t size)56 static int sendmsg_nofs(struct socket *sock, struct msghdr *msg,
57 struct kvec *vec, size_t num, size_t size)
58 {
59 unsigned int nofs_flags;
60 int ret;
61
62 /* enable NOFS for memory allocation */
63 nofs_flags = memalloc_nofs_save();
64 ret = kernel_sendmsg(sock, msg, vec, num, size);
65 memalloc_nofs_restore(nofs_flags);
66
67 return ret;
68 }
69
tcp_set_recvtimeo(struct socket * sock,int timeout)70 static int tcp_set_recvtimeo(struct socket *sock, int timeout)
71 {
72 long jiffies_left = timeout * msecs_to_jiffies(MSEC_PER_SEC);
73
74 tcp_sock_set_nodelay(sock->sk);
75 tcp_sock_set_user_timeout(sock->sk, jiffies_left);
76 return 0;
77 }
78
hmdfs_tcpi_rtt(struct hmdfs_peer * con)79 uint32_t hmdfs_tcpi_rtt(struct hmdfs_peer *con)
80 {
81 uint32_t rtt_us = 0;
82 struct connection *conn_impl = NULL;
83 struct tcp_handle *tcp = NULL;
84
85 conn_impl = get_conn_impl(con, CONNECT_TYPE_TCP);
86 if (!conn_impl)
87 return rtt_us;
88 tcp = (struct tcp_handle *)(conn_impl->connect_handle);
89 if (tcp->sock)
90 rtt_us = tcp_sk(tcp->sock->sk)->srtt_us >> 3;
91 connection_put(conn_impl);
92 return rtt_us;
93 }
94
tcp_read_head_from_socket(struct socket * sock,void * buf,unsigned int to_read)95 static int tcp_read_head_from_socket(struct socket *sock, void *buf,
96 unsigned int to_read)
97 {
98 int rc = 0;
99 struct msghdr hmdfs_msg;
100 struct kvec iov;
101
102 iov.iov_base = buf;
103 iov.iov_len = to_read;
104 memset(&hmdfs_msg, 0, sizeof(hmdfs_msg));
105 hmdfs_msg.msg_flags = MSG_WAITALL;
106 hmdfs_msg.msg_control = NULL;
107 hmdfs_msg.msg_controllen = 0;
108 rc = recvmsg_nofs(sock, &hmdfs_msg, &iov, 1, to_read,
109 hmdfs_msg.msg_flags);
110 if (rc == -EAGAIN || rc == -ETIMEDOUT || rc == -EINTR ||
111 rc == -EBADMSG) {
112 usleep_range(1000, 2000);
113 return -EAGAIN;
114 }
115 // error occurred
116 if (rc != to_read) {
117 hmdfs_err("tcp recv error %d", rc);
118 return -ESHUTDOWN;
119 }
120 return 0;
121 }
122
tcp_read_buffer_from_socket(struct socket * sock,void * buf,unsigned int to_read)123 static int tcp_read_buffer_from_socket(struct socket *sock, void *buf,
124 unsigned int to_read)
125 {
126 int read_cnt = 0;
127 int retry_time = 0;
128 int rc = 0;
129 struct msghdr hmdfs_msg;
130 struct kvec iov;
131
132 do {
133 iov.iov_base = (char *)buf + read_cnt;
134 iov.iov_len = to_read - read_cnt;
135 memset(&hmdfs_msg, 0, sizeof(hmdfs_msg));
136 hmdfs_msg.msg_flags = MSG_WAITALL;
137 hmdfs_msg.msg_control = NULL;
138 hmdfs_msg.msg_controllen = 0;
139 rc = recvmsg_nofs(sock, &hmdfs_msg, &iov, 1,
140 to_read - read_cnt, hmdfs_msg.msg_flags);
141 if (rc == -EBADMSG) {
142 usleep_range(1000, 2000);
143 continue;
144 }
145 if (rc == -EAGAIN || rc == -ETIMEDOUT || rc == -EINTR) {
146 retry_time++;
147 hmdfs_info("read again %d", rc);
148 usleep_range(1000, 2000);
149 continue;
150 }
151 // error occurred
152 if (rc <= 0) {
153 hmdfs_err("tcp recv error %d", rc);
154 return -ESHUTDOWN;
155 }
156 read_cnt += rc;
157 if (read_cnt != to_read)
158 hmdfs_info("read again %d/%d", read_cnt, to_read);
159 } while (read_cnt < to_read && retry_time < MAX_RECV_RETRY_TIMES);
160 if (read_cnt == to_read)
161 return 0;
162 return -ESHUTDOWN;
163 }
164
hmdfs_drop_readpage_buffer(struct socket * sock,struct hmdfs_head_cmd * recv)165 static int hmdfs_drop_readpage_buffer(struct socket *sock,
166 struct hmdfs_head_cmd *recv)
167 {
168 unsigned int len;
169 void *buf = NULL;
170 int err;
171
172 len = le32_to_cpu(recv->data_len) - sizeof(struct hmdfs_head_cmd);
173 if (len > HMDFS_PAGE_SIZE || !len) {
174 hmdfs_err("recv invalid readpage length %u", len);
175 return -EINVAL;
176 }
177
178 /* Abort the connection if no memory */
179 buf = kmalloc(len, GFP_KERNEL);
180 if (!buf)
181 return -ESHUTDOWN;
182
183 err = tcp_read_buffer_from_socket(sock, buf, len);
184 kfree(buf);
185
186 return err;
187 }
188
hmdfs_get_readpage_buffer(struct socket * sock,struct hmdfs_head_cmd * recv,struct page * page)189 static int hmdfs_get_readpage_buffer(struct socket *sock,
190 struct hmdfs_head_cmd *recv,
191 struct page *page)
192 {
193 char *page_buf = NULL;
194 unsigned int out_len;
195 int err;
196
197 out_len = le32_to_cpu(recv->data_len) - sizeof(struct hmdfs_head_cmd);
198 if (out_len > HMDFS_PAGE_SIZE || !out_len) {
199 hmdfs_err("recv invalid readpage length %u", out_len);
200 return -EINVAL;
201 }
202
203 page_buf = kmap(page);
204 err = tcp_read_buffer_from_socket(sock, page_buf, out_len);
205 if (err)
206 goto out_unmap;
207 if (out_len != HMDFS_PAGE_SIZE)
208 memset(page_buf + out_len, 0, HMDFS_PAGE_SIZE - out_len);
209
210 out_unmap:
211 kunmap(page);
212 return err;
213 }
214
tcp_recvpage_tls(struct connection * connect,struct hmdfs_head_cmd * recv)215 static int tcp_recvpage_tls(struct connection *connect,
216 struct hmdfs_head_cmd *recv)
217 {
218 int ret = 0;
219 struct tcp_handle *tcp = NULL;
220 struct hmdfs_peer *node = NULL;
221 struct page *page = NULL;
222 struct hmdfs_async_work *async_work = NULL;
223 int rd_err;
224
225 if (!connect) {
226 hmdfs_err("tcp connect == NULL");
227 return -ESHUTDOWN;
228 }
229 node = connect->node;
230 tcp = (struct tcp_handle *)(connect->connect_handle);
231
232 rd_err = le32_to_cpu(recv->ret_code);
233 if (rd_err)
234 hmdfs_warning("tcp: readpage from peer %llu ret err %d",
235 node->device_id, rd_err);
236
237 async_work = (struct hmdfs_async_work *)hmdfs_find_msg_head(node,
238 le32_to_cpu(recv->msg_id));
239 if (!async_work || !cancel_delayed_work(&async_work->d_work))
240 goto out;
241
242 page = async_work->page;
243 if (!page) {
244 hmdfs_err("page not found");
245 goto out;
246 }
247
248 if (!rd_err) {
249 ret = hmdfs_get_readpage_buffer(tcp->sock, recv, page);
250 if (ret)
251 rd_err = ret;
252 }
253 node->conn_operations->recvpage(node, recv, rd_err, async_work);
254 asw_put(async_work);
255 return ret;
256
257 out:
258 /* async_work will be released by recvpage in normal processure */
259 if (async_work)
260 asw_put(async_work);
261 hmdfs_err_ratelimited("timeout and droppage");
262 hmdfs_client_resp_statis(node->sbi, F_READPAGE, HMDFS_RESP_DELAY, 0, 0);
263 if (!rd_err)
264 ret = hmdfs_drop_readpage_buffer(tcp->sock, recv);
265 return ret;
266 }
267
aeadcipher_cb(struct crypto_async_request * req,int error)268 static void aeadcipher_cb(struct crypto_async_request *req, int error)
269 {
270 struct aeadcrypt_result *result = req->data;
271
272 if (error == -EINPROGRESS)
273 return;
274 result->err = error;
275 complete(&result->completion);
276 }
277
aeadcipher_en_de(struct aead_request * req,struct aeadcrypt_result result,int flag)278 static int aeadcipher_en_de(struct aead_request *req,
279 struct aeadcrypt_result result, int flag)
280 {
281 int rc = 0;
282
283 if (flag)
284 rc = crypto_aead_encrypt(req);
285 else
286 rc = crypto_aead_decrypt(req);
287 switch (rc) {
288 case 0:
289 break;
290 case -EINPROGRESS:
291 case -EBUSY:
292 rc = wait_for_completion_interruptible(&result.completion);
293 if (!rc && !result.err)
294 reinit_completion(&result.completion);
295 break;
296 default:
297 hmdfs_err("returned rc %d result %d", rc, result.err);
298 break;
299 }
300 return rc;
301 }
302
set_aeadcipher(struct crypto_aead * tfm,struct aead_request * req,struct aeadcrypt_result * result)303 static int set_aeadcipher(struct crypto_aead *tfm, struct aead_request *req,
304 struct aeadcrypt_result *result)
305 {
306 init_completion(&result->completion);
307 aead_request_set_callback(
308 req, CRYPTO_TFM_REQ_MAY_BACKLOG | CRYPTO_TFM_REQ_MAY_SLEEP,
309 aeadcipher_cb, result);
310 return 0;
311 }
312
aeadcipher_encrypt_buffer(struct connection * con,__u8 * src_buf,size_t src_len,__u8 * dst_buf,size_t dst_len)313 int aeadcipher_encrypt_buffer(struct connection *con, __u8 *src_buf,
314 size_t src_len, __u8 *dst_buf, size_t dst_len)
315 {
316 int ret = 0;
317 struct scatterlist src, dst;
318 struct aead_request *req = NULL;
319 struct aeadcrypt_result result;
320 __u8 cipher_iv[HMDFS_IV_SIZE];
321
322 if (src_len <= 0)
323 return -EINVAL;
324 if (!virt_addr_valid(src_buf) || !virt_addr_valid(dst_buf)) {
325 WARN_ON(1);
326 hmdfs_err("encrypt address is invalid");
327 return -EPERM;
328 }
329
330 get_random_bytes(cipher_iv, HMDFS_IV_SIZE);
331 memcpy(dst_buf, cipher_iv, HMDFS_IV_SIZE);
332 req = aead_request_alloc(con->tfm, GFP_KERNEL);
333 if (!req) {
334 hmdfs_err("aead_request_alloc() failed");
335 return -ENOMEM;
336 }
337 ret = set_aeadcipher(con->tfm, req, &result);
338 if (ret) {
339 hmdfs_err("set_enaeadcipher exit fault");
340 goto out;
341 }
342
343 sg_init_one(&src, src_buf, src_len);
344 sg_init_one(&dst, dst_buf + HMDFS_IV_SIZE, dst_len - HMDFS_IV_SIZE);
345 aead_request_set_crypt(req, &src, &dst, src_len, cipher_iv);
346 aead_request_set_ad(req, 0);
347 ret = aeadcipher_en_de(req, result, ENCRYPT_FLAG);
348 out:
349 aead_request_free(req);
350 return ret;
351 }
352
aeadcipher_decrypt_buffer(struct connection * con,__u8 * src_buf,size_t src_len,__u8 * dst_buf,size_t dst_len)353 int aeadcipher_decrypt_buffer(struct connection *con, __u8 *src_buf,
354 size_t src_len, __u8 *dst_buf, size_t dst_len)
355 {
356 int ret = 0;
357 struct scatterlist src, dst;
358 struct aead_request *req = NULL;
359 struct aeadcrypt_result result;
360 __u8 cipher_iv[HMDFS_IV_SIZE];
361
362 if (src_len <= HMDFS_IV_SIZE + HMDFS_TAG_SIZE)
363 return -EINVAL;
364 if (!virt_addr_valid(src_buf) || !virt_addr_valid(dst_buf)) {
365 WARN_ON(1);
366 hmdfs_err("decrypt address is invalid");
367 return -EPERM;
368 }
369
370 memcpy(cipher_iv, src_buf, HMDFS_IV_SIZE);
371 req = aead_request_alloc(con->tfm, GFP_KERNEL);
372 if (!req) {
373 hmdfs_err("aead_request_alloc() failed");
374 return -ENOMEM;
375 }
376 ret = set_aeadcipher(con->tfm, req, &result);
377 if (ret) {
378 hmdfs_err("set_deaeadcipher exit fault");
379 goto out;
380 }
381
382 sg_init_one(&src, src_buf + HMDFS_IV_SIZE, src_len - HMDFS_IV_SIZE);
383 sg_init_one(&dst, dst_buf, dst_len);
384 aead_request_set_crypt(req, &src, &dst, src_len - HMDFS_IV_SIZE,
385 cipher_iv);
386 aead_request_set_ad(req, 0);
387 ret = aeadcipher_en_de(req, result, DECRYPT_FLAG);
388 out:
389 aead_request_free(req);
390 return ret;
391 }
392
tcp_recvbuffer_cipher(struct connection * connect,struct hmdfs_head_cmd * recv)393 static int tcp_recvbuffer_cipher(struct connection *connect,
394 struct hmdfs_head_cmd *recv)
395 {
396 int ret = 0;
397 struct tcp_handle *tcp = NULL;
398 size_t cipherbuffer_len;
399 __u8 *cipherbuffer = NULL;
400 size_t outlen = 0;
401 __u8 *outdata = NULL;
402 __u32 recv_len = le32_to_cpu(recv->data_len);
403
404 tcp = (struct tcp_handle *)(connect->connect_handle);
405 if (recv_len == sizeof(struct hmdfs_head_cmd))
406 goto out_recv_head;
407 else if (recv_len > sizeof(struct hmdfs_head_cmd) &&
408 recv_len <= ADAPTER_MESSAGE_LENGTH)
409 cipherbuffer_len = recv_len - sizeof(struct hmdfs_head_cmd) +
410 HMDFS_IV_SIZE + HMDFS_TAG_SIZE;
411 else
412 return -ENOMSG;
413 cipherbuffer = kzalloc(cipherbuffer_len, GFP_KERNEL);
414 if (!cipherbuffer) {
415 hmdfs_err("zalloc cipherbuffer error");
416 return -ESHUTDOWN;
417 }
418 outlen = cipherbuffer_len - HMDFS_IV_SIZE - HMDFS_TAG_SIZE;
419 outdata = kzalloc(outlen, GFP_KERNEL);
420 if (!outdata) {
421 hmdfs_err("encrypt zalloc outdata error");
422 kfree(cipherbuffer);
423 return -ESHUTDOWN;
424 }
425
426 ret = tcp_read_buffer_from_socket(tcp->sock, cipherbuffer,
427 cipherbuffer_len);
428 if (ret)
429 goto out_recv;
430 ret = aeadcipher_decrypt_buffer(connect, cipherbuffer, cipherbuffer_len,
431 outdata, outlen);
432 if (ret) {
433 hmdfs_err("decrypt_buf fail");
434 goto out_recv;
435 }
436 out_recv_head:
437 if (connect_recv_callback[connect->status]) {
438 connect_recv_callback[connect->status](connect, recv, outdata,
439 outlen);
440 } else {
441 kfree(outdata);
442 hmdfs_err("encypt callback NULL status %d", connect->status);
443 }
444 kfree(cipherbuffer);
445 return ret;
446 out_recv:
447 kfree(cipherbuffer);
448 kfree(outdata);
449 return ret;
450 }
451
tcp_recvbuffer_tls(struct connection * connect,struct hmdfs_head_cmd * recv)452 static int tcp_recvbuffer_tls(struct connection *connect,
453 struct hmdfs_head_cmd *recv)
454 {
455 int ret = 0;
456 struct tcp_handle *tcp = NULL;
457 size_t outlen;
458 __u8 *outdata = NULL;
459 __u32 recv_len = le32_to_cpu(recv->data_len);
460
461 tcp = (struct tcp_handle *)(connect->connect_handle);
462 outlen = recv_len - sizeof(struct hmdfs_head_cmd);
463 if (outlen == 0)
464 goto out_recv_head;
465
466 /*
467 * NOTE: Up to half of the allocated memory may be wasted due to
468 * the Internal Fragmentation, however the memory allocation times
469 * can be reduced and we don't have to adjust existing message
470 * transporting mechanism
471 */
472 outdata = kmalloc(outlen, GFP_KERNEL);
473 if (!outdata)
474 return -ESHUTDOWN;
475
476 ret = tcp_read_buffer_from_socket(tcp->sock, outdata, outlen);
477 if (ret) {
478 kfree(outdata);
479 return ret;
480 }
481 tcp->connect->stat.recv_bytes += outlen;
482 out_recv_head:
483 if (connect_recv_callback[connect->status]) {
484 connect_recv_callback[connect->status](connect, recv, outdata,
485 outlen);
486 } else {
487 kfree(outdata);
488 hmdfs_err("callback NULL status %d", connect->status);
489 }
490 return 0;
491 }
492
tcp_receive_from_sock(struct tcp_handle * tcp)493 static int tcp_receive_from_sock(struct tcp_handle *tcp)
494 {
495 struct hmdfs_head_cmd *recv = NULL;
496 int ret = 0;
497
498 if (!tcp) {
499 hmdfs_info("tcp recv thread !tcp");
500 return -ESHUTDOWN;
501 }
502
503 if (!tcp->sock) {
504 hmdfs_info("tcp recv thread !sock");
505 return -ESHUTDOWN;
506 }
507
508 recv = kmem_cache_alloc(tcp->recv_cache, GFP_KERNEL);
509 if (!recv) {
510 hmdfs_info("tcp recv thread !cache");
511 return -ESHUTDOWN;
512 }
513
514 ret = tcp_read_head_from_socket(tcp->sock, recv,
515 sizeof(struct hmdfs_head_cmd));
516 if (ret)
517 goto out;
518
519 tcp->connect->stat.recv_bytes += sizeof(struct hmdfs_head_cmd);
520 tcp->connect->stat.recv_message_count++;
521
522 if (recv->magic != HMDFS_MSG_MAGIC) {
523 hmdfs_info_ratelimited("tcp recv fd %d wrong magic. drop message",
524 tcp->fd);
525 goto out;
526 }
527
528 if ((le32_to_cpu(recv->data_len) >
529 HMDFS_MAX_MESSAGE_LEN + sizeof(struct hmdfs_head_cmd)) ||
530 (le32_to_cpu(recv->data_len) < sizeof(struct hmdfs_head_cmd))) {
531 hmdfs_info("tcp recv fd %d length error. drop message",
532 tcp->fd);
533 goto out;
534 }
535
536 if (recv->version > USERSPACE_MAX_VER &&
537 tcp->connect->status == CONNECT_STAT_WORKING &&
538 recv->operations.command == F_READPAGE &&
539 recv->operations.cmd_flag == C_RESPONSE) {
540 ret = tcp_recvpage_tls(tcp->connect, recv);
541 goto out;
542 }
543
544 if (tcp->connect->status == CONNECT_STAT_WORKING &&
545 recv->version > USERSPACE_MAX_VER)
546 ret = tcp_recvbuffer_tls(tcp->connect, recv);
547 else
548 ret = tcp_recvbuffer_cipher(tcp->connect, recv);
549
550 out:
551 kmem_cache_free(tcp->recv_cache, recv);
552 return ret;
553 }
554
tcp_handle_is_available(struct tcp_handle * tcp)555 static bool tcp_handle_is_available(struct tcp_handle *tcp)
556 {
557 #ifdef CONFIG_HMDFS_FS_ENCRYPTION
558 struct tls_context *tls_ctx = NULL;
559 struct tls_sw_context_rx *ctx = NULL;
560
561 #endif
562 if (!tcp || !tcp->sock || !tcp->sock->sk) {
563 hmdfs_err("Invalid tcp connection");
564 return false;
565 }
566
567 if (tcp->sock->sk->sk_state != TCP_ESTABLISHED) {
568 hmdfs_err("TCP conn %d is broken, current sk_state is %d",
569 tcp->fd, tcp->sock->sk->sk_state);
570 return false;
571 }
572
573 if (tcp->sock->state != SS_CONNECTING &&
574 tcp->sock->state != SS_CONNECTED) {
575 hmdfs_err("TCP conn %d is broken, current sock state is %d",
576 tcp->fd, tcp->sock->state);
577 return false;
578 }
579
580 #ifdef CONFIG_HMDFS_FS_ENCRYPTION
581 tls_ctx = tls_get_ctx(tcp->sock->sk);
582 if (tls_ctx) {
583 ctx = tls_sw_ctx_rx(tls_ctx);
584 if (ctx && ctx->strp.stopped) {
585 hmdfs_err(
586 "TCP conn %d is broken, the strparser has stopped",
587 tcp->fd);
588 return false;
589 }
590 }
591 #endif
592 return true;
593 }
594
tcp_recv_thread(void * arg)595 static int tcp_recv_thread(void *arg)
596 {
597 int ret = 0;
598 struct tcp_handle *tcp = (struct tcp_handle *)arg;
599 const struct cred *old_cred;
600
601 WARN_ON(!tcp);
602 WARN_ON(!tcp->sock);
603 set_freezable();
604
605 old_cred = hmdfs_override_creds(tcp->connect->node->sbi->system_cred);
606
607 while (!kthread_should_stop()) {
608 /*
609 * 1. In case the redundant connection has not been mounted on
610 * a peer
611 * 2. Lock is unnecessary since a transient state is acceptable
612 */
613 if (tcp_handle_is_available(tcp) &&
614 list_empty(&tcp->connect->list))
615 goto freeze;
616 if (!mutex_trylock(&tcp->close_mutex))
617 continue;
618 if (tcp_handle_is_available(tcp))
619 ret = tcp_receive_from_sock(tcp);
620 else
621 ret = -ESHUTDOWN;
622 /*
623 * This kthread will exit if ret is -ESHUTDOWN, thus we need to
624 * set recv_task to NULL to avoid calling kthread_stop() from
625 * tcp_close_socket().
626 */
627 if (ret == -ESHUTDOWN)
628 tcp->recv_task = NULL;
629 mutex_unlock(&tcp->close_mutex);
630 if (ret == -ESHUTDOWN) {
631 hmdfs_node_inc_evt_seq(tcp->connect->node);
632 tcp->connect->status = CONNECT_STAT_STOP;
633 if (tcp->connect->node->status != NODE_STAT_OFFLINE)
634 hmdfs_reget_connection(tcp->connect);
635 break;
636 }
637 freeze:
638 schedule();
639 try_to_freeze();
640 }
641
642 hmdfs_info("Exiting. Now, sock state = %d", tcp->sock->state);
643 hmdfs_revert_creds(old_cred);
644 connection_put(tcp->connect);
645 return 0;
646 }
647
tcp_send_message_sock_cipher(struct tcp_handle * tcp,struct hmdfs_send_data * msg)648 static int tcp_send_message_sock_cipher(struct tcp_handle *tcp,
649 struct hmdfs_send_data *msg)
650 {
651 int ret = 0;
652 __u8 *outdata = NULL;
653 size_t outlen = 0;
654 int send_len = 0;
655 int send_vec_cnt = 0;
656 struct msghdr tcp_msg;
657 struct kvec iov[TCP_KVEC_ELE_DOUBLE];
658
659 memset(&tcp_msg, 0, sizeof(tcp_msg));
660 if (!tcp || !tcp->sock) {
661 hmdfs_err("encrypt tcp socket = NULL");
662 return -ESHUTDOWN;
663 }
664 iov[0].iov_base = msg->head;
665 iov[0].iov_len = msg->head_len;
666 send_vec_cnt = TCP_KVEC_HEAD;
667 if (msg->len == 0)
668 goto send;
669
670 outlen = msg->len + HMDFS_IV_SIZE + HMDFS_TAG_SIZE;
671 outdata = kzalloc(outlen, GFP_KERNEL);
672 if (!outdata) {
673 hmdfs_err("tcp send message encrypt fail to alloc outdata");
674 return -ENOMEM;
675 }
676 ret = aeadcipher_encrypt_buffer(tcp->connect, msg->data, msg->len,
677 outdata, outlen);
678 if (ret) {
679 hmdfs_err("encrypt_buf fail");
680 goto out;
681 }
682 iov[1].iov_base = outdata;
683 iov[1].iov_len = outlen;
684 send_vec_cnt = TCP_KVEC_ELE_DOUBLE;
685 send:
686 mutex_lock(&tcp->send_mutex);
687 send_len = sendmsg_nofs(tcp->sock, &tcp_msg, iov, send_vec_cnt,
688 msg->head_len + outlen);
689 mutex_unlock(&tcp->send_mutex);
690 if (send_len <= 0) {
691 hmdfs_err("error %d", send_len);
692 ret = -ESHUTDOWN;
693 } else if (send_len != msg->head_len + outlen) {
694 hmdfs_err("send part of message. %d/%zu", send_len,
695 msg->head_len + outlen);
696 ret = -EAGAIN;
697 } else {
698 ret = 0;
699 }
700 out:
701 kfree(outdata);
702 return ret;
703 }
704
tcp_send_message_sock_tls(struct tcp_handle * tcp,struct hmdfs_send_data * msg)705 static int tcp_send_message_sock_tls(struct tcp_handle *tcp,
706 struct hmdfs_send_data *msg)
707 {
708 int send_len = 0;
709 int send_vec_cnt = 0;
710 struct msghdr tcp_msg;
711 struct kvec iov[TCP_KVEC_ELE_TRIPLE];
712
713 memset(&tcp_msg, 0, sizeof(tcp_msg));
714 if (!tcp || !tcp->sock) {
715 hmdfs_err("tcp socket = NULL");
716 return -ESHUTDOWN;
717 }
718 iov[TCP_KVEC_HEAD].iov_base = msg->head;
719 iov[TCP_KVEC_HEAD].iov_len = msg->head_len;
720 if (msg->len == 0 && msg->sdesc_len == 0) {
721 send_vec_cnt = TCP_KVEC_ELE_SINGLE;
722 } else if (msg->sdesc_len == 0) {
723 iov[TCP_KVEC_DATA].iov_base = msg->data;
724 iov[TCP_KVEC_DATA].iov_len = msg->len;
725 send_vec_cnt = TCP_KVEC_ELE_DOUBLE;
726 } else {
727 iov[TCP_KVEC_FILE_PARA].iov_base = msg->sdesc;
728 iov[TCP_KVEC_FILE_PARA].iov_len = msg->sdesc_len;
729 iov[TCP_KVEC_FILE_CONTENT].iov_base = msg->data;
730 iov[TCP_KVEC_FILE_CONTENT].iov_len = msg->len;
731 send_vec_cnt = TCP_KVEC_ELE_TRIPLE;
732 }
733 mutex_lock(&tcp->send_mutex);
734 send_len = sendmsg_nofs(tcp->sock, &tcp_msg, iov, send_vec_cnt,
735 msg->head_len + msg->len + msg->sdesc_len);
736 mutex_unlock(&tcp->send_mutex);
737 if (send_len == -EBADMSG) {
738 return -EBADMSG;
739 } else if (send_len <= 0) {
740 hmdfs_err("error %d", send_len);
741 return -ESHUTDOWN;
742 } else if (send_len != msg->head_len + msg->len + msg->sdesc_len) {
743 hmdfs_err("send part of message. %d/%zu", send_len,
744 msg->head_len + msg->len);
745 tcp->connect->stat.send_bytes += send_len;
746 return -EAGAIN;
747 }
748 tcp->connect->stat.send_bytes += send_len;
749 tcp->connect->stat.send_message_count++;
750 return 0;
751 }
752
753 #ifdef CONFIG_HMDFS_FS_ENCRYPTION
tcp_send_rekey_request(struct connection * connect)754 int tcp_send_rekey_request(struct connection *connect)
755 {
756 int ret = 0;
757 struct hmdfs_send_data msg;
758 struct tcp_handle *tcp = connect->connect_handle;
759 struct hmdfs_head_cmd *head = NULL;
760 struct connection_rekey_request *rekey_request_param = NULL;
761 struct hmdfs_cmd operations;
762
763 hmdfs_init_cmd(&operations, F_CONNECT_REKEY);
764 head = kzalloc(sizeof(struct hmdfs_head_cmd) +
765 sizeof(struct connection_rekey_request),
766 GFP_KERNEL);
767 if (!head)
768 return -ENOMEM;
769 rekey_request_param =
770 (struct connection_rekey_request
771 *)((uint8_t *)head + sizeof(struct hmdfs_head_cmd));
772
773 rekey_request_param->update_request = cpu_to_le32(UPDATE_NOT_REQUESTED);
774
775 head->magic = HMDFS_MSG_MAGIC;
776 head->version = DFS_2_0;
777 head->operations = operations;
778 head->data_len =
779 cpu_to_le32(sizeof(*head) + sizeof(*rekey_request_param));
780 head->reserved = 0;
781 head->reserved1 = 0;
782 head->ret_code = 0;
783
784 msg.head = head;
785 msg.head_len = sizeof(*head);
786 msg.data = rekey_request_param;
787 msg.len = sizeof(*rekey_request_param);
788 msg.sdesc = NULL;
789 msg.sdesc_len = 0;
790 ret = tcp_send_message_sock_tls(tcp, &msg);
791 if (ret != 0)
792 hmdfs_err("return error %d", ret);
793 kfree(head);
794 return ret;
795 }
796 #endif
797
tcp_send_message(struct connection * connect,struct hmdfs_send_data * msg)798 static int tcp_send_message(struct connection *connect,
799 struct hmdfs_send_data *msg)
800 {
801 int ret = 0;
802 #ifdef CONFIG_HMDFS_FS_ENCRYPTION
803 unsigned long nowtime = jiffies;
804 #endif
805 struct tcp_handle *tcp = NULL;
806
807 if (!connect) {
808 hmdfs_err("tcp connection = NULL ");
809 return -ESHUTDOWN;
810 }
811 if (!msg) {
812 hmdfs_err("msg = NULL");
813 return -EINVAL;
814 }
815 if (msg->len > HMDFS_MAX_MESSAGE_LEN) {
816 hmdfs_err("message->len error: %zu", msg->len);
817 return -EINVAL;
818 }
819 tcp = (struct tcp_handle *)(connect->connect_handle);
820 if (connect->status == CONNECT_STAT_STOP)
821 return -EAGAIN;
822
823 trace_hmdfs_tcp_send_message(msg->head);
824
825 if (connect->status == CONNECT_STAT_WORKING &&
826 connect->node->version > USERSPACE_MAX_VER)
827 ret = tcp_send_message_sock_tls(tcp, msg);
828 else
829 // Handshake status or version HMDFS1.0
830 ret = tcp_send_message_sock_cipher(tcp, msg);
831
832 if (ret != 0) {
833 hmdfs_err("return error %d", ret);
834 return ret;
835 }
836 #ifdef CONFIG_HMDFS_FS_ENCRYPTION
837 if (nowtime - connect->stat.rekey_time >= REKEY_LIFETIME &&
838 connect->status == CONNECT_STAT_WORKING &&
839 connect->node->version >= DFS_2_0) {
840 hmdfs_info("send rekey message to devid %llu",
841 connect->node->device_id);
842 ret = tcp_send_rekey_request(connect);
843 if (ret == 0)
844 set_crypto_info(connect, SET_CRYPTO_SEND);
845 connect->stat.rekey_time = nowtime;
846 }
847 #endif
848 return ret;
849 }
850
tcp_close_socket(struct tcp_handle * tcp)851 void tcp_close_socket(struct tcp_handle *tcp)
852 {
853 if (!tcp)
854 return;
855 mutex_lock(&tcp->close_mutex);
856 if (tcp->recv_task) {
857 kthread_stop(tcp->recv_task);
858 tcp->recv_task = NULL;
859 }
860 mutex_unlock(&tcp->close_mutex);
861 }
862
set_tfm(__u8 * master_key,struct crypto_aead * tfm)863 static int set_tfm(__u8 *master_key, struct crypto_aead *tfm)
864 {
865 int ret = 0;
866 int iv_len;
867 __u8 *sec_key = NULL;
868
869 sec_key = master_key;
870 crypto_aead_clear_flags(tfm, ~0);
871 ret = crypto_aead_setkey(tfm, sec_key, HMDFS_KEY_SIZE);
872 if (ret) {
873 hmdfs_err("failed to set the key");
874 goto out;
875 }
876 ret = crypto_aead_setauthsize(tfm, HMDFS_TAG_SIZE);
877 if (ret) {
878 hmdfs_err("authsize length is error");
879 goto out;
880 }
881
882 iv_len = crypto_aead_ivsize(tfm);
883 if (iv_len != HMDFS_IV_SIZE) {
884 hmdfs_err("IV recommended value should be set %d", iv_len);
885 ret = -ENODATA;
886 }
887 out:
888 return ret;
889 }
890
tcp_update_socket(struct tcp_handle * tcp,int fd,uint8_t * master_key,struct socket * socket)891 static int tcp_update_socket(struct tcp_handle *tcp, int fd,
892 uint8_t *master_key, struct socket *socket)
893 {
894 int err = 0;
895 struct hmdfs_peer *node = NULL;
896
897 if (!master_key || fd == 0)
898 return -EAGAIN;
899
900 tcp->sock = socket;
901 tcp->fd = fd;
902 if (!tcp_handle_is_available(tcp)) {
903 err = -EPIPE;
904 goto put_sock;
905 }
906
907 hmdfs_info("socket fd %d, state %d, refcount %ld",
908 fd, socket->state, file_count(socket->file));
909
910 tcp->recv_cache = kmem_cache_create("hmdfs_socket",
911 tcp->recvbuf_maxsize,
912 0, SLAB_HWCACHE_ALIGN, NULL);
913 if (!tcp->recv_cache) {
914 err = -ENOMEM;
915 goto put_sock;
916 }
917
918 err = tcp_set_recvtimeo(socket, TCP_RECV_TIMEOUT);
919 if (err) {
920 hmdfs_err("tcp set timeout error");
921 goto free_mem_cache;
922 }
923
924 /* send key and recv key, default MASTER KEY */
925 memcpy(tcp->connect->master_key, master_key, HMDFS_KEY_SIZE);
926 memcpy(tcp->connect->send_key, master_key, HMDFS_KEY_SIZE);
927 memcpy(tcp->connect->recv_key, master_key, HMDFS_KEY_SIZE);
928 tcp->connect->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
929 if (IS_ERR(tcp->connect->tfm)) {
930 err = PTR_ERR(tcp->connect->tfm);
931 tcp->connect->tfm = NULL;
932 hmdfs_err("failed to load transform for gcm(aes):%d", err);
933 goto free_mem_cache;
934 }
935
936 err = set_tfm(master_key, tcp->connect->tfm);
937 if (err) {
938 hmdfs_err("tfm seting exit fault");
939 goto free_crypto;
940 }
941
942 connection_get(tcp->connect);
943
944 node = tcp->connect->node;
945 tcp->recv_task = kthread_create(tcp_recv_thread, (void *)tcp,
946 "dfs_rcv%u_%llu_%d",
947 node->owner, node->device_id, fd);
948 if (IS_ERR(tcp->recv_task)) {
949 err = PTR_ERR(tcp->recv_task);
950 hmdfs_err("tcp->rcev_task %d", err);
951 goto put_conn;
952 }
953
954 return 0;
955
956 put_conn:
957 tcp->recv_task = NULL;
958 connection_put(tcp->connect);
959 free_crypto:
960 crypto_free_aead(tcp->connect->tfm);
961 tcp->connect->tfm = NULL;
962 free_mem_cache:
963 kmem_cache_destroy(tcp->recv_cache);
964 tcp->recv_cache = NULL;
965 put_sock:
966 tcp->sock = NULL;
967 tcp->fd = 0;
968
969 return err;
970 }
971
tcp_alloc_handle(struct connection * connect,int socket_fd,uint8_t * master_key,struct socket * socket)972 static struct tcp_handle *tcp_alloc_handle(struct connection *connect,
973 int socket_fd, uint8_t *master_key, struct socket *socket)
974 {
975 int ret = 0;
976 struct tcp_handle *tcp = kzalloc(sizeof(*tcp), GFP_KERNEL);
977
978 if (!tcp)
979 return NULL;
980 tcp->connect = connect;
981 tcp->connect->connect_handle = (void *)tcp;
982 tcp->recvbuf_maxsize = MAX_RECV_SIZE;
983 tcp->recv_task = NULL;
984 tcp->recv_cache = NULL;
985 tcp->sock = NULL;
986 mutex_init(&tcp->close_mutex);
987 mutex_init(&tcp->send_mutex);
988 ret = tcp_update_socket(tcp, socket_fd, master_key, socket);
989 if (ret) {
990 kfree(tcp);
991 return NULL;
992 }
993 return tcp;
994 }
995
hmdfs_get_connection(struct hmdfs_peer * peer)996 void hmdfs_get_connection(struct hmdfs_peer *peer)
997 {
998 struct notify_param param;
999
1000 if (!peer)
1001 return;
1002 param.notify = NOTIFY_GET_SESSION;
1003 param.fd = INVALID_SOCKET_FD;
1004 memcpy(param.remote_cid, peer->cid, HMDFS_CID_SIZE);
1005 notify(peer, ¶m);
1006 }
1007
connection_notify_to_close(struct connection * conn)1008 static void connection_notify_to_close(struct connection *conn)
1009 {
1010 struct notify_param param;
1011 struct hmdfs_peer *peer = NULL;
1012 struct tcp_handle *tcp = NULL;
1013
1014 tcp = conn->connect_handle;
1015 peer = conn->node;
1016
1017 // libdistbus/src/TcpSession.cpp will close the socket
1018 param.notify = NOTIFY_GET_SESSION;
1019 param.fd = tcp->fd;
1020 memcpy(param.remote_cid, peer->cid, HMDFS_CID_SIZE);
1021 notify(peer, ¶m);
1022 }
1023
hmdfs_reget_connection(struct connection * conn)1024 void hmdfs_reget_connection(struct connection *conn)
1025 {
1026 struct tcp_handle *tcp = NULL;
1027 struct connection *conn_impl = NULL;
1028 struct connection *next = NULL;
1029 struct task_struct *recv_task = NULL;
1030 bool should_put = false;
1031 bool stop_thread = true;
1032
1033 if (!conn)
1034 return;
1035
1036 // One may put a connection if and only if he took it out of the list
1037 mutex_lock(&conn->node->conn_impl_list_lock);
1038 list_for_each_entry_safe(conn_impl, next, &conn->node->conn_impl_list,
1039 list) {
1040 if (conn_impl == conn) {
1041 should_put = true;
1042 list_move(&conn->list, &conn->node->conn_deleting_list);
1043 break;
1044 }
1045 }
1046 if (!should_put) {
1047 mutex_unlock(&conn->node->conn_impl_list_lock);
1048 return;
1049 }
1050
1051 tcp = conn->connect_handle;
1052 if (tcp) {
1053 recv_task = tcp->recv_task;
1054 /*
1055 * To avoid the receive thread to stop itself. Ensure receive
1056 * thread stop before process offline event
1057 */
1058 if (!recv_task || recv_task->pid == current->pid)
1059 stop_thread = false;
1060 }
1061 mutex_unlock(&conn->node->conn_impl_list_lock);
1062
1063 if (tcp) {
1064 if (tcp->sock) {
1065 hmdfs_info("shudown sock: fd = %d, sockref = %ld, connref = %u stop_thread = %d",
1066 tcp->fd, file_count(tcp->sock->file),
1067 kref_read(&conn->ref_cnt), stop_thread);
1068 kernel_sock_shutdown(tcp->sock, SHUT_RDWR);
1069 }
1070
1071 if (stop_thread)
1072 tcp_close_socket(tcp);
1073
1074 if (tcp->fd != INVALID_SOCKET_FD)
1075 connection_notify_to_close(conn);
1076 }
1077 connection_put(conn);
1078 }
1079
1080 static struct connection *
lookup_conn_by_socketfd_unsafe(struct hmdfs_peer * node,struct socket * socket)1081 lookup_conn_by_socketfd_unsafe(struct hmdfs_peer *node, struct socket *socket)
1082 {
1083 struct connection *tcp_conn = NULL;
1084 struct tcp_handle *tcp = NULL;
1085
1086 list_for_each_entry(tcp_conn, &node->conn_impl_list, list) {
1087 if (tcp_conn->connect_handle) {
1088 tcp = (struct tcp_handle *)(tcp_conn->connect_handle);
1089 if (tcp->sock == socket) {
1090 connection_get(tcp_conn);
1091 return tcp_conn;
1092 }
1093 }
1094 }
1095 return NULL;
1096 }
1097
hmdfs_reget_connection_work_fn(struct work_struct * work)1098 static void hmdfs_reget_connection_work_fn(struct work_struct *work)
1099 {
1100 struct connection *conn =
1101 container_of(work, struct connection, reget_work);
1102
1103 hmdfs_reget_connection(conn);
1104 connection_put(conn);
1105 }
1106
alloc_conn_tcp(struct hmdfs_peer * node,int socket_fd,uint8_t * master_key,uint8_t status,struct socket * socket)1107 struct connection *alloc_conn_tcp(struct hmdfs_peer *node, int socket_fd,
1108 uint8_t *master_key, uint8_t status, struct socket *socket)
1109 {
1110 struct connection *tcp_conn = NULL;
1111 unsigned long nowtime = jiffies;
1112
1113 tcp_conn = kzalloc(sizeof(*tcp_conn), GFP_KERNEL);
1114 if (!tcp_conn)
1115 goto out_err;
1116
1117 kref_init(&tcp_conn->ref_cnt);
1118 mutex_init(&tcp_conn->ref_lock);
1119 INIT_LIST_HEAD(&tcp_conn->list);
1120 tcp_conn->node = node;
1121 tcp_conn->close = tcp_stop_connect;
1122 tcp_conn->send_message = tcp_send_message;
1123 tcp_conn->type = CONNECT_TYPE_TCP;
1124 tcp_conn->status = status;
1125 tcp_conn->stat.rekey_time = nowtime;
1126 tcp_conn->connect_handle =
1127 (void *)tcp_alloc_handle(tcp_conn, socket_fd, master_key, socket);
1128 INIT_WORK(&tcp_conn->reget_work, hmdfs_reget_connection_work_fn);
1129 if (!tcp_conn->connect_handle) {
1130 hmdfs_err("Failed to alloc tcp_handle for strcut conn");
1131 goto out_err;
1132 }
1133 return tcp_conn;
1134
1135 out_err:
1136 kfree(tcp_conn);
1137 return NULL;
1138 }
1139
add_conn_tcp_unsafe(struct hmdfs_peer * node,struct socket * socket,struct connection * conn2add)1140 static struct connection *add_conn_tcp_unsafe(struct hmdfs_peer *node,
1141 struct socket *socket,
1142 struct connection *conn2add)
1143 {
1144 struct connection *conn;
1145
1146 conn = lookup_conn_by_socketfd_unsafe(node, socket);
1147 if (conn) {
1148 hmdfs_info("socket already in list");
1149 return conn;
1150 }
1151
1152 /* Prefer to use socket opened by local device */
1153 if (conn2add->status == CONNECT_STAT_WAIT_REQUEST)
1154 list_add(&conn2add->list, &node->conn_impl_list);
1155 else
1156 list_add_tail(&conn2add->list, &node->conn_impl_list);
1157 connection_get(conn2add);
1158 return conn2add;
1159 }
1160
hmdfs_get_conn_tcp(struct hmdfs_peer * node,int fd,uint8_t * master_key,uint8_t status)1161 struct connection *hmdfs_get_conn_tcp(struct hmdfs_peer *node, int fd,
1162 uint8_t *master_key, uint8_t status)
1163 {
1164 struct connection *tcp_conn = NULL, *on_peer_conn = NULL;
1165 struct tcp_handle *tcp = NULL;
1166 struct socket *socket = NULL;
1167 int err = 0;
1168
1169 socket = sockfd_lookup(fd, &err);
1170 if (!socket) {
1171 hmdfs_err("lookup socket fail, socket_fd %d, err %d", fd, err);
1172 return NULL;
1173 }
1174 mutex_lock(&node->conn_impl_list_lock);
1175 tcp_conn = lookup_conn_by_socketfd_unsafe(node, socket);
1176 mutex_unlock(&node->conn_impl_list_lock);
1177 if (tcp_conn) {
1178 hmdfs_info("Got a existing tcp conn: fsocket_fd = %d",
1179 fd);
1180 sockfd_put(socket);
1181 goto out;
1182 }
1183
1184 tcp_conn = alloc_conn_tcp(node, fd, master_key, status, socket);
1185 if (!tcp_conn) {
1186 hmdfs_info("Failed to alloc a tcp conn, socket_fd %d", fd);
1187 sockfd_put(socket);
1188 goto out;
1189 }
1190
1191 mutex_lock(&node->conn_impl_list_lock);
1192 on_peer_conn = add_conn_tcp_unsafe(node, socket, tcp_conn);
1193 mutex_unlock(&node->conn_impl_list_lock);
1194 tcp = tcp_conn->connect_handle;
1195 if (on_peer_conn == tcp_conn) {
1196 hmdfs_info("Got a newly allocated tcp conn: socket_fd = %d", fd);
1197 wake_up_process(tcp->recv_task);
1198 if (status == CONNECT_STAT_WAIT_RESPONSE)
1199 connection_send_handshake(
1200 on_peer_conn, CONNECT_MESG_HANDSHAKE_REQUEST,
1201 0);
1202 } else {
1203 hmdfs_info("Got a existing tcp conn: socket_fd = %d", fd);
1204 tcp->fd = INVALID_SOCKET_FD;
1205 tcp_close_socket(tcp);
1206 connection_put(tcp_conn);
1207
1208 tcp_conn = on_peer_conn;
1209 }
1210
1211 out:
1212 return tcp_conn;
1213 }
1214
tcp_stop_connect(struct connection * connect)1215 void tcp_stop_connect(struct connection *connect)
1216 {
1217 hmdfs_info("now nothing to do");
1218 }
1219