• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Structs for Unix Domain Socket listener and endpoint.
5 
6 use std::fs::File;
7 use std::io::{ErrorKind, IoSlice, IoSliceMut};
8 use std::marker::PhantomData;
9 use std::path::{Path, PathBuf};
10 
11 use base::{AsRawDescriptor, FromRawDescriptor, RawDescriptor, ScmSocket};
12 
13 use super::{Error, Result};
14 use crate::connection::{Endpoint as EndpointTrait, Listener as ListenerTrait, Req};
15 use crate::message::*;
16 use crate::{SystemListener, SystemStream};
17 
18 /// Unix domain socket listener for accepting incoming connections.
19 pub struct Listener {
20     fd: SystemListener,
21     path: PathBuf,
22 }
23 
24 impl Listener {
25     /// Create a unix domain socket listener.
26     ///
27     /// # Return:
28     /// * - the new Listener object on success.
29     /// * - SocketError: failed to create listener socket.
new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self>30     pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
31         if unlink {
32             let _ = std::fs::remove_file(&path);
33         }
34         let fd = SystemListener::bind(&path).map_err(Error::SocketError)?;
35         Ok(Listener {
36             fd,
37             path: path.as_ref().to_owned(),
38         })
39     }
40 }
41 
42 impl ListenerTrait for Listener {
43     type Connection = SystemStream;
44 
45     /// Accept an incoming connection.
46     ///
47     /// # Return:
48     /// * - Some(SystemListener): new SystemListener object if new incoming connection is available.
49     /// * - None: no incoming connection available.
50     /// * - SocketError: errors from accept().
accept(&mut self) -> Result<Option<Self::Connection>>51     fn accept(&mut self) -> Result<Option<Self::Connection>> {
52         loop {
53             match self.fd.accept() {
54                 Ok((stream, _addr)) => return Ok(Some(stream)),
55                 Err(e) => {
56                     match e.kind() {
57                         // No incoming connection available.
58                         ErrorKind::WouldBlock => return Ok(None),
59                         // New connection closed by peer.
60                         ErrorKind::ConnectionAborted => return Ok(None),
61                         // Interrupted by signals, retry
62                         ErrorKind::Interrupted => continue,
63                         _ => return Err(Error::SocketError(e)),
64                     }
65                 }
66             }
67         }
68     }
69 
70     /// Change blocking status on the listener.
71     ///
72     /// # Return:
73     /// * - () on success.
74     /// * - SocketError: failure from set_nonblocking().
set_nonblocking(&self, block: bool) -> Result<()>75     fn set_nonblocking(&self, block: bool) -> Result<()> {
76         self.fd.set_nonblocking(block).map_err(Error::SocketError)
77     }
78 }
79 
80 impl AsRawDescriptor for Listener {
as_raw_descriptor(&self) -> RawDescriptor81     fn as_raw_descriptor(&self) -> RawDescriptor {
82         self.fd.as_raw_descriptor()
83     }
84 }
85 
86 impl Drop for Listener {
drop(&mut self)87     fn drop(&mut self) {
88         let _ = std::fs::remove_file(&self.path);
89     }
90 }
91 
92 /// Unix domain socket endpoint for vhost-user connection.
93 pub struct Endpoint<R: Req> {
94     sock: SystemStream,
95     _r: PhantomData<R>,
96 }
97 
98 impl<R: Req> From<SystemStream> for Endpoint<R> {
from(sock: SystemStream) -> Self99     fn from(sock: SystemStream) -> Self {
100         Self {
101             sock,
102             _r: PhantomData,
103         }
104     }
105 }
106 
107 impl<R: Req> EndpointTrait<R> for Endpoint<R> {
108     type Listener = Listener;
109 
110     /// Create an endpoint from a stream object.
from_connection( sock: <<Self as EndpointTrait<R>>::Listener as ListenerTrait>::Connection, ) -> Self111     fn from_connection(
112         sock: <<Self as EndpointTrait<R>>::Listener as ListenerTrait>::Connection,
113     ) -> Self {
114         Self {
115             sock,
116             _r: PhantomData,
117         }
118     }
119 
120     /// Create a new stream by connecting to server at `str`.
121     ///
122     /// # Return:
123     /// * - the new Endpoint object on success.
124     /// * - SocketConnect: failed to connect to peer.
connect<P: AsRef<Path>>(path: P) -> Result<Self>125     fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
126         let sock = SystemStream::connect(path).map_err(Error::SocketConnect)?;
127         Ok(Self::from(sock))
128     }
129 
130     /// Sends bytes from scatter-gather vectors over the socket with optional attached file
131     /// descriptors.
132     ///
133     /// # Return:
134     /// * - number of bytes sent on success
135     /// * - SocketRetry: temporary error caused by signals or short of resources.
136     /// * - SocketBroken: the underline socket is broken.
137     /// * - SocketError: other socket related errors.
send_iovec(&mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>) -> Result<usize>138     fn send_iovec(&mut self, iovs: &[IoSlice], fds: Option<&[RawDescriptor]>) -> Result<usize> {
139         let rfds = match fds {
140             Some(rfds) => rfds,
141             _ => &[],
142         };
143         self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into)
144     }
145 
146     /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
147     /// file.
148     ///
149     /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
150     /// tricky to pass file descriptors through such a communication channel. Let's assume that a
151     /// sender sending a message with some file descriptors attached. To successfully receive those
152     /// attached file descriptors, the receiver must obey following rules:
153     ///   1) file descriptors are attached to a message.
154     ///   2) message(packet) boundaries must be respected on the receive side.
155     /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
156     /// attached file descriptors will get lost.
157     /// Note that this function wraps received file descriptors as `File`.
158     ///
159     /// # Return:
160     /// * - (number of bytes received, [received files]) on success
161     /// * - Disconnect: the connection is closed.
162     /// * - SocketRetry: temporary error caused by signals or short of resources.
163     /// * - SocketBroken: the underline socket is broken.
164     /// * - SocketError: other socket related errors.
recv_into_bufs( &mut self, bufs: &mut [IoSliceMut], allow_fd: bool, ) -> Result<(usize, Option<Vec<File>>)>165     fn recv_into_bufs(
166         &mut self,
167         bufs: &mut [IoSliceMut],
168         allow_fd: bool,
169     ) -> Result<(usize, Option<Vec<File>>)> {
170         let mut fd_array = if allow_fd {
171             vec![0; MAX_ATTACHED_FD_ENTRIES]
172         } else {
173             vec![]
174         };
175         let mut iovs: Vec<_> = bufs.iter_mut().map(|s| IoSliceMut::new(s)).collect();
176         let (bytes, fds) = self.sock.recv_iovecs_with_fds(&mut iovs, &mut fd_array)?;
177 
178         // 0-bytes indicates that the connection is closed.
179         if bytes == 0 {
180             return Err(Error::Disconnect);
181         }
182 
183         let files = match fds {
184             0 => None,
185             n => {
186                 let files = fd_array
187                     .iter()
188                     .take(n)
189                     .map(|fd| {
190                         // Safe because we have the ownership of `fd`.
191                         unsafe { File::from_raw_descriptor(*fd as RawDescriptor) }
192                     })
193                     .collect();
194                 Some(files)
195             }
196         };
197 
198         Ok((bytes, files))
199     }
200 }
201 
202 impl<T: Req> AsRawDescriptor for Endpoint<T> {
as_raw_descriptor(&self) -> RawDescriptor203     fn as_raw_descriptor(&self) -> RawDescriptor {
204         self.sock.as_raw_descriptor()
205     }
206 }
207 
208 impl<T: Req> AsMut<SystemStream> for Endpoint<T> {
as_mut(&mut self) -> &mut SystemStream209     fn as_mut(&mut self) -> &mut SystemStream {
210         &mut self.sock
211     }
212 }
213 
214 #[cfg(test)]
215 mod tests {
216     use super::*;
217     use std::io::{Read, Seek, SeekFrom, Write};
218     use std::{mem, slice};
219     use tempfile::{tempfile, Builder, TempDir};
220 
221     use crate::connection::EndpointExt;
222 
temp_dir() -> TempDir223     fn temp_dir() -> TempDir {
224         Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
225     }
226 
227     #[test]
create_listener()228     fn create_listener() {
229         let dir = temp_dir();
230         let mut path = dir.path().to_owned();
231         path.push("sock");
232         let listener = Listener::new(&path, true).unwrap();
233 
234         assert!(listener.as_raw_descriptor() > 0);
235     }
236 
237     #[test]
accept_connection()238     fn accept_connection() {
239         let dir = temp_dir();
240         let mut path = dir.path().to_owned();
241         path.push("sock");
242         let mut listener = Listener::new(&path, true).unwrap();
243         listener.set_nonblocking(true).unwrap();
244 
245         // accept on a fd without incoming connection
246         let conn = listener.accept().unwrap();
247         assert!(conn.is_none());
248     }
249 
250     #[test]
send_data()251     fn send_data() {
252         let dir = temp_dir();
253         let mut path = dir.path().to_owned();
254         path.push("sock");
255         let mut listener = Listener::new(&path, true).unwrap();
256         listener.set_nonblocking(true).unwrap();
257         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
258         let sock = listener.accept().unwrap().unwrap();
259         let mut slave = Endpoint::<MasterReq>::from(sock);
260 
261         let buf1 = vec![0x1, 0x2, 0x3, 0x4];
262         let mut len = master.send_slice(IoSlice::new(&buf1[..]), None).unwrap();
263         assert_eq!(len, 4);
264         let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
265         assert_eq!(bytes, 4);
266         assert_eq!(&buf1[..], &buf2[..bytes]);
267 
268         len = master.send_slice(IoSlice::new(&buf1[..]), None).unwrap();
269         assert_eq!(len, 4);
270         let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
271         assert_eq!(bytes, 2);
272         assert_eq!(&buf1[..2], &buf2[..]);
273         let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
274         assert_eq!(bytes, 2);
275         assert_eq!(&buf1[2..], &buf2[..]);
276     }
277 
278     #[test]
send_fd()279     fn send_fd() {
280         let dir = temp_dir();
281         let mut path = dir.path().to_owned();
282         path.push("sock");
283         let mut listener = Listener::new(&path, true).unwrap();
284         listener.set_nonblocking(true).unwrap();
285         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
286         let sock = listener.accept().unwrap().unwrap();
287         let mut slave = Endpoint::<MasterReq>::from(sock);
288 
289         let mut fd = tempfile().unwrap();
290         write!(fd, "test").unwrap();
291 
292         // Normal case for sending/receiving file descriptors
293         let buf1 = vec![0x1, 0x2, 0x3, 0x4];
294         let len = master
295             .send_slice(IoSlice::new(&buf1[..]), Some(&[fd.as_raw_descriptor()]))
296             .unwrap();
297         assert_eq!(len, 4);
298 
299         let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
300         assert_eq!(bytes, 4);
301         assert_eq!(&buf1[..], &buf2[..]);
302         assert!(files.is_some());
303         let files = files.unwrap();
304         {
305             assert_eq!(files.len(), 1);
306             let mut file = &files[0];
307             let mut content = String::new();
308             file.seek(SeekFrom::Start(0)).unwrap();
309             file.read_to_string(&mut content).unwrap();
310             assert_eq!(content, "test");
311         }
312 
313         // Following communication pattern should work:
314         // Sending side: data(header, body) with fds
315         // Receiving side: data(header) with fds, data(body)
316         let len = master
317             .send_slice(
318                 IoSlice::new(&buf1[..]),
319                 Some(&[
320                     fd.as_raw_descriptor(),
321                     fd.as_raw_descriptor(),
322                     fd.as_raw_descriptor(),
323                 ]),
324             )
325             .unwrap();
326         assert_eq!(len, 4);
327 
328         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
329         assert_eq!(bytes, 2);
330         assert_eq!(&buf1[..2], &buf2[..]);
331         assert!(files.is_some());
332         let files = files.unwrap();
333         {
334             assert_eq!(files.len(), 3);
335             let mut file = &files[1];
336             let mut content = String::new();
337             file.seek(SeekFrom::Start(0)).unwrap();
338             file.read_to_string(&mut content).unwrap();
339             assert_eq!(content, "test");
340         }
341         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
342         assert_eq!(bytes, 2);
343         assert_eq!(&buf1[2..], &buf2[..]);
344         assert!(files.is_none());
345 
346         // Following communication pattern should not work:
347         // Sending side: data(header, body) with fds
348         // Receiving side: data(header), data(body) with fds
349         let len = master
350             .send_slice(
351                 IoSlice::new(&buf1[..]),
352                 Some(&[
353                     fd.as_raw_descriptor(),
354                     fd.as_raw_descriptor(),
355                     fd.as_raw_descriptor(),
356                 ]),
357             )
358             .unwrap();
359         assert_eq!(len, 4);
360 
361         let buf4 = slave.recv_data(2).unwrap();
362         assert_eq!(buf4.len(), 2);
363         assert_eq!(&buf1[..2], &buf4[..]);
364         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
365         assert_eq!(bytes, 2);
366         assert_eq!(&buf1[2..], &buf2[..]);
367         assert!(files.is_none());
368 
369         // Following communication pattern should work:
370         // Sending side: data, data with fds
371         // Receiving side: data, data with fds
372         let len = master.send_slice(IoSlice::new(&buf1[..]), None).unwrap();
373         assert_eq!(len, 4);
374         let len = master
375             .send_slice(
376                 IoSlice::new(&buf1[..]),
377                 Some(&[
378                     fd.as_raw_descriptor(),
379                     fd.as_raw_descriptor(),
380                     fd.as_raw_descriptor(),
381                 ]),
382             )
383             .unwrap();
384         assert_eq!(len, 4);
385 
386         let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
387         assert_eq!(bytes, 4);
388         assert_eq!(&buf1[..], &buf2[..]);
389         assert!(files.is_none());
390 
391         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
392         assert_eq!(bytes, 2);
393         assert_eq!(&buf1[..2], &buf2[..]);
394         assert!(files.is_some());
395         let files = files.unwrap();
396         {
397             assert_eq!(files.len(), 3);
398             let mut file = &files[1];
399             let mut content = String::new();
400             file.seek(SeekFrom::Start(0)).unwrap();
401             file.read_to_string(&mut content).unwrap();
402             assert_eq!(content, "test");
403         }
404         let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
405         assert_eq!(bytes, 2);
406         assert_eq!(&buf1[2..], &buf2[..]);
407         assert!(files.is_none());
408 
409         // Following communication pattern should not work:
410         // Sending side: data1, data2 with fds
411         // Receiving side: data + partial of data2, left of data2 with fds
412         let len = master.send_slice(IoSlice::new(&buf1[..]), None).unwrap();
413         assert_eq!(len, 4);
414         let len = master
415             .send_slice(
416                 IoSlice::new(&buf1[..]),
417                 Some(&[
418                     fd.as_raw_descriptor(),
419                     fd.as_raw_descriptor(),
420                     fd.as_raw_descriptor(),
421                 ]),
422             )
423             .unwrap();
424         assert_eq!(len, 4);
425 
426         let v = slave.recv_data(5).unwrap();
427         assert_eq!(v.len(), 5);
428 
429         let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
430         assert_eq!(bytes, 3);
431         assert!(files.is_none());
432 
433         // If the target fd array is too small, extra file descriptors will get lost.
434         let len = master
435             .send_slice(
436                 IoSlice::new(&buf1[..]),
437                 Some(&[
438                     fd.as_raw_descriptor(),
439                     fd.as_raw_descriptor(),
440                     fd.as_raw_descriptor(),
441                 ]),
442             )
443             .unwrap();
444         assert_eq!(len, 4);
445 
446         let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
447         assert_eq!(bytes, 4);
448         assert!(files.is_some());
449     }
450 
451     #[test]
send_recv()452     fn send_recv() {
453         let dir = temp_dir();
454         let mut path = dir.path().to_owned();
455         path.push("sock");
456         let mut listener = Listener::new(&path, true).unwrap();
457         listener.set_nonblocking(true).unwrap();
458         let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
459         let sock = listener.accept().unwrap().unwrap();
460         let mut slave = Endpoint::<MasterReq>::from(sock);
461 
462         let mut hdr1 =
463             VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
464         hdr1.set_need_reply(true);
465         let features1 = 0x1u64;
466         master.send_message(&hdr1, &features1, None).unwrap();
467 
468         let mut features2 = 0u64;
469         let slice = unsafe {
470             slice::from_raw_parts_mut(
471                 (&mut features2 as *mut u64) as *mut u8,
472                 mem::size_of::<u64>(),
473             )
474         };
475         let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
476         assert_eq!(hdr1, hdr2);
477         assert_eq!(bytes, 8);
478         assert_eq!(features1, features2);
479         assert!(files.is_none());
480 
481         master.send_header(&hdr1, None).unwrap();
482         let (hdr2, files) = slave.recv_header().unwrap();
483         assert_eq!(hdr1, hdr2);
484         assert!(files.is_none());
485     }
486 }
487