1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 /// Support for virtual sockets.
6 use std::fmt;
7 use std::io;
8 use std::mem::{self, size_of};
9 use std::num::ParseIntError;
10 use std::os::raw::{c_uchar, c_uint, c_ushort};
11 use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
12 use std::result;
13 use std::str::FromStr;
14
15 use libc::{
16 self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
17 VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
18 };
19
20 // The domain for vsock sockets.
21 const AF_VSOCK: sa_family_t = 40;
22
23 // Vsock loopback address.
24 const VMADDR_CID_LOCAL: c_uint = 1;
25
26 /// Vsock equivalent of binding on port 0. Binds to a random port.
27 pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
28
29 // The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
30 // from linux/vm_sockets.h.
31 const PADDING: usize = size_of::<sockaddr>()
32 - size_of::<sa_family_t>()
33 - size_of::<c_ushort>()
34 - (2 * size_of::<c_uint>());
35
36 #[repr(C)]
37 #[derive(Default)]
38 struct sockaddr_vm {
39 svm_family: sa_family_t,
40 svm_reserved1: c_ushort,
41 svm_port: c_uint,
42 svm_cid: c_uint,
43 svm_zero: [c_uchar; PADDING],
44 }
45
46 #[derive(Debug)]
47 pub struct AddrParseError;
48
49 impl fmt::Display for AddrParseError {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result50 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
51 write!(fmt, "failed to parse vsock address")
52 }
53 }
54
55 /// The vsock equivalent of an IP address.
56 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
57 pub enum VsockCid {
58 /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
59 Any,
60 /// An address that refers to the bare-metal machine that serves as the hypervisor.
61 Hypervisor,
62 /// The loopback address.
63 Local,
64 /// The parent machine. It may not be the hypervisor for nested VMs.
65 Host,
66 /// An assigned CID that serves as the address for VSOCK.
67 Cid(c_uint),
68 }
69
70 impl fmt::Display for VsockCid {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result71 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
72 match &self {
73 VsockCid::Any => write!(fmt, "Any"),
74 VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
75 VsockCid::Local => write!(fmt, "Local"),
76 VsockCid::Host => write!(fmt, "Host"),
77 VsockCid::Cid(c) => write!(fmt, "'{}'", c),
78 }
79 }
80 }
81
82 impl From<c_uint> for VsockCid {
from(c: c_uint) -> Self83 fn from(c: c_uint) -> Self {
84 match c {
85 VMADDR_CID_ANY => VsockCid::Any,
86 VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
87 VMADDR_CID_LOCAL => VsockCid::Local,
88 VMADDR_CID_HOST => VsockCid::Host,
89 _ => VsockCid::Cid(c),
90 }
91 }
92 }
93
94 impl FromStr for VsockCid {
95 type Err = ParseIntError;
96
from_str(s: &str) -> Result<Self, Self::Err>97 fn from_str(s: &str) -> Result<Self, Self::Err> {
98 let c: c_uint = s.parse()?;
99 Ok(c.into())
100 }
101 }
102
103 impl Into<c_uint> for VsockCid {
into(self) -> c_uint104 fn into(self) -> c_uint {
105 match self {
106 VsockCid::Any => VMADDR_CID_ANY,
107 VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
108 VsockCid::Local => VMADDR_CID_LOCAL,
109 VsockCid::Host => VMADDR_CID_HOST,
110 VsockCid::Cid(c) => c,
111 }
112 }
113 }
114
115 /// An address associated with a virtual socket.
116 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
117 pub struct SocketAddr {
118 pub cid: VsockCid,
119 pub port: c_uint,
120 }
121
122 pub trait ToSocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>123 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
124 }
125
126 impl ToSocketAddr for SocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>127 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
128 Ok(*self)
129 }
130 }
131
132 impl ToSocketAddr for str {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>133 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
134 self.parse()
135 }
136 }
137
138 impl ToSocketAddr for (VsockCid, c_uint) {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>139 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
140 let (cid, port) = *self;
141 Ok(SocketAddr { cid, port })
142 }
143 }
144
145 impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>146 fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
147 (**self).to_socket_addr()
148 }
149 }
150
151 impl FromStr for SocketAddr {
152 type Err = AddrParseError;
153
154 /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
155 /// "vsock:cid:port".
from_str(s: &str) -> Result<SocketAddr, AddrParseError>156 fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
157 let components: Vec<&str> = s.split(':').collect();
158 if components.len() != 3 || components[0] != "vsock" {
159 return Err(AddrParseError);
160 }
161
162 Ok(SocketAddr {
163 cid: components[1].parse().map_err(|_| AddrParseError)?,
164 port: components[2].parse().map_err(|_| AddrParseError)?,
165 })
166 }
167 }
168
169 impl fmt::Display for SocketAddr {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result170 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
171 write!(fmt, "{}:{}", self.cid, self.port)
172 }
173 }
174
175 /// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
176 /// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()>177 unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
178 let flags = libc::fcntl(fd, F_GETFL, 0);
179 if flags < 0 {
180 return Err(io::Error::last_os_error());
181 }
182
183 let flags = if nonblocking {
184 flags | O_NONBLOCK
185 } else {
186 flags & !O_NONBLOCK
187 };
188
189 let ret = libc::fcntl(fd, F_SETFL, flags);
190 if ret < 0 {
191 return Err(io::Error::last_os_error());
192 }
193
194 Ok(())
195 }
196
197 /// A virtual socket.
198 ///
199 /// Do not use this class unless you need to change socket options or query the
200 /// state of the socket prior to calling listen or connect. Instead use either VsockStream or
201 /// VsockListener.
202 #[derive(Debug)]
203 pub struct VsockSocket {
204 fd: RawFd,
205 }
206
207 impl VsockSocket {
new() -> io::Result<Self>208 pub fn new() -> io::Result<Self> {
209 let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
210 if fd < 0 {
211 Err(io::Error::last_os_error())
212 } else {
213 Ok(VsockSocket { fd })
214 }
215 }
216
bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()>217 pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
218 let sockaddr = addr
219 .to_socket_addr()
220 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
221
222 // The compiler should optimize this out since these are both compile-time constants.
223 assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
224
225 let mut svm: sockaddr_vm = Default::default();
226 svm.svm_family = AF_VSOCK;
227 svm.svm_cid = sockaddr.cid.into();
228 svm.svm_port = sockaddr.port;
229
230 // Safe because this doesn't modify any memory and we check the return value.
231 let ret = unsafe {
232 libc::bind(
233 self.fd,
234 &svm as *const sockaddr_vm as *const sockaddr,
235 size_of::<sockaddr_vm>() as socklen_t,
236 )
237 };
238 if ret < 0 {
239 let bind_err = io::Error::last_os_error();
240 Err(bind_err)
241 } else {
242 Ok(())
243 }
244 }
245
connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream>246 pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
247 let sockaddr = addr
248 .to_socket_addr()
249 .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
250
251 let mut svm: sockaddr_vm = Default::default();
252 svm.svm_family = AF_VSOCK;
253 svm.svm_cid = sockaddr.cid.into();
254 svm.svm_port = sockaddr.port;
255
256 // Safe because this just connects a vsock socket, and the return value is checked.
257 let ret = unsafe {
258 libc::connect(
259 self.fd,
260 &svm as *const sockaddr_vm as *const sockaddr,
261 size_of::<sockaddr_vm>() as socklen_t,
262 )
263 };
264 if ret < 0 {
265 let connect_err = io::Error::last_os_error();
266 Err(connect_err)
267 } else {
268 Ok(VsockStream { sock: self })
269 }
270 }
271
listen(self) -> io::Result<VsockListener>272 pub fn listen(self) -> io::Result<VsockListener> {
273 // Safe because this doesn't modify any memory and we check the return value.
274 let ret = unsafe { libc::listen(self.fd, 1) };
275 if ret < 0 {
276 let listen_err = io::Error::last_os_error();
277 return Err(listen_err);
278 }
279 Ok(VsockListener { sock: self })
280 }
281
282 /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u32>283 pub fn local_port(&self) -> io::Result<u32> {
284 let mut svm: sockaddr_vm = Default::default();
285
286 // Safe because we give a valid pointer for addrlen and check the length.
287 let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
288 let ret = unsafe {
289 // Get the socket address that was actually bound.
290 libc::getsockname(
291 self.fd,
292 &mut svm as *mut sockaddr_vm as *mut sockaddr,
293 &mut addrlen as *mut socklen_t,
294 )
295 };
296 if ret < 0 {
297 let getsockname_err = io::Error::last_os_error();
298 Err(getsockname_err)
299 } else {
300 // If this doesn't match, it's not safe to get the port out of the sockaddr.
301 assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
302
303 Ok(svm.svm_port)
304 }
305 }
306
try_clone(&self) -> io::Result<Self>307 pub fn try_clone(&self) -> io::Result<Self> {
308 // Safe because this doesn't modify any memory and we check the return value.
309 let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
310 if dup_fd < 0 {
311 Err(io::Error::last_os_error())
312 } else {
313 Ok(Self { fd: dup_fd })
314 }
315 }
316
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>317 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
318 // Safe because the fd is valid and owned by this stream.
319 unsafe { set_nonblocking(self.fd, nonblocking) }
320 }
321 }
322
323 impl IntoRawFd for VsockSocket {
into_raw_fd(self) -> RawFd324 fn into_raw_fd(self) -> RawFd {
325 let fd = self.fd;
326 mem::forget(self);
327 fd
328 }
329 }
330
331 impl AsRawFd for VsockSocket {
as_raw_fd(&self) -> RawFd332 fn as_raw_fd(&self) -> RawFd {
333 self.fd
334 }
335 }
336
337 impl Drop for VsockSocket {
drop(&mut self)338 fn drop(&mut self) {
339 // Safe because this doesn't modify any memory and we are the only
340 // owner of the file descriptor.
341 unsafe { libc::close(self.fd) };
342 }
343 }
344
345 /// A virtual stream socket.
346 #[derive(Debug)]
347 pub struct VsockStream {
348 sock: VsockSocket,
349 }
350
351 impl VsockStream {
connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream>352 pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
353 let sock = VsockSocket::new()?;
354 sock.connect(addr)
355 }
356
357 /// Returns the port that this stream is bound to.
local_port(&self) -> io::Result<u32>358 pub fn local_port(&self) -> io::Result<u32> {
359 self.sock.local_port()
360 }
361
try_clone(&self) -> io::Result<VsockStream>362 pub fn try_clone(&self) -> io::Result<VsockStream> {
363 self.sock.try_clone().map(|f| VsockStream { sock: f })
364 }
365
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>366 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
367 self.sock.set_nonblocking(nonblocking)
368 }
369 }
370
371 impl io::Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>372 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
373 // Safe because this will only modify the contents of |buf| and we check the return value.
374 let ret = unsafe {
375 libc::read(
376 self.sock.as_raw_fd(),
377 buf as *mut [u8] as *mut c_void,
378 buf.len() as size_t,
379 )
380 };
381 if ret < 0 {
382 return Err(io::Error::last_os_error());
383 }
384
385 Ok(ret as usize)
386 }
387 }
388
389 impl io::Write for VsockStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>390 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
391 // Safe because this doesn't modify any memory and we check the return value.
392 let ret = unsafe {
393 libc::write(
394 self.sock.as_raw_fd(),
395 buf as *const [u8] as *const c_void,
396 buf.len() as size_t,
397 )
398 };
399 if ret < 0 {
400 return Err(io::Error::last_os_error());
401 }
402
403 Ok(ret as usize)
404 }
405
flush(&mut self) -> io::Result<()>406 fn flush(&mut self) -> io::Result<()> {
407 // No buffered data so nothing to do.
408 Ok(())
409 }
410 }
411
412 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd413 fn as_raw_fd(&self) -> RawFd {
414 self.sock.as_raw_fd()
415 }
416 }
417
418 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd419 fn into_raw_fd(self) -> RawFd {
420 self.sock.into_raw_fd()
421 }
422 }
423
424 /// Represents a virtual socket server.
425 #[derive(Debug)]
426 pub struct VsockListener {
427 sock: VsockSocket,
428 }
429
430 impl VsockListener {
431 /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
432 /// endpoint.
bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener>433 pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
434 let mut sock = VsockSocket::new()?;
435 sock.bind(addr)?;
436 sock.listen()
437 }
438
439 /// Returns the port that this listener is bound to.
local_port(&self) -> io::Result<u32>440 pub fn local_port(&self) -> io::Result<u32> {
441 self.sock.local_port()
442 }
443
444 /// Accepts a new incoming connection on this listener. Blocks the calling thread until a
445 /// new connection is established. When established, returns the corresponding `VsockStream`
446 /// and the remote peer's address.
accept(&self) -> io::Result<(VsockStream, SocketAddr)>447 pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
448 let mut svm: sockaddr_vm = Default::default();
449
450 // Safe because this will only modify |svm| and we check the return value.
451 let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
452 let fd = unsafe {
453 libc::accept4(
454 self.sock.as_raw_fd(),
455 &mut svm as *mut sockaddr_vm as *mut sockaddr,
456 &mut socklen as *mut socklen_t,
457 libc::SOCK_CLOEXEC,
458 )
459 };
460 if fd < 0 {
461 return Err(io::Error::last_os_error());
462 }
463
464 if svm.svm_family != AF_VSOCK {
465 return Err(io::Error::new(
466 io::ErrorKind::InvalidData,
467 format!("unexpected address family: {}", svm.svm_family),
468 ));
469 }
470
471 Ok((
472 VsockStream {
473 sock: VsockSocket { fd },
474 },
475 SocketAddr {
476 cid: svm.svm_cid.into(),
477 port: svm.svm_port,
478 },
479 ))
480 }
481
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>482 pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
483 self.sock.set_nonblocking(nonblocking)
484 }
485 }
486
487 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd488 fn as_raw_fd(&self) -> RawFd {
489 self.sock.as_raw_fd()
490 }
491 }
492