• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 fsyncd, Berlin, Germany.
3  * Additional material Copyright the Rust project and it's contributors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 //! Virtio socket support for Rust.
19 
20 use libc::{
21     accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK,
22     FIONBIO, SOCK_CLOEXEC,
23 };
24 use nix::{
25     ioctl_read_bad,
26     sys::socket::{
27         self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
28         sockopt::{ReceiveTimeout, SendTimeout, SocketError},
29         AddressFamily, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
30     },
31     unistd::close,
32 };
33 use std::fs::File;
34 use std::io::{Error, ErrorKind, Read, Result, Write};
35 use std::mem::{self, size_of};
36 use std::net::Shutdown;
37 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
38 use std::time::Duration;
39 
40 pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
41 pub use nix::sys::socket::{SockaddrLike, VsockAddr};
42 
new_socket() -> Result<RawFd>43 fn new_socket() -> Result<RawFd> {
44     Ok(socket(
45         AddressFamily::Vsock,
46         SockType::Stream,
47         SockFlag::SOCK_CLOEXEC,
48         None,
49     )?)
50 }
51 
52 /// An iterator that infinitely accepts connections on a VsockListener.
53 #[derive(Debug)]
54 pub struct Incoming<'a> {
55     listener: &'a VsockListener,
56 }
57 
58 impl<'a> Iterator for Incoming<'a> {
59     type Item = Result<VsockStream>;
60 
next(&mut self) -> Option<Result<VsockStream>>61     fn next(&mut self) -> Option<Result<VsockStream>> {
62         Some(self.listener.accept().map(|p| p.0))
63     }
64 }
65 
66 /// A virtio socket server, listening for connections.
67 #[derive(Debug, Clone)]
68 pub struct VsockListener {
69     socket: RawFd,
70 }
71 
72 impl VsockListener {
73     /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &impl SockaddrLike) -> Result<Self>74     pub fn bind(addr: &impl SockaddrLike) -> Result<Self> {
75         if addr.family() != Some(AddressFamily::Vsock) {
76             return Err(Error::new(
77                 ErrorKind::Other,
78                 "requires a virtio socket address",
79             ));
80         }
81 
82         let socket = new_socket()?;
83 
84         bind(socket, addr)?;
85 
86         // rust stdlib uses a 128 connection backlog
87         listen(socket, 128)?;
88 
89         Ok(Self { socket })
90     }
91 
92     /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>93     pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
94         Self::bind(&VsockAddr::new(cid, port))
95     }
96 
97     /// The local socket address of the listener.
local_addr(&self) -> Result<VsockAddr>98     pub fn local_addr(&self) -> Result<VsockAddr> {
99         Ok(getsockname(self.socket)?)
100     }
101 
102     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>103     pub fn try_clone(&self) -> Result<Self> {
104         Ok(self.clone())
105     }
106 
107     /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, VsockAddr)>108     pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
109         let mut vsock_addr = sockaddr_vm {
110             svm_family: AF_VSOCK as sa_family_t,
111             svm_reserved1: 0,
112             svm_port: 0,
113             svm_cid: 0,
114             svm_zero: [0u8; 4],
115         };
116         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
117         let socket = unsafe {
118             accept4(
119                 self.socket,
120                 &mut vsock_addr as *mut _ as *mut sockaddr,
121                 &mut vsock_addr_len,
122                 SOCK_CLOEXEC,
123             )
124         };
125         if socket < 0 {
126             Err(Error::last_os_error())
127         } else {
128             Ok((
129                 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
130                 VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
131             ))
132         }
133     }
134 
135     /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming136     pub fn incoming(&self) -> Incoming {
137         Incoming { listener: self }
138     }
139 
140     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>141     pub fn take_error(&self) -> Result<Option<Error>> {
142         let error = SocketError.get(self.socket)?;
143         Ok(if error == 0 {
144             None
145         } else {
146             Some(Error::from_raw_os_error(error))
147         })
148     }
149 
150     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>151     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
152         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
153         if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
154             Err(Error::last_os_error())
155         } else {
156             Ok(())
157         }
158     }
159 }
160 
161 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd162     fn as_raw_fd(&self) -> RawFd {
163         self.socket
164     }
165 }
166 
167 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self168     unsafe fn from_raw_fd(socket: RawFd) -> Self {
169         Self { socket }
170     }
171 }
172 
173 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd174     fn into_raw_fd(self) -> RawFd {
175         let fd = self.socket;
176         mem::forget(self);
177         fd
178     }
179 }
180 
181 impl Drop for VsockListener {
drop(&mut self)182     fn drop(&mut self) {
183         let _ = close(self.socket);
184     }
185 }
186 
187 /// A virtio stream between a local and a remote socket.
188 #[derive(Debug, Clone)]
189 pub struct VsockStream {
190     socket: RawFd,
191 }
192 
193 impl VsockStream {
194     /// Open a connection to a remote host.
connect(addr: &impl SockaddrLike) -> Result<Self>195     pub fn connect(addr: &impl SockaddrLike) -> Result<Self> {
196         if addr.family() != Some(AddressFamily::Vsock) {
197             return Err(Error::new(
198                 ErrorKind::Other,
199                 "requires a virtio socket address",
200             ));
201         }
202 
203         let sock = new_socket()?;
204         connect(sock, addr)?;
205         Ok(unsafe { Self::from_raw_fd(sock) })
206     }
207 
208     /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>209     pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
210         Self::connect(&VsockAddr::new(cid, port))
211     }
212 
213     /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<VsockAddr>214     pub fn peer_addr(&self) -> Result<VsockAddr> {
215         Ok(getpeername(self.socket)?)
216     }
217 
218     /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<VsockAddr>219     pub fn local_addr(&self) -> Result<VsockAddr> {
220         Ok(getsockname(self.socket)?)
221     }
222 
223     /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>224     pub fn shutdown(&self, how: Shutdown) -> Result<()> {
225         let how = match how {
226             Shutdown::Write => socket::Shutdown::Write,
227             Shutdown::Read => socket::Shutdown::Read,
228             Shutdown::Both => socket::Shutdown::Both,
229         };
230         Ok(shutdown(self.socket, how)?)
231     }
232 
233     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>234     pub fn try_clone(&self) -> Result<Self> {
235         Ok(self.clone())
236     }
237 
238     /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>239     pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
240         let timeout = Self::timeval_from_duration(dur)?.into();
241         Ok(SendTimeout.set(self.socket, &timeout)?)
242     }
243 
244     /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>245     pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
246         let timeout = Self::timeval_from_duration(dur)?.into();
247         Ok(ReceiveTimeout.set(self.socket, &timeout)?)
248     }
249 
250     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>251     pub fn take_error(&self) -> Result<Option<Error>> {
252         let error = SocketError.get(self.socket)?;
253         Ok(if error == 0 {
254             None
255         } else {
256             Some(Error::from_raw_os_error(error))
257         })
258     }
259 
260     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>261     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
262         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
263         if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
264             Err(Error::last_os_error())
265         } else {
266             Ok(())
267         }
268     }
269 
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>270     fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
271         match dur {
272             Some(dur) => {
273                 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
274                     return Err(Error::new(
275                         ErrorKind::InvalidInput,
276                         "cannot set a zero duration timeout",
277                     ));
278                 }
279 
280                 // https://github.com/rust-lang/libc/issues/1848
281                 #[cfg_attr(target_env = "musl", allow(deprecated))]
282                 let secs = if dur.as_secs() > libc::time_t::max_value() as u64 {
283                     libc::time_t::max_value()
284                 } else {
285                     dur.as_secs() as libc::time_t
286                 };
287                 let mut timeout = timeval {
288                     tv_sec: secs,
289                     tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
290                 };
291                 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
292                     timeout.tv_usec = 1;
293                 }
294                 Ok(timeout)
295             }
296             None => Ok(timeval {
297                 tv_sec: 0,
298                 tv_usec: 0,
299             }),
300         }
301     }
302 }
303 
304 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>305     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
306         <&Self>::read(&mut &*self, buf)
307     }
308 }
309 
310 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>311     fn write(&mut self, buf: &[u8]) -> Result<usize> {
312         <&Self>::write(&mut &*self, buf)
313     }
314 
flush(&mut self) -> Result<()>315     fn flush(&mut self) -> Result<()> {
316         Ok(())
317     }
318 }
319 
320 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>321     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
322         Ok(recv(self.socket, buf, MsgFlags::empty())?)
323     }
324 }
325 
326 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>327     fn write(&mut self, buf: &[u8]) -> Result<usize> {
328         Ok(send(self.socket, buf, MsgFlags::MSG_NOSIGNAL)?)
329     }
330 
flush(&mut self) -> Result<()>331     fn flush(&mut self) -> Result<()> {
332         Ok(())
333     }
334 }
335 
336 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd337     fn as_raw_fd(&self) -> RawFd {
338         self.socket
339     }
340 }
341 
342 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self343     unsafe fn from_raw_fd(socket: RawFd) -> Self {
344         Self { socket }
345     }
346 }
347 
348 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd349     fn into_raw_fd(self) -> RawFd {
350         let fd = self.socket;
351         mem::forget(self);
352         fd
353     }
354 }
355 
356 impl Drop for VsockStream {
drop(&mut self)357     fn drop(&mut self) {
358         let _ = close(self.socket);
359     }
360 }
361 
362 const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
363 ioctl_read_bad!(
364     vm_sockets_get_local_cid,
365     IOCTL_VM_SOCKETS_GET_LOCAL_CID,
366     u32
367 );
368 
369 /// Gets the CID of the local machine.
370 ///
371 /// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
372 /// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
get_local_cid() -> Result<u32>373 pub fn get_local_cid() -> Result<u32> {
374     let f = File::open("/dev/vsock")?;
375     let mut cid = 0;
376     // SAFETY: the kernel only modifies the given u32 integer.
377     unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
378     Ok(cid)
379 }
380