// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::convert::TryFrom; use std::convert::TryInto; use std::ops::Deref; #[derive(Debug)] pub enum ParseError { InvalidEnumValue, DivisionFailure, ArithmeticOverflow, OutOfBoundsAccess, MisalignedPayload, } #[derive(Clone, Copy, Debug)] pub struct BitSlice<'a> { // note: the offsets are ENTIRELY UNRELATED to the size of this struct, // so indexing needs to be checked to avoid panics backing: &'a [u8], // invariant: end_bit_offset >= start_bit_offset, so subtraction will NEVER wrap start_bit_offset: usize, end_bit_offset: usize, } #[derive(Clone, Copy, Debug)] pub struct SizedBitSlice<'a>(BitSlice<'a>); impl<'a> BitSlice<'a> { pub fn offset(&self, offset: usize) -> Result, ParseError> { if self.end_bit_offset - self.start_bit_offset < offset { return Err(ParseError::OutOfBoundsAccess); } Ok(Self { backing: self.backing, start_bit_offset: self .start_bit_offset .checked_add(offset) .ok_or(ParseError::ArithmeticOverflow)?, end_bit_offset: self.end_bit_offset, }) } pub fn slice(&self, len: usize) -> Result, ParseError> { if self.end_bit_offset - self.start_bit_offset < len { return Err(ParseError::OutOfBoundsAccess); } Ok(SizedBitSlice(Self { backing: self.backing, start_bit_offset: self.start_bit_offset, end_bit_offset: self .start_bit_offset .checked_add(len) .ok_or(ParseError::ArithmeticOverflow)?, })) } fn byte_at(&self, index: usize) -> Result { self.backing.get(index).ok_or(ParseError::OutOfBoundsAccess).copied() } } impl<'a> Deref for SizedBitSlice<'a> { type Target = BitSlice<'a>; fn deref(&self) -> &Self::Target { &self.0 } } impl<'a> From> for BitSlice<'a> { fn from(x: SizedBitSlice<'a>) -> Self { *x } } impl<'a, 'b> From<&'b [u8]> for SizedBitSlice<'a> where 'b: 'a, { fn from(backing: &'a [u8]) -> Self { Self(BitSlice { backing, start_bit_offset: 0, end_bit_offset: backing.len() * 8 }) } } impl<'a> SizedBitSlice<'a> { pub fn try_parse>(&self) -> Result { if self.end_bit_offset < self.start_bit_offset { return Err(ParseError::OutOfBoundsAccess); } let size_in_bits = self.end_bit_offset - self.start_bit_offset; // fields that fit into a u64 don't need to be byte-aligned if size_in_bits <= 64 { let mut accumulator = 0u64; // where we are in our accumulation let mut curr_byte_index = self.start_bit_offset / 8; let mut curr_bit_offset = self.start_bit_offset % 8; let mut remaining_bits = size_in_bits; while remaining_bits > 0 { // how many bits to take from the current byte? // check if this is the last byte if curr_bit_offset + remaining_bits <= 8 { let tmp = ((self.byte_at(curr_byte_index)? >> curr_bit_offset) as u64) & ((1u64 << remaining_bits) - 1); accumulator += tmp << (size_in_bits - remaining_bits); break; } else { // this is not the last byte, so we have 8 - curr_bit_offset bits to // consume in this byte let bits_to_consume = 8 - curr_bit_offset; let tmp = (self.byte_at(curr_byte_index)? >> curr_bit_offset) as u64; accumulator += tmp << (size_in_bits - remaining_bits); curr_bit_offset = 0; curr_byte_index += 1; remaining_bits -= bits_to_consume as usize; } } T::try_from(accumulator).map_err(|_| ParseError::ArithmeticOverflow) } else { return Err(ParseError::MisalignedPayload); } } pub fn get_size_in_bits(&self) -> usize { self.end_bit_offset - self.start_bit_offset } } pub trait Packet<'a> where Self: Sized, { type Parent; type Owned; type Builder; fn try_parse_from_buffer(buf: impl Into>) -> Result; fn try_parse(parent: Self::Parent) -> Result; fn to_owned_packet(&self) -> Self::Owned; } pub trait OwnedPacket where Self: Sized, { // Enable GAT when 1.65 is available in AOSP // type View<'a> where Self : 'a; fn try_parse(buf: Box<[u8]>) -> Result; // fn view<'a>(&'a self) -> Self::View<'a>; } pub trait Builder: Serializable { type OwnedPacket: OwnedPacket; } #[derive(Debug)] pub enum SerializeError { NegativePadding, IntegerConversionFailure, ValueTooLarge, AlignmentError, } pub trait BitWriter { fn write_bits>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result, ) -> Result<(), SerializeError>; } pub trait Serializable { fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError>; fn size_in_bits(&self) -> Result { let mut sizer = Sizer::new(); self.serialize(&mut sizer)?; Ok(sizer.size()) } fn write(&self, vec: &mut Vec) -> Result<(), SerializeError> { let mut serializer = Serializer::new(vec); self.serialize(&mut serializer)?; serializer.flush(); Ok(()) } fn to_vec(&self) -> Result, SerializeError> { let mut out = vec![]; self.write(&mut out)?; Ok(out) } } struct Sizer { size: usize, } impl Sizer { fn new() -> Self { Self { size: 0 } } fn size(self) -> usize { self.size } } impl BitWriter for Sizer { fn write_bits>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result, ) -> Result<(), SerializeError> { self.size += num_bits; Ok(()) } } struct Serializer<'a> { buf: &'a mut Vec, curr_byte: u8, curr_bit_offset: u8, } impl<'a> Serializer<'a> { fn new(buf: &'a mut Vec) -> Self { Self { buf, curr_byte: 0, curr_bit_offset: 0 } } fn flush(self) { if self.curr_bit_offset > 0 { // partial byte remaining self.buf.push(self.curr_byte << (8 - self.curr_bit_offset)); } } } impl<'a> BitWriter for Serializer<'a> { fn write_bits>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result, ) -> Result<(), SerializeError> { let val = gen_contents()?.into(); if num_bits < 64 && val >= 1 << num_bits { return Err(SerializeError::ValueTooLarge); } let mut remaining_val = val; let mut remaining_bits = num_bits; while remaining_bits > 0 { let remaining_bits_in_curr_byte = (8 - self.curr_bit_offset) as usize; if remaining_bits < remaining_bits_in_curr_byte { // we cannot finish the last byte self.curr_byte += (remaining_val as u8) << self.curr_bit_offset; self.curr_bit_offset += remaining_bits as u8; break; } else { // finish up our current byte and move on let val_for_this_byte = (remaining_val & ((1 << remaining_bits_in_curr_byte) - 1)) as u8; let curr_byte = self.curr_byte + (val_for_this_byte << self.curr_bit_offset); self.buf.push(curr_byte); // clear pending byte self.curr_bit_offset = 0; self.curr_byte = 0; // update what's remaining remaining_val >>= remaining_bits_in_curr_byte; remaining_bits -= remaining_bits_in_curr_byte; } } Ok(()) } }