use super::{ protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket, VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE, }; use crate::{transport::Transport, Hal, Result}; use alloc::{boxed::Box, vec::Vec}; use core::cmp::min; use core::convert::TryInto; use core::hint::spin_loop; use log::debug; use zerocopy::FromZeroes; const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024; /// A higher level interface for VirtIO socket (vsock) devices. /// /// This keeps track of multiple vsock connections. /// /// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be /// bigger than `size_of::()`. /// /// # Example /// /// ``` /// # use virtio_drivers::{Error, Hal}; /// # use virtio_drivers::transport::Transport; /// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager}; /// /// # fn example(transport: T) -> Result<(), Error> { /// let mut socket = VsockConnectionManager::new(VirtIOSocket::::new(transport)?); /// /// // Start a thread to call `socket.poll()` and handle events. /// /// let remote_address = VsockAddr { cid: 2, port: 42 }; /// let local_port = 1234; /// socket.connect(remote_address, local_port)?; /// /// // Wait until `socket.poll()` returns an event indicating that the socket is connected. /// /// socket.send(remote_address, local_port, "Hello world".as_bytes())?; /// /// socket.shutdown(remote_address, local_port)?; /// # Ok(()) /// # } /// ``` pub struct VsockConnectionManager< H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE, > { driver: VirtIOSocket, per_connection_buffer_capacity: u32, connections: Vec, listening_ports: Vec, } #[derive(Debug)] struct Connection { info: ConnectionInfo, buffer: RingBuffer, /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is /// still data in the buffer. peer_requested_shutdown: bool, } impl Connection { fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self { let mut info = ConnectionInfo::new(peer, local_port); info.buf_alloc = buffer_capacity; Self { info, buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()), peer_requested_shutdown: false, } } } impl VsockConnectionManager { /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. pub fn new(driver: VirtIOSocket) -> Self { Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY) } /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with /// the given per-connection buffer capacity. pub fn new_with_capacity( driver: VirtIOSocket, per_connection_buffer_capacity: u32, ) -> Self { Self { driver, connections: Vec::new(), listening_ports: Vec::new(), per_connection_buffer_capacity, } } /// Returns the CID which has been assigned to this guest. pub fn guest_cid(&self) -> u64 { self.driver.guest_cid() } /// Allows incoming connections on the given port number. pub fn listen(&mut self, port: u32) { if !self.listening_ports.contains(&port) { self.listening_ports.push(port); } } /// Stops allowing incoming connections on the given port number. pub fn unlisten(&mut self, port: u32) { self.listening_ports.retain(|p| *p != port); } /// Sends a request to connect to the given destination. /// /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Connected` event indicating that the peer has accepted the connection /// before sending data. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result { if self.connections.iter().any(|connection| { connection.info.dst == destination && connection.info.src_port == src_port }) { return Err(SocketError::ConnectionExists.into()); } let new_connection = Connection::new(destination, src_port, self.per_connection_buffer_capacity); self.driver.connect(&new_connection.info)?; debug!("Connection requested: {:?}", new_connection.info); self.connections.push(new_connection); Ok(()) } /// Sends the buffer to the destination. pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result { let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; self.driver.send(buffer, &mut connection.info) } /// Polls the vsock device to receive data or other updates. pub fn poll(&mut self) -> Result> { let guest_cid = self.driver.guest_cid(); let connections = &mut self.connections; let per_connection_buffer_capacity = self.per_connection_buffer_capacity; let result = self.driver.poll(|event, body| { let connection = get_connection_for_event(connections, &event, guest_cid); // Skip events which don't match any connection we know about, unless they are a // connection request. let connection = if let Some((_, connection)) = connection { connection } else if let VsockEventType::ConnectionRequest = event.event_type { // If the requested connection already exists or the CID isn't ours, ignore it. if connection.is_some() || event.destination.cid != guest_cid { return Ok(None); } // Add the new connection to our list, at least for now. It will be removed again // below if we weren't listening on the port. connections.push(Connection::new( event.source, event.destination.port, per_connection_buffer_capacity, )); connections.last_mut().unwrap() } else { return Ok(None); }; // Update stored connection info. connection.info.update_for_event(&event); if let VsockEventType::Received { length } = event.event_type { // Copy to buffer if !connection.buffer.add(body) { return Err(SocketError::OutputBufferTooShort(length).into()); } } Ok(Some(event)) })?; let Some(event) = result else { return Ok(None); }; // The connection must exist because we found it above in the callback. let (connection_index, connection) = get_connection_for_event(connections, &event, guest_cid).unwrap(); match event.event_type { VsockEventType::ConnectionRequest => { if self.listening_ports.contains(&event.destination.port) { self.driver.accept(&connection.info)?; } else { // Reject the connection request and remove it from our list. self.driver.force_close(&connection.info)?; self.connections.swap_remove(connection_index); // No need to pass the request on to the client, as we've already rejected it. return Ok(None); } } VsockEventType::Connected => {} VsockEventType::Disconnected { reason } => { // Wait until client reads all data before removing connection. if connection.buffer.is_empty() { if reason == DisconnectReason::Shutdown { self.driver.force_close(&connection.info)?; } self.connections.swap_remove(connection_index); } else { connection.peer_requested_shutdown = true; } } VsockEventType::Received { .. } => { // Already copied the buffer in the callback above. } VsockEventType::CreditRequest => { // If the peer requested credit, send an update. self.driver.credit_update(&connection.info)?; // No need to pass the request on to the client, we've already handled it. return Ok(None); } VsockEventType::CreditUpdate => {} } Ok(Some(event)) } /// Reads data received from the given connection. pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result { let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?; // Copy from ring buffer let bytes_read = connection.buffer.drain(buffer); connection.info.done_forwarding(bytes_read); // If buffer is now empty and the peer requested shutdown, finish shutting down the // connection. if connection.peer_requested_shutdown && connection.buffer.is_empty() { self.driver.force_close(&connection.info)?; self.connections.swap_remove(connection_index); } Ok(bytes_read) } /// Returns the number of bytes in the receive buffer available to be read by `recv`. /// /// When the available bytes is 0, it indicates that the receive buffer is empty and does not /// contain any data. pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result { let (_, connection) = get_connection(&mut self.connections, peer, src_port)?; Ok(connection.buffer.used()) } /// Sends a credit update to the given peer. pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result { let (_, connection) = get_connection(&mut self.connections, peer, src_port)?; self.driver.credit_update(&connection.info) } /// Blocks until we get some event from the vsock device. pub fn wait_for_event(&mut self) -> Result { loop { if let Some(event) = self.poll()? { return Ok(event); } else { spin_loop(); } } } /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive /// any more data. /// /// This returns as soon as the request is sent; you should wait until `poll` returns a /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// shutdown. pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result { let (_, connection) = get_connection(&mut self.connections, destination, src_port)?; self.driver.shutdown(&connection.info) } /// Forcibly closes the connection without waiting for the peer. pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result { let (index, connection) = get_connection(&mut self.connections, destination, src_port)?; self.driver.force_close(&connection.info)?; self.connections.swap_remove(index); Ok(()) } } /// Returns the connection from the given list matching the given peer address and local port, and /// its index. /// /// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list. fn get_connection( connections: &mut [Connection], peer: VsockAddr, local_port: u32, ) -> core::result::Result<(usize, &mut Connection), SocketError> { connections .iter_mut() .enumerate() .find(|(_, connection)| { connection.info.dst == peer && connection.info.src_port == local_port }) .ok_or(SocketError::NotConnected) } /// Returns the connection from the given list matching the event, if any, and its index. fn get_connection_for_event<'a>( connections: &'a mut [Connection], event: &VsockEvent, local_cid: u64, ) -> Option<(usize, &'a mut Connection)> { connections .iter_mut() .enumerate() .find(|(_, connection)| event.matches_connection(&connection.info, local_cid)) } #[derive(Debug)] struct RingBuffer { buffer: Box<[u8]>, /// The number of bytes currently in the buffer. used: usize, /// The index of the first used byte in the buffer. start: usize, } impl RingBuffer { pub fn new(capacity: usize) -> Self { Self { buffer: FromZeroes::new_box_slice_zeroed(capacity), used: 0, start: 0, } } /// Returns the number of bytes currently used in the buffer. pub fn used(&self) -> usize { self.used } /// Returns true iff there are currently no bytes in the buffer. pub fn is_empty(&self) -> bool { self.used == 0 } /// Returns the number of bytes currently free in the buffer. pub fn free(&self) -> usize { self.buffer.len() - self.used } /// Adds the given bytes to the buffer if there is enough capacity for them all. /// /// Returns true if they were added, or false if they were not. pub fn add(&mut self, bytes: &[u8]) -> bool { if bytes.len() > self.free() { return false; } // The index of the first available position in the buffer. let first_available = (self.start + self.used) % self.buffer.len(); // The number of bytes to copy from `bytes` to `buffer` between `first_available` and // `buffer.len()`. let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available); self.buffer[first_available..first_available + copy_length_before_wraparound] .copy_from_slice(&bytes[0..copy_length_before_wraparound]); if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) { self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound); } self.used += bytes.len(); true } /// Reads and removes as many bytes as possible from the buffer, up to the length of the given /// buffer. pub fn drain(&mut self, out: &mut [u8]) -> usize { let bytes_read = min(self.used, out.len()); // The number of bytes to copy out between `start` and the end of the buffer. let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start); // The number of bytes to copy out from the beginning of the buffer after wrapping around. let read_after_wraparound = bytes_read .checked_sub(read_before_wraparound) .unwrap_or_default(); out[0..read_before_wraparound] .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]); out[read_before_wraparound..bytes_read] .copy_from_slice(&self.buffer[0..read_after_wraparound]); self.used -= bytes_read; self.start = (self.start + bytes_read) % self.buffer.len(); bytes_read } } #[cfg(test)] mod tests { use super::*; use crate::{ device::socket::{ protocol::{ SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, }, vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX}, }, hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, DeviceType, }, volatile::ReadOnly, }; use alloc::{sync::Arc, vec}; use core::{mem::size_of, ptr::NonNull}; use std::{sync::Mutex, thread}; use zerocopy::{AsBytes, FromBytes}; #[test] fn send_recv() { let host_cid = 2; let guest_cid = 66; let host_port = 1234; let guest_port = 4321; let host_address = VsockAddr { cid: host_cid, port: host_port, }; let hello_from_guest = "Hello from guest"; let hello_from_host = "Hello from host"; let mut config_space = VirtioVsockConfig { guest_cid_low: ReadOnly::new(66), guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Socket, max_queue_size: 32, device_features: 0, config_space: NonNull::from(&mut config_space), state: state.clone(), }; let mut socket = VsockConnectionManager::new( VirtIOSocket::>::new(transport).unwrap(), ); // Start a thread to simulate the device. let handle = thread::spawn(move || { // Wait for connection request. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Request.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 1024.into(), fwd_cnt: 0.into(), } ); // Accept connection and give the peer enough credit to send the message. state.lock().unwrap().write_to_queue::( RX_QUEUE_IDX, VirtioVsockHdr { op: VirtioVsockOp::Response.into(), src_cid: host_cid.into(), dst_cid: guest_cid.into(), src_port: host_port.into(), dst_port: guest_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 50.into(), fwd_cnt: 0.into(), } .as_bytes(), ); // Expect the guest to send some data. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); let request = state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX); assert_eq!( request.len(), size_of::() + hello_from_guest.len() ); assert_eq!( VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Rw.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: (hello_from_guest.len() as u32).into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 1024.into(), fwd_cnt: 0.into(), } ); assert_eq!( &request[size_of::()..], hello_from_guest.as_bytes() ); println!("Host sending"); // Send a response. let mut response = vec![0; size_of::() + hello_from_host.len()]; VirtioVsockHdr { op: VirtioVsockOp::Rw.into(), src_cid: host_cid.into(), dst_cid: guest_cid.into(), src_port: host_port.into(), dst_port: guest_port.into(), len: (hello_from_host.len() as u32).into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 50.into(), fwd_cnt: (hello_from_guest.len() as u32).into(), } .write_to_prefix(response.as_mut_slice()); response[size_of::()..].copy_from_slice(hello_from_host.as_bytes()); state .lock() .unwrap() .write_to_queue::(RX_QUEUE_IDX, &response); // Expect a shutdown. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Shutdown.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(), buf_alloc: 1024.into(), fwd_cnt: (hello_from_host.len() as u32).into(), } ); }); socket.connect(host_address, guest_port).unwrap(); assert_eq!( socket.wait_for_event().unwrap(), VsockEvent { source: host_address, destination: VsockAddr { cid: guest_cid, port: guest_port, }, event_type: VsockEventType::Connected, buffer_status: VsockBufferStatus { buffer_allocation: 50, forward_count: 0, }, } ); println!("Guest sending"); socket .send(host_address, guest_port, "Hello from guest".as_bytes()) .unwrap(); println!("Guest waiting to receive."); assert_eq!( socket.wait_for_event().unwrap(), VsockEvent { source: host_address, destination: VsockAddr { cid: guest_cid, port: guest_port, }, event_type: VsockEventType::Received { length: hello_from_host.len() }, buffer_status: VsockBufferStatus { buffer_allocation: 50, forward_count: hello_from_guest.len() as u32, }, } ); println!("Guest getting received data."); let mut buffer = [0u8; 64]; assert_eq!( socket.recv(host_address, guest_port, &mut buffer).unwrap(), hello_from_host.len() ); assert_eq!( &buffer[0..hello_from_host.len()], hello_from_host.as_bytes() ); socket.shutdown(host_address, guest_port).unwrap(); handle.join().unwrap(); } #[test] fn incoming_connection() { let host_cid = 2; let guest_cid = 66; let host_port = 1234; let guest_port = 4321; let wrong_guest_port = 4444; let host_address = VsockAddr { cid: host_cid, port: host_port, }; let mut config_space = VirtioVsockConfig { guest_cid_low: ReadOnly::new(66), guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], ..Default::default() })); let transport = FakeTransport { device_type: DeviceType::Socket, max_queue_size: 32, device_features: 0, config_space: NonNull::from(&mut config_space), state: state.clone(), }; let mut socket = VsockConnectionManager::new( VirtIOSocket::>::new(transport).unwrap(), ); socket.listen(guest_port); // Start a thread to simulate the device. let handle = thread::spawn(move || { // Send a connection request for a port the guest isn't listening on. println!("Host sending connection request to wrong port"); state.lock().unwrap().write_to_queue::( RX_QUEUE_IDX, VirtioVsockHdr { op: VirtioVsockOp::Request.into(), src_cid: host_cid.into(), dst_cid: guest_cid.into(), src_port: host_port.into(), dst_port: wrong_guest_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 50.into(), fwd_cnt: 0.into(), } .as_bytes(), ); // Expect a rejection. println!("Host waiting for rejection"); State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Rst.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: wrong_guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 1024.into(), fwd_cnt: 0.into(), } ); // Send a connection request for a port the guest is listening on. println!("Host sending connection request to right port"); state.lock().unwrap().write_to_queue::( RX_QUEUE_IDX, VirtioVsockHdr { op: VirtioVsockOp::Request.into(), src_cid: host_cid.into(), dst_cid: guest_cid.into(), src_port: host_port.into(), dst_port: guest_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 50.into(), fwd_cnt: 0.into(), } .as_bytes(), ); // Expect a response. println!("Host waiting for response"); State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Response.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 1024.into(), fwd_cnt: 0.into(), } ); println!("Host finished"); }); // Expect an incoming connection. println!("Guest expecting incoming connection."); assert_eq!( socket.wait_for_event().unwrap(), VsockEvent { source: host_address, destination: VsockAddr { cid: guest_cid, port: guest_port, }, event_type: VsockEventType::ConnectionRequest, buffer_status: VsockBufferStatus { buffer_allocation: 50, forward_count: 0, }, } ); handle.join().unwrap(); } }