// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use crate::backends::rust_legacy::{mask_bits, types, ToIdent, ToUpperCamelCase}; use crate::{analyzer, ast}; use quote::{format_ident, quote}; /// A single bit-field value. struct BitField { value: proc_macro2::TokenStream, // An expression which produces a value. field_type: types::Integer, // The type of the value. shift: usize, // A bit-shift to apply to `value`. } pub struct FieldSerializer<'a> { scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, chunk: Vec, code: Vec, shift: usize, } impl<'a> FieldSerializer<'a> { pub fn new( scope: &'a analyzer::Scope<'a>, schema: &'a analyzer::Schema, endianness: ast::EndiannessValue, packet_name: &'a str, span: &'a proc_macro2::Ident, ) -> FieldSerializer<'a> { FieldSerializer { scope, schema, endianness, packet_name, span, chunk: Vec::new(), code: Vec::new(), shift: 0, } } pub fn add(&mut self, field: &ast::Field) { match &field.desc { _ if field.cond.is_some() => self.add_optional_field(field), _ if self.scope.is_bitfield(field) => self.add_bit_field(field), ast::FieldDesc::Array { id, width, .. } => self.add_array_field( id, *width, self.schema.padded_size(field.key), self.scope.get_type_declaration(field), ), ast::FieldDesc::Typedef { id, type_id } => { self.add_typedef_field(id, type_id); } ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. } => { self.add_payload_field(); } // Padding field handled in serialization of associated array field. ast::FieldDesc::Padding { .. } => (), _ => todo!("Cannot yet serialize {field:?}"), } } fn add_optional_field(&mut self, field: &ast::Field) { self.code.push(match &field.desc { ast::FieldDesc::Scalar { id, width } => { let name = id; let id = id.to_ident(); let backing_type = types::Integer::new(*width); let write = types::put_uint(self.endianness, "e!(*#id), *width, self.span); let range_check = (backing_type.width > *width).then(|| { let packet_name = &self.packet_name; let max_value = mask_bits(*width, "u64"); quote! { if *#id > #max_value { return Err(EncodeError::InvalidScalarValue { packet: #packet_name, field: #name, value: *#id as u64, maximum_value: #max_value as u64, }) } } }); quote! { if let Some(#id) = &self.#id { #range_check #write } } } ast::FieldDesc::Typedef { id, type_id } => match &self.scope.typedef[type_id].desc { ast::DeclDesc::Enum { width, .. } => { let id = id.to_ident(); let backing_type = types::Integer::new(*width); let write = types::put_uint( self.endianness, "e!(#backing_type::from(#id)), *width, self.span, ); quote! { if let Some(#id) = &self.#id { #write } } } ast::DeclDesc::Struct { .. } => { let id = id.to_ident(); let span = self.span; quote! { if let Some(#id) = &self.#id { #id.write_to(#span)?; } } } _ => unreachable!(), }, _ => unreachable!(), }) } fn add_bit_field(&mut self, field: &ast::Field) { let width = self.schema.field_size(field.key).static_().unwrap(); let shift = self.shift; match &field.desc { ast::FieldDesc::Flag { optional_field_id, set_value, .. } => { let optional_field_id = optional_field_id.to_ident(); let cond_value_present = syn::parse_str::(&format!("{}", set_value)).unwrap(); let cond_value_absent = syn::parse_str::(&format!("{}", 1 - set_value)).unwrap(); self.chunk.push(BitField { value: quote! { if self.#optional_field_id.is_some() { #cond_value_present } else { #cond_value_absent } }, field_type: types::Integer::new(1), shift, }); } ast::FieldDesc::Scalar { id, width } => { let field_name = id.to_ident(); let field_type = types::Integer::new(*width); if field_type.width > *width { let packet_name = &self.packet_name; let max_value = mask_bits(*width, "u64"); self.code.push(quote! { if self.#field_name > #max_value { return Err(EncodeError::InvalidScalarValue { packet: #packet_name, field: #id, value: self.#field_name as u64, maximum_value: #max_value, }) } }); } self.chunk.push(BitField { value: quote!(self.#field_name), field_type, shift }); } ast::FieldDesc::FixedEnum { enum_id, tag_id, .. } => { let field_type = types::Integer::new(width); let enum_id = enum_id.to_ident(); let tag_id = format_ident!("{}", tag_id.to_upper_camel_case()); self.chunk.push(BitField { value: quote!(#field_type::from(#enum_id::#tag_id)), field_type, shift, }); } ast::FieldDesc::FixedScalar { value, .. } => { let field_type = types::Integer::new(width); let value = proc_macro2::Literal::usize_unsuffixed(*value); self.chunk.push(BitField { value: quote!(#value), field_type, shift }); } ast::FieldDesc::Typedef { id, .. } => { let field_name = id.to_ident(); let field_type = types::Integer::new(width); self.chunk.push(BitField { value: quote!(#field_type::from(self.#field_name)), field_type, shift, }); } ast::FieldDesc::Reserved { .. } => { // Nothing to do here. } ast::FieldDesc::Size { field_id, width, .. } => { let packet_name = &self.packet_name; let max_value = mask_bits(*width, "usize"); let decl = self.scope.typedef.get(self.packet_name).unwrap(); let value_field = self .scope .iter_fields(decl) .find(|field| match &field.desc { ast::FieldDesc::Payload { .. } => field_id == "_payload_", ast::FieldDesc::Body { .. } => field_id == "_body_", _ => field.id() == Some(field_id), }) .unwrap(); let field_name = field_id.to_ident(); let field_type = types::Integer::new(*width); // TODO: size modifier let value_field_decl = self.scope.get_type_declaration(value_field); let field_size_name = format_ident!("{field_id}_size"); let array_size = match (&value_field.desc, value_field_decl.map(|decl| &decl.desc)) { (ast::FieldDesc::Payload { size_modifier: Some(size_modifier) }, _) => { let size_modifier = proc_macro2::Literal::usize_unsuffixed( size_modifier .parse::() .expect("failed to parse the size modifier"), ); if let ast::DeclDesc::Packet { .. } = &decl.desc { quote! { (self.child.get_total_size() + #size_modifier) } } else { quote! { (self.payload.len() + #size_modifier) } } } (ast::FieldDesc::Payload { .. } | ast::FieldDesc::Body { .. }, _) => { if let ast::DeclDesc::Packet { .. } = &decl.desc { quote! { self.child.get_total_size() } } else { quote! { self.payload.len() } } } (ast::FieldDesc::Array { width: Some(width), .. }, _) | (ast::FieldDesc::Array { .. }, Some(ast::DeclDesc::Enum { width, .. })) => { let byte_width = syn::Index::from(width / 8); if byte_width.index == 1 { quote! { self.#field_name.len() } } else { quote! { (self.#field_name.len() * #byte_width) } } } (ast::FieldDesc::Array { .. }, _) => { self.code.push(quote! { let #field_size_name = self.#field_name .iter() .map(|elem| elem.get_size()) .sum::(); }); quote! { #field_size_name } } _ => panic!("Unexpected size field: {field:?}"), }; self.code.push(quote! { if #array_size > #max_value { return Err(EncodeError::SizeOverflow { packet: #packet_name, field: #field_id, size: #array_size, maximum_size: #max_value, }) } }); self.chunk.push(BitField { value: quote!(#array_size as #field_type), field_type, shift, }); } ast::FieldDesc::Count { field_id, width, .. } => { let field_name = field_id.to_ident(); let field_type = types::Integer::new(*width); if field_type.width > *width { let packet_name = &self.packet_name; let max_value = mask_bits(*width, "usize"); self.code.push(quote! { if self.#field_name.len() > #max_value { return Err(EncodeError::CountOverflow { packet: #packet_name, field: #field_id, count: self.#field_name.len(), maximum_count: #max_value, }) } }); } self.chunk.push(BitField { value: quote!(self.#field_name.len() as #field_type), field_type, shift, }); } _ => todo!("{field:?}"), } self.shift += width; if self.shift % 8 == 0 { self.pack_bit_fields() } } fn pack_bit_fields(&mut self) { assert_eq!(self.shift % 8, 0); let chunk_type = types::Integer::new(self.shift); let values = self .chunk .drain(..) .map(|BitField { mut value, field_type, shift }| { if field_type.width != chunk_type.width { // We will be combining values with `|`, so we // need to cast them first. value = quote! { (#value as #chunk_type) }; } if shift > 0 { let op = quote!(<<); let shift = proc_macro2::Literal::usize_unsuffixed(shift); value = quote! { (#value #op #shift) }; } value }) .collect::>(); match values.as_slice() { [] => { let span = format_ident!("{}", self.span); let count = syn::Index::from(self.shift / 8); self.code.push(quote! { #span.put_bytes(0, #count); }); } [value] => { let put = types::put_uint(self.endianness, value, self.shift, self.span); self.code.push(quote! { #put; }); } _ => { let put = types::put_uint(self.endianness, "e!(value), self.shift, self.span); self.code.push(quote! { let value = #(#values)|*; #put; }); } } self.shift = 0; } fn add_array_field( &mut self, id: &str, width: Option, padding_size: Option, decl: Option<&ast::Decl>, ) { let span = format_ident!("{}", self.span); let serialize = match width { Some(width) => { let value = quote!(*elem); types::put_uint(self.endianness, &value, width, self.span) } None => { if let Some(ast::DeclDesc::Enum { width, .. }) = decl.map(|decl| &decl.desc) { let element_type = types::Integer::new(*width); types::put_uint( self.endianness, "e!(#element_type::from(elem)), *width, self.span, ) } else { quote! { elem.write_to(#span)? } } } }; let packet_name = self.packet_name; let name = id; let id = id.to_ident(); if let Some(padding_size) = padding_size { let padding_octets = padding_size / 8; let element_width = match &width { Some(width) => Some(*width), None => self.schema.decl_size(decl.unwrap().key).static_(), }; let array_size = match element_width { Some(element_width) => { let element_size = proc_macro2::Literal::usize_unsuffixed(element_width / 8); quote! { self.#id.len() * #element_size } } _ => { quote! { self.#id.iter().fold(0, |size, elem| size + elem.get_size()) } } }; self.code.push(quote! { let array_size = #array_size; if array_size > #padding_octets { return Err(EncodeError::SizeOverflow { packet: #packet_name, field: #name, size: array_size, maximum_size: #padding_octets, }) } for elem in &self.#id { #serialize; } #span.put_bytes(0, #padding_octets - array_size); }); } else { self.code.push(quote! { for elem in &self.#id { #serialize; } }); } } fn add_typedef_field(&mut self, id: &str, type_id: &str) { assert_eq!(self.shift, 0, "Typedef field does not start on an octet boundary"); let decl = self.scope.typedef[type_id]; if let ast::DeclDesc::Struct { parent_id: Some(_), .. } = &decl.desc { panic!("Derived struct used in typedef field"); } let id = id.to_ident(); let span = format_ident!("{}", self.span); self.code.push(match &decl.desc { ast::DeclDesc::Checksum { .. } => todo!(), ast::DeclDesc::CustomField { width: Some(width), .. } => { let backing_type = types::Integer::new(*width); let put_uint = types::put_uint( self.endianness, "e! { #backing_type::from(self.#id) }, *width, self.span, ); quote! { #put_uint; } } ast::DeclDesc::Struct { .. } => quote! { self.#id.write_to(#span)?; }, _ => unreachable!(), }); } fn add_payload_field(&mut self) { if self.shift != 0 && self.endianness == ast::EndiannessValue::BigEndian { panic!("Payload field does not start on an octet boundary"); } let decl = self.scope.typedef[self.packet_name]; let is_packet = matches!(&decl.desc, ast::DeclDesc::Packet { .. }); let child_ids = self .scope .iter_children(decl) .map(|child| child.id().unwrap().to_ident()) .collect::>(); let span = format_ident!("{}", self.span); if self.shift == 0 { if is_packet { let packet_data_child = format_ident!("{}DataChild", self.packet_name); self.code.push(quote! { match &self.child { #(#packet_data_child::#child_ids(child) => child.write_to(#span)?,)* #packet_data_child::Payload(payload) => #span.put_slice(payload), #packet_data_child::None => {}, } }) } else { self.code.push(quote! { #span.put_slice(&self.payload); }); } } else { todo!("Shifted payloads"); } } } impl quote::ToTokens for FieldSerializer<'_> { fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { let code = &self.code; tokens.extend(quote! { #(#code)* }); } }