1 /*
2 * Copyright 2019 fsyncd, Berlin, Germany.
3 * Additional material Copyright the Rust project and it's contributors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 //! Virtio socket support for Rust.
19
20 use libc::*;
21 use nix::ioctl_read_bad;
22 use std::ffi::c_void;
23 use std::fs::File;
24 use std::io::{Error, ErrorKind, Read, Result, Write};
25 use std::mem::{self, size_of};
26 use std::net::Shutdown;
27 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
28 use std::time::Duration;
29
30 pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
31 pub use nix::sys::socket::{SockAddr, VsockAddr};
32
new_socket() -> libc::c_int33 fn new_socket() -> libc::c_int {
34 unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) }
35 }
36
37 /// An iterator that infinitely accepts connections on a VsockListener.
38 #[derive(Debug)]
39 pub struct Incoming<'a> {
40 listener: &'a VsockListener,
41 }
42
43 impl<'a> Iterator for Incoming<'a> {
44 type Item = Result<VsockStream>;
45
next(&mut self) -> Option<Result<VsockStream>>46 fn next(&mut self) -> Option<Result<VsockStream>> {
47 Some(self.listener.accept().map(|p| p.0))
48 }
49 }
50
51 /// A virtio socket server, listening for connections.
52 #[derive(Debug, Clone)]
53 pub struct VsockListener {
54 socket: RawFd,
55 }
56
57 impl VsockListener {
58 /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &SockAddr) -> Result<VsockListener>59 pub fn bind(addr: &SockAddr) -> Result<VsockListener> {
60 let mut vsock_addr = if let SockAddr::Vsock(addr) = addr {
61 addr.0
62 } else {
63 return Err(Error::new(
64 ErrorKind::Other,
65 "requires a virtio socket address",
66 ));
67 };
68
69 let socket = new_socket();
70 if socket < 0 {
71 return Err(Error::last_os_error());
72 }
73
74 let res = unsafe {
75 bind(
76 socket,
77 &mut vsock_addr as *mut _ as *mut sockaddr,
78 size_of::<sockaddr_vm>() as socklen_t,
79 )
80 };
81 if res < 0 {
82 return Err(Error::last_os_error());
83 }
84
85 // rust stdlib uses a 128 connection backlog
86 let res = unsafe { listen(socket, 128) };
87 if res < 0 {
88 return Err(Error::last_os_error());
89 }
90
91 Ok(Self { socket })
92 }
93
94 /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>95 pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
96 Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port)))
97 }
98
99 /// The local socket address of the listener.
local_addr(&self) -> Result<SockAddr>100 pub fn local_addr(&self) -> Result<SockAddr> {
101 let mut vsock_addr = sockaddr_vm {
102 svm_family: AF_VSOCK as sa_family_t,
103 svm_reserved1: 0,
104 svm_port: 0,
105 svm_cid: 0,
106 svm_zero: [0u8; 4],
107 };
108 let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
109 if unsafe {
110 getsockname(
111 self.socket,
112 &mut vsock_addr as *mut _ as *mut sockaddr,
113 &mut vsock_addr_len,
114 )
115 } < 0
116 {
117 Err(Error::last_os_error())
118 } else {
119 Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
120 }
121 }
122
123 /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>124 pub fn try_clone(&self) -> Result<Self> {
125 Ok(self.clone())
126 }
127
128 /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, SockAddr)>129 pub fn accept(&self) -> Result<(VsockStream, SockAddr)> {
130 let mut vsock_addr = sockaddr_vm {
131 svm_family: AF_VSOCK as sa_family_t,
132 svm_reserved1: 0,
133 svm_port: 0,
134 svm_cid: 0,
135 svm_zero: [0u8; 4],
136 };
137 let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
138 let socket = unsafe {
139 accept4(
140 self.socket,
141 &mut vsock_addr as *mut _ as *mut sockaddr,
142 &mut vsock_addr_len,
143 SOCK_CLOEXEC,
144 )
145 };
146 if socket < 0 {
147 Err(Error::last_os_error())
148 } else {
149 Ok((
150 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
151 SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)),
152 ))
153 }
154 }
155
156 /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming157 pub fn incoming(&self) -> Incoming {
158 Incoming { listener: self }
159 }
160
161 /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>162 pub fn take_error(&self) -> Result<Option<Error>> {
163 let mut error: i32 = 0;
164 let mut error_len: socklen_t = 0;
165 if unsafe {
166 getsockopt(
167 self.socket,
168 SOL_SOCKET,
169 SO_ERROR,
170 &mut error as *mut _ as *mut c_void,
171 &mut error_len,
172 )
173 } < 0
174 {
175 Err(Error::last_os_error())
176 } else {
177 Ok(if error == 0 {
178 None
179 } else {
180 Some(Error::from_raw_os_error(error))
181 })
182 }
183 }
184
185 /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>186 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
187 let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
188 if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
189 Err(Error::last_os_error())
190 } else {
191 Ok(())
192 }
193 }
194 }
195
196 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd197 fn as_raw_fd(&self) -> RawFd {
198 self.socket
199 }
200 }
201
202 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self203 unsafe fn from_raw_fd(socket: RawFd) -> Self {
204 Self { socket }
205 }
206 }
207
208 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd209 fn into_raw_fd(self) -> RawFd {
210 let fd = self.socket;
211 mem::forget(self);
212 fd
213 }
214 }
215
216 impl Drop for VsockListener {
drop(&mut self)217 fn drop(&mut self) {
218 unsafe { close(self.socket) };
219 }
220 }
221
222 /// A virtio stream between a local and a remote socket.
223 #[derive(Debug, Clone)]
224 pub struct VsockStream {
225 socket: RawFd,
226 }
227
228 impl VsockStream {
229 /// Open a connection to a remote host.
connect(addr: &SockAddr) -> Result<Self>230 pub fn connect(addr: &SockAddr) -> Result<Self> {
231 let vsock_addr = if let SockAddr::Vsock(addr) = addr {
232 addr.0
233 } else {
234 return Err(Error::new(
235 ErrorKind::Other,
236 "requires a virtio socket address",
237 ));
238 };
239
240 let sock = new_socket();
241 if sock < 0 {
242 return Err(Error::last_os_error());
243 }
244 if unsafe {
245 connect(
246 sock,
247 &vsock_addr as *const _ as *const sockaddr,
248 size_of::<sockaddr_vm>() as socklen_t,
249 )
250 } < 0
251 {
252 Err(Error::last_os_error())
253 } else {
254 Ok(unsafe { VsockStream::from_raw_fd(sock) })
255 }
256 }
257
258 /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>259 pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
260 Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port)))
261 }
262
263 /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<SockAddr>264 pub fn peer_addr(&self) -> Result<SockAddr> {
265 let mut vsock_addr = sockaddr_vm {
266 svm_family: AF_VSOCK as sa_family_t,
267 svm_reserved1: 0,
268 svm_port: 0,
269 svm_cid: 0,
270 svm_zero: [0u8; 4],
271 };
272 let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
273 if unsafe {
274 getpeername(
275 self.socket,
276 &mut vsock_addr as *mut _ as *mut sockaddr,
277 &mut vsock_addr_len,
278 )
279 } < 0
280 {
281 Err(Error::last_os_error())
282 } else {
283 Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
284 }
285 }
286
287 /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<SockAddr>288 pub fn local_addr(&self) -> Result<SockAddr> {
289 let mut vsock_addr = sockaddr_vm {
290 svm_family: AF_VSOCK as sa_family_t,
291 svm_reserved1: 0,
292 svm_port: 0,
293 svm_cid: 0,
294 svm_zero: [0u8; 4],
295 };
296 let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
297 if unsafe {
298 getsockname(
299 self.socket,
300 &mut vsock_addr as *mut _ as *mut sockaddr,
301 &mut vsock_addr_len,
302 )
303 } < 0
304 {
305 Err(Error::last_os_error())
306 } else {
307 Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
308 }
309 }
310
311 /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>312 pub fn shutdown(&self, how: Shutdown) -> Result<()> {
313 let how = match how {
314 Shutdown::Write => SHUT_WR,
315 Shutdown::Read => SHUT_RD,
316 Shutdown::Both => SHUT_RDWR,
317 };
318 if unsafe { shutdown(self.socket, how) } < 0 {
319 Err(Error::last_os_error())
320 } else {
321 Ok(())
322 }
323 }
324
325 /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>326 pub fn try_clone(&self) -> Result<Self> {
327 Ok(self.clone())
328 }
329
330 /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>331 pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
332 let timeout = Self::timeval_from_duration(dur)?;
333 if unsafe {
334 setsockopt(
335 self.socket,
336 SOL_SOCKET,
337 SO_SNDTIMEO,
338 &timeout as *const _ as *const c_void,
339 size_of::<timeval>() as socklen_t,
340 )
341 } < 0
342 {
343 Err(Error::last_os_error())
344 } else {
345 Ok(())
346 }
347 }
348
349 /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>350 pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
351 let timeout = Self::timeval_from_duration(dur)?;
352 if unsafe {
353 setsockopt(
354 self.socket,
355 SOL_SOCKET,
356 SO_RCVTIMEO,
357 &timeout as *const _ as *const c_void,
358 size_of::<timeval>() as socklen_t,
359 )
360 } < 0
361 {
362 Err(Error::last_os_error())
363 } else {
364 Ok(())
365 }
366 }
367
368 /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>369 pub fn take_error(&self) -> Result<Option<Error>> {
370 let mut error: i32 = 0;
371 let mut error_len: socklen_t = 0;
372 if unsafe {
373 getsockopt(
374 self.socket,
375 SOL_SOCKET,
376 SO_ERROR,
377 &mut error as *mut _ as *mut c_void,
378 &mut error_len,
379 )
380 } < 0
381 {
382 Err(Error::last_os_error())
383 } else {
384 Ok(if error == 0 {
385 None
386 } else {
387 Some(Error::from_raw_os_error(error))
388 })
389 }
390 }
391
392 /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>393 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
394 let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
395 if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
396 Err(Error::last_os_error())
397 } else {
398 Ok(())
399 }
400 }
401
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>402 fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
403 match dur {
404 Some(dur) => {
405 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
406 return Err(Error::new(
407 ErrorKind::InvalidInput,
408 "cannot set a zero duration timeout",
409 ));
410 }
411
412 // https://github.com/rust-lang/libc/issues/1848
413 #[cfg_attr(target_env = "musl", allow(deprecated))]
414 let secs = if dur.as_secs() > time_t::max_value() as u64 {
415 time_t::max_value()
416 } else {
417 dur.as_secs() as time_t
418 };
419 let mut timeout = timeval {
420 tv_sec: secs,
421 tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
422 };
423 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
424 timeout.tv_usec = 1;
425 }
426 Ok(timeout)
427 }
428 None => Ok(timeval {
429 tv_sec: 0,
430 tv_usec: 0,
431 }),
432 }
433 }
434 }
435
436 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>437 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
438 <&Self>::read(&mut &*self, buf)
439 }
440 }
441
442 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>443 fn write(&mut self, buf: &[u8]) -> Result<usize> {
444 <&Self>::write(&mut &*self, buf)
445 }
446
flush(&mut self) -> Result<()>447 fn flush(&mut self) -> Result<()> {
448 Ok(())
449 }
450 }
451
452 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>453 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
454 let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
455 if ret < 0 {
456 Err(Error::last_os_error())
457 } else {
458 Ok(ret as usize)
459 }
460 }
461 }
462
463 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>464 fn write(&mut self, buf: &[u8]) -> Result<usize> {
465 let ret = unsafe {
466 send(
467 self.socket,
468 buf.as_ptr() as *const c_void,
469 buf.len(),
470 MSG_NOSIGNAL,
471 )
472 };
473 if ret < 0 {
474 Err(Error::last_os_error())
475 } else {
476 Ok(ret as usize)
477 }
478 }
479
flush(&mut self) -> Result<()>480 fn flush(&mut self) -> Result<()> {
481 Ok(())
482 }
483 }
484
485 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd486 fn as_raw_fd(&self) -> RawFd {
487 self.socket
488 }
489 }
490
491 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self492 unsafe fn from_raw_fd(socket: RawFd) -> Self {
493 Self { socket }
494 }
495 }
496
497 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd498 fn into_raw_fd(self) -> RawFd {
499 let fd = self.socket;
500 mem::forget(self);
501 fd
502 }
503 }
504
505 impl Drop for VsockStream {
drop(&mut self)506 fn drop(&mut self) {
507 unsafe { close(self.socket) };
508 }
509 }
510
511 const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
512 ioctl_read_bad!(
513 vm_sockets_get_local_cid,
514 IOCTL_VM_SOCKETS_GET_LOCAL_CID,
515 u32
516 );
517
518 /// Gets the CID of the local machine.
519 ///
520 /// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
521 /// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
get_local_cid() -> Result<u32>522 pub fn get_local_cid() -> Result<u32> {
523 let f = File::open("/dev/vsock")?;
524 let mut cid = 0;
525 // SAFETY: the kernel only modifies the given u32 integer.
526 unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
527 Ok(cid)
528 }
529