• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 - 2020 Intel Corporation
3  *
4  * SPDX-License-Identifier: BSD-3-Clause
5  */
6 #include <stdlib.h>
7 #include <usfstl/list.h>
8 #include <usfstl/loop.h>
9 #include <usfstl/uds.h>
10 #include <sys/socket.h>
11 #include <sys/mman.h>
12 #include <sys/un.h>
13 #include <stdlib.h>
14 #include <errno.h>
15 #include <usfstl/vhost.h>
16 #include <linux/virtio_ring.h>
17 #include <linux/virtio_config.h>
18 #include <endian.h>
19 
20 /* copied from uapi */
21 #define VIRTIO_F_VERSION_1		32
22 
23 #define MAX_REGIONS 8
24 #define SG_STACK_PREALLOC 5
25 
26 struct usfstl_vhost_user_dev_int {
27 	struct usfstl_list fds;
28 	struct usfstl_job irq_job;
29 
30 	struct usfstl_loop_entry entry;
31 
32 	struct usfstl_vhost_user_dev ext;
33 
34 	unsigned int n_regions;
35 	struct vhost_user_region regions[MAX_REGIONS];
36 	int region_fds[MAX_REGIONS];
37 	void *region_vaddr[MAX_REGIONS];
38 
39 	int req_fd;
40 
41 	struct {
42 		struct usfstl_loop_entry entry;
43 		bool enabled;
44 		bool sleeping;
45 		bool triggered;
46 		uint64_t desc_guest_addr;
47 		uint64_t avail_guest_addr;
48 		uint64_t used_guest_addr;
49 		struct vring virtq;
50 		int call_fd;
51 		uint16_t last_avail_idx;
52 	} virtqs[];
53 };
54 
55 #define CONV(bits)							\
56 static inline uint##bits##_t __attribute__((used))			\
57 cpu_to_virtio##bits(struct usfstl_vhost_user_dev_int *dev,		\
58 		    uint##bits##_t v)					\
59 {									\
60 	if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1))		\
61 		return htole##bits(v);					\
62 	return v;							\
63 }									\
64 static inline uint##bits##_t __attribute__((used))			\
65 virtio_to_cpu##bits(struct usfstl_vhost_user_dev_int *dev,		\
66 		    uint##bits##_t v)					\
67 {									\
68 	if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1))		\
69 		return le##bits##toh(v);				\
70 	return v;							\
71 }
72 
73 CONV(16)
74 CONV(32)
75 CONV(64)
76 
77 static struct usfstl_vhost_user_buf *
usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx,struct usfstl_vhost_user_buf * fixed)78 usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
79 				unsigned int virtq_idx,
80 				struct usfstl_vhost_user_buf *fixed)
81 {
82 	struct usfstl_vhost_user_buf *buf = fixed;
83 	struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
84 	uint16_t avail_idx = virtio_to_cpu16(dev, virtq->avail->idx);
85 	uint16_t idx, desc_idx;
86 	struct vring_desc *desc;
87 	unsigned int n_in = 0, n_out = 0;
88 	bool more;
89 
90 	if (avail_idx == dev->virtqs[virtq_idx].last_avail_idx)
91 		return NULL;
92 
93 	/* ensure we read the descriptor after checking the index */
94 	__sync_synchronize();
95 
96 	idx = dev->virtqs[virtq_idx].last_avail_idx++;
97 	idx %= virtq->num;
98 	desc_idx = virtio_to_cpu16(dev, virtq->avail->ring[idx]);
99 	USFSTL_ASSERT(desc_idx < virtq->num);
100 
101 	desc = &virtq->desc[desc_idx];
102 	do {
103 		more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
104 
105 		if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE)
106 			n_in++;
107 		else
108 			n_out++;
109 		desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
110 	} while (more);
111 
112 	if (n_in > fixed->n_in_sg || n_out > fixed->n_out_sg) {
113 		size_t sz = sizeof(*buf);
114 		struct iovec *vec;
115 
116 		sz += (n_in + n_out) * sizeof(*vec);
117 
118 		buf = calloc(1, sz);
119 		if (!buf)
120 			return NULL;
121 
122 		vec = (void *)(buf + 1);
123 		buf->in_sg = vec;
124 		buf->out_sg = vec + n_in;
125 		buf->allocated = true;
126 	}
127 
128 	buf->n_in_sg = 0;
129 	buf->n_out_sg = 0;
130 	buf->idx = desc_idx;
131 
132 	desc = &virtq->desc[desc_idx];
133 	do {
134 		struct iovec *vec;
135 		uint64_t addr;
136 
137 		more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
138 
139 		if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE) {
140 			vec = &buf->in_sg[buf->n_in_sg];
141 			buf->n_in_sg++;
142 		} else {
143 			vec = &buf->out_sg[buf->n_out_sg];
144 			buf->n_out_sg++;
145 		}
146 
147 		addr = virtio_to_cpu64(dev, desc->addr);
148 		vec->iov_base = usfstl_vhost_phys_to_va(&dev->ext, addr);
149 		vec->iov_len = virtio_to_cpu32(dev, desc->len);
150 
151 		desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
152 	} while (more);
153 
154 	return buf;
155 }
156 
usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf * buf)157 static void usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf *buf)
158 {
159 	if (buf->allocated)
160 		free(buf);
161 }
162 
usfstl_vhost_user_readable_handler(struct usfstl_loop_entry * entry)163 static void usfstl_vhost_user_readable_handler(struct usfstl_loop_entry *entry)
164 {
165 	usfstl_loop_unregister(entry);
166 	entry->fd = -1;
167 }
168 
usfstl_vhost_user_read_msg(int fd,struct msghdr * msghdr)169 static int usfstl_vhost_user_read_msg(int fd, struct msghdr *msghdr)
170 {
171 	struct iovec msg_iov;
172 	struct msghdr hdr2 = {
173 		.msg_iov = &msg_iov,
174 		.msg_iovlen = 1,
175 		.msg_control = msghdr->msg_control,
176 		.msg_controllen = msghdr->msg_controllen,
177 	};
178 	struct vhost_user_msg_hdr *hdr;
179 	size_t i;
180 	size_t maxlen = 0;
181 	ssize_t len;
182 	ssize_t prev_datalen;
183 	size_t prev_iovlen;
184 
185 	USFSTL_ASSERT(msghdr->msg_iovlen >= 1);
186 	USFSTL_ASSERT(msghdr->msg_iov[0].iov_len >= sizeof(*hdr));
187 
188 	hdr = msghdr->msg_iov[0].iov_base;
189 	msg_iov.iov_base = hdr;
190 	msg_iov.iov_len = sizeof(*hdr);
191 
192 	len = recvmsg(fd, &hdr2, 0);
193 	if (len < 0)
194 		return -errno;
195 	if (len == 0)
196 		return -ENOTCONN;
197 
198 	for (i = 0; i < msghdr->msg_iovlen; i++)
199 		maxlen += msghdr->msg_iov[i].iov_len;
200 	maxlen -= sizeof(*hdr);
201 
202 	USFSTL_ASSERT_EQ((int)len, (int)sizeof(*hdr), "%d");
203 	USFSTL_ASSERT(hdr->size <= maxlen);
204 
205 	if (!hdr->size)
206 		return 0;
207 
208 	prev_iovlen = msghdr->msg_iovlen;
209 	msghdr->msg_iovlen = 1;
210 
211 	msghdr->msg_control = NULL;
212 	msghdr->msg_controllen = 0;
213 	msghdr->msg_iov[0].iov_base += sizeof(*hdr);
214 	prev_datalen = msghdr->msg_iov[0].iov_len;
215 	msghdr->msg_iov[0].iov_len = hdr->size;
216 	len = recvmsg(fd, msghdr, 0);
217 
218 	/* restore just in case the user needs it */
219 	msghdr->msg_iov[0].iov_base -= sizeof(*hdr);
220 	msghdr->msg_iov[0].iov_len = prev_datalen;
221 	msghdr->msg_control = hdr2.msg_control;
222 	msghdr->msg_controllen = hdr2.msg_controllen;
223 
224 	msghdr->msg_iovlen = prev_iovlen;
225 
226 	if (len < 0)
227 		return -errno;
228 	if (len == 0)
229 		return -ENOTCONN;
230 
231 	USFSTL_ASSERT_EQ(hdr->size, (uint32_t)len, "%u");
232 
233 	return 0;
234 }
235 
usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int * dev,struct vhost_user_msg * msg)236 static void usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int *dev,
237 				       struct vhost_user_msg *msg)
238 {
239 	size_t msgsz = sizeof(msg->hdr) + msg->hdr.size;
240 	bool ack = dev->ext.protocol_features &
241 		   (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
242 	ssize_t written;
243 
244 	if (ack)
245 		msg->hdr.flags |= VHOST_USER_MSG_FLAGS_NEED_REPLY;
246 
247 	written = write(dev->req_fd, msg, msgsz);
248 	USFSTL_ASSERT_EQ(written, (ssize_t)msgsz, "%zd");
249 
250 	if (ack) {
251 		struct usfstl_loop_entry entry = {
252 			.fd = dev->req_fd,
253 			.priority = 0x7fffffff, // max
254 			.handler = usfstl_vhost_user_readable_handler,
255 		};
256 		struct iovec msg_iov = {
257 			.iov_base = msg,
258 			.iov_len = sizeof(*msg),
259 		};
260 		struct msghdr msghdr = {
261 			.msg_iovlen = 1,
262 			.msg_iov = &msg_iov,
263 		};
264 
265 		/*
266 		 * Wait for the fd to be readable - we may have to
267 		 * handle other simulation (time) messages while
268 		 * waiting ...
269 		 */
270 		usfstl_loop_register(&entry);
271 		while (entry.fd != -1)
272 			usfstl_loop_wait_and_handle();
273 		USFSTL_ASSERT_EQ(usfstl_vhost_user_read_msg(dev->req_fd,
274 							    &msghdr),
275 				 0, "%d");
276 	}
277 }
278 
usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int * dev,struct usfstl_vhost_user_buf * buf,int virtq_idx)279 static void usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
280 					     struct usfstl_vhost_user_buf *buf,
281 					     int virtq_idx)
282 {
283 	struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
284 	unsigned int idx, widx;
285 	int call_fd = dev->virtqs[virtq_idx].call_fd;
286 	ssize_t written;
287 	uint64_t e = 1;
288 
289 	if (dev->ext.server->ctrl)
290 		usfstl_sched_ctrl_sync_to(dev->ext.server->ctrl);
291 
292 	idx = virtio_to_cpu16(dev, virtq->used->idx);
293 	widx = idx + 1;
294 
295 	idx %= virtq->num;
296 	virtq->used->ring[idx].id = cpu_to_virtio32(dev, buf->idx);
297 	virtq->used->ring[idx].len = cpu_to_virtio32(dev, buf->written);
298 
299 	/* write buffers / used table before flush */
300 	__sync_synchronize();
301 
302 	virtq->used->idx = cpu_to_virtio16(dev, widx);
303 
304 	if (call_fd < 0 &&
305 	    dev->ext.protocol_features &
306 			(1ULL << VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
307 	    dev->ext.protocol_features &
308 			(1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
309 		struct vhost_user_msg msg = {
310 			.hdr.request = VHOST_USER_SLAVE_VRING_CALL,
311 			.hdr.flags = VHOST_USER_VERSION,
312 			.hdr.size = sizeof(msg.payload.vring_state),
313 			.payload.vring_state = {
314 				.idx = virtq_idx,
315 			},
316 		};
317 
318 		usfstl_vhost_user_send_msg(dev, &msg);
319 		return;
320 	}
321 
322 	written = write(dev->virtqs[virtq_idx].call_fd, &e, sizeof(e));
323 	USFSTL_ASSERT_EQ(written, (ssize_t)sizeof(e), "%zd");
324 }
325 
usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx)326 static void usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int *dev,
327 					   unsigned int virtq_idx)
328 {
329 	/* preallocate on the stack for most cases */
330 	struct iovec in_sg[SG_STACK_PREALLOC] = { };
331 	struct iovec out_sg[SG_STACK_PREALLOC] = { };
332 	struct usfstl_vhost_user_buf _buf = {
333 		.in_sg = in_sg,
334 		.n_in_sg = SG_STACK_PREALLOC,
335 		.out_sg = out_sg,
336 		.n_out_sg = SG_STACK_PREALLOC,
337 	};
338 	struct usfstl_vhost_user_buf *buf;
339 
340 	while ((buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf))) {
341 		dev->ext.server->ops->handle(&dev->ext, buf, virtq_idx);
342 
343 		usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
344 		usfstl_vhost_user_free_buf(buf);
345 	}
346 }
347 
usfstl_vhost_user_job_callback(struct usfstl_job * job)348 static void usfstl_vhost_user_job_callback(struct usfstl_job *job)
349 {
350 	struct usfstl_vhost_user_dev_int *dev = job->data;
351 	unsigned int virtq;
352 
353 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
354 		if (!dev->virtqs[virtq].triggered)
355 			continue;
356 		dev->virtqs[virtq].triggered = false;
357 
358 		usfstl_vhost_user_handle_queue(dev, virtq);
359 	}
360 }
361 
usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq)362 static void usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
363 					 unsigned int virtq)
364 {
365 	if (!(dev->ext.server->input_queues & (1ULL << virtq)))
366 		return;
367 
368 	dev->virtqs[virtq].triggered = true;
369 
370 	if (usfstl_job_scheduled(&dev->irq_job))
371 		return;
372 
373 	if (!dev->ext.server->scheduler) {
374 		usfstl_vhost_user_job_callback(&dev->irq_job);
375 		return;
376 	}
377 
378 	if (dev->ext.server->ctrl)
379 		usfstl_sched_ctrl_sync_from(dev->ext.server->ctrl);
380 
381 	dev->irq_job.start = usfstl_sched_current_time(dev->ext.server->scheduler) +
382 			     dev->ext.server->interrupt_latency;
383 	usfstl_sched_add_job(dev->ext.server->scheduler, &dev->irq_job);
384 }
385 
usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry * entry)386 static void usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry *entry)
387 {
388 	struct usfstl_vhost_user_dev_int *dev = entry->data;
389 	unsigned int virtq;
390 	uint64_t v;
391 
392 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
393 		if (entry == &dev->virtqs[virtq].entry)
394 			break;
395 	}
396 
397 	USFSTL_ASSERT(virtq < dev->ext.server->max_queues);
398 
399 	USFSTL_ASSERT_EQ((int)read(entry->fd, &v, sizeof(v)),
400 		       (int)sizeof(v), "%d");
401 
402 	usfstl_vhost_user_virtq_kick(dev, virtq);
403 }
404 
usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int * dev)405 static void usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int *dev)
406 {
407 	unsigned int idx;
408 	for (idx = 0; idx < MAX_REGIONS; idx++) {
409 		if (dev->region_vaddr[idx]) {
410 			munmap(dev->region_vaddr[idx],
411 			       dev->regions[idx].size + dev->regions[idx].mmap_offset);
412 			dev->region_vaddr[idx] = NULL;
413 		}
414 
415 		if (dev->region_fds[idx] != -1) {
416 			close(dev->region_fds[idx]);
417 			dev->region_fds[idx] = -1;
418 		}
419 	}
420 }
421 
usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int * dev)422 static void usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int *dev)
423 {
424 	unsigned int idx;
425 
426 	for (idx = 0; idx < dev->n_regions; idx++) {
427 		USFSTL_ASSERT(!dev->region_vaddr[idx]);
428 
429 		/*
430 		 * Cannot rely on the offset being page-aligned, I think ...
431 		 * adjust for it later when we translate addresses instead.
432 		 */
433 		dev->region_vaddr[idx] = mmap(NULL,
434 					      dev->regions[idx].size +
435 					      dev->regions[idx].mmap_offset,
436 					      PROT_READ | PROT_WRITE, MAP_SHARED,
437 					      dev->region_fds[idx], 0);
438 		USFSTL_ASSERT(dev->region_vaddr[idx] != (void *)-1,
439 			      "mmap() failed (%d) for fd %d", errno, dev->region_fds[idx]);
440 	}
441 }
442 
443 static void
usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq,int fd)444 usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
445 				  unsigned int virtq, int fd)
446 {
447 	if (dev->virtqs[virtq].entry.fd != -1) {
448 		usfstl_loop_unregister(&dev->virtqs[virtq].entry);
449 		close(dev->virtqs[virtq].entry.fd);
450 	}
451 
452 	if (fd != -1) {
453 		dev->virtqs[virtq].entry.fd = fd;
454 		usfstl_loop_register(&dev->virtqs[virtq].entry);
455 	}
456 }
457 
usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int * dev)458 static void usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int *dev)
459 {
460 	unsigned int virtq;
461 
462 	usfstl_loop_unregister(&dev->entry);
463 	usfstl_sched_del_job(&dev->irq_job);
464 
465 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
466 		usfstl_vhost_user_update_virtq_kick(dev, virtq, -1);
467 		if (dev->virtqs[virtq].call_fd != -1)
468 			close(dev->virtqs[virtq].call_fd);
469 	}
470 
471 	usfstl_vhost_user_clear_mappings(dev);
472 
473 	if (dev->req_fd != -1)
474 		close(dev->req_fd);
475 
476 	if (dev->ext.server->ops->disconnected)
477 		dev->ext.server->ops->disconnected(&dev->ext);
478 
479 	if (dev->entry.fd != -1)
480 		close(dev->entry.fd);
481 
482 	free(dev);
483 }
484 
usfstl_vhost_user_get_msg_fds(struct msghdr * msghdr,int * outfds,int max_fds)485 static void usfstl_vhost_user_get_msg_fds(struct msghdr *msghdr,
486 					  int *outfds, int max_fds)
487 {
488 	struct cmsghdr *msg;
489 	int fds;
490 
491 	for (msg = CMSG_FIRSTHDR(msghdr); msg; msg = CMSG_NXTHDR(msghdr, msg)) {
492 		if (msg->cmsg_level != SOL_SOCKET)
493 			continue;
494 		if (msg->cmsg_type != SCM_RIGHTS)
495 			continue;
496 
497 		fds = (msg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
498 		USFSTL_ASSERT(fds <= max_fds);
499 		memcpy(outfds, CMSG_DATA(msg), fds * sizeof(int));
500 		break;
501 	}
502 }
503 
usfstl_vhost_user_handle_msg(struct usfstl_loop_entry * entry)504 static void usfstl_vhost_user_handle_msg(struct usfstl_loop_entry *entry)
505 {
506 	struct usfstl_vhost_user_dev_int *dev;
507 	struct vhost_user_msg msg;
508 	uint8_t data[256]; // limits the config space size ...
509 	struct iovec msg_iov[3] = {
510 		[0] = {
511 			.iov_base = &msg.hdr,
512 			.iov_len = sizeof(msg.hdr),
513 		},
514 		[1] = {
515 			.iov_base = &msg.payload,
516 			.iov_len = sizeof(msg.payload),
517 		},
518 		[2] = {
519 			.iov_base = data,
520 			.iov_len = sizeof(data),
521 		},
522 	};
523 	uint8_t msg_control[CMSG_SPACE(sizeof(int) * MAX_REGIONS)] = { 0 };
524 	struct msghdr msghdr = {
525 		.msg_iov = msg_iov,
526 		.msg_iovlen = 3,
527 		.msg_control = msg_control,
528 		.msg_controllen = sizeof(msg_control),
529 	};
530 	ssize_t len;
531 	size_t reply_len = 0;
532 	unsigned int virtq;
533 	int fd;
534 
535 	dev = container_of(entry, struct usfstl_vhost_user_dev_int, entry);
536 
537 	if (usfstl_vhost_user_read_msg(entry->fd, &msghdr)) {
538 		usfstl_vhost_user_dev_free(dev);
539 		return;
540 	}
541 	len = msg.hdr.size;
542 
543 	USFSTL_ASSERT((msg.hdr.flags & VHOST_USER_MSG_FLAGS_VERSION) ==
544 		    VHOST_USER_VERSION);
545 
546 	switch (msg.hdr.request) {
547 	case VHOST_USER_GET_FEATURES:
548 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
549 		reply_len = sizeof(uint64_t);
550 		msg.payload.u64 = dev->ext.server->features;
551 		msg.payload.u64 |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
552 		break;
553 	case VHOST_USER_SET_FEATURES:
554 		USFSTL_ASSERT_EQ(len, (ssize_t)sizeof(msg.payload.u64), "%zd");
555 		dev->ext.features = msg.payload.u64;
556 		break;
557 	case VHOST_USER_SET_OWNER:
558 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
559 		/* nothing to be done */
560 		break;
561 	case VHOST_USER_SET_MEM_TABLE:
562 		USFSTL_ASSERT(len <= (int)sizeof(msg.payload.mem_regions));
563 		USFSTL_ASSERT(msg.payload.mem_regions.n_regions <= MAX_REGIONS);
564 		usfstl_vhost_user_clear_mappings(dev);
565 		memcpy(dev->regions, msg.payload.mem_regions.regions,
566 		       msg.payload.mem_regions.n_regions *
567 		       sizeof(dev->regions[0]));
568 		dev->n_regions = msg.payload.mem_regions.n_regions;
569 		usfstl_vhost_user_get_msg_fds(&msghdr, dev->region_fds, MAX_REGIONS);
570 		usfstl_vhost_user_setup_mappings(dev);
571 		break;
572 	case VHOST_USER_SET_VRING_NUM:
573 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
574 		USFSTL_ASSERT(msg.payload.vring_state.idx <
575 			      dev->ext.server->max_queues);
576 		dev->virtqs[msg.payload.vring_state.idx].virtq.num =
577 			msg.payload.vring_state.num;
578 		break;
579 	case VHOST_USER_SET_VRING_ADDR:
580 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_addr));
581 		USFSTL_ASSERT(msg.payload.vring_addr.idx <=
582 			      dev->ext.server->max_queues);
583 		USFSTL_ASSERT_EQ(msg.payload.vring_addr.flags, (uint32_t)0, "0x%x");
584 		USFSTL_ASSERT(!dev->virtqs[msg.payload.vring_addr.idx].enabled);
585 
586 		// Save the guest physical addresses to make snapshotting more convenient.
587 		dev->virtqs[msg.payload.vring_addr.idx].desc_guest_addr =
588 			usfstl_vhost_user_to_phys(&dev->ext, msg.payload.vring_addr.descriptor);
589 		dev->virtqs[msg.payload.vring_addr.idx].used_guest_addr =
590 			usfstl_vhost_user_to_phys(&dev->ext, msg.payload.vring_addr.used);
591 		dev->virtqs[msg.payload.vring_addr.idx].avail_guest_addr =
592 			usfstl_vhost_user_to_phys(&dev->ext, msg.payload.vring_addr.avail);
593 
594 		dev->virtqs[msg.payload.vring_addr.idx].last_avail_idx = 0;
595 		dev->virtqs[msg.payload.vring_addr.idx].virtq.desc =
596 			usfstl_vhost_user_to_va(&dev->ext,
597 					      msg.payload.vring_addr.descriptor);
598 		dev->virtqs[msg.payload.vring_addr.idx].virtq.used =
599 			usfstl_vhost_user_to_va(&dev->ext,
600 					      msg.payload.vring_addr.used);
601 		dev->virtqs[msg.payload.vring_addr.idx].virtq.avail =
602 			usfstl_vhost_user_to_va(&dev->ext,
603 					      msg.payload.vring_addr.avail);
604 		USFSTL_ASSERT(dev->virtqs[msg.payload.vring_addr.idx].virtq.avail &&
605 			    dev->virtqs[msg.payload.vring_addr.idx].virtq.desc &&
606 			    dev->virtqs[msg.payload.vring_addr.idx].virtq.used);
607 		break;
608 	case VHOST_USER_SET_VRING_BASE:
609 		/* ignored - logging not supported */
610 		/*
611 		 * FIXME: our Linux UML virtio implementation
612 		 *        shouldn't send this
613 		 */
614 		break;
615 	case VHOST_USER_SET_VRING_KICK:
616 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
617 		virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
618 		USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
619 		if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
620 			fd = -1;
621 		else
622 			usfstl_vhost_user_get_msg_fds(&msghdr, &fd, 1);
623 		usfstl_vhost_user_update_virtq_kick(dev, virtq, fd);
624 		break;
625 	case VHOST_USER_SET_VRING_CALL:
626 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
627 		virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
628 		USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
629 		if (dev->virtqs[virtq].call_fd != -1)
630 			close(dev->virtqs[virtq].call_fd);
631 		if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
632 			dev->virtqs[virtq].call_fd = -1;
633 		else
634 			usfstl_vhost_user_get_msg_fds(&msghdr,
635 						    &dev->virtqs[virtq].call_fd,
636 						    1);
637 		break;
638 	case VHOST_USER_GET_PROTOCOL_FEATURES:
639 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
640 		reply_len = sizeof(uint64_t);
641 		msg.payload.u64 = dev->ext.server->protocol_features;
642 		if (dev->ext.server->config && dev->ext.server->config_len)
643 			msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
644 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
645 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD;
646 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK;
647 		break;
648 	case VHOST_USER_SET_VRING_ENABLE:
649 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
650 		USFSTL_ASSERT(msg.payload.vring_state.idx <
651 			      dev->ext.server->max_queues);
652 		dev->virtqs[msg.payload.vring_state.idx].enabled =
653 			msg.payload.vring_state.num;
654 		break;
655 	case VHOST_USER_SET_PROTOCOL_FEATURES:
656 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
657 		dev->ext.protocol_features = msg.payload.u64;
658 		break;
659 	case VHOST_USER_SET_SLAVE_REQ_FD:
660 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
661 		if (dev->req_fd != -1)
662 			close(dev->req_fd);
663 		usfstl_vhost_user_get_msg_fds(&msghdr, &dev->req_fd, 1);
664 		USFSTL_ASSERT(dev->req_fd != -1);
665 		break;
666 	case VHOST_USER_GET_CONFIG:
667 		USFSTL_ASSERT(len == (int)(sizeof(msg.payload.cfg_space) +
668 					msg.payload.cfg_space.size));
669 		USFSTL_ASSERT(dev->ext.server->config && dev->ext.server->config_len);
670 		USFSTL_ASSERT(msg.payload.cfg_space.offset == 0);
671 		USFSTL_ASSERT(msg.payload.cfg_space.size <= dev->ext.server->config_len);
672 		msg.payload.cfg_space.flags = 0;
673 		msg_iov[1].iov_len = sizeof(msg.payload.cfg_space);
674 		msg_iov[2].iov_base = (void *)dev->ext.server->config;
675 		reply_len = len;
676 		break;
677 	case VHOST_USER_VRING_KICK:
678 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
679 		USFSTL_ASSERT(msg.payload.vring_state.idx <
680 			      dev->ext.server->max_queues);
681 		USFSTL_ASSERT(msg.payload.vring_state.num == 0);
682 		usfstl_vhost_user_virtq_kick(dev, msg.payload.vring_state.idx);
683 		break;
684 	case VHOST_USER_GET_SHARED_MEMORY_REGIONS:
685 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
686 		reply_len = sizeof(uint64_t);
687 		msg.payload.u64 = 0;
688 		break;
689 	case VHOST_USER_SLEEP:
690 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
691 		USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
692 		for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
693 			if (dev->virtqs[virtq].enabled) {
694 				dev->virtqs[virtq].enabled = false;
695 				dev->virtqs[virtq].sleeping = true;
696 				usfstl_loop_unregister(&dev->virtqs[virtq].entry);
697 			}
698 		}
699 		msg.payload.i8 = 1; // success
700 		reply_len = sizeof(msg.payload.i8);
701 		break;
702 	case VHOST_USER_WAKE:
703 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
704 		USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
705 		// enable previously enabled queues on wake
706 		for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
707 			if (dev->virtqs[virtq].sleeping) {
708 				dev->virtqs[virtq].enabled = true;
709 				dev->virtqs[virtq].sleeping = false;
710 				usfstl_loop_register(&dev->virtqs[virtq].entry);
711 				// TODO: is this needed?
712 				usfstl_vhost_user_virtq_kick(dev, virtq);
713 			}
714 		}
715 		msg.payload.i8 = 1; // success
716 		reply_len = sizeof(msg.payload.i8);
717 		break;
718 	case VHOST_USER_SNAPSHOT: {
719 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
720 		USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
721 		for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
722 			struct vring_snapshot* snapshot = &msg.payload.snapshot_response.snapshot.vrings[virtq];
723 			snapshot->enabled = dev->virtqs[virtq].enabled;
724 			snapshot->sleeping = dev->virtqs[virtq].sleeping;
725 			snapshot->triggered = dev->virtqs[virtq].triggered;
726 			snapshot->num = dev->virtqs[virtq].virtq.num;
727 			snapshot->desc_guest_addr = dev->virtqs[virtq].desc_guest_addr;
728 			snapshot->avail_guest_addr = dev->virtqs[virtq].avail_guest_addr;
729 			snapshot->used_guest_addr = dev->virtqs[virtq].used_guest_addr;
730 			snapshot->last_avail_idx = dev->virtqs[virtq].last_avail_idx;
731 		}
732 		msg.payload.snapshot_response.bool_store = 1;
733 		reply_len = (int)sizeof(msg.payload.snapshot_response);
734 		break;
735 	}
736 	case VHOST_USER_RESTORE: {
737 		int *fds;
738 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.restore_request));
739 		USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
740 
741 		fds = (int*)malloc(dev->ext.server->max_queues * sizeof(int));
742 		for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
743 			fds[virtq] = -1;
744 		}
745 		usfstl_vhost_user_get_msg_fds(&msghdr, fds, 2);
746 
747 		for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
748 			const struct vring_snapshot* snapshot = &msg.payload.restore_request.snapshot.vrings[virtq];
749 			dev->virtqs[virtq].enabled = snapshot->enabled;
750 			dev->virtqs[virtq].sleeping = snapshot->sleeping;
751 			dev->virtqs[virtq].triggered = snapshot->triggered;
752 			dev->virtqs[virtq].virtq.num = snapshot->num;
753 			dev->virtqs[virtq].desc_guest_addr = snapshot->desc_guest_addr;
754 			dev->virtqs[virtq].avail_guest_addr = snapshot->avail_guest_addr;
755 			dev->virtqs[virtq].used_guest_addr = snapshot->used_guest_addr;
756 			dev->virtqs[virtq].last_avail_idx = snapshot->last_avail_idx;
757 
758 			dev->virtqs[virtq].entry.fd = fds[virtq];
759 
760 			// Translate vring guest physical addresses.
761 			dev->virtqs[virtq].virtq.desc = usfstl_vhost_phys_to_va(&dev->ext, dev->virtqs[virtq].desc_guest_addr);
762 			dev->virtqs[virtq].virtq.used = usfstl_vhost_phys_to_va(&dev->ext, dev->virtqs[virtq].used_guest_addr);
763 			dev->virtqs[virtq].virtq.avail = usfstl_vhost_phys_to_va(&dev->ext, dev->virtqs[virtq].avail_guest_addr);
764 			USFSTL_ASSERT(dev->virtqs[virtq].virtq.avail &&
765 				      dev->virtqs[virtq].virtq.desc &&
766 				      dev->virtqs[virtq].virtq.used);
767 		}
768 
769 		free(fds);
770 
771 		msg.payload.i8 = 1; // success
772 		reply_len = sizeof(msg.payload.i8);
773 		break;
774 	}
775 	default:
776 		USFSTL_ASSERT(0, "Unsupported message: %d\n", msg.hdr.request);
777 	}
778 
779 	if (reply_len || (msg.hdr.flags & VHOST_USER_MSG_FLAGS_NEED_REPLY)) {
780 		size_t i, tmp;
781 
782 		if (!reply_len) {
783 			msg.payload.u64 = 0;
784 			reply_len = sizeof(uint64_t);
785 		}
786 
787 		msg.hdr.size = reply_len;
788 		msg.hdr.flags &= ~VHOST_USER_MSG_FLAGS_NEED_REPLY;
789 		msg.hdr.flags |= VHOST_USER_MSG_FLAGS_REPLY;
790 
791 		msghdr.msg_control = NULL;
792 		msghdr.msg_controllen = 0;
793 
794 		reply_len += sizeof(msg.hdr);
795 
796 		tmp = reply_len;
797 		for (i = 0; tmp && i < msghdr.msg_iovlen; i++) {
798 			if (tmp <= msg_iov[i].iov_len)
799 				msg_iov[i].iov_len = tmp;
800 			tmp -= msg_iov[i].iov_len;
801 		}
802 		msghdr.msg_iovlen = i;
803 
804 		while (reply_len) {
805 			len = sendmsg(entry->fd, &msghdr, 0);
806 			if (len < 0) {
807 				usfstl_vhost_user_dev_free(dev);
808 				return;
809 			}
810 			USFSTL_ASSERT(len != 0);
811 			reply_len -= len;
812 
813 			for (i = 0; len && i < msghdr.msg_iovlen; i++) {
814 				unsigned int rm = len;
815 
816 				if (msg_iov[i].iov_len <= (size_t)len)
817 					rm = msg_iov[i].iov_len;
818 				len -= rm;
819 				msg_iov[i].iov_len -= rm;
820 				msg_iov[i].iov_base += rm;
821 			}
822 		}
823 	}
824 }
825 
usfstl_vhost_user_connected(int fd,void * data)826 static void usfstl_vhost_user_connected(int fd, void *data)
827 {
828 	struct usfstl_vhost_user_server *server = data;
829 	struct usfstl_vhost_user_dev_int *dev;
830 	unsigned int i;
831 
832 	dev = calloc(1, sizeof(*dev) +
833 			sizeof(dev->virtqs[0]) * server->max_queues);
834 
835 	USFSTL_ASSERT(dev);
836 
837 	for (i = 0; i < server->max_queues; i++) {
838 		dev->virtqs[i].call_fd = -1;
839 		dev->virtqs[i].entry.fd = -1;
840 		dev->virtqs[i].entry.data = dev;
841 		dev->virtqs[i].entry.handler = usfstl_vhost_user_virtq_fdkick;
842 	}
843 
844 	for (i = 0; i < MAX_REGIONS; i++)
845 		dev->region_fds[i] = -1;
846 	dev->req_fd = -1;
847 
848 	dev->ext.server = server;
849 	dev->irq_job.data = dev;
850 	dev->irq_job.name = "vhost-user-irq";
851 	dev->irq_job.priority = 0x10000000;
852 	dev->irq_job.callback = usfstl_vhost_user_job_callback;
853 	usfstl_list_init(&dev->fds);
854 
855 	if (server->ops->connected)
856 		server->ops->connected(&dev->ext);
857 
858 	dev->entry.fd = fd;
859 	dev->entry.handler = usfstl_vhost_user_handle_msg;
860 
861 	usfstl_loop_register(&dev->entry);
862 }
863 
usfstl_vhost_user_server_start(struct usfstl_vhost_user_server * server)864 void usfstl_vhost_user_server_start(struct usfstl_vhost_user_server *server)
865 {
866 	USFSTL_ASSERT(server->ops);
867 	USFSTL_ASSERT(server->socket);
868 
869 	usfstl_uds_create(server->socket, usfstl_vhost_user_connected, server);
870 }
871 
usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server * server)872 void usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server *server)
873 {
874 	usfstl_uds_remove(server->socket);
875 }
876 
usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev * extdev,unsigned int virtq_idx,const uint8_t * data,size_t datalen)877 void usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev *extdev,
878 				  unsigned int virtq_idx,
879 				  const uint8_t *data, size_t datalen)
880 {
881 	struct usfstl_vhost_user_dev_int *dev;
882 	/* preallocate on the stack for most cases */
883 	struct iovec in_sg[SG_STACK_PREALLOC] = { };
884 	struct iovec out_sg[SG_STACK_PREALLOC] = { };
885 	struct usfstl_vhost_user_buf _buf = {
886 		.in_sg = in_sg,
887 		.n_in_sg = SG_STACK_PREALLOC,
888 		.out_sg = out_sg,
889 		.n_out_sg = SG_STACK_PREALLOC,
890 	};
891 	struct usfstl_vhost_user_buf *buf;
892 
893 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
894 
895 	USFSTL_ASSERT(virtq_idx <= dev->ext.server->max_queues);
896 
897 	if (!dev->virtqs[virtq_idx].enabled)
898 		return;
899 
900 	buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf);
901 	if (!buf)
902 		return;
903 
904 	USFSTL_ASSERT(buf->n_in_sg && !buf->n_out_sg);
905 	iov_fill(buf->in_sg, buf->n_in_sg, data, datalen);
906 	buf->written = datalen;
907 
908 	usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
909 	usfstl_vhost_user_free_buf(buf);
910 }
911 
usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev * dev)912 void usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev *dev)
913 {
914 	struct usfstl_vhost_user_dev_int *idev;
915 	struct vhost_user_msg msg = {
916 		.hdr.request = VHOST_USER_SLAVE_CONFIG_CHANGE_MSG,
917 		.hdr.flags = VHOST_USER_VERSION,
918 	};
919 
920 	idev = container_of(dev, struct usfstl_vhost_user_dev_int, ext);
921 
922 	if (!(idev->ext.protocol_features &
923 			(1ULL << VHOST_USER_PROTOCOL_F_CONFIG)))
924 		return;
925 
926 	usfstl_vhost_user_send_msg(idev, &msg);
927 }
928 
usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)929 void *usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
930 {
931 	struct usfstl_vhost_user_dev_int *dev;
932 	unsigned int region;
933 
934 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
935 
936 	for (region = 0; region < dev->n_regions; region++) {
937 		if (addr >= dev->regions[region].user_addr &&
938 		    addr < dev->regions[region].user_addr +
939 			   dev->regions[region].size)
940 			return (uint8_t *)dev->region_vaddr[region] +
941 			       (addr -
942 				dev->regions[region].user_addr +
943 				dev->regions[region].mmap_offset);
944 	}
945 	USFSTL_ASSERT(0, "cannot translate user address %"PRIx64"\n", addr);
946 	return NULL;
947 }
948 
usfstl_vhost_user_to_phys(struct usfstl_vhost_user_dev * extdev,uint64_t addr)949 uint64_t usfstl_vhost_user_to_phys(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
950 {
951 	struct usfstl_vhost_user_dev_int *dev;
952 	unsigned int region;
953 
954 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
955 
956 	for (region = 0; region < dev->n_regions; region++) {
957 		if (addr >= dev->regions[region].user_addr &&
958 		    addr < dev->regions[region].user_addr +
959 			   dev->regions[region].size)
960 			return addr -
961 				dev->regions[region].user_addr +
962 				dev->regions[region].guest_phys_addr;
963 	}
964 	USFSTL_ASSERT(0, "cannot translate user address %"PRIx64"\n", addr);
965 	return 0;
966 }
967 
usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)968 void *usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
969 {
970 	struct usfstl_vhost_user_dev_int *dev;
971 	unsigned int region;
972 
973 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
974 
975 	for (region = 0; region < dev->n_regions; region++) {
976 		if (addr >= dev->regions[region].guest_phys_addr &&
977 		    addr < dev->regions[region].guest_phys_addr +
978 			   dev->regions[region].size)
979 			return (uint8_t *)dev->region_vaddr[region] +
980 			       (addr -
981 				dev->regions[region].guest_phys_addr +
982 				dev->regions[region].mmap_offset);
983 	}
984 
985 	USFSTL_ASSERT(0, "cannot translate physical address %"PRIx64"\n", addr);
986 	return NULL;
987 }
988 
iov_len(struct iovec * sg,unsigned int nsg)989 size_t iov_len(struct iovec *sg, unsigned int nsg)
990 {
991 	size_t len = 0;
992 	unsigned int i;
993 
994 	for (i = 0; i < nsg; i++)
995 		len += sg[i].iov_len;
996 
997 	return len;
998 }
999 
iov_fill(struct iovec * sg,unsigned int nsg,const void * _buf,size_t buflen)1000 size_t iov_fill(struct iovec *sg, unsigned int nsg,
1001 		const void *_buf, size_t buflen)
1002 {
1003 	const char *buf = _buf;
1004 	unsigned int i;
1005 	size_t copied = 0;
1006 
1007 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
1008 	for (i = 0; buflen && i < nsg; i++) {
1009 		size_t cpy = min(buflen, sg[i].iov_len);
1010 
1011 		memcpy(sg[i].iov_base, buf, cpy);
1012 		buflen -= cpy;
1013 		copied += cpy;
1014 		buf += cpy;
1015 	}
1016 
1017 	return copied;
1018 }
1019 
iov_read(void * _buf,size_t buflen,struct iovec * sg,unsigned int nsg)1020 size_t iov_read(void *_buf, size_t buflen,
1021 		struct iovec *sg, unsigned int nsg)
1022 {
1023 	char *buf = _buf;
1024 	unsigned int i;
1025 	size_t copied = 0;
1026 
1027 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
1028 	for (i = 0; buflen && i < nsg; i++) {
1029 		size_t cpy = min(buflen, sg[i].iov_len);
1030 
1031 		memcpy(buf, sg[i].iov_base, cpy);
1032 		buflen -= cpy;
1033 		copied += cpy;
1034 		buf += cpy;
1035 	}
1036 
1037 	return copied;
1038 }
1039