• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Driver for VirtIO socket devices.
2 #![deny(unsafe_op_in_unsafe_fn)]
3 
4 use super::error::SocketError;
5 use super::protocol::{
6     Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7 };
8 use super::DEFAULT_RX_BUFFER_SIZE;
9 use crate::config::read_config;
10 use crate::hal::Hal;
11 use crate::queue::{owning::OwningQueue, VirtQueue};
12 use crate::transport::Transport;
13 use crate::Result;
14 use core::mem::size_of;
15 use log::debug;
16 use zerocopy::{FromBytes, IntoBytes};
17 
18 pub(crate) const RX_QUEUE_IDX: u16 = 0;
19 pub(crate) const TX_QUEUE_IDX: u16 = 1;
20 const EVENT_QUEUE_IDX: u16 = 2;
21 
22 pub(crate) const QUEUE_SIZE: usize = 8;
23 const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX.union(Feature::RING_INDIRECT_DESC);
24 
25 /// Information about a particular vsock connection.
26 #[derive(Clone, Debug, Default, PartialEq, Eq)]
27 pub struct ConnectionInfo {
28     /// The address of the peer.
29     pub dst: VsockAddr,
30     /// The local port number associated with the connection.
31     pub src_port: u32,
32     /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
33     /// bytes it has allocated for packet bodies.
34     peer_buf_alloc: u32,
35     /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
36     /// has finished processing.
37     peer_fwd_cnt: u32,
38     /// The number of bytes of packet bodies which we have sent to the peer.
39     tx_cnt: u32,
40     /// The number of bytes of buffer space we have allocated to receive packet bodies from the
41     /// peer.
42     pub buf_alloc: u32,
43     /// The number of bytes of packet bodies which we have received from the peer and handled.
44     fwd_cnt: u32,
45     /// Whether we have recently requested credit from the peer.
46     ///
47     /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
48     /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
49     has_pending_credit_request: bool,
50 }
51 
52 impl ConnectionInfo {
53     /// Creates a new `ConnectionInfo` for the given peer address and local port, and default values
54     /// for everything else.
new(destination: VsockAddr, src_port: u32) -> Self55     pub fn new(destination: VsockAddr, src_port: u32) -> Self {
56         Self {
57             dst: destination,
58             src_port,
59             ..Default::default()
60         }
61     }
62 
63     /// Updates this connection info with the peer buffer allocation and forwarded count from the
64     /// given event.
update_for_event(&mut self, event: &VsockEvent)65     pub fn update_for_event(&mut self, event: &VsockEvent) {
66         self.peer_buf_alloc = event.buffer_status.buffer_allocation;
67         self.peer_fwd_cnt = event.buffer_status.forward_count;
68 
69         if let VsockEventType::CreditUpdate = event.event_type {
70             self.has_pending_credit_request = false;
71         }
72     }
73 
74     /// Increases the forwarded count recorded for this connection by the given number of bytes.
75     ///
76     /// This should be called once received data has been passed to the client, so there is buffer
77     /// space available for more.
done_forwarding(&mut self, length: usize)78     pub fn done_forwarding(&mut self, length: usize) {
79         self.fwd_cnt += length as u32;
80     }
81 
82     /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
83     /// data from us.
peer_free(&self) -> u3284     fn peer_free(&self) -> u32 {
85         self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
86     }
87 
new_header(&self, src_cid: u64) -> VirtioVsockHdr88     fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
89         VirtioVsockHdr {
90             src_cid: src_cid.into(),
91             dst_cid: self.dst.cid.into(),
92             src_port: self.src_port.into(),
93             dst_port: self.dst.port.into(),
94             buf_alloc: self.buf_alloc.into(),
95             fwd_cnt: self.fwd_cnt.into(),
96             ..Default::default()
97         }
98     }
99 }
100 
101 /// An event received from a VirtIO socket device.
102 #[derive(Clone, Debug, Eq, PartialEq)]
103 pub struct VsockEvent {
104     /// The source of the event, i.e. the peer who sent it.
105     pub source: VsockAddr,
106     /// The destination of the event, i.e. the CID and port on our side.
107     pub destination: VsockAddr,
108     /// The peer's buffer status for the connection.
109     pub buffer_status: VsockBufferStatus,
110     /// The type of event.
111     pub event_type: VsockEventType,
112 }
113 
114 impl VsockEvent {
115     /// Returns whether the event matches the given connection.
matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool116     pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
117         self.source == connection_info.dst
118             && self.destination.cid == guest_cid
119             && self.destination.port == connection_info.src_port
120     }
121 
from_header(header: &VirtioVsockHdr) -> Result<Self>122     fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
123         let op = header.op()?;
124         let buffer_status = VsockBufferStatus {
125             buffer_allocation: header.buf_alloc.into(),
126             forward_count: header.fwd_cnt.into(),
127         };
128         let source = header.source();
129         let destination = header.destination();
130 
131         let event_type = match op {
132             VirtioVsockOp::Request => {
133                 header.check_data_is_empty()?;
134                 VsockEventType::ConnectionRequest
135             }
136             VirtioVsockOp::Response => {
137                 header.check_data_is_empty()?;
138                 VsockEventType::Connected
139             }
140             VirtioVsockOp::CreditUpdate => {
141                 header.check_data_is_empty()?;
142                 VsockEventType::CreditUpdate
143             }
144             VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
145                 header.check_data_is_empty()?;
146                 debug!("Disconnected from the peer");
147                 let reason = if op == VirtioVsockOp::Rst {
148                     DisconnectReason::Reset
149                 } else {
150                     DisconnectReason::Shutdown
151                 };
152                 VsockEventType::Disconnected { reason }
153             }
154             VirtioVsockOp::Rw => VsockEventType::Received {
155                 length: header.len() as usize,
156             },
157             VirtioVsockOp::CreditRequest => {
158                 header.check_data_is_empty()?;
159                 VsockEventType::CreditRequest
160             }
161             VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
162         };
163 
164         Ok(VsockEvent {
165             source,
166             destination,
167             buffer_status,
168             event_type,
169         })
170     }
171 }
172 
173 #[derive(Clone, Debug, Eq, PartialEq)]
174 pub struct VsockBufferStatus {
175     pub buffer_allocation: u32,
176     pub forward_count: u32,
177 }
178 
179 /// The reason why a vsock connection was closed.
180 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
181 pub enum DisconnectReason {
182     /// The peer has either closed the connection in response to our shutdown request, or forcibly
183     /// closed it of its own accord.
184     Reset,
185     /// The peer asked to shut down the connection.
186     Shutdown,
187 }
188 
189 /// Details of the type of an event received from a VirtIO socket.
190 #[derive(Clone, Debug, Eq, PartialEq)]
191 pub enum VsockEventType {
192     /// The peer requests to establish a connection with us.
193     ConnectionRequest,
194     /// The connection was successfully established.
195     Connected,
196     /// The connection was closed.
197     Disconnected {
198         /// The reason for the disconnection.
199         reason: DisconnectReason,
200     },
201     /// Data was received on the connection.
202     Received {
203         /// The length of the data in bytes.
204         length: usize,
205     },
206     /// The peer requests us to send a credit update.
207     CreditRequest,
208     /// The peer just sent us a credit update with nothing else.
209     CreditUpdate,
210 }
211 
212 /// Low-level driver for a VirtIO socket device.
213 ///
214 /// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
215 /// using this directly.
216 ///
217 /// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
218 /// bigger than `size_of::<VirtioVsockHdr>()`.
219 pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
220 {
221     transport: T,
222     /// Virtqueue to receive packets.
223     rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
224     tx: VirtQueue<H, { QUEUE_SIZE }>,
225     /// Virtqueue to receive events from the device.
226     event: VirtQueue<H, { QUEUE_SIZE }>,
227     /// The guest_cid field contains the guest’s context ID, which uniquely identifies
228     /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
229     guest_cid: u64,
230 }
231 
232 impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
233     for VirtIOSocket<H, T, RX_BUFFER_SIZE>
234 {
drop(&mut self)235     fn drop(&mut self) {
236         // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
237         // after they have been freed.
238         self.transport.queue_unset(RX_QUEUE_IDX);
239         self.transport.queue_unset(TX_QUEUE_IDX);
240         self.transport.queue_unset(EVENT_QUEUE_IDX);
241     }
242 }
243 
244 impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
245     /// Create a new VirtIO Vsock driver.
new(mut transport: T) -> Result<Self>246     pub fn new(mut transport: T) -> Result<Self> {
247         assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
248 
249         let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
250 
251         let guest_cid = transport.read_consistent(|| {
252             Ok(
253                 read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64
254                     | (read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32,
255             )
256         })?;
257         debug!("guest cid: {guest_cid:?}");
258 
259         let rx = VirtQueue::new(
260             &mut transport,
261             RX_QUEUE_IDX,
262             negotiated_features.contains(Feature::RING_INDIRECT_DESC),
263             negotiated_features.contains(Feature::RING_EVENT_IDX),
264         )?;
265         let tx = VirtQueue::new(
266             &mut transport,
267             TX_QUEUE_IDX,
268             negotiated_features.contains(Feature::RING_INDIRECT_DESC),
269             negotiated_features.contains(Feature::RING_EVENT_IDX),
270         )?;
271         let event = VirtQueue::new(
272             &mut transport,
273             EVENT_QUEUE_IDX,
274             negotiated_features.contains(Feature::RING_INDIRECT_DESC),
275             negotiated_features.contains(Feature::RING_EVENT_IDX),
276         )?;
277 
278         let rx = OwningQueue::new(rx)?;
279 
280         transport.finish_init();
281         if rx.should_notify() {
282             transport.notify(RX_QUEUE_IDX);
283         }
284 
285         Ok(Self {
286             transport,
287             rx,
288             tx,
289             event,
290             guest_cid,
291         })
292     }
293 
294     /// Returns the CID which has been assigned to this guest.
guest_cid(&self) -> u64295     pub fn guest_cid(&self) -> u64 {
296         self.guest_cid
297     }
298 
299     /// Sends a request to connect to the given destination.
300     ///
301     /// This returns as soon as the request is sent; you should wait until `poll` returns a
302     /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
303     /// before sending data.
connect(&mut self, connection_info: &ConnectionInfo) -> Result304     pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
305         let header = VirtioVsockHdr {
306             op: VirtioVsockOp::Request.into(),
307             ..connection_info.new_header(self.guest_cid)
308         };
309         // Sends a header only packet to the TX queue to connect the device to the listening socket
310         // at the given destination.
311         self.send_packet_to_tx_queue(&header, &[])
312     }
313 
314     /// Accepts the given connection from a peer.
accept(&mut self, connection_info: &ConnectionInfo) -> Result315     pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
316         let header = VirtioVsockHdr {
317             op: VirtioVsockOp::Response.into(),
318             ..connection_info.new_header(self.guest_cid)
319         };
320         self.send_packet_to_tx_queue(&header, &[])
321     }
322 
323     /// Requests the peer to send us a credit update for the given connection.
request_credit(&mut self, connection_info: &ConnectionInfo) -> Result324     fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
325         let header = VirtioVsockHdr {
326             op: VirtioVsockOp::CreditRequest.into(),
327             ..connection_info.new_header(self.guest_cid)
328         };
329         self.send_packet_to_tx_queue(&header, &[])
330     }
331 
332     /// Sends the buffer to the destination.
send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result333     pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
334         self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
335 
336         let len = buffer.len() as u32;
337         let header = VirtioVsockHdr {
338             op: VirtioVsockOp::Rw.into(),
339             len: len.into(),
340             ..connection_info.new_header(self.guest_cid)
341         };
342         connection_info.tx_cnt += len;
343         self.send_packet_to_tx_queue(&header, buffer)
344     }
345 
check_peer_buffer_is_sufficient( &mut self, connection_info: &mut ConnectionInfo, buffer_len: usize, ) -> Result346     fn check_peer_buffer_is_sufficient(
347         &mut self,
348         connection_info: &mut ConnectionInfo,
349         buffer_len: usize,
350     ) -> Result {
351         if connection_info.peer_free() as usize >= buffer_len {
352             Ok(())
353         } else {
354             // Request an update of the cached peer credit, if we haven't already done so, and tell
355             // the caller to try again later.
356             if !connection_info.has_pending_credit_request {
357                 self.request_credit(connection_info)?;
358                 connection_info.has_pending_credit_request = true;
359             }
360             Err(SocketError::InsufficientBufferSpaceInPeer.into())
361         }
362     }
363 
364     /// Tells the peer how much buffer space we have to receive data.
credit_update(&mut self, connection_info: &ConnectionInfo) -> Result365     pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
366         let header = VirtioVsockHdr {
367             op: VirtioVsockOp::CreditUpdate.into(),
368             ..connection_info.new_header(self.guest_cid)
369         };
370         self.send_packet_to_tx_queue(&header, &[])
371     }
372 
373     /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
374     /// it.
poll( &mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>, ) -> Result<Option<VsockEvent>>375     pub fn poll(
376         &mut self,
377         handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
378     ) -> Result<Option<VsockEvent>> {
379         self.rx.poll(&mut self.transport, |buffer| {
380             let (header, body) = read_header_and_body(buffer)?;
381             VsockEvent::from_header(&header).and_then(|event| handler(event, body))
382         })
383     }
384 
385     /// Requests to shut down the connection cleanly, sending hints about whether we will send or
386     /// receive more data.
387     ///
388     /// This returns as soon as the request is sent; you should wait until `poll` returns a
389     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
390     /// shutdown.
shutdown_with_hints( &mut self, connection_info: &ConnectionInfo, hints: StreamShutdown, ) -> Result391     pub fn shutdown_with_hints(
392         &mut self,
393         connection_info: &ConnectionInfo,
394         hints: StreamShutdown,
395     ) -> Result {
396         let header = VirtioVsockHdr {
397             op: VirtioVsockOp::Shutdown.into(),
398             flags: hints.into(),
399             ..connection_info.new_header(self.guest_cid)
400         };
401         self.send_packet_to_tx_queue(&header, &[])
402     }
403 
404     /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
405     /// any more data.
406     ///
407     /// This returns as soon as the request is sent; you should wait until `poll` returns a
408     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
409     /// shutdown.
shutdown(&mut self, connection_info: &ConnectionInfo) -> Result410     pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
411         self.shutdown_with_hints(
412             connection_info,
413             StreamShutdown::SEND | StreamShutdown::RECEIVE,
414         )
415     }
416 
417     /// Forcibly closes the connection without waiting for the peer.
force_close(&mut self, connection_info: &ConnectionInfo) -> Result418     pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
419         let header = VirtioVsockHdr {
420             op: VirtioVsockOp::Rst.into(),
421             ..connection_info.new_header(self.guest_cid)
422         };
423         self.send_packet_to_tx_queue(&header, &[])?;
424         Ok(())
425     }
426 
send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result427     fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
428         let _len = if buffer.is_empty() {
429             self.tx
430                 .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
431         } else {
432             self.tx.add_notify_wait_pop(
433                 &[header.as_bytes(), buffer],
434                 &mut [],
435                 &mut self.transport,
436             )?
437         };
438         Ok(())
439     }
440 }
441 
read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])>442 fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
443     // This could fail if the device returns a buffer used length shorter than the header size.
444     let header = VirtioVsockHdr::read_from_prefix(buffer)
445         .map_err(|_| SocketError::BufferTooShort)?
446         .0;
447     let body_length = header.len() as usize;
448 
449     // This could fail if the device returns an unreasonably long body length.
450     let data_end = size_of::<VirtioVsockHdr>()
451         .checked_add(body_length)
452         .ok_or(SocketError::InvalidNumber)?;
453     // This could fail if the device returns a body length longer than buffer used length it
454     // returned.
455     let data = buffer
456         .get(size_of::<VirtioVsockHdr>()..data_end)
457         .ok_or(SocketError::BufferTooShort)?;
458     Ok((header, data))
459 }
460 
461 #[cfg(test)]
462 mod tests {
463     use super::*;
464     use crate::{
465         config::ReadOnly,
466         hal::fake::FakeHal,
467         transport::{
468             fake::{FakeTransport, QueueStatus, State},
469             DeviceType,
470         },
471     };
472     use alloc::{sync::Arc, vec};
473     use std::sync::Mutex;
474 
475     #[test]
config()476     fn config() {
477         let config_space = VirtioVsockConfig {
478             guest_cid_low: ReadOnly::new(66),
479             guest_cid_high: ReadOnly::new(0),
480         };
481         let state = Arc::new(Mutex::new(State::new(
482             vec![
483                 QueueStatus::default(),
484                 QueueStatus::default(),
485                 QueueStatus::default(),
486             ],
487             config_space,
488         )));
489         let transport = FakeTransport {
490             device_type: DeviceType::Socket,
491             max_queue_size: 32,
492             device_features: 0,
493             state: state.clone(),
494         };
495         let socket =
496             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
497         assert_eq!(socket.guest_cid(), 0x00_0000_0042);
498     }
499 }
500