1 // Copyright 2022 The Chromium OS Authors. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3
4 //! Unix specific code that keeps rest of the code in the crate platform independent.
5
6 use std::any::Any;
7 use std::fs::File;
8 use std::io::ErrorKind;
9 use std::io::IoSlice;
10 use std::io::IoSliceMut;
11 use std::os::fd::OwnedFd;
12 use std::os::unix::net::UnixListener;
13 use std::os::unix::net::UnixStream;
14 use std::path::Path;
15 use std::path::PathBuf;
16
17 use base::AsRawDescriptor;
18 use base::RawDescriptor;
19 use base::SafeDescriptor;
20 use base::ScmSocket;
21
22 use crate::connection::Listener;
23 use crate::frontend_server::FrontendServer;
24 use crate::message::FrontendReq;
25 use crate::message::MAX_ATTACHED_FD_ENTRIES;
26 use crate::Connection;
27 use crate::Error;
28 use crate::Frontend;
29 use crate::Result;
30
31 /// Alias to enable platform independent code.
32 pub type SystemListener = UnixListener;
33
34 /// Alias to enable platform independent code.
35 pub type SystemStream = UnixStream;
36
37 pub use SocketPlatformConnection as PlatformConnection;
38
39 /// Unix domain socket listener for accepting incoming connections.
40 pub struct SocketListener {
41 fd: SystemListener,
42 drop_path: Option<Box<dyn Any>>,
43 }
44
45 impl SocketListener {
46 /// Create a unix domain socket listener.
47 ///
48 /// # Return:
49 /// * - the new SocketListener object on success.
50 /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>51 pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
52 if unlink {
53 let _ = std::fs::remove_file(&path);
54 }
55 let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
56
57 struct DropPath {
58 path: PathBuf,
59 }
60
61 impl Drop for DropPath {
62 fn drop(&mut self) {
63 let _ = std::fs::remove_file(&self.path);
64 }
65 }
66
67 Ok(SocketListener {
68 fd,
69 drop_path: Some(Box::new(DropPath {
70 path: path.as_ref().to_owned(),
71 })),
72 })
73 }
74
75 /// Take and return the resources that the parent process needs to keep alive as long as the
76 /// child process lives, in case of incoming fork.
take_resources_for_parent(&mut self) -> Option<Box<dyn Any>>77 pub fn take_resources_for_parent(&mut self) -> Option<Box<dyn Any>> {
78 self.drop_path.take()
79 }
80 }
81
82 impl Listener for SocketListener {
83 /// Accept an incoming connection.
84 ///
85 /// # Return:
86 /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
87 /// * - None: no incoming connection available.
88 /// * - SocketError: errors from accept().
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>89 fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>> {
90 loop {
91 match self.fd.accept() {
92 Ok((stream, _addr)) => {
93 return Ok(Some(Connection::from(stream)));
94 }
95 Err(e) => {
96 match e.kind() {
97 // No incoming connection available.
98 ErrorKind::WouldBlock => return Ok(None),
99 // New connection closed by peer.
100 ErrorKind::ConnectionAborted => return Ok(None),
101 // Interrupted by signals, retry
102 ErrorKind::Interrupted => continue,
103 _ => return Err(Error::SocketError(e)),
104 }
105 }
106 }
107 }
108 }
109
110 /// Change blocking status on the listener.
111 ///
112 /// # Return:
113 /// * - () on success.
114 /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>115 fn set_nonblocking(&self, block: bool) -> Result<()> {
116 self.fd.set_nonblocking(block).map_err(Error::SocketError)
117 }
118 }
119
120 impl AsRawDescriptor for SocketListener {
as_raw_descriptor(&self) -> RawDescriptor121 fn as_raw_descriptor(&self) -> RawDescriptor {
122 self.fd.as_raw_descriptor()
123 }
124 }
125
126 /// Unix domain socket based vhost-user connection.
127 pub struct SocketPlatformConnection {
128 sock: ScmSocket<SystemStream>,
129 }
130
131 // TODO: Switch to TryFrom to avoid the unwrap.
132 impl From<SystemStream> for SocketPlatformConnection {
from(sock: SystemStream) -> Self133 fn from(sock: SystemStream) -> Self {
134 Self {
135 sock: sock.try_into().unwrap(),
136 }
137 }
138 }
139
140 // Advance the internal cursor of the slices.
141 // This is same with a nightly API `IoSlice::advance_slices` but for `&[u8]`.
advance_slices(bufs: &mut &mut [&[u8]], mut count: usize)142 fn advance_slices(bufs: &mut &mut [&[u8]], mut count: usize) {
143 use std::mem::take;
144
145 let mut idx = 0;
146 for b in bufs.iter() {
147 if count < b.len() {
148 break;
149 }
150 count -= b.len();
151 idx += 1;
152 }
153 *bufs = &mut take(bufs)[idx..];
154 if !bufs.is_empty() {
155 bufs[0] = &bufs[0][count..];
156 }
157 }
158
159 impl SocketPlatformConnection {
160 /// Create a new stream by connecting to server at `str`.
161 ///
162 /// # Return:
163 /// * - the new SocketPlatformConnection object on success.
164 /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>165 pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
166 let sock = SystemStream::connect(path).map_err(Error::SocketConnect)?;
167 Ok(Self::from(sock))
168 }
169
170 /// Sends all bytes from scatter-gather vectors with optional attached file descriptors. Will
171 /// loop until all data has been transfered.
172 ///
173 /// # TODO
174 /// This function takes a slice of `&[u8]` instead of `IoSlice` because the internal
175 /// cursor needs to be moved by `advance_slices()`.
176 /// Once `IoSlice::advance_slices()` becomes stable, this should be updated.
177 /// <https://github.com/rust-lang/rust/issues/62726>.
send_iovec_all( &self, mut iovs: &mut [&[u8]], mut fds: Option<&[RawDescriptor]>, ) -> Result<()>178 fn send_iovec_all(
179 &self,
180 mut iovs: &mut [&[u8]],
181 mut fds: Option<&[RawDescriptor]>,
182 ) -> Result<()> {
183 // Guarantee that `iovs` becomes empty if it doesn't contain any data.
184 advance_slices(&mut iovs, 0);
185
186 while !iovs.is_empty() {
187 let iovec: Vec<_> = iovs.iter_mut().map(|i| IoSlice::new(i)).collect();
188 match self.sock.send_vectored_with_fds(&iovec, fds.unwrap_or(&[])) {
189 Ok(n) => {
190 fds = None;
191 advance_slices(&mut iovs, n);
192 }
193 Err(e) => match e.kind() {
194 ErrorKind::WouldBlock | ErrorKind::Interrupted => {}
195 _ => return Err(Error::SocketError(e)),
196 },
197 }
198 }
199 Ok(())
200 }
201
202 /// Sends a single message over the socket with optional attached file descriptors.
203 ///
204 /// - `hdr`: vhost message header
205 /// - `body`: vhost message body (may be empty to send a header-only message)
206 /// - `payload`: additional bytes to append to `body` (may be empty)
send_message( &self, hdr: &[u8], body: &[u8], payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>207 pub fn send_message(
208 &self,
209 hdr: &[u8],
210 body: &[u8],
211 payload: &[u8],
212 fds: Option<&[RawDescriptor]>,
213 ) -> Result<()> {
214 let mut iobufs = [hdr, body, payload];
215 self.send_iovec_all(&mut iobufs, fds)
216 }
217
218 /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
219 /// file.
220 ///
221 /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
222 /// tricky to pass file descriptors through such a communication channel. Let's assume that a
223 /// sender sending a message with some file descriptors attached. To successfully receive those
224 /// attached file descriptors, the receiver must obey following rules:
225 /// 1) file descriptors are attached to a message.
226 /// 2) message(packet) boundaries must be respected on the receive side.
227 /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
228 /// attached file descriptors will get lost.
229 /// Note that this function wraps received file descriptors as `File`.
230 ///
231 /// # Return:
232 /// * - (number of bytes received, [received files]) on success
233 /// * - Disconnect: the connection is closed.
234 /// * - SocketRetry: temporary error caused by signals or short of resources.
235 /// * - SocketBroken: the underline socket is broken.
236 /// * - SocketError: other socket related errors.
recv_into_bufs( &self, bufs: &mut [IoSliceMut], allow_fd: bool, ) -> Result<(usize, Option<Vec<File>>)>237 pub fn recv_into_bufs(
238 &self,
239 bufs: &mut [IoSliceMut],
240 allow_fd: bool,
241 ) -> Result<(usize, Option<Vec<File>>)> {
242 let max_fds = if allow_fd { MAX_ATTACHED_FD_ENTRIES } else { 0 };
243 let (bytes, fds) = self.sock.recv_vectored_with_fds(bufs, max_fds)?;
244
245 // 0-bytes indicates that the connection is closed.
246 if bytes == 0 {
247 return Err(Error::Disconnect);
248 }
249
250 let files = if fds.is_empty() {
251 None
252 } else {
253 Some(fds.into_iter().map(File::from).collect())
254 };
255
256 Ok((bytes, files))
257 }
258 }
259
260 impl AsRawDescriptor for SocketPlatformConnection {
as_raw_descriptor(&self) -> RawDescriptor261 fn as_raw_descriptor(&self) -> RawDescriptor {
262 self.sock.as_raw_descriptor()
263 }
264 }
265
266 impl AsMut<SystemStream> for SocketPlatformConnection {
as_mut(&mut self) -> &mut SystemStream267 fn as_mut(&mut self) -> &mut SystemStream {
268 self.sock.inner_mut()
269 }
270 }
271
272 /// Convert a `SafeDescriptor` to a `UnixStream`.
273 ///
274 /// # Safety
275 ///
276 /// `file` must represent a unix domain socket.
to_system_stream(fd: SafeDescriptor) -> Result<SystemStream>277 pub unsafe fn to_system_stream(fd: SafeDescriptor) -> Result<SystemStream> {
278 Ok(fd.into())
279 }
280
281 impl<S: Frontend> AsRawDescriptor for FrontendServer<S> {
282 /// Used for polling.
as_raw_descriptor(&self) -> RawDescriptor283 fn as_raw_descriptor(&self) -> RawDescriptor {
284 self.sub_sock.as_raw_descriptor()
285 }
286 }
287
288 impl<S: Frontend> FrontendServer<S> {
289 /// Create a `FrontendServer` that uses a Unix stream internally.
290 ///
291 /// The returned `SafeDescriptor` is the client side of the stream and should be sent to the
292 /// backend using [BackendClient::set_slave_request_fd()].
293 ///
294 /// [BackendClient::set_slave_request_fd()]: struct.BackendClient.html#method.set_slave_request_fd
with_stream(backend: S) -> Result<(Self, SafeDescriptor)>295 pub fn with_stream(backend: S) -> Result<(Self, SafeDescriptor)> {
296 let (tx, rx) = SystemStream::pair()?;
297 Ok((
298 Self::new(backend, rx)?,
299 SafeDescriptor::from(OwnedFd::from(tx)),
300 ))
301 }
302 }
303
304 #[cfg(test)]
305 pub(crate) mod tests {
306 use tempfile::Builder;
307 use tempfile::TempDir;
308
309 use super::*;
310 use crate::backend_client::BackendClient;
311 use crate::backend_server::Backend;
312 use crate::backend_server::BackendServer;
313 use crate::connection::Listener;
314 use crate::message::FrontendReq;
315 use crate::Connection;
316
temp_dir() -> TempDir317 pub(crate) fn temp_dir() -> TempDir {
318 Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
319 }
320
create_pair() -> (BackendClient, Connection<FrontendReq>)321 pub(crate) fn create_pair() -> (BackendClient, Connection<FrontendReq>) {
322 let dir = temp_dir();
323 let mut path = dir.path().to_owned();
324 path.push("sock");
325 let mut listener = SocketListener::new(&path, true).unwrap();
326 listener.set_nonblocking(true).unwrap();
327 let backend_client = BackendClient::connect(path).unwrap();
328 let server_connection = listener.accept().unwrap().unwrap();
329 (backend_client, server_connection)
330 }
331
create_connection_pair() -> (Connection<FrontendReq>, Connection<FrontendReq>)332 pub(crate) fn create_connection_pair() -> (Connection<FrontendReq>, Connection<FrontendReq>) {
333 let dir = temp_dir();
334 let mut path = dir.path().to_owned();
335 path.push("sock");
336 let mut listener = SocketListener::new(&path, true).unwrap();
337 listener.set_nonblocking(true).unwrap();
338 let client_connection = Connection::<FrontendReq>::connect(path).unwrap();
339 let server_connection = listener.accept().unwrap().unwrap();
340 (client_connection, server_connection)
341 }
342
create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>) where S: Backend,343 pub(crate) fn create_client_server_pair<S>(backend: S) -> (BackendClient, BackendServer<S>)
344 where
345 S: Backend,
346 {
347 let dir = Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap();
348 let mut path = dir.path().to_owned();
349 path.push("sock");
350 let mut listener = SocketListener::new(&path, true).unwrap();
351 let backend_client = BackendClient::connect(&path).unwrap();
352 let connection = listener.accept().unwrap().unwrap();
353 let req_handler = BackendServer::new(connection, backend);
354 (backend_client, req_handler)
355 }
356
357 #[test]
create_listener()358 fn create_listener() {
359 let dir = temp_dir();
360 let mut path = dir.path().to_owned();
361 path.push("sock");
362 let listener = SocketListener::new(&path, true).unwrap();
363
364 assert!(listener.as_raw_descriptor() > 0);
365 }
366
367 #[test]
accept_connection()368 fn accept_connection() {
369 let dir = temp_dir();
370 let mut path = dir.path().to_owned();
371 path.push("sock");
372 let mut listener = SocketListener::new(&path, true).unwrap();
373 listener.set_nonblocking(true).unwrap();
374
375 // accept on a fd without incoming connection
376 let conn = listener.accept().unwrap();
377 assert!(conn.is_none());
378 }
379
380 #[test]
test_create_failure()381 fn test_create_failure() {
382 let dir = temp_dir();
383 let mut path = dir.path().to_owned();
384 path.push("sock");
385 let _ = SocketListener::new(&path, true).unwrap();
386 let _ = SocketListener::new(&path, false).is_err();
387 assert!(BackendClient::connect(&path).is_err());
388
389 let mut listener = SocketListener::new(&path, true).unwrap();
390 assert!(SocketListener::new(&path, false).is_err());
391 listener.set_nonblocking(true).unwrap();
392
393 let _backend_client = BackendClient::connect(&path).unwrap();
394 let _server_connection = listener.accept().unwrap().unwrap();
395 }
396
397 #[test]
test_advance_slices()398 fn test_advance_slices() {
399 // Test case from https://doc.rust-lang.org/std/io/struct.IoSlice.html#method.advance_slices
400 let buf1 = [1; 8];
401 let buf2 = [2; 16];
402 let buf3 = [3; 8];
403 let mut bufs = &mut [&buf1[..], &buf2[..], &buf3[..]][..];
404 advance_slices(&mut bufs, 10);
405 assert_eq!(bufs[0], [2; 14].as_ref());
406 assert_eq!(bufs[1], [3; 8].as_ref());
407 }
408 }
409