• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::{err::*, *};
2 
3 /// Slice containing the UDP headers & payload.
4 #[derive(Clone, Debug, Eq, PartialEq)]
5 pub struct UdpSlice<'a> {
6     slice: &'a [u8],
7 }
8 
9 impl<'a> UdpSlice<'a> {
10     /// Decode length from UDP header and restrict slice to the length
11     /// of the header including the payload.
12     ///
13     /// Note that this method fall backs to the length of the slice
14     /// in the case the length field in the UDP header is set to zero.
from_slice(slice: &'a [u8]) -> Result<UdpSlice<'a>, LenError>15     pub fn from_slice(slice: &'a [u8]) -> Result<UdpSlice<'a>, LenError> {
16         // slice header
17         let header = UdpHeaderSlice::from_slice(slice)?;
18 
19         // validate the length of the slice
20         let len: usize = header.length().into();
21         if slice.len() < len {
22             return Err(LenError {
23                 required_len: len,
24                 len: slice.len(),
25                 len_source: LenSource::Slice,
26                 layer: Layer::UdpPayload,
27                 layer_start_offset: 0,
28             });
29         }
30 
31         // fallback to the slice length in case length is set to 0
32         if len == 0 {
33             Ok(UdpSlice { slice })
34         } else {
35             // validate the length
36             if len < UdpHeader::LEN {
37                 // TODO: Should this replaced with a custom error?
38                 Err(LenError {
39                     required_len: UdpHeader::LEN,
40                     len,
41                     len_source: LenSource::UdpHeaderLen,
42                     layer: Layer::UdpHeader,
43                     layer_start_offset: 0,
44                 })
45             } else {
46                 Ok(UdpSlice {
47                     // SAFETY: Safe as slice.len() was validated before to
48                     // be at least as big as "len".
49                     slice: unsafe { core::slice::from_raw_parts(slice.as_ptr(), len) },
50                 })
51             }
52         }
53     }
54 
55     /// Try decoding length from UDP header and restrict slice to the length
56     /// of the header including the payload if possible. If not the slice length
57     /// is used as a fallback value.
58     ///
59     /// Note that this method fall also backs to the length of the slice
60     /// in the case the length field in the UDP header is set to zero or smaller
61     /// then the minimum header length.
from_slice_lax(slice: &'a [u8]) -> Result<UdpSlice<'a>, LenError>62     pub fn from_slice_lax(slice: &'a [u8]) -> Result<UdpSlice<'a>, LenError> {
63         // slice header
64         let header = UdpHeaderSlice::from_slice(slice)?;
65 
66         // validate the length of the slice and fallback to the slice
67         // length if the slice is smaller then expected or zero.
68         let len: usize = header.length().into();
69         if slice.len() < len || len < UdpHeader::LEN {
70             Ok(UdpSlice { slice })
71         } else {
72             Ok(UdpSlice {
73                 // SAFETY: Safe as slice.len() was validated before to
74                 // be at least as big as "len".
75                 slice: unsafe { core::slice::from_raw_parts(slice.as_ptr(), len) },
76             })
77         }
78     }
79 
80     /// Return the slice containing the UDP header & payload.
81     #[inline]
slice(&self) -> &'a [u8]82     pub fn slice(&self) -> &'a [u8] {
83         self.slice
84     }
85 
86     /// Return the slice containing the UDP header.
87     #[inline]
header_slice(&self) -> &'a [u8]88     pub fn header_slice(&self) -> &'a [u8] {
89         unsafe {
90             // SAFETY: Safe as the slice length was verified
91             // to be at least UdpHeader::LEN by "from_slice".
92             core::slice::from_raw_parts(self.slice.as_ptr(), UdpHeader::LEN)
93         }
94     }
95 
96     /// Returns the slice containing the UDP payload.
97     #[inline]
payload(&self) -> &'a [u8]98     pub fn payload(&self) -> &'a [u8] {
99         unsafe {
100             // SAFETY: Safe as the slice length was verified
101             // to be at least UdpHeader::LEN by "from_slice".
102             core::slice::from_raw_parts(
103                 self.slice.as_ptr().add(UdpHeader::LEN),
104                 self.slice.len() - UdpHeader::LEN,
105             )
106         }
107     }
108 
109     /// Value that was used to determine the length of the payload.
110     #[inline]
payload_len_source(&self) -> LenSource111     pub fn payload_len_source(&self) -> LenSource {
112         if usize::from(self.length()) == self.slice.len() {
113             LenSource::UdpHeaderLen
114         } else {
115             LenSource::Slice
116         }
117     }
118 
119     /// Reads the "udp source port" in the UDP header.
120     #[inline]
source_port(&self) -> u16121     pub fn source_port(&self) -> u16 {
122         // SAFETY:
123         // Safe as the contructor checks that the slice has
124         // at least the length of UdpHeader::LEN (8).
125         unsafe { get_unchecked_be_u16(self.slice.as_ptr()) }
126     }
127 
128     /// Reads the "udp destination port" in the UDP header.
129     #[inline]
destination_port(&self) -> u16130     pub fn destination_port(&self) -> u16 {
131         // SAFETY:
132         // Safe as the contructor checks that the slice has
133         // at least the length of UdpHeader::LEN (8).
134         unsafe { get_unchecked_be_u16(self.slice.as_ptr().add(2)) }
135     }
136 
137     /// Reads the "length" field in the UDP header.
138     #[inline]
length(&self) -> u16139     pub fn length(&self) -> u16 {
140         // SAFETY:
141         // Safe as the contructor checks that the slice has
142         // at least the length of UdpHeader::LEN (8).
143         unsafe { get_unchecked_be_u16(self.slice.as_ptr().add(4)) }
144     }
145 
146     /// Reads the "checksum" from the slice.
147     #[inline]
checksum(&self) -> u16148     pub fn checksum(&self) -> u16 {
149         // SAFETY:
150         // Safe as the contructor checks that the slice has
151         // at least the length of UdpHeader::LEN (8).
152         unsafe { get_unchecked_be_u16(self.slice.as_ptr().add(6)) }
153     }
154 
155     /// Length of the UDP header (equal to [`crate::UdpHeader::LEN`]).
156     #[inline]
header_len(&self) -> usize157     pub const fn header_len(&self) -> usize {
158         UdpHeader::LEN
159     }
160 
161     /// Length of the UDP header in an [`u16`] (equal to [`crate::UdpHeader::LEN_U16`]).
162     #[inline]
header_len_u16(&self) -> u16163     pub const fn header_len_u16(&self) -> u16 {
164         UdpHeader::LEN_U16
165     }
166 
167     /// Decode all the fields of the UDP header and copy the results
168     /// to a UdpHeader struct.
169     #[inline]
to_header(&self) -> UdpHeader170     pub fn to_header(&self) -> UdpHeader {
171         UdpHeader {
172             source_port: self.source_port(),
173             destination_port: self.destination_port(),
174             length: self.length(),
175             checksum: self.checksum(),
176         }
177     }
178 }
179 
180 #[cfg(test)]
181 mod test {
182     use super::*;
183     use crate::test_gens::*;
184     use alloc::{format, vec::Vec};
185     use proptest::prelude::*;
186 
187     proptest! {
188         #[test]
189         fn debug_clone_eq(
190             udp_base in udp_any()
191         ) {
192             let payload: [u8;4] = [1,2,3,4];
193             let mut data = Vec::with_capacity(
194                 udp_base.header_len() +
195                 payload.len()
196             );
197             let mut udp = udp_base.clone();
198             udp.length = (UdpHeader::LEN + payload.len()) as u16;
199             data.extend_from_slice(&udp.to_bytes());
200             data.extend_from_slice(&payload);
201 
202             // decode packet
203             let slice = UdpSlice::from_slice(&data).unwrap();
204 
205             // check debug output
206             prop_assert_eq!(
207                 format!("{:?}", slice),
208                 format!(
209                     "UdpSlice {{ slice: {:?} }}",
210                     &data[..]
211                 )
212             );
213             prop_assert_eq!(slice.clone(), slice);
214         }
215     }
216 
217     proptest! {
218         #[test]
219         fn getters(
220             udp_base in udp_any()
221         ) {
222             let udp = {
223                 let mut udp = udp_base.clone();
224                 udp.length = UdpHeader::LEN as u16;
225                 udp
226             };
227             let data = {
228                 let mut data = Vec::with_capacity(
229                     udp.header_len()
230                 );
231                 data.extend_from_slice(&udp.to_bytes());
232                 data
233             };
234 
235             // normal decode
236             {
237                 let slice = UdpSlice::from_slice(&data).unwrap();
238                 assert_eq!(slice.slice(), &data);
239                 assert_eq!(slice.header_slice(), &data);
240                 assert_eq!(slice.payload(), &[]);
241                 assert_eq!(slice.source_port(), udp.source_port);
242                 assert_eq!(slice.destination_port(), udp.destination_port);
243                 assert_eq!(slice.length(), udp.length);
244                 assert_eq!(slice.checksum(), udp.checksum);
245                 assert_eq!(slice.to_header(), udp);
246             }
247         }
248     }
249 
250     proptest! {
251         #[test]
252         fn from_slice(
253             udp_base in udp_any()
254         ) {
255             let payload: [u8;4] = [1,2,3,4];
256             let udp = {
257                 let mut udp = udp_base.clone();
258                 udp.length = (UdpHeader::LEN + payload.len()) as u16;
259                 udp
260             };
261             let data = {
262                 let mut data = Vec::with_capacity(
263                     udp.header_len() +
264                     payload.len()
265                 );
266                 data.extend_from_slice(&udp.to_bytes());
267                 data.extend_from_slice(&payload);
268                 data
269             };
270 
271             // normal decode
272             {
273                 let slice = UdpSlice::from_slice(&data).unwrap();
274                 assert_eq!(udp, slice.to_header());
275                 assert_eq!(payload, slice.payload());
276             }
277 
278             // decode a payload smaller then the given slice
279             {
280                 let mut mod_data = data.clone();
281                 let reduced_len = (UdpHeader::LEN + payload.len() - 1) as u16;
282                 // inject the reduced length
283                 {
284                     let rl_be = reduced_len.to_be_bytes();
285                     mod_data[4] = rl_be[0];
286                     mod_data[5] = rl_be[1];
287                 }
288 
289                 let slice = UdpSlice::from_slice(&mod_data).unwrap();
290                 assert_eq!(
291                     slice.to_header(),
292                     {
293                         let mut expected = slice.to_header();
294                         expected.length = reduced_len;
295                         expected
296                     }
297                 );
298                 assert_eq!(&payload[..payload.len() - 1], slice.payload());
299             }
300 
301             // if length is zero the length given by the slice should be used
302             {
303                 // inject zero as length
304                 let mut mod_data = data.clone();
305                 mod_data[4] = 0;
306                 mod_data[5] = 0;
307 
308                 let slice = UdpSlice::from_slice(&mod_data).unwrap();
309 
310                 assert_eq!(slice.source_port(), udp_base.source_port);
311                 assert_eq!(slice.destination_port(), udp_base.destination_port);
312                 assert_eq!(slice.checksum(), udp_base.checksum);
313                 assert_eq!(slice.length(), 0);
314                 assert_eq!(&payload, slice.payload());
315             }
316 
317             // too little data to even decode the header
318             for len in 0..UdpHeader::LEN {
319                 assert_eq!(
320                     UdpSlice::from_slice(&data[..len]).unwrap_err(),
321                     LenError {
322                         required_len: UdpHeader::LEN,
323                         len,
324                         len_source: LenSource::Slice,
325                         layer: Layer::UdpHeader,
326                         layer_start_offset: 0,
327                     }
328                 );
329             }
330 
331             // slice length smaller then the length described in the header
332             assert_eq!(
333                 UdpSlice::from_slice(&data[..data.len() - 1]).unwrap_err(),
334                 LenError {
335                     required_len: data.len(),
336                     len: data.len() - 1,
337                     len_source: LenSource::Slice,
338                     layer: Layer::UdpPayload,
339                     layer_start_offset: 0,
340                 }
341             );
342 
343             // length in header smaller than the header itself
344             {
345                 let mut mod_data = data.clone();
346                 // inject the reduced length
347                 {
348                     let len_be = ((UdpHeader::LEN - 1) as u16).to_be_bytes();
349                     mod_data[4] = len_be[0];
350                     mod_data[5] = len_be[1];
351                 }
352                 assert_eq!(
353                     UdpSlice::from_slice(&mod_data).unwrap_err(),
354                     LenError {
355                         required_len: UdpHeader::LEN,
356                         len: UdpHeader::LEN - 1,
357                         len_source: LenSource::UdpHeaderLen,
358                         layer: Layer::UdpHeader,
359                         layer_start_offset: 0
360                     }
361                 );
362             }
363         }
364     }
365 
366     proptest! {
367         #[test]
368         fn from_slice_lax(
369             udp_base in udp_any()
370         ) {
371             let payload: [u8;4] = [1,2,3,4];
372             let udp = {
373                 let mut udp = udp_base.clone();
374                 udp.length = (UdpHeader::LEN + payload.len()) as u16;
375                 udp
376             };
377             let data = {
378                 let mut data = Vec::with_capacity(
379                     udp.header_len() +
380                     payload.len()
381                 );
382                 data.extend_from_slice(&udp.to_bytes());
383                 data.extend_from_slice(&payload);
384                 data
385             };
386 
387             // normal decode
388             {
389                 let slice = UdpSlice::from_slice_lax(&data).unwrap();
390                 assert_eq!(udp, slice.to_header());
391                 assert_eq!(payload, slice.payload());
392                 assert_eq!(slice.payload_len_source(), LenSource::UdpHeaderLen);
393             }
394 
395             // decode a payload smaller then the given slice
396             {
397                 let mut mod_data = data.clone();
398                 let reduced_len = (UdpHeader::LEN + payload.len() - 1) as u16;
399                 // inject the reduced length
400                 {
401                     let rl_be = reduced_len.to_be_bytes();
402                     mod_data[4] = rl_be[0];
403                     mod_data[5] = rl_be[1];
404                 }
405 
406                 let slice = UdpSlice::from_slice_lax(&mod_data).unwrap();
407                 assert_eq!(
408                     slice.to_header(),
409                     {
410                         let mut expected = slice.to_header();
411                         expected.length = reduced_len;
412                         expected
413                     }
414                 );
415                 assert_eq!(&payload[..payload.len() - 1], slice.payload());
416                 assert_eq!(slice.payload_len_source(), LenSource::UdpHeaderLen);
417             }
418 
419             // if length is zero the length given by the slice should be used
420             for len in 0..UdpHeader::LEN_U16{
421                 // inject zero as length
422                 let mut mod_data = data.clone();
423                 mod_data[4] = len.to_be_bytes()[0];
424                 mod_data[5] = len.to_be_bytes()[1];
425 
426                 let slice = UdpSlice::from_slice_lax(&mod_data).unwrap();
427 
428                 assert_eq!(slice.source_port(), udp_base.source_port);
429                 assert_eq!(slice.destination_port(), udp_base.destination_port);
430                 assert_eq!(slice.checksum(), udp_base.checksum);
431                 assert_eq!(slice.length(), len);
432                 assert_eq!(&payload, slice.payload());
433                 assert_eq!(slice.payload_len_source(), LenSource::Slice);
434             }
435 
436             // too little data to even decode the header
437             for len in 0..UdpHeader::LEN {
438                 assert_eq!(
439                     UdpSlice::from_slice_lax(&data[..len]).unwrap_err(),
440                     LenError {
441                         required_len: UdpHeader::LEN,
442                         len,
443                         len_source: LenSource::Slice,
444                         layer: Layer::UdpHeader,
445                         layer_start_offset: 0,
446                     }
447                 );
448             }
449 
450             // slice length smaller then the length described in the header
451             {
452                 let slice = UdpSlice::from_slice_lax(&data[..data.len() - 1]).unwrap();
453                 assert_eq!(udp, slice.to_header());
454                 assert_eq!(&payload[..payload.len() - 1], slice.payload());
455                 assert_eq!(slice.payload_len_source(), LenSource::Slice);
456             }
457         }
458     }
459 
460     proptest! {
461         #[test]
462         fn header_len(
463             udp in udp_any()
464         ) {
465             let mut udp = udp.clone();
466             udp.length = UdpHeader::LEN_U16;
467             let bytes = udp.to_bytes();
468             let slice = UdpSlice::from_slice(&bytes).unwrap();
469             assert_eq!(UdpHeader::LEN, slice.header_len());
470         }
471     }
472 
473     proptest! {
474         #[test]
475         fn header_len_u16(
476             udp in udp_any()
477         ) {
478             let mut udp = udp.clone();
479             udp.length = UdpHeader::LEN_U16;
480             let bytes = udp.to_bytes();
481             let slice = UdpSlice::from_slice(&bytes).unwrap();
482             assert_eq!(UdpHeader::LEN_U16, slice.header_len_u16());
483         }
484     }
485 }
486