• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 /// Support for virtual sockets.
6 use std::fmt;
7 use std::{
8     io,
9     mem::{
10         size_of, {self},
11     },
12     num::ParseIntError,
13     os::{
14         raw::{c_uchar, c_uint, c_ushort},
15         unix::io::{AsRawFd, IntoRawFd, RawFd},
16     },
17     result,
18     str::FromStr,
19 };
20 
21 use libc::{
22     c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK, VMADDR_CID_ANY,
23     VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, {self},
24 };
25 use thiserror::Error;
26 
27 // The domain for vsock sockets.
28 const AF_VSOCK: sa_family_t = 40;
29 
30 // Vsock loopback address.
31 const VMADDR_CID_LOCAL: c_uint = 1;
32 
33 /// Vsock equivalent of binding on port 0. Binds to a random port.
34 pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
35 
36 // The number of bytes of padding to be added to the sockaddr_vm struct.  Taken directly
37 // from linux/vm_sockets.h.
38 const PADDING: usize = size_of::<sockaddr>()
39     - size_of::<sa_family_t>()
40     - size_of::<c_ushort>()
41     - (2 * size_of::<c_uint>());
42 
43 #[repr(C)]
44 #[derive(Default)]
45 struct sockaddr_vm {
46     svm_family: sa_family_t,
47     svm_reserved1: c_ushort,
48     svm_port: c_uint,
49     svm_cid: c_uint,
50     svm_zero: [c_uchar; PADDING],
51 }
52 
53 #[derive(Error, Debug)]
54 #[error("failed to parse vsock address")]
55 pub struct AddrParseError;
56 
57 /// The vsock equivalent of an IP address.
58 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
59 pub enum VsockCid {
60     /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
61     Any,
62     /// An address that refers to the bare-metal machine that serves as the hypervisor.
63     Hypervisor,
64     /// The loopback address.
65     Local,
66     /// The parent machine. It may not be the hypervisor for nested VMs.
67     Host,
68     /// An assigned CID that serves as the address for VSOCK.
69     Cid(c_uint),
70 }
71 
72 impl fmt::Display for VsockCid {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result73     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
74         match &self {
75             VsockCid::Any => write!(fmt, "Any"),
76             VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
77             VsockCid::Local => write!(fmt, "Local"),
78             VsockCid::Host => write!(fmt, "Host"),
79             VsockCid::Cid(c) => write!(fmt, "'{}'", c),
80         }
81     }
82 }
83 
84 impl From<c_uint> for VsockCid {
from(c: c_uint) -> Self85     fn from(c: c_uint) -> Self {
86         match c {
87             VMADDR_CID_ANY => VsockCid::Any,
88             VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
89             VMADDR_CID_LOCAL => VsockCid::Local,
90             VMADDR_CID_HOST => VsockCid::Host,
91             _ => VsockCid::Cid(c),
92         }
93     }
94 }
95 
96 impl FromStr for VsockCid {
97     type Err = ParseIntError;
98 
from_str(s: &str) -> Result<Self, Self::Err>99     fn from_str(s: &str) -> Result<Self, Self::Err> {
100         let c: c_uint = s.parse()?;
101         Ok(c.into())
102     }
103 }
104 
105 impl From<VsockCid> for c_uint {
from(cid: VsockCid) -> c_uint106     fn from(cid: VsockCid) -> c_uint {
107         match cid {
108             VsockCid::Any => VMADDR_CID_ANY,
109             VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
110             VsockCid::Local => VMADDR_CID_LOCAL,
111             VsockCid::Host => VMADDR_CID_HOST,
112             VsockCid::Cid(c) => c,
113         }
114     }
115 }
116 
117 /// An address associated with a virtual socket.
118 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
119 pub struct SocketAddr {
120     pub cid: VsockCid,
121     pub port: c_uint,
122 }
123 
124 pub trait ToSocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>125     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
126 }
127 
128 impl ToSocketAddr for SocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>129     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
130         Ok(*self)
131     }
132 }
133 
134 impl ToSocketAddr for str {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>135     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136         self.parse()
137     }
138 }
139 
140 impl ToSocketAddr for (VsockCid, c_uint) {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>141     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142         let (cid, port) = *self;
143         Ok(SocketAddr { cid, port })
144     }
145 }
146 
147 impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>148     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
149         (**self).to_socket_addr()
150     }
151 }
152 
153 impl FromStr for SocketAddr {
154     type Err = AddrParseError;
155 
156     /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
157     /// "vsock:cid:port".
from_str(s: &str) -> Result<SocketAddr, AddrParseError>158     fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
159         let components: Vec<&str> = s.split(':').collect();
160         if components.len() != 3 || components[0] != "vsock" {
161             return Err(AddrParseError);
162         }
163 
164         Ok(SocketAddr {
165             cid: components[1].parse().map_err(|_| AddrParseError)?,
166             port: components[2].parse().map_err(|_| AddrParseError)?,
167         })
168     }
169 }
170 
171 impl fmt::Display for SocketAddr {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result172     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
173         write!(fmt, "{}:{}", self.cid, self.port)
174     }
175 }
176 
177 /// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
178 /// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()>179 unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
180     let flags = libc::fcntl(fd, F_GETFL, 0);
181     if flags < 0 {
182         return Err(io::Error::last_os_error());
183     }
184 
185     let flags = if nonblocking {
186         flags | O_NONBLOCK
187     } else {
188         flags & !O_NONBLOCK
189     };
190 
191     let ret = libc::fcntl(fd, F_SETFL, flags);
192     if ret < 0 {
193         return Err(io::Error::last_os_error());
194     }
195 
196     Ok(())
197 }
198 
199 /// A virtual socket.
200 ///
201 /// Do not use this class unless you need to change socket options or query the
202 /// state of the socket prior to calling listen or connect. Instead use either VsockStream or
203 /// VsockListener.
204 #[derive(Debug)]
205 pub struct VsockSocket {
206     fd: RawFd,
207 }
208 
209 impl VsockSocket {
new() -> io::Result<Self>210     pub fn new() -> io::Result<Self> {
211         let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
212         if fd < 0 {
213             Err(io::Error::last_os_error())
214         } else {
215             Ok(VsockSocket { fd })
216         }
217     }
218 
bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()>219     pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
220         let sockaddr = addr
221             .to_socket_addr()
222             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
223 
224         // The compiler should optimize this out since these are both compile-time constants.
225         assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
226 
227         let svm = sockaddr_vm {
228             svm_family: AF_VSOCK,
229             svm_cid: sockaddr.cid.into(),
230             svm_port: sockaddr.port,
231             ..Default::default()
232         };
233 
234         // Safe because this doesn't modify any memory and we check the return value.
235         let ret = unsafe {
236             libc::bind(
237                 self.fd,
238                 &svm as *const sockaddr_vm as *const sockaddr,
239                 size_of::<sockaddr_vm>() as socklen_t,
240             )
241         };
242         if ret < 0 {
243             let bind_err = io::Error::last_os_error();
244             Err(bind_err)
245         } else {
246             Ok(())
247         }
248     }
249 
connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream>250     pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
251         let sockaddr = addr
252             .to_socket_addr()
253             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
254 
255         let svm = sockaddr_vm {
256             svm_family: AF_VSOCK,
257             svm_cid: sockaddr.cid.into(),
258             svm_port: sockaddr.port,
259             ..Default::default()
260         };
261 
262         // Safe because this just connects a vsock socket, and the return value is checked.
263         let ret = unsafe {
264             libc::connect(
265                 self.fd,
266                 &svm as *const sockaddr_vm as *const sockaddr,
267                 size_of::<sockaddr_vm>() as socklen_t,
268             )
269         };
270         if ret < 0 {
271             let connect_err = io::Error::last_os_error();
272             Err(connect_err)
273         } else {
274             Ok(VsockStream { sock: self })
275         }
276     }
277 
listen(self) -> io::Result<VsockListener>278     pub fn listen(self) -> io::Result<VsockListener> {
279         // Safe because this doesn't modify any memory and we check the return value.
280         let ret = unsafe { libc::listen(self.fd, 1) };
281         if ret < 0 {
282             let listen_err = io::Error::last_os_error();
283             return Err(listen_err);
284         }
285         Ok(VsockListener { sock: self })
286     }
287 
288     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u32>289     pub fn local_port(&self) -> io::Result<u32> {
290         let mut svm: sockaddr_vm = Default::default();
291 
292         // Safe because we give a valid pointer for addrlen and check the length.
293         let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
294         let ret = unsafe {
295             // Get the socket address that was actually bound.
296             libc::getsockname(
297                 self.fd,
298                 &mut svm as *mut sockaddr_vm as *mut sockaddr,
299                 &mut addrlen as *mut socklen_t,
300             )
301         };
302         if ret < 0 {
303             let getsockname_err = io::Error::last_os_error();
304             Err(getsockname_err)
305         } else {
306             // If this doesn't match, it's not safe to get the port out of the sockaddr.
307             assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
308 
309             Ok(svm.svm_port)
310         }
311     }
312 
try_clone(&self) -> io::Result<Self>313     pub fn try_clone(&self) -> io::Result<Self> {
314         // Safe because this doesn't modify any memory and we check the return value.
315         let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
316         if dup_fd < 0 {
317             Err(io::Error::last_os_error())
318         } else {
319             Ok(Self { fd: dup_fd })
320         }
321     }
322 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>323     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
324         // Safe because the fd is valid and owned by this stream.
325         unsafe { set_nonblocking(self.fd, nonblocking) }
326     }
327 }
328 
329 impl IntoRawFd for VsockSocket {
into_raw_fd(self) -> RawFd330     fn into_raw_fd(self) -> RawFd {
331         let fd = self.fd;
332         mem::forget(self);
333         fd
334     }
335 }
336 
337 impl AsRawFd for VsockSocket {
as_raw_fd(&self) -> RawFd338     fn as_raw_fd(&self) -> RawFd {
339         self.fd
340     }
341 }
342 
343 impl Drop for VsockSocket {
drop(&mut self)344     fn drop(&mut self) {
345         // Safe because this doesn't modify any memory and we are the only
346         // owner of the file descriptor.
347         unsafe { libc::close(self.fd) };
348     }
349 }
350 
351 /// A virtual stream socket.
352 #[derive(Debug)]
353 pub struct VsockStream {
354     sock: VsockSocket,
355 }
356 
357 impl VsockStream {
connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream>358     pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
359         let sock = VsockSocket::new()?;
360         sock.connect(addr)
361     }
362 
363     /// Returns the port that this stream is bound to.
local_port(&self) -> io::Result<u32>364     pub fn local_port(&self) -> io::Result<u32> {
365         self.sock.local_port()
366     }
367 
try_clone(&self) -> io::Result<VsockStream>368     pub fn try_clone(&self) -> io::Result<VsockStream> {
369         self.sock.try_clone().map(|f| VsockStream { sock: f })
370     }
371 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>372     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
373         self.sock.set_nonblocking(nonblocking)
374     }
375 }
376 
377 impl io::Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>378     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
379         // Safe because this will only modify the contents of |buf| and we check the return value.
380         let ret = unsafe {
381             libc::read(
382                 self.sock.as_raw_fd(),
383                 buf as *mut [u8] as *mut c_void,
384                 buf.len() as size_t,
385             )
386         };
387         if ret < 0 {
388             return Err(io::Error::last_os_error());
389         }
390 
391         Ok(ret as usize)
392     }
393 }
394 
395 impl io::Write for VsockStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>396     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
397         // Safe because this doesn't modify any memory and we check the return value.
398         let ret = unsafe {
399             libc::write(
400                 self.sock.as_raw_fd(),
401                 buf as *const [u8] as *const c_void,
402                 buf.len() as size_t,
403             )
404         };
405         if ret < 0 {
406             return Err(io::Error::last_os_error());
407         }
408 
409         Ok(ret as usize)
410     }
411 
flush(&mut self) -> io::Result<()>412     fn flush(&mut self) -> io::Result<()> {
413         // No buffered data so nothing to do.
414         Ok(())
415     }
416 }
417 
418 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd419     fn as_raw_fd(&self) -> RawFd {
420         self.sock.as_raw_fd()
421     }
422 }
423 
424 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd425     fn into_raw_fd(self) -> RawFd {
426         self.sock.into_raw_fd()
427     }
428 }
429 
430 /// Represents a virtual socket server.
431 #[derive(Debug)]
432 pub struct VsockListener {
433     sock: VsockSocket,
434 }
435 
436 impl VsockListener {
437     /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
438     /// endpoint.
bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener>439     pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
440         let mut sock = VsockSocket::new()?;
441         sock.bind(addr)?;
442         sock.listen()
443     }
444 
445     /// Returns the port that this listener is bound to.
local_port(&self) -> io::Result<u32>446     pub fn local_port(&self) -> io::Result<u32> {
447         self.sock.local_port()
448     }
449 
450     /// Accepts a new incoming connection on this listener.  Blocks the calling thread until a
451     /// new connection is established.  When established, returns the corresponding `VsockStream`
452     /// and the remote peer's address.
accept(&self) -> io::Result<(VsockStream, SocketAddr)>453     pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
454         let mut svm: sockaddr_vm = Default::default();
455 
456         // Safe because this will only modify |svm| and we check the return value.
457         let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
458         let fd = unsafe {
459             libc::accept4(
460                 self.sock.as_raw_fd(),
461                 &mut svm as *mut sockaddr_vm as *mut sockaddr,
462                 &mut socklen as *mut socklen_t,
463                 libc::SOCK_CLOEXEC,
464             )
465         };
466         if fd < 0 {
467             return Err(io::Error::last_os_error());
468         }
469 
470         if svm.svm_family != AF_VSOCK {
471             return Err(io::Error::new(
472                 io::ErrorKind::InvalidData,
473                 format!("unexpected address family: {}", svm.svm_family),
474             ));
475         }
476 
477         Ok((
478             VsockStream {
479                 sock: VsockSocket { fd },
480             },
481             SocketAddr {
482                 cid: svm.svm_cid.into(),
483                 port: svm.svm_port,
484             },
485         ))
486     }
487 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>488     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
489         self.sock.set_nonblocking(nonblocking)
490     }
491 }
492 
493 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd494     fn as_raw_fd(&self) -> RawFd {
495         self.sock.as_raw_fd()
496     }
497 }
498