1 /*
2 * Copyright 2019 fsyncd, Berlin, Germany.
3 * Additional material Copyright the Rust project and it's contributors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 //! Virtio socket support for Rust.
19
20 use libc::{
21 accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK,
22 FIONBIO, SOCK_CLOEXEC,
23 };
24 use nix::{
25 ioctl_read_bad,
26 sys::socket::{
27 self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
28 sockopt::{ReceiveTimeout, SendTimeout, SocketError},
29 AddressFamily, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
30 },
31 unistd::close,
32 };
33 use std::fs::File;
34 use std::io::{Error, ErrorKind, Read, Result, Write};
35 use std::mem::{self, size_of};
36 use std::net::Shutdown;
37 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
38 use std::time::Duration;
39
40 pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
41 pub use nix::sys::socket::{SockaddrLike, VsockAddr};
42
new_socket() -> Result<RawFd>43 fn new_socket() -> Result<RawFd> {
44 Ok(socket(
45 AddressFamily::Vsock,
46 SockType::Stream,
47 SockFlag::SOCK_CLOEXEC,
48 None,
49 )?)
50 }
51
52 /// An iterator that infinitely accepts connections on a VsockListener.
53 #[derive(Debug)]
54 pub struct Incoming<'a> {
55 listener: &'a VsockListener,
56 }
57
58 impl<'a> Iterator for Incoming<'a> {
59 type Item = Result<VsockStream>;
60
next(&mut self) -> Option<Result<VsockStream>>61 fn next(&mut self) -> Option<Result<VsockStream>> {
62 Some(self.listener.accept().map(|p| p.0))
63 }
64 }
65
66 /// A virtio socket server, listening for connections.
67 #[derive(Debug, Clone)]
68 pub struct VsockListener {
69 socket: RawFd,
70 }
71
72 impl VsockListener {
73 /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &impl SockaddrLike) -> Result<Self>74 pub fn bind(addr: &impl SockaddrLike) -> Result<Self> {
75 if addr.family() != Some(AddressFamily::Vsock) {
76 return Err(Error::new(
77 ErrorKind::Other,
78 "requires a virtio socket address",
79 ));
80 }
81
82 let socket = new_socket()?;
83
84 bind(socket, addr)?;
85
86 // rust stdlib uses a 128 connection backlog
87 listen(socket, 128)?;
88
89 Ok(Self { socket })
90 }
91
92 /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>93 pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
94 Self::bind(&VsockAddr::new(cid, port))
95 }
96
97 /// The local socket address of the listener.
local_addr(&self) -> Result<VsockAddr>98 pub fn local_addr(&self) -> Result<VsockAddr> {
99 Ok(getsockname(self.socket)?)
100 }
101
102 /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>103 pub fn try_clone(&self) -> Result<Self> {
104 Ok(self.clone())
105 }
106
107 /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, VsockAddr)>108 pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
109 let mut vsock_addr = sockaddr_vm {
110 svm_family: AF_VSOCK as sa_family_t,
111 svm_reserved1: 0,
112 svm_port: 0,
113 svm_cid: 0,
114 svm_zero: [0u8; 4],
115 };
116 let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
117 let socket = unsafe {
118 accept4(
119 self.socket,
120 &mut vsock_addr as *mut _ as *mut sockaddr,
121 &mut vsock_addr_len,
122 SOCK_CLOEXEC,
123 )
124 };
125 if socket < 0 {
126 Err(Error::last_os_error())
127 } else {
128 Ok((
129 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
130 VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
131 ))
132 }
133 }
134
135 /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming136 pub fn incoming(&self) -> Incoming {
137 Incoming { listener: self }
138 }
139
140 /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>141 pub fn take_error(&self) -> Result<Option<Error>> {
142 let error = SocketError.get(self.socket)?;
143 Ok(if error == 0 {
144 None
145 } else {
146 Some(Error::from_raw_os_error(error))
147 })
148 }
149
150 /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>151 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
152 let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
153 if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
154 Err(Error::last_os_error())
155 } else {
156 Ok(())
157 }
158 }
159 }
160
161 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd162 fn as_raw_fd(&self) -> RawFd {
163 self.socket
164 }
165 }
166
167 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self168 unsafe fn from_raw_fd(socket: RawFd) -> Self {
169 Self { socket }
170 }
171 }
172
173 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd174 fn into_raw_fd(self) -> RawFd {
175 let fd = self.socket;
176 mem::forget(self);
177 fd
178 }
179 }
180
181 impl Drop for VsockListener {
drop(&mut self)182 fn drop(&mut self) {
183 let _ = close(self.socket);
184 }
185 }
186
187 /// A virtio stream between a local and a remote socket.
188 #[derive(Debug, Clone)]
189 pub struct VsockStream {
190 socket: RawFd,
191 }
192
193 impl VsockStream {
194 /// Open a connection to a remote host.
connect(addr: &impl SockaddrLike) -> Result<Self>195 pub fn connect(addr: &impl SockaddrLike) -> Result<Self> {
196 if addr.family() != Some(AddressFamily::Vsock) {
197 return Err(Error::new(
198 ErrorKind::Other,
199 "requires a virtio socket address",
200 ));
201 }
202
203 let sock = new_socket()?;
204 connect(sock, addr)?;
205 Ok(unsafe { Self::from_raw_fd(sock) })
206 }
207
208 /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>209 pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
210 Self::connect(&VsockAddr::new(cid, port))
211 }
212
213 /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<VsockAddr>214 pub fn peer_addr(&self) -> Result<VsockAddr> {
215 Ok(getpeername(self.socket)?)
216 }
217
218 /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<VsockAddr>219 pub fn local_addr(&self) -> Result<VsockAddr> {
220 Ok(getsockname(self.socket)?)
221 }
222
223 /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>224 pub fn shutdown(&self, how: Shutdown) -> Result<()> {
225 let how = match how {
226 Shutdown::Write => socket::Shutdown::Write,
227 Shutdown::Read => socket::Shutdown::Read,
228 Shutdown::Both => socket::Shutdown::Both,
229 };
230 Ok(shutdown(self.socket, how)?)
231 }
232
233 /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>234 pub fn try_clone(&self) -> Result<Self> {
235 Ok(self.clone())
236 }
237
238 /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>239 pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
240 let timeout = Self::timeval_from_duration(dur)?.into();
241 Ok(SendTimeout.set(self.socket, &timeout)?)
242 }
243
244 /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>245 pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
246 let timeout = Self::timeval_from_duration(dur)?.into();
247 Ok(ReceiveTimeout.set(self.socket, &timeout)?)
248 }
249
250 /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>251 pub fn take_error(&self) -> Result<Option<Error>> {
252 let error = SocketError.get(self.socket)?;
253 Ok(if error == 0 {
254 None
255 } else {
256 Some(Error::from_raw_os_error(error))
257 })
258 }
259
260 /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>261 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
262 let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
263 if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
264 Err(Error::last_os_error())
265 } else {
266 Ok(())
267 }
268 }
269
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>270 fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
271 match dur {
272 Some(dur) => {
273 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
274 return Err(Error::new(
275 ErrorKind::InvalidInput,
276 "cannot set a zero duration timeout",
277 ));
278 }
279
280 // https://github.com/rust-lang/libc/issues/1848
281 #[cfg_attr(target_env = "musl", allow(deprecated))]
282 let secs = if dur.as_secs() > libc::time_t::max_value() as u64 {
283 libc::time_t::max_value()
284 } else {
285 dur.as_secs() as libc::time_t
286 };
287 let mut timeout = timeval {
288 tv_sec: secs,
289 tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
290 };
291 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
292 timeout.tv_usec = 1;
293 }
294 Ok(timeout)
295 }
296 None => Ok(timeval {
297 tv_sec: 0,
298 tv_usec: 0,
299 }),
300 }
301 }
302 }
303
304 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>305 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
306 <&Self>::read(&mut &*self, buf)
307 }
308 }
309
310 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>311 fn write(&mut self, buf: &[u8]) -> Result<usize> {
312 <&Self>::write(&mut &*self, buf)
313 }
314
flush(&mut self) -> Result<()>315 fn flush(&mut self) -> Result<()> {
316 Ok(())
317 }
318 }
319
320 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>321 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
322 Ok(recv(self.socket, buf, MsgFlags::empty())?)
323 }
324 }
325
326 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>327 fn write(&mut self, buf: &[u8]) -> Result<usize> {
328 Ok(send(self.socket, buf, MsgFlags::MSG_NOSIGNAL)?)
329 }
330
flush(&mut self) -> Result<()>331 fn flush(&mut self) -> Result<()> {
332 Ok(())
333 }
334 }
335
336 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd337 fn as_raw_fd(&self) -> RawFd {
338 self.socket
339 }
340 }
341
342 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self343 unsafe fn from_raw_fd(socket: RawFd) -> Self {
344 Self { socket }
345 }
346 }
347
348 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd349 fn into_raw_fd(self) -> RawFd {
350 let fd = self.socket;
351 mem::forget(self);
352 fd
353 }
354 }
355
356 impl Drop for VsockStream {
drop(&mut self)357 fn drop(&mut self) {
358 let _ = close(self.socket);
359 }
360 }
361
362 const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
363 ioctl_read_bad!(
364 vm_sockets_get_local_cid,
365 IOCTL_VM_SOCKETS_GET_LOCAL_CID,
366 u32
367 );
368
369 /// Gets the CID of the local machine.
370 ///
371 /// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
372 /// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
get_local_cid() -> Result<u32>373 pub fn get_local_cid() -> Result<u32> {
374 let f = File::open("/dev/vsock")?;
375 let mut cid = 0;
376 // SAFETY: the kernel only modifies the given u32 integer.
377 unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
378 Ok(cid)
379 }
380