• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Linux vhost kernel API wrapper.
6 
7 #[cfg(unix)]
8 pub mod net;
9 mod vsock;
10 
11 use std::alloc::Layout;
12 use std::io::Error as IoError;
13 use std::ptr::null;
14 
15 use base::ioctl;
16 use base::ioctl_with_mut_ref;
17 use base::ioctl_with_ptr;
18 use base::ioctl_with_ref;
19 use base::AsRawDescriptor;
20 use base::Event;
21 use base::LayoutAllocation;
22 use remain::sorted;
23 use static_assertions::const_assert;
24 use thiserror::Error;
25 use vm_memory::GuestAddress;
26 use vm_memory::GuestMemory;
27 use vm_memory::GuestMemoryError;
28 use vm_memory::MemoryRegionInformation;
29 
30 #[cfg(unix)]
31 pub use crate::net::Net;
32 #[cfg(unix)]
33 pub use crate::net::NetT;
34 pub use crate::vsock::Vsock;
35 
36 #[sorted]
37 #[derive(Error, Debug)]
38 pub enum Error {
39     /// Invalid available address.
40     #[error("invalid available address: {0}")]
41     AvailAddress(GuestMemoryError),
42     /// Invalid descriptor table address.
43     #[error("invalid descriptor table address: {0}")]
44     DescriptorTableAddress(GuestMemoryError),
45     /// Invalid queue.
46     #[error("invalid queue")]
47     InvalidQueue,
48     /// Error while running ioctl.
49     #[error("failed to run ioctl: {0}")]
50     IoctlError(IoError),
51     /// Invalid log address.
52     #[error("invalid log address: {0}")]
53     LogAddress(GuestMemoryError),
54     /// Invalid used address.
55     #[error("invalid used address: {0}")]
56     UsedAddress(GuestMemoryError),
57     /// Error opening vhost device.
58     #[error("failed to open vhost device: {0}")]
59     VhostOpen(IoError),
60 }
61 
62 pub type Result<T> = std::result::Result<T, Error>;
63 
ioctl_result<T>() -> Result<T>64 fn ioctl_result<T>() -> Result<T> {
65     Err(Error::IoctlError(IoError::last_os_error()))
66 }
67 
68 /// An interface for setting up vhost-based virtio devices.  Vhost-based devices are different
69 /// from regular virtio devices because the host kernel takes care of handling all the data
70 /// transfer.  The device itself only needs to deal with setting up the kernel driver and
71 /// managing the control channel.
72 pub trait Vhost: AsRawDescriptor + std::marker::Sized {
73     /// Set the current process as the owner of this file descriptor.
74     /// This must be run before any other vhost ioctls.
set_owner(&self) -> Result<()>75     fn set_owner(&self) -> Result<()> {
76         // This ioctl is called on a valid vhost_net descriptor and has its
77         // return value checked.
78         let ret = unsafe { ioctl(self, virtio_sys::VHOST_SET_OWNER()) };
79         if ret < 0 {
80             return ioctl_result();
81         }
82         Ok(())
83     }
84 
85     /// Give up ownership and reset the device to default values. Allows a subsequent call to
86     /// `set_owner` to succeed.
reset_owner(&self) -> Result<()>87     fn reset_owner(&self) -> Result<()> {
88         // This ioctl is called on a valid vhost fd and has its
89         // return value checked.
90         let ret = unsafe { ioctl(self, virtio_sys::VHOST_RESET_OWNER()) };
91         if ret < 0 {
92             return ioctl_result();
93         }
94         Ok(())
95     }
96 
97     /// Get a bitmask of supported virtio/vhost features.
get_features(&self) -> Result<u64>98     fn get_features(&self) -> Result<u64> {
99         let mut avail_features: u64 = 0;
100         // This ioctl is called on a valid vhost_net descriptor and has its
101         // return value checked.
102         let ret = unsafe {
103             ioctl_with_mut_ref(self, virtio_sys::VHOST_GET_FEATURES(), &mut avail_features)
104         };
105         if ret < 0 {
106             return ioctl_result();
107         }
108         Ok(avail_features)
109     }
110 
111     /// Inform the vhost subsystem which features to enable. This should be a subset of
112     /// supported features from VHOST_GET_FEATURES.
113     ///
114     /// # Arguments
115     /// * `features` - Bitmask of features to set.
set_features(&self, features: u64) -> Result<()>116     fn set_features(&self, features: u64) -> Result<()> {
117         // This ioctl is called on a valid vhost_net descriptor and has its
118         // return value checked.
119         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_FEATURES(), &features) };
120         if ret < 0 {
121             return ioctl_result();
122         }
123         Ok(())
124     }
125 
126     /// Set the guest memory mappings for vhost to use.
set_mem_table(&self, mem: &GuestMemory) -> Result<()>127     fn set_mem_table(&self, mem: &GuestMemory) -> Result<()> {
128         const SIZE_OF_MEMORY: usize = std::mem::size_of::<virtio_sys::vhost::vhost_memory>();
129         const SIZE_OF_REGION: usize = std::mem::size_of::<virtio_sys::vhost::vhost_memory_region>();
130         const ALIGN_OF_MEMORY: usize = std::mem::align_of::<virtio_sys::vhost::vhost_memory>();
131         const_assert!(
132             ALIGN_OF_MEMORY >= std::mem::align_of::<virtio_sys::vhost::vhost_memory_region>()
133         );
134 
135         let num_regions = mem.num_regions() as usize;
136         let size = SIZE_OF_MEMORY + num_regions * SIZE_OF_REGION;
137         let layout = Layout::from_size_align(size, ALIGN_OF_MEMORY).expect("impossible layout");
138         let mut allocation = LayoutAllocation::zeroed(layout);
139 
140         // Safe to obtain an exclusive reference because there are no other
141         // references to the allocation yet and all-zero is a valid bit pattern.
142         let vhost_memory = unsafe { allocation.as_mut::<virtio_sys::vhost::vhost_memory>() };
143 
144         vhost_memory.nregions = num_regions as u32;
145         // regions is a zero-length array, so taking a mut slice requires that
146         // we correctly specify the size to match the amount of backing memory.
147         let vhost_regions = unsafe { vhost_memory.regions.as_mut_slice(num_regions as usize) };
148 
149         let _ = mem.with_regions::<_, ()>(
150             |MemoryRegionInformation {
151                  index,
152                  guest_addr,
153                  size,
154                  host_addr,
155                  ..
156              }| {
157                 vhost_regions[index] = virtio_sys::vhost::vhost_memory_region {
158                     guest_phys_addr: guest_addr.offset() as u64,
159                     memory_size: size as u64,
160                     userspace_addr: host_addr as u64,
161                     flags_padding: 0u64,
162                 };
163                 Ok(())
164             },
165         );
166 
167         // This ioctl is called with a pointer that is valid for the lifetime
168         // of this function. The kernel will make its own copy of the memory
169         // tables. As always, check the return value.
170         let ret = unsafe { ioctl_with_ptr(self, virtio_sys::VHOST_SET_MEM_TABLE(), vhost_memory) };
171         if ret < 0 {
172             return ioctl_result();
173         }
174 
175         Ok(())
176 
177         // vhost_memory allocation is deallocated.
178     }
179 
180     /// Set the number of descriptors in the vring.
181     ///
182     /// # Arguments
183     /// * `queue_index` - Index of the queue to set descriptor count for.
184     /// * `num` - Number of descriptors in the queue.
set_vring_num(&self, queue_index: usize, num: u16) -> Result<()>185     fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
186         let vring_state = virtio_sys::vhost::vhost_vring_state {
187             index: queue_index as u32,
188             num: num as u32,
189         };
190 
191         // This ioctl is called on a valid vhost_net descriptor and has its
192         // return value checked.
193         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_NUM(), &vring_state) };
194         if ret < 0 {
195             return ioctl_result();
196         }
197         Ok(())
198     }
199 
200     // TODO(smbarber): This is copypasta. Eliminate the copypasta.
201     #[allow(clippy::if_same_then_else)]
is_valid( &self, mem: &GuestMemory, queue_max_size: u16, queue_size: u16, desc_addr: GuestAddress, avail_addr: GuestAddress, used_addr: GuestAddress, ) -> bool202     fn is_valid(
203         &self,
204         mem: &GuestMemory,
205         queue_max_size: u16,
206         queue_size: u16,
207         desc_addr: GuestAddress,
208         avail_addr: GuestAddress,
209         used_addr: GuestAddress,
210     ) -> bool {
211         let desc_table_size = 16 * queue_size as usize;
212         let avail_ring_size = 6 + 2 * queue_size as usize;
213         let used_ring_size = 6 + 8 * queue_size as usize;
214         if queue_size > queue_max_size || queue_size == 0 || (queue_size & (queue_size - 1)) != 0 {
215             false
216         } else if desc_addr
217             .checked_add(desc_table_size as u64)
218             .map_or(true, |v| !mem.address_in_range(v))
219         {
220             false
221         } else if avail_addr
222             .checked_add(avail_ring_size as u64)
223             .map_or(true, |v| !mem.address_in_range(v))
224         {
225             false
226         } else if used_addr
227             .checked_add(used_ring_size as u64)
228             .map_or(true, |v| !mem.address_in_range(v))
229         {
230             false
231         } else {
232             true
233         }
234     }
235 
236     /// Set the addresses for a given vring.
237     ///
238     /// # Arguments
239     /// * `queue_max_size` - Maximum queue size supported by the device.
240     /// * `queue_size` - Actual queue size negotiated by the driver.
241     /// * `queue_index` - Index of the queue to set addresses for.
242     /// * `flags` - Bitmask of vring flags.
243     /// * `desc_addr` - Descriptor table address.
244     /// * `used_addr` - Used ring buffer address.
245     /// * `avail_addr` - Available ring buffer address.
246     /// * `log_addr` - Optional address for logging.
set_vring_addr( &self, mem: &GuestMemory, queue_max_size: u16, queue_size: u16, queue_index: usize, flags: u32, desc_addr: GuestAddress, used_addr: GuestAddress, avail_addr: GuestAddress, log_addr: Option<GuestAddress>, ) -> Result<()>247     fn set_vring_addr(
248         &self,
249         mem: &GuestMemory,
250         queue_max_size: u16,
251         queue_size: u16,
252         queue_index: usize,
253         flags: u32,
254         desc_addr: GuestAddress,
255         used_addr: GuestAddress,
256         avail_addr: GuestAddress,
257         log_addr: Option<GuestAddress>,
258     ) -> Result<()> {
259         // TODO(smbarber): Refactor out virtio from crosvm so we can
260         // validate a Queue struct directly.
261         if !self.is_valid(
262             mem,
263             queue_max_size,
264             queue_size,
265             desc_addr,
266             used_addr,
267             avail_addr,
268         ) {
269             return Err(Error::InvalidQueue);
270         }
271 
272         let desc_addr = mem
273             .get_host_address(desc_addr)
274             .map_err(Error::DescriptorTableAddress)?;
275         let used_addr = mem
276             .get_host_address(used_addr)
277             .map_err(Error::UsedAddress)?;
278         let avail_addr = mem
279             .get_host_address(avail_addr)
280             .map_err(Error::AvailAddress)?;
281         let log_addr = match log_addr {
282             None => null(),
283             Some(a) => mem.get_host_address(a).map_err(Error::LogAddress)?,
284         };
285 
286         let vring_addr = virtio_sys::vhost::vhost_vring_addr {
287             index: queue_index as u32,
288             flags,
289             desc_user_addr: desc_addr as u64,
290             used_user_addr: used_addr as u64,
291             avail_user_addr: avail_addr as u64,
292             log_guest_addr: log_addr as u64,
293         };
294 
295         // This ioctl is called on a valid vhost_net descriptor and has its
296         // return value checked.
297         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_ADDR(), &vring_addr) };
298         if ret < 0 {
299             return ioctl_result();
300         }
301         Ok(())
302     }
303 
304     /// Set the first index to look for available descriptors.
305     ///
306     /// # Arguments
307     /// * `queue_index` - Index of the queue to modify.
308     /// * `num` - Index where available descriptors start.
set_vring_base(&self, queue_index: usize, num: u16) -> Result<()>309     fn set_vring_base(&self, queue_index: usize, num: u16) -> Result<()> {
310         let vring_state = virtio_sys::vhost::vhost_vring_state {
311             index: queue_index as u32,
312             num: num as u32,
313         };
314 
315         // This ioctl is called on a valid vhost_net descriptor and has its
316         // return value checked.
317         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_BASE(), &vring_state) };
318         if ret < 0 {
319             return ioctl_result();
320         }
321         Ok(())
322     }
323 
324     /// Gets the index of the next available descriptor in the queue.
325     ///
326     /// # Arguments
327     /// * `queue_index` - Index of the queue to query.
get_vring_base(&self, queue_index: usize) -> Result<u16>328     fn get_vring_base(&self, queue_index: usize) -> Result<u16> {
329         let mut vring_state = virtio_sys::vhost::vhost_vring_state {
330             index: queue_index as u32,
331             num: 0,
332         };
333 
334         // Safe because this will only modify `vring_state` and we check the return value.
335         let ret = unsafe {
336             ioctl_with_mut_ref(self, virtio_sys::VHOST_GET_VRING_BASE(), &mut vring_state)
337         };
338         if ret < 0 {
339             return ioctl_result();
340         }
341 
342         Ok(vring_state.num as u16)
343     }
344 
345     /// Set the event to trigger when buffers have been used by the host.
346     ///
347     /// # Arguments
348     /// * `queue_index` - Index of the queue to modify.
349     /// * `event` - Event to trigger.
set_vring_call(&self, queue_index: usize, event: &Event) -> Result<()>350     fn set_vring_call(&self, queue_index: usize, event: &Event) -> Result<()> {
351         let vring_file = virtio_sys::vhost::vhost_vring_file {
352             index: queue_index as u32,
353             fd: event.as_raw_descriptor() as i32,
354         };
355 
356         // This ioctl is called on a valid vhost_net descriptor and has its
357         // return value checked.
358         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_CALL(), &vring_file) };
359         if ret < 0 {
360             return ioctl_result();
361         }
362         Ok(())
363     }
364 
365     /// Set the event to trigger to signal an error.
366     ///
367     /// # Arguments
368     /// * `queue_index` - Index of the queue to modify.
369     /// * `event` - Event to trigger.
set_vring_err(&self, queue_index: usize, event: &Event) -> Result<()>370     fn set_vring_err(&self, queue_index: usize, event: &Event) -> Result<()> {
371         let vring_file = virtio_sys::vhost::vhost_vring_file {
372             index: queue_index as u32,
373             fd: event.as_raw_descriptor() as i32,
374         };
375 
376         // This ioctl is called on a valid vhost_net fd and has its
377         // return value checked.
378         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_ERR(), &vring_file) };
379         if ret < 0 {
380             return ioctl_result();
381         }
382         Ok(())
383     }
384 
385     /// Set the event that will be signaled by the guest when buffers are
386     /// available for the host to process.
387     ///
388     /// # Arguments
389     /// * `queue_index` - Index of the queue to modify.
390     /// * `event` - Event that will be signaled from guest.
set_vring_kick(&self, queue_index: usize, event: &Event) -> Result<()>391     fn set_vring_kick(&self, queue_index: usize, event: &Event) -> Result<()> {
392         let vring_file = virtio_sys::vhost::vhost_vring_file {
393             index: queue_index as u32,
394             fd: event.as_raw_descriptor() as i32,
395         };
396 
397         // This ioctl is called on a valid vhost_net descriptor and has its
398         // return value checked.
399         let ret = unsafe { ioctl_with_ref(self, virtio_sys::VHOST_SET_VRING_KICK(), &vring_file) };
400         if ret < 0 {
401             return ioctl_result();
402         }
403         Ok(())
404     }
405 }
406