1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::{
6 alloc::Layout,
7 cmp::min,
8 convert::TryFrom,
9 io,
10 mem::{align_of, size_of},
11 os::unix::io::RawFd,
12 };
13
14 use anyhow::anyhow;
15 use sys_util::LayoutAllocation;
16
17 // Allocates a buffer to hold a `libc::cmsghdr` with `cap` bytes of data.
18 //
19 // Returns the `LayoutAllocation` for the buffer as well as the size of the allocation, which is
20 // guaranteed to be at least `size_of::<libc::cmsghdr>() + cap` bytes.
allocate_cmsg_buffer(cap: u32) -> anyhow::Result<(LayoutAllocation, usize)>21 pub fn allocate_cmsg_buffer(cap: u32) -> anyhow::Result<(LayoutAllocation, usize)> {
22 // Not sure why this is unsafe.
23 let cmsg_cap = usize::try_from(unsafe { libc::CMSG_SPACE(cap) })
24 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
25 let alloc = Layout::from_size_align(cmsg_cap, align_of::<libc::cmsghdr>())
26 .map(LayoutAllocation::zeroed)
27 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
28
29 Ok((alloc, cmsg_cap))
30 }
31
32 // Adds a control message with the file descriptors from `fds` to `msg`.
33 // Note: this doesn't append but expects no cmsg already set and puts `fds` as a
34 // single `cmsg` inside passed `msg`
35 //
36 // Returns the `LayoutAllocation` backing the control message.
add_fds_to_message( msg: &mut libc::msghdr, fds: &[RawFd], ) -> anyhow::Result<LayoutAllocation>37 pub fn add_fds_to_message(
38 msg: &mut libc::msghdr,
39 fds: &[RawFd],
40 ) -> anyhow::Result<LayoutAllocation> {
41 let fd_len = fds
42 .len()
43 .checked_mul(size_of::<RawFd>())
44 .and_then(|l| u32::try_from(l).ok())
45 .ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))?;
46
47 let (cmsg_buffer, cmsg_cap) = allocate_cmsg_buffer(fd_len)?;
48
49 if !msg.msg_control.is_null() {
50 anyhow::bail!("msg already contains cmsg");
51 }
52
53 msg.msg_control = cmsg_buffer.as_ptr();
54 msg.msg_controllen = cmsg_cap;
55
56 unsafe {
57 // Safety:
58 // * CMSG_FIRSTHDR will either return a null pointer or a pointer to `msg.msg_control`.
59 // * `msg.msg_control` is properly aligned because `cmsg_buffer` is properly aligned.
60 // * The buffer is zeroed, which is a valid bit-pattern for `libc::cmsghdr`.
61 // * The reference does not escape this function.
62 let cmsg = libc::CMSG_FIRSTHDR(msg).as_mut().unwrap();
63 cmsg.cmsg_len = libc::CMSG_LEN(fd_len) as libc::size_t;
64 cmsg.cmsg_level = libc::SOL_SOCKET;
65 cmsg.cmsg_type = libc::SCM_RIGHTS;
66
67 // Safety: `libc::CMSG_DATA(cmsg)` and `fds` are valid for `fd_len` bytes of memory.
68 libc::memcpy(
69 libc::CMSG_DATA(cmsg).cast(),
70 fds.as_ptr().cast(),
71 fd_len as usize,
72 );
73 }
74
75 Ok(cmsg_buffer)
76 }
77
78 // Copies file descriptors from the control message in `msg` into `fds`.
79 //
80 // Returns the number of file descriptors that were copied from `msg`.
take_fds_from_message(msg: &libc::msghdr, fds: &mut [RawFd]) -> anyhow::Result<usize>81 pub fn take_fds_from_message(msg: &libc::msghdr, fds: &mut [RawFd]) -> anyhow::Result<usize> {
82 let cap = fds
83 .len()
84 .checked_mul(size_of::<RawFd>())
85 .ok_or_else(|| anyhow!(io::Error::from(io::ErrorKind::InvalidInput)))?;
86
87 let mut rem = cap;
88 let mut fd_pos = 0;
89 unsafe {
90 let mut cmsg = libc::CMSG_FIRSTHDR(msg);
91
92 // Safety:
93 // * CMSG_FIRSTHDR will either return a null pointer or a pointer to `msg.msg_control`.
94 // * `msg.msg_control` is properly aligned because it was allocated by `allocate_cmsg_buffer`.
95 // * The buffer was zero-initialized, which is a valid bit-pattern for `libc::cmsghdr`.
96 // * The reference does not escape this function.
97 while let Some(current) = cmsg.as_ref() {
98 if current.cmsg_level != libc::SOL_SOCKET || current.cmsg_type != libc::SCM_RIGHTS {
99 cmsg = libc::CMSG_NXTHDR(msg, cmsg);
100 continue;
101 }
102
103 let data_len = min(current.cmsg_len - libc::CMSG_LEN(0) as usize, rem);
104
105 // Safety: `fds` and `libc::CMSG_DATA(cmsg)` are valid for `data_len` bytes of memory.
106 libc::memcpy(
107 fds[fd_pos..].as_mut_ptr().cast(),
108 libc::CMSG_DATA(cmsg).cast(),
109 data_len,
110 );
111 rem -= data_len;
112 fd_pos += data_len / size_of::<RawFd>();
113 if rem == 0 {
114 break;
115 }
116
117 cmsg = libc::CMSG_NXTHDR(msg, cmsg);
118 }
119 }
120
121 Ok((cap - rem) / size_of::<RawFd>())
122 }
123
124 #[cfg(test)]
125 mod tests {
126 use std::ptr;
127
128 use super::*;
129
130 #[test]
131 #[cfg(not(target_arch = "arm"))]
test_add_fds_to_message()132 fn test_add_fds_to_message() {
133 let buf = [0xEAu8, 0xDD, 0xAA, 0xCC];
134 let mut iov = libc::iovec {
135 iov_base: buf.as_ptr() as *const libc::c_void as *mut libc::c_void,
136 iov_len: buf.len() as libc::size_t,
137 };
138
139 let fds = [0xDE, 0xAD, 0xBE, 0xEF];
140 let mut msg = libc::msghdr {
141 msg_name: ptr::null_mut(),
142 msg_namelen: 0,
143 msg_iov: &mut iov,
144 msg_iovlen: 1,
145 msg_flags: 0,
146 msg_control: ptr::null_mut(),
147 msg_controllen: 0,
148 };
149
150 let cmsg_buffer = add_fds_to_message(&mut msg, &fds[..]).unwrap();
151 let expected_cmsg = [
152 32u8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0xDE, 0, 0, 0, 0xAD, 0, 0, 0, 0xBE,
153 0, 0, 0, 0xEF, 0, 0, 0,
154 ];
155 assert_eq!(unsafe { cmsg_buffer.as_slice::<u8>(9999) }, &expected_cmsg);
156 assert_eq!(msg.msg_controllen, unsafe {
157 cmsg_buffer.as_slice::<u8>(9999).len()
158 });
159 assert_eq!(msg.msg_control, cmsg_buffer.as_ptr());
160
161 let mut extracted_fds = [0x0i32; 4];
162
163 assert_eq!(
164 4,
165 take_fds_from_message(&msg, &mut extracted_fds[..]).unwrap()
166 );
167
168 assert_eq!(extracted_fds, fds);
169 }
170
171 #[test]
172 #[cfg(not(target_arch = "arm"))]
test_take_fds_from_message()173 fn test_take_fds_from_message() {
174 let buf = [0xEAu8, 0xDD, 0xAA, 0xCC];
175 let mut iov = libc::iovec {
176 iov_base: buf.as_ptr() as *const libc::c_void as *mut libc::c_void,
177 iov_len: buf.len() as libc::size_t,
178 };
179
180 let mut cmsg = [
181 32u8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0xDE, 0, 0, 0, 0xAD, 0, 0, 0, 0xBE,
182 0, 0, 0, 0xEF, 0, 0, 0, 32u8, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0xDE, 0, 0,
183 0, 0xAD, 0, 0, 0, 0xBE, 0, 0, 0, 0xEF, 0, 0, 0,
184 ];
185
186 let msg = libc::msghdr {
187 msg_name: ptr::null_mut(),
188 msg_namelen: 0,
189 msg_iov: &mut iov,
190 msg_iovlen: 1,
191 msg_flags: 0,
192 msg_control: cmsg.as_mut_ptr() as *mut libc::c_void,
193 msg_controllen: cmsg.len(),
194 };
195
196 let mut extracted_fds = [0x0i32; 9];
197 assert_eq!(take_fds_from_message(&msg, &mut extracted_fds).unwrap(), 8);
198 assert_eq!(
199 extracted_fds,
200 [0xDE, 0xAD, 0xBE, 0xEF, 0xDE, 0xAD, 0xBE, 0xEF, 0x00]
201 );
202 }
203 }
204