• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::alloc::Layout;
6 use std::mem::MaybeUninit;
7 use std::os::unix::io::AsRawFd;
8 use std::str;
9 
10 use libc::EINVAL;
11 use log::error;
12 use zerocopy::FromBytes;
13 use zerocopy::Immutable;
14 use zerocopy::IntoBytes;
15 use zerocopy::KnownLayout;
16 
17 use super::errno_result;
18 use super::getpid;
19 use super::Error;
20 use super::RawDescriptor;
21 use super::Result;
22 use crate::alloc::LayoutAllocation;
23 use crate::descriptor::AsRawDescriptor;
24 use crate::descriptor::FromRawDescriptor;
25 use crate::descriptor::SafeDescriptor;
26 
27 macro_rules! debug_pr {
28     // By default debugs are suppressed, to enabled them replace macro body with:
29     // $($args:tt)+) => (println!($($args)*))
30     ($($args:tt)+) => {};
31 }
32 
33 const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
34 const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
35 const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
36 const NLATTR_ALIGN_TO: usize = 4;
37 
38 #[repr(C)]
39 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
40 struct NlMsgHdr {
41     pub nlmsg_len: u32,
42     pub nlmsg_type: u16,
43     pub nlmsg_flags: u16,
44     pub nlmsg_seq: u32,
45     pub nlmsg_pid: u32,
46 }
47 
48 /// Netlink attribute struct, can be used by netlink consumer
49 #[repr(C)]
50 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
51 pub struct NlAttr {
52     pub len: u16,
53     pub _type: u16,
54 }
55 
56 /// Generic netlink header struct, can be used by netlink consumer
57 #[repr(C)]
58 #[derive(Copy, Clone, FromBytes, Immutable, IntoBytes, KnownLayout)]
59 pub struct GenlMsgHdr {
60     pub cmd: u8,
61     pub version: u8,
62     pub reserved: u16,
63 }
64 /// A single netlink message, including its header and data.
65 pub struct NetlinkMessage<'a> {
66     pub _type: u16,
67     pub flags: u16,
68     pub seq: u32,
69     pub pid: u32,
70     pub data: &'a [u8],
71 }
72 
73 pub struct NlAttrWithData<'a> {
74     pub len: u16,
75     pub _type: u16,
76     pub data: &'a [u8],
77 }
78 
79 /// Iterator over `struct NlAttr` as received from a netlink socket.
80 pub struct NetlinkGenericDataIter<'a> {
81     // `data` must be properly aligned for NlAttr.
82     data: &'a [u8],
83 }
84 
85 impl<'a> Iterator for NetlinkGenericDataIter<'a> {
86     type Item = NlAttrWithData<'a>;
87 
next(&mut self) -> Option<Self::Item>88     fn next(&mut self) -> Option<Self::Item> {
89         let (nl_hdr, _) = NlAttr::read_from_prefix(self.data).ok()?;
90         let nl_data_len = nl_hdr.len as usize;
91         let data = self.data.get(NLA_HDRLEN..nl_data_len)?;
92 
93         // Get next NlAttr
94         let next_hdr = nl_data_len.next_multiple_of(NLATTR_ALIGN_TO);
95         self.data = self.data.get(next_hdr..).unwrap_or(&[]);
96 
97         Some(NlAttrWithData {
98             _type: nl_hdr._type,
99             len: nl_hdr.len,
100             data,
101         })
102     }
103 }
104 
105 /// Iterator over `struct nlmsghdr` as received from a netlink socket.
106 pub struct NetlinkMessageIter<'a> {
107     // `data` must be properly aligned for nlmsghdr.
108     data: &'a [u8],
109 }
110 
111 impl<'a> Iterator for NetlinkMessageIter<'a> {
112     type Item = NetlinkMessage<'a>;
113 
next(&mut self) -> Option<Self::Item>114     fn next(&mut self) -> Option<Self::Item> {
115         let (hdr, _) = NlMsgHdr::read_from_prefix(self.data).ok()?;
116         let msg_len = hdr.nlmsg_len as usize;
117         let data = self.data.get(NLMSGHDR_SIZE..msg_len)?;
118 
119         // NLMSG_NEXT
120         let next_hdr = msg_len.next_multiple_of(std::mem::align_of::<NlMsgHdr>());
121         self.data = self.data.get(next_hdr..).unwrap_or(&[]);
122 
123         Some(NetlinkMessage {
124             _type: hdr.nlmsg_type,
125             flags: hdr.nlmsg_flags,
126             seq: hdr.nlmsg_seq,
127             pid: hdr.nlmsg_pid,
128             data,
129         })
130     }
131 }
132 
133 /// Safe wrapper for `NETLINK_GENERIC` netlink sockets.
134 pub struct NetlinkGenericSocket {
135     sock: SafeDescriptor,
136 }
137 
138 impl AsRawDescriptor for NetlinkGenericSocket {
as_raw_descriptor(&self) -> RawDescriptor139     fn as_raw_descriptor(&self) -> RawDescriptor {
140         self.sock.as_raw_descriptor()
141     }
142 }
143 
144 impl NetlinkGenericSocket {
145     /// Create and bind a new `NETLINK_GENERIC` socket.
new(nl_groups: u32) -> Result<Self>146     pub fn new(nl_groups: u32) -> Result<Self> {
147         // SAFETY:
148         // Safe because we check the return value and convert the raw fd into a SafeDescriptor.
149         let sock = unsafe {
150             let fd = libc::socket(
151                 libc::AF_NETLINK,
152                 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
153                 libc::NETLINK_GENERIC,
154             );
155             if fd < 0 {
156                 return errno_result();
157             }
158 
159             SafeDescriptor::from_raw_descriptor(fd)
160         };
161 
162         // SAFETY:
163         // This MaybeUninit dance is needed because sockaddr_nl has a private padding field and
164         // doesn't implement Default. Safe because all 0s is valid data for sockaddr_nl.
165         let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
166         sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
167         sa.nl_groups = nl_groups;
168 
169         // SAFETY:
170         // Safe because we pass a descriptor that we own and valid pointer/size for sockaddr.
171         unsafe {
172             let res = libc::bind(
173                 sock.as_raw_fd(),
174                 &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
175                 std::mem::size_of_val(&sa) as libc::socklen_t,
176             );
177             if res < 0 {
178                 return errno_result();
179             }
180         }
181 
182         Ok(NetlinkGenericSocket { sock })
183     }
184 
185     /// Receive messages from the netlink socket.
recv(&self) -> Result<NetlinkGenericRead>186     pub fn recv(&self) -> Result<NetlinkGenericRead> {
187         let buf_size = 8192; // TODO(dverkamp): make this configurable?
188 
189         // Create a buffer with sufficient alignment for nlmsghdr.
190         let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
191             .map_err(|_| Error::new(EINVAL))?;
192         let allocation = LayoutAllocation::uninitialized(layout);
193 
194         // SAFETY:
195         // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
196         let bytes_read = unsafe {
197             let res = libc::recv(self.sock.as_raw_fd(), allocation.as_ptr(), buf_size, 0);
198             if res < 0 {
199                 return errno_result();
200             }
201             res as usize
202         };
203 
204         Ok(NetlinkGenericRead {
205             allocation,
206             len: bytes_read,
207         })
208     }
209 
family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead>210     pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
211         let buf_size = 1024;
212         debug_pr!(
213             "preparing query for family name {}, len {}",
214             family_name,
215             family_name.len()
216         );
217 
218         // Create a buffer with sufficient alignment for nlmsghdr.
219         let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
220             .map_err(|_| Error::new(EINVAL))
221             .unwrap();
222         let mut allocation = LayoutAllocation::zeroed(layout);
223 
224         // SAFETY:
225         // Safe because the data in allocation was initialized up to `buf_size` and is
226         // sufficiently aligned.
227         let data = unsafe { allocation.as_mut_slice(buf_size) };
228 
229         // Prepare the netlink message header
230         let (hdr, genl_hdr) = NlMsgHdr::mut_from_prefix(data).expect("failed to unwrap");
231         hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
232         hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
233         hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
234         hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
235         hdr.nlmsg_pid = getpid() as u32;
236 
237         // Prepare generic netlink message header
238         let (genl_hdr, nlattr) =
239             GenlMsgHdr::mut_from_prefix(genl_hdr).expect("unable to get GenlMsgHdr from slice");
240         genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
241         genl_hdr.version = 0x1;
242 
243         // Netlink attributes
244         let (nl_attr, payload) =
245             NlAttr::mut_from_prefix(nlattr).expect("unable to get NlAttr from slice");
246         nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
247         nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
248 
249         // Fill the message payload with the family name
250         payload[..family_name.len()].copy_from_slice(family_name.as_bytes());
251 
252         let len = NLMSGHDR_SIZE + GENL_HDRLEN + NLA_HDRLEN + family_name.len() + 1;
253 
254         // SAFETY:
255         // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
256         unsafe {
257             let res = libc::send(self.sock.as_raw_fd(), allocation.as_ptr(), len, 0);
258             if res < 0 {
259                 error!("failed to send get_family_cmd");
260                 return errno_result();
261             }
262         };
263 
264         // Return the answer
265         match self.recv() {
266             Ok(msg) => Ok(msg),
267             Err(e) => {
268                 error!("recv get_family returned with error {}", e);
269                 Err(e)
270             }
271         }
272     }
273 }
274 
parse_ctrl_group_name_and_id( nested_nl_attr_data: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>275 fn parse_ctrl_group_name_and_id(
276     nested_nl_attr_data: NetlinkGenericDataIter,
277     group_name: &str,
278 ) -> Option<u32> {
279     let mut mcast_group_id: Option<u32> = None;
280 
281     for nested_nl_attr in nested_nl_attr_data {
282         debug_pr!(
283             "\t\tmcast_grp: nlattr type {}, len {}",
284             nested_nl_attr._type,
285             nested_nl_attr.len
286         );
287 
288         if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
289             mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
290             debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
291         }
292 
293         if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
294             debug_pr!(
295                 "\t\t mcast group name {}",
296                 strip_padding(&nested_nl_attr.data)
297             );
298 
299             // If the group name match and the group_id was set in previous iteration, return,
300             // valid for group_name, group_id
301             if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
302                 debug_pr!(
303                     "\t\t Got what we were looking for group_id = {} for {}",
304                     mcast_group_id?,
305                     group_name
306                 );
307 
308                 return mcast_group_id;
309             }
310         }
311     }
312 
313     None
314 }
315 
316 /// Parse CTRL_ATTR_MCAST_GROUPS data in order to get multicast group id
317 ///
318 /// On success, returns group_id for a given `group_name`
319 ///
320 /// # Arguments
321 ///
322 /// * `nl_attr_area`
323 ///
324 ///     Nested attributes area (CTRL_ATTR_MCAST_GROUPS data), where nl_attr's corresponding to
325 ///     specific groups are embed
326 ///
327 /// * `group_name`
328 ///
329 ///     String with group_name for which we are looking group_id
330 ///
331 /// the CTRL_ATTR_MCAST_GROUPS data has nested attributes. Each of nested attribute is per
332 /// multicast group attributes, which have another nested attributes: CTRL_ATTR_MCAST_GRP_NAME and
333 /// CTRL_ATTR_MCAST_GRP_ID. Need to parse all of them to get mcast group id for a given group_name..
334 ///
335 /// Illustrated layout:
336 /// CTRL_ATTR_MCAST_GROUPS:
337 ///   GR1 (nl_attr._type = 1):
338 ///       CTRL_ATTR_MCAST_GRP_ID,
339 ///       CTRL_ATTR_MCAST_GRP_NAME,
340 ///   GR2 (nl_attr._type = 2):
341 ///       CTRL_ATTR_MCAST_GRP_ID,
342 ///       CTRL_ATTR_MCAST_GRP_NAME,
343 ///   ..
344 ///
345 /// Unfortunately kernel implementation uses `nla_nest_start_noflag` for that
346 /// purpose, which means that it never marked their nest attributes with NLA_F_NESTED flag.
347 /// Therefore all this nesting stages need to be deduced based on specific nl_attr type.
parse_ctrl_mcast_group_id( nl_attr_area: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>348 fn parse_ctrl_mcast_group_id(
349     nl_attr_area: NetlinkGenericDataIter,
350     group_name: &str,
351 ) -> Option<u32> {
352     // There may be multiple nested multicast groups, go through all of them.
353     // Each of nested group, has other nested nlattr:
354     //  CTRL_ATTR_MCAST_GRP_ID
355     //  CTRL_ATTR_MCAST_GRP_NAME
356     //
357     //  which are further proceed by parse_ctrl_group_name_and_id
358     for nested_gr_nl_attr in nl_attr_area {
359         debug_pr!(
360             "\tmcast_groups: nlattr type(gr_nr) {}, len {}",
361             nested_gr_nl_attr._type,
362             nested_gr_nl_attr.len
363         );
364 
365         let netlink_nested_attr = NetlinkGenericDataIter {
366             data: nested_gr_nl_attr.data,
367         };
368 
369         if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
370         {
371             return Some(mcast_group_id);
372         }
373     }
374 
375     None
376 }
377 
378 // Like `CStr::from_bytes_with_nul` but strips any bytes starting from first '\0'-byte and
379 // returns &str. Panics if `b` doesn't contain any '\0' bytes.
strip_padding(b: &[u8]) -> &str380 fn strip_padding(b: &[u8]) -> &str {
381     // It would be nice if we could use memchr here but that's locked behind an unstable gate.
382     let pos = b
383         .iter()
384         .position(|&c| c == 0)
385         .expect("`b` doesn't contain any nul bytes");
386 
387     str::from_utf8(&b[..pos]).unwrap()
388 }
389 
390 pub struct NetlinkGenericRead {
391     allocation: LayoutAllocation,
392     len: usize,
393 }
394 
395 impl NetlinkGenericRead {
iter(&self) -> NetlinkMessageIter396     pub fn iter(&self) -> NetlinkMessageIter {
397         // SAFETY:
398         // Safe because the data in allocation was initialized up to `self.len` by `recv()` and is
399         // sufficiently aligned.
400         let data = unsafe { &self.allocation.as_slice(self.len) };
401         NetlinkMessageIter { data }
402     }
403 
404     /// Parse NetlinkGeneric response in order to get multicast group id
405     ///
406     /// On success, returns group_id for a given `group_name`
407     ///
408     /// # Arguments
409     ///
410     /// * `group_name` - String with group_name for which we are looking group_id
411     ///
412     /// Response from family_name_query (CTRL_CMD_GETFAMILY) is a netlink message with multiple
413     /// attributes encapsulated (some of them are nested). An example response layout is
414     /// illustrated below:
415     ///
416     ///  {
417     ///    CTRL_ATTR_FAMILY_NAME
418     ///    CTRL_ATTR_FAMILY_ID
419     ///    CTRL_ATTR_VERSION
420     ///    ...
421     ///    CTRL_ATTR_MCAST_GROUPS {
422     ///      GR1 (nl_attr._type = 1) {
423     ///          CTRL_ATTR_MCAST_GRP_ID    *we need parse this attr to obtain group id used for
424     ///                                     the group mask
425     ///          CTRL_ATTR_MCAST_GRP_NAME  *group_name that we need to match with
426     ///      }
427     ///      GR2 (nl_attr._type = 2) {
428     ///          CTRL_ATTR_MCAST_GRP_ID
429     ///          CTRL_ATTR_MCAST_GRP_NAME
430     ///      }
431     ///      ...
432     ///     }
433     ///   }
get_multicast_group_id(&self, group_name: String) -> Option<u32>434     pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
435         for netlink_msg in self.iter() {
436             debug_pr!(
437                 "received type: {}, flags {}, pid {}, data {:?}",
438                 netlink_msg._type,
439                 netlink_msg.flags,
440                 netlink_msg.pid,
441                 netlink_msg.data
442             );
443 
444             if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
445                 error!("Received not a generic netlink controller msg");
446                 return None;
447             }
448 
449             let netlink_data = NetlinkGenericDataIter {
450                 data: &netlink_msg.data[GENL_HDRLEN..],
451             };
452             for nl_attr in netlink_data {
453                 debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
454 
455                 if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
456                     let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
457 
458                     if let Some(mcast_group_id) =
459                         parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
460                     {
461                         return Some(mcast_group_id);
462                     }
463                 }
464             }
465         }
466         None
467     }
468 }
469