1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Hyper-V transport for vsock
4 *
5 * Hyper-V Sockets supplies a byte-stream based communication mechanism
6 * between the host and the VM. This driver implements the necessary
7 * support in the VM by introducing the new vsock transport.
8 *
9 * Copyright (c) 2017, Microsoft Corporation.
10 */
11 #include <linux/module.h>
12 #include <linux/vmalloc.h>
13 #include <linux/hyperv.h>
14 #include <net/sock.h>
15 #include <net/af_vsock.h>
16
17 /* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
18 * stricter requirements on the hv_sock ring buffer size of six 4K pages. Newer
19 * hosts don't have this limitation; but, keep the defaults the same for compat.
20 */
21 #define PAGE_SIZE_4K 4096
22 #define RINGBUFFER_HVS_RCV_SIZE (PAGE_SIZE_4K * 6)
23 #define RINGBUFFER_HVS_SND_SIZE (PAGE_SIZE_4K * 6)
24 #define RINGBUFFER_HVS_MAX_SIZE (PAGE_SIZE_4K * 64)
25
26 /* The MTU is 16KB per the host side's design */
27 #define HVS_MTU_SIZE (1024 * 16)
28
29 /* How long to wait for graceful shutdown of a connection */
30 #define HVS_CLOSE_TIMEOUT (8 * HZ)
31
32 struct vmpipe_proto_header {
33 u32 pkt_type;
34 u32 data_size;
35 };
36
37 /* For recv, we use the VMBus in-place packet iterator APIs to directly copy
38 * data from the ringbuffer into the userspace buffer.
39 */
40 struct hvs_recv_buf {
41 /* The header before the payload data */
42 struct vmpipe_proto_header hdr;
43
44 /* The payload */
45 u8 data[HVS_MTU_SIZE];
46 };
47
48 /* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use
49 * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the
50 * guest and the host processing as one VMBUS packet is the smallest processing
51 * unit.
52 *
53 * Note: the buffer can be eliminated in the future when we add new VMBus
54 * ringbuffer APIs that allow us to directly copy data from userspace buffer
55 * to VMBus ringbuffer.
56 */
57 #define HVS_SEND_BUF_SIZE (PAGE_SIZE_4K - sizeof(struct vmpipe_proto_header))
58
59 struct hvs_send_buf {
60 /* The header before the payload data */
61 struct vmpipe_proto_header hdr;
62
63 /* The payload */
64 u8 data[HVS_SEND_BUF_SIZE];
65 };
66
67 #define HVS_HEADER_LEN (sizeof(struct vmpacket_descriptor) + \
68 sizeof(struct vmpipe_proto_header))
69
70 /* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and
71 * __hv_pkt_iter_next().
72 */
73 #define VMBUS_PKT_TRAILER_SIZE (sizeof(u64))
74
75 #define HVS_PKT_LEN(payload_len) (HVS_HEADER_LEN + \
76 ALIGN((payload_len), 8) + \
77 VMBUS_PKT_TRAILER_SIZE)
78
79 union hvs_service_id {
80 guid_t srv_id;
81
82 struct {
83 unsigned int svm_port;
84 unsigned char b[sizeof(guid_t) - sizeof(unsigned int)];
85 };
86 };
87
88 /* Per-socket state (accessed via vsk->trans) */
89 struct hvsock {
90 struct vsock_sock *vsk;
91
92 guid_t vm_srv_id;
93 guid_t host_srv_id;
94
95 struct vmbus_channel *chan;
96 struct vmpacket_descriptor *recv_desc;
97
98 /* The length of the payload not delivered to userland yet */
99 u32 recv_data_len;
100 /* The offset of the payload */
101 u32 recv_data_off;
102
103 /* Have we sent the zero-length packet (FIN)? */
104 bool fin_sent;
105 };
106
107 /* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is
108 * <cid, port> (see struct sockaddr_vm). Note: cid is not really used here:
109 * when we write apps to connect to the host, we can only use VMADDR_CID_ANY
110 * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we
111 * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY
112 * as the local cid.
113 *
114 * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV:
115 * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-
116 * guide/make-integration-service, and the endpoint is <VmID, ServiceId> with
117 * the below sockaddr:
118 *
119 * struct SOCKADDR_HV
120 * {
121 * ADDRESS_FAMILY Family;
122 * USHORT Reserved;
123 * GUID VmId;
124 * GUID ServiceId;
125 * };
126 * Note: VmID is not used by Linux VM and actually it isn't transmitted via
127 * VMBus, because here it's obvious the host and the VM can easily identify
128 * each other. Though the VmID is useful on the host, especially in the case
129 * of Windows container, Linux VM doesn't need it at all.
130 *
131 * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit
132 * the available GUID space of SOCKADDR_HV so that we can create a mapping
133 * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing
134 * Hyper-V Sockets apps on the host and in Linux VM is:
135 *
136 ****************************************************************************
137 * The only valid Service GUIDs, from the perspectives of both the host and *
138 * Linux VM, that can be connected by the other end, must conform to this *
139 * format: <port>-facb-11e6-bd58-64006a7986d3. *
140 ****************************************************************************
141 *
142 * When we write apps on the host to connect(), the GUID ServiceID is used.
143 * When we write apps in Linux VM to connect(), we only need to specify the
144 * port and the driver will form the GUID and use that to request the host.
145 *
146 */
147
148 /* 00000000-facb-11e6-bd58-64006a7986d3 */
149 static const guid_t srv_id_template =
150 GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
151 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
152
is_valid_srv_id(const guid_t * id)153 static bool is_valid_srv_id(const guid_t *id)
154 {
155 return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
156 }
157
get_port_by_srv_id(const guid_t * svr_id)158 static unsigned int get_port_by_srv_id(const guid_t *svr_id)
159 {
160 return *((unsigned int *)svr_id);
161 }
162
hvs_addr_init(struct sockaddr_vm * addr,const guid_t * svr_id)163 static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id)
164 {
165 unsigned int port = get_port_by_srv_id(svr_id);
166
167 vsock_addr_init(addr, VMADDR_CID_ANY, port);
168 }
169
hvs_set_channel_pending_send_size(struct vmbus_channel * chan)170 static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan)
171 {
172 set_channel_pending_send_size(chan,
173 HVS_PKT_LEN(HVS_SEND_BUF_SIZE));
174
175 virt_mb();
176 }
177
hvs_channel_readable(struct vmbus_channel * chan)178 static bool hvs_channel_readable(struct vmbus_channel *chan)
179 {
180 u32 readable = hv_get_bytes_to_read(&chan->inbound);
181
182 /* 0-size payload means FIN */
183 return readable >= HVS_PKT_LEN(0);
184 }
185
hvs_channel_readable_payload(struct vmbus_channel * chan)186 static int hvs_channel_readable_payload(struct vmbus_channel *chan)
187 {
188 u32 readable = hv_get_bytes_to_read(&chan->inbound);
189
190 if (readable > HVS_PKT_LEN(0)) {
191 /* At least we have 1 byte to read. We don't need to return
192 * the exact readable bytes: see vsock_stream_recvmsg() ->
193 * vsock_stream_has_data().
194 */
195 return 1;
196 }
197
198 if (readable == HVS_PKT_LEN(0)) {
199 /* 0-size payload means FIN */
200 return 0;
201 }
202
203 /* No payload or FIN */
204 return -1;
205 }
206
hvs_channel_writable_bytes(struct vmbus_channel * chan)207 static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan)
208 {
209 u32 writeable = hv_get_bytes_to_write(&chan->outbound);
210 size_t ret;
211
212 /* The ringbuffer mustn't be 100% full, and we should reserve a
213 * zero-length-payload packet for the FIN: see hv_ringbuffer_write()
214 * and hvs_shutdown().
215 */
216 if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0))
217 return 0;
218
219 ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0);
220
221 return round_down(ret, 8);
222 }
223
hvs_send_data(struct vmbus_channel * chan,struct hvs_send_buf * send_buf,size_t to_write)224 static int hvs_send_data(struct vmbus_channel *chan,
225 struct hvs_send_buf *send_buf, size_t to_write)
226 {
227 send_buf->hdr.pkt_type = 1;
228 send_buf->hdr.data_size = to_write;
229 return vmbus_sendpacket(chan, &send_buf->hdr,
230 sizeof(send_buf->hdr) + to_write,
231 0, VM_PKT_DATA_INBAND, 0);
232 }
233
hvs_channel_cb(void * ctx)234 static void hvs_channel_cb(void *ctx)
235 {
236 struct sock *sk = (struct sock *)ctx;
237 struct vsock_sock *vsk = vsock_sk(sk);
238 struct hvsock *hvs = vsk->trans;
239 struct vmbus_channel *chan = hvs->chan;
240
241 if (hvs_channel_readable(chan))
242 sk->sk_data_ready(sk);
243
244 if (hv_get_bytes_to_write(&chan->outbound) > 0)
245 sk->sk_write_space(sk);
246 }
247
hvs_do_close_lock_held(struct vsock_sock * vsk,bool cancel_timeout)248 static void hvs_do_close_lock_held(struct vsock_sock *vsk,
249 bool cancel_timeout)
250 {
251 struct sock *sk = sk_vsock(vsk);
252
253 sock_set_flag(sk, SOCK_DONE);
254 vsk->peer_shutdown = SHUTDOWN_MASK;
255 if (vsock_stream_has_data(vsk) <= 0)
256 sk->sk_state = TCP_CLOSING;
257 sk->sk_state_change(sk);
258 if (vsk->close_work_scheduled &&
259 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
260 vsk->close_work_scheduled = false;
261 vsock_remove_sock(vsk);
262
263 /* Release the reference taken while scheduling the timeout */
264 sock_put(sk);
265 }
266 }
267
hvs_close_connection(struct vmbus_channel * chan)268 static void hvs_close_connection(struct vmbus_channel *chan)
269 {
270 struct sock *sk = get_per_channel_state(chan);
271
272 lock_sock(sk);
273 hvs_do_close_lock_held(vsock_sk(sk), true);
274 release_sock(sk);
275
276 /* Release the refcnt for the channel that's opened in
277 * hvs_open_connection().
278 */
279 sock_put(sk);
280 }
281
hvs_open_connection(struct vmbus_channel * chan)282 static void hvs_open_connection(struct vmbus_channel *chan)
283 {
284 guid_t *if_instance, *if_type;
285 unsigned char conn_from_host;
286
287 struct sockaddr_vm addr;
288 struct sock *sk, *new = NULL;
289 struct vsock_sock *vnew = NULL;
290 struct hvsock *hvs = NULL;
291 struct hvsock *hvs_new = NULL;
292 int rcvbuf;
293 int ret;
294 int sndbuf;
295
296 if_type = &chan->offermsg.offer.if_type;
297 if_instance = &chan->offermsg.offer.if_instance;
298 conn_from_host = chan->offermsg.offer.u.pipe.user_def[0];
299 if (!is_valid_srv_id(if_type))
300 return;
301
302 hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
303 sk = vsock_find_bound_socket(&addr);
304 if (!sk)
305 return;
306
307 lock_sock(sk);
308 if ((conn_from_host && sk->sk_state != TCP_LISTEN) ||
309 (!conn_from_host && sk->sk_state != TCP_SYN_SENT))
310 goto out;
311
312 if (conn_from_host) {
313 if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
314 goto out;
315
316 new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
317 sk->sk_type, 0);
318 if (!new)
319 goto out;
320
321 new->sk_state = TCP_SYN_SENT;
322 vnew = vsock_sk(new);
323
324 hvs_addr_init(&vnew->local_addr, if_type);
325
326 /* Remote peer is always the host */
327 vsock_addr_init(&vnew->remote_addr,
328 VMADDR_CID_HOST, VMADDR_PORT_ANY);
329 vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance);
330 hvs_new = vnew->trans;
331 hvs_new->chan = chan;
332 } else {
333 hvs = vsock_sk(sk)->trans;
334 hvs->chan = chan;
335 }
336
337 set_channel_read_mode(chan, HV_CALL_DIRECT);
338
339 /* Use the socket buffer sizes as hints for the VMBUS ring size. For
340 * server side sockets, 'sk' is the parent socket and thus, this will
341 * allow the child sockets to inherit the size from the parent. Keep
342 * the mins to the default value and align to page size as per VMBUS
343 * requirements.
344 * For the max, the socket core library will limit the socket buffer
345 * size that can be set by the user, but, since currently, the hv_sock
346 * VMBUS ring buffer is physically contiguous allocation, restrict it
347 * further.
348 * Older versions of hv_sock host side code cannot handle bigger VMBUS
349 * ring buffer size. Use the version number to limit the change to newer
350 * versions.
351 */
352 if (vmbus_proto_version < VERSION_WIN10_V5) {
353 sndbuf = RINGBUFFER_HVS_SND_SIZE;
354 rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
355 } else {
356 sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
357 sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
358 sndbuf = ALIGN(sndbuf, PAGE_SIZE);
359 rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
360 rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
361 rcvbuf = ALIGN(rcvbuf, PAGE_SIZE);
362 }
363
364 ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
365 conn_from_host ? new : sk);
366 if (ret != 0) {
367 if (conn_from_host) {
368 hvs_new->chan = NULL;
369 sock_put(new);
370 } else {
371 hvs->chan = NULL;
372 }
373 goto out;
374 }
375
376 set_per_channel_state(chan, conn_from_host ? new : sk);
377
378 /* This reference will be dropped by hvs_close_connection(). */
379 sock_hold(conn_from_host ? new : sk);
380 vmbus_set_chn_rescind_callback(chan, hvs_close_connection);
381
382 /* Set the pending send size to max packet size to always get
383 * notifications from the host when there is enough writable space.
384 * The host is optimized to send notifications only when the pending
385 * size boundary is crossed, and not always.
386 */
387 hvs_set_channel_pending_send_size(chan);
388
389 if (conn_from_host) {
390 new->sk_state = TCP_ESTABLISHED;
391 sk->sk_ack_backlog++;
392
393 hvs_addr_init(&vnew->local_addr, if_type);
394 hvs_new->vm_srv_id = *if_type;
395 hvs_new->host_srv_id = *if_instance;
396
397 vsock_insert_connected(vnew);
398
399 vsock_enqueue_accept(sk, new);
400 } else {
401 sk->sk_state = TCP_ESTABLISHED;
402 sk->sk_socket->state = SS_CONNECTED;
403
404 vsock_insert_connected(vsock_sk(sk));
405 }
406
407 sk->sk_state_change(sk);
408
409 out:
410 /* Release refcnt obtained when we called vsock_find_bound_socket() */
411 sock_put(sk);
412
413 release_sock(sk);
414 }
415
hvs_get_local_cid(void)416 static u32 hvs_get_local_cid(void)
417 {
418 return VMADDR_CID_ANY;
419 }
420
hvs_sock_init(struct vsock_sock * vsk,struct vsock_sock * psk)421 static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk)
422 {
423 struct hvsock *hvs;
424 struct sock *sk = sk_vsock(vsk);
425
426 hvs = kzalloc(sizeof(*hvs), GFP_KERNEL);
427 if (!hvs)
428 return -ENOMEM;
429
430 vsk->trans = hvs;
431 hvs->vsk = vsk;
432 sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE;
433 sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE;
434 return 0;
435 }
436
hvs_connect(struct vsock_sock * vsk)437 static int hvs_connect(struct vsock_sock *vsk)
438 {
439 union hvs_service_id vm, host;
440 struct hvsock *h = vsk->trans;
441
442 vm.srv_id = srv_id_template;
443 vm.svm_port = vsk->local_addr.svm_port;
444 h->vm_srv_id = vm.srv_id;
445
446 host.srv_id = srv_id_template;
447 host.svm_port = vsk->remote_addr.svm_port;
448 h->host_srv_id = host.srv_id;
449
450 return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id);
451 }
452
hvs_shutdown_lock_held(struct hvsock * hvs,int mode)453 static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode)
454 {
455 struct vmpipe_proto_header hdr;
456
457 if (hvs->fin_sent || !hvs->chan)
458 return;
459
460 /* It can't fail: see hvs_channel_writable_bytes(). */
461 (void)hvs_send_data(hvs->chan, (struct hvs_send_buf *)&hdr, 0);
462 hvs->fin_sent = true;
463 }
464
hvs_shutdown(struct vsock_sock * vsk,int mode)465 static int hvs_shutdown(struct vsock_sock *vsk, int mode)
466 {
467 if (!(mode & SEND_SHUTDOWN))
468 return 0;
469
470 hvs_shutdown_lock_held(vsk->trans, mode);
471 return 0;
472 }
473
hvs_close_timeout(struct work_struct * work)474 static void hvs_close_timeout(struct work_struct *work)
475 {
476 struct vsock_sock *vsk =
477 container_of(work, struct vsock_sock, close_work.work);
478 struct sock *sk = sk_vsock(vsk);
479
480 sock_hold(sk);
481 lock_sock(sk);
482 if (!sock_flag(sk, SOCK_DONE))
483 hvs_do_close_lock_held(vsk, false);
484
485 vsk->close_work_scheduled = false;
486 release_sock(sk);
487 sock_put(sk);
488 }
489
490 /* Returns true, if it is safe to remove socket; false otherwise */
hvs_close_lock_held(struct vsock_sock * vsk)491 static bool hvs_close_lock_held(struct vsock_sock *vsk)
492 {
493 struct sock *sk = sk_vsock(vsk);
494
495 if (!(sk->sk_state == TCP_ESTABLISHED ||
496 sk->sk_state == TCP_CLOSING))
497 return true;
498
499 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
500 hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK);
501
502 if (sock_flag(sk, SOCK_DONE))
503 return true;
504
505 /* This reference will be dropped by the delayed close routine */
506 sock_hold(sk);
507 INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout);
508 vsk->close_work_scheduled = true;
509 schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT);
510 return false;
511 }
512
hvs_release(struct vsock_sock * vsk)513 static void hvs_release(struct vsock_sock *vsk)
514 {
515 struct sock *sk = sk_vsock(vsk);
516 bool remove_sock;
517
518 lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
519 remove_sock = hvs_close_lock_held(vsk);
520 release_sock(sk);
521 if (remove_sock)
522 vsock_remove_sock(vsk);
523 }
524
hvs_destruct(struct vsock_sock * vsk)525 static void hvs_destruct(struct vsock_sock *vsk)
526 {
527 struct hvsock *hvs = vsk->trans;
528 struct vmbus_channel *chan = hvs->chan;
529
530 if (chan)
531 vmbus_hvsock_device_unregister(chan);
532
533 kfree(hvs);
534 }
535
hvs_dgram_bind(struct vsock_sock * vsk,struct sockaddr_vm * addr)536 static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr)
537 {
538 return -EOPNOTSUPP;
539 }
540
hvs_dgram_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len,int flags)541 static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
542 size_t len, int flags)
543 {
544 return -EOPNOTSUPP;
545 }
546
hvs_dgram_enqueue(struct vsock_sock * vsk,struct sockaddr_vm * remote,struct msghdr * msg,size_t dgram_len)547 static int hvs_dgram_enqueue(struct vsock_sock *vsk,
548 struct sockaddr_vm *remote, struct msghdr *msg,
549 size_t dgram_len)
550 {
551 return -EOPNOTSUPP;
552 }
553
hvs_dgram_allow(u32 cid,u32 port)554 static bool hvs_dgram_allow(u32 cid, u32 port)
555 {
556 return false;
557 }
558
hvs_update_recv_data(struct hvsock * hvs)559 static int hvs_update_recv_data(struct hvsock *hvs)
560 {
561 struct hvs_recv_buf *recv_buf;
562 u32 payload_len;
563
564 recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
565 payload_len = recv_buf->hdr.data_size;
566
567 if (payload_len > HVS_MTU_SIZE)
568 return -EIO;
569
570 if (payload_len == 0)
571 hvs->vsk->peer_shutdown |= SEND_SHUTDOWN;
572
573 hvs->recv_data_len = payload_len;
574 hvs->recv_data_off = 0;
575
576 return 0;
577 }
578
hvs_stream_dequeue(struct vsock_sock * vsk,struct msghdr * msg,size_t len,int flags)579 static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
580 size_t len, int flags)
581 {
582 struct hvsock *hvs = vsk->trans;
583 bool need_refill = !hvs->recv_desc;
584 struct hvs_recv_buf *recv_buf;
585 u32 to_read;
586 int ret;
587
588 if (flags & MSG_PEEK)
589 return -EOPNOTSUPP;
590
591 if (need_refill) {
592 hvs->recv_desc = hv_pkt_iter_first(hvs->chan);
593 ret = hvs_update_recv_data(hvs);
594 if (ret)
595 return ret;
596 }
597
598 recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1);
599 to_read = min_t(u32, len, hvs->recv_data_len);
600 ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read);
601 if (ret != 0)
602 return ret;
603
604 hvs->recv_data_len -= to_read;
605 if (hvs->recv_data_len == 0) {
606 hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc);
607 if (hvs->recv_desc) {
608 ret = hvs_update_recv_data(hvs);
609 if (ret)
610 return ret;
611 }
612 } else {
613 hvs->recv_data_off += to_read;
614 }
615
616 return to_read;
617 }
618
hvs_stream_enqueue(struct vsock_sock * vsk,struct msghdr * msg,size_t len)619 static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
620 size_t len)
621 {
622 struct hvsock *hvs = vsk->trans;
623 struct vmbus_channel *chan = hvs->chan;
624 struct hvs_send_buf *send_buf;
625 ssize_t to_write, max_writable;
626 ssize_t ret = 0;
627 ssize_t bytes_written = 0;
628
629 BUILD_BUG_ON(sizeof(*send_buf) != PAGE_SIZE_4K);
630
631 send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
632 if (!send_buf)
633 return -ENOMEM;
634
635 /* Reader(s) could be draining data from the channel as we write.
636 * Maximize bandwidth, by iterating until the channel is found to be
637 * full.
638 */
639 while (len) {
640 max_writable = hvs_channel_writable_bytes(chan);
641 if (!max_writable)
642 break;
643 to_write = min_t(ssize_t, len, max_writable);
644 to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE);
645 /* memcpy_from_msg is safe for loop as it advances the offsets
646 * within the message iterator.
647 */
648 ret = memcpy_from_msg(send_buf->data, msg, to_write);
649 if (ret < 0)
650 goto out;
651
652 ret = hvs_send_data(hvs->chan, send_buf, to_write);
653 if (ret < 0)
654 goto out;
655
656 bytes_written += to_write;
657 len -= to_write;
658 }
659 out:
660 /* If any data has been sent, return that */
661 if (bytes_written)
662 ret = bytes_written;
663 kfree(send_buf);
664 return ret;
665 }
666
hvs_stream_has_data(struct vsock_sock * vsk)667 static s64 hvs_stream_has_data(struct vsock_sock *vsk)
668 {
669 struct hvsock *hvs = vsk->trans;
670 s64 ret;
671
672 if (hvs->recv_data_len > 0)
673 return 1;
674
675 switch (hvs_channel_readable_payload(hvs->chan)) {
676 case 1:
677 ret = 1;
678 break;
679 case 0:
680 vsk->peer_shutdown |= SEND_SHUTDOWN;
681 ret = 0;
682 break;
683 default: /* -1 */
684 ret = 0;
685 break;
686 }
687
688 return ret;
689 }
690
hvs_stream_has_space(struct vsock_sock * vsk)691 static s64 hvs_stream_has_space(struct vsock_sock *vsk)
692 {
693 struct hvsock *hvs = vsk->trans;
694
695 return hvs_channel_writable_bytes(hvs->chan);
696 }
697
hvs_stream_rcvhiwat(struct vsock_sock * vsk)698 static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk)
699 {
700 return HVS_MTU_SIZE + 1;
701 }
702
hvs_stream_is_active(struct vsock_sock * vsk)703 static bool hvs_stream_is_active(struct vsock_sock *vsk)
704 {
705 struct hvsock *hvs = vsk->trans;
706
707 return hvs->chan != NULL;
708 }
709
hvs_stream_allow(u32 cid,u32 port)710 static bool hvs_stream_allow(u32 cid, u32 port)
711 {
712 if (cid == VMADDR_CID_HOST)
713 return true;
714
715 return false;
716 }
717
718 static
hvs_notify_poll_in(struct vsock_sock * vsk,size_t target,bool * readable)719 int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable)
720 {
721 struct hvsock *hvs = vsk->trans;
722
723 *readable = hvs_channel_readable(hvs->chan);
724 return 0;
725 }
726
727 static
hvs_notify_poll_out(struct vsock_sock * vsk,size_t target,bool * writable)728 int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable)
729 {
730 *writable = hvs_stream_has_space(vsk) > 0;
731
732 return 0;
733 }
734
735 static
hvs_notify_recv_init(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * d)736 int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target,
737 struct vsock_transport_recv_notify_data *d)
738 {
739 return 0;
740 }
741
742 static
hvs_notify_recv_pre_block(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * d)743 int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target,
744 struct vsock_transport_recv_notify_data *d)
745 {
746 return 0;
747 }
748
749 static
hvs_notify_recv_pre_dequeue(struct vsock_sock * vsk,size_t target,struct vsock_transport_recv_notify_data * d)750 int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target,
751 struct vsock_transport_recv_notify_data *d)
752 {
753 return 0;
754 }
755
756 static
hvs_notify_recv_post_dequeue(struct vsock_sock * vsk,size_t target,ssize_t copied,bool data_read,struct vsock_transport_recv_notify_data * d)757 int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target,
758 ssize_t copied, bool data_read,
759 struct vsock_transport_recv_notify_data *d)
760 {
761 return 0;
762 }
763
764 static
hvs_notify_send_init(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * d)765 int hvs_notify_send_init(struct vsock_sock *vsk,
766 struct vsock_transport_send_notify_data *d)
767 {
768 return 0;
769 }
770
771 static
hvs_notify_send_pre_block(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * d)772 int hvs_notify_send_pre_block(struct vsock_sock *vsk,
773 struct vsock_transport_send_notify_data *d)
774 {
775 return 0;
776 }
777
778 static
hvs_notify_send_pre_enqueue(struct vsock_sock * vsk,struct vsock_transport_send_notify_data * d)779 int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk,
780 struct vsock_transport_send_notify_data *d)
781 {
782 return 0;
783 }
784
785 static
hvs_notify_send_post_enqueue(struct vsock_sock * vsk,ssize_t written,struct vsock_transport_send_notify_data * d)786 int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
787 struct vsock_transport_send_notify_data *d)
788 {
789 return 0;
790 }
791
hvs_set_buffer_size(struct vsock_sock * vsk,u64 val)792 static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val)
793 {
794 /* Ignored. */
795 }
796
hvs_set_min_buffer_size(struct vsock_sock * vsk,u64 val)797 static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
798 {
799 /* Ignored. */
800 }
801
hvs_set_max_buffer_size(struct vsock_sock * vsk,u64 val)802 static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
803 {
804 /* Ignored. */
805 }
806
hvs_get_buffer_size(struct vsock_sock * vsk)807 static u64 hvs_get_buffer_size(struct vsock_sock *vsk)
808 {
809 return -ENOPROTOOPT;
810 }
811
hvs_get_min_buffer_size(struct vsock_sock * vsk)812 static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk)
813 {
814 return -ENOPROTOOPT;
815 }
816
hvs_get_max_buffer_size(struct vsock_sock * vsk)817 static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk)
818 {
819 return -ENOPROTOOPT;
820 }
821
822 static struct vsock_transport hvs_transport = {
823 .get_local_cid = hvs_get_local_cid,
824
825 .init = hvs_sock_init,
826 .destruct = hvs_destruct,
827 .release = hvs_release,
828 .connect = hvs_connect,
829 .shutdown = hvs_shutdown,
830
831 .dgram_bind = hvs_dgram_bind,
832 .dgram_dequeue = hvs_dgram_dequeue,
833 .dgram_enqueue = hvs_dgram_enqueue,
834 .dgram_allow = hvs_dgram_allow,
835
836 .stream_dequeue = hvs_stream_dequeue,
837 .stream_enqueue = hvs_stream_enqueue,
838 .stream_has_data = hvs_stream_has_data,
839 .stream_has_space = hvs_stream_has_space,
840 .stream_rcvhiwat = hvs_stream_rcvhiwat,
841 .stream_is_active = hvs_stream_is_active,
842 .stream_allow = hvs_stream_allow,
843
844 .notify_poll_in = hvs_notify_poll_in,
845 .notify_poll_out = hvs_notify_poll_out,
846 .notify_recv_init = hvs_notify_recv_init,
847 .notify_recv_pre_block = hvs_notify_recv_pre_block,
848 .notify_recv_pre_dequeue = hvs_notify_recv_pre_dequeue,
849 .notify_recv_post_dequeue = hvs_notify_recv_post_dequeue,
850 .notify_send_init = hvs_notify_send_init,
851 .notify_send_pre_block = hvs_notify_send_pre_block,
852 .notify_send_pre_enqueue = hvs_notify_send_pre_enqueue,
853 .notify_send_post_enqueue = hvs_notify_send_post_enqueue,
854
855 .set_buffer_size = hvs_set_buffer_size,
856 .set_min_buffer_size = hvs_set_min_buffer_size,
857 .set_max_buffer_size = hvs_set_max_buffer_size,
858 .get_buffer_size = hvs_get_buffer_size,
859 .get_min_buffer_size = hvs_get_min_buffer_size,
860 .get_max_buffer_size = hvs_get_max_buffer_size,
861 };
862
hvs_probe(struct hv_device * hdev,const struct hv_vmbus_device_id * dev_id)863 static int hvs_probe(struct hv_device *hdev,
864 const struct hv_vmbus_device_id *dev_id)
865 {
866 struct vmbus_channel *chan = hdev->channel;
867
868 hvs_open_connection(chan);
869
870 /* Always return success to suppress the unnecessary error message
871 * in vmbus_probe(): on error the host will rescind the device in
872 * 30 seconds and we can do cleanup at that time in
873 * vmbus_onoffer_rescind().
874 */
875 return 0;
876 }
877
hvs_remove(struct hv_device * hdev)878 static int hvs_remove(struct hv_device *hdev)
879 {
880 struct vmbus_channel *chan = hdev->channel;
881
882 vmbus_close(chan);
883
884 return 0;
885 }
886
887 /* This isn't really used. See vmbus_match() and vmbus_probe() */
888 static const struct hv_vmbus_device_id id_table[] = {
889 {},
890 };
891
892 static struct hv_driver hvs_drv = {
893 .name = "hv_sock",
894 .hvsock = true,
895 .id_table = id_table,
896 .probe = hvs_probe,
897 .remove = hvs_remove,
898 };
899
hvs_init(void)900 static int __init hvs_init(void)
901 {
902 int ret;
903
904 if (vmbus_proto_version < VERSION_WIN10)
905 return -ENODEV;
906
907 ret = vmbus_driver_register(&hvs_drv);
908 if (ret != 0)
909 return ret;
910
911 ret = vsock_core_init(&hvs_transport);
912 if (ret) {
913 vmbus_driver_unregister(&hvs_drv);
914 return ret;
915 }
916
917 return 0;
918 }
919
hvs_exit(void)920 static void __exit hvs_exit(void)
921 {
922 vsock_core_exit();
923 vmbus_driver_unregister(&hvs_drv);
924 }
925
926 module_init(hvs_init);
927 module_exit(hvs_exit);
928
929 MODULE_DESCRIPTION("Hyper-V Sockets");
930 MODULE_VERSION("1.0.0");
931 MODULE_LICENSE("GPL");
932 MODULE_ALIAS_NETPROTO(PF_VSOCK);
933