1 /*
2 * common code for virtio vsock
3 *
4 * Copyright (C) 2013-2015 Red Hat, Inc.
5 * Author: Asias He <asias@redhat.com>
6 * Stefan Hajnoczi <stefanha@redhat.com>
7 *
8 * This work is licensed under the terms of the GNU GPL, version 2.
9 */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29
virtio_transport_get_ops(void)30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32 const struct vsock_transport *t = vsock_core_get_transport();
33
34 return container_of(t, struct virtio_transport, transport);
35 }
36
37 static struct virtio_vsock_pkt *
virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info * info,size_t len,u32 src_cid,u32 src_port,u32 dst_cid,u32 dst_port)38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39 size_t len,
40 u32 src_cid,
41 u32 src_port,
42 u32 dst_cid,
43 u32 dst_port)
44 {
45 struct virtio_vsock_pkt *pkt;
46 int err;
47
48 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49 if (!pkt)
50 return NULL;
51
52 pkt->hdr.type = cpu_to_le16(info->type);
53 pkt->hdr.op = cpu_to_le16(info->op);
54 pkt->hdr.src_cid = cpu_to_le64(src_cid);
55 pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
56 pkt->hdr.src_port = cpu_to_le32(src_port);
57 pkt->hdr.dst_port = cpu_to_le32(dst_port);
58 pkt->hdr.flags = cpu_to_le32(info->flags);
59 pkt->len = len;
60 pkt->hdr.len = cpu_to_le32(len);
61 pkt->reply = info->reply;
62 pkt->vsk = info->vsk;
63
64 if (info->msg && len > 0) {
65 pkt->buf = kmalloc(len, GFP_KERNEL);
66 if (!pkt->buf)
67 goto out_pkt;
68 err = memcpy_from_msg(pkt->buf, info->msg, len);
69 if (err)
70 goto out;
71 }
72
73 trace_virtio_transport_alloc_pkt(src_cid, src_port,
74 dst_cid, dst_port,
75 len,
76 info->type,
77 info->op,
78 info->flags);
79
80 return pkt;
81
82 out:
83 kfree(pkt->buf);
84 out_pkt:
85 kfree(pkt);
86 return NULL;
87 }
88
89 /* Packet capture */
virtio_transport_build_skb(void * opaque)90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92 struct virtio_vsock_pkt *pkt = opaque;
93 struct af_vsockmon_hdr *hdr;
94 struct sk_buff *skb;
95 size_t payload_len;
96 void *payload_buf;
97
98 /* A packet could be split to fit the RX buffer, so we can retrieve
99 * the payload length from the header and the buffer pointer taking
100 * care of the offset in the original packet.
101 */
102 payload_len = le32_to_cpu(pkt->hdr.len);
103 payload_buf = pkt->buf + pkt->off;
104
105 skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
106 GFP_ATOMIC);
107 if (!skb)
108 return NULL;
109
110 hdr = skb_put(skb, sizeof(*hdr));
111
112 /* pkt->hdr is little-endian so no need to byteswap here */
113 hdr->src_cid = pkt->hdr.src_cid;
114 hdr->src_port = pkt->hdr.src_port;
115 hdr->dst_cid = pkt->hdr.dst_cid;
116 hdr->dst_port = pkt->hdr.dst_port;
117
118 hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
119 hdr->len = cpu_to_le16(sizeof(pkt->hdr));
120 memset(hdr->reserved, 0, sizeof(hdr->reserved));
121
122 switch (le16_to_cpu(pkt->hdr.op)) {
123 case VIRTIO_VSOCK_OP_REQUEST:
124 case VIRTIO_VSOCK_OP_RESPONSE:
125 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
126 break;
127 case VIRTIO_VSOCK_OP_RST:
128 case VIRTIO_VSOCK_OP_SHUTDOWN:
129 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
130 break;
131 case VIRTIO_VSOCK_OP_RW:
132 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
133 break;
134 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
135 case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
136 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
137 break;
138 default:
139 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
140 break;
141 }
142
143 skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
144
145 if (payload_len) {
146 skb_put_data(skb, payload_buf, payload_len);
147 }
148
149 return skb;
150 }
151
virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt * pkt)152 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
153 {
154 vsock_deliver_tap(virtio_transport_build_skb, pkt);
155 }
156 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
157
virtio_transport_send_pkt_info(struct vsock_sock * vsk,struct virtio_vsock_pkt_info * info)158 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
159 struct virtio_vsock_pkt_info *info)
160 {
161 u32 src_cid, src_port, dst_cid, dst_port;
162 struct virtio_vsock_sock *vvs;
163 struct virtio_vsock_pkt *pkt;
164 u32 pkt_len = info->pkt_len;
165
166 src_cid = vm_sockets_get_local_cid();
167 src_port = vsk->local_addr.svm_port;
168 if (!info->remote_cid) {
169 dst_cid = vsk->remote_addr.svm_cid;
170 dst_port = vsk->remote_addr.svm_port;
171 } else {
172 dst_cid = info->remote_cid;
173 dst_port = info->remote_port;
174 }
175
176 vvs = vsk->trans;
177
178 /* we can send less than pkt_len bytes */
179 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
180 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
181
182 /* virtio_transport_get_credit might return less than pkt_len credit */
183 pkt_len = virtio_transport_get_credit(vvs, pkt_len);
184
185 /* Do not send zero length OP_RW pkt */
186 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
187 return pkt_len;
188
189 pkt = virtio_transport_alloc_pkt(info, pkt_len,
190 src_cid, src_port,
191 dst_cid, dst_port);
192 if (!pkt) {
193 virtio_transport_put_credit(vvs, pkt_len);
194 return -ENOMEM;
195 }
196
197 virtio_transport_inc_tx_pkt(vvs, pkt);
198
199 return virtio_transport_get_ops()->send_pkt(pkt);
200 }
201
virtio_transport_inc_rx_pkt(struct virtio_vsock_sock * vvs,struct virtio_vsock_pkt * pkt)202 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
203 struct virtio_vsock_pkt *pkt)
204 {
205 vvs->rx_bytes += pkt->len;
206 }
207
virtio_transport_dec_rx_pkt(struct virtio_vsock_sock * vvs,struct virtio_vsock_pkt * pkt)208 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
209 struct virtio_vsock_pkt *pkt)
210 {
211 vvs->rx_bytes -= pkt->len;
212 vvs->fwd_cnt += pkt->len;
213 }
214
virtio_transport_inc_tx_pkt(struct virtio_vsock_sock * vvs,struct virtio_vsock_pkt * pkt)215 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
216 {
217 spin_lock_bh(&vvs->tx_lock);
218 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
219 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
220 spin_unlock_bh(&vvs->tx_lock);
221 }
222 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
223
virtio_transport_get_credit(struct virtio_vsock_sock * vvs,u32 credit)224 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
225 {
226 u32 ret;
227
228 spin_lock_bh(&vvs->tx_lock);
229 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
230 if (ret > credit)
231 ret = credit;
232 vvs->tx_cnt += ret;
233 spin_unlock_bh(&vvs->tx_lock);
234
235 return ret;
236 }
237 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
238
virtio_transport_put_credit(struct virtio_vsock_sock * vvs,u32 credit)239 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
240 {
241 spin_lock_bh(&vvs->tx_lock);
242 vvs->tx_cnt -= credit;
243 spin_unlock_bh(&vvs->tx_lock);
244 }
245 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
246
virtio_transport_send_credit_update(struct vsock_sock * vsk,int type,struct virtio_vsock_hdr * hdr)247 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
248 int type,
249 struct virtio_vsock_hdr *hdr)
250 {
251 struct virtio_vsock_pkt_info info = {
252 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
253 .type = type,
254 .vsk = vsk,
255 };
256
257 return virtio_transport_send_pkt_info(vsk, &info);
258 }
259
260 static ssize_t
virtio_transport_stream_do_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len)261 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
262 struct msghdr *msg,
263 size_t len)
264 {
265 struct virtio_vsock_sock *vvs = vsk->trans;
266 struct virtio_vsock_pkt *pkt;
267 size_t bytes, total = 0;
268 int err = -EFAULT;
269
270 spin_lock_bh(&vvs->rx_lock);
271 while (total < len && !list_empty(&vvs->rx_queue)) {
272 pkt = list_first_entry(&vvs->rx_queue,
273 struct virtio_vsock_pkt, list);
274
275 bytes = len - total;
276 if (bytes > pkt->len - pkt->off)
277 bytes = pkt->len - pkt->off;
278
279 /* sk_lock is held by caller so no one else can dequeue.
280 * Unlock rx_lock since memcpy_to_msg() may sleep.
281 */
282 spin_unlock_bh(&vvs->rx_lock);
283
284 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
285 if (err)
286 goto out;
287
288 spin_lock_bh(&vvs->rx_lock);
289
290 total += bytes;
291 pkt->off += bytes;
292 if (pkt->off == pkt->len) {
293 virtio_transport_dec_rx_pkt(vvs, pkt);
294 list_del(&pkt->list);
295 virtio_transport_free_pkt(pkt);
296 }
297 }
298 spin_unlock_bh(&vvs->rx_lock);
299
300 /* Send a credit pkt to peer */
301 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
302 NULL);
303
304 return total;
305
306 out:
307 if (total)
308 err = total;
309 return err;
310 }
311
312 ssize_t
virtio_transport_stream_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len,int flags)313 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
314 struct msghdr *msg,
315 size_t len, int flags)
316 {
317 if (flags & MSG_PEEK)
318 return -EOPNOTSUPP;
319
320 return virtio_transport_stream_do_dequeue(vsk, msg, len);
321 }
322 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
323
324 int
virtio_transport_dgram_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len,int flags)325 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
326 struct msghdr *msg,
327 size_t len, int flags)
328 {
329 return -EOPNOTSUPP;
330 }
331 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
332
virtio_transport_stream_has_data(struct vsock_sock * vsk)333 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
334 {
335 struct virtio_vsock_sock *vvs = vsk->trans;
336 s64 bytes;
337
338 spin_lock_bh(&vvs->rx_lock);
339 bytes = vvs->rx_bytes;
340 spin_unlock_bh(&vvs->rx_lock);
341
342 return bytes;
343 }
344 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
345
virtio_transport_has_space(struct vsock_sock * vsk)346 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
347 {
348 struct virtio_vsock_sock *vvs = vsk->trans;
349 s64 bytes;
350
351 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
352 if (bytes < 0)
353 bytes = 0;
354
355 return bytes;
356 }
357
virtio_transport_stream_has_space(struct vsock_sock * vsk)358 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
359 {
360 struct virtio_vsock_sock *vvs = vsk->trans;
361 s64 bytes;
362
363 spin_lock_bh(&vvs->tx_lock);
364 bytes = virtio_transport_has_space(vsk);
365 spin_unlock_bh(&vvs->tx_lock);
366
367 return bytes;
368 }
369 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
370
virtio_transport_do_socket_init(struct vsock_sock * vsk,struct vsock_sock * psk)371 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
372 struct vsock_sock *psk)
373 {
374 struct virtio_vsock_sock *vvs;
375
376 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
377 if (!vvs)
378 return -ENOMEM;
379
380 vsk->trans = vvs;
381 vvs->vsk = vsk;
382 if (psk) {
383 struct virtio_vsock_sock *ptrans = psk->trans;
384
385 vvs->buf_size = ptrans->buf_size;
386 vvs->buf_size_min = ptrans->buf_size_min;
387 vvs->buf_size_max = ptrans->buf_size_max;
388 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
389 } else {
390 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
391 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
392 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
393 }
394
395 vvs->buf_alloc = vvs->buf_size;
396
397 spin_lock_init(&vvs->rx_lock);
398 spin_lock_init(&vvs->tx_lock);
399 INIT_LIST_HEAD(&vvs->rx_queue);
400
401 return 0;
402 }
403 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
404
virtio_transport_get_buffer_size(struct vsock_sock * vsk)405 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
406 {
407 struct virtio_vsock_sock *vvs = vsk->trans;
408
409 return vvs->buf_size;
410 }
411 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
412
virtio_transport_get_min_buffer_size(struct vsock_sock * vsk)413 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
414 {
415 struct virtio_vsock_sock *vvs = vsk->trans;
416
417 return vvs->buf_size_min;
418 }
419 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
420
virtio_transport_get_max_buffer_size(struct vsock_sock * vsk)421 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
422 {
423 struct virtio_vsock_sock *vvs = vsk->trans;
424
425 return vvs->buf_size_max;
426 }
427 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
428
virtio_transport_set_buffer_size(struct vsock_sock * vsk,u64 val)429 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
430 {
431 struct virtio_vsock_sock *vvs = vsk->trans;
432
433 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
434 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
435 if (val < vvs->buf_size_min)
436 vvs->buf_size_min = val;
437 if (val > vvs->buf_size_max)
438 vvs->buf_size_max = val;
439 vvs->buf_size = val;
440 vvs->buf_alloc = val;
441 }
442 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
443
virtio_transport_set_min_buffer_size(struct vsock_sock * vsk,u64 val)444 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
445 {
446 struct virtio_vsock_sock *vvs = vsk->trans;
447
448 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
449 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
450 if (val > vvs->buf_size)
451 vvs->buf_size = val;
452 vvs->buf_size_min = val;
453 }
454 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
455
virtio_transport_set_max_buffer_size(struct vsock_sock * vsk,u64 val)456 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
457 {
458 struct virtio_vsock_sock *vvs = vsk->trans;
459
460 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
461 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
462 if (val < vvs->buf_size)
463 vvs->buf_size = val;
464 vvs->buf_size_max = val;
465 }
466 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
467
468 int
virtio_transport_notify_poll_in(struct vsock_sock * vsk,size_t target,bool * data_ready_now)469 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
470 size_t target,
471 bool *data_ready_now)
472 {
473 if (vsock_stream_has_data(vsk))
474 *data_ready_now = true;
475 else
476 *data_ready_now = false;
477
478 return 0;
479 }
480 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
481
482 int
virtio_transport_notify_poll_out(struct vsock_sock * vsk,size_t target,bool * space_avail_now)483 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
484 size_t target,
485 bool *space_avail_now)
486 {
487 s64 free_space;
488
489 free_space = vsock_stream_has_space(vsk);
490 if (free_space > 0)
491 *space_avail_now = true;
492 else if (free_space == 0)
493 *space_avail_now = false;
494
495 return 0;
496 }
497 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
498
virtio_transport_notify_recv_init(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * data)499 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
500 size_t target, struct vsock_transport_recv_notify_data *data)
501 {
502 return 0;
503 }
504 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
505
virtio_transport_notify_recv_pre_block(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * data)506 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
507 size_t target, struct vsock_transport_recv_notify_data *data)
508 {
509 return 0;
510 }
511 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
512
virtio_transport_notify_recv_pre_dequeue(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * data)513 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
514 size_t target, struct vsock_transport_recv_notify_data *data)
515 {
516 return 0;
517 }
518 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
519
virtio_transport_notify_recv_post_dequeue(struct vsock_sock * vsk,size_t target,ssize_t copied,bool data_read,struct vsock_transport_recv_notify_data * data)520 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
521 size_t target, ssize_t copied, bool data_read,
522 struct vsock_transport_recv_notify_data *data)
523 {
524 return 0;
525 }
526 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
527
virtio_transport_notify_send_init(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * data)528 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
529 struct vsock_transport_send_notify_data *data)
530 {
531 return 0;
532 }
533 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
534
virtio_transport_notify_send_pre_block(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * data)535 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
536 struct vsock_transport_send_notify_data *data)
537 {
538 return 0;
539 }
540 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
541
virtio_transport_notify_send_pre_enqueue(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * data)542 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
543 struct vsock_transport_send_notify_data *data)
544 {
545 return 0;
546 }
547 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
548
virtio_transport_notify_send_post_enqueue(struct vsock_sock * vsk,ssize_t written,struct vsock_transport_send_notify_data * data)549 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
550 ssize_t written, struct vsock_transport_send_notify_data *data)
551 {
552 return 0;
553 }
554 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
555
virtio_transport_stream_rcvhiwat(struct vsock_sock * vsk)556 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
557 {
558 struct virtio_vsock_sock *vvs = vsk->trans;
559
560 return vvs->buf_size;
561 }
562 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
563
virtio_transport_stream_is_active(struct vsock_sock * vsk)564 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
565 {
566 return true;
567 }
568 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
569
virtio_transport_stream_allow(u32 cid,u32 port)570 bool virtio_transport_stream_allow(u32 cid, u32 port)
571 {
572 return true;
573 }
574 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
575
virtio_transport_dgram_bind(struct vsock_sock * vsk,struct sockaddr_vm * addr)576 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
577 struct sockaddr_vm *addr)
578 {
579 return -EOPNOTSUPP;
580 }
581 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
582
virtio_transport_dgram_allow(u32 cid,u32 port)583 bool virtio_transport_dgram_allow(u32 cid, u32 port)
584 {
585 return false;
586 }
587 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
588
virtio_transport_connect(struct vsock_sock * vsk)589 int virtio_transport_connect(struct vsock_sock *vsk)
590 {
591 struct virtio_vsock_pkt_info info = {
592 .op = VIRTIO_VSOCK_OP_REQUEST,
593 .type = VIRTIO_VSOCK_TYPE_STREAM,
594 .vsk = vsk,
595 };
596
597 return virtio_transport_send_pkt_info(vsk, &info);
598 }
599 EXPORT_SYMBOL_GPL(virtio_transport_connect);
600
virtio_transport_shutdown(struct vsock_sock * vsk,int mode)601 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
602 {
603 struct virtio_vsock_pkt_info info = {
604 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
605 .type = VIRTIO_VSOCK_TYPE_STREAM,
606 .flags = (mode & RCV_SHUTDOWN ?
607 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
608 (mode & SEND_SHUTDOWN ?
609 VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
610 .vsk = vsk,
611 };
612
613 return virtio_transport_send_pkt_info(vsk, &info);
614 }
615 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
616
617 int
virtio_transport_dgram_enqueue(struct vsock_sock * vsk,struct sockaddr_vm * remote_addr,struct msghdr * msg,size_t dgram_len)618 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
619 struct sockaddr_vm *remote_addr,
620 struct msghdr *msg,
621 size_t dgram_len)
622 {
623 return -EOPNOTSUPP;
624 }
625 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
626
627 ssize_t
virtio_transport_stream_enqueue(struct vsock_sock * vsk,struct msghdr * msg,size_t len)628 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
629 struct msghdr *msg,
630 size_t len)
631 {
632 struct virtio_vsock_pkt_info info = {
633 .op = VIRTIO_VSOCK_OP_RW,
634 .type = VIRTIO_VSOCK_TYPE_STREAM,
635 .msg = msg,
636 .pkt_len = len,
637 .vsk = vsk,
638 };
639
640 return virtio_transport_send_pkt_info(vsk, &info);
641 }
642 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
643
virtio_transport_destruct(struct vsock_sock * vsk)644 void virtio_transport_destruct(struct vsock_sock *vsk)
645 {
646 struct virtio_vsock_sock *vvs = vsk->trans;
647
648 kfree(vvs);
649 }
650 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
651
virtio_transport_reset(struct vsock_sock * vsk,struct virtio_vsock_pkt * pkt)652 static int virtio_transport_reset(struct vsock_sock *vsk,
653 struct virtio_vsock_pkt *pkt)
654 {
655 struct virtio_vsock_pkt_info info = {
656 .op = VIRTIO_VSOCK_OP_RST,
657 .type = VIRTIO_VSOCK_TYPE_STREAM,
658 .reply = !!pkt,
659 .vsk = vsk,
660 };
661
662 /* Send RST only if the original pkt is not a RST pkt */
663 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
664 return 0;
665
666 return virtio_transport_send_pkt_info(vsk, &info);
667 }
668
669 /* Normally packets are associated with a socket. There may be no socket if an
670 * attempt was made to connect to a socket that does not exist.
671 */
virtio_transport_reset_no_sock(struct virtio_vsock_pkt * pkt)672 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
673 {
674 const struct virtio_transport *t;
675 struct virtio_vsock_pkt *reply;
676 struct virtio_vsock_pkt_info info = {
677 .op = VIRTIO_VSOCK_OP_RST,
678 .type = le16_to_cpu(pkt->hdr.type),
679 .reply = true,
680 };
681
682 /* Send RST only if the original pkt is not a RST pkt */
683 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
684 return 0;
685
686 reply = virtio_transport_alloc_pkt(&info, 0,
687 le64_to_cpu(pkt->hdr.dst_cid),
688 le32_to_cpu(pkt->hdr.dst_port),
689 le64_to_cpu(pkt->hdr.src_cid),
690 le32_to_cpu(pkt->hdr.src_port));
691 if (!reply)
692 return -ENOMEM;
693
694 t = virtio_transport_get_ops();
695 if (!t) {
696 virtio_transport_free_pkt(reply);
697 return -ENOTCONN;
698 }
699
700 return t->send_pkt(reply);
701 }
702
virtio_transport_wait_close(struct sock * sk,long timeout)703 static void virtio_transport_wait_close(struct sock *sk, long timeout)
704 {
705 if (timeout) {
706 DEFINE_WAIT_FUNC(wait, woken_wake_function);
707
708 add_wait_queue(sk_sleep(sk), &wait);
709
710 do {
711 if (sk_wait_event(sk, &timeout,
712 sock_flag(sk, SOCK_DONE), &wait))
713 break;
714 } while (!signal_pending(current) && timeout);
715
716 remove_wait_queue(sk_sleep(sk), &wait);
717 }
718 }
719
virtio_transport_do_close(struct vsock_sock * vsk,bool cancel_timeout)720 static void virtio_transport_do_close(struct vsock_sock *vsk,
721 bool cancel_timeout)
722 {
723 struct sock *sk = sk_vsock(vsk);
724
725 sock_set_flag(sk, SOCK_DONE);
726 vsk->peer_shutdown = SHUTDOWN_MASK;
727 if (vsock_stream_has_data(vsk) <= 0)
728 sk->sk_state = TCP_CLOSING;
729 sk->sk_state_change(sk);
730
731 if (vsk->close_work_scheduled &&
732 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
733 vsk->close_work_scheduled = false;
734
735 vsock_remove_sock(vsk);
736
737 /* Release refcnt obtained when we scheduled the timeout */
738 sock_put(sk);
739 }
740 }
741
virtio_transport_close_timeout(struct work_struct * work)742 static void virtio_transport_close_timeout(struct work_struct *work)
743 {
744 struct vsock_sock *vsk =
745 container_of(work, struct vsock_sock, close_work.work);
746 struct sock *sk = sk_vsock(vsk);
747
748 sock_hold(sk);
749 lock_sock(sk);
750
751 if (!sock_flag(sk, SOCK_DONE)) {
752 (void)virtio_transport_reset(vsk, NULL);
753
754 virtio_transport_do_close(vsk, false);
755 }
756
757 vsk->close_work_scheduled = false;
758
759 release_sock(sk);
760 sock_put(sk);
761 }
762
763 /* User context, vsk->sk is locked */
virtio_transport_close(struct vsock_sock * vsk)764 static bool virtio_transport_close(struct vsock_sock *vsk)
765 {
766 struct sock *sk = &vsk->sk;
767
768 if (!(sk->sk_state == TCP_ESTABLISHED ||
769 sk->sk_state == TCP_CLOSING))
770 return true;
771
772 /* Already received SHUTDOWN from peer, reply with RST */
773 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
774 (void)virtio_transport_reset(vsk, NULL);
775 return true;
776 }
777
778 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
779 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
780
781 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
782 virtio_transport_wait_close(sk, sk->sk_lingertime);
783
784 if (sock_flag(sk, SOCK_DONE)) {
785 return true;
786 }
787
788 sock_hold(sk);
789 INIT_DELAYED_WORK(&vsk->close_work,
790 virtio_transport_close_timeout);
791 vsk->close_work_scheduled = true;
792 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
793 return false;
794 }
795
virtio_transport_release(struct vsock_sock * vsk)796 void virtio_transport_release(struct vsock_sock *vsk)
797 {
798 struct virtio_vsock_sock *vvs = vsk->trans;
799 struct virtio_vsock_pkt *pkt, *tmp;
800 struct sock *sk = &vsk->sk;
801 bool remove_sock = true;
802
803 lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
804 if (sk->sk_type == SOCK_STREAM)
805 remove_sock = virtio_transport_close(vsk);
806
807 list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
808 list_del(&pkt->list);
809 virtio_transport_free_pkt(pkt);
810 }
811 release_sock(sk);
812
813 if (remove_sock)
814 vsock_remove_sock(vsk);
815 }
816 EXPORT_SYMBOL_GPL(virtio_transport_release);
817
818 static int
virtio_transport_recv_connecting(struct sock * sk,struct virtio_vsock_pkt * pkt)819 virtio_transport_recv_connecting(struct sock *sk,
820 struct virtio_vsock_pkt *pkt)
821 {
822 struct vsock_sock *vsk = vsock_sk(sk);
823 int err;
824 int skerr;
825
826 switch (le16_to_cpu(pkt->hdr.op)) {
827 case VIRTIO_VSOCK_OP_RESPONSE:
828 sk->sk_state = TCP_ESTABLISHED;
829 sk->sk_socket->state = SS_CONNECTED;
830 vsock_insert_connected(vsk);
831 sk->sk_state_change(sk);
832 break;
833 case VIRTIO_VSOCK_OP_INVALID:
834 break;
835 case VIRTIO_VSOCK_OP_RST:
836 skerr = ECONNRESET;
837 err = 0;
838 goto destroy;
839 default:
840 skerr = EPROTO;
841 err = -EINVAL;
842 goto destroy;
843 }
844 return 0;
845
846 destroy:
847 virtio_transport_reset(vsk, pkt);
848 sk->sk_state = TCP_CLOSE;
849 sk->sk_err = skerr;
850 sk->sk_error_report(sk);
851 return err;
852 }
853
854 static int
virtio_transport_recv_connected(struct sock * sk,struct virtio_vsock_pkt * pkt)855 virtio_transport_recv_connected(struct sock *sk,
856 struct virtio_vsock_pkt *pkt)
857 {
858 struct vsock_sock *vsk = vsock_sk(sk);
859 struct virtio_vsock_sock *vvs = vsk->trans;
860 int err = 0;
861
862 switch (le16_to_cpu(pkt->hdr.op)) {
863 case VIRTIO_VSOCK_OP_RW:
864 pkt->len = le32_to_cpu(pkt->hdr.len);
865 pkt->off = 0;
866
867 spin_lock_bh(&vvs->rx_lock);
868 virtio_transport_inc_rx_pkt(vvs, pkt);
869 list_add_tail(&pkt->list, &vvs->rx_queue);
870 spin_unlock_bh(&vvs->rx_lock);
871
872 sk->sk_data_ready(sk);
873 return err;
874 case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
875 sk->sk_write_space(sk);
876 break;
877 case VIRTIO_VSOCK_OP_SHUTDOWN:
878 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
879 vsk->peer_shutdown |= RCV_SHUTDOWN;
880 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
881 vsk->peer_shutdown |= SEND_SHUTDOWN;
882 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
883 vsock_stream_has_data(vsk) <= 0)
884 sk->sk_state = TCP_CLOSING;
885 if (le32_to_cpu(pkt->hdr.flags))
886 sk->sk_state_change(sk);
887 break;
888 case VIRTIO_VSOCK_OP_RST:
889 virtio_transport_do_close(vsk, true);
890 break;
891 default:
892 err = -EINVAL;
893 break;
894 }
895
896 virtio_transport_free_pkt(pkt);
897 return err;
898 }
899
900 static void
virtio_transport_recv_disconnecting(struct sock * sk,struct virtio_vsock_pkt * pkt)901 virtio_transport_recv_disconnecting(struct sock *sk,
902 struct virtio_vsock_pkt *pkt)
903 {
904 struct vsock_sock *vsk = vsock_sk(sk);
905
906 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
907 virtio_transport_do_close(vsk, true);
908 }
909
910 static int
virtio_transport_send_response(struct vsock_sock * vsk,struct virtio_vsock_pkt * pkt)911 virtio_transport_send_response(struct vsock_sock *vsk,
912 struct virtio_vsock_pkt *pkt)
913 {
914 struct virtio_vsock_pkt_info info = {
915 .op = VIRTIO_VSOCK_OP_RESPONSE,
916 .type = VIRTIO_VSOCK_TYPE_STREAM,
917 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
918 .remote_port = le32_to_cpu(pkt->hdr.src_port),
919 .reply = true,
920 .vsk = vsk,
921 };
922
923 return virtio_transport_send_pkt_info(vsk, &info);
924 }
925
926 /* Handle server socket */
927 static int
virtio_transport_recv_listen(struct sock * sk,struct virtio_vsock_pkt * pkt)928 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
929 {
930 struct vsock_sock *vsk = vsock_sk(sk);
931 struct vsock_sock *vchild;
932 struct sock *child;
933
934 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
935 virtio_transport_reset(vsk, pkt);
936 return -EINVAL;
937 }
938
939 if (sk_acceptq_is_full(sk)) {
940 virtio_transport_reset(vsk, pkt);
941 return -ENOMEM;
942 }
943
944 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
945 sk->sk_type, 0);
946 if (!child) {
947 virtio_transport_reset(vsk, pkt);
948 return -ENOMEM;
949 }
950
951 sk->sk_ack_backlog++;
952
953 lock_sock_nested(child, SINGLE_DEPTH_NESTING);
954
955 child->sk_state = TCP_ESTABLISHED;
956
957 vchild = vsock_sk(child);
958 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
959 le32_to_cpu(pkt->hdr.dst_port));
960 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
961 le32_to_cpu(pkt->hdr.src_port));
962
963 vsock_insert_connected(vchild);
964 vsock_enqueue_accept(sk, child);
965 virtio_transport_send_response(vchild, pkt);
966
967 release_sock(child);
968
969 sk->sk_data_ready(sk);
970 return 0;
971 }
972
virtio_transport_space_update(struct sock * sk,struct virtio_vsock_pkt * pkt)973 static bool virtio_transport_space_update(struct sock *sk,
974 struct virtio_vsock_pkt *pkt)
975 {
976 struct vsock_sock *vsk = vsock_sk(sk);
977 struct virtio_vsock_sock *vvs = vsk->trans;
978 bool space_available;
979
980 /* buf_alloc and fwd_cnt is always included in the hdr */
981 spin_lock_bh(&vvs->tx_lock);
982 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
983 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
984 space_available = virtio_transport_has_space(vsk);
985 spin_unlock_bh(&vvs->tx_lock);
986 return space_available;
987 }
988
989 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
990 * lock.
991 */
virtio_transport_recv_pkt(struct virtio_vsock_pkt * pkt)992 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
993 {
994 struct sockaddr_vm src, dst;
995 struct vsock_sock *vsk;
996 struct sock *sk;
997 bool space_available;
998
999 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
1000 le32_to_cpu(pkt->hdr.src_port));
1001 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
1002 le32_to_cpu(pkt->hdr.dst_port));
1003
1004 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1005 dst.svm_cid, dst.svm_port,
1006 le32_to_cpu(pkt->hdr.len),
1007 le16_to_cpu(pkt->hdr.type),
1008 le16_to_cpu(pkt->hdr.op),
1009 le32_to_cpu(pkt->hdr.flags),
1010 le32_to_cpu(pkt->hdr.buf_alloc),
1011 le32_to_cpu(pkt->hdr.fwd_cnt));
1012
1013 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
1014 (void)virtio_transport_reset_no_sock(pkt);
1015 goto free_pkt;
1016 }
1017
1018 /* The socket must be in connected or bound table
1019 * otherwise send reset back
1020 */
1021 sk = vsock_find_connected_socket(&src, &dst);
1022 if (!sk) {
1023 sk = vsock_find_bound_socket(&dst);
1024 if (!sk) {
1025 (void)virtio_transport_reset_no_sock(pkt);
1026 goto free_pkt;
1027 }
1028 }
1029
1030 vsk = vsock_sk(sk);
1031
1032 space_available = virtio_transport_space_update(sk, pkt);
1033
1034 lock_sock(sk);
1035
1036 /* Update CID in case it has changed after a transport reset event */
1037 vsk->local_addr.svm_cid = dst.svm_cid;
1038
1039 if (space_available)
1040 sk->sk_write_space(sk);
1041
1042 switch (sk->sk_state) {
1043 case TCP_LISTEN:
1044 virtio_transport_recv_listen(sk, pkt);
1045 virtio_transport_free_pkt(pkt);
1046 break;
1047 case TCP_SYN_SENT:
1048 virtio_transport_recv_connecting(sk, pkt);
1049 virtio_transport_free_pkt(pkt);
1050 break;
1051 case TCP_ESTABLISHED:
1052 virtio_transport_recv_connected(sk, pkt);
1053 break;
1054 case TCP_CLOSING:
1055 virtio_transport_recv_disconnecting(sk, pkt);
1056 virtio_transport_free_pkt(pkt);
1057 break;
1058 default:
1059 virtio_transport_free_pkt(pkt);
1060 break;
1061 }
1062 release_sock(sk);
1063
1064 /* Release refcnt obtained when we fetched this socket out of the
1065 * bound or connected list.
1066 */
1067 sock_put(sk);
1068 return;
1069
1070 free_pkt:
1071 virtio_transport_free_pkt(pkt);
1072 }
1073 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1074
virtio_transport_free_pkt(struct virtio_vsock_pkt * pkt)1075 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1076 {
1077 kfree(pkt->buf);
1078 kfree(pkt);
1079 }
1080 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1081
1082 MODULE_LICENSE("GPL v2");
1083 MODULE_AUTHOR("Asias He");
1084 MODULE_DESCRIPTION("common code for virtio vsock");
1085