• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 - 2020 Intel Corporation
3  *
4  * SPDX-License-Identifier: BSD-3-Clause
5  */
6 #include <stdlib.h>
7 #include <usfstl/list.h>
8 #include <usfstl/loop.h>
9 #include <usfstl/uds.h>
10 #include <sys/socket.h>
11 #include <sys/mman.h>
12 #include <sys/un.h>
13 #include <stdlib.h>
14 #include <errno.h>
15 #include <usfstl/vhost.h>
16 #include <linux/virtio_ring.h>
17 #include <linux/virtio_config.h>
18 #include <endian.h>
19 
20 /* copied from uapi */
21 #define VIRTIO_F_VERSION_1		32
22 
23 #define MAX_REGIONS 8
24 #define SG_STACK_PREALLOC 5
25 
26 struct usfstl_vhost_user_dev_int {
27 	struct usfstl_list fds;
28 	struct usfstl_job irq_job;
29 
30 	struct usfstl_loop_entry entry;
31 
32 	struct usfstl_vhost_user_dev ext;
33 
34 	unsigned int n_regions;
35 	struct vhost_user_region regions[MAX_REGIONS];
36 	int region_fds[MAX_REGIONS];
37 	void *region_vaddr[MAX_REGIONS];
38 
39 	int req_fd;
40 
41 	struct {
42 		struct usfstl_loop_entry entry;
43 		bool enabled;
44 		bool triggered;
45 		struct vring virtq;
46 		int call_fd;
47 		uint16_t last_avail_idx;
48 	} virtqs[];
49 };
50 
51 #define CONV(bits)							\
52 static inline uint##bits##_t __attribute__((used))			\
53 cpu_to_virtio##bits(struct usfstl_vhost_user_dev_int *dev,		\
54 		    uint##bits##_t v)					\
55 {									\
56 	if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1))		\
57 		return htole##bits(v);					\
58 	return v;							\
59 }									\
60 static inline uint##bits##_t __attribute__((used))			\
61 virtio_to_cpu##bits(struct usfstl_vhost_user_dev_int *dev,		\
62 		    uint##bits##_t v)					\
63 {									\
64 	if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1))		\
65 		return le##bits##toh(v);				\
66 	return v;							\
67 }
68 
69 CONV(16)
70 CONV(32)
71 CONV(64)
72 
73 static struct usfstl_vhost_user_buf *
usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx,struct usfstl_vhost_user_buf * fixed)74 usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
75 				unsigned int virtq_idx,
76 				struct usfstl_vhost_user_buf *fixed)
77 {
78 	struct usfstl_vhost_user_buf *buf = fixed;
79 	struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
80 	uint16_t avail_idx = virtio_to_cpu16(dev, virtq->avail->idx);
81 	uint16_t idx, desc_idx;
82 	struct vring_desc *desc;
83 	unsigned int n_in = 0, n_out = 0;
84 	bool more;
85 
86 	if (avail_idx == dev->virtqs[virtq_idx].last_avail_idx)
87 		return NULL;
88 
89 	/* ensure we read the descriptor after checking the index */
90 	__sync_synchronize();
91 
92 	idx = dev->virtqs[virtq_idx].last_avail_idx++;
93 	idx %= virtq->num;
94 	desc_idx = virtio_to_cpu16(dev, virtq->avail->ring[idx]);
95 	USFSTL_ASSERT(desc_idx < virtq->num);
96 
97 	desc = &virtq->desc[desc_idx];
98 	do {
99 		more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
100 
101 		if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE)
102 			n_in++;
103 		else
104 			n_out++;
105 		desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
106 	} while (more);
107 
108 	if (n_in > fixed->n_in_sg || n_out > fixed->n_out_sg) {
109 		size_t sz = sizeof(*buf);
110 		struct iovec *vec;
111 
112 		sz += (n_in + n_out) * sizeof(*vec);
113 
114 		buf = calloc(1, sz);
115 		if (!buf)
116 			return NULL;
117 
118 		vec = (void *)(buf + 1);
119 		buf->in_sg = vec;
120 		buf->out_sg = vec + n_in;
121 		buf->allocated = true;
122 	}
123 
124 	buf->n_in_sg = 0;
125 	buf->n_out_sg = 0;
126 	buf->idx = desc_idx;
127 
128 	desc = &virtq->desc[desc_idx];
129 	do {
130 		struct iovec *vec;
131 		uint64_t addr;
132 
133 		more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
134 
135 		if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE) {
136 			vec = &buf->in_sg[buf->n_in_sg];
137 			buf->n_in_sg++;
138 		} else {
139 			vec = &buf->out_sg[buf->n_out_sg];
140 			buf->n_out_sg++;
141 		}
142 
143 		addr = virtio_to_cpu64(dev, desc->addr);
144 		vec->iov_base = usfstl_vhost_phys_to_va(&dev->ext, addr);
145 		vec->iov_len = virtio_to_cpu32(dev, desc->len);
146 
147 		desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
148 	} while (more);
149 
150 	return buf;
151 }
152 
usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf * buf)153 static void usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf *buf)
154 {
155 	if (buf->allocated)
156 		free(buf);
157 }
158 
usfstl_vhost_user_readable_handler(struct usfstl_loop_entry * entry)159 static void usfstl_vhost_user_readable_handler(struct usfstl_loop_entry *entry)
160 {
161 	usfstl_loop_unregister(entry);
162 	entry->fd = -1;
163 }
164 
usfstl_vhost_user_read_msg(int fd,struct msghdr * msghdr)165 static int usfstl_vhost_user_read_msg(int fd, struct msghdr *msghdr)
166 {
167 	struct iovec msg_iov;
168 	struct msghdr hdr2 = {
169 		.msg_iov = &msg_iov,
170 		.msg_iovlen = 1,
171 		.msg_control = msghdr->msg_control,
172 		.msg_controllen = msghdr->msg_controllen,
173 	};
174 	struct vhost_user_msg_hdr *hdr;
175 	size_t i;
176 	size_t maxlen = 0;
177 	ssize_t len;
178 	ssize_t prev_datalen;
179 	size_t prev_iovlen;
180 
181 	USFSTL_ASSERT(msghdr->msg_iovlen >= 1);
182 	USFSTL_ASSERT(msghdr->msg_iov[0].iov_len >= sizeof(*hdr));
183 
184 	hdr = msghdr->msg_iov[0].iov_base;
185 	msg_iov.iov_base = hdr;
186 	msg_iov.iov_len = sizeof(*hdr);
187 
188 	len = recvmsg(fd, &hdr2, 0);
189 	if (len < 0)
190 		return -errno;
191 	if (len == 0)
192 		return -ENOTCONN;
193 
194 	for (i = 0; i < msghdr->msg_iovlen; i++)
195 		maxlen += msghdr->msg_iov[i].iov_len;
196 	maxlen -= sizeof(*hdr);
197 
198 	USFSTL_ASSERT_EQ((int)len, (int)sizeof(*hdr), "%d");
199 	USFSTL_ASSERT(hdr->size <= maxlen);
200 
201 	if (!hdr->size)
202 		return 0;
203 
204 	prev_iovlen = msghdr->msg_iovlen;
205 	msghdr->msg_iovlen = 1;
206 
207 	msghdr->msg_control = NULL;
208 	msghdr->msg_controllen = 0;
209 	msghdr->msg_iov[0].iov_base += sizeof(*hdr);
210 	prev_datalen = msghdr->msg_iov[0].iov_len;
211 	msghdr->msg_iov[0].iov_len = hdr->size;
212 	len = recvmsg(fd, msghdr, 0);
213 
214 	/* restore just in case the user needs it */
215 	msghdr->msg_iov[0].iov_base -= sizeof(*hdr);
216 	msghdr->msg_iov[0].iov_len = prev_datalen;
217 	msghdr->msg_control = hdr2.msg_control;
218 	msghdr->msg_controllen = hdr2.msg_controllen;
219 
220 	msghdr->msg_iovlen = prev_iovlen;
221 
222 	if (len < 0)
223 		return -errno;
224 	if (len == 0)
225 		return -ENOTCONN;
226 
227 	USFSTL_ASSERT_EQ(hdr->size, (uint32_t)len, "%u");
228 
229 	return 0;
230 }
231 
usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int * dev,struct vhost_user_msg * msg)232 static void usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int *dev,
233 				       struct vhost_user_msg *msg)
234 {
235 	size_t msgsz = sizeof(msg->hdr) + msg->hdr.size;
236 	bool ack = dev->ext.protocol_features &
237 		   (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
238 	ssize_t written;
239 
240 	if (ack)
241 		msg->hdr.flags |= VHOST_USER_MSG_FLAGS_NEED_REPLY;
242 
243 	written = write(dev->req_fd, msg, msgsz);
244 	USFSTL_ASSERT_EQ(written, (ssize_t)msgsz, "%zd");
245 
246 	if (ack) {
247 		struct usfstl_loop_entry entry = {
248 			.fd = dev->req_fd,
249 			.priority = 0x7fffffff, // max
250 			.handler = usfstl_vhost_user_readable_handler,
251 		};
252 		struct iovec msg_iov = {
253 			.iov_base = msg,
254 			.iov_len = sizeof(*msg),
255 		};
256 		struct msghdr msghdr = {
257 			.msg_iovlen = 1,
258 			.msg_iov = &msg_iov,
259 		};
260 
261 		/*
262 		 * Wait for the fd to be readable - we may have to
263 		 * handle other simulation (time) messages while
264 		 * waiting ...
265 		 */
266 		usfstl_loop_register(&entry);
267 		while (entry.fd != -1)
268 			usfstl_loop_wait_and_handle();
269 		USFSTL_ASSERT_EQ(usfstl_vhost_user_read_msg(dev->req_fd,
270 							    &msghdr),
271 				 0, "%d");
272 	}
273 }
274 
usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int * dev,struct usfstl_vhost_user_buf * buf,int virtq_idx)275 static void usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
276 					     struct usfstl_vhost_user_buf *buf,
277 					     int virtq_idx)
278 {
279 	struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
280 	unsigned int idx, widx;
281 	int call_fd = dev->virtqs[virtq_idx].call_fd;
282 	ssize_t written;
283 	uint64_t e = 1;
284 
285 	if (dev->ext.server->ctrl)
286 		usfstl_sched_ctrl_sync_to(dev->ext.server->ctrl);
287 
288 	idx = virtio_to_cpu16(dev, virtq->used->idx);
289 	widx = idx + 1;
290 
291 	idx %= virtq->num;
292 	virtq->used->ring[idx].id = cpu_to_virtio32(dev, buf->idx);
293 	virtq->used->ring[idx].len = cpu_to_virtio32(dev, buf->written);
294 
295 	/* write buffers / used table before flush */
296 	__sync_synchronize();
297 
298 	virtq->used->idx = cpu_to_virtio16(dev, widx);
299 
300 	if (call_fd < 0 &&
301 	    dev->ext.protocol_features &
302 			(1ULL << VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
303 	    dev->ext.protocol_features &
304 			(1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
305 		struct vhost_user_msg msg = {
306 			.hdr.request = VHOST_USER_SLAVE_VRING_CALL,
307 			.hdr.flags = VHOST_USER_VERSION,
308 			.hdr.size = sizeof(msg.payload.vring_state),
309 			.payload.vring_state = {
310 				.idx = virtq_idx,
311 			},
312 		};
313 
314 		usfstl_vhost_user_send_msg(dev, &msg);
315 		return;
316 	}
317 
318 	written = write(dev->virtqs[virtq_idx].call_fd, &e, sizeof(e));
319 	USFSTL_ASSERT_EQ(written, (ssize_t)sizeof(e), "%zd");
320 }
321 
usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx)322 static void usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int *dev,
323 					   unsigned int virtq_idx)
324 {
325 	/* preallocate on the stack for most cases */
326 	struct iovec in_sg[SG_STACK_PREALLOC] = { };
327 	struct iovec out_sg[SG_STACK_PREALLOC] = { };
328 	struct usfstl_vhost_user_buf _buf = {
329 		.in_sg = in_sg,
330 		.n_in_sg = SG_STACK_PREALLOC,
331 		.out_sg = out_sg,
332 		.n_out_sg = SG_STACK_PREALLOC,
333 	};
334 	struct usfstl_vhost_user_buf *buf;
335 
336 	while ((buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf))) {
337 		dev->ext.server->ops->handle(&dev->ext, buf, virtq_idx);
338 
339 		usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
340 		usfstl_vhost_user_free_buf(buf);
341 	}
342 }
343 
usfstl_vhost_user_job_callback(struct usfstl_job * job)344 static void usfstl_vhost_user_job_callback(struct usfstl_job *job)
345 {
346 	struct usfstl_vhost_user_dev_int *dev = job->data;
347 	unsigned int virtq;
348 
349 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
350 		if (!dev->virtqs[virtq].triggered)
351 			continue;
352 		dev->virtqs[virtq].triggered = false;
353 
354 		usfstl_vhost_user_handle_queue(dev, virtq);
355 	}
356 }
357 
usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq)358 static void usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
359 					 unsigned int virtq)
360 {
361 	if (!(dev->ext.server->input_queues & (1ULL << virtq)))
362 		return;
363 
364 	dev->virtqs[virtq].triggered = true;
365 
366 	if (usfstl_job_scheduled(&dev->irq_job))
367 		return;
368 
369 	if (!dev->ext.server->scheduler) {
370 		usfstl_vhost_user_job_callback(&dev->irq_job);
371 		return;
372 	}
373 
374 	if (dev->ext.server->ctrl)
375 		usfstl_sched_ctrl_sync_from(dev->ext.server->ctrl);
376 
377 	dev->irq_job.start = usfstl_sched_current_time(dev->ext.server->scheduler) +
378 			     dev->ext.server->interrupt_latency;
379 	usfstl_sched_add_job(dev->ext.server->scheduler, &dev->irq_job);
380 }
381 
usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry * entry)382 static void usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry *entry)
383 {
384 	struct usfstl_vhost_user_dev_int *dev = entry->data;
385 	unsigned int virtq;
386 	uint64_t v;
387 
388 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
389 		if (entry == &dev->virtqs[virtq].entry)
390 			break;
391 	}
392 
393 	USFSTL_ASSERT(virtq < dev->ext.server->max_queues);
394 
395 	USFSTL_ASSERT_EQ((int)read(entry->fd, &v, sizeof(v)),
396 		       (int)sizeof(v), "%d");
397 
398 	usfstl_vhost_user_virtq_kick(dev, virtq);
399 }
400 
usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int * dev)401 static void usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int *dev)
402 {
403 	unsigned int idx;
404 	for (idx = 0; idx < MAX_REGIONS; idx++) {
405 		if (dev->region_vaddr[idx]) {
406 			munmap(dev->region_vaddr[idx],
407 			       dev->regions[idx].size + dev->regions[idx].mmap_offset);
408 			dev->region_vaddr[idx] = NULL;
409 		}
410 
411 		if (dev->region_fds[idx] != -1) {
412 			close(dev->region_fds[idx]);
413 			dev->region_fds[idx] = -1;
414 		}
415 	}
416 }
417 
usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int * dev)418 static void usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int *dev)
419 {
420 	unsigned int idx;
421 
422 	for (idx = 0; idx < dev->n_regions; idx++) {
423 		USFSTL_ASSERT(!dev->region_vaddr[idx]);
424 
425 		/*
426 		 * Cannot rely on the offset being page-aligned, I think ...
427 		 * adjust for it later when we translate addresses instead.
428 		 */
429 		dev->region_vaddr[idx] = mmap(NULL,
430 					      dev->regions[idx].size +
431 					      dev->regions[idx].mmap_offset,
432 					      PROT_READ | PROT_WRITE, MAP_SHARED,
433 					      dev->region_fds[idx], 0);
434 		USFSTL_ASSERT(dev->region_vaddr[idx] != (void *)-1,
435 			      "mmap() failed (%d) for fd %d", errno, dev->region_fds[idx]);
436 	}
437 }
438 
439 static void
usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq,int fd)440 usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
441 				  unsigned int virtq, int fd)
442 {
443 	if (dev->virtqs[virtq].entry.fd != -1) {
444 		usfstl_loop_unregister(&dev->virtqs[virtq].entry);
445 		close(dev->virtqs[virtq].entry.fd);
446 	}
447 
448 	if (fd != -1) {
449 		dev->virtqs[virtq].entry.fd = fd;
450 		usfstl_loop_register(&dev->virtqs[virtq].entry);
451 	}
452 }
453 
usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int * dev)454 static void usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int *dev)
455 {
456 	unsigned int virtq;
457 
458 	usfstl_loop_unregister(&dev->entry);
459 	usfstl_sched_del_job(&dev->irq_job);
460 
461 	for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
462 		usfstl_vhost_user_update_virtq_kick(dev, virtq, -1);
463 		if (dev->virtqs[virtq].call_fd != -1)
464 			close(dev->virtqs[virtq].call_fd);
465 	}
466 
467 	usfstl_vhost_user_clear_mappings(dev);
468 
469 	if (dev->req_fd != -1)
470 		close(dev->req_fd);
471 
472 	if (dev->ext.server->ops->disconnected)
473 		dev->ext.server->ops->disconnected(&dev->ext);
474 
475 	if (dev->entry.fd != -1)
476 		close(dev->entry.fd);
477 
478 	free(dev);
479 }
480 
usfstl_vhost_user_get_msg_fds(struct msghdr * msghdr,int * outfds,int max_fds)481 static void usfstl_vhost_user_get_msg_fds(struct msghdr *msghdr,
482 					  int *outfds, int max_fds)
483 {
484 	struct cmsghdr *msg;
485 	int fds;
486 
487 	for (msg = CMSG_FIRSTHDR(msghdr); msg; msg = CMSG_NXTHDR(msghdr, msg)) {
488 		if (msg->cmsg_level != SOL_SOCKET)
489 			continue;
490 		if (msg->cmsg_type != SCM_RIGHTS)
491 			continue;
492 
493 		fds = (msg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
494 		USFSTL_ASSERT(fds <= max_fds);
495 		memcpy(outfds, CMSG_DATA(msg), fds * sizeof(int));
496 		break;
497 	}
498 }
499 
usfstl_vhost_user_handle_msg(struct usfstl_loop_entry * entry)500 static void usfstl_vhost_user_handle_msg(struct usfstl_loop_entry *entry)
501 {
502 	struct usfstl_vhost_user_dev_int *dev;
503 	struct vhost_user_msg msg;
504 	uint8_t data[256]; // limits the config space size ...
505 	struct iovec msg_iov[3] = {
506 		[0] = {
507 			.iov_base = &msg.hdr,
508 			.iov_len = sizeof(msg.hdr),
509 		},
510 		[1] = {
511 			.iov_base = &msg.payload,
512 			.iov_len = sizeof(msg.payload),
513 		},
514 		[2] = {
515 			.iov_base = data,
516 			.iov_len = sizeof(data),
517 		},
518 	};
519 	uint8_t msg_control[CMSG_SPACE(sizeof(int) * MAX_REGIONS)] = { 0 };
520 	struct msghdr msghdr = {
521 		.msg_iov = msg_iov,
522 		.msg_iovlen = 3,
523 		.msg_control = msg_control,
524 		.msg_controllen = sizeof(msg_control),
525 	};
526 	ssize_t len;
527 	size_t reply_len = 0;
528 	unsigned int virtq;
529 	int fd;
530 
531 	dev = container_of(entry, struct usfstl_vhost_user_dev_int, entry);
532 
533 	if (usfstl_vhost_user_read_msg(entry->fd, &msghdr)) {
534 		usfstl_vhost_user_dev_free(dev);
535 		return;
536 	}
537 	len = msg.hdr.size;
538 
539 	USFSTL_ASSERT((msg.hdr.flags & VHOST_USER_MSG_FLAGS_VERSION) ==
540 		    VHOST_USER_VERSION);
541 
542 	switch (msg.hdr.request) {
543 	case VHOST_USER_GET_FEATURES:
544 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
545 		reply_len = sizeof(uint64_t);
546 		msg.payload.u64 = dev->ext.server->features;
547 		msg.payload.u64 |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
548 		break;
549 	case VHOST_USER_SET_FEATURES:
550 		USFSTL_ASSERT_EQ(len, (ssize_t)sizeof(msg.payload.u64), "%zd");
551 		dev->ext.features = msg.payload.u64;
552 		break;
553 	case VHOST_USER_SET_OWNER:
554 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
555 		/* nothing to be done */
556 		break;
557 	case VHOST_USER_SET_MEM_TABLE:
558 		USFSTL_ASSERT(len <= (int)sizeof(msg.payload.mem_regions));
559 		USFSTL_ASSERT(msg.payload.mem_regions.n_regions <= MAX_REGIONS);
560 		usfstl_vhost_user_clear_mappings(dev);
561 		memcpy(dev->regions, msg.payload.mem_regions.regions,
562 		       msg.payload.mem_regions.n_regions *
563 		       sizeof(dev->regions[0]));
564 		dev->n_regions = msg.payload.mem_regions.n_regions;
565 		usfstl_vhost_user_get_msg_fds(&msghdr, dev->region_fds, MAX_REGIONS);
566 		usfstl_vhost_user_setup_mappings(dev);
567 		break;
568 	case VHOST_USER_SET_VRING_NUM:
569 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
570 		USFSTL_ASSERT(msg.payload.vring_state.idx <
571 			      dev->ext.server->max_queues);
572 		dev->virtqs[msg.payload.vring_state.idx].virtq.num =
573 			msg.payload.vring_state.num;
574 		break;
575 	case VHOST_USER_SET_VRING_ADDR:
576 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_addr));
577 		USFSTL_ASSERT(msg.payload.vring_addr.idx <=
578 			      dev->ext.server->max_queues);
579 		USFSTL_ASSERT_EQ(msg.payload.vring_addr.flags, (uint32_t)0, "0x%x");
580 		USFSTL_ASSERT(!dev->virtqs[msg.payload.vring_addr.idx].enabled);
581 		dev->virtqs[msg.payload.vring_addr.idx].last_avail_idx = 0;
582 		dev->virtqs[msg.payload.vring_addr.idx].virtq.desc =
583 			usfstl_vhost_user_to_va(&dev->ext,
584 					      msg.payload.vring_addr.descriptor);
585 		dev->virtqs[msg.payload.vring_addr.idx].virtq.used =
586 			usfstl_vhost_user_to_va(&dev->ext,
587 					      msg.payload.vring_addr.used);
588 		dev->virtqs[msg.payload.vring_addr.idx].virtq.avail =
589 			usfstl_vhost_user_to_va(&dev->ext,
590 					      msg.payload.vring_addr.avail);
591 		USFSTL_ASSERT(dev->virtqs[msg.payload.vring_addr.idx].virtq.avail &&
592 			    dev->virtqs[msg.payload.vring_addr.idx].virtq.desc &&
593 			    dev->virtqs[msg.payload.vring_addr.idx].virtq.used);
594 		break;
595 	case VHOST_USER_SET_VRING_BASE:
596 		/* ignored - logging not supported */
597 		/*
598 		 * FIXME: our Linux UML virtio implementation
599 		 *        shouldn't send this
600 		 */
601 		break;
602 	case VHOST_USER_SET_VRING_KICK:
603 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
604 		virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
605 		USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
606 		if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
607 			fd = -1;
608 		else
609 			usfstl_vhost_user_get_msg_fds(&msghdr, &fd, 1);
610 		usfstl_vhost_user_update_virtq_kick(dev, virtq, fd);
611 		break;
612 	case VHOST_USER_SET_VRING_CALL:
613 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
614 		virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
615 		USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
616 		if (dev->virtqs[virtq].call_fd != -1)
617 			close(dev->virtqs[virtq].call_fd);
618 		if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
619 			dev->virtqs[virtq].call_fd = -1;
620 		else
621 			usfstl_vhost_user_get_msg_fds(&msghdr,
622 						    &dev->virtqs[virtq].call_fd,
623 						    1);
624 		break;
625 	case VHOST_USER_GET_PROTOCOL_FEATURES:
626 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
627 		reply_len = sizeof(uint64_t);
628 		msg.payload.u64 = dev->ext.server->protocol_features;
629 		if (dev->ext.server->config && dev->ext.server->config_len)
630 			msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
631 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
632 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD;
633 		msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK;
634 		break;
635 	case VHOST_USER_SET_VRING_ENABLE:
636 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
637 		USFSTL_ASSERT(msg.payload.vring_state.idx <
638 			      dev->ext.server->max_queues);
639 		dev->virtqs[msg.payload.vring_state.idx].enabled =
640 			msg.payload.vring_state.num;
641 		break;
642 	case VHOST_USER_SET_PROTOCOL_FEATURES:
643 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
644 		dev->ext.protocol_features = msg.payload.u64;
645 		break;
646 	case VHOST_USER_SET_SLAVE_REQ_FD:
647 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
648 		if (dev->req_fd != -1)
649 			close(dev->req_fd);
650 		usfstl_vhost_user_get_msg_fds(&msghdr, &dev->req_fd, 1);
651 		USFSTL_ASSERT(dev->req_fd != -1);
652 		break;
653 	case VHOST_USER_GET_CONFIG:
654 		USFSTL_ASSERT(len == (int)(sizeof(msg.payload.cfg_space) +
655 					msg.payload.cfg_space.size));
656 		USFSTL_ASSERT(dev->ext.server->config && dev->ext.server->config_len);
657 		USFSTL_ASSERT(msg.payload.cfg_space.offset == 0);
658 		USFSTL_ASSERT(msg.payload.cfg_space.size <= dev->ext.server->config_len);
659 		msg.payload.cfg_space.flags = 0;
660 		msg_iov[1].iov_len = sizeof(msg.payload.cfg_space);
661 		msg_iov[2].iov_base = (void *)dev->ext.server->config;
662 		reply_len = len;
663 		break;
664 	case VHOST_USER_VRING_KICK:
665 		USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
666 		USFSTL_ASSERT(msg.payload.vring_state.idx <
667 			      dev->ext.server->max_queues);
668 		USFSTL_ASSERT(msg.payload.vring_state.num == 0);
669 		usfstl_vhost_user_virtq_kick(dev, msg.payload.vring_state.idx);
670 		break;
671 	case VHOST_USER_GET_SHARED_MEMORY_REGIONS:
672 		USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
673 		reply_len = sizeof(uint64_t);
674 		msg.payload.u64 = 0;
675 		break;
676 	default:
677 		USFSTL_ASSERT(0, "Unsupported message: %d\n", msg.hdr.request);
678 	}
679 
680 	if (reply_len || (msg.hdr.flags & VHOST_USER_MSG_FLAGS_NEED_REPLY)) {
681 		size_t i, tmp;
682 
683 		if (!reply_len) {
684 			msg.payload.u64 = 0;
685 			reply_len = sizeof(uint64_t);
686 		}
687 
688 		msg.hdr.size = reply_len;
689 		msg.hdr.flags &= ~VHOST_USER_MSG_FLAGS_NEED_REPLY;
690 		msg.hdr.flags |= VHOST_USER_MSG_FLAGS_REPLY;
691 
692 		msghdr.msg_control = NULL;
693 		msghdr.msg_controllen = 0;
694 
695 		reply_len += sizeof(msg.hdr);
696 
697 		tmp = reply_len;
698 		for (i = 0; tmp && i < msghdr.msg_iovlen; i++) {
699 			if (tmp <= msg_iov[i].iov_len)
700 				msg_iov[i].iov_len = tmp;
701 			tmp -= msg_iov[i].iov_len;
702 		}
703 		msghdr.msg_iovlen = i;
704 
705 		while (reply_len) {
706 			len = sendmsg(entry->fd, &msghdr, 0);
707 			if (len < 0) {
708 				usfstl_vhost_user_dev_free(dev);
709 				return;
710 			}
711 			USFSTL_ASSERT(len != 0);
712 			reply_len -= len;
713 
714 			for (i = 0; len && i < msghdr.msg_iovlen; i++) {
715 				unsigned int rm = len;
716 
717 				if (msg_iov[i].iov_len <= (size_t)len)
718 					rm = msg_iov[i].iov_len;
719 				len -= rm;
720 				msg_iov[i].iov_len -= rm;
721 				msg_iov[i].iov_base += rm;
722 			}
723 		}
724 	}
725 }
726 
usfstl_vhost_user_connected(int fd,void * data)727 static void usfstl_vhost_user_connected(int fd, void *data)
728 {
729 	struct usfstl_vhost_user_server *server = data;
730 	struct usfstl_vhost_user_dev_int *dev;
731 	unsigned int i;
732 
733 	dev = calloc(1, sizeof(*dev) +
734 			sizeof(dev->virtqs[0]) * server->max_queues);
735 
736 	USFSTL_ASSERT(dev);
737 
738 	for (i = 0; i < server->max_queues; i++) {
739 		dev->virtqs[i].call_fd = -1;
740 		dev->virtqs[i].entry.fd = -1;
741 		dev->virtqs[i].entry.data = dev;
742 		dev->virtqs[i].entry.handler = usfstl_vhost_user_virtq_fdkick;
743 	}
744 
745 	for (i = 0; i < MAX_REGIONS; i++)
746 		dev->region_fds[i] = -1;
747 	dev->req_fd = -1;
748 
749 	dev->ext.server = server;
750 	dev->irq_job.data = dev;
751 	dev->irq_job.name = "vhost-user-irq";
752 	dev->irq_job.priority = 0x10000000;
753 	dev->irq_job.callback = usfstl_vhost_user_job_callback;
754 	usfstl_list_init(&dev->fds);
755 
756 	if (server->ops->connected)
757 		server->ops->connected(&dev->ext);
758 
759 	dev->entry.fd = fd;
760 	dev->entry.handler = usfstl_vhost_user_handle_msg;
761 
762 	usfstl_loop_register(&dev->entry);
763 }
764 
usfstl_vhost_user_server_start(struct usfstl_vhost_user_server * server)765 void usfstl_vhost_user_server_start(struct usfstl_vhost_user_server *server)
766 {
767 	USFSTL_ASSERT(server->ops);
768 	USFSTL_ASSERT(server->socket);
769 
770 	usfstl_uds_create(server->socket, usfstl_vhost_user_connected, server);
771 }
772 
usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server * server)773 void usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server *server)
774 {
775 	usfstl_uds_remove(server->socket);
776 }
777 
usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev * extdev,unsigned int virtq_idx,const uint8_t * data,size_t datalen)778 void usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev *extdev,
779 				  unsigned int virtq_idx,
780 				  const uint8_t *data, size_t datalen)
781 {
782 	struct usfstl_vhost_user_dev_int *dev;
783 	/* preallocate on the stack for most cases */
784 	struct iovec in_sg[SG_STACK_PREALLOC] = { };
785 	struct iovec out_sg[SG_STACK_PREALLOC] = { };
786 	struct usfstl_vhost_user_buf _buf = {
787 		.in_sg = in_sg,
788 		.n_in_sg = SG_STACK_PREALLOC,
789 		.out_sg = out_sg,
790 		.n_out_sg = SG_STACK_PREALLOC,
791 	};
792 	struct usfstl_vhost_user_buf *buf;
793 
794 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
795 
796 	USFSTL_ASSERT(virtq_idx <= dev->ext.server->max_queues);
797 
798 	if (!dev->virtqs[virtq_idx].enabled)
799 		return;
800 
801 	buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf);
802 	if (!buf)
803 		return;
804 
805 	USFSTL_ASSERT(buf->n_in_sg && !buf->n_out_sg);
806 	iov_fill(buf->in_sg, buf->n_in_sg, data, datalen);
807 	buf->written = datalen;
808 
809 	usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
810 	usfstl_vhost_user_free_buf(buf);
811 }
812 
usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev * dev)813 void usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev *dev)
814 {
815 	struct usfstl_vhost_user_dev_int *idev;
816 	struct vhost_user_msg msg = {
817 		.hdr.request = VHOST_USER_SLAVE_CONFIG_CHANGE_MSG,
818 		.hdr.flags = VHOST_USER_VERSION,
819 	};
820 
821 	idev = container_of(dev, struct usfstl_vhost_user_dev_int, ext);
822 
823 	if (!(idev->ext.protocol_features &
824 			(1ULL << VHOST_USER_PROTOCOL_F_CONFIG)))
825 		return;
826 
827 	usfstl_vhost_user_send_msg(idev, &msg);
828 }
829 
usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)830 void *usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
831 {
832 	struct usfstl_vhost_user_dev_int *dev;
833 	unsigned int region;
834 
835 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
836 
837 	for (region = 0; region < dev->n_regions; region++) {
838 		if (addr >= dev->regions[region].user_addr &&
839 		    addr < dev->regions[region].user_addr +
840 			   dev->regions[region].size)
841 			return (uint8_t *)dev->region_vaddr[region] +
842 			       (addr -
843 				dev->regions[region].user_addr +
844 				dev->regions[region].mmap_offset);
845 	}
846 
847 	USFSTL_ASSERT(0, "cannot translate user address %"PRIx64"\n", addr);
848 	return NULL;
849 }
850 
usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)851 void *usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
852 {
853 	struct usfstl_vhost_user_dev_int *dev;
854 	unsigned int region;
855 
856 	dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
857 
858 	for (region = 0; region < dev->n_regions; region++) {
859 		if (addr >= dev->regions[region].guest_phys_addr &&
860 		    addr < dev->regions[region].guest_phys_addr +
861 			   dev->regions[region].size)
862 			return (uint8_t *)dev->region_vaddr[region] +
863 			       (addr -
864 				dev->regions[region].guest_phys_addr +
865 				dev->regions[region].mmap_offset);
866 	}
867 
868 	USFSTL_ASSERT(0, "cannot translate physical address %"PRIx64"\n", addr);
869 	return NULL;
870 }
871 
iov_len(struct iovec * sg,unsigned int nsg)872 size_t iov_len(struct iovec *sg, unsigned int nsg)
873 {
874 	size_t len = 0;
875 	unsigned int i;
876 
877 	for (i = 0; i < nsg; i++)
878 		len += sg[i].iov_len;
879 
880 	return len;
881 }
882 
iov_fill(struct iovec * sg,unsigned int nsg,const void * _buf,size_t buflen)883 size_t iov_fill(struct iovec *sg, unsigned int nsg,
884 		const void *_buf, size_t buflen)
885 {
886 	const char *buf = _buf;
887 	unsigned int i;
888 	size_t copied = 0;
889 
890 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
891 	for (i = 0; buflen && i < nsg; i++) {
892 		size_t cpy = min(buflen, sg[i].iov_len);
893 
894 		memcpy(sg[i].iov_base, buf, cpy);
895 		buflen -= cpy;
896 		copied += cpy;
897 		buf += cpy;
898 	}
899 
900 	return copied;
901 }
902 
iov_read(void * _buf,size_t buflen,struct iovec * sg,unsigned int nsg)903 size_t iov_read(void *_buf, size_t buflen,
904 		struct iovec *sg, unsigned int nsg)
905 {
906 	char *buf = _buf;
907 	unsigned int i;
908 	size_t copied = 0;
909 
910 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
911 	for (i = 0; buflen && i < nsg; i++) {
912 		size_t cpy = min(buflen, sg[i].iov_len);
913 
914 		memcpy(buf, sg[i].iov_base, cpy);
915 		buflen -= cpy;
916 		copied += cpy;
917 		buf += cpy;
918 	}
919 
920 	return copied;
921 }
922