// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::collections::HashMap; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use crate::{ ast, backends::{ intermediate::{ComputedValue, ComputedValueId, PacketOrStruct, Schema}, rust_no_allocation::utils::get_integer_type, }, }; fn standardize_child(id: &str) -> &str { match id { "_body_" | "_payload_" => "_child_", _ => id, } } pub fn generate_packet_serializer( id: &str, parent_id: Option<&str>, fields: &[ast::Field], schema: &Schema, curr_schema: &PacketOrStruct, children: &HashMap<&str, Vec<&str>>, ) -> TokenStream { let id_ident = format_ident!("{id}Builder"); let builder_fields = fields .iter() .filter_map(|field| { match &field.desc { ast::FieldDesc::Padding { .. } | ast::FieldDesc::Flag { .. } | ast::FieldDesc::Reserved { .. } | ast::FieldDesc::FixedScalar { .. } | ast::FieldDesc::FixedEnum { .. } | ast::FieldDesc::ElementSize { .. } | ast::FieldDesc::Count { .. } | ast::FieldDesc::Size { .. } => { // no-op, no getter generated for this type None } ast::FieldDesc::Group { .. } => unreachable!(), ast::FieldDesc::Checksum { .. } => { unimplemented!("checksums not yet supported with this backend") } ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => { let type_ident = format_ident!("{id}Child"); Some(("_child_", quote! { #type_ident })) } ast::FieldDesc::Array { id, width, type_id, .. } => { let element_type = if let Some(width) = width { get_integer_type(*width) } else if let Some(type_id) = type_id { if schema.enums.contains_key(type_id.as_str()) { format_ident!("{type_id}") } else { format_ident!("{type_id}Builder") } } else { unreachable!(); }; Some((id.as_str(), quote! { Box<[#element_type]> })) } ast::FieldDesc::Scalar { id, width } => { let id_type = get_integer_type(*width); Some((id.as_str(), quote! { #id_type })) } ast::FieldDesc::Typedef { id, type_id } => { let type_ident = if schema.enums.contains_key(type_id.as_str()) { format_ident!("{type_id}") } else { format_ident!("{type_id}Builder") }; Some((id.as_str(), quote! { #type_ident })) } } }) .map(|(id, typ)| { let id_ident = format_ident!("{id}"); quote! { pub #id_ident: #typ } }); let mut has_child = false; let serializer = fields.iter().map(|field| { match &field.desc { ast::FieldDesc::Checksum { .. } | ast::FieldDesc::Group { .. } | ast::FieldDesc::Flag { .. } => unimplemented!(), ast::FieldDesc::Padding { size, .. } => { quote! { if (most_recent_array_size_in_bits > #size * 8) { return Err(SerializeError::NegativePadding); } writer.write_bits((#size * 8 - most_recent_array_size_in_bits) as usize, || Ok(0u64))?; } }, ast::FieldDesc::Size { field_id, width } => { let field_id = standardize_child(field_id); let field_ident = format_ident!("{field_id}"); // if the element-size is fixed, we can directly multiply if let Some(ComputedValue::Constant(element_width)) = curr_schema.computed_values.get(&ComputedValueId::FieldElementSize(field_id)) { return quote! { writer.write_bits( #width, || u64::try_from(self.#field_ident.len() * #element_width).or(Err(SerializeError::IntegerConversionFailure)) )?; } } // if the field is "countable", loop over it to sum up the size if curr_schema.computed_values.contains_key(&ComputedValueId::FieldCount(field_id)) { return quote! { writer.write_bits(#width, || { let size_in_bits = self.#field_ident.iter().map(|elem| elem.size_in_bits()).fold(Ok(0), |total, next| { let total: u64 = total?; let next = u64::try_from(next?).or(Err(SerializeError::IntegerConversionFailure))?; total.checked_add(next).ok_or(SerializeError::IntegerConversionFailure) })?; if size_in_bits % 8 != 0 { return Err(SerializeError::AlignmentError); } Ok(size_in_bits / 8) })?; } } // otherwise, try to get the size directly quote! { writer.write_bits(#width, || { let size_in_bits: u64 = self.#field_ident.size_in_bits()?.try_into().or(Err(SerializeError::IntegerConversionFailure))?; if size_in_bits % 8 != 0 { return Err(SerializeError::AlignmentError); } Ok(size_in_bits / 8) })?; } } ast::FieldDesc::Count { field_id, width } => { let field_ident = format_ident!("{field_id}"); quote! { writer.write_bits(#width, || u64::try_from(self.#field_ident.len()).or(Err(SerializeError::IntegerConversionFailure)))?; } } ast::FieldDesc::ElementSize { field_id, width } => { // TODO(aryarahul) - add validation for elementsize against all the other elements let field_ident = format_ident!("{field_id}"); quote! { let get_element_size = || Ok(if let Some(field) = self.#field_ident.get(0) { let size_in_bits = field.size_in_bits()?; if size_in_bits % 8 == 0 { (size_in_bits / 8) as u64 } else { return Err(SerializeError::AlignmentError); } } else { 0 }); writer.write_bits(#width, || get_element_size() )?; } } ast::FieldDesc::Reserved { width, .. } => { quote!{ writer.write_bits(#width, || Ok(0u64))?; } } ast::FieldDesc::Scalar { width, id } => { let field_ident = format_ident!("{id}"); quote! { writer.write_bits(#width, || Ok(self.#field_ident))?; } } ast::FieldDesc::FixedScalar { width, value } => { let width = quote! { #width }; let value = { let value = *value as u64; quote! { #value } }; quote!{ writer.write_bits(#width, || Ok(#value))?; } } ast::FieldDesc::FixedEnum { enum_id, tag_id } => { let width = { let width = schema.enums[enum_id.as_str()].width; quote! { #width } }; let value = { let enum_ident = format_ident!("{}", enum_id); let tag_ident = format_ident!("{tag_id}"); quote! { #enum_ident::#tag_ident.value() } }; quote!{ writer.write_bits(#width, || Ok(#value))?; } } ast::FieldDesc::Body | ast::FieldDesc::Payload { .. } => { has_child = true; quote! { self._child_.serialize(writer)?; } } ast::FieldDesc::Array { width, id, .. } => { let id_ident = format_ident!("{id}"); if let Some(width) = width { quote! { for elem in self.#id_ident.iter() { writer.write_bits(#width, || Ok(*elem))?; } let most_recent_array_size_in_bits = #width * self.#id_ident.len(); } } else { quote! { let mut most_recent_array_size_in_bits = 0; for elem in self.#id_ident.iter() { most_recent_array_size_in_bits += elem.size_in_bits()?; elem.serialize(writer)?; } } } } ast::FieldDesc::Typedef { id, .. } => { let id_ident = format_ident!("{id}"); quote! { self.#id_ident.serialize(writer)?; } } } }).collect::>(); let variant_names = children.get(id).into_iter().flatten().collect::>(); let variants = variant_names.iter().map(|name| { let name_ident = format_ident!("{name}"); let variant_ident = format_ident!("{name}Builder"); quote! { #name_ident(#variant_ident) } }); let variant_serializers = variant_names.iter().map(|name| { let name_ident = format_ident!("{name}"); quote! { Self::#name_ident(x) => { x.serialize(writer)?; } } }); let children_enum = if has_child { let enum_ident = format_ident!("{id}Child"); quote! { #[derive(Debug, Clone, PartialEq, Eq)] pub enum #enum_ident { RawData(Box<[u8]>), #(#variants),* } impl Serializable for #enum_ident { fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> { match self { Self::RawData(data) => { for byte in data.iter() { writer.write_bits(8, || Ok(*byte as u64))?; } }, #(#variant_serializers),* } Ok(()) } } } } else { quote! {} }; let parent_type_converter = if let Some(parent_id) = parent_id { let parent_enum_ident = format_ident!("{parent_id}Child"); let variant_ident = format_ident!("{id}"); Some(quote! { impl From<#id_ident> for #parent_enum_ident { fn from(x: #id_ident) -> Self { Self::#variant_ident(x) } } }) } else { None }; let owned_packet_ident = format_ident!("Owned{id}View"); quote! { #[derive(Debug, Clone, PartialEq, Eq)] pub struct #id_ident { #(#builder_fields),* } impl Builder for #id_ident { type OwnedPacket = #owned_packet_ident; } impl Serializable for #id_ident { fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError> { #(#serializer)* Ok(()) } } #parent_type_converter #children_enum } }