• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2017 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Used to send and receive messages with file descriptors on sockets that accept control messages
6 //! (e.g. Unix domain sockets).
7 
8 use std::fs::File;
9 use std::mem::size_of;
10 use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
11 use std::os::unix::net::{UnixDatagram, UnixStream};
12 use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
13 
14 use libc::{
15     c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
16 };
17 
18 use data_model::VolatileSlice;
19 
20 use crate::net::UnixSeqpacket;
21 use crate::{Error, Result};
22 
23 // Each of the following macros performs the same function as their C counterparts. They are each
24 // macros because they are used to size statically allocated arrays.
25 
26 macro_rules! CMSG_ALIGN {
27     ($len:expr) => {
28         (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
29     };
30 }
31 
32 macro_rules! CMSG_SPACE {
33     ($len:expr) => {
34         size_of::<cmsghdr>() + CMSG_ALIGN!($len)
35     };
36 }
37 
38 macro_rules! CMSG_LEN {
39     ($len:expr) => {
40         size_of::<cmsghdr>() + ($len)
41     };
42 }
43 
44 // This function (macro in the C version) is not used in any compile time constant slots, so is just
45 // an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
46 // module supports.
47 #[allow(non_snake_case)]
48 #[inline(always)]
CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd49 fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
50     // Essentially returns a pointer to just past the header.
51     cmsg_buffer.wrapping_offset(1) as *mut RawFd
52 }
53 
54 // This function is like CMSG_NEXT, but safer because it reads only from references, although it
55 // does some pointer arithmetic on cmsg_ptr.
56 #[allow(clippy::cast_ptr_alignment)]
get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr57 fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
58     let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr;
59     if next_cmsg
60         .wrapping_offset(1)
61         .wrapping_sub(msghdr.msg_control as usize) as usize
62         > msghdr.msg_controllen
63     {
64         null_mut()
65     } else {
66         next_cmsg
67     }
68 }
69 
70 const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
71 
72 enum CmsgBuffer {
73     Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
74     Heap(Box<[cmsghdr]>),
75 }
76 
77 impl CmsgBuffer {
with_capacity(capacity: usize) -> CmsgBuffer78     fn with_capacity(capacity: usize) -> CmsgBuffer {
79         let cap_in_cmsghdr_units =
80             (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
81         if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
82             CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
83         } else {
84             CmsgBuffer::Heap(
85                 vec![
86                     cmsghdr {
87                         cmsg_len: 0,
88                         cmsg_level: 0,
89                         cmsg_type: 0,
90                     };
91                     cap_in_cmsghdr_units
92                 ]
93                 .into_boxed_slice(),
94             )
95         }
96     }
97 
as_mut_ptr(&mut self) -> *mut cmsghdr98     fn as_mut_ptr(&mut self) -> *mut cmsghdr {
99         match self {
100             CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
101             CmsgBuffer::Heap(a) => a.as_mut_ptr(),
102         }
103     }
104 }
105 
raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: D, out_fds: &[RawFd]) -> Result<usize>106 fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: D, out_fds: &[RawFd]) -> Result<usize> {
107     let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
108     let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
109 
110     let mut iovec = iovec {
111         iov_base: out_data.as_ptr() as *mut c_void,
112         iov_len: out_data.size(),
113     };
114 
115     let mut msg = msghdr {
116         msg_name: null_mut(),
117         msg_namelen: 0,
118         msg_iov: &mut iovec as *mut iovec,
119         msg_iovlen: 1,
120         msg_control: null_mut(),
121         msg_controllen: 0,
122         msg_flags: 0,
123     };
124 
125     if !out_fds.is_empty() {
126         let cmsg = cmsghdr {
127             cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
128             cmsg_level: SOL_SOCKET,
129             cmsg_type: SCM_RIGHTS,
130         };
131         unsafe {
132             // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
133             write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
134             // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
135             // file descriptors.
136             copy_nonoverlapping(
137                 out_fds.as_ptr(),
138                 CMSG_DATA(cmsg_buffer.as_mut_ptr()),
139                 out_fds.len(),
140             );
141         }
142 
143         msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
144         msg.msg_controllen = cmsg_capacity;
145     }
146 
147     // Safe because the msghdr was properly constructed from valid (or null) pointers of the
148     // indicated length and we check the return value.
149     let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
150 
151     if write_count == -1 {
152         Err(Error::last())
153     } else {
154         Ok(write_count as usize)
155     }
156 }
157 
raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)>158 fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
159     let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
160     let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
161 
162     let mut iovec = iovec {
163         iov_base: in_data.as_mut_ptr() as *mut c_void,
164         iov_len: in_data.len(),
165     };
166 
167     let mut msg = msghdr {
168         msg_name: null_mut(),
169         msg_namelen: 0,
170         msg_iov: &mut iovec as *mut iovec,
171         msg_iovlen: 1,
172         msg_control: null_mut(),
173         msg_controllen: 0,
174         msg_flags: 0,
175     };
176 
177     if !in_fds.is_empty() {
178         msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
179         msg.msg_controllen = cmsg_capacity;
180     }
181 
182     // Safe because the msghdr was properly constructed from valid (or null) pointers of the
183     // indicated length and we check the return value.
184     let total_read = unsafe { recvmsg(fd, &mut msg, 0) };
185 
186     if total_read == -1 {
187         return Err(Error::last());
188     }
189 
190     if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() {
191         return Ok((0, 0));
192     }
193 
194     let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
195     let mut in_fds_count = 0;
196     while !cmsg_ptr.is_null() {
197         // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
198         // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
199         let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
200 
201         if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
202             let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>();
203             unsafe {
204                 copy_nonoverlapping(
205                     CMSG_DATA(cmsg_ptr),
206                     in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
207                     fd_count,
208                 );
209             }
210             in_fds_count += fd_count;
211         }
212 
213         cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
214     }
215 
216     Ok((total_read as usize, in_fds_count))
217 }
218 
219 /// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
220 /// `recvmsg`.
221 pub trait ScmSocket {
222     /// Gets the file descriptor of this socket.
socket_fd(&self) -> RawFd223     fn socket_fd(&self) -> RawFd;
224 
225     /// Sends the given data and file descriptor over the socket.
226     ///
227     /// On success, returns the number of bytes sent.
228     ///
229     /// # Arguments
230     ///
231     /// * `buf` - A buffer of data to send on the `socket`.
232     /// * `fd` - A file descriptors to be sent.
send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize>233     fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
234         self.send_with_fds(buf, &[fd])
235     }
236 
237     /// Sends the given data and file descriptors over the socket.
238     ///
239     /// On success, returns the number of bytes sent.
240     ///
241     /// # Arguments
242     ///
243     /// * `buf` - A buffer of data to send on the `socket`.
244     /// * `fds` - A list of file descriptors to be sent.
send_with_fds<D: IntoIovec>(&self, buf: D, fd: &[RawFd]) -> Result<usize>245     fn send_with_fds<D: IntoIovec>(&self, buf: D, fd: &[RawFd]) -> Result<usize> {
246         raw_sendmsg(self.socket_fd(), buf, fd)
247     }
248 
249     /// Receives data and potentially a file descriptor from the socket.
250     ///
251     /// On success, returns the number of bytes and an optional file descriptor.
252     ///
253     /// # Arguments
254     ///
255     /// * `buf` - A buffer to receive data from the socket.vm
recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)>256     fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
257         let mut fd = [0];
258         let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?;
259         let file = if fd_count == 0 {
260             None
261         } else {
262             // Safe because the first fd from recv_with_fds is owned by us and valid because this
263             // branch was taken.
264             Some(unsafe { File::from_raw_fd(fd[0]) })
265         };
266         Ok((read_count, file))
267     }
268 
269     /// Receives data and file descriptors from the socket.
270     ///
271     /// On success, returns the number of bytes and file descriptors received as a tuple
272     /// `(bytes count, files count)`.
273     ///
274     /// # Arguments
275     ///
276     /// * `buf` - A buffer to receive data from the socket.
277     /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
278     ///           number of valid file descriptors is indicated by the second element of the
279     ///           returned tuple. The caller owns these file descriptors, but they will not be
280     ///           closed on drop like a `File`-like type would be. It is recommended that each valid
281     ///           file descriptor gets wrapped in a drop type that closes it after this returns.
recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)>282     fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> {
283         raw_recvmsg(self.socket_fd(), buf, fds)
284     }
285 }
286 
287 impl ScmSocket for UnixDatagram {
socket_fd(&self) -> RawFd288     fn socket_fd(&self) -> RawFd {
289         self.as_raw_fd()
290     }
291 }
292 
293 impl ScmSocket for UnixStream {
socket_fd(&self) -> RawFd294     fn socket_fd(&self) -> RawFd {
295         self.as_raw_fd()
296     }
297 }
298 impl ScmSocket for UnixSeqpacket {
socket_fd(&self) -> RawFd299     fn socket_fd(&self) -> RawFd {
300         self.as_raw_fd()
301     }
302 }
303 
304 /// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for
305 /// the lifetime of this object.
306 ///
307 /// This trait is unsafe because interfaces that use this trait depend on the base pointer and size
308 /// being accurate.
309 pub unsafe trait IntoIovec {
310     /// Gets the base pointer of this `iovec`.
as_ptr(&self) -> *const c_void311     fn as_ptr(&self) -> *const c_void;
312 
313     /// Gets the size in bytes of this `iovec`.
size(&self) -> usize314     fn size(&self) -> usize;
315 }
316 
317 // Safe because this slice can not have another mutable reference and it's pointer and size are
318 // guaranteed to be valid.
319 unsafe impl<'a> IntoIovec for &'a [u8] {
320     // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480
321     #[allow(clippy::useless_asref)]
as_ptr(&self) -> *const c_void322     fn as_ptr(&self) -> *const c_void {
323         self.as_ref().as_ptr() as *const c_void
324     }
325 
size(&self) -> usize326     fn size(&self) -> usize {
327         self.len()
328     }
329 }
330 
331 // Safe because volatile slices are only ever accessed with other volatile interfaces and the
332 // pointer and size are guaranteed to be accurate.
333 unsafe impl<'a> IntoIovec for VolatileSlice<'a> {
as_ptr(&self) -> *const c_void334     fn as_ptr(&self) -> *const c_void {
335         self.as_ptr() as *const c_void
336     }
337 
size(&self) -> usize338     fn size(&self) -> usize {
339         self.size() as usize
340     }
341 }
342 
343 #[cfg(test)]
344 mod tests {
345     use super::*;
346 
347     use std::io::Write;
348     use std::mem::size_of;
349     use std::os::raw::c_long;
350     use std::os::unix::net::UnixDatagram;
351     use std::slice::from_raw_parts;
352 
353     use libc::cmsghdr;
354 
355     use crate::EventFd;
356 
357     #[test]
buffer_len()358     fn buffer_len() {
359         assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>());
360         assert_eq!(
361             CMSG_SPACE!(1 * size_of::<RawFd>()),
362             size_of::<cmsghdr>() + size_of::<c_long>()
363         );
364         if size_of::<RawFd>() == 4 {
365             assert_eq!(
366                 CMSG_SPACE!(2 * size_of::<RawFd>()),
367                 size_of::<cmsghdr>() + size_of::<c_long>()
368             );
369             assert_eq!(
370                 CMSG_SPACE!(3 * size_of::<RawFd>()),
371                 size_of::<cmsghdr>() + size_of::<c_long>() * 2
372             );
373             assert_eq!(
374                 CMSG_SPACE!(4 * size_of::<RawFd>()),
375                 size_of::<cmsghdr>() + size_of::<c_long>() * 2
376             );
377         } else if size_of::<RawFd>() == 8 {
378             assert_eq!(
379                 CMSG_SPACE!(2 * size_of::<RawFd>()),
380                 size_of::<cmsghdr>() + size_of::<c_long>() * 2
381             );
382             assert_eq!(
383                 CMSG_SPACE!(3 * size_of::<RawFd>()),
384                 size_of::<cmsghdr>() + size_of::<c_long>() * 3
385             );
386             assert_eq!(
387                 CMSG_SPACE!(4 * size_of::<RawFd>()),
388                 size_of::<cmsghdr>() + size_of::<c_long>() * 4
389             );
390         }
391     }
392 
393     #[test]
send_recv_no_fd()394     fn send_recv_no_fd() {
395         let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
396 
397         let write_count = s1
398             .send_with_fds([1u8, 1, 2, 21, 34, 55].as_ref(), &[])
399             .expect("failed to send data");
400 
401         assert_eq!(write_count, 6);
402 
403         let mut buf = [0; 6];
404         let mut files = [0; 1];
405         let (read_count, file_count) = s2
406             .recv_with_fds(&mut buf[..], &mut files)
407             .expect("failed to recv data");
408 
409         assert_eq!(read_count, 6);
410         assert_eq!(file_count, 0);
411         assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
412     }
413 
414     #[test]
send_recv_only_fd()415     fn send_recv_only_fd() {
416         let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
417 
418         let evt = EventFd::new().expect("failed to create eventfd");
419         let write_count = s1
420             .send_with_fd([].as_ref(), evt.as_raw_fd())
421             .expect("failed to send fd");
422 
423         assert_eq!(write_count, 0);
424 
425         let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
426 
427         let mut file = file_opt.unwrap();
428 
429         assert_eq!(read_count, 0);
430         assert!(file.as_raw_fd() >= 0);
431         assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
432         assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
433         assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
434 
435         file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
436             .expect("failed to write to sent fd");
437 
438         assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
439     }
440 
441     #[test]
send_recv_with_fd()442     fn send_recv_with_fd() {
443         let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
444 
445         let evt = EventFd::new().expect("failed to create eventfd");
446         let write_count = s1
447             .send_with_fds([237].as_ref(), &[evt.as_raw_fd()])
448             .expect("failed to send fd");
449 
450         assert_eq!(write_count, 1);
451 
452         let mut files = [0; 2];
453         let mut buf = [0u8];
454         let (read_count, file_count) = s2
455             .recv_with_fds(&mut buf, &mut files)
456             .expect("failed to recv fd");
457 
458         assert_eq!(read_count, 1);
459         assert_eq!(buf[0], 237);
460         assert_eq!(file_count, 1);
461         assert!(files[0] >= 0);
462         assert_ne!(files[0], s1.as_raw_fd());
463         assert_ne!(files[0], s2.as_raw_fd());
464         assert_ne!(files[0], evt.as_raw_fd());
465 
466         let mut file = unsafe { File::from_raw_fd(files[0]) };
467 
468         file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
469             .expect("failed to write to sent fd");
470 
471         assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
472     }
473 }
474