1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Implement a struct that works as a `vmm_vhost`'s backend.
6 
7 use std::cmp::Ordering;
8 use std::io::Error as IoError;
9 use std::io::IoSlice;
10 use std::io::IoSliceMut;
11 use std::mem;
12 use std::os::unix::prelude::RawFd;
13 use std::sync::mpsc::channel;
14 use std::sync::mpsc::Receiver;
15 use std::sync::mpsc::Sender;
16 use std::sync::Arc;
17 use std::thread;
18 
19 use anyhow::anyhow;
20 use anyhow::bail;
21 use anyhow::Context;
22 use anyhow::Result;
23 use base::error;
24 use base::info;
25 use base::AsRawDescriptor;
26 use base::Descriptor;
27 use base::Event;
28 use base::EventExt;
29 use base::MappedRegion;
30 use base::MemoryMappingBuilder;
31 use base::MemoryMappingBuilderUnix;
32 use base::RawDescriptor;
33 use base::SafeDescriptor;
34 use cros_async::EventAsync;
35 use cros_async::Executor;
36 use futures::pin_mut;
37 use futures::select;
38 use futures::FutureExt;
39 use resources::Alloc;
40 use sync::Mutex;
41 use vmm_vhost::connection::vfio::Device as VfioDeviceTrait;
42 use vmm_vhost::connection::vfio::Endpoint as VfioEndpoint;
43 use vmm_vhost::connection::vfio::RecvIntoBufsError;
44 use vmm_vhost::connection::Endpoint;
45 use vmm_vhost::message::*;
46 
47 use crate::vfio::VfioDevice;
48 use crate::virtio::vhost::user::device::vvu::pci::QueueNotifier;
49 use crate::virtio::vhost::user::device::vvu::pci::VvuPciDevice;
50 use crate::virtio::vhost::user::device::vvu::queue::UserQueue;
51 use crate::virtio::vhost::vhost_body_from_message_bytes;
52 use crate::virtio::vhost::vhost_header_from_bytes;
53 use crate::virtio::vhost::HEADER_LEN;
54 
55 // Helper class for forwarding messages from the virtqueue thread to the main worker thread.
56 struct VfioSender {
57     sender: Sender<Vec<u8>>,
58     evt: Event,
59 }
60 
61 impl VfioSender {
new(sender: Sender<Vec<u8>>, evt: Event) -> Self62     fn new(sender: Sender<Vec<u8>>, evt: Event) -> Self {
63         Self { sender, evt }
64     }
65 
send(&self, buf: Vec<u8>) -> Result<()>66     fn send(&self, buf: Vec<u8>) -> Result<()> {
67         self.sender.send(buf)?;
68         // Increment the event counter as we sent one buffer.
69         self.evt.write_count(1).context("failed to signal event")
70     }
71 }
72 
73 struct VfioReceiver {
74     receiver: Receiver<Vec<u8>>,
75     buf: Vec<u8>,
76     offset: usize,
77     evt: Event,
78 }
79 
80 // Utility class for converting discrete vhost user messages received by a
81 // VfioSender into a byte stream.
82 impl VfioReceiver {
new(receiver: Receiver<Vec<u8>>, evt: Event) -> Self83     fn new(receiver: Receiver<Vec<u8>>, evt: Event) -> Self {
84         Self {
85             receiver,
86             buf: Vec::new(),
87             offset: 0,
88             evt,
89         }
90     }
91 
92     // Reads the vhost user message into a byte stream. After each discrete message has
93     // been consumed, returns the message for post-processing.
recv_into_buf( &mut self, out: &mut IoSliceMut, ) -> Result<(usize, Option<Vec<u8>>), RecvIntoBufsError>94     fn recv_into_buf(
95         &mut self,
96         out: &mut IoSliceMut,
97     ) -> Result<(usize, Option<Vec<u8>>), RecvIntoBufsError> {
98         let len = out.len();
99 
100         if self.buf.is_empty() {
101             let data = self
102                 .receiver
103                 .recv()
104                 .context("failed to receive data")
105                 .map_err(RecvIntoBufsError::Fatal)?;
106 
107             if data.len() == 0 {
108                 // TODO(b/216407443): We should change `self.state` and exit gracefully.
109                 info!("VVU connection is closed");
110                 return Err(RecvIntoBufsError::Disconnect);
111             }
112 
113             self.buf = data;
114             self.offset = 0;
115             // Decrement the event counter as we received one buffer.
116             self.evt
117                 .read_count()
118                 .and_then(|c| self.evt.write_count(c - 1))
119                 .context("failed to decrease event counter")
120                 .map_err(RecvIntoBufsError::Fatal)?;
121         }
122 
123         if self.offset + len > self.buf.len() {
124             // VVU rxq runs at message granularity. If there's not enough bytes to fill
125             // |out|, then that means we're being asked to merge bytes from multiple messages
126             // into a single buffer. That almost certainly indicates a message framing error
127             // higher up the stack, so reject the request.
128             return Err(RecvIntoBufsError::Fatal(anyhow!(
129                 "recv underflow {} {} {}",
130                 self.offset,
131                 len,
132                 self.buf.len()
133             )));
134         }
135         out.clone_from_slice(&self.buf[self.offset..(self.offset + len)]);
136 
137         self.offset += len;
138         let ret_vec = if self.offset == self.buf.len() {
139             Some(std::mem::take(&mut self.buf))
140         } else {
141             None
142         };
143 
144         Ok((len, ret_vec))
145     }
146 
recv_into_bufs( &mut self, bufs: &mut [IoSliceMut], mut processor: Option<&mut BackendChannelInner>, ) -> Result<usize, RecvIntoBufsError>147     fn recv_into_bufs(
148         &mut self,
149         bufs: &mut [IoSliceMut],
150         mut processor: Option<&mut BackendChannelInner>,
151     ) -> Result<usize, RecvIntoBufsError> {
152         let mut size = 0;
153         for buf in bufs {
154             let (len, msg) = self.recv_into_buf(buf)?;
155             size += len;
156 
157             if let (Some(processor), Some(msg)) = (processor.as_mut(), msg) {
158                 processor
159                     .postprocess_rx(msg)
160                     .map_err(RecvIntoBufsError::Fatal)?;
161             }
162         }
163 
164         Ok(size)
165     }
166 }
167 
168 // Data queued to send on an endpoint.
169 #[derive(Default)]
170 struct EndpointTxBuffer {
171     bytes: Vec<u8>,
172     files: Vec<SafeDescriptor>,
173 }
174 
175 // Utility class for writing an input vhost-user byte stream to the vvu
176 // tx virtqueue as discrete vhost-user messages.
177 struct Queue {
178     txq: UserQueue,
179     txq_notifier: QueueNotifier,
180 }
181 
182 impl Queue {
send_bufs( &mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>, tx_state: &mut EndpointTxBuffer, processor: Option<&mut BackendChannelInner>, ) -> Result<usize>183     fn send_bufs(
184         &mut self,
185         iovs: &[IoSlice],
186         fds: Option<&[RawDescriptor]>,
187         tx_state: &mut EndpointTxBuffer,
188         processor: Option<&mut BackendChannelInner>,
189     ) -> Result<usize> {
190         if let Some(fds) = fds {
191             if processor.is_none() {
192                 bail!("cannot send FDs");
193             }
194 
195             let fds: std::result::Result<Vec<_>, IoError> = fds
196                 .iter()
197                 .map(|fd| SafeDescriptor::try_from(&Descriptor(*fd) as &dyn AsRawDescriptor))
198                 .collect();
199             tx_state.files = fds?;
200         }
201 
202         let mut size = 0;
203         for iov in iovs {
204             let mut vec = iov.to_vec();
205             size += iov.len();
206             tx_state.bytes.append(&mut vec);
207         }
208 
209         if let Some(hdr) = vhost_header_from_bytes::<MasterReq>(&tx_state.bytes) {
210             let bytes_needed = hdr.get_size() as usize + HEADER_LEN;
211             match bytes_needed.cmp(&tx_state.bytes.len()) {
212                 Ordering::Greater => (),
213                 Ordering::Equal => {
214                     let msg = mem::take(&mut tx_state.bytes);
215                     let files = std::mem::take(&mut tx_state.files);
216 
217                     let msg = if let Some(processor) = processor {
218                         processor
219                             .preprocess_tx(msg, files)
220                             .context("failed to preprocess message")?
221                     } else {
222                         msg
223                     };
224 
225                     self.txq.write(&msg).context("Failed to send data")?;
226                     self.txq_notifier.notify();
227                 }
228                 Ordering::Less => bail!("sent bytes larger than message size"),
229             }
230         }
231 
232         Ok(size)
233     }
234 }
235 
process_rxq( evt: EventAsync, mut rxq: UserQueue, rxq_notifier: QueueNotifier, frontend_sender: VfioSender, backend_sender: VfioSender, ) -> Result<()>236 async fn process_rxq(
237     evt: EventAsync,
238     mut rxq: UserQueue,
239     rxq_notifier: QueueNotifier,
240     frontend_sender: VfioSender,
241     backend_sender: VfioSender,
242 ) -> Result<()> {
243     loop {
244         if let Err(e) = evt.next_val().await {
245             error!("Failed to read the next queue event: {}", e);
246             continue;
247         }
248 
249         while let Some(slice) = rxq.read_data()? {
250             if slice.size() < HEADER_LEN {
251                 bail!("rxq message too short: {}", slice.size());
252             }
253 
254             let mut buf = vec![0_u8; slice.size()];
255             slice.copy_to(&mut buf);
256 
257             // The inbound message may be a SlaveReq message. However, the values
258             // of all SlaveReq enum values can be safely interpreted as MasterReq
259             // enum values.
260             let hdr =
261                 vhost_header_from_bytes::<MasterReq>(&buf).context("rxq message too short")?;
262             if HEADER_LEN + hdr.get_size() as usize != slice.size() {
263                 bail!(
264                     "rxq message size mismatch: {} vs {}",
265                     slice.size(),
266                     hdr.get_size()
267                 );
268             }
269 
270             if hdr.is_reply() {
271                 &backend_sender
272             } else {
273                 &frontend_sender
274             }
275             .send(buf)
276             .context("send failed")?;
277         }
278         rxq_notifier.notify();
279     }
280 }
281 
process_txq(evt: EventAsync, txq: Arc<Mutex<Queue>>) -> Result<()>282 async fn process_txq(evt: EventAsync, txq: Arc<Mutex<Queue>>) -> Result<()> {
283     loop {
284         if let Err(e) = evt.next_val().await {
285             error!("Failed to read the next queue event: {}", e);
286             continue;
287         }
288 
289         txq.lock().txq.ack_used()?;
290     }
291 }
292 
run_worker( ex: Executor, rx_queue: UserQueue, rx_irq: Event, rx_notifier: QueueNotifier, frontend_sender: VfioSender, backend_sender: VfioSender, tx_queue: Arc<Mutex<Queue>>, tx_irq: Event, ) -> Result<()>293 fn run_worker(
294     ex: Executor,
295     rx_queue: UserQueue,
296     rx_irq: Event,
297     rx_notifier: QueueNotifier,
298     frontend_sender: VfioSender,
299     backend_sender: VfioSender,
300     tx_queue: Arc<Mutex<Queue>>,
301     tx_irq: Event,
302 ) -> Result<()> {
303     let rx_irq = EventAsync::new(rx_irq, &ex).context("failed to create async event")?;
304     let rxq = process_rxq(
305         rx_irq,
306         rx_queue,
307         rx_notifier,
308         frontend_sender,
309         backend_sender,
310     );
311     pin_mut!(rxq);
312 
313     let tx_irq = EventAsync::new(tx_irq, &ex).context("failed to create async event")?;
314     let txq = process_txq(tx_irq, Arc::clone(&tx_queue));
315     pin_mut!(txq);
316 
317     let done = async {
318         select! {
319             res = rxq.fuse() => res.context("failed to handle rxq"),
320             res = txq.fuse() => res.context("failed to handle txq"),
321         }
322     };
323 
324     match ex.run_until(done) {
325         Ok(_) => Ok(()),
326         Err(e) => {
327             bail!("failed to process virtio-vhost-user queues: {}", e);
328         }
329     }
330 }
331 
332 enum DeviceState {
333     Initialized {
334         // TODO(keiichiw): Update `VfioDeviceTrait::start()` to take `VvuPciDevice` so that we can
335         // drop this field.
336         device: VvuPciDevice,
337     },
338     Running {
339         rxq_receiver: VfioReceiver,
340         tx_state: EndpointTxBuffer,
341 
342         txq: Arc<Mutex<Queue>>,
343     },
344 }
345 
346 pub struct VvuDevice {
347     state: DeviceState,
348     frontend_rxq_evt: Event,
349 
350     backend_channel: Option<VfioEndpoint<SlaveReq, BackendChannel>>,
351 }
352 
353 impl VvuDevice {
new(device: VvuPciDevice) -> Self354     pub fn new(device: VvuPciDevice) -> Self {
355         Self {
356             state: DeviceState::Initialized { device },
357             frontend_rxq_evt: Event::new().expect("failed to create VvuDevice's rxq_evt"),
358             backend_channel: None,
359         }
360     }
361 }
362 
363 impl VfioDeviceTrait for VvuDevice {
event(&self) -> &Event364     fn event(&self) -> &Event {
365         &self.frontend_rxq_evt
366     }
367 
start(&mut self) -> Result<()>368     fn start(&mut self) -> Result<()> {
369         let device = match &mut self.state {
370             DeviceState::Initialized { device } => device,
371             DeviceState::Running { .. } => {
372                 bail!("VvuDevice has already started");
373             }
374         };
375         let ex = Executor::new().expect("Failed to create an executor");
376 
377         let mut irqs = mem::take(&mut device.irqs);
378         let mut queues = mem::take(&mut device.queues);
379         let mut queue_notifiers = mem::take(&mut device.queue_notifiers);
380 
381         let rxq = queues.remove(0);
382         let rxq_irq = irqs.remove(0);
383         let rxq_notifier = queue_notifiers.remove(0);
384         // TODO: Can we use async channel instead so we don't need `rxq_evt`?
385         let (rxq_sender, rxq_receiver) = channel();
386         let rxq_evt = self.frontend_rxq_evt.try_clone().expect("rxq_evt clone");
387 
388         let txq = Arc::new(Mutex::new(Queue {
389             txq: queues.remove(0),
390             txq_notifier: queue_notifiers.remove(0),
391         }));
392         let txq_cloned = Arc::clone(&txq);
393         let txq_irq = irqs.remove(0);
394 
395         let (backend_rxq_sender, backend_rxq_receiver) = channel();
396         let backend_rxq_evt = Event::new().expect("failed to create VvuDevice's rxq_evt");
397         let backend_rxq_evt2 = backend_rxq_evt.try_clone().expect("rxq_evt clone");
398         self.backend_channel = Some(VfioEndpoint::from(BackendChannel {
399             receiver: VfioReceiver::new(backend_rxq_receiver, backend_rxq_evt),
400             queue: txq.clone(),
401             inner: BackendChannelInner {
402                 pending_unmap: None,
403                 vfio: device.vfio_dev.clone(),
404             },
405             tx_state: EndpointTxBuffer::default(),
406         }));
407 
408         let old_state = std::mem::replace(
409             &mut self.state,
410             DeviceState::Running {
411                 rxq_receiver: VfioReceiver::new(
412                     rxq_receiver,
413                     self.frontend_rxq_evt
414                         .try_clone()
415                         .expect("frontend_rxq_evt clone"),
416                 ),
417                 tx_state: EndpointTxBuffer::default(),
418                 txq,
419             },
420         );
421 
422         let device = match old_state {
423             DeviceState::Initialized { device } => device,
424             _ => unreachable!(),
425         };
426 
427         let frontend_sender = VfioSender::new(rxq_sender, rxq_evt);
428         let backend_sender = VfioSender::new(backend_rxq_sender, backend_rxq_evt2);
429         thread::Builder::new()
430             .name("vvu_driver".to_string())
431             .spawn(move || {
432                 device.start().expect("failed to start device");
433                 if let Err(e) = run_worker(
434                     ex,
435                     rxq,
436                     rxq_irq,
437                     rxq_notifier,
438                     frontend_sender,
439                     backend_sender,
440                     txq_cloned,
441                     txq_irq,
442                 ) {
443                     error!("worker thread exited with error: {}", e);
444                 }
445             })?;
446 
447         Ok(())
448     }
449 
send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>) -> Result<usize>450     fn send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>) -> Result<usize> {
451         match &mut self.state {
452             DeviceState::Initialized { .. } => {
453                 bail!("VvuDevice hasn't started yet");
454             }
455             DeviceState::Running { txq, tx_state, .. } => {
456                 let mut queue = txq.lock();
457                 queue.send_bufs(iovs, fds, tx_state, None)
458             }
459         }
460     }
461 
recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result<usize, RecvIntoBufsError>462     fn recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result<usize, RecvIntoBufsError> {
463         match &mut self.state {
464             DeviceState::Initialized { .. } => Err(RecvIntoBufsError::Fatal(anyhow!(
465                 "VvuDevice hasn't started yet"
466             ))),
467             DeviceState::Running { rxq_receiver, .. } => rxq_receiver.recv_into_bufs(bufs, None),
468         }
469     }
470 
create_slave_request_endpoint(&mut self) -> Result<Box<dyn Endpoint<SlaveReq>>>471     fn create_slave_request_endpoint(&mut self) -> Result<Box<dyn Endpoint<SlaveReq>>> {
472         self.backend_channel
473             .take()
474             .map_or(Err(anyhow!("missing backend endpoint")), |c| {
475                 Ok(Box::new(c))
476             })
477     }
478 }
479 
480 // State of the backend channel not directly related to sending/receiving data.
481 struct BackendChannelInner {
482     vfio: Arc<VfioDevice>,
483 
484     // Offset of the pending unmap operation. Set when an unmap message is sent,
485     // and cleared when the reply is recieved.
486     pending_unmap: Option<u64>,
487 }
488 
489 // Struct which implements the Endpoint for backend messages.
490 struct BackendChannel {
491     receiver: VfioReceiver,
492     queue: Arc<Mutex<Queue>>,
493     inner: BackendChannelInner,
494     tx_state: EndpointTxBuffer,
495 }
496 
497 impl VfioDeviceTrait for BackendChannel {
event(&self) -> &Event498     fn event(&self) -> &Event {
499         &self.receiver.evt
500     }
501 
start(&mut self) -> Result<()>502     fn start(&mut self) -> Result<()> {
503         Ok(())
504     }
505 
send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawFd]>) -> Result<usize>506     fn send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawFd]>) -> Result<usize> {
507         self.queue.lock().send_bufs(
508             iovs,
509             fds,
510             &mut self.tx_state,
511             Some(&mut self.inner as &mut BackendChannelInner),
512         )
513     }
514 
recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result<usize, RecvIntoBufsError>515     fn recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result<usize, RecvIntoBufsError> {
516         self.receiver.recv_into_bufs(bufs, Some(&mut self.inner))
517     }
518 
create_slave_request_endpoint(&mut self) -> Result<Box<dyn Endpoint<SlaveReq>>>519     fn create_slave_request_endpoint(&mut self) -> Result<Box<dyn Endpoint<SlaveReq>>> {
520         Err(anyhow!(
521             "can't construct backend endpoint from backend endpoint"
522         ))
523     }
524 }
525 
526 impl BackendChannelInner {
527     // Preprocess messages before forwarding them to the virtqueue. Returns the bytes to
528     // send to the host.
preprocess_tx( &mut self, mut msg: Vec<u8>, mut files: Vec<SafeDescriptor>, ) -> Result<Vec<u8>>529     fn preprocess_tx(
530         &mut self,
531         mut msg: Vec<u8>,
532         mut files: Vec<SafeDescriptor>,
533     ) -> Result<Vec<u8>> {
534         // msg came from a ProtocolReader, so this can't fail.
535         let hdr = vhost_header_from_bytes::<SlaveReq>(&msg).expect("framing error");
536         let msg_type = hdr.get_code();
537 
538         match msg_type {
539             SlaveReq::SHMEM_MAP => {
540                 let file = files.pop().context("missing file to mmap")?;
541 
542                 // msg came from a ProtoclReader, so this can't fail.
543                 let mut msg = vhost_body_from_message_bytes::<VhostUserShmemMapMsg>(&mut msg)
544                     .expect("framing error");
545 
546                 let mapping = MemoryMappingBuilder::new(msg.len as usize)
547                     .from_descriptor(&file)
548                     .offset(msg.fd_offset)
549                     .protection(msg.flags.into())
550                     .build()
551                     .context("failed to map file")?;
552 
553                 let iova = self
554                     .vfio
555                     .alloc_iova(msg.len, 4096, Alloc::Anon(msg.shm_offset as usize))
556                     .context("failed to allocate iova")?;
557                 // Safe because we're mapping an external file.
558                 unsafe {
559                     self.vfio
560                         .vfio_dma_map(iova, msg.len, mapping.as_ptr() as u64, true)
561                         .context("failed to map into IO address space")?;
562                 }
563 
564                 // The udmabuf constructed in the hypervisor corresponds to the region
565                 // we mmap'ed, so fd_offset is no longer necessary. Reuse it for the
566                 // iova.
567                 msg.fd_offset = iova;
568             }
569             SlaveReq::SHMEM_UNMAP => {
570                 if self.pending_unmap.is_some() {
571                     bail!("overlapping unmap requests");
572                 }
573 
574                 let msg = vhost_body_from_message_bytes::<VhostUserShmemUnmapMsg>(&mut msg)
575                     .expect("framing error");
576                 match self.vfio.get_iova(&Alloc::Anon(msg.shm_offset as usize)) {
577                     None => bail!("unmap doesn't match mapped allocation"),
578                     Some(range) => {
579                         if !range.len().map_or(false, |l| l == msg.len) {
580                             bail!("unmap size mismatch");
581                         }
582                     }
583                 }
584 
585                 self.pending_unmap = Some(msg.shm_offset)
586             }
587             _ => (),
588         }
589 
590         if !files.is_empty() {
591             bail!("{} unhandled files for {:?}", files.len(), msg_type);
592         }
593 
594         Ok(msg)
595     }
596 
597     // Postprocess replies recieved from the virtqueue. This occurs after the
598     // replies have been forwarded to the endpoint.
postprocess_rx(&mut self, msg: Vec<u8>) -> Result<()>599     fn postprocess_rx(&mut self, msg: Vec<u8>) -> Result<()> {
600         // msg are provided by ProtocolReader, so this can't fail.
601         let hdr = vhost_header_from_bytes::<SlaveReq>(&msg).unwrap();
602 
603         if hdr.get_code() == SlaveReq::SHMEM_UNMAP {
604             let offset = self
605                 .pending_unmap
606                 .take()
607                 .ok_or(RecvIntoBufsError::Fatal(anyhow!(
608                     "unexpected unmap response"
609                 )))?;
610 
611             let r = self
612                 .vfio
613                 .release_iova(Alloc::Anon(offset as usize))
614                 .expect("corrupted IOVA address space");
615             self.vfio
616                 .vfio_dma_unmap(r.start, r.len().unwrap())
617                 .context("failed to unmap memory")?;
618         }
619 
620         Ok(())
621     }
622 }
623