• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022 The Chromium OS Authors. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Unix specific code that keeps rest of the code in the crate platform independent.
5 
6 use std::any::Any;
7 use std::fs::File;
8 use std::io::ErrorKind;
9 use std::io::IoSlice;
10 use std::io::IoSliceMut;
11 use std::os::fd::OwnedFd;
12 use std::os::unix::net::UnixListener;
13 use std::os::unix::net::UnixStream;
14 use std::path::Path;
15 use std::path::PathBuf;
16 
17 use base::AsRawDescriptor;
18 use base::RawDescriptor;
19 use base::SafeDescriptor;
20 use base::ScmSocket;
21 
22 use crate::connection::Listener;
23 use crate::frontend_server::FrontendServer;
24 use crate::message::FrontendReq;
25 use crate::message::MAX_ATTACHED_FD_ENTRIES;
26 use crate::Connection;
27 use crate::Error;
28 use crate::Frontend;
29 use crate::Result;
30 
31 /// Alias to enable platform independent code.
32 pub type SystemListener = UnixListener;
33 
34 /// Alias to enable platform independent code.
35 pub type SystemStream = UnixStream;
36 
37 pub use SocketPlatformConnection as PlatformConnection;
38 
39 /// Unix domain socket listener for accepting incoming connections.
40 pub struct SocketListener {
41     fd: SystemListener,
42     drop_path: Option<Box<dyn Any>>,
43 }
44 
45 impl SocketListener {
46     /// Create a unix domain socket listener.
47     ///
48     /// # Return:
49     /// * - the new SocketListener object on success.
50     /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>51     pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
52         if unlink {
53             let _ = std::fs::remove_file(&path);
54         }
55         let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
56 
57         struct DropPath {
58             path: PathBuf,
59         }
60 
61         impl Drop for DropPath {
62             fn drop(&mut self) {
63                 let _ = std::fs::remove_file(&self.path);
64             }
65         }
66 
67         Ok(SocketListener {
68             fd,
69             drop_path: Some(Box::new(DropPath {
70                 path: path.as_ref().to_owned(),
71             })),
72         })
73     }
74 
75     /// Take and return the resources that the parent process needs to keep alive as long as the
76     /// child process lives, in case of incoming fork.
take_resources_for_parent(&mut self) -> Option<Box<dyn Any>>77     pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
78         self.drop_path.take()
79     }
80 }
81 
82 impl Listener for SocketListener {
83     /// Accept an incoming connection.
84     ///
85     /// # Return:
86     /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
87     /// * - None: no incoming connection available.
88     /// * - SocketError: errors from accept().
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>89     fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>> {
90         loop {
91             match self.fd.accept() {
92                 Ok((stream, _addr)) => {
93                     return Ok(Some(Connection::from(stream)));
94                 }
95                 Err(e) => {
96                     match e.kind() {
97                         // No incoming connection available.
98                         ErrorKind::WouldBlock => return Ok(None),
99                         // New connection closed by peer.
100                         ErrorKind::ConnectionAborted => return Ok(None),
101                         // Interrupted by signals, retry
102                         ErrorKind::Interrupted => continue,
103                         _ => return Err(Error::SocketError(e)),
104                     }
105                 }
106             }
107         }
108     }
109 
110     /// Change blocking status on the listener.
111     ///
112     /// # Return:
113     /// * - () on success.
114     /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>115     fn set_nonblocking(&self, block: bool) -> Result<()> {
116         self.fd.set_nonblocking(block).map_err(Error::SocketError)
117     }
118 }
119 
120 impl AsRawDescriptor for SocketListener {
as_raw_descriptor(&self) -> RawDescriptor121     fn as_raw_descriptor(&self) -> RawDescriptor {
122         self.fd.as_raw_descriptor()
123     }
124 }
125 
126 /// Unix domain socket based vhost-user connection.
127 pub struct SocketPlatformConnection {
128     sock: ScmSocket<SystemStream>,
129 }
130 
131 // TODO: Switch to TryFrom to avoid the unwrap.
132 impl From<SystemStream> for SocketPlatformConnection {
from(sock: SystemStream) -> Self133     fn from(sock: SystemStream) -> Self {
134         Self {
135             sock: sock.try_into().unwrap(),
136         }
137     }
138 }
139 
140 // Advance the internal cursor of the slices.
141 // This is same with a nightly API `IoSlice::advance_slices` but for `&[u8]`.
advance_slices(bufs: &mut &mut [&[u8]], mut count: usize)142 fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
143     use std::mem::take;
144 
145     let mut idx = 0;
146     for b in bufs.iter() {
147         if count < b.len() {
148             break;
149         }
150         count -= b.len();
151         idx += 1;
152     }
153     *bufs = &mut take(bufs)[idx..];
154     if !bufs.is_empty() {
155         bufs[0] = &bufs[0][count..];
156     }
157 }
158 
159 impl SocketPlatformConnection {
160     /// Create a new stream by connecting to server at `str`.
161     ///
162     /// # Return:
163     /// * - the new SocketPlatformConnection object on success.
164     /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>165     pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
166         let sock = SystemStream::connect(path).map_err(Error::SocketConnect)?;
167         Ok(Self::from(sock))
168     }
169 
170     /// Sends all bytes from scatter-gather vectors with optional attached file descriptors. Will
171     /// loop until all data has been transfered.
172     ///
173     /// # TODO
174     /// This function takes a slice of `&[u8]` instead of `IoSlice` because the internal
175     /// cursor needs to be moved by `advance_slices()`.
176     /// Once `IoSlice::advance_slices()` becomes stable, this should be updated.
177     /// <https://github.com/rust-lang/rust/issues/62726>.
send_iovec_all( &self, mut iovs: &mut [&[u8]], mut fds: Option<&[RawDescriptor]>, ) -> Result<()>178     fn send_iovec_all(
179         &self,
180         mut iovs: &mut [&[u8]],
181         mut fds: Option<&[RawDescriptor]>,
182     ) -> Result<()> {
183         // Guarantee that `iovs` becomes empty if it doesn't contain any data.
184         advance_slices(&mut iovs, 0);
185 
186         while !iovs.is_empty() {
187             let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
188             match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
189                 Ok(n) => {
190                     fds = None;
191                     advance_slices(&mut iovs, n);
192                 }
193                 Err(e) => match e.kind() {
194                     ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
195                     _ => return Err(Error::SocketError(e)),
196                 },
197             }
198         }
199         Ok(())
200     }
201 
202     /// Sends a single message over the socket with optional attached file descriptors.
203     ///
204     /// - `hdr`: vhost message header
205     /// - `body`: vhost message body (may be empty to send a header-only message)
206     /// - `payload`: additional bytes to append to `body` (may be empty)
send_message( &self, hdr: &[u8], body: &[u8], payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>207     pub fn send_message(
208         &self,
209         hdr: &[u8],
210         body: &[u8],
211         payload: &[u8],
212         fds: Option<&[RawDescriptor]>,
213     ) -> Result<()> {
214         let mut iobufs = [hdr, body, payload];
215         self.send_iovec_all(&mut iobufs, fds)
216     }
217 
218     /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
219     /// file.
220     ///
221     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
222     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
223     /// sender sending a message with some file descriptors attached. To successfully receive those
224     /// attached file descriptors, the receiver must obey following rules:
225     ///   1) file descriptors are attached to a message.
226     ///   2) message(packet) boundaries must be respected on the receive side.
227     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
228     /// attached file descriptors will get lost.
229     /// Note that this function wraps received file descriptors as `File`.
230     ///
231     /// # Return:
232     /// * - (number of bytes received, [received files]) on success
233     /// * - Disconnect: the connection is closed.
234     /// * - SocketRetry: temporary error caused by signals or short of resources.
235     /// * - SocketBroken: the underline socket is broken.
236     /// * - SocketError: other socket related errors.
recv_into_bufs( &self, bufs: &mut [IoSliceMut], allow_fd: bool, ) -> Result<(usize, Option<Vec<File>>)>237     pub fn recv_into_bufs(
238         &self,
239         bufs: &mut [IoSliceMut],
240         allow_fd: bool,
241     ) -> Result<(usize, Option<Vec<File>>)> {
242         let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
243         let (bytes, fds) = self.sock.recv_vectored_with_fds(bufs, max_fds)?;
244 
245         // 0-bytes indicates that the connection is closed.
246         if bytes == 0 {
247             return Err(Error::Disconnect);
248         }
249 
250         let files = if fds.is_empty() {
251             None
252         } else {
253             Some(fds.into_iter().map(File::from).collect())
254         };
255 
256         Ok((bytes, files))
257     }
258 }
259 
260 impl AsRawDescriptor for SocketPlatformConnection {
as_raw_descriptor(&self) -> RawDescriptor261     fn as_raw_descriptor(&self) -> RawDescriptor {
262         self.sock.as_raw_descriptor()
263     }
264 }
265 
266 impl AsMut<SystemStream> for SocketPlatformConnection {
as_mut(&mut self) -> &mut SystemStream267     fn as_mut(&mut self) -> &mut SystemStream {
268         self.sock.inner_mut()
269     }
270 }
271 
272 /// Convert a `SafeDescriptor` to a `UnixStream`.
273 ///
274 /// # Safety
275 ///
276 /// `file` must represent a unix domain socket.
to_system_stream(fd: SafeDescriptor) -> Result<SystemStream>277 pub unsafe fn to_system_stream(fd: SafeDescriptor) -> Result<SystemStream> {
278     Ok(fd.into())
279 }
280 
281 impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
282     /// Used for polling.
as_raw_descriptor(&self) -> RawDescriptor283     fn as_raw_descriptor(&self) -> RawDescriptor {
284         self.sub_sock.as_raw_descriptor()
285     }
286 }
287 
288 impl<S: Frontend> FrontendServer<S> {
289     /// Create a `FrontendServer` that uses a Unix stream internally.
290     ///
291     /// The returned `SafeDescriptor` is the client side of the stream and should be sent to the
292     /// backend using [BackendClient::set_slave_request_fd()].
293     ///
294     /// [BackendClient::set_slave_request_fd()]: struct.BackendClient.html#method.set_slave_request_fd
with_stream(backend: S) -> Result<(Self, SafeDescriptor)>295     pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
296         let (tx, rx) = SystemStream::pair()?;
297         Ok((
298             Self::new(backend, rx)?,
299             SafeDescriptor::from(OwnedFd::from(tx)),
300         ))
301     }
302 }
303 
304 #[cfg(test)]
305 pub(crate) mod tests {
306     use tempfile::Builder;
307     use tempfile::TempDir;
308 
309     use super::*;
310     use crate::backend_client::BackendClient;
311     use crate::backend_server::Backend;
312     use crate::backend_server::BackendServer;
313     use crate::connection::Listener;
314     use crate::message::FrontendReq;
315     use crate::Connection;
316 
temp_dir() -> TempDir317     pub(crate) fn temp_dir() -> TempDir {
318         Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
319     }
320 
create_pair() -> (BackendClient, Connection<FrontendReq>)321     pub(crate) fn create_pair() -> (BackendClient, Connection<FrontendReq>) {
322         let dir = temp_dir();
323         let mut path = dir.path().to_owned();
324         path.push("sock");
325         let mut listener = SocketListener::new(&path, true).unwrap();
326         listener.set_nonblocking(true).unwrap();
327         let backend_client = BackendClient::connect(path).unwrap();
328         let server_connection = listener.accept().unwrap().unwrap();
329         (backend_client, server_connection)
330     }
331 
create_connection_pair() -> (Connection<FrontendReq>, Connection<FrontendReq>)332     pub(crate) fn create_connection_pair() -> (Connection<FrontendReq>, Connection<FrontendReq>) {
333         let dir = temp_dir();
334         let mut path = dir.path().to_owned();
335         path.push("sock");
336         let mut listener = SocketListener::new(&path, true).unwrap();
337         listener.set_nonblocking(true).unwrap();
338         let client_connection = Connection::<FrontendReq>::connect(path).unwrap();
339         let server_connection = listener.accept().unwrap().unwrap();
340         (client_connection, server_connection)
341     }
342 
create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>) where S: Backend,343     pub(crate) fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
344     where
345         S: Backend,
346     {
347         let dir = Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap();
348         let mut path = dir.path().to_owned();
349         path.push("sock");
350         let mut listener = SocketListener::new(&path, true).unwrap();
351         let backend_client = BackendClient::connect(&path).unwrap();
352         let connection = listener.accept().unwrap().unwrap();
353         let req_handler = BackendServer::new(connection, backend);
354         (backend_client, req_handler)
355     }
356 
357     #[test]
create_listener()358     fn create_listener() {
359         let dir = temp_dir();
360         let mut path = dir.path().to_owned();
361         path.push("sock");
362         let listener = SocketListener::new(&path, true).unwrap();
363 
364         assert!(listener.as_raw_descriptor() > 0);
365     }
366 
367     #[test]
accept_connection()368     fn accept_connection() {
369         let dir = temp_dir();
370         let mut path = dir.path().to_owned();
371         path.push("sock");
372         let mut listener = SocketListener::new(&path, true).unwrap();
373         listener.set_nonblocking(true).unwrap();
374 
375         // accept on a fd without incoming connection
376         let conn = listener.accept().unwrap();
377         assert!(conn.is_none());
378     }
379 
380     #[test]
test_create_failure()381     fn test_create_failure() {
382         let dir = temp_dir();
383         let mut path = dir.path().to_owned();
384         path.push("sock");
385         let _ = SocketListener::new(&path, true).unwrap();
386         let _ = SocketListener::new(&path, false).is_err();
387         assert!(BackendClient::connect(&path).is_err());
388 
389         let mut listener = SocketListener::new(&path, true).unwrap();
390         assert!(SocketListener::new(&path, false).is_err());
391         listener.set_nonblocking(true).unwrap();
392 
393         let _backend_client = BackendClient::connect(&path).unwrap();
394         let _server_connection = listener.accept().unwrap().unwrap();
395     }
396 
397     #[test]
test_advance_slices()398     fn test_advance_slices() {
399         // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
400         let buf1 = [1; 8];
401         let buf2 = [2; 16];
402         let buf3 = [3; 8];
403         let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
404         advance_slices(&mut bufs, 10);
405         assert_eq!(bufs[0], [2; 14].as_ref());
406         assert_eq!(bufs[1], [3; 8].as_ref());
407     }
408 }
409