1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 /// Support for virtual sockets.
6 use std::fmt;
7 use std::{
8 io,
9 mem::{
10 size_of, {self},
11 },
12 num::ParseIntError,
13 os::{
14 raw::{c_uchar, c_uint, c_ushort},
15 unix::io::{AsRawFd, IntoRawFd, RawFd},
16 },
17 result,
18 str::FromStr,
19 };
20
21 use libc::{
22 c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK, VMADDR_CID_ANY,
23 VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, {self},
24 };
25 use thiserror::Error;
26
27 // The domain for vsock sockets.
28 const AF_VSOCK: sa_family_t = 40;
29
30 // Vsock loopback address.
31 const VMADDR_CID_LOCAL: c_uint = 1;
32
33 /// Vsock equivalent of binding on port 0. Binds to a random port.
34 pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
35
36 // The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
37 // from linux/vm_sockets.h.
38 const PADDING: usize = size_of::<sockaddr>()
39 - size_of::<sa_family_t>()
40 - size_of::<c_ushort>()
41 - (2 * size_of::<c_uint>());
42
43 #[repr(C)]
44 #[derive(Default)]
45 struct sockaddr_vm {
46 svm_family: sa_family_t,
47 svm_reserved1: c_ushort,
48 svm_port: c_uint,
49 svm_cid: c_uint,
50 svm_zero: [c_uchar; PADDING],
51 }
52
53 #[derive(Error, Debug)]
54 #[error("failed to parse vsock address")]
55 pub struct AddrParseError;
56
57 /// The vsock equivalent of an IP address.
58 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
59 pub enum VsockCid {
60 /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
61 Any,
62 /// An address that refers to the bare-metal machine that serves as the hypervisor.
63 Hypervisor,
64 /// The loopback address.
65 Local,
66 /// The parent machine. It may not be the hypervisor for nested VMs.
67 Host,
68 /// An assigned CID that serves as the address for VSOCK.
69 Cid(c_uint),
70 }
71
72 impl fmt::Display for VsockCid {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result73 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
74 match &self {
75 VsockCid::Any => write!(fmt, "Any"),
76 VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
77 VsockCid::Local => write!(fmt, "Local"),
78 VsockCid::Host => write!(fmt, "Host"),
79 VsockCid::Cid(c) => write!(fmt, "'{}'", c),
80 }
81 }
82 }
83
84 impl From<c_uint> for VsockCid {
from(c: c_uint) -> Self85 fn from(c: c_uint) -> Self {
86 match c {
87 VMADDR_CID_ANY => VsockCid::Any,
88 VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
89 VMADDR_CID_LOCAL => VsockCid::Local,
90 VMADDR_CID_HOST => VsockCid::Host,
91 _ => VsockCid::Cid(c),
92 }
93 }
94 }
95
96 impl FromStr for VsockCid {
97 type Err = ParseIntError;
98
from_str(s: &str) -> Result<Self, Self::Err>99 fn from_str(s: &str) -> Result<Self, Self::Err> {
100 let c: c_uint = s.parse()?;
101 Ok(c.into())
102 }
103 }
104
105 impl From<VsockCid> for c_uint {
from(cid: VsockCid) -> c_uint106 fn from(cid: VsockCid) -> c_uint {
107 match cid {
108 VsockCid::Any => VMADDR_CID_ANY,
109 VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
110 VsockCid::Local => VMADDR_CID_LOCAL,
111 VsockCid::Host => VMADDR_CID_HOST,
112 VsockCid::Cid(c) => c,
113 }
114 }
115 }
116
117 /// An address associated with a virtual socket.
118 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
119 pub struct SocketAddr {
120 pub cid: VsockCid,
121 pub port: c_uint,
122 }
123
124 pub trait ToSocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>125 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
126 }
127
128 impl ToSocketAddr for SocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>129 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
130 Ok(*self)
131 }
132 }
133
134 impl ToSocketAddr for str {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>135 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
136 self.parse()
137 }
138 }
139
140 impl ToSocketAddr for (VsockCid, c_uint) {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>141 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
142 let (cid, port) = *self;
143 Ok(SocketAddr { cid, port })
144 }
145 }
146
147 impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>148 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
149 (**self).to_socket_addr()
150 }
151 }
152
153 impl FromStr for SocketAddr {
154 type Err = AddrParseError;
155
156 /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
157 /// "vsock:cid:port".
from_str(s: &str) -> Result<SocketAddr, AddrParseError>158 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
159 let components: Vec<&str> = s.split(':').collect();
160 if components.len() != 3 || components[0] != "vsock" {
161 return Err(AddrParseError);
162 }
163
164 Ok(SocketAddr {
165 cid: components[1].parse().map_err(|_| AddrParseError)?,
166 port: components[2].parse().map_err(|_| AddrParseError)?,
167 })
168 }
169 }
170
171 impl fmt::Display for SocketAddr {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result172 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
173 write!(fmt, "{}:{}", self.cid, self.port)
174 }
175 }
176
177 /// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
178 /// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()>179 unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
180 let flags = libc::fcntl(fd, F_GETFL, 0);
181 if flags < 0 {
182 return Err(io::Error::last_os_error());
183 }
184
185 let flags = if nonblocking {
186 flags | O_NONBLOCK
187 } else {
188 flags & !O_NONBLOCK
189 };
190
191 let ret = libc::fcntl(fd, F_SETFL, flags);
192 if ret < 0 {
193 return Err(io::Error::last_os_error());
194 }
195
196 Ok(())
197 }
198
199 /// A virtual socket.
200 ///
201 /// Do not use this class unless you need to change socket options or query the
202 /// state of the socket prior to calling listen or connect. Instead use either VsockStream or
203 /// VsockListener.
204 #[derive(Debug)]
205 pub struct VsockSocket {
206 fd: RawFd,
207 }
208
209 impl VsockSocket {
new() -> io::Result<Self>210 pub fn new() -> io::Result<Self> {
211 let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
212 if fd < 0 {
213 Err(io::Error::last_os_error())
214 } else {
215 Ok(VsockSocket { fd })
216 }
217 }
218
bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()>219 pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
220 let sockaddr = addr
221 .to_socket_addr()
222 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
223
224 // The compiler should optimize this out since these are both compile-time constants.
225 assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
226
227 let svm = sockaddr_vm {
228 svm_family: AF_VSOCK,
229 svm_cid: sockaddr.cid.into(),
230 svm_port: sockaddr.port,
231 ..Default::default()
232 };
233
234 // Safe because this doesn't modify any memory and we check the return value.
235 let ret = unsafe {
236 libc::bind(
237 self.fd,
238 &svm as *const sockaddr_vm as *const sockaddr,
239 size_of::<sockaddr_vm>() as socklen_t,
240 )
241 };
242 if ret < 0 {
243 let bind_err = io::Error::last_os_error();
244 Err(bind_err)
245 } else {
246 Ok(())
247 }
248 }
249
connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream>250 pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
251 let sockaddr = addr
252 .to_socket_addr()
253 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
254
255 let svm = sockaddr_vm {
256 svm_family: AF_VSOCK,
257 svm_cid: sockaddr.cid.into(),
258 svm_port: sockaddr.port,
259 ..Default::default()
260 };
261
262 // Safe because this just connects a vsock socket, and the return value is checked.
263 let ret = unsafe {
264 libc::connect(
265 self.fd,
266 &svm as *const sockaddr_vm as *const sockaddr,
267 size_of::<sockaddr_vm>() as socklen_t,
268 )
269 };
270 if ret < 0 {
271 let connect_err = io::Error::last_os_error();
272 Err(connect_err)
273 } else {
274 Ok(VsockStream { sock: self })
275 }
276 }
277
listen(self) -> io::Result<VsockListener>278 pub fn listen(self) -> io::Result<VsockListener> {
279 // Safe because this doesn't modify any memory and we check the return value.
280 let ret = unsafe { libc::listen(self.fd, 1) };
281 if ret < 0 {
282 let listen_err = io::Error::last_os_error();
283 return Err(listen_err);
284 }
285 Ok(VsockListener { sock: self })
286 }
287
288 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u32>289 pub fn local_port(&self) -> io::Result<u32> {
290 let mut svm: sockaddr_vm = Default::default();
291
292 // Safe because we give a valid pointer for addrlen and check the length.
293 let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
294 let ret = unsafe {
295 // Get the socket address that was actually bound.
296 libc::getsockname(
297 self.fd,
298 &mut svm as *mut sockaddr_vm as *mut sockaddr,
299 &mut addrlen as *mut socklen_t,
300 )
301 };
302 if ret < 0 {
303 let getsockname_err = io::Error::last_os_error();
304 Err(getsockname_err)
305 } else {
306 // If this doesn't match, it's not safe to get the port out of the sockaddr.
307 assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
308
309 Ok(svm.svm_port)
310 }
311 }
312
try_clone(&self) -> io::Result<Self>313 pub fn try_clone(&self) -> io::Result<Self> {
314 // Safe because this doesn't modify any memory and we check the return value.
315 let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
316 if dup_fd < 0 {
317 Err(io::Error::last_os_error())
318 } else {
319 Ok(Self { fd: dup_fd })
320 }
321 }
322
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>323 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
324 // Safe because the fd is valid and owned by this stream.
325 unsafe { set_nonblocking(self.fd, nonblocking) }
326 }
327 }
328
329 impl IntoRawFd for VsockSocket {
into_raw_fd(self) -> RawFd330 fn into_raw_fd(self) -> RawFd {
331 let fd = self.fd;
332 mem::forget(self);
333 fd
334 }
335 }
336
337 impl AsRawFd for VsockSocket {
as_raw_fd(&self) -> RawFd338 fn as_raw_fd(&self) -> RawFd {
339 self.fd
340 }
341 }
342
343 impl Drop for VsockSocket {
drop(&mut self)344 fn drop(&mut self) {
345 // Safe because this doesn't modify any memory and we are the only
346 // owner of the file descriptor.
347 unsafe { libc::close(self.fd) };
348 }
349 }
350
351 /// A virtual stream socket.
352 #[derive(Debug)]
353 pub struct VsockStream {
354 sock: VsockSocket,
355 }
356
357 impl VsockStream {
connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream>358 pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
359 let sock = VsockSocket::new()?;
360 sock.connect(addr)
361 }
362
363 /// Returns the port that this stream is bound to.
local_port(&self) -> io::Result<u32>364 pub fn local_port(&self) -> io::Result<u32> {
365 self.sock.local_port()
366 }
367
try_clone(&self) -> io::Result<VsockStream>368 pub fn try_clone(&self) -> io::Result<VsockStream> {
369 self.sock.try_clone().map(|f| VsockStream { sock: f })
370 }
371
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>372 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
373 self.sock.set_nonblocking(nonblocking)
374 }
375 }
376
377 impl io::Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>378 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
379 // Safe because this will only modify the contents of |buf| and we check the return value.
380 let ret = unsafe {
381 libc::read(
382 self.sock.as_raw_fd(),
383 buf as *mut [u8] as *mut c_void,
384 buf.len() as size_t,
385 )
386 };
387 if ret < 0 {
388 return Err(io::Error::last_os_error());
389 }
390
391 Ok(ret as usize)
392 }
393 }
394
395 impl io::Write for VsockStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>396 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
397 // Safe because this doesn't modify any memory and we check the return value.
398 let ret = unsafe {
399 libc::write(
400 self.sock.as_raw_fd(),
401 buf as *const [u8] as *const c_void,
402 buf.len() as size_t,
403 )
404 };
405 if ret < 0 {
406 return Err(io::Error::last_os_error());
407 }
408
409 Ok(ret as usize)
410 }
411
flush(&mut self) -> io::Result<()>412 fn flush(&mut self) -> io::Result<()> {
413 // No buffered data so nothing to do.
414 Ok(())
415 }
416 }
417
418 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd419 fn as_raw_fd(&self) -> RawFd {
420 self.sock.as_raw_fd()
421 }
422 }
423
424 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd425 fn into_raw_fd(self) -> RawFd {
426 self.sock.into_raw_fd()
427 }
428 }
429
430 /// Represents a virtual socket server.
431 #[derive(Debug)]
432 pub struct VsockListener {
433 sock: VsockSocket,
434 }
435
436 impl VsockListener {
437 /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
438 /// endpoint.
bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener>439 pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
440 let mut sock = VsockSocket::new()?;
441 sock.bind(addr)?;
442 sock.listen()
443 }
444
445 /// Returns the port that this listener is bound to.
local_port(&self) -> io::Result<u32>446 pub fn local_port(&self) -> io::Result<u32> {
447 self.sock.local_port()
448 }
449
450 /// Accepts a new incoming connection on this listener. Blocks the calling thread until a
451 /// new connection is established. When established, returns the corresponding `VsockStream`
452 /// and the remote peer's address.
accept(&self) -> io::Result<(VsockStream, SocketAddr)>453 pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
454 let mut svm: sockaddr_vm = Default::default();
455
456 // Safe because this will only modify |svm| and we check the return value.
457 let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
458 let fd = unsafe {
459 libc::accept4(
460 self.sock.as_raw_fd(),
461 &mut svm as *mut sockaddr_vm as *mut sockaddr,
462 &mut socklen as *mut socklen_t,
463 libc::SOCK_CLOEXEC,
464 )
465 };
466 if fd < 0 {
467 return Err(io::Error::last_os_error());
468 }
469
470 if svm.svm_family != AF_VSOCK {
471 return Err(io::Error::new(
472 io::ErrorKind::InvalidData,
473 format!("unexpected address family: {}", svm.svm_family),
474 ));
475 }
476
477 Ok((
478 VsockStream {
479 sock: VsockSocket { fd },
480 },
481 SocketAddr {
482 cid: svm.svm_cid.into(),
483 port: svm.svm_port,
484 },
485 ))
486 }
487
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>488 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
489 self.sock.set_nonblocking(nonblocking)
490 }
491 }
492
493 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd494 fn as_raw_fd(&self) -> RawFd {
495 self.sock.as_raw_fd()
496 }
497 }
498