1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::alloc::Layout;
6 use std::mem::MaybeUninit;
7 use std::os::unix::io::AsRawFd;
8 use std::str;
9
10 use libc::EINVAL;
11 use log::error;
12 use zerocopy::FromBytes;
13 use zerocopy::Immutable;
14 use zerocopy::IntoBytes;
15 use zerocopy::KnownLayout;
16
17 use super::errno_result;
18 use super::getpid;
19 use super::Error;
20 use super::RawDescriptor;
21 use super::Result;
22 use crate::alloc::LayoutAllocation;
23 use crate::descriptor::AsRawDescriptor;
24 use crate::descriptor::FromRawDescriptor;
25 use crate::descriptor::SafeDescriptor;
26
27 macro_rules! debug_pr {
28 // By default debugs are suppressed, to enabled them replace macro body with:
29 // $($args:tt)+) => (println!($($args)*))
30 ($($args:tt)+) => {};
31 }
32
33 const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
34 const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
35 const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
36 const NLATTR_ALIGN_TO: usize = 4;
37
38 #[repr(C)]
39 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
40 struct NlMsgHdr {
41 pub nlmsg_len: u32,
42 pub nlmsg_type: u16,
43 pub nlmsg_flags: u16,
44 pub nlmsg_seq: u32,
45 pub nlmsg_pid: u32,
46 }
47
48 /// Netlink attribute struct, can be used by netlink consumer
49 #[repr(C)]
50 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
51 pub struct NlAttr {
52 pub len: u16,
53 pub _type: u16,
54 }
55
56 /// Generic netlink header struct, can be used by netlink consumer
57 #[repr(C)]
58 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
59 pub struct GenlMsgHdr {
60 pub cmd: u8,
61 pub version: u8,
62 pub reserved: u16,
63 }
64 /// A single netlink message, including its header and data.
65 pub struct NetlinkMessage<'a> {
66 pub _type: u16,
67 pub flags: u16,
68 pub seq: u32,
69 pub pid: u32,
70 pub data: &'a [u8],
71 }
72
73 pub struct NlAttrWithData<'a> {
74 pub len: u16,
75 pub _type: u16,
76 pub data: &'a [u8],
77 }
78
79 /// Iterator over `struct NlAttr` as received from a netlink socket.
80 pub struct NetlinkGenericDataIter<'a> {
81 // `data` must be properly aligned for NlAttr.
82 data: &'a [u8],
83 }
84
85 impl<'a> Iterator for NetlinkGenericDataIter<'a> {
86 type Item = NlAttrWithData<'a>;
87
next(&mut self) -> Option<Self::Item>88 fn next(&mut self) -> Option<Self::Item> {
89 let (nl_hdr, _) = NlAttr::read_from_prefix(self.data).ok()?;
90 let nl_data_len = nl_hdr.len as usize;
91 let data = self.data.get(NLA_HDRLEN..nl_data_len)?;
92
93 // Get next NlAttr
94 let next_hdr = nl_data_len.next_multiple_of(NLATTR_ALIGN_TO);
95 self.data = self.data.get(next_hdr..).unwrap_or(&[]);
96
97 Some(NlAttrWithData {
98 _type: nl_hdr._type,
99 len: nl_hdr.len,
100 data,
101 })
102 }
103 }
104
105 /// Iterator over `struct nlmsghdr` as received from a netlink socket.
106 pub struct NetlinkMessageIter<'a> {
107 // `data` must be properly aligned for nlmsghdr.
108 data: &'a [u8],
109 }
110
111 impl<'a> Iterator for NetlinkMessageIter<'a> {
112 type Item = NetlinkMessage<'a>;
113
next(&mut self) -> Option<Self::Item>114 fn next(&mut self) -> Option<Self::Item> {
115 let (hdr, _) = NlMsgHdr::read_from_prefix(self.data).ok()?;
116 let msg_len = hdr.nlmsg_len as usize;
117 let data = self.data.get(NLMSGHDR_SIZE..msg_len)?;
118
119 // NLMSG_NEXT
120 let next_hdr = msg_len.next_multiple_of(std::mem::align_of::<NlMsgHdr>());
121 self.data = self.data.get(next_hdr..).unwrap_or(&[]);
122
123 Some(NetlinkMessage {
124 _type: hdr.nlmsg_type,
125 flags: hdr.nlmsg_flags,
126 seq: hdr.nlmsg_seq,
127 pid: hdr.nlmsg_pid,
128 data,
129 })
130 }
131 }
132
133 /// Safe wrapper for `NETLINK_GENERIC` netlink sockets.
134 pub struct NetlinkGenericSocket {
135 sock: SafeDescriptor,
136 }
137
138 impl AsRawDescriptor for NetlinkGenericSocket {
as_raw_descriptor(&self) -> RawDescriptor139 fn as_raw_descriptor(&self) -> RawDescriptor {
140 self.sock.as_raw_descriptor()
141 }
142 }
143
144 impl NetlinkGenericSocket {
145 /// Create and bind a new `NETLINK_GENERIC` socket.
new(nl_groups: u32) -> Result<Self>146 pub fn new(nl_groups: u32) -> Result<Self> {
147 // SAFETY:
148 // Safe because we check the return value and convert the raw fd into a SafeDescriptor.
149 let sock = unsafe {
150 let fd = libc::socket(
151 libc::AF_NETLINK,
152 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
153 libc::NETLINK_GENERIC,
154 );
155 if fd < 0 {
156 return errno_result();
157 }
158
159 SafeDescriptor::from_raw_descriptor(fd)
160 };
161
162 // SAFETY:
163 // This MaybeUninit dance is needed because sockaddr_nl has a private padding field and
164 // doesn't implement Default. Safe because all 0s is valid data for sockaddr_nl.
165 let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
166 sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
167 sa.nl_groups = nl_groups;
168
169 // SAFETY:
170 // Safe because we pass a descriptor that we own and valid pointer/size for sockaddr.
171 unsafe {
172 let res = libc::bind(
173 sock.as_raw_fd(),
174 &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
175 std::mem::size_of_val(&sa) as libc::socklen_t,
176 );
177 if res < 0 {
178 return errno_result();
179 }
180 }
181
182 Ok(NetlinkGenericSocket { sock })
183 }
184
185 /// Receive messages from the netlink socket.
recv(&self) -> Result<NetlinkGenericRead>186 pub fn recv(&self) -> Result<NetlinkGenericRead> {
187 let buf_size = 8192; // TODO(dverkamp): make this configurable?
188
189 // Create a buffer with sufficient alignment for nlmsghdr.
190 let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
191 .map_err(|_| Error::new(EINVAL))?;
192 let allocation = LayoutAllocation::uninitialized(layout);
193
194 // SAFETY:
195 // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
196 let bytes_read = unsafe {
197 let res = libc::recv(self.sock.as_raw_fd(), allocation.as_ptr(), buf_size, 0);
198 if res < 0 {
199 return errno_result();
200 }
201 res as usize
202 };
203
204 Ok(NetlinkGenericRead {
205 allocation,
206 len: bytes_read,
207 })
208 }
209
family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead>210 pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
211 let buf_size = 1024;
212 debug_pr!(
213 "preparing query for family name {}, len {}",
214 family_name,
215 family_name.len()
216 );
217
218 // Create a buffer with sufficient alignment for nlmsghdr.
219 let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
220 .map_err(|_| Error::new(EINVAL))
221 .unwrap();
222 let mut allocation = LayoutAllocation::zeroed(layout);
223
224 // SAFETY:
225 // Safe because the data in allocation was initialized up to `buf_size` and is
226 // sufficiently aligned.
227 let data = unsafe { allocation.as_mut_slice(buf_size) };
228
229 // Prepare the netlink message header
230 let (hdr, genl_hdr) = NlMsgHdr::mut_from_prefix(data).expect("failed to unwrap");
231 hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
232 hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
233 hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
234 hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
235 hdr.nlmsg_pid = getpid() as u32;
236
237 // Prepare generic netlink message header
238 let (genl_hdr, nlattr) =
239 GenlMsgHdr::mut_from_prefix(genl_hdr).expect("unable to get GenlMsgHdr from slice");
240 genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
241 genl_hdr.version = 0x1;
242
243 // Netlink attributes
244 let (nl_attr, payload) =
245 NlAttr::mut_from_prefix(nlattr).expect("unable to get NlAttr from slice");
246 nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
247 nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
248
249 // Fill the message payload with the family name
250 payload[..family_name.len()].copy_from_slice(family_name.as_bytes());
251
252 let len = NLMSGHDR_SIZE + GENL_HDRLEN + NLA_HDRLEN + family_name.len() + 1;
253
254 // SAFETY:
255 // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
256 unsafe {
257 let res = libc::send(self.sock.as_raw_fd(), allocation.as_ptr(), len, 0);
258 if res < 0 {
259 error!("failed to send get_family_cmd");
260 return errno_result();
261 }
262 };
263
264 // Return the answer
265 match self.recv() {
266 Ok(msg) => Ok(msg),
267 Err(e) => {
268 error!("recv get_family returned with error {}", e);
269 Err(e)
270 }
271 }
272 }
273 }
274
parse_ctrl_group_name_and_id( nested_nl_attr_data: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>275 fn parse_ctrl_group_name_and_id(
276 nested_nl_attr_data: NetlinkGenericDataIter,
277 group_name: &str,
278 ) -> Option<u32> {
279 let mut mcast_group_id: Option<u32> = None;
280
281 for nested_nl_attr in nested_nl_attr_data {
282 debug_pr!(
283 "\t\tmcast_grp: nlattr type {}, len {}",
284 nested_nl_attr._type,
285 nested_nl_attr.len
286 );
287
288 if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
289 mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
290 debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
291 }
292
293 if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
294 debug_pr!(
295 "\t\t mcast group name {}",
296 strip_padding(&nested_nl_attr.data)
297 );
298
299 // If the group name match and the group_id was set in previous iteration, return,
300 // valid for group_name, group_id
301 if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
302 debug_pr!(
303 "\t\t Got what we were looking for group_id = {} for {}",
304 mcast_group_id?,
305 group_name
306 );
307
308 return mcast_group_id;
309 }
310 }
311 }
312
313 None
314 }
315
316 /// Parse CTRL_ATTR_MCAST_GROUPS data in order to get multicast group id
317 ///
318 /// On success, returns group_id for a given `group_name`
319 ///
320 /// # Arguments
321 ///
322 /// * `nl_attr_area`
323 ///
324 /// Nested attributes area (CTRL_ATTR_MCAST_GROUPS data), where nl_attr's corresponding to
325 /// specific groups are embed
326 ///
327 /// * `group_name`
328 ///
329 /// String with group_name for which we are looking group_id
330 ///
331 /// the CTRL_ATTR_MCAST_GROUPS data has nested attributes. Each of nested attribute is per
332 /// multicast group attributes, which have another nested attributes: CTRL_ATTR_MCAST_GRP_NAME and
333 /// CTRL_ATTR_MCAST_GRP_ID. Need to parse all of them to get mcast group id for a given group_name..
334 ///
335 /// Illustrated layout:
336 /// CTRL_ATTR_MCAST_GROUPS:
337 /// GR1 (nl_attr._type = 1):
338 /// CTRL_ATTR_MCAST_GRP_ID,
339 /// CTRL_ATTR_MCAST_GRP_NAME,
340 /// GR2 (nl_attr._type = 2):
341 /// CTRL_ATTR_MCAST_GRP_ID,
342 /// CTRL_ATTR_MCAST_GRP_NAME,
343 /// ..
344 ///
345 /// Unfortunately kernel implementation uses `nla_nest_start_noflag` for that
346 /// purpose, which means that it never marked their nest attributes with NLA_F_NESTED flag.
347 /// Therefore all this nesting stages need to be deduced based on specific nl_attr type.
parse_ctrl_mcast_group_id( nl_attr_area: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>348 fn parse_ctrl_mcast_group_id(
349 nl_attr_area: NetlinkGenericDataIter,
350 group_name: &str,
351 ) -> Option<u32> {
352 // There may be multiple nested multicast groups, go through all of them.
353 // Each of nested group, has other nested nlattr:
354 // CTRL_ATTR_MCAST_GRP_ID
355 // CTRL_ATTR_MCAST_GRP_NAME
356 //
357 // which are further proceed by parse_ctrl_group_name_and_id
358 for nested_gr_nl_attr in nl_attr_area {
359 debug_pr!(
360 "\tmcast_groups: nlattr type(gr_nr) {}, len {}",
361 nested_gr_nl_attr._type,
362 nested_gr_nl_attr.len
363 );
364
365 let netlink_nested_attr = NetlinkGenericDataIter {
366 data: nested_gr_nl_attr.data,
367 };
368
369 if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
370 {
371 return Some(mcast_group_id);
372 }
373 }
374
375 None
376 }
377
378 // Like `CStr::from_bytes_with_nul` but strips any bytes starting from first '\0'-byte and
379 // returns &str. Panics if `b` doesn't contain any '\0' bytes.
strip_padding(b: &[u8]) -> &str380 fn strip_padding(b: &[u8]) -> &str {
381 // It would be nice if we could use memchr here but that's locked behind an unstable gate.
382 let pos = b
383 .iter()
384 .position(|&c| c == 0)
385 .expect("`b` doesn't contain any nul bytes");
386
387 str::from_utf8(&b[..pos]).unwrap()
388 }
389
390 pub struct NetlinkGenericRead {
391 allocation: LayoutAllocation,
392 len: usize,
393 }
394
395 impl NetlinkGenericRead {
iter(&self) -> NetlinkMessageIter396 pub fn iter(&self) -> NetlinkMessageIter {
397 // SAFETY:
398 // Safe because the data in allocation was initialized up to `self.len` by `recv()` and is
399 // sufficiently aligned.
400 let data = unsafe { &self.allocation.as_slice(self.len) };
401 NetlinkMessageIter { data }
402 }
403
404 /// Parse NetlinkGeneric response in order to get multicast group id
405 ///
406 /// On success, returns group_id for a given `group_name`
407 ///
408 /// # Arguments
409 ///
410 /// * `group_name` - String with group_name for which we are looking group_id
411 ///
412 /// Response from family_name_query (CTRL_CMD_GETFAMILY) is a netlink message with multiple
413 /// attributes encapsulated (some of them are nested). An example response layout is
414 /// illustrated below:
415 ///
416 /// {
417 /// CTRL_ATTR_FAMILY_NAME
418 /// CTRL_ATTR_FAMILY_ID
419 /// CTRL_ATTR_VERSION
420 /// ...
421 /// CTRL_ATTR_MCAST_GROUPS {
422 /// GR1 (nl_attr._type = 1) {
423 /// CTRL_ATTR_MCAST_GRP_ID *we need parse this attr to obtain group id used for
424 /// the group mask
425 /// CTRL_ATTR_MCAST_GRP_NAME *group_name that we need to match with
426 /// }
427 /// GR2 (nl_attr._type = 2) {
428 /// CTRL_ATTR_MCAST_GRP_ID
429 /// CTRL_ATTR_MCAST_GRP_NAME
430 /// }
431 /// ...
432 /// }
433 /// }
get_multicast_group_id(&self, group_name: String) -> Option<u32>434 pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
435 for netlink_msg in self.iter() {
436 debug_pr!(
437 "received type: {}, flags {}, pid {}, data {:?}",
438 netlink_msg._type,
439 netlink_msg.flags,
440 netlink_msg.pid,
441 netlink_msg.data
442 );
443
444 if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
445 error!("Received not a generic netlink controller msg");
446 return None;
447 }
448
449 let netlink_data = NetlinkGenericDataIter {
450 data: &netlink_msg.data[GENL_HDRLEN..],
451 };
452 for nl_attr in netlink_data {
453 debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
454
455 if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
456 let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
457
458 if let Some(mcast_group_id) =
459 parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
460 {
461 return Some(mcast_group_id);
462 }
463 }
464 }
465 }
466 None
467 }
468 }
469