1 //! Driver for VirtIO socket devices.
2 #![deny(unsafe_op_in_unsafe_fn)]
3
4 use super::error::SocketError;
5 use super::protocol::{
6 Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7 };
8 use super::DEFAULT_RX_BUFFER_SIZE;
9 use crate::config::read_config;
10 use crate::hal::Hal;
11 use crate::queue::{owning::OwningQueue, VirtQueue};
12 use crate::transport::Transport;
13 use crate::Result;
14 use core::mem::size_of;
15 use log::debug;
16 use zerocopy::{FromBytes, IntoBytes};
17
18 pub(crate) const RX_QUEUE_IDX: u16 = 0;
19 pub(crate) const TX_QUEUE_IDX: u16 = 1;
20 const EVENT_QUEUE_IDX: u16 = 2;
21
22 pub(crate) const QUEUE_SIZE: usize = 8;
23 const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX.union(Feature::RING_INDIRECT_DESC);
24
25 /// Information about a particular vsock connection.
26 #[derive(Clone, Debug, Default, PartialEq, Eq)]
27 pub struct ConnectionInfo {
28 /// The address of the peer.
29 pub dst: VsockAddr,
30 /// The local port number associated with the connection.
31 pub src_port: u32,
32 /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
33 /// bytes it has allocated for packet bodies.
34 peer_buf_alloc: u32,
35 /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
36 /// has finished processing.
37 peer_fwd_cnt: u32,
38 /// The number of bytes of packet bodies which we have sent to the peer.
39 tx_cnt: u32,
40 /// The number of bytes of buffer space we have allocated to receive packet bodies from the
41 /// peer.
42 pub buf_alloc: u32,
43 /// The number of bytes of packet bodies which we have received from the peer and handled.
44 fwd_cnt: u32,
45 /// Whether we have recently requested credit from the peer.
46 ///
47 /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
48 /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
49 has_pending_credit_request: bool,
50 }
51
52 impl ConnectionInfo {
53 /// Creates a new `ConnectionInfo` for the given peer address and local port, and default values
54 /// for everything else.
new(destination: VsockAddr, src_port: u32) -> Self55 pub fn new(destination: VsockAddr, src_port: u32) -> Self {
56 Self {
57 dst: destination,
58 src_port,
59 ..Default::default()
60 }
61 }
62
63 /// Updates this connection info with the peer buffer allocation and forwarded count from the
64 /// given event.
update_for_event(&mut self, event: &VsockEvent)65 pub fn update_for_event(&mut self, event: &VsockEvent) {
66 self.peer_buf_alloc = event.buffer_status.buffer_allocation;
67 self.peer_fwd_cnt = event.buffer_status.forward_count;
68
69 if let VsockEventType::CreditUpdate = event.event_type {
70 self.has_pending_credit_request = false;
71 }
72 }
73
74 /// Increases the forwarded count recorded for this connection by the given number of bytes.
75 ///
76 /// This should be called once received data has been passed to the client, so there is buffer
77 /// space available for more.
done_forwarding(&mut self, length: usize)78 pub fn done_forwarding(&mut self, length: usize) {
79 self.fwd_cnt += length as u32;
80 }
81
82 /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
83 /// data from us.
peer_free(&self) -> u3284 fn peer_free(&self) -> u32 {
85 self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
86 }
87
new_header(&self, src_cid: u64) -> VirtioVsockHdr88 fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
89 VirtioVsockHdr {
90 src_cid: src_cid.into(),
91 dst_cid: self.dst.cid.into(),
92 src_port: self.src_port.into(),
93 dst_port: self.dst.port.into(),
94 buf_alloc: self.buf_alloc.into(),
95 fwd_cnt: self.fwd_cnt.into(),
96 ..Default::default()
97 }
98 }
99 }
100
101 /// An event received from a VirtIO socket device.
102 #[derive(Clone, Debug, Eq, PartialEq)]
103 pub struct VsockEvent {
104 /// The source of the event, i.e. the peer who sent it.
105 pub source: VsockAddr,
106 /// The destination of the event, i.e. the CID and port on our side.
107 pub destination: VsockAddr,
108 /// The peer's buffer status for the connection.
109 pub buffer_status: VsockBufferStatus,
110 /// The type of event.
111 pub event_type: VsockEventType,
112 }
113
114 impl VsockEvent {
115 /// Returns whether the event matches the given connection.
matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool116 pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
117 self.source == connection_info.dst
118 && self.destination.cid == guest_cid
119 && self.destination.port == connection_info.src_port
120 }
121
from_header(header: &VirtioVsockHdr) -> Result<Self>122 fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
123 let op = header.op()?;
124 let buffer_status = VsockBufferStatus {
125 buffer_allocation: header.buf_alloc.into(),
126 forward_count: header.fwd_cnt.into(),
127 };
128 let source = header.source();
129 let destination = header.destination();
130
131 let event_type = match op {
132 VirtioVsockOp::Request => {
133 header.check_data_is_empty()?;
134 VsockEventType::ConnectionRequest
135 }
136 VirtioVsockOp::Response => {
137 header.check_data_is_empty()?;
138 VsockEventType::Connected
139 }
140 VirtioVsockOp::CreditUpdate => {
141 header.check_data_is_empty()?;
142 VsockEventType::CreditUpdate
143 }
144 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
145 header.check_data_is_empty()?;
146 debug!("Disconnected from the peer");
147 let reason = if op == VirtioVsockOp::Rst {
148 DisconnectReason::Reset
149 } else {
150 DisconnectReason::Shutdown
151 };
152 VsockEventType::Disconnected { reason }
153 }
154 VirtioVsockOp::Rw => VsockEventType::Received {
155 length: header.len() as usize,
156 },
157 VirtioVsockOp::CreditRequest => {
158 header.check_data_is_empty()?;
159 VsockEventType::CreditRequest
160 }
161 VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
162 };
163
164 Ok(VsockEvent {
165 source,
166 destination,
167 buffer_status,
168 event_type,
169 })
170 }
171 }
172
173 #[derive(Clone, Debug, Eq, PartialEq)]
174 pub struct VsockBufferStatus {
175 pub buffer_allocation: u32,
176 pub forward_count: u32,
177 }
178
179 /// The reason why a vsock connection was closed.
180 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
181 pub enum DisconnectReason {
182 /// The peer has either closed the connection in response to our shutdown request, or forcibly
183 /// closed it of its own accord.
184 Reset,
185 /// The peer asked to shut down the connection.
186 Shutdown,
187 }
188
189 /// Details of the type of an event received from a VirtIO socket.
190 #[derive(Clone, Debug, Eq, PartialEq)]
191 pub enum VsockEventType {
192 /// The peer requests to establish a connection with us.
193 ConnectionRequest,
194 /// The connection was successfully established.
195 Connected,
196 /// The connection was closed.
197 Disconnected {
198 /// The reason for the disconnection.
199 reason: DisconnectReason,
200 },
201 /// Data was received on the connection.
202 Received {
203 /// The length of the data in bytes.
204 length: usize,
205 },
206 /// The peer requests us to send a credit update.
207 CreditRequest,
208 /// The peer just sent us a credit update with nothing else.
209 CreditUpdate,
210 }
211
212 /// Low-level driver for a VirtIO socket device.
213 ///
214 /// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
215 /// using this directly.
216 ///
217 /// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
218 /// bigger than `size_of::<VirtioVsockHdr>()`.
219 pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
220 {
221 transport: T,
222 /// Virtqueue to receive packets.
223 rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
224 tx: VirtQueue<H, { QUEUE_SIZE }>,
225 /// Virtqueue to receive events from the device.
226 event: VirtQueue<H, { QUEUE_SIZE }>,
227 /// The guest_cid field contains the guest’s context ID, which uniquely identifies
228 /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
229 guest_cid: u64,
230 }
231
232 impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
233 for VirtIOSocket<H, T, RX_BUFFER_SIZE>
234 {
drop(&mut self)235 fn drop(&mut self) {
236 // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
237 // after they have been freed.
238 self.transport.queue_unset(RX_QUEUE_IDX);
239 self.transport.queue_unset(TX_QUEUE_IDX);
240 self.transport.queue_unset(EVENT_QUEUE_IDX);
241 }
242 }
243
244 impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
245 /// Create a new VirtIO Vsock driver.
new(mut transport: T) -> Result<Self>246 pub fn new(mut transport: T) -> Result<Self> {
247 assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
248
249 let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
250
251 let guest_cid = transport.read_consistent(|| {
252 Ok(
253 read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64
254 | (read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32,
255 )
256 })?;
257 debug!("guest cid: {guest_cid:?}");
258
259 let rx = VirtQueue::new(
260 &mut transport,
261 RX_QUEUE_IDX,
262 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
263 negotiated_features.contains(Feature::RING_EVENT_IDX),
264 )?;
265 let tx = VirtQueue::new(
266 &mut transport,
267 TX_QUEUE_IDX,
268 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
269 negotiated_features.contains(Feature::RING_EVENT_IDX),
270 )?;
271 let event = VirtQueue::new(
272 &mut transport,
273 EVENT_QUEUE_IDX,
274 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
275 negotiated_features.contains(Feature::RING_EVENT_IDX),
276 )?;
277
278 let rx = OwningQueue::new(rx)?;
279
280 transport.finish_init();
281 if rx.should_notify() {
282 transport.notify(RX_QUEUE_IDX);
283 }
284
285 Ok(Self {
286 transport,
287 rx,
288 tx,
289 event,
290 guest_cid,
291 })
292 }
293
294 /// Returns the CID which has been assigned to this guest.
guest_cid(&self) -> u64295 pub fn guest_cid(&self) -> u64 {
296 self.guest_cid
297 }
298
299 /// Sends a request to connect to the given destination.
300 ///
301 /// This returns as soon as the request is sent; you should wait until `poll` returns a
302 /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
303 /// before sending data.
connect(&mut self, connection_info: &ConnectionInfo) -> Result304 pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
305 let header = VirtioVsockHdr {
306 op: VirtioVsockOp::Request.into(),
307 ..connection_info.new_header(self.guest_cid)
308 };
309 // Sends a header only packet to the TX queue to connect the device to the listening socket
310 // at the given destination.
311 self.send_packet_to_tx_queue(&header, &[])
312 }
313
314 /// Accepts the given connection from a peer.
accept(&mut self, connection_info: &ConnectionInfo) -> Result315 pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
316 let header = VirtioVsockHdr {
317 op: VirtioVsockOp::Response.into(),
318 ..connection_info.new_header(self.guest_cid)
319 };
320 self.send_packet_to_tx_queue(&header, &[])
321 }
322
323 /// Requests the peer to send us a credit update for the given connection.
request_credit(&mut self, connection_info: &ConnectionInfo) -> Result324 fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
325 let header = VirtioVsockHdr {
326 op: VirtioVsockOp::CreditRequest.into(),
327 ..connection_info.new_header(self.guest_cid)
328 };
329 self.send_packet_to_tx_queue(&header, &[])
330 }
331
332 /// Sends the buffer to the destination.
send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result333 pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
334 self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
335
336 let len = buffer.len() as u32;
337 let header = VirtioVsockHdr {
338 op: VirtioVsockOp::Rw.into(),
339 len: len.into(),
340 ..connection_info.new_header(self.guest_cid)
341 };
342 connection_info.tx_cnt += len;
343 self.send_packet_to_tx_queue(&header, buffer)
344 }
345
check_peer_buffer_is_sufficient( &mut self, connection_info: &mut ConnectionInfo, buffer_len: usize, ) -> Result346 fn check_peer_buffer_is_sufficient(
347 &mut self,
348 connection_info: &mut ConnectionInfo,
349 buffer_len: usize,
350 ) -> Result {
351 if connection_info.peer_free() as usize >= buffer_len {
352 Ok(())
353 } else {
354 // Request an update of the cached peer credit, if we haven't already done so, and tell
355 // the caller to try again later.
356 if !connection_info.has_pending_credit_request {
357 self.request_credit(connection_info)?;
358 connection_info.has_pending_credit_request = true;
359 }
360 Err(SocketError::InsufficientBufferSpaceInPeer.into())
361 }
362 }
363
364 /// Tells the peer how much buffer space we have to receive data.
credit_update(&mut self, connection_info: &ConnectionInfo) -> Result365 pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
366 let header = VirtioVsockHdr {
367 op: VirtioVsockOp::CreditUpdate.into(),
368 ..connection_info.new_header(self.guest_cid)
369 };
370 self.send_packet_to_tx_queue(&header, &[])
371 }
372
373 /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
374 /// it.
poll( &mut self, handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>, ) -> Result<Option<VsockEvent>>375 pub fn poll(
376 &mut self,
377 handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
378 ) -> Result<Option<VsockEvent>> {
379 self.rx.poll(&mut self.transport, |buffer| {
380 let (header, body) = read_header_and_body(buffer)?;
381 VsockEvent::from_header(&header).and_then(|event| handler(event, body))
382 })
383 }
384
385 /// Requests to shut down the connection cleanly, sending hints about whether we will send or
386 /// receive more data.
387 ///
388 /// This returns as soon as the request is sent; you should wait until `poll` returns a
389 /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
390 /// shutdown.
shutdown_with_hints( &mut self, connection_info: &ConnectionInfo, hints: StreamShutdown, ) -> Result391 pub fn shutdown_with_hints(
392 &mut self,
393 connection_info: &ConnectionInfo,
394 hints: StreamShutdown,
395 ) -> Result {
396 let header = VirtioVsockHdr {
397 op: VirtioVsockOp::Shutdown.into(),
398 flags: hints.into(),
399 ..connection_info.new_header(self.guest_cid)
400 };
401 self.send_packet_to_tx_queue(&header, &[])
402 }
403
404 /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
405 /// any more data.
406 ///
407 /// This returns as soon as the request is sent; you should wait until `poll` returns a
408 /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
409 /// shutdown.
shutdown(&mut self, connection_info: &ConnectionInfo) -> Result410 pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
411 self.shutdown_with_hints(
412 connection_info,
413 StreamShutdown::SEND | StreamShutdown::RECEIVE,
414 )
415 }
416
417 /// Forcibly closes the connection without waiting for the peer.
force_close(&mut self, connection_info: &ConnectionInfo) -> Result418 pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
419 let header = VirtioVsockHdr {
420 op: VirtioVsockOp::Rst.into(),
421 ..connection_info.new_header(self.guest_cid)
422 };
423 self.send_packet_to_tx_queue(&header, &[])?;
424 Ok(())
425 }
426
send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result427 fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
428 let _len = if buffer.is_empty() {
429 self.tx
430 .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
431 } else {
432 self.tx.add_notify_wait_pop(
433 &[header.as_bytes(), buffer],
434 &mut [],
435 &mut self.transport,
436 )?
437 };
438 Ok(())
439 }
440 }
441
read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])>442 fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
443 // This could fail if the device returns a buffer used length shorter than the header size.
444 let header = VirtioVsockHdr::read_from_prefix(buffer)
445 .map_err(|_| SocketError::BufferTooShort)?
446 .0;
447 let body_length = header.len() as usize;
448
449 // This could fail if the device returns an unreasonably long body length.
450 let data_end = size_of::<VirtioVsockHdr>()
451 .checked_add(body_length)
452 .ok_or(SocketError::InvalidNumber)?;
453 // This could fail if the device returns a body length longer than buffer used length it
454 // returned.
455 let data = buffer
456 .get(size_of::<VirtioVsockHdr>()..data_end)
457 .ok_or(SocketError::BufferTooShort)?;
458 Ok((header, data))
459 }
460
461 #[cfg(test)]
462 mod tests {
463 use super::*;
464 use crate::{
465 config::ReadOnly,
466 hal::fake::FakeHal,
467 transport::{
468 fake::{FakeTransport, QueueStatus, State},
469 DeviceType,
470 },
471 };
472 use alloc::{sync::Arc, vec};
473 use std::sync::Mutex;
474
475 #[test]
config()476 fn config() {
477 let config_space = VirtioVsockConfig {
478 guest_cid_low: ReadOnly::new(66),
479 guest_cid_high: ReadOnly::new(0),
480 };
481 let state = Arc::new(Mutex::new(State::new(
482 vec![
483 QueueStatus::default(),
484 QueueStatus::default(),
485 QueueStatus::default(),
486 ],
487 config_space,
488 )));
489 let transport = FakeTransport {
490 device_type: DeviceType::Socket,
491 max_queue_size: 32,
492 device_features: 0,
493 state: state.clone(),
494 };
495 let socket =
496 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
497 assert_eq!(socket.guest_cid(), 0x00_0000_0042);
498 }
499 }
500