• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! vhost-user input device
2 
3 mod buf_reader;
4 mod vhu_input;
5 mod vio_input;
6 
7 use std::fs;
8 use std::os::fd::{FromRawFd, IntoRawFd};
9 use std::str::FromStr;
10 use std::sync::{Arc, Mutex};
11 
12 use anyhow::{anyhow, bail, Context, Result};
13 use clap::Parser;
14 use log::{error, info, LevelFilter};
15 use vhost::vhost_user::Listener;
16 use vhost_user_backend::VhostUserDaemon;
17 use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
18 
19 use vhu_input::VhostUserInput;
20 use vio_input::VirtioInputConfig;
21 
22 /// Vhost-user input server.
23 #[derive(Parser, Debug)]
24 #[command(about = None, long_about = None)]
25 struct Args {
26     /// Log verbosity, one of Off, Error, Warning, Info, Debug, Trace.
27     #[arg(short, long, default_value_t = String::from("Debug") )]
28     verbosity: String,
29     /// File descriptor for the vhost user backend unix socket.
30     #[arg(short, long, required = true)]
31     socket_fd: i32,
32     /// Path to a file specifying the device's config in JSON format.
33     #[arg(short, long, required = true)]
34     device_config: String,
35 }
36 
init_logging(verbosity: &str) -> Result<()>37 fn init_logging(verbosity: &str) -> Result<()> {
38     env_logger::builder()
39         .format_timestamp_secs()
40         .filter_level(
41             LevelFilter::from_str(verbosity)
42                 .with_context(|| format!("Invalid log level: {}", verbosity))?,
43         )
44         .init();
45     Ok(())
46 }
47 
main() -> Result<()>48 fn main() -> Result<()> {
49     // SAFETY: First thing after main
50     unsafe {
51         rustutils::inherited_fd::init_once()
52             .context("Failed to take ownership of process' file descriptors")?
53     };
54     let args = Args::parse();
55     init_logging(&args.verbosity)?;
56 
57     if args.socket_fd < 0 {
58         bail!("Invalid socket file descriptor: {}", args.socket_fd);
59     }
60 
61     let device_config_str =
62         fs::read_to_string(args.device_config).context("Unable to read device config file")?;
63 
64     let device_config = VirtioInputConfig::from_json(device_config_str.as_str())
65         .context("Unable to parse config file")?;
66 
67     // SAFETY: No choice but to trust the caller passed a valid fd representing a unix socket.
68     let server_fd = rustutils::inherited_fd::take_fd_ownership(args.socket_fd)
69         .context("Failed to take ownership of socket fd")?;
70     loop {
71         let backend =
72             Arc::new(Mutex::new(VhostUserInput::new(device_config.clone(), std::io::stdin())));
73         let mut daemon = VhostUserDaemon::new(
74             "vhost-user-input".to_string(),
75             backend.clone(),
76             GuestMemoryAtomic::new(GuestMemoryMmap::new()),
77         )
78         .map_err(|e| anyhow!("Failed to create vhost user daemon: {:?}", e))?;
79 
80         VhostUserInput::<std::io::Stdin>::register_handlers(
81             0i32, // stdin
82             daemon
83                 .get_epoll_handlers()
84                 .first()
85                 .context("Daemon created without epoll handler threads")?,
86         )
87         .context("Failed to register epoll handler")?;
88 
89         let listener = {
90             // vhost::vhost_user::Listener takes ownership of the underlying fd and closes it when
91             // wait returns, so a dup of the original fd is passed to the constructor.
92             let server_dup = server_fd.try_clone().context("Failed to clone socket fd")?;
93             // SAFETY: Safe because we just dupped this fd and don't use it anywhwere else.
94             // Listener takes ownership and ensures it's properly closed when finished with it.
95             unsafe { Listener::from_raw_fd(server_dup.into_raw_fd()) }
96         };
97         info!("Created vhost-user daemon");
98         daemon
99             .start(listener)
100             .map_err(|e| anyhow!("Failed to start vhost-user daemon: {:?}", e))?;
101         info!("Accepted connection in vhost-user daemon");
102         if let Err(e) = daemon.wait() {
103             // This will print an error even when the frontend disconnects to do a restart.
104             error!("Error: {:?}", e);
105         };
106         info!("Daemon exited");
107     }
108 }
109