• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 fsyncd, Berlin, Germany.
3  * Additional material Copyright the Rust project and it's contributors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 //! Virtio socket support for Rust.
19 
20 use libc::*;
21 use nix::ioctl_read_bad;
22 use std::ffi::c_void;
23 use std::fs::File;
24 use std::io::{Error, ErrorKind, Read, Result, Write};
25 use std::mem::{self, size_of};
26 use std::net::Shutdown;
27 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
28 use std::time::Duration;
29 
30 pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
31 pub use nix::sys::socket::{SockAddr, VsockAddr};
32 
new_socket() -> libc::c_int33 fn new_socket() -> libc::c_int {
34     unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) }
35 }
36 
37 /// An iterator that infinitely accepts connections on a VsockListener.
38 #[derive(Debug)]
39 pub struct Incoming<'a> {
40     listener: &'a VsockListener,
41 }
42 
43 impl<'a> Iterator for Incoming<'a> {
44     type Item = Result<VsockStream>;
45 
next(&mut self) -> Option<Result<VsockStream>>46     fn next(&mut self) -> Option<Result<VsockStream>> {
47         Some(self.listener.accept().map(|p| p.0))
48     }
49 }
50 
51 /// A virtio socket server, listening for connections.
52 #[derive(Debug, Clone)]
53 pub struct VsockListener {
54     socket: RawFd,
55 }
56 
57 impl VsockListener {
58     /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &SockAddr) -> Result<VsockListener>59     pub fn bind(addr: &SockAddr) -> Result<VsockListener> {
60         let mut vsock_addr = if let SockAddr::Vsock(addr) = addr {
61             addr.0
62         } else {
63             return Err(Error::new(
64                 ErrorKind::Other,
65                 "requires a virtio socket address",
66             ));
67         };
68 
69         let socket = new_socket();
70         if socket < 0 {
71             return Err(Error::last_os_error());
72         }
73 
74         let res = unsafe {
75             bind(
76                 socket,
77                 &mut vsock_addr as *mut _ as *mut sockaddr,
78                 size_of::<sockaddr_vm>() as socklen_t,
79             )
80         };
81         if res < 0 {
82             return Err(Error::last_os_error());
83         }
84 
85         // rust stdlib uses a 128 connection backlog
86         let res = unsafe { listen(socket, 128) };
87         if res < 0 {
88             return Err(Error::last_os_error());
89         }
90 
91         Ok(Self { socket })
92     }
93 
94     /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>95     pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
96         Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port)))
97     }
98 
99     /// The local socket address of the listener.
local_addr(&self) -> Result<SockAddr>100     pub fn local_addr(&self) -> Result<SockAddr> {
101         let mut vsock_addr = sockaddr_vm {
102             svm_family: AF_VSOCK as sa_family_t,
103             svm_reserved1: 0,
104             svm_port: 0,
105             svm_cid: 0,
106             svm_zero: [0u8; 4],
107         };
108         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
109         if unsafe {
110             getsockname(
111                 self.socket,
112                 &mut vsock_addr as *mut _ as *mut sockaddr,
113                 &mut vsock_addr_len,
114             )
115         } < 0
116         {
117             Err(Error::last_os_error())
118         } else {
119             Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
120         }
121     }
122 
123     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>124     pub fn try_clone(&self) -> Result<Self> {
125         Ok(self.clone())
126     }
127 
128     /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, SockAddr)>129     pub fn accept(&self) -> Result<(VsockStream, SockAddr)> {
130         let mut vsock_addr = sockaddr_vm {
131             svm_family: AF_VSOCK as sa_family_t,
132             svm_reserved1: 0,
133             svm_port: 0,
134             svm_cid: 0,
135             svm_zero: [0u8; 4],
136         };
137         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
138         let socket = unsafe {
139             accept4(
140                 self.socket,
141                 &mut vsock_addr as *mut _ as *mut sockaddr,
142                 &mut vsock_addr_len,
143                 SOCK_CLOEXEC,
144             )
145         };
146         if socket < 0 {
147             Err(Error::last_os_error())
148         } else {
149             Ok((
150                 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
151                 SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)),
152             ))
153         }
154     }
155 
156     /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming157     pub fn incoming(&self) -> Incoming {
158         Incoming { listener: self }
159     }
160 
161     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>162     pub fn take_error(&self) -> Result<Option<Error>> {
163         let mut error: i32 = 0;
164         let mut error_len: socklen_t = 0;
165         if unsafe {
166             getsockopt(
167                 self.socket,
168                 SOL_SOCKET,
169                 SO_ERROR,
170                 &mut error as *mut _ as *mut c_void,
171                 &mut error_len,
172             )
173         } < 0
174         {
175             Err(Error::last_os_error())
176         } else {
177             Ok(if error == 0 {
178                 None
179             } else {
180                 Some(Error::from_raw_os_error(error))
181             })
182         }
183     }
184 
185     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>186     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
187         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
188         if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
189             Err(Error::last_os_error())
190         } else {
191             Ok(())
192         }
193     }
194 }
195 
196 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd197     fn as_raw_fd(&self) -> RawFd {
198         self.socket
199     }
200 }
201 
202 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self203     unsafe fn from_raw_fd(socket: RawFd) -> Self {
204         Self { socket }
205     }
206 }
207 
208 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd209     fn into_raw_fd(self) -> RawFd {
210         let fd = self.socket;
211         mem::forget(self);
212         fd
213     }
214 }
215 
216 impl Drop for VsockListener {
drop(&mut self)217     fn drop(&mut self) {
218         unsafe { close(self.socket) };
219     }
220 }
221 
222 /// A virtio stream between a local and a remote socket.
223 #[derive(Debug, Clone)]
224 pub struct VsockStream {
225     socket: RawFd,
226 }
227 
228 impl VsockStream {
229     /// Open a connection to a remote host.
connect(addr: &SockAddr) -> Result<Self>230     pub fn connect(addr: &SockAddr) -> Result<Self> {
231         let vsock_addr = if let SockAddr::Vsock(addr) = addr {
232             addr.0
233         } else {
234             return Err(Error::new(
235                 ErrorKind::Other,
236                 "requires a virtio socket address",
237             ));
238         };
239 
240         let sock = new_socket();
241         if sock < 0 {
242             return Err(Error::last_os_error());
243         }
244         if unsafe {
245             connect(
246                 sock,
247                 &vsock_addr as *const _ as *const sockaddr,
248                 size_of::<sockaddr_vm>() as socklen_t,
249             )
250         } < 0
251         {
252             Err(Error::last_os_error())
253         } else {
254             Ok(unsafe { VsockStream::from_raw_fd(sock) })
255         }
256     }
257 
258     /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>259     pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
260         Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port)))
261     }
262 
263     /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<SockAddr>264     pub fn peer_addr(&self) -> Result<SockAddr> {
265         let mut vsock_addr = sockaddr_vm {
266             svm_family: AF_VSOCK as sa_family_t,
267             svm_reserved1: 0,
268             svm_port: 0,
269             svm_cid: 0,
270             svm_zero: [0u8; 4],
271         };
272         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
273         if unsafe {
274             getpeername(
275                 self.socket,
276                 &mut vsock_addr as *mut _ as *mut sockaddr,
277                 &mut vsock_addr_len,
278             )
279         } < 0
280         {
281             Err(Error::last_os_error())
282         } else {
283             Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
284         }
285     }
286 
287     /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<SockAddr>288     pub fn local_addr(&self) -> Result<SockAddr> {
289         let mut vsock_addr = sockaddr_vm {
290             svm_family: AF_VSOCK as sa_family_t,
291             svm_reserved1: 0,
292             svm_port: 0,
293             svm_cid: 0,
294             svm_zero: [0u8; 4],
295         };
296         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
297         if unsafe {
298             getsockname(
299                 self.socket,
300                 &mut vsock_addr as *mut _ as *mut sockaddr,
301                 &mut vsock_addr_len,
302             )
303         } < 0
304         {
305             Err(Error::last_os_error())
306         } else {
307             Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
308         }
309     }
310 
311     /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>312     pub fn shutdown(&self, how: Shutdown) -> Result<()> {
313         let how = match how {
314             Shutdown::Write => SHUT_WR,
315             Shutdown::Read => SHUT_RD,
316             Shutdown::Both => SHUT_RDWR,
317         };
318         if unsafe { shutdown(self.socket, how) } < 0 {
319             Err(Error::last_os_error())
320         } else {
321             Ok(())
322         }
323     }
324 
325     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>326     pub fn try_clone(&self) -> Result<Self> {
327         Ok(self.clone())
328     }
329 
330     /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>331     pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
332         let timeout = Self::timeval_from_duration(dur)?;
333         if unsafe {
334             setsockopt(
335                 self.socket,
336                 SOL_SOCKET,
337                 SO_SNDTIMEO,
338                 &timeout as *const _ as *const c_void,
339                 size_of::<timeval>() as socklen_t,
340             )
341         } < 0
342         {
343             Err(Error::last_os_error())
344         } else {
345             Ok(())
346         }
347     }
348 
349     /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>350     pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
351         let timeout = Self::timeval_from_duration(dur)?;
352         if unsafe {
353             setsockopt(
354                 self.socket,
355                 SOL_SOCKET,
356                 SO_RCVTIMEO,
357                 &timeout as *const _ as *const c_void,
358                 size_of::<timeval>() as socklen_t,
359             )
360         } < 0
361         {
362             Err(Error::last_os_error())
363         } else {
364             Ok(())
365         }
366     }
367 
368     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>369     pub fn take_error(&self) -> Result<Option<Error>> {
370         let mut error: i32 = 0;
371         let mut error_len: socklen_t = 0;
372         if unsafe {
373             getsockopt(
374                 self.socket,
375                 SOL_SOCKET,
376                 SO_ERROR,
377                 &mut error as *mut _ as *mut c_void,
378                 &mut error_len,
379             )
380         } < 0
381         {
382             Err(Error::last_os_error())
383         } else {
384             Ok(if error == 0 {
385                 None
386             } else {
387                 Some(Error::from_raw_os_error(error))
388             })
389         }
390     }
391 
392     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>393     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
394         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
395         if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
396             Err(Error::last_os_error())
397         } else {
398             Ok(())
399         }
400     }
401 
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>402     fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
403         match dur {
404             Some(dur) => {
405                 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
406                     return Err(Error::new(
407                         ErrorKind::InvalidInput,
408                         "cannot set a zero duration timeout",
409                     ));
410                 }
411 
412                 // https://github.com/rust-lang/libc/issues/1848
413                 #[cfg_attr(target_env = "musl", allow(deprecated))]
414                 let secs = if dur.as_secs() > time_t::max_value() as u64 {
415                     time_t::max_value()
416                 } else {
417                     dur.as_secs() as time_t
418                 };
419                 let mut timeout = timeval {
420                     tv_sec: secs,
421                     tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
422                 };
423                 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
424                     timeout.tv_usec = 1;
425                 }
426                 Ok(timeout)
427             }
428             None => Ok(timeval {
429                 tv_sec: 0,
430                 tv_usec: 0,
431             }),
432         }
433     }
434 }
435 
436 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>437     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
438         <&Self>::read(&mut &*self, buf)
439     }
440 }
441 
442 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>443     fn write(&mut self, buf: &[u8]) -> Result<usize> {
444         <&Self>::write(&mut &*self, buf)
445     }
446 
flush(&mut self) -> Result<()>447     fn flush(&mut self) -> Result<()> {
448         Ok(())
449     }
450 }
451 
452 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>453     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
454         let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
455         if ret < 0 {
456             Err(Error::last_os_error())
457         } else {
458             Ok(ret as usize)
459         }
460     }
461 }
462 
463 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>464     fn write(&mut self, buf: &[u8]) -> Result<usize> {
465         let ret = unsafe {
466             send(
467                 self.socket,
468                 buf.as_ptr() as *const c_void,
469                 buf.len(),
470                 MSG_NOSIGNAL,
471             )
472         };
473         if ret < 0 {
474             Err(Error::last_os_error())
475         } else {
476             Ok(ret as usize)
477         }
478     }
479 
flush(&mut self) -> Result<()>480     fn flush(&mut self) -> Result<()> {
481         Ok(())
482     }
483 }
484 
485 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd486     fn as_raw_fd(&self) -> RawFd {
487         self.socket
488     }
489 }
490 
491 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self492     unsafe fn from_raw_fd(socket: RawFd) -> Self {
493         Self { socket }
494     }
495 }
496 
497 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd498     fn into_raw_fd(self) -> RawFd {
499         let fd = self.socket;
500         mem::forget(self);
501         fd
502     }
503 }
504 
505 impl Drop for VsockStream {
drop(&mut self)506     fn drop(&mut self) {
507         unsafe { close(self.socket) };
508     }
509 }
510 
511 const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
512 ioctl_read_bad!(
513     vm_sockets_get_local_cid,
514     IOCTL_VM_SOCKETS_GET_LOCAL_CID,
515     u32
516 );
517 
518 /// Gets the CID of the local machine.
519 ///
520 /// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
521 /// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
get_local_cid() -> Result<u32>522 pub fn get_local_cid() -> Result<u32> {
523     let f = File::open("/dev/vsock")?;
524     let mut cid = 0;
525     // SAFETY: the kernel only modifies the given u32 integer.
526     unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
527     Ok(cid)
528 }
529