1 // Copyright (C) 2019 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
3
4 //! Virtio Vhost Backend Drivers
5 //!
6 //! Virtio devices use virtqueues to transport data efficiently. The first generation of virtqueue
7 //! is a set of three different single-producer, single-consumer ring structures designed to store
8 //! generic scatter-gather I/O. The virtio specification 1.1 introduces an alternative compact
9 //! virtqueue layout named "Packed Virtqueue", which is more friendly to memory cache system and
10 //! hardware implemented virtio devices. The packed virtqueue uses read-write memory, that means
11 //! the memory will be both read and written by both host and guest. The new Packed Virtqueue is
12 //! preferred for performance.
13 //!
14 //! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations
15 //! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O
16 //! completion interruption are piped through the hypervisor.
17 //! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to
18 //! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a
19 //! hypervisor process with an existing Virtio (PCI) driver.
20 //!
21 //! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
22 //! communicate with userspace applications. Dedicated kernel worker threads are created to handle
23 //! IO requests from the guest.
24 //!
25 //! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the
26 //! vhost implementation in the Linux kernel. It implements the control plane needed to establish
27 //! virtqueues sharing with a user space process on the same host. It uses communication over a
28 //! Unix domain socket to share file descriptors in the ancillary data of the message. The protocol
29 //! defines 2 sides of the communication, frontend and backend. Frontend is the application that
30 //! shares its virtqueues. Backend is the consumer of the virtqueues. Frontend and backend can be
31 //! either a client (i.e. connecting) or server (listening) in the socket communication.
32
33 use std::fs::File;
34 use std::io::Error as IOError;
35 use std::num::TryFromIntError;
36
37 use remain::sorted;
38 use thiserror::Error as ThisError;
39
40 mod backend;
41 pub use backend::*;
42
43 pub mod message;
44 pub use message::VHOST_USER_F_PROTOCOL_FEATURES;
45
46 pub mod connection;
47
48 mod sys;
49 pub use connection::Connection;
50 pub use message::BackendReq;
51 pub use message::FrontendReq;
52 pub use sys::SystemStream;
53 pub use sys::*;
54
55 pub(crate) mod backend_client;
56 pub use backend_client::BackendClient;
57 mod frontend_server;
58 pub use self::frontend_server::Frontend;
59 mod backend_server;
60 mod frontend_client;
61 pub use self::backend_server::Backend;
62 pub use self::backend_server::BackendServer;
63 pub use self::frontend_client::FrontendClient;
64 pub use self::frontend_server::FrontendServer;
65
66 /// Errors for vhost-user operations
67 #[sorted]
68 #[derive(Debug, ThisError)]
69 pub enum Error {
70 /// Failure from the backend side.
71 #[error("backend internal error")]
72 BackendInternalError,
73 /// client exited properly.
74 #[error("client exited properly")]
75 ClientExit,
76 /// Failure to deserialize data.
77 #[error("failed to deserialize data")]
78 DeserializationFailed,
79 /// client disconnected.
80 /// If connection is closed properly, use `ClientExit` instead.
81 #[error("client closed the connection")]
82 Disconnect,
83 /// Virtio/protocol features mismatch.
84 #[error("virtio features mismatch")]
85 FeatureMismatch,
86 /// Failure from the frontend side.
87 #[error("frontend Internal error")]
88 FrontendInternalError,
89 /// Fd array in question is too big or too small
90 #[error("wrong number of attached fds")]
91 IncorrectFds,
92 /// Invalid cast to int.
93 #[error("invalid cast to int: {0}")]
94 InvalidCastToInt(TryFromIntError),
95 /// Invalid message format, flag or content.
96 #[error("invalid message")]
97 InvalidMessage,
98 /// Unsupported operations due to that the protocol feature hasn't been negotiated.
99 #[error("invalid operation")]
100 InvalidOperation,
101 /// Invalid parameters.
102 #[error("invalid parameters")]
103 InvalidParam,
104 /// Message is too large
105 #[error("oversized message")]
106 OversizedMsg,
107 /// Only part of a message have been sent or received successfully
108 #[error("partial message")]
109 PartialMessage,
110 /// Provided recv buffer was too small, and data was dropped.
111 #[error("buffer for recv was too small, data was dropped: got size {got}, needed {want}")]
112 RecvBufferTooSmall {
113 /// The size of the buffer received.
114 got: usize,
115 /// The expected size of the buffer.
116 want: usize,
117 },
118 /// Error from request handler
119 #[error("handler failed to handle request: {0}")]
120 ReqHandlerError(IOError),
121 /// Failure to restore.
122 #[error("Failed to restore")]
123 RestoreError(anyhow::Error),
124 /// Failure to serialize data.
125 #[error("failed to serialize data")]
126 SerializationFailed,
127 /// Failure to run device specific sleep.
128 #[error("Failed to run device specific sleep: {0}")]
129 SleepError(anyhow::Error),
130 /// Failure to snapshot.
131 #[error("Failed to snapshot")]
132 SnapshotError(anyhow::Error),
133 /// The socket is broken or has been closed.
134 #[error("socket is broken: {0}")]
135 SocketBroken(std::io::Error),
136 /// Can't connect to peer.
137 #[error("can't connect to peer: {0}")]
138 SocketConnect(std::io::Error),
139 /// Generic socket errors.
140 #[error("socket error: {0}")]
141 SocketError(std::io::Error),
142 /// Should retry the socket operation again.
143 #[error("temporary socket error: {0}")]
144 SocketRetry(std::io::Error),
145 /// Failure to stop a queue.
146 #[error("failed to stop queue")]
147 StopQueueError(anyhow::Error),
148 /// Error from tx/rx on a Tube.
149 #[error("failed to read/write on Tube: {0}")]
150 TubeError(base::TubeError),
151 /// Error from VFIO device.
152 #[error("error occurred in VFIO device: {0}")]
153 VfioDeviceError(anyhow::Error),
154 /// Error from invalid vring index.
155 #[error("Vring index not found: {0}")]
156 VringIndexNotFound(usize),
157 /// Failure to run device specific wake.
158 #[error("Failed to run device specific wake: {0}")]
159 WakeError(anyhow::Error),
160 }
161
162 impl From<base::TubeError> for Error {
from(err: base::TubeError) -> Self163 fn from(err: base::TubeError) -> Self {
164 match err {
165 base::TubeError::Disconnected => Error::Disconnect,
166 err => Error::TubeError(err),
167 }
168 }
169 }
170
171 impl From<std::io::Error> for Error {
from(err: std::io::Error) -> Self172 fn from(err: std::io::Error) -> Self {
173 Error::SocketError(err)
174 }
175 }
176
177 impl From<base::Error> for Error {
178 /// Convert raw socket errors into meaningful vhost-user errors.
179 ///
180 /// The base::Error is a simple wrapper over the raw errno, which doesn't means
181 /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify
182 /// the connection manager logic.
183 ///
184 /// # Return:
185 /// * - Error::SocketRetry: temporary error caused by signals or short of resources.
186 /// * - Error::SocketBroken: the underline socket is broken.
187 /// * - Error::SocketError: other socket related errors.
188 #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
from(err: base::Error) -> Self189 fn from(err: base::Error) -> Self {
190 match err.errno() {
191 // Retry:
192 // * EAGAIN, EWOULDBLOCK: The socket is marked nonblocking and the requested operation
193 // would block.
194 // * EINTR: A signal occurred before any data was transmitted
195 // * ENOBUFS: The output queue for a network interface was full. This generally
196 // indicates that the interface has stopped sending, but may be caused by transient
197 // congestion.
198 // * ENOMEM: No memory available.
199 libc::EAGAIN | libc::EWOULDBLOCK | libc::EINTR | libc::ENOBUFS | libc::ENOMEM => {
200 Error::SocketRetry(err.into())
201 }
202 // Broken:
203 // * ECONNRESET: Connection reset by peer.
204 // * EPIPE: The local end has been shut down on a connection oriented socket. In this
205 // case the process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
206 libc::ECONNRESET | libc::EPIPE => Error::SocketBroken(err.into()),
207 // Write permission is denied on the destination socket file, or search permission is
208 // denied for one of the directories the path prefix.
209 libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
210 // Catch all other errors
211 e => Error::SocketError(IOError::from_raw_os_error(e)),
212 }
213 }
214 }
215
216 /// Result of vhost-user operations
217 pub type Result<T> = std::result::Result<T, Error>;
218
219 /// Result of request handler.
220 pub type HandlerResult<T> = std::result::Result<T, IOError>;
221
222 /// Utility function to convert a vector of files into a single file.
223 /// Returns `None` if the vector contains no files or more than one file.
into_single_file(mut files: Vec<File>) -> Option<File>224 pub(crate) fn into_single_file(mut files: Vec<File>) -> Option<File> {
225 if files.len() != 1 {
226 return None;
227 }
228 Some(files.swap_remove(0))
229 }
230
231 #[cfg(test)]
232 mod test_backend;
233
234 #[cfg(test)]
235 mod tests {
236 use std::sync::Arc;
237 use std::sync::Barrier;
238 use std::thread;
239
240 use base::AsRawDescriptor;
241 use tempfile::tempfile;
242
243 use super::*;
244 use crate::message::*;
245 pub(crate) use crate::sys::tests::create_client_server_pair;
246 pub(crate) use crate::sys::tests::create_connection_pair;
247 pub(crate) use crate::sys::tests::create_pair;
248 use crate::test_backend::TestBackend;
249 use crate::test_backend::VIRTIO_FEATURES;
250 use crate::VhostUserMemoryRegionInfo;
251 use crate::VringConfigData;
252
253 /// Utility function to process a header and a message together.
handle_request(h: &mut BackendServer<TestBackend>) -> Result<()>254 fn handle_request(h: &mut BackendServer<TestBackend>) -> Result<()> {
255 // We assume that a header comes together with message body in tests so we don't wait before
256 // calling `process_message()`.
257 let (hdr, files) = h.recv_header()?;
258 h.process_message(hdr, files)
259 }
260
261 #[test]
create_test_backend()262 fn create_test_backend() {
263 let mut backend = TestBackend::new();
264
265 backend.set_owner().unwrap();
266 assert!(backend.set_owner().is_err());
267 }
268
269 #[test]
test_set_owner()270 fn test_set_owner() {
271 let test_backend = TestBackend::new();
272 let (backend_client, mut backend_server) = create_client_server_pair(test_backend);
273
274 assert!(!backend_server.as_ref().owned);
275 backend_client.set_owner().unwrap();
276 handle_request(&mut backend_server).unwrap();
277 assert!(backend_server.as_ref().owned);
278 backend_client.set_owner().unwrap();
279 assert!(handle_request(&mut backend_server).is_err());
280 assert!(backend_server.as_ref().owned);
281 }
282
283 #[test]
test_set_features()284 fn test_set_features() {
285 let mbar = Arc::new(Barrier::new(2));
286 let sbar = mbar.clone();
287 let test_backend = TestBackend::new();
288 let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
289
290 thread::spawn(move || {
291 handle_request(&mut backend_server).unwrap();
292 assert!(backend_server.as_ref().owned);
293
294 handle_request(&mut backend_server).unwrap();
295 handle_request(&mut backend_server).unwrap();
296 assert_eq!(
297 backend_server.as_ref().acked_features,
298 VIRTIO_FEATURES & !0x1
299 );
300
301 handle_request(&mut backend_server).unwrap();
302 handle_request(&mut backend_server).unwrap();
303 assert_eq!(
304 backend_server.as_ref().acked_protocol_features,
305 VhostUserProtocolFeatures::all().bits()
306 );
307
308 sbar.wait();
309 });
310
311 backend_client.set_owner().unwrap();
312
313 // set virtio features
314 let features = backend_client.get_features().unwrap();
315 assert_eq!(features, VIRTIO_FEATURES);
316 backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
317
318 // set vhost protocol features
319 let features = backend_client.get_protocol_features().unwrap();
320 assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
321 backend_client.set_protocol_features(features).unwrap();
322
323 mbar.wait();
324 }
325
326 #[test]
test_client_server_process()327 fn test_client_server_process() {
328 let mbar = Arc::new(Barrier::new(2));
329 let sbar = mbar.clone();
330 let test_backend = TestBackend::new();
331 let (mut backend_client, mut backend_server) = create_client_server_pair(test_backend);
332
333 thread::spawn(move || {
334 // set_own()
335 handle_request(&mut backend_server).unwrap();
336 assert!(backend_server.as_ref().owned);
337
338 // get/set_features()
339 handle_request(&mut backend_server).unwrap();
340 handle_request(&mut backend_server).unwrap();
341 assert_eq!(
342 backend_server.as_ref().acked_features,
343 VIRTIO_FEATURES & !0x1
344 );
345
346 handle_request(&mut backend_server).unwrap();
347 handle_request(&mut backend_server).unwrap();
348 assert_eq!(
349 backend_server.as_ref().acked_protocol_features,
350 VhostUserProtocolFeatures::all().bits()
351 );
352
353 // get_inflight_fd()
354 handle_request(&mut backend_server).unwrap();
355 // set_inflight_fd()
356 handle_request(&mut backend_server).unwrap();
357
358 // get_queue_num()
359 handle_request(&mut backend_server).unwrap();
360
361 // set_mem_table()
362 handle_request(&mut backend_server).unwrap();
363
364 // get/set_config()
365 handle_request(&mut backend_server).unwrap();
366 handle_request(&mut backend_server).unwrap();
367
368 // set_backend_req_fd
369 handle_request(&mut backend_server).unwrap();
370
371 // set_vring_enable
372 handle_request(&mut backend_server).unwrap();
373
374 // set_log_base,set_log_fd()
375 handle_request(&mut backend_server).unwrap_err();
376 handle_request(&mut backend_server).unwrap_err();
377
378 // set_vring_xxx
379 handle_request(&mut backend_server).unwrap();
380 handle_request(&mut backend_server).unwrap();
381 handle_request(&mut backend_server).unwrap();
382 handle_request(&mut backend_server).unwrap();
383 handle_request(&mut backend_server).unwrap();
384 handle_request(&mut backend_server).unwrap();
385
386 // get_max_mem_slots()
387 handle_request(&mut backend_server).unwrap();
388
389 // add_mem_region()
390 handle_request(&mut backend_server).unwrap();
391
392 // remove_mem_region()
393 handle_request(&mut backend_server).unwrap();
394
395 sbar.wait();
396 });
397
398 backend_client.set_owner().unwrap();
399
400 // set virtio features
401 let features = backend_client.get_features().unwrap();
402 assert_eq!(features, VIRTIO_FEATURES);
403 backend_client.set_features(VIRTIO_FEATURES & !0x1).unwrap();
404
405 // set vhost protocol features
406 let features = backend_client.get_protocol_features().unwrap();
407 assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
408 backend_client.set_protocol_features(features).unwrap();
409
410 // Retrieve inflight I/O tracking information
411 let (inflight_info, inflight_file) = backend_client
412 .get_inflight_fd(&VhostUserInflight {
413 num_queues: 2,
414 queue_size: 256,
415 ..Default::default()
416 })
417 .unwrap();
418 // Set the buffer back to the backend
419 backend_client
420 .set_inflight_fd(&inflight_info, inflight_file.as_raw_descriptor())
421 .unwrap();
422
423 let num = backend_client.get_queue_num().unwrap();
424 assert_eq!(num, 2);
425
426 let event = base::Event::new().unwrap();
427 let mem = [VhostUserMemoryRegionInfo {
428 guest_phys_addr: 0,
429 memory_size: 0x10_0000,
430 userspace_addr: 0,
431 mmap_offset: 0,
432 mmap_handle: event.as_raw_descriptor(),
433 }];
434 backend_client.set_mem_table(&mem).unwrap();
435
436 backend_client
437 .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
438 .unwrap();
439 let buf = [0x0u8; 4];
440 let (reply_body, reply_payload) = backend_client
441 .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
442 .unwrap();
443 let offset = reply_body.offset;
444 assert_eq!(offset, 0x100);
445 assert_eq!(reply_payload[0], 0xa5);
446
447 #[cfg(windows)]
448 let tubes = base::Tube::pair().unwrap();
449 #[cfg(windows)]
450 let descriptor =
451 // SAFETY:
452 // Safe because we will be importing the Tube in the other thread.
453 unsafe { tube_transporter::packed_tube::pack(tubes.0, std::process::id()).unwrap() };
454
455 #[cfg(unix)]
456 let descriptor = base::Event::new().unwrap();
457
458 backend_client.set_backend_req_fd(&descriptor).unwrap();
459 backend_client.set_vring_enable(0, true).unwrap();
460
461 // unimplemented yet
462 backend_client
463 .set_log_base(0, Some(event.as_raw_descriptor()))
464 .unwrap();
465 backend_client
466 .set_log_fd(event.as_raw_descriptor())
467 .unwrap();
468
469 backend_client.set_vring_num(0, 256).unwrap();
470 backend_client.set_vring_base(0, 0).unwrap();
471 let config = VringConfigData {
472 queue_size: 128,
473 flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
474 desc_table_addr: 0x1000,
475 used_ring_addr: 0x2000,
476 avail_ring_addr: 0x3000,
477 log_addr: Some(0x4000),
478 };
479 backend_client.set_vring_addr(0, &config).unwrap();
480 backend_client.set_vring_call(0, &event).unwrap();
481 backend_client.set_vring_kick(0, &event).unwrap();
482 backend_client.set_vring_err(0, &event).unwrap();
483
484 let max_mem_slots = backend_client.get_max_mem_slots().unwrap();
485 assert_eq!(max_mem_slots, 32);
486
487 let region_file = tempfile().unwrap();
488 let region = VhostUserMemoryRegionInfo {
489 guest_phys_addr: 0x10_0000,
490 memory_size: 0x10_0000,
491 userspace_addr: 0,
492 mmap_offset: 0,
493 mmap_handle: region_file.as_raw_descriptor(),
494 };
495 backend_client.add_mem_region(®ion).unwrap();
496
497 backend_client.remove_mem_region(®ion).unwrap();
498
499 mbar.wait();
500 }
501
502 #[test]
test_error_display()503 fn test_error_display() {
504 assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
505 assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
506 }
507
508 #[test]
test_error_from_base_error()509 fn test_error_from_base_error() {
510 let e: Error = base::Error::new(libc::EAGAIN).into();
511 if let Error::SocketRetry(e1) = e {
512 assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
513 } else {
514 panic!("invalid error code conversion!");
515 }
516 }
517 }
518