• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{
6     alloc::Layout,
7     cmp::min,
8     convert::TryFrom,
9     io,
10     mem::{align_of, size_of},
11     os::unix::io::RawFd,
12 };
13 
14 use anyhow::anyhow;
15 use sys_util::LayoutAllocation;
16 
17 // Allocates a buffer to hold a `libc::cmsghdr` with `cap` bytes of data.
18 //
19 // Returns the `LayoutAllocation` for the buffer as well as the size of the allocation, which is
20 // guaranteed to be at least `size_of::<libc::cmsghdr>() + cap` bytes.
allocate_cmsg_buffer(cap: u32) -> anyhow::Result<(LayoutAllocation, usize)>21 pub fn allocate_cmsg_buffer(cap: u32) -> anyhow::Result<(LayoutAllocation, usize)> {
22     // Not sure why this is unsafe.
23     let cmsg_cap = usize::try_from(unsafe { libc::CMSG_SPACE(cap) })
24         .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
25     let alloc = Layout::from_size_align(cmsg_cap, align_of::<libc::cmsghdr>())
26         .map(LayoutAllocation::zeroed)
27         .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
28 
29     Ok((alloc, cmsg_cap))
30 }
31 
32 // Adds a control message with the file descriptors from `fds` to `msg`.
33 // Note: this doesn't append but expects no cmsg already set and puts `fds` as a
34 // single `cmsg` inside passed `msg`
35 //
36 // Returns the `LayoutAllocation` backing the control message.
add_fds_to_message( msg: &mut libc::msghdr, fds: &[RawFd], ) -> anyhow::Result<LayoutAllocation>37 pub fn add_fds_to_message(
38     msg: &mut libc::msghdr,
39     fds: &[RawFd],
40 ) -> anyhow::Result<LayoutAllocation> {
41     let fd_len = fds
42         .len()
43         .checked_mul(size_of::<RawFd>())
44         .and_then(|l| u32::try_from(l).ok())
45         .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))?;
46 
47     let (cmsg_buffer, cmsg_cap) = allocate_cmsg_buffer(fd_len)?;
48 
49     if !msg.msg_control.is_null() {
50         anyhow::bail!("msg already contains cmsg");
51     }
52 
53     msg.msg_control = cmsg_buffer.as_ptr();
54     msg.msg_controllen = cmsg_cap;
55 
56     unsafe {
57         // Safety:
58         // * CMSG_FIRSTHDR will either return a null pointer or a pointer to `msg.msg_control`.
59         // * `msg.msg_control` is properly aligned because `cmsg_buffer` is properly aligned.
60         // * The buffer is zeroed, which is a valid bit-pattern for `libc::cmsghdr`.
61         // * The reference does not escape this function.
62         let cmsg = libc::CMSG_FIRSTHDR(msg).as_mut().unwrap();
63         cmsg.cmsg_len = libc::CMSG_LEN(fd_len) as libc::size_t;
64         cmsg.cmsg_level = libc::SOL_SOCKET;
65         cmsg.cmsg_type = libc::SCM_RIGHTS;
66 
67         // Safety: `libc::CMSG_DATA(cmsg)` and `fds` are valid for `fd_len` bytes of memory.
68         libc::memcpy(
69             libc::CMSG_DATA(cmsg).cast(),
70             fds.as_ptr().cast(),
71             fd_len as usize,
72         );
73     }
74 
75     Ok(cmsg_buffer)
76 }
77 
78 // Copies file descriptors from the control message in `msg` into `fds`.
79 //
80 // Returns the number of file descriptors that were copied from `msg`.
take_fds_from_message(msg: &libc::msghdr, fds: &mut [RawFd]) -> anyhow::Result<usize>81 pub fn take_fds_from_message(msg: &libc::msghdr, fds: &mut [RawFd]) -> anyhow::Result<usize> {
82     let cap = fds
83         .len()
84         .checked_mul(size_of::<RawFd>())
85         .ok_or_else(|| anyhow!(io::Error::from(io::ErrorKind::InvalidInput)))?;
86 
87     let mut rem = cap;
88     let mut fd_pos = 0;
89     unsafe {
90         let mut cmsg = libc::CMSG_FIRSTHDR(msg);
91 
92         // Safety:
93         // * CMSG_FIRSTHDR will either return a null pointer or a pointer to `msg.msg_control`.
94         // * `msg.msg_control` is properly aligned because it was allocated by `allocate_cmsg_buffer`.
95         // * The buffer was zero-initialized, which is a valid bit-pattern for `libc::cmsghdr`.
96         // * The reference does not escape this function.
97         while let Some(current) = cmsg.as_ref() {
98             if current.cmsg_level != libc::SOL_SOCKET || current.cmsg_type != libc::SCM_RIGHTS {
99                 cmsg = libc::CMSG_NXTHDR(msg, cmsg);
100                 continue;
101             }
102 
103             let data_len = min(current.cmsg_len - libc::CMSG_LEN(0) as usize, rem);
104 
105             // Safety: `fds` and `libc::CMSG_DATA(cmsg)` are valid for `data_len` bytes of memory.
106             libc::memcpy(
107                 fds[fd_pos..].as_mut_ptr().cast(),
108                 libc::CMSG_DATA(cmsg).cast(),
109                 data_len,
110             );
111             rem -= data_len;
112             fd_pos += data_len / size_of::<RawFd>();
113             if rem == 0 {
114                 break;
115             }
116 
117             cmsg = libc::CMSG_NXTHDR(msg, cmsg);
118         }
119     }
120 
121     Ok((cap - rem) / size_of::<RawFd>())
122 }
123 
124 #[cfg(test)]
125 mod tests {
126     use std::ptr;
127 
128     use super::*;
129 
130     #[test]
131     #[cfg(not(target_arch = "arm"))]
test_add_fds_to_message()132     fn test_add_fds_to_message() {
133         let buf = [0xEAu8, 0xDD, 0xAA, 0xCC];
134         let mut iov = libc::iovec {
135             iov_base: buf.as_ptr() as *const libc::c_void as *mut libc::c_void,
136             iov_len: buf.len() as libc::size_t,
137         };
138 
139         let fds = [0xDE, 0xAD, 0xBE, 0xEF];
140         let mut msg = libc::msghdr {
141             msg_name: ptr::null_mut(),
142             msg_namelen: 0,
143             msg_iov: &mut iov,
144             msg_iovlen: 1,
145             msg_flags: 0,
146             msg_control: ptr::null_mut(),
147             msg_controllen: 0,
148         };
149 
150         let cmsg_buffer = add_fds_to_message(&mut msg, &fds[..]).unwrap();
151         let expected_cmsg = [
152             32u8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0xDE, 0, 0, 0, 0xAD, 0, 0, 0, 0xBE,
153             0, 0, 0, 0xEF, 0, 0, 0,
154         ];
155         assert_eq!(unsafe { cmsg_buffer.as_slice::<u8>(9999) }, &expected_cmsg);
156         assert_eq!(msg.msg_controllen, unsafe {
157             cmsg_buffer.as_slice::<u8>(9999).len()
158         });
159         assert_eq!(msg.msg_control, cmsg_buffer.as_ptr());
160 
161         let mut extracted_fds = [0x0i32; 4];
162 
163         assert_eq!(
164             4,
165             take_fds_from_message(&msg, &mut extracted_fds[..]).unwrap()
166         );
167 
168         assert_eq!(extracted_fds, fds);
169     }
170 
171     #[test]
172     #[cfg(not(target_arch = "arm"))]
test_take_fds_from_message()173     fn test_take_fds_from_message() {
174         let buf = [0xEAu8, 0xDD, 0xAA, 0xCC];
175         let mut iov = libc::iovec {
176             iov_base: buf.as_ptr() as *const libc::c_void as *mut libc::c_void,
177             iov_len: buf.len() as libc::size_t,
178         };
179 
180         let mut cmsg = [
181             32u8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0xDE, 0, 0, 0, 0xAD, 0, 0, 0, 0xBE,
182             0, 0, 0, 0xEF, 0, 0, 0, 32u8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0xDE, 0, 0,
183             0, 0xAD, 0, 0, 0, 0xBE, 0, 0, 0, 0xEF, 0, 0, 0,
184         ];
185 
186         let msg = libc::msghdr {
187             msg_name: ptr::null_mut(),
188             msg_namelen: 0,
189             msg_iov: &mut iov,
190             msg_iovlen: 1,
191             msg_flags: 0,
192             msg_control: cmsg.as_mut_ptr() as *mut libc::c_void,
193             msg_controllen: cmsg.len(),
194         };
195 
196         let mut extracted_fds = [0x0i32; 9];
197         assert_eq!(take_fds_from_message(&msg, &mut extracted_fds).unwrap(), 8);
198         assert_eq!(
199             extracted_fds,
200             [0xDE, 0xAD, 0xBE, 0xEF, 0xDE, 0xAD, 0xBE, 0xEF, 0x00]
201         );
202     }
203 }
204