// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use crate::backends::intermediate::{ ComputedOffset, ComputedOffsetId, ComputedValue, ComputedValueId, }; /// This trait is implemented on computed quantities (offsets and values) that can be retrieved via a function call pub trait Declarable { fn get_name(&self) -> String; fn get_ident(&self) -> Ident { format_ident!("try_get_{}", self.get_name()) } fn call_fn(&self) -> TokenStream { let fn_name = self.get_ident(); quote! { self.#fn_name()? } } fn declare_fn(&self, body: TokenStream) -> TokenStream { let fn_name = self.get_ident(); quote! { #[inline] fn #fn_name(&self) -> Result { #body } } } } impl Declarable for ComputedValueId<'_> { fn get_name(&self) -> String { match self { ComputedValueId::FieldSize(field) => format!("{field}_size"), ComputedValueId::FieldElementSize(field) => format!("{field}_element_size"), ComputedValueId::FieldCount(field) => format!("{field}_count"), ComputedValueId::Custom(i) => format!("custom_value_{i}"), } } } impl Declarable for ComputedOffsetId<'_> { fn get_name(&self) -> String { match self { ComputedOffsetId::HeaderStart => "header_start_offset".to_string(), ComputedOffsetId::PacketEnd => "packet_end_offset".to_string(), ComputedOffsetId::FieldOffset(field) => format!("{field}_offset"), ComputedOffsetId::FieldEndOffset(field) => format!("{field}_end_offset"), ComputedOffsetId::Custom(i) => format!("custom_offset_{i}"), ComputedOffsetId::TrailerStart => "trailer_start_offset".to_string(), } } } /// This trait is implemented on computed expressions that are computed on-demand (i.e. not via a function call) pub trait Computable { fn compute(&self) -> TokenStream; } impl Computable for ComputedValue<'_> { fn compute(&self) -> TokenStream { match self { ComputedValue::Constant(k) => quote! { Ok(#k) }, ComputedValue::CountStructsUpToSize { base_id, size, struct_type } => { let base_offset = base_id.call_fn(); let size = size.call_fn(); let struct_type = format_ident!("{struct_type}View"); quote! { let mut cnt = 0; let mut view = self.buf.offset(#base_offset)?; let mut remaining_size = #size; while remaining_size > 0 { let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?; if next_struct_size > remaining_size { return Err(ParseError::OutOfBoundsAccess); } remaining_size -= next_struct_size; view = view.offset(next_struct_size * 8)?; cnt += 1; } Ok(cnt) } } ComputedValue::SizeOfNStructs { base_id, n, struct_type } => { let base_offset = base_id.call_fn(); let n = n.call_fn(); let struct_type = format_ident!("{struct_type}View"); quote! { let mut view = self.buf.offset(#base_offset)?; let mut size = 0; for _ in 0..#n { let next_struct_size = #struct_type::try_parse(view)?.try_get_size()?; size += next_struct_size; view = view.offset(next_struct_size * 8)?; } Ok(size) } } ComputedValue::Product(x, y) => { let x = x.call_fn(); let y = y.call_fn(); quote! { #x.checked_mul(#y).ok_or(ParseError::ArithmeticOverflow) } } ComputedValue::Divide(x, y) => { let x = x.call_fn(); let y = y.call_fn(); quote! { if #y == 0 || #x % #y != 0 { return Err(ParseError::DivisionFailure) } Ok(#x / #y) } } ComputedValue::Difference(x, y) => { let x = x.call_fn(); let y = y.call_fn(); quote! { let bit_difference = #x.checked_sub(#y).ok_or(ParseError::ArithmeticOverflow)?; if bit_difference % 8 != 0 { return Err(ParseError::DivisionFailure); } Ok(bit_difference / 8) } } ComputedValue::ValueAt { offset, width } => { let offset = offset.call_fn(); quote! { self.buf.offset(#offset)?.slice(#width)?.try_parse() } } } } } impl Computable for ComputedOffset<'_> { fn compute(&self) -> TokenStream { match self { ComputedOffset::ConstantPlusOffsetInBits(base_id, offset) => { let base_id = base_id.call_fn(); quote! { #base_id.checked_add_signed(#offset as isize).ok_or(ParseError::ArithmeticOverflow) } } ComputedOffset::SumWithOctets(x, y) => { let x = x.call_fn(); let y = y.call_fn(); quote! { #x.checked_add(#y.checked_mul(8).ok_or(ParseError::ArithmeticOverflow)?) .ok_or(ParseError::ArithmeticOverflow) } } ComputedOffset::Alias(alias) => { let alias = alias.call_fn(); quote! { Ok(#alias) } } } } }