1 /*
2 * Copyright (C) 2019 - 2020 Intel Corporation
3 *
4 * SPDX-License-Identifier: BSD-3-Clause
5 */
6 #include <stdlib.h>
7 #include <usfstl/list.h>
8 #include <usfstl/loop.h>
9 #include <usfstl/uds.h>
10 #include <sys/socket.h>
11 #include <sys/mman.h>
12 #include <sys/un.h>
13 #include <stdlib.h>
14 #include <errno.h>
15 #include <usfstl/vhost.h>
16 #include <linux/virtio_ring.h>
17 #include <linux/virtio_config.h>
18 #include <endian.h>
19
20 /* copied from uapi */
21 #define VIRTIO_F_VERSION_1 32
22
23 #define MAX_REGIONS 8
24 #define SG_STACK_PREALLOC 5
25
26 struct usfstl_vhost_user_dev_int {
27 struct usfstl_list fds;
28 struct usfstl_job irq_job;
29
30 struct usfstl_loop_entry entry;
31
32 struct usfstl_vhost_user_dev ext;
33
34 unsigned int n_regions;
35 struct vhost_user_region regions[MAX_REGIONS];
36 int region_fds[MAX_REGIONS];
37 void *region_vaddr[MAX_REGIONS];
38
39 int req_fd;
40
41 struct {
42 struct usfstl_loop_entry entry;
43 bool enabled;
44 bool sleeping;
45 bool triggered;
46 uint64_t desc_guest_addr;
47 uint64_t avail_guest_addr;
48 uint64_t used_guest_addr;
49 struct vring virtq;
50 int call_fd;
51 uint16_t last_avail_idx;
52 } virtqs[];
53 };
54
55 #define CONV(bits) \
56 static inline uint##bits##_t __attribute__((used)) \
57 cpu_to_virtio##bits(struct usfstl_vhost_user_dev_int *dev, \
58 uint##bits##_t v) \
59 { \
60 if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1)) \
61 return htole##bits(v); \
62 return v; \
63 } \
64 static inline uint##bits##_t __attribute__((used)) \
65 virtio_to_cpu##bits(struct usfstl_vhost_user_dev_int *dev, \
66 uint##bits##_t v) \
67 { \
68 if (dev->ext.features & (1ULL << VIRTIO_F_VERSION_1)) \
69 return le##bits##toh(v); \
70 return v; \
71 }
72
73 CONV(16)
74 CONV(32)
75 CONV(64)
76
77 static struct usfstl_vhost_user_buf *
usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx,struct usfstl_vhost_user_buf * fixed)78 usfstl_vhost_user_get_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
79 unsigned int virtq_idx,
80 struct usfstl_vhost_user_buf *fixed)
81 {
82 struct usfstl_vhost_user_buf *buf = fixed;
83 struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
84 uint16_t avail_idx = virtio_to_cpu16(dev, virtq->avail->idx);
85 uint16_t idx, desc_idx;
86 struct vring_desc *desc;
87 unsigned int n_in = 0, n_out = 0;
88 bool more;
89
90 if (avail_idx == dev->virtqs[virtq_idx].last_avail_idx)
91 return NULL;
92
93 /* ensure we read the descriptor after checking the index */
94 __sync_synchronize();
95
96 idx = dev->virtqs[virtq_idx].last_avail_idx++;
97 idx %= virtq->num;
98 desc_idx = virtio_to_cpu16(dev, virtq->avail->ring[idx]);
99 USFSTL_ASSERT(desc_idx < virtq->num);
100
101 desc = &virtq->desc[desc_idx];
102 do {
103 more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
104
105 if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE)
106 n_in++;
107 else
108 n_out++;
109 desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
110 } while (more);
111
112 if (n_in > fixed->n_in_sg || n_out > fixed->n_out_sg) {
113 size_t sz = sizeof(*buf);
114 struct iovec *vec;
115
116 sz += (n_in + n_out) * sizeof(*vec);
117
118 buf = calloc(1, sz);
119 if (!buf)
120 return NULL;
121
122 vec = (void *)(buf + 1);
123 buf->in_sg = vec;
124 buf->out_sg = vec + n_in;
125 buf->allocated = true;
126 }
127
128 buf->n_in_sg = 0;
129 buf->n_out_sg = 0;
130 buf->idx = desc_idx;
131
132 desc = &virtq->desc[desc_idx];
133 do {
134 struct iovec *vec;
135 uint64_t addr;
136
137 more = virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_NEXT;
138
139 if (virtio_to_cpu16(dev, desc->flags) & VRING_DESC_F_WRITE) {
140 vec = &buf->in_sg[buf->n_in_sg];
141 buf->n_in_sg++;
142 } else {
143 vec = &buf->out_sg[buf->n_out_sg];
144 buf->n_out_sg++;
145 }
146
147 addr = virtio_to_cpu64(dev, desc->addr);
148 vec->iov_base = usfstl_vhost_phys_to_va(&dev->ext, addr);
149 vec->iov_len = virtio_to_cpu32(dev, desc->len);
150
151 desc = &virtq->desc[virtio_to_cpu16(dev, desc->next)];
152 } while (more);
153
154 return buf;
155 }
156
usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf * buf)157 static void usfstl_vhost_user_free_buf(struct usfstl_vhost_user_buf *buf)
158 {
159 if (buf->allocated)
160 free(buf);
161 }
162
usfstl_vhost_user_readable_handler(struct usfstl_loop_entry * entry)163 static void usfstl_vhost_user_readable_handler(struct usfstl_loop_entry *entry)
164 {
165 usfstl_loop_unregister(entry);
166 entry->fd = -1;
167 }
168
usfstl_vhost_user_read_msg(int fd,struct msghdr * msghdr)169 static int usfstl_vhost_user_read_msg(int fd, struct msghdr *msghdr)
170 {
171 struct iovec msg_iov;
172 struct msghdr hdr2 = {
173 .msg_iov = &msg_iov,
174 .msg_iovlen = 1,
175 .msg_control = msghdr->msg_control,
176 .msg_controllen = msghdr->msg_controllen,
177 };
178 struct vhost_user_msg_hdr *hdr;
179 size_t i;
180 size_t maxlen = 0;
181 ssize_t len;
182 ssize_t prev_datalen;
183 size_t prev_iovlen;
184
185 USFSTL_ASSERT(msghdr->msg_iovlen >= 1);
186 USFSTL_ASSERT(msghdr->msg_iov[0].iov_len >= sizeof(*hdr));
187
188 hdr = msghdr->msg_iov[0].iov_base;
189 msg_iov.iov_base = hdr;
190 msg_iov.iov_len = sizeof(*hdr);
191
192 len = recvmsg(fd, &hdr2, 0);
193 if (len < 0)
194 return -errno;
195 if (len == 0)
196 return -ENOTCONN;
197
198 for (i = 0; i < msghdr->msg_iovlen; i++)
199 maxlen += msghdr->msg_iov[i].iov_len;
200 maxlen -= sizeof(*hdr);
201
202 USFSTL_ASSERT_EQ((int)len, (int)sizeof(*hdr), "%d");
203 USFSTL_ASSERT(hdr->size <= maxlen);
204
205 if (!hdr->size)
206 return 0;
207
208 prev_iovlen = msghdr->msg_iovlen;
209 msghdr->msg_iovlen = 1;
210
211 msghdr->msg_control = NULL;
212 msghdr->msg_controllen = 0;
213 msghdr->msg_iov[0].iov_base += sizeof(*hdr);
214 prev_datalen = msghdr->msg_iov[0].iov_len;
215 msghdr->msg_iov[0].iov_len = hdr->size;
216 len = recvmsg(fd, msghdr, 0);
217
218 /* restore just in case the user needs it */
219 msghdr->msg_iov[0].iov_base -= sizeof(*hdr);
220 msghdr->msg_iov[0].iov_len = prev_datalen;
221 msghdr->msg_control = hdr2.msg_control;
222 msghdr->msg_controllen = hdr2.msg_controllen;
223
224 msghdr->msg_iovlen = prev_iovlen;
225
226 if (len < 0)
227 return -errno;
228 if (len == 0)
229 return -ENOTCONN;
230
231 USFSTL_ASSERT_EQ(hdr->size, (uint32_t)len, "%u");
232
233 return 0;
234 }
235
usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int * dev,struct vhost_user_msg * msg)236 static void usfstl_vhost_user_send_msg(struct usfstl_vhost_user_dev_int *dev,
237 struct vhost_user_msg *msg)
238 {
239 size_t msgsz = sizeof(msg->hdr) + msg->hdr.size;
240 bool ack = dev->ext.protocol_features &
241 (1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK);
242 ssize_t written;
243
244 if (ack)
245 msg->hdr.flags |= VHOST_USER_MSG_FLAGS_NEED_REPLY;
246
247 written = write(dev->req_fd, msg, msgsz);
248 USFSTL_ASSERT_EQ(written, (ssize_t)msgsz, "%zd");
249
250 if (ack) {
251 struct usfstl_loop_entry entry = {
252 .fd = dev->req_fd,
253 .priority = 0x7fffffff, // max
254 .handler = usfstl_vhost_user_readable_handler,
255 };
256 struct iovec msg_iov = {
257 .iov_base = msg,
258 .iov_len = sizeof(*msg),
259 };
260 struct msghdr msghdr = {
261 .msg_iovlen = 1,
262 .msg_iov = &msg_iov,
263 };
264
265 /*
266 * Wait for the fd to be readable - we may have to
267 * handle other simulation (time) messages while
268 * waiting ...
269 */
270 usfstl_loop_register(&entry);
271 while (entry.fd != -1)
272 usfstl_loop_wait_and_handle();
273 USFSTL_ASSERT_EQ(usfstl_vhost_user_read_msg(dev->req_fd,
274 &msghdr),
275 0, "%d");
276 }
277 }
278
usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int * dev,struct usfstl_vhost_user_buf * buf,int virtq_idx)279 static void usfstl_vhost_user_send_virtq_buf(struct usfstl_vhost_user_dev_int *dev,
280 struct usfstl_vhost_user_buf *buf,
281 int virtq_idx)
282 {
283 struct vring *virtq = &dev->virtqs[virtq_idx].virtq;
284 unsigned int idx, widx;
285 int call_fd = dev->virtqs[virtq_idx].call_fd;
286 ssize_t written;
287 uint64_t e = 1;
288
289 if (dev->ext.server->ctrl)
290 usfstl_sched_ctrl_sync_to(dev->ext.server->ctrl);
291
292 idx = virtio_to_cpu16(dev, virtq->used->idx);
293 widx = idx + 1;
294
295 idx %= virtq->num;
296 virtq->used->ring[idx].id = cpu_to_virtio32(dev, buf->idx);
297 virtq->used->ring[idx].len = cpu_to_virtio32(dev, buf->written);
298
299 /* write buffers / used table before flush */
300 __sync_synchronize();
301
302 virtq->used->idx = cpu_to_virtio16(dev, widx);
303
304 if (call_fd < 0 &&
305 dev->ext.protocol_features &
306 (1ULL << VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
307 dev->ext.protocol_features &
308 (1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
309 struct vhost_user_msg msg = {
310 .hdr.request = VHOST_USER_SLAVE_VRING_CALL,
311 .hdr.flags = VHOST_USER_VERSION,
312 .hdr.size = sizeof(msg.payload.vring_state),
313 .payload.vring_state = {
314 .idx = virtq_idx,
315 },
316 };
317
318 usfstl_vhost_user_send_msg(dev, &msg);
319 return;
320 }
321
322 written = write(dev->virtqs[virtq_idx].call_fd, &e, sizeof(e));
323 USFSTL_ASSERT_EQ(written, (ssize_t)sizeof(e), "%zd");
324 }
325
usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq_idx)326 static void usfstl_vhost_user_handle_queue(struct usfstl_vhost_user_dev_int *dev,
327 unsigned int virtq_idx)
328 {
329 /* preallocate on the stack for most cases */
330 struct iovec in_sg[SG_STACK_PREALLOC] = { };
331 struct iovec out_sg[SG_STACK_PREALLOC] = { };
332 struct usfstl_vhost_user_buf _buf = {
333 .in_sg = in_sg,
334 .n_in_sg = SG_STACK_PREALLOC,
335 .out_sg = out_sg,
336 .n_out_sg = SG_STACK_PREALLOC,
337 };
338 struct usfstl_vhost_user_buf *buf;
339
340 while ((buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf))) {
341 dev->ext.server->ops->handle(&dev->ext, buf, virtq_idx);
342
343 usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
344 usfstl_vhost_user_free_buf(buf);
345 }
346 }
347
usfstl_vhost_user_job_callback(struct usfstl_job * job)348 static void usfstl_vhost_user_job_callback(struct usfstl_job *job)
349 {
350 struct usfstl_vhost_user_dev_int *dev = job->data;
351 unsigned int virtq;
352
353 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
354 if (!dev->virtqs[virtq].triggered)
355 continue;
356 dev->virtqs[virtq].triggered = false;
357
358 usfstl_vhost_user_handle_queue(dev, virtq);
359 }
360 }
361
usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq)362 static void usfstl_vhost_user_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
363 unsigned int virtq)
364 {
365 if (!(dev->ext.server->input_queues & (1ULL << virtq)))
366 return;
367
368 dev->virtqs[virtq].triggered = true;
369
370 if (usfstl_job_scheduled(&dev->irq_job))
371 return;
372
373 if (!dev->ext.server->scheduler) {
374 usfstl_vhost_user_job_callback(&dev->irq_job);
375 return;
376 }
377
378 if (dev->ext.server->ctrl)
379 usfstl_sched_ctrl_sync_from(dev->ext.server->ctrl);
380
381 dev->irq_job.start = usfstl_sched_current_time(dev->ext.server->scheduler) +
382 dev->ext.server->interrupt_latency;
383 usfstl_sched_add_job(dev->ext.server->scheduler, &dev->irq_job);
384 }
385
usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry * entry)386 static void usfstl_vhost_user_virtq_fdkick(struct usfstl_loop_entry *entry)
387 {
388 struct usfstl_vhost_user_dev_int *dev = entry->data;
389 unsigned int virtq;
390 uint64_t v;
391
392 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
393 if (entry == &dev->virtqs[virtq].entry)
394 break;
395 }
396
397 USFSTL_ASSERT(virtq < dev->ext.server->max_queues);
398
399 USFSTL_ASSERT_EQ((int)read(entry->fd, &v, sizeof(v)),
400 (int)sizeof(v), "%d");
401
402 usfstl_vhost_user_virtq_kick(dev, virtq);
403 }
404
usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int * dev)405 static void usfstl_vhost_user_clear_mappings(struct usfstl_vhost_user_dev_int *dev)
406 {
407 unsigned int idx;
408 for (idx = 0; idx < MAX_REGIONS; idx++) {
409 if (dev->region_vaddr[idx]) {
410 munmap(dev->region_vaddr[idx],
411 dev->regions[idx].size + dev->regions[idx].mmap_offset);
412 dev->region_vaddr[idx] = NULL;
413 }
414
415 if (dev->region_fds[idx] != -1) {
416 close(dev->region_fds[idx]);
417 dev->region_fds[idx] = -1;
418 }
419 }
420 }
421
usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int * dev)422 static void usfstl_vhost_user_setup_mappings(struct usfstl_vhost_user_dev_int *dev)
423 {
424 unsigned int idx;
425
426 for (idx = 0; idx < dev->n_regions; idx++) {
427 USFSTL_ASSERT(!dev->region_vaddr[idx]);
428
429 /*
430 * Cannot rely on the offset being page-aligned, I think ...
431 * adjust for it later when we translate addresses instead.
432 */
433 dev->region_vaddr[idx] = mmap(NULL,
434 dev->regions[idx].size +
435 dev->regions[idx].mmap_offset,
436 PROT_READ | PROT_WRITE, MAP_SHARED,
437 dev->region_fds[idx], 0);
438 USFSTL_ASSERT(dev->region_vaddr[idx] != (void *)-1,
439 "mmap() failed (%d) for fd %d", errno, dev->region_fds[idx]);
440 }
441 }
442
443 static void
usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int * dev,unsigned int virtq,int fd)444 usfstl_vhost_user_update_virtq_kick(struct usfstl_vhost_user_dev_int *dev,
445 unsigned int virtq, int fd)
446 {
447 if (dev->virtqs[virtq].entry.fd != -1) {
448 usfstl_loop_unregister(&dev->virtqs[virtq].entry);
449 close(dev->virtqs[virtq].entry.fd);
450 }
451
452 if (fd != -1) {
453 dev->virtqs[virtq].entry.fd = fd;
454 usfstl_loop_register(&dev->virtqs[virtq].entry);
455 }
456 }
457
usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int * dev)458 static void usfstl_vhost_user_dev_free(struct usfstl_vhost_user_dev_int *dev)
459 {
460 unsigned int virtq;
461
462 usfstl_loop_unregister(&dev->entry);
463 usfstl_sched_del_job(&dev->irq_job);
464
465 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
466 usfstl_vhost_user_update_virtq_kick(dev, virtq, -1);
467 if (dev->virtqs[virtq].call_fd != -1)
468 close(dev->virtqs[virtq].call_fd);
469 }
470
471 usfstl_vhost_user_clear_mappings(dev);
472
473 if (dev->req_fd != -1)
474 close(dev->req_fd);
475
476 if (dev->ext.server->ops->disconnected)
477 dev->ext.server->ops->disconnected(&dev->ext);
478
479 if (dev->entry.fd != -1)
480 close(dev->entry.fd);
481
482 free(dev);
483 }
484
usfstl_vhost_user_get_msg_fds(struct msghdr * msghdr,int * outfds,int max_fds)485 static void usfstl_vhost_user_get_msg_fds(struct msghdr *msghdr,
486 int *outfds, int max_fds)
487 {
488 struct cmsghdr *msg;
489 int fds;
490
491 for (msg = CMSG_FIRSTHDR(msghdr); msg; msg = CMSG_NXTHDR(msghdr, msg)) {
492 if (msg->cmsg_level != SOL_SOCKET)
493 continue;
494 if (msg->cmsg_type != SCM_RIGHTS)
495 continue;
496
497 fds = (msg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
498 USFSTL_ASSERT(fds <= max_fds);
499 memcpy(outfds, CMSG_DATA(msg), fds * sizeof(int));
500 break;
501 }
502 }
503
usfstl_vhost_user_handle_msg(struct usfstl_loop_entry * entry)504 static void usfstl_vhost_user_handle_msg(struct usfstl_loop_entry *entry)
505 {
506 struct usfstl_vhost_user_dev_int *dev;
507 struct vhost_user_msg msg;
508 uint8_t data[256]; // limits the config space size ...
509 struct iovec msg_iov[3] = {
510 [0] = {
511 .iov_base = &msg.hdr,
512 .iov_len = sizeof(msg.hdr),
513 },
514 [1] = {
515 .iov_base = &msg.payload,
516 .iov_len = sizeof(msg.payload),
517 },
518 [2] = {
519 .iov_base = data,
520 .iov_len = sizeof(data),
521 },
522 };
523 uint8_t msg_control[CMSG_SPACE(sizeof(int) * MAX_REGIONS)] = { 0 };
524 struct msghdr msghdr = {
525 .msg_iov = msg_iov,
526 .msg_iovlen = 3,
527 .msg_control = msg_control,
528 .msg_controllen = sizeof(msg_control),
529 };
530 ssize_t len;
531 size_t reply_len = 0;
532 unsigned int virtq;
533 int fd;
534
535 dev = container_of(entry, struct usfstl_vhost_user_dev_int, entry);
536
537 if (usfstl_vhost_user_read_msg(entry->fd, &msghdr)) {
538 usfstl_vhost_user_dev_free(dev);
539 return;
540 }
541 len = msg.hdr.size;
542
543 USFSTL_ASSERT((msg.hdr.flags & VHOST_USER_MSG_FLAGS_VERSION) ==
544 VHOST_USER_VERSION);
545
546 switch (msg.hdr.request) {
547 case VHOST_USER_GET_FEATURES:
548 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
549 reply_len = sizeof(uint64_t);
550 msg.payload.u64 = dev->ext.server->features;
551 msg.payload.u64 |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
552 break;
553 case VHOST_USER_SET_FEATURES:
554 USFSTL_ASSERT_EQ(len, (ssize_t)sizeof(msg.payload.u64), "%zd");
555 dev->ext.features = msg.payload.u64;
556 break;
557 case VHOST_USER_SET_OWNER:
558 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
559 /* nothing to be done */
560 break;
561 case VHOST_USER_SET_MEM_TABLE:
562 USFSTL_ASSERT(len <= (int)sizeof(msg.payload.mem_regions));
563 USFSTL_ASSERT(msg.payload.mem_regions.n_regions <= MAX_REGIONS);
564 usfstl_vhost_user_clear_mappings(dev);
565 memcpy(dev->regions, msg.payload.mem_regions.regions,
566 msg.payload.mem_regions.n_regions *
567 sizeof(dev->regions[0]));
568 dev->n_regions = msg.payload.mem_regions.n_regions;
569 usfstl_vhost_user_get_msg_fds(&msghdr, dev->region_fds, MAX_REGIONS);
570 usfstl_vhost_user_setup_mappings(dev);
571 break;
572 case VHOST_USER_SET_VRING_NUM:
573 USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
574 USFSTL_ASSERT(msg.payload.vring_state.idx <
575 dev->ext.server->max_queues);
576 dev->virtqs[msg.payload.vring_state.idx].virtq.num =
577 msg.payload.vring_state.num;
578 break;
579 case VHOST_USER_SET_VRING_ADDR:
580 USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_addr));
581 USFSTL_ASSERT(msg.payload.vring_addr.idx <=
582 dev->ext.server->max_queues);
583 USFSTL_ASSERT_EQ(msg.payload.vring_addr.flags, (uint32_t)0, "0x%x");
584 USFSTL_ASSERT(!dev->virtqs[msg.payload.vring_addr.idx].enabled);
585
586 // Save the guest physical addresses to make snapshotting more convenient.
587 dev->virtqs[msg.payload.vring_addr.idx].desc_guest_addr =
588 usfstl_vhost_user_to_phys(&dev->ext, msg.payload.vring_addr.descriptor);
589 dev->virtqs[msg.payload.vring_addr.idx].used_guest_addr =
590 usfstl_vhost_user_to_phys(&dev->ext, msg.payload.vring_addr.used);
591 dev->virtqs[msg.payload.vring_addr.idx].avail_guest_addr =
592 usfstl_vhost_user_to_phys(&dev->ext, msg.payload.vring_addr.avail);
593
594 dev->virtqs[msg.payload.vring_addr.idx].last_avail_idx = 0;
595 dev->virtqs[msg.payload.vring_addr.idx].virtq.desc =
596 usfstl_vhost_user_to_va(&dev->ext,
597 msg.payload.vring_addr.descriptor);
598 dev->virtqs[msg.payload.vring_addr.idx].virtq.used =
599 usfstl_vhost_user_to_va(&dev->ext,
600 msg.payload.vring_addr.used);
601 dev->virtqs[msg.payload.vring_addr.idx].virtq.avail =
602 usfstl_vhost_user_to_va(&dev->ext,
603 msg.payload.vring_addr.avail);
604 USFSTL_ASSERT(dev->virtqs[msg.payload.vring_addr.idx].virtq.avail &&
605 dev->virtqs[msg.payload.vring_addr.idx].virtq.desc &&
606 dev->virtqs[msg.payload.vring_addr.idx].virtq.used);
607 break;
608 case VHOST_USER_SET_VRING_BASE:
609 /* ignored - logging not supported */
610 /*
611 * FIXME: our Linux UML virtio implementation
612 * shouldn't send this
613 */
614 break;
615 case VHOST_USER_SET_VRING_KICK:
616 USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
617 virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
618 USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
619 if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
620 fd = -1;
621 else
622 usfstl_vhost_user_get_msg_fds(&msghdr, &fd, 1);
623 usfstl_vhost_user_update_virtq_kick(dev, virtq, fd);
624 break;
625 case VHOST_USER_SET_VRING_CALL:
626 USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
627 virtq = msg.payload.u64 & VHOST_USER_U64_VRING_IDX_MSK;
628 USFSTL_ASSERT(virtq <= dev->ext.server->max_queues);
629 if (dev->virtqs[virtq].call_fd != -1)
630 close(dev->virtqs[virtq].call_fd);
631 if (msg.payload.u64 & VHOST_USER_U64_NO_FD)
632 dev->virtqs[virtq].call_fd = -1;
633 else
634 usfstl_vhost_user_get_msg_fds(&msghdr,
635 &dev->virtqs[virtq].call_fd,
636 1);
637 break;
638 case VHOST_USER_GET_PROTOCOL_FEATURES:
639 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
640 reply_len = sizeof(uint64_t);
641 msg.payload.u64 = dev->ext.server->protocol_features;
642 if (dev->ext.server->config && dev->ext.server->config_len)
643 msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
644 msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ;
645 msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD;
646 msg.payload.u64 |= 1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK;
647 break;
648 case VHOST_USER_SET_VRING_ENABLE:
649 USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
650 USFSTL_ASSERT(msg.payload.vring_state.idx <
651 dev->ext.server->max_queues);
652 dev->virtqs[msg.payload.vring_state.idx].enabled =
653 msg.payload.vring_state.num;
654 break;
655 case VHOST_USER_SET_PROTOCOL_FEATURES:
656 USFSTL_ASSERT(len == (int)sizeof(msg.payload.u64));
657 dev->ext.protocol_features = msg.payload.u64;
658 break;
659 case VHOST_USER_SET_SLAVE_REQ_FD:
660 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
661 if (dev->req_fd != -1)
662 close(dev->req_fd);
663 usfstl_vhost_user_get_msg_fds(&msghdr, &dev->req_fd, 1);
664 USFSTL_ASSERT(dev->req_fd != -1);
665 break;
666 case VHOST_USER_GET_CONFIG:
667 USFSTL_ASSERT(len == (int)(sizeof(msg.payload.cfg_space) +
668 msg.payload.cfg_space.size));
669 USFSTL_ASSERT(dev->ext.server->config && dev->ext.server->config_len);
670 USFSTL_ASSERT(msg.payload.cfg_space.offset == 0);
671 USFSTL_ASSERT(msg.payload.cfg_space.size <= dev->ext.server->config_len);
672 msg.payload.cfg_space.flags = 0;
673 msg_iov[1].iov_len = sizeof(msg.payload.cfg_space);
674 msg_iov[2].iov_base = (void *)dev->ext.server->config;
675 reply_len = len;
676 break;
677 case VHOST_USER_VRING_KICK:
678 USFSTL_ASSERT(len == (int)sizeof(msg.payload.vring_state));
679 USFSTL_ASSERT(msg.payload.vring_state.idx <
680 dev->ext.server->max_queues);
681 USFSTL_ASSERT(msg.payload.vring_state.num == 0);
682 usfstl_vhost_user_virtq_kick(dev, msg.payload.vring_state.idx);
683 break;
684 case VHOST_USER_GET_SHARED_MEMORY_REGIONS:
685 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
686 reply_len = sizeof(uint64_t);
687 msg.payload.u64 = 0;
688 break;
689 case VHOST_USER_SLEEP:
690 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
691 USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
692 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
693 if (dev->virtqs[virtq].enabled) {
694 dev->virtqs[virtq].enabled = false;
695 dev->virtqs[virtq].sleeping = true;
696 usfstl_loop_unregister(&dev->virtqs[virtq].entry);
697 }
698 }
699 msg.payload.i8 = 1; // success
700 reply_len = sizeof(msg.payload.i8);
701 break;
702 case VHOST_USER_WAKE:
703 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
704 USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
705 // enable previously enabled queues on wake
706 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
707 if (dev->virtqs[virtq].sleeping) {
708 dev->virtqs[virtq].enabled = true;
709 dev->virtqs[virtq].sleeping = false;
710 usfstl_loop_register(&dev->virtqs[virtq].entry);
711 // TODO: is this needed?
712 usfstl_vhost_user_virtq_kick(dev, virtq);
713 }
714 }
715 msg.payload.i8 = 1; // success
716 reply_len = sizeof(msg.payload.i8);
717 break;
718 case VHOST_USER_SNAPSHOT: {
719 USFSTL_ASSERT_EQ(len, (ssize_t)0, "%zd");
720 USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
721 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
722 struct vring_snapshot* snapshot = &msg.payload.snapshot_response.snapshot.vrings[virtq];
723 snapshot->enabled = dev->virtqs[virtq].enabled;
724 snapshot->sleeping = dev->virtqs[virtq].sleeping;
725 snapshot->triggered = dev->virtqs[virtq].triggered;
726 snapshot->num = dev->virtqs[virtq].virtq.num;
727 snapshot->desc_guest_addr = dev->virtqs[virtq].desc_guest_addr;
728 snapshot->avail_guest_addr = dev->virtqs[virtq].avail_guest_addr;
729 snapshot->used_guest_addr = dev->virtqs[virtq].used_guest_addr;
730 snapshot->last_avail_idx = dev->virtqs[virtq].last_avail_idx;
731 }
732 msg.payload.snapshot_response.bool_store = 1;
733 reply_len = (int)sizeof(msg.payload.snapshot_response);
734 break;
735 }
736 case VHOST_USER_RESTORE: {
737 int *fds;
738 USFSTL_ASSERT(len == (int)sizeof(msg.payload.restore_request));
739 USFSTL_ASSERT_EQ(dev->ext.server->max_queues, NUM_SNAPSHOT_QUEUES, "%d");
740
741 fds = (int*)malloc(dev->ext.server->max_queues * sizeof(int));
742 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
743 fds[virtq] = -1;
744 }
745 usfstl_vhost_user_get_msg_fds(&msghdr, fds, 2);
746
747 for (virtq = 0; virtq < dev->ext.server->max_queues; virtq++) {
748 const struct vring_snapshot* snapshot = &msg.payload.restore_request.snapshot.vrings[virtq];
749 dev->virtqs[virtq].enabled = snapshot->enabled;
750 dev->virtqs[virtq].sleeping = snapshot->sleeping;
751 dev->virtqs[virtq].triggered = snapshot->triggered;
752 dev->virtqs[virtq].virtq.num = snapshot->num;
753 dev->virtqs[virtq].desc_guest_addr = snapshot->desc_guest_addr;
754 dev->virtqs[virtq].avail_guest_addr = snapshot->avail_guest_addr;
755 dev->virtqs[virtq].used_guest_addr = snapshot->used_guest_addr;
756 dev->virtqs[virtq].last_avail_idx = snapshot->last_avail_idx;
757
758 dev->virtqs[virtq].entry.fd = fds[virtq];
759
760 // Translate vring guest physical addresses.
761 dev->virtqs[virtq].virtq.desc = usfstl_vhost_phys_to_va(&dev->ext, dev->virtqs[virtq].desc_guest_addr);
762 dev->virtqs[virtq].virtq.used = usfstl_vhost_phys_to_va(&dev->ext, dev->virtqs[virtq].used_guest_addr);
763 dev->virtqs[virtq].virtq.avail = usfstl_vhost_phys_to_va(&dev->ext, dev->virtqs[virtq].avail_guest_addr);
764 USFSTL_ASSERT(dev->virtqs[virtq].virtq.avail &&
765 dev->virtqs[virtq].virtq.desc &&
766 dev->virtqs[virtq].virtq.used);
767 }
768
769 free(fds);
770
771 msg.payload.i8 = 1; // success
772 reply_len = sizeof(msg.payload.i8);
773 break;
774 }
775 default:
776 USFSTL_ASSERT(0, "Unsupported message: %d\n", msg.hdr.request);
777 }
778
779 if (reply_len || (msg.hdr.flags & VHOST_USER_MSG_FLAGS_NEED_REPLY)) {
780 size_t i, tmp;
781
782 if (!reply_len) {
783 msg.payload.u64 = 0;
784 reply_len = sizeof(uint64_t);
785 }
786
787 msg.hdr.size = reply_len;
788 msg.hdr.flags &= ~VHOST_USER_MSG_FLAGS_NEED_REPLY;
789 msg.hdr.flags |= VHOST_USER_MSG_FLAGS_REPLY;
790
791 msghdr.msg_control = NULL;
792 msghdr.msg_controllen = 0;
793
794 reply_len += sizeof(msg.hdr);
795
796 tmp = reply_len;
797 for (i = 0; tmp && i < msghdr.msg_iovlen; i++) {
798 if (tmp <= msg_iov[i].iov_len)
799 msg_iov[i].iov_len = tmp;
800 tmp -= msg_iov[i].iov_len;
801 }
802 msghdr.msg_iovlen = i;
803
804 while (reply_len) {
805 len = sendmsg(entry->fd, &msghdr, 0);
806 if (len < 0) {
807 usfstl_vhost_user_dev_free(dev);
808 return;
809 }
810 USFSTL_ASSERT(len != 0);
811 reply_len -= len;
812
813 for (i = 0; len && i < msghdr.msg_iovlen; i++) {
814 unsigned int rm = len;
815
816 if (msg_iov[i].iov_len <= (size_t)len)
817 rm = msg_iov[i].iov_len;
818 len -= rm;
819 msg_iov[i].iov_len -= rm;
820 msg_iov[i].iov_base += rm;
821 }
822 }
823 }
824 }
825
usfstl_vhost_user_connected(int fd,void * data)826 static void usfstl_vhost_user_connected(int fd, void *data)
827 {
828 struct usfstl_vhost_user_server *server = data;
829 struct usfstl_vhost_user_dev_int *dev;
830 unsigned int i;
831
832 dev = calloc(1, sizeof(*dev) +
833 sizeof(dev->virtqs[0]) * server->max_queues);
834
835 USFSTL_ASSERT(dev);
836
837 for (i = 0; i < server->max_queues; i++) {
838 dev->virtqs[i].call_fd = -1;
839 dev->virtqs[i].entry.fd = -1;
840 dev->virtqs[i].entry.data = dev;
841 dev->virtqs[i].entry.handler = usfstl_vhost_user_virtq_fdkick;
842 }
843
844 for (i = 0; i < MAX_REGIONS; i++)
845 dev->region_fds[i] = -1;
846 dev->req_fd = -1;
847
848 dev->ext.server = server;
849 dev->irq_job.data = dev;
850 dev->irq_job.name = "vhost-user-irq";
851 dev->irq_job.priority = 0x10000000;
852 dev->irq_job.callback = usfstl_vhost_user_job_callback;
853 usfstl_list_init(&dev->fds);
854
855 if (server->ops->connected)
856 server->ops->connected(&dev->ext);
857
858 dev->entry.fd = fd;
859 dev->entry.handler = usfstl_vhost_user_handle_msg;
860
861 usfstl_loop_register(&dev->entry);
862 }
863
usfstl_vhost_user_server_start(struct usfstl_vhost_user_server * server)864 void usfstl_vhost_user_server_start(struct usfstl_vhost_user_server *server)
865 {
866 USFSTL_ASSERT(server->ops);
867 USFSTL_ASSERT(server->socket);
868
869 usfstl_uds_create(server->socket, usfstl_vhost_user_connected, server);
870 }
871
usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server * server)872 void usfstl_vhost_user_server_stop(struct usfstl_vhost_user_server *server)
873 {
874 usfstl_uds_remove(server->socket);
875 }
876
usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev * extdev,unsigned int virtq_idx,const uint8_t * data,size_t datalen)877 void usfstl_vhost_user_dev_notify(struct usfstl_vhost_user_dev *extdev,
878 unsigned int virtq_idx,
879 const uint8_t *data, size_t datalen)
880 {
881 struct usfstl_vhost_user_dev_int *dev;
882 /* preallocate on the stack for most cases */
883 struct iovec in_sg[SG_STACK_PREALLOC] = { };
884 struct iovec out_sg[SG_STACK_PREALLOC] = { };
885 struct usfstl_vhost_user_buf _buf = {
886 .in_sg = in_sg,
887 .n_in_sg = SG_STACK_PREALLOC,
888 .out_sg = out_sg,
889 .n_out_sg = SG_STACK_PREALLOC,
890 };
891 struct usfstl_vhost_user_buf *buf;
892
893 dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
894
895 USFSTL_ASSERT(virtq_idx <= dev->ext.server->max_queues);
896
897 if (!dev->virtqs[virtq_idx].enabled)
898 return;
899
900 buf = usfstl_vhost_user_get_virtq_buf(dev, virtq_idx, &_buf);
901 if (!buf)
902 return;
903
904 USFSTL_ASSERT(buf->n_in_sg && !buf->n_out_sg);
905 iov_fill(buf->in_sg, buf->n_in_sg, data, datalen);
906 buf->written = datalen;
907
908 usfstl_vhost_user_send_virtq_buf(dev, buf, virtq_idx);
909 usfstl_vhost_user_free_buf(buf);
910 }
911
usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev * dev)912 void usfstl_vhost_user_config_changed(struct usfstl_vhost_user_dev *dev)
913 {
914 struct usfstl_vhost_user_dev_int *idev;
915 struct vhost_user_msg msg = {
916 .hdr.request = VHOST_USER_SLAVE_CONFIG_CHANGE_MSG,
917 .hdr.flags = VHOST_USER_VERSION,
918 };
919
920 idev = container_of(dev, struct usfstl_vhost_user_dev_int, ext);
921
922 if (!(idev->ext.protocol_features &
923 (1ULL << VHOST_USER_PROTOCOL_F_CONFIG)))
924 return;
925
926 usfstl_vhost_user_send_msg(idev, &msg);
927 }
928
usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)929 void *usfstl_vhost_user_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
930 {
931 struct usfstl_vhost_user_dev_int *dev;
932 unsigned int region;
933
934 dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
935
936 for (region = 0; region < dev->n_regions; region++) {
937 if (addr >= dev->regions[region].user_addr &&
938 addr < dev->regions[region].user_addr +
939 dev->regions[region].size)
940 return (uint8_t *)dev->region_vaddr[region] +
941 (addr -
942 dev->regions[region].user_addr +
943 dev->regions[region].mmap_offset);
944 }
945 USFSTL_ASSERT(0, "cannot translate user address %"PRIx64"\n", addr);
946 return NULL;
947 }
948
usfstl_vhost_user_to_phys(struct usfstl_vhost_user_dev * extdev,uint64_t addr)949 uint64_t usfstl_vhost_user_to_phys(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
950 {
951 struct usfstl_vhost_user_dev_int *dev;
952 unsigned int region;
953
954 dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
955
956 for (region = 0; region < dev->n_regions; region++) {
957 if (addr >= dev->regions[region].user_addr &&
958 addr < dev->regions[region].user_addr +
959 dev->regions[region].size)
960 return addr -
961 dev->regions[region].user_addr +
962 dev->regions[region].guest_phys_addr;
963 }
964 USFSTL_ASSERT(0, "cannot translate user address %"PRIx64"\n", addr);
965 return 0;
966 }
967
usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev * extdev,uint64_t addr)968 void *usfstl_vhost_phys_to_va(struct usfstl_vhost_user_dev *extdev, uint64_t addr)
969 {
970 struct usfstl_vhost_user_dev_int *dev;
971 unsigned int region;
972
973 dev = container_of(extdev, struct usfstl_vhost_user_dev_int, ext);
974
975 for (region = 0; region < dev->n_regions; region++) {
976 if (addr >= dev->regions[region].guest_phys_addr &&
977 addr < dev->regions[region].guest_phys_addr +
978 dev->regions[region].size)
979 return (uint8_t *)dev->region_vaddr[region] +
980 (addr -
981 dev->regions[region].guest_phys_addr +
982 dev->regions[region].mmap_offset);
983 }
984
985 USFSTL_ASSERT(0, "cannot translate physical address %"PRIx64"\n", addr);
986 return NULL;
987 }
988
iov_len(struct iovec * sg,unsigned int nsg)989 size_t iov_len(struct iovec *sg, unsigned int nsg)
990 {
991 size_t len = 0;
992 unsigned int i;
993
994 for (i = 0; i < nsg; i++)
995 len += sg[i].iov_len;
996
997 return len;
998 }
999
iov_fill(struct iovec * sg,unsigned int nsg,const void * _buf,size_t buflen)1000 size_t iov_fill(struct iovec *sg, unsigned int nsg,
1001 const void *_buf, size_t buflen)
1002 {
1003 const char *buf = _buf;
1004 unsigned int i;
1005 size_t copied = 0;
1006
1007 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
1008 for (i = 0; buflen && i < nsg; i++) {
1009 size_t cpy = min(buflen, sg[i].iov_len);
1010
1011 memcpy(sg[i].iov_base, buf, cpy);
1012 buflen -= cpy;
1013 copied += cpy;
1014 buf += cpy;
1015 }
1016
1017 return copied;
1018 }
1019
iov_read(void * _buf,size_t buflen,struct iovec * sg,unsigned int nsg)1020 size_t iov_read(void *_buf, size_t buflen,
1021 struct iovec *sg, unsigned int nsg)
1022 {
1023 char *buf = _buf;
1024 unsigned int i;
1025 size_t copied = 0;
1026
1027 #define min(a, b) ({ typeof(a) _a = (a); typeof(b) _b = (b); _a < _b ? _a : _b; })
1028 for (i = 0; buflen && i < nsg; i++) {
1029 size_t cpy = min(buflen, sg[i].iov_len);
1030
1031 memcpy(buf, sg[i].iov_base, cpy);
1032 buflen -= cpy;
1033 copied += cpy;
1034 buf += cpy;
1035 }
1036
1037 return copied;
1038 }
1039