• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2019 fsyncd, Berlin, Germany.
3  * Additional material Copyright the Rust project and it's contributors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 //! Virtio socket support for Rust.
19 
20 use std::io::{Error, ErrorKind, Read, Result, Write};
21 use std::mem::{self, size_of};
22 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
23 
24 use libc::*;
25 use std::ffi::c_void;
26 use std::net::Shutdown;
27 use std::time::Duration;
28 
29 pub use nix::sys::socket::{SockAddr, VsockAddr};
30 
new_socket() -> libc::c_int31 fn new_socket() -> libc::c_int {
32     unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) }
33 }
34 
35 /// An iterator that infinitely accepts connections on a VsockListener.
36 #[derive(Debug)]
37 pub struct Incoming<'a> {
38     listener: &'a VsockListener,
39 }
40 
41 impl<'a> Iterator for Incoming<'a> {
42     type Item = Result<VsockStream>;
43 
next(&mut self) -> Option<Result<VsockStream>>44     fn next(&mut self) -> Option<Result<VsockStream>> {
45         Some(self.listener.accept().map(|p| p.0))
46     }
47 }
48 
49 /// A virtio socket server, listening for connections.
50 #[derive(Debug, Clone)]
51 pub struct VsockListener {
52     socket: RawFd,
53 }
54 
55 impl VsockListener {
56     /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &SockAddr) -> Result<VsockListener>57     pub fn bind(addr: &SockAddr) -> Result<VsockListener> {
58         let mut vsock_addr = if let SockAddr::Vsock(addr) = addr {
59             addr.0
60         } else {
61             return Err(Error::new(
62                 ErrorKind::Other,
63                 "requires a virtio socket address",
64             ));
65         };
66 
67         let socket = new_socket();
68         if socket < 0 {
69             return Err(Error::last_os_error());
70         }
71 
72         let res = unsafe {
73             bind(
74                 socket,
75                 &mut vsock_addr as *mut _ as *mut sockaddr,
76                 size_of::<sockaddr_vm>() as socklen_t,
77             )
78         };
79         if res < 0 {
80             return Err(Error::last_os_error());
81         }
82 
83         // rust stdlib uses a 128 connection backlog
84         let res = unsafe { listen(socket, 128) };
85         if res < 0 {
86             return Err(Error::last_os_error());
87         }
88 
89         Ok(Self { socket })
90     }
91 
92     /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>93     pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
94         Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port)))
95     }
96 
97     /// The local socket address of the listener.
local_addr(&self) -> Result<SockAddr>98     pub fn local_addr(&self) -> Result<SockAddr> {
99         let mut vsock_addr = sockaddr_vm {
100             svm_family: AF_VSOCK as sa_family_t,
101             svm_reserved1: 0,
102             svm_port: 0,
103             svm_cid: 0,
104             svm_zero: [0u8; 4],
105         };
106         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
107         if unsafe {
108             getsockname(
109                 self.socket,
110                 &mut vsock_addr as *mut _ as *mut sockaddr,
111                 &mut vsock_addr_len,
112             )
113         } < 0
114         {
115             Err(Error::last_os_error())
116         } else {
117             Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
118         }
119     }
120 
121     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>122     pub fn try_clone(&self) -> Result<Self> {
123         Ok(self.clone())
124     }
125 
126     /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, SockAddr)>127     pub fn accept(&self) -> Result<(VsockStream, SockAddr)> {
128         let mut vsock_addr = sockaddr_vm {
129             svm_family: AF_VSOCK as sa_family_t,
130             svm_reserved1: 0,
131             svm_port: 0,
132             svm_cid: 0,
133             svm_zero: [0u8; 4],
134         };
135         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
136         let socket = unsafe {
137             accept(
138                 self.socket,
139                 &mut vsock_addr as *mut _ as *mut sockaddr,
140                 &mut vsock_addr_len,
141             )
142         };
143         if socket < 0 {
144             Err(Error::last_os_error())
145         } else {
146             Ok((
147                 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
148                 SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)),
149             ))
150         }
151     }
152 
153     /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming154     pub fn incoming(&self) -> Incoming {
155         Incoming { listener: self }
156     }
157 
158     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>159     pub fn take_error(&self) -> Result<Option<Error>> {
160         let mut error: i32 = 0;
161         let mut error_len: socklen_t = 0;
162         if unsafe {
163             getsockopt(
164                 self.socket,
165                 SOL_SOCKET,
166                 SO_ERROR,
167                 &mut error as *mut _ as *mut c_void,
168                 &mut error_len,
169             )
170         } < 0
171         {
172             Err(Error::last_os_error())
173         } else {
174             Ok(if error == 0 {
175                 None
176             } else {
177                 Some(Error::from_raw_os_error(error))
178             })
179         }
180     }
181 
182     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>183     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
184         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
185         if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
186             Err(Error::last_os_error())
187         } else {
188             Ok(())
189         }
190     }
191 }
192 
193 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd194     fn as_raw_fd(&self) -> RawFd {
195         self.socket
196     }
197 }
198 
199 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self200     unsafe fn from_raw_fd(socket: RawFd) -> Self {
201         Self { socket }
202     }
203 }
204 
205 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd206     fn into_raw_fd(self) -> RawFd {
207         let fd = self.socket;
208         mem::forget(self);
209         fd
210     }
211 }
212 
213 impl Drop for VsockListener {
drop(&mut self)214     fn drop(&mut self) {
215         unsafe { close(self.socket) };
216     }
217 }
218 
219 /// A virtio stream between a local and a remote socket.
220 #[derive(Debug, Clone)]
221 pub struct VsockStream {
222     socket: RawFd,
223 }
224 
225 impl VsockStream {
226     /// Open a connection to a remote host.
connect(addr: &SockAddr) -> Result<Self>227     pub fn connect(addr: &SockAddr) -> Result<Self> {
228         let vsock_addr = if let SockAddr::Vsock(addr) = addr {
229             addr.0
230         } else {
231             return Err(Error::new(
232                 ErrorKind::Other,
233                 "requires a virtio socket address",
234             ));
235         };
236 
237         let sock = new_socket();
238         if sock < 0 {
239             return Err(Error::last_os_error());
240         }
241         if unsafe {
242             connect(
243                 sock,
244                 &vsock_addr as *const _ as *const sockaddr,
245                 size_of::<sockaddr_vm>() as socklen_t,
246             )
247         } < 0
248         {
249             Err(Error::last_os_error())
250         } else {
251             Ok(unsafe { VsockStream::from_raw_fd(sock) })
252         }
253     }
254 
255     /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>256     pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
257         Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port)))
258     }
259 
260     /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<SockAddr>261     pub fn peer_addr(&self) -> Result<SockAddr> {
262         let mut vsock_addr = sockaddr_vm {
263             svm_family: AF_VSOCK as sa_family_t,
264             svm_reserved1: 0,
265             svm_port: 0,
266             svm_cid: 0,
267             svm_zero: [0u8; 4],
268         };
269         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
270         if unsafe {
271             getpeername(
272                 self.socket,
273                 &mut vsock_addr as *mut _ as *mut sockaddr,
274                 &mut vsock_addr_len,
275             )
276         } < 0
277         {
278             Err(Error::last_os_error())
279         } else {
280             Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
281         }
282     }
283 
284     /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<SockAddr>285     pub fn local_addr(&self) -> Result<SockAddr> {
286         let mut vsock_addr = sockaddr_vm {
287             svm_family: AF_VSOCK as sa_family_t,
288             svm_reserved1: 0,
289             svm_port: 0,
290             svm_cid: 0,
291             svm_zero: [0u8; 4],
292         };
293         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
294         if unsafe {
295             getsockname(
296                 self.socket,
297                 &mut vsock_addr as *mut _ as *mut sockaddr,
298                 &mut vsock_addr_len,
299             )
300         } < 0
301         {
302             Err(Error::last_os_error())
303         } else {
304             Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
305         }
306     }
307 
308     /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>309     pub fn shutdown(&self, how: Shutdown) -> Result<()> {
310         let how = match how {
311             Shutdown::Write => SHUT_WR,
312             Shutdown::Read => SHUT_RD,
313             Shutdown::Both => SHUT_RDWR,
314         };
315         if unsafe { shutdown(self.socket, how) } < 0 {
316             Err(Error::last_os_error())
317         } else {
318             Ok(())
319         }
320     }
321 
322     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>323     pub fn try_clone(&self) -> Result<Self> {
324         Ok(self.clone())
325     }
326 
327     /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>328     pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
329         let timeout = Self::timeval_from_duration(dur)?;
330         if unsafe {
331             setsockopt(
332                 self.socket,
333                 SOL_SOCKET,
334                 SO_SNDTIMEO,
335                 &timeout as *const _ as *const c_void,
336                 size_of::<timeval>() as socklen_t,
337             )
338         } < 0
339         {
340             Err(Error::last_os_error())
341         } else {
342             Ok(())
343         }
344     }
345 
346     /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>347     pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
348         let timeout = Self::timeval_from_duration(dur)?;
349         if unsafe {
350             setsockopt(
351                 self.socket,
352                 SOL_SOCKET,
353                 SO_RCVTIMEO,
354                 &timeout as *const _ as *const c_void,
355                 size_of::<timeval>() as socklen_t,
356             )
357         } < 0
358         {
359             Err(Error::last_os_error())
360         } else {
361             Ok(())
362         }
363     }
364 
365     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>366     pub fn take_error(&self) -> Result<Option<Error>> {
367         let mut error: i32 = 0;
368         let mut error_len: socklen_t = 0;
369         if unsafe {
370             getsockopt(
371                 self.socket,
372                 SOL_SOCKET,
373                 SO_ERROR,
374                 &mut error as *mut _ as *mut c_void,
375                 &mut error_len,
376             )
377         } < 0
378         {
379             Err(Error::last_os_error())
380         } else {
381             Ok(if error == 0 {
382                 None
383             } else {
384                 Some(Error::from_raw_os_error(error))
385             })
386         }
387     }
388 
389     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>390     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
391         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
392         if unsafe { ioctl(self.socket, FIONBIO, &mut nonblocking) } < 0 {
393             Err(Error::last_os_error())
394         } else {
395             Ok(())
396         }
397     }
398 
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>399     fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
400         match dur {
401             Some(dur) => {
402                 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
403                     return Err(Error::new(
404                         ErrorKind::InvalidInput,
405                         "cannot set a zero duration timeout",
406                     ));
407                 }
408 
409                 // https://github.com/rust-lang/libc/issues/1848
410                 #[cfg_attr(target_env = "musl", allow(deprecated))]
411                 let secs = if dur.as_secs() > time_t::max_value() as u64 {
412                     time_t::max_value()
413                 } else {
414                     dur.as_secs() as time_t
415                 };
416                 let mut timeout = timeval {
417                     tv_sec: secs,
418                     tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
419                 };
420                 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
421                     timeout.tv_usec = 1;
422                 }
423                 Ok(timeout)
424             }
425             None => Ok(timeval {
426                 tv_sec: 0,
427                 tv_usec: 0,
428             }),
429         }
430     }
431 }
432 
433 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>434     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
435         <&Self>::read(&mut &*self, buf)
436     }
437 }
438 
439 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>440     fn write(&mut self, buf: &[u8]) -> Result<usize> {
441         <&Self>::write(&mut &*self, buf)
442     }
443 
flush(&mut self) -> Result<()>444     fn flush(&mut self) -> Result<()> {
445         Ok(())
446     }
447 }
448 
449 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>450     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
451         let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
452         if ret < 0 {
453             Err(Error::last_os_error())
454         } else {
455             Ok(ret as usize)
456         }
457     }
458 }
459 
460 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>461     fn write(&mut self, buf: &[u8]) -> Result<usize> {
462         let ret = unsafe {
463             send(
464                 self.socket,
465                 buf.as_ptr() as *const c_void,
466                 buf.len(),
467                 MSG_NOSIGNAL,
468             )
469         };
470         if ret < 0 {
471             Err(Error::last_os_error())
472         } else {
473             Ok(ret as usize)
474         }
475     }
476 
flush(&mut self) -> Result<()>477     fn flush(&mut self) -> Result<()> {
478         Ok(())
479     }
480 }
481 
482 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd483     fn as_raw_fd(&self) -> RawFd {
484         self.socket
485     }
486 }
487 
488 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self489     unsafe fn from_raw_fd(socket: RawFd) -> Self {
490         Self { socket }
491     }
492 }
493 
494 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd495     fn into_raw_fd(self) -> RawFd {
496         let fd = self.socket;
497         mem::forget(self);
498         fd
499     }
500 }
501 
502 impl Drop for VsockStream {
drop(&mut self)503     fn drop(&mut self) {
504         unsafe { close(self.socket) };
505     }
506 }
507