• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{alloc::Layout, mem::MaybeUninit, os::unix::io::AsRawFd, str};
6 
7 use data_model::DataInit;
8 use libc::EINVAL;
9 
10 use sys_util_core::LayoutAllocation;
11 
12 use super::{
13     errno_result, getpid, AsRawDescriptor, Error, FromRawDescriptor, RawDescriptor, Result,
14     SafeDescriptor,
15 };
16 
17 macro_rules! debug_pr {
18     // By default debugs are suppressed, to enabled them replace macro body with:
19     // $($args:tt)+) => (println!($($args)*))
20     ($($args:tt)+) => {};
21 }
22 
23 const NLMSGHDR_SIZE: usize = std::mem::size_of::<NlMsgHdr>();
24 const GENL_HDRLEN: usize = std::mem::size_of::<GenlMsgHdr>();
25 const NLA_HDRLEN: usize = std::mem::size_of::<NlAttr>();
26 const NLATTR_ALIGN_TO: usize = 4;
27 
28 // Custom nlmsghdr struct that can be declared DataInit.
29 #[repr(C)]
30 #[derive(Copy, Clone)]
31 struct NlMsgHdr {
32     pub nlmsg_len: u32,
33     pub nlmsg_type: u16,
34     pub nlmsg_flags: u16,
35     pub nlmsg_seq: u32,
36     pub nlmsg_pid: u32,
37 }
38 unsafe impl DataInit for NlMsgHdr {}
39 
40 /// Netlink attribute struct, can be used by netlink consumer
41 #[repr(C)]
42 #[derive(Copy, Clone)]
43 pub struct NlAttr {
44     pub len: u16,
45     pub _type: u16,
46 }
47 unsafe impl DataInit for NlAttr {}
48 
49 /// Generic netlink header struct, can be used by netlink consumer
50 #[repr(C)]
51 #[derive(Copy, Clone)]
52 pub struct GenlMsgHdr {
53     pub cmd: u8,
54     pub version: u8,
55     pub reserved: u16,
56 }
57 unsafe impl DataInit for GenlMsgHdr {}
58 
59 /// A single netlink message, including its header and data.
60 pub struct NetlinkMessage<'a> {
61     pub _type: u16,
62     pub flags: u16,
63     pub seq: u32,
64     pub pid: u32,
65     pub data: &'a [u8],
66 }
67 
68 pub struct NlAttrWithData<'a> {
69     pub len: u16,
70     pub _type: u16,
71     pub data: &'a [u8],
72 }
73 
nlattr_align(offset: usize) -> usize74 fn nlattr_align(offset: usize) -> usize {
75     return (offset + NLATTR_ALIGN_TO - 1) & !(NLATTR_ALIGN_TO - 1);
76 }
77 
78 /// Iterator over `struct NlAttr` as received from a netlink socket.
79 pub struct NetlinkGenericDataIter<'a> {
80     // `data` must be properly aligned for NlAttr.
81     data: &'a [u8],
82 }
83 
84 impl<'a> Iterator for NetlinkGenericDataIter<'a> {
85     type Item = NlAttrWithData<'a>;
86 
next(&mut self) -> Option<Self::Item>87     fn next(&mut self) -> Option<Self::Item> {
88         if self.data.len() < NLA_HDRLEN {
89             return None;
90         }
91         let nl_hdr = NlAttr::from_slice(&self.data[..NLA_HDRLEN])?;
92 
93         // Make sure NlAtrr fits
94         let nl_data_len = nl_hdr.len as usize;
95         if nl_data_len < NLA_HDRLEN || nl_data_len > self.data.len() {
96             return None;
97         }
98 
99         // Get data related to processed NlAttr
100         let data_start = NLA_HDRLEN;
101         let data = &self.data[data_start..nl_data_len];
102 
103         // Get next NlAttr
104         let next_hdr = nlattr_align(nl_data_len);
105         if next_hdr >= self.data.len() {
106             self.data = &[];
107         } else {
108             self.data = &self.data[next_hdr..];
109         }
110 
111         Some(NlAttrWithData {
112             _type: nl_hdr._type,
113             len: nl_hdr.len,
114             data,
115         })
116     }
117 }
118 
119 /// Iterator over `struct nlmsghdr` as received from a netlink socket.
120 pub struct NetlinkMessageIter<'a> {
121     // `data` must be properly aligned for nlmsghdr.
122     data: &'a [u8],
123 }
124 
125 impl<'a> Iterator for NetlinkMessageIter<'a> {
126     type Item = NetlinkMessage<'a>;
127 
next(&mut self) -> Option<Self::Item>128     fn next(&mut self) -> Option<Self::Item> {
129         if self.data.len() < NLMSGHDR_SIZE {
130             return None;
131         }
132         let hdr = NlMsgHdr::from_slice(&self.data[..NLMSGHDR_SIZE])?;
133 
134         // NLMSG_OK
135         let msg_len = hdr.nlmsg_len as usize;
136         if msg_len < NLMSGHDR_SIZE || msg_len > self.data.len() {
137             return None;
138         }
139 
140         // NLMSG_DATA
141         let data_start = NLMSGHDR_SIZE;
142         let data = &self.data[data_start..msg_len];
143 
144         // NLMSG_NEXT
145         let align_to = std::mem::align_of::<NlMsgHdr>();
146         let next_hdr = (msg_len + align_to - 1) & !(align_to - 1);
147         if next_hdr >= self.data.len() {
148             self.data = &[];
149         } else {
150             self.data = &self.data[next_hdr..];
151         }
152 
153         Some(NetlinkMessage {
154             _type: hdr.nlmsg_type,
155             flags: hdr.nlmsg_flags,
156             seq: hdr.nlmsg_seq,
157             pid: hdr.nlmsg_pid,
158             data,
159         })
160     }
161 }
162 
163 /// Safe wrapper for `NETLINK_GENERIC` netlink sockets.
164 pub struct NetlinkGenericSocket {
165     sock: SafeDescriptor,
166 }
167 
168 impl AsRawDescriptor for NetlinkGenericSocket {
as_raw_descriptor(&self) -> RawDescriptor169     fn as_raw_descriptor(&self) -> RawDescriptor {
170         self.sock.as_raw_descriptor()
171     }
172 }
173 
174 impl NetlinkGenericSocket {
175     /// Create and bind a new `NETLINK_GENERIC` socket.
new(nl_groups: u32) -> Result<Self>176     pub fn new(nl_groups: u32) -> Result<Self> {
177         // Safe because we check the return value and convert the raw fd into a SafeDescriptor.
178         let sock = unsafe {
179             let fd = libc::socket(
180                 libc::AF_NETLINK,
181                 libc::SOCK_RAW | libc::SOCK_CLOEXEC,
182                 libc::NETLINK_GENERIC,
183             );
184             if fd < 0 {
185                 return errno_result();
186             }
187 
188             SafeDescriptor::from_raw_descriptor(fd)
189         };
190 
191         // This MaybeUninit dance is needed because sockaddr_nl has a private padding field and
192         // doesn't implement Default. Safe because all 0s is valid data for sockaddr_nl.
193         let mut sa = unsafe { MaybeUninit::<libc::sockaddr_nl>::zeroed().assume_init() };
194         sa.nl_family = libc::AF_NETLINK as libc::sa_family_t;
195         sa.nl_groups = nl_groups;
196 
197         // Safe because we pass a descriptor that we own and valid pointer/size for sockaddr.
198         unsafe {
199             let res = libc::bind(
200                 sock.as_raw_fd(),
201                 &sa as *const libc::sockaddr_nl as *const libc::sockaddr,
202                 std::mem::size_of_val(&sa) as libc::socklen_t,
203             );
204             if res < 0 {
205                 return errno_result();
206             }
207         }
208 
209         Ok(NetlinkGenericSocket { sock })
210     }
211 
212     /// Receive messages from the netlink socket.
recv(&self) -> Result<NetlinkGenericRead>213     pub fn recv(&self) -> Result<NetlinkGenericRead> {
214         let buf_size = 8192; // TODO(dverkamp): make this configurable?
215 
216         // Create a buffer with sufficient alignment for nlmsghdr.
217         let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
218             .map_err(|_| Error::new(EINVAL))?;
219         let allocation = LayoutAllocation::uninitialized(layout);
220 
221         // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
222         let bytes_read = unsafe {
223             let res = libc::recv(
224                 self.sock.as_raw_fd(),
225                 allocation.as_ptr() as *mut libc::c_void,
226                 buf_size,
227                 0,
228             );
229             if res < 0 {
230                 return errno_result();
231             }
232             res as usize
233         };
234 
235         Ok(NetlinkGenericRead {
236             allocation,
237             len: bytes_read,
238         })
239     }
240 
family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead>241     pub fn family_name_query(&self, family_name: String) -> Result<NetlinkGenericRead> {
242         let buf_size = 1024;
243         debug_pr!(
244             "preparing query for family name {}, len {}",
245             family_name,
246             family_name.len()
247         );
248 
249         // Create a buffer with sufficient alignment for nlmsghdr.
250         let layout = Layout::from_size_align(buf_size, std::mem::align_of::<NlMsgHdr>())
251             .map_err(|_| Error::new(EINVAL))
252             .unwrap();
253         let mut allocation = LayoutAllocation::zeroed(layout);
254 
255         // Safe because the data in allocation was initialized up to `buf_size` and is
256         // sufficiently aligned.
257         let data = unsafe { allocation.as_mut_slice(buf_size) };
258 
259         // Prepare the netlink message header
260         let mut hdr =
261             &mut NlMsgHdr::from_mut_slice(&mut data[..NLMSGHDR_SIZE]).expect("failed to unwrap");
262         hdr.nlmsg_len = NLMSGHDR_SIZE as u32 + GENL_HDRLEN as u32;
263         hdr.nlmsg_len += NLA_HDRLEN as u32 + family_name.len() as u32 + 1;
264         hdr.nlmsg_flags = libc::NLM_F_REQUEST as u16;
265         hdr.nlmsg_type = libc::GENL_ID_CTRL as u16;
266         hdr.nlmsg_pid = getpid() as u32;
267 
268         // Prepare generic netlink message header
269         let genl_hdr_end = NLMSGHDR_SIZE + GENL_HDRLEN;
270         let genl_hdr = GenlMsgHdr::from_mut_slice(&mut data[NLMSGHDR_SIZE..genl_hdr_end])
271             .expect("unable to get GenlMsgHdr from slice");
272         genl_hdr.cmd = libc::CTRL_CMD_GETFAMILY as u8;
273         genl_hdr.version = 0x1;
274 
275         // Netlink attributes
276         let nlattr_start = genl_hdr_end;
277         let nlattr_end = nlattr_start + NLA_HDRLEN;
278         let nl_attr = NlAttr::from_mut_slice(&mut data[nlattr_start..nlattr_end])
279             .expect("unable to get NlAttr from slice");
280         nl_attr._type = libc::CTRL_ATTR_FAMILY_NAME as u16;
281         nl_attr.len = family_name.len() as u16 + 1 + NLA_HDRLEN as u16;
282 
283         // Fill the message payload with the family name
284         let payload_start = nlattr_end;
285         let payload_end = payload_start + family_name.len();
286         data[payload_start..payload_end].copy_from_slice(family_name.as_bytes());
287 
288         // Safe because we pass a valid, owned socket fd and a valid pointer/size for the buffer.
289         let _bytes_read = unsafe {
290             let res = libc::send(
291                 self.sock.as_raw_fd(),
292                 allocation.as_ptr() as *mut libc::c_void,
293                 payload_end + 1,
294                 0,
295             );
296             if res < 0 {
297                 error!("failed to send get_family_cmd");
298                 return errno_result();
299             }
300         };
301 
302         // Return the answer
303         match self.recv() {
304             Ok(msg) => return Ok(msg),
305             Err(e) => {
306                 error!("recv get_family returned with error {}", e);
307                 return Err(e);
308             }
309         };
310     }
311 }
312 
parse_ctrl_group_name_and_id( nested_nl_attr_data: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>313 fn parse_ctrl_group_name_and_id(
314     nested_nl_attr_data: NetlinkGenericDataIter,
315     group_name: &str,
316 ) -> Option<u32> {
317     let mut mcast_group_id: Option<u32> = None;
318 
319     for nested_nl_attr in nested_nl_attr_data {
320         debug_pr!(
321             "\t\tmcast_grp: nlattr type {}, len {}",
322             nested_nl_attr._type,
323             nested_nl_attr.len
324         );
325 
326         if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_ID as u16 {
327             mcast_group_id = Some(u32::from_ne_bytes(nested_nl_attr.data.try_into().unwrap()));
328             debug_pr!("\t\t mcast group_id {}", mcast_group_id?);
329         }
330 
331         if nested_nl_attr._type == libc::CTRL_ATTR_MCAST_GRP_NAME as u16 {
332             debug_pr!(
333                 "\t\t mcast group name {}",
334                 strip_padding(&nested_nl_attr.data)
335             );
336 
337             // If the group name match and the group_id was set in previous iteration, return,
338             // valid for group_name, group_id
339             if group_name.eq(strip_padding(nested_nl_attr.data)) && mcast_group_id.is_some() {
340                 debug_pr!(
341                     "\t\t Got what we were looking for group_id = {} for {}",
342                     mcast_group_id?,
343                     group_name
344                 );
345 
346                 return mcast_group_id;
347             }
348         }
349     }
350 
351     None
352 }
353 
354 /// Parse CTRL_ATTR_MCAST_GROUPS data in order to get multicast group id
355 ///
356 /// On success, returns group_id for a given `group_name`
357 ///
358 /// # Arguments
359 ///
360 /// * `nl_attr_area`    - Nested attributes area (CTRL_ATTR_MCAST_GROUPS data), where nl_attr's
361 ///                       corresponding to specific groups are embed
362 /// * `group_name`      - String with group_name for which we are looking group_id
363 ///
364 /// the CTRL_ATTR_MCAST_GROUPS data has nested attributes. Each of nested attribute is per
365 /// multicast group attributes, which have another nested attributes: CTRL_ATTR_MCAST_GRP_NAME and
366 /// CTRL_ATTR_MCAST_GRP_ID. Need to parse all of them to get mcast group id for a given group_name..
367 ///
368 /// Illustrated layout:
369 /// CTRL_ATTR_MCAST_GROUPS:
370 ///   GR1 (nl_attr._type = 1):
371 ///       CTRL_ATTR_MCAST_GRP_ID,
372 ///       CTRL_ATTR_MCAST_GRP_NAME,
373 ///   GR2 (nl_attr._type = 2):
374 ///       CTRL_ATTR_MCAST_GRP_ID,
375 ///       CTRL_ATTR_MCAST_GRP_NAME,
376 ///   ..
377 ///
378 /// Unfortunately kernel implementation uses `nla_nest_start_noflag` for that
379 /// purpose, which means that it never marked their nest attributes with NLA_F_NESTED flag.
380 /// Therefore all this nesting stages need to be deduced based on specific nl_attr type.
parse_ctrl_mcast_group_id( nl_attr_area: NetlinkGenericDataIter, group_name: &str, ) -> Option<u32>381 fn parse_ctrl_mcast_group_id(
382     nl_attr_area: NetlinkGenericDataIter,
383     group_name: &str,
384 ) -> Option<u32> {
385     // There may be multiple nested multicast groups, go through all of them.
386     // Each of nested group, has other nested nlattr:
387     //  CTRL_ATTR_MCAST_GRP_ID
388     //  CTRL_ATTR_MCAST_GRP_NAME
389     //
390     //  which are further proceed by parse_ctrl_group_name_and_id
391     for nested_gr_nl_attr in nl_attr_area {
392         debug_pr!(
393             "\tmcast_groups: nlattr type(gr_nr) {}, len {}",
394             nested_gr_nl_attr._type,
395             nested_gr_nl_attr.len
396         );
397 
398         let netlink_nested_attr = NetlinkGenericDataIter {
399             data: nested_gr_nl_attr.data,
400         };
401 
402         if let Some(mcast_group_id) = parse_ctrl_group_name_and_id(netlink_nested_attr, group_name)
403         {
404             return Some(mcast_group_id);
405         }
406     }
407 
408     None
409 }
410 
411 // Like `CStr::from_bytes_with_nul` but strips any bytes starting from first '\0'-byte and
412 // returns &str. Panics if `b` doesn't contain any '\0' bytes.
strip_padding(b: &[u8]) -> &str413 fn strip_padding(b: &[u8]) -> &str {
414     // It would be nice if we could use memchr here but that's locked behind an unstable gate.
415     let pos = b
416         .iter()
417         .position(|&c| c == 0)
418         .expect("`b` doesn't contain any nul bytes");
419 
420     str::from_utf8(&b[..pos]).unwrap()
421 }
422 
423 pub struct NetlinkGenericRead {
424     allocation: LayoutAllocation,
425     len: usize,
426 }
427 
428 impl NetlinkGenericRead {
iter(&self) -> NetlinkMessageIter429     pub fn iter(&self) -> NetlinkMessageIter {
430         // Safe because the data in allocation was initialized up to `self.len` by `recv()` and is
431         // sufficiently aligned.
432         let data = unsafe { &self.allocation.as_slice(self.len) };
433         NetlinkMessageIter { data }
434     }
435 
436     /// Parse NetlinkGeneric response in order to get multicast group id
437     ///
438     /// On success, returns group_id for a given `group_name`
439     ///
440     /// # Arguments
441     ///
442     /// * `group_name` - String with group_name for which we are looking group_id
443     ///
444     /// Response from family_name_query (CTRL_CMD_GETFAMILY) is a netlink message with multiple
445     /// attributes encapsulated (some of them are nested). An example response layout is
446     /// illustrated below:
447     ///
448     ///  {
449     ///    CTRL_ATTR_FAMILY_NAME
450     ///    CTRL_ATTR_FAMILY_ID
451     ///    CTRL_ATTR_VERSION
452     ///    ...
453     ///    CTRL_ATTR_MCAST_GROUPS {
454     ///      GR1 (nl_attr._type = 1) {
455     ///          CTRL_ATTR_MCAST_GRP_ID    *we need parse this attr to obtain group id used for
456     ///                                     the group mask
457     ///          CTRL_ATTR_MCAST_GRP_NAME  *group_name that we need to match with
458     ///      }
459     ///      GR2 (nl_attr._type = 2) {
460     ///          CTRL_ATTR_MCAST_GRP_ID
461     ///          CTRL_ATTR_MCAST_GRP_NAME
462     ///      }
463     ///      ...
464     ///     }
465     ///   }
466     ///
get_multicast_group_id(&self, group_name: String) -> Option<u32>467     pub fn get_multicast_group_id(&self, group_name: String) -> Option<u32> {
468         for netlink_msg in self.iter() {
469             debug_pr!(
470                 "received type: {}, flags {}, pid {}, data {:?}",
471                 netlink_msg._type,
472                 netlink_msg.flags,
473                 netlink_msg.pid,
474                 netlink_msg.data
475             );
476 
477             if netlink_msg._type != libc::GENL_ID_CTRL as u16 {
478                 error!("Received not a generic netlink controller msg");
479                 return None;
480             }
481 
482             let netlink_data = NetlinkGenericDataIter {
483                 data: &netlink_msg.data[GENL_HDRLEN..],
484             };
485             for nl_attr in netlink_data {
486                 debug_pr!("nl_attr type {}, len {}", nl_attr._type, nl_attr.len);
487 
488                 if nl_attr._type == libc::CTRL_ATTR_MCAST_GROUPS as u16 {
489                     let netlink_nested_attr = NetlinkGenericDataIter { data: nl_attr.data };
490 
491                     if let Some(mcast_group_id) =
492                         parse_ctrl_mcast_group_id(netlink_nested_attr, &group_name)
493                     {
494                         return Some(mcast_group_id);
495                     }
496                 }
497             }
498         }
499         None
500     }
501 }
502