• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 /// Support for virtual sockets.
6 use std::fmt;
7 use std::io;
8 use std::mem::{self, size_of};
9 use std::num::ParseIntError;
10 use std::os::raw::{c_uchar, c_uint, c_ushort};
11 use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
12 use std::result;
13 use std::str::FromStr;
14 
15 use libc::{
16     self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
17     VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
18 };
19 
20 // The domain for vsock sockets.
21 const AF_VSOCK: sa_family_t = 40;
22 
23 // Vsock loopback address.
24 const VMADDR_CID_LOCAL: c_uint = 1;
25 
26 /// Vsock equivalent of binding on port 0. Binds to a random port.
27 pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
28 
29 // The number of bytes of padding to be added to the sockaddr_vm struct.  Taken directly
30 // from linux/vm_sockets.h.
31 const PADDING: usize = size_of::<sockaddr>()
32     - size_of::<sa_family_t>()
33     - size_of::<c_ushort>()
34     - (2 * size_of::<c_uint>());
35 
36 #[repr(C)]
37 #[derive(Default)]
38 struct sockaddr_vm {
39     svm_family: sa_family_t,
40     svm_reserved1: c_ushort,
41     svm_port: c_uint,
42     svm_cid: c_uint,
43     svm_zero: [c_uchar; PADDING],
44 }
45 
46 #[derive(Debug)]
47 pub struct AddrParseError;
48 
49 impl fmt::Display for AddrParseError {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result50     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
51         write!(fmt, "failed to parse vsock address")
52     }
53 }
54 
55 /// The vsock equivalent of an IP address.
56 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
57 pub enum VsockCid {
58     /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
59     Any,
60     /// An address that refers to the bare-metal machine that serves as the hypervisor.
61     Hypervisor,
62     /// The loopback address.
63     Local,
64     /// The parent machine. It may not be the hypervisor for nested VMs.
65     Host,
66     /// An assigned CID that serves as the address for VSOCK.
67     Cid(c_uint),
68 }
69 
70 impl fmt::Display for VsockCid {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result71     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
72         match &self {
73             VsockCid::Any => write!(fmt, "Any"),
74             VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
75             VsockCid::Local => write!(fmt, "Local"),
76             VsockCid::Host => write!(fmt, "Host"),
77             VsockCid::Cid(c) => write!(fmt, "'{}'", c),
78         }
79     }
80 }
81 
82 impl From<c_uint> for VsockCid {
from(c: c_uint) -> Self83     fn from(c: c_uint) -> Self {
84         match c {
85             VMADDR_CID_ANY => VsockCid::Any,
86             VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
87             VMADDR_CID_LOCAL => VsockCid::Local,
88             VMADDR_CID_HOST => VsockCid::Host,
89             _ => VsockCid::Cid(c),
90         }
91     }
92 }
93 
94 impl FromStr for VsockCid {
95     type Err = ParseIntError;
96 
from_str(s: &str) -> Result<Self, Self::Err>97     fn from_str(s: &str) -> Result<Self, Self::Err> {
98         let c: c_uint = s.parse()?;
99         Ok(c.into())
100     }
101 }
102 
103 impl Into<c_uint> for VsockCid {
into(self) -> c_uint104     fn into(self) -> c_uint {
105         match self {
106             VsockCid::Any => VMADDR_CID_ANY,
107             VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
108             VsockCid::Local => VMADDR_CID_LOCAL,
109             VsockCid::Host => VMADDR_CID_HOST,
110             VsockCid::Cid(c) => c,
111         }
112     }
113 }
114 
115 /// An address associated with a virtual socket.
116 #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
117 pub struct SocketAddr {
118     pub cid: VsockCid,
119     pub port: c_uint,
120 }
121 
122 pub trait ToSocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>123     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
124 }
125 
126 impl ToSocketAddr for SocketAddr {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>127     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
128         Ok(*self)
129     }
130 }
131 
132 impl ToSocketAddr for str {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>133     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
134         self.parse()
135     }
136 }
137 
138 impl ToSocketAddr for (VsockCid, c_uint) {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>139     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
140         let (cid, port) = *self;
141         Ok(SocketAddr { cid, port })
142     }
143 }
144 
145 impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>146     fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
147         (**self).to_socket_addr()
148     }
149 }
150 
151 impl FromStr for SocketAddr {
152     type Err = AddrParseError;
153 
154     /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
155     /// "vsock:cid:port".
from_str(s: &str) -> Result<SocketAddr, AddrParseError>156     fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
157         let components: Vec<&str> = s.split(':').collect();
158         if components.len() != 3 || components[0] != "vsock" {
159             return Err(AddrParseError);
160         }
161 
162         Ok(SocketAddr {
163             cid: components[1].parse().map_err(|_| AddrParseError)?,
164             port: components[2].parse().map_err(|_| AddrParseError)?,
165         })
166     }
167 }
168 
169 impl fmt::Display for SocketAddr {
fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result170     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
171         write!(fmt, "{}:{}", self.cid, self.port)
172     }
173 }
174 
175 /// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
176 /// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()>177 unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
178     let flags = libc::fcntl(fd, F_GETFL, 0);
179     if flags < 0 {
180         return Err(io::Error::last_os_error());
181     }
182 
183     let flags = if nonblocking {
184         flags | O_NONBLOCK
185     } else {
186         flags & !O_NONBLOCK
187     };
188 
189     let ret = libc::fcntl(fd, F_SETFL, flags);
190     if ret < 0 {
191         return Err(io::Error::last_os_error());
192     }
193 
194     Ok(())
195 }
196 
197 /// A virtual socket.
198 ///
199 /// Do not use this class unless you need to change socket options or query the
200 /// state of the socket prior to calling listen or connect. Instead use either VsockStream or
201 /// VsockListener.
202 #[derive(Debug)]
203 pub struct VsockSocket {
204     fd: RawFd,
205 }
206 
207 impl VsockSocket {
new() -> io::Result<Self>208     pub fn new() -> io::Result<Self> {
209         let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
210         if fd < 0 {
211             Err(io::Error::last_os_error())
212         } else {
213             Ok(VsockSocket { fd })
214         }
215     }
216 
bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()>217     pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
218         let sockaddr = addr
219             .to_socket_addr()
220             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
221 
222         // The compiler should optimize this out since these are both compile-time constants.
223         assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
224 
225         let mut svm: sockaddr_vm = Default::default();
226         svm.svm_family = AF_VSOCK;
227         svm.svm_cid = sockaddr.cid.into();
228         svm.svm_port = sockaddr.port;
229 
230         // Safe because this doesn't modify any memory and we check the return value.
231         let ret = unsafe {
232             libc::bind(
233                 self.fd,
234                 &svm as *const sockaddr_vm as *const sockaddr,
235                 size_of::<sockaddr_vm>() as socklen_t,
236             )
237         };
238         if ret < 0 {
239             let bind_err = io::Error::last_os_error();
240             Err(bind_err)
241         } else {
242             Ok(())
243         }
244     }
245 
connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream>246     pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
247         let sockaddr = addr
248             .to_socket_addr()
249             .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
250 
251         let mut svm: sockaddr_vm = Default::default();
252         svm.svm_family = AF_VSOCK;
253         svm.svm_cid = sockaddr.cid.into();
254         svm.svm_port = sockaddr.port;
255 
256         // Safe because this just connects a vsock socket, and the return value is checked.
257         let ret = unsafe {
258             libc::connect(
259                 self.fd,
260                 &svm as *const sockaddr_vm as *const sockaddr,
261                 size_of::<sockaddr_vm>() as socklen_t,
262             )
263         };
264         if ret < 0 {
265             let connect_err = io::Error::last_os_error();
266             Err(connect_err)
267         } else {
268             Ok(VsockStream { sock: self })
269         }
270     }
271 
listen(self) -> io::Result<VsockListener>272     pub fn listen(self) -> io::Result<VsockListener> {
273         // Safe because this doesn't modify any memory and we check the return value.
274         let ret = unsafe { libc::listen(self.fd, 1) };
275         if ret < 0 {
276             let listen_err = io::Error::last_os_error();
277             return Err(listen_err);
278         }
279         Ok(VsockListener { sock: self })
280     }
281 
282     /// Returns the port that this socket is bound to. This can only succeed after bind is called.
local_port(&self) -> io::Result<u32>283     pub fn local_port(&self) -> io::Result<u32> {
284         let mut svm: sockaddr_vm = Default::default();
285 
286         // Safe because we give a valid pointer for addrlen and check the length.
287         let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
288         let ret = unsafe {
289             // Get the socket address that was actually bound.
290             libc::getsockname(
291                 self.fd,
292                 &mut svm as *mut sockaddr_vm as *mut sockaddr,
293                 &mut addrlen as *mut socklen_t,
294             )
295         };
296         if ret < 0 {
297             let getsockname_err = io::Error::last_os_error();
298             Err(getsockname_err)
299         } else {
300             // If this doesn't match, it's not safe to get the port out of the sockaddr.
301             assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
302 
303             Ok(svm.svm_port)
304         }
305     }
306 
try_clone(&self) -> io::Result<Self>307     pub fn try_clone(&self) -> io::Result<Self> {
308         // Safe because this doesn't modify any memory and we check the return value.
309         let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
310         if dup_fd < 0 {
311             Err(io::Error::last_os_error())
312         } else {
313             Ok(Self { fd: dup_fd })
314         }
315     }
316 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>317     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
318         // Safe because the fd is valid and owned by this stream.
319         unsafe { set_nonblocking(self.fd, nonblocking) }
320     }
321 }
322 
323 impl IntoRawFd for VsockSocket {
into_raw_fd(self) -> RawFd324     fn into_raw_fd(self) -> RawFd {
325         let fd = self.fd;
326         mem::forget(self);
327         fd
328     }
329 }
330 
331 impl AsRawFd for VsockSocket {
as_raw_fd(&self) -> RawFd332     fn as_raw_fd(&self) -> RawFd {
333         self.fd
334     }
335 }
336 
337 impl Drop for VsockSocket {
drop(&mut self)338     fn drop(&mut self) {
339         // Safe because this doesn't modify any memory and we are the only
340         // owner of the file descriptor.
341         unsafe { libc::close(self.fd) };
342     }
343 }
344 
345 /// A virtual stream socket.
346 #[derive(Debug)]
347 pub struct VsockStream {
348     sock: VsockSocket,
349 }
350 
351 impl VsockStream {
connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream>352     pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
353         let sock = VsockSocket::new()?;
354         sock.connect(addr)
355     }
356 
357     /// Returns the port that this stream is bound to.
local_port(&self) -> io::Result<u32>358     pub fn local_port(&self) -> io::Result<u32> {
359         self.sock.local_port()
360     }
361 
try_clone(&self) -> io::Result<VsockStream>362     pub fn try_clone(&self) -> io::Result<VsockStream> {
363         self.sock.try_clone().map(|f| VsockStream { sock: f })
364     }
365 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>366     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
367         self.sock.set_nonblocking(nonblocking)
368     }
369 }
370 
371 impl io::Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>372     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
373         // Safe because this will only modify the contents of |buf| and we check the return value.
374         let ret = unsafe {
375             libc::read(
376                 self.sock.as_raw_fd(),
377                 buf as *mut [u8] as *mut c_void,
378                 buf.len() as size_t,
379             )
380         };
381         if ret < 0 {
382             return Err(io::Error::last_os_error());
383         }
384 
385         Ok(ret as usize)
386     }
387 }
388 
389 impl io::Write for VsockStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>390     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
391         // Safe because this doesn't modify any memory and we check the return value.
392         let ret = unsafe {
393             libc::write(
394                 self.sock.as_raw_fd(),
395                 buf as *const [u8] as *const c_void,
396                 buf.len() as size_t,
397             )
398         };
399         if ret < 0 {
400             return Err(io::Error::last_os_error());
401         }
402 
403         Ok(ret as usize)
404     }
405 
flush(&mut self) -> io::Result<()>406     fn flush(&mut self) -> io::Result<()> {
407         // No buffered data so nothing to do.
408         Ok(())
409     }
410 }
411 
412 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd413     fn as_raw_fd(&self) -> RawFd {
414         self.sock.as_raw_fd()
415     }
416 }
417 
418 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd419     fn into_raw_fd(self) -> RawFd {
420         self.sock.into_raw_fd()
421     }
422 }
423 
424 /// Represents a virtual socket server.
425 #[derive(Debug)]
426 pub struct VsockListener {
427     sock: VsockSocket,
428 }
429 
430 impl VsockListener {
431     /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
432     /// endpoint.
bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener>433     pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
434         let mut sock = VsockSocket::new()?;
435         sock.bind(addr)?;
436         sock.listen()
437     }
438 
439     /// Returns the port that this listener is bound to.
local_port(&self) -> io::Result<u32>440     pub fn local_port(&self) -> io::Result<u32> {
441         self.sock.local_port()
442     }
443 
444     /// Accepts a new incoming connection on this listener.  Blocks the calling thread until a
445     /// new connection is established.  When established, returns the corresponding `VsockStream`
446     /// and the remote peer's address.
accept(&self) -> io::Result<(VsockStream, SocketAddr)>447     pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
448         let mut svm: sockaddr_vm = Default::default();
449 
450         // Safe because this will only modify |svm| and we check the return value.
451         let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
452         let fd = unsafe {
453             libc::accept4(
454                 self.sock.as_raw_fd(),
455                 &mut svm as *mut sockaddr_vm as *mut sockaddr,
456                 &mut socklen as *mut socklen_t,
457                 libc::SOCK_CLOEXEC,
458             )
459         };
460         if fd < 0 {
461             return Err(io::Error::last_os_error());
462         }
463 
464         if svm.svm_family != AF_VSOCK {
465             return Err(io::Error::new(
466                 io::ErrorKind::InvalidData,
467                 format!("unexpected address family: {}", svm.svm_family),
468             ));
469         }
470 
471         Ok((
472             VsockStream {
473                 sock: VsockSocket { fd },
474             },
475             SocketAddr {
476                 cid: svm.svm_cid.into(),
477                 port: svm.svm_port,
478             },
479         ))
480     }
481 
set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()>482     pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
483         self.sock.set_nonblocking(nonblocking)
484     }
485 }
486 
487 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd488     fn as_raw_fd(&self) -> RawFd {
489         self.sock.as_raw_fd()
490     }
491 }
492