//! Support for deriving the `Sequence` trait on structs for the purposes of //! decoding/encoding ASN.1 `SEQUENCE` types as mapped to struct fields. mod field; use crate::{default_lifetime, TypeAttrs}; use field::SequenceField; use proc_macro2::TokenStream; use quote::quote; use syn::{DeriveInput, GenericParam, Generics, Ident, LifetimeParam}; /// Derive the `Sequence` trait for a struct pub(crate) struct DeriveSequence { /// Name of the sequence struct. ident: Ident, /// Generics of the struct. generics: Generics, /// Fields of the struct. fields: Vec, } impl DeriveSequence { /// Parse [`DeriveInput`]. pub fn new(input: DeriveInput) -> syn::Result { let data = match input.data { syn::Data::Struct(data) => data, _ => abort!( input.ident, "can't derive `Sequence` on this type: only `struct` types are allowed", ), }; let type_attrs = TypeAttrs::parse(&input.attrs)?; let fields = data .fields .iter() .map(|field| SequenceField::new(field, &type_attrs)) .collect::>()?; Ok(Self { ident: input.ident, generics: input.generics.clone(), fields, }) } /// Lower the derived output into a [`TokenStream`]. pub fn to_tokens(&self) -> TokenStream { let ident = &self.ident; let mut generics = self.generics.clone(); // Use the first lifetime parameter as lifetime for Decode/Encode lifetime // if none found, add one. let lifetime = generics .lifetimes() .next() .map(|lt| lt.lifetime.clone()) .unwrap_or_else(|| { let lt = default_lifetime(); generics .params .insert(0, GenericParam::Lifetime(LifetimeParam::new(lt.clone()))); lt }); // We may or may not have inserted a lifetime. let (_, ty_generics, where_clause) = self.generics.split_for_impl(); let (impl_generics, _, _) = generics.split_for_impl(); let mut decode_body = Vec::new(); let mut decode_result = Vec::new(); let mut encoded_lengths = Vec::new(); let mut encode_fields = Vec::new(); for field in &self.fields { decode_body.push(field.to_decode_tokens()); decode_result.push(&field.ident); let field = field.to_encode_tokens(); encoded_lengths.push(quote!(#field.encoded_len()?)); encode_fields.push(quote!(#field.encode(writer)?;)); } quote! { impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause { fn decode_value>( reader: &mut R, header: ::der::Header, ) -> ::der::Result { use ::der::{Decode as _, DecodeValue as _, Reader as _}; reader.read_nested(header.length, |reader| { #(#decode_body)* Ok(Self { #(#decode_result),* }) }) } } impl #impl_generics ::der::EncodeValue for #ident #ty_generics #where_clause { fn value_len(&self) -> ::der::Result<::der::Length> { use ::der::Encode as _; [ #(#encoded_lengths),* ] .into_iter() .try_fold(::der::Length::ZERO, |acc, len| acc + len) } fn encode_value(&self, writer: &mut impl ::der::Writer) -> ::der::Result<()> { use ::der::Encode as _; #(#encode_fields)* Ok(()) } } impl #impl_generics ::der::Sequence<#lifetime> for #ident #ty_generics #where_clause {} } } } #[cfg(test)] mod tests { use super::DeriveSequence; use crate::{Asn1Type, TagMode}; use syn::parse_quote; /// X.509 SPKI `AlgorithmIdentifier`. #[test] fn algorithm_identifier_example() { let input = parse_quote! { #[derive(Sequence)] pub struct AlgorithmIdentifier<'a> { pub algorithm: ObjectIdentifier, pub parameters: Option>, } }; let ir = DeriveSequence::new(input).unwrap(); assert_eq!(ir.ident, "AlgorithmIdentifier"); assert_eq!( ir.generics.lifetimes().next().unwrap().lifetime.to_string(), "'a" ); assert_eq!(ir.fields.len(), 2); let algorithm_field = &ir.fields[0]; assert_eq!(algorithm_field.ident, "algorithm"); assert_eq!(algorithm_field.attrs.asn1_type, None); assert_eq!(algorithm_field.attrs.context_specific, None); assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); let parameters_field = &ir.fields[1]; assert_eq!(parameters_field.ident, "parameters"); assert_eq!(parameters_field.attrs.asn1_type, None); assert_eq!(parameters_field.attrs.context_specific, None); assert_eq!(parameters_field.attrs.tag_mode, TagMode::Explicit); } /// X.509 `SubjectPublicKeyInfo`. #[test] fn spki_example() { let input = parse_quote! { #[derive(Sequence)] pub struct SubjectPublicKeyInfo<'a> { pub algorithm: AlgorithmIdentifier<'a>, #[asn1(type = "BIT STRING")] pub subject_public_key: &'a [u8], } }; let ir = DeriveSequence::new(input).unwrap(); assert_eq!(ir.ident, "SubjectPublicKeyInfo"); assert_eq!( ir.generics.lifetimes().next().unwrap().lifetime.to_string(), "'a" ); assert_eq!(ir.fields.len(), 2); let algorithm_field = &ir.fields[0]; assert_eq!(algorithm_field.ident, "algorithm"); assert_eq!(algorithm_field.attrs.asn1_type, None); assert_eq!(algorithm_field.attrs.context_specific, None); assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); let subject_public_key_field = &ir.fields[1]; assert_eq!(subject_public_key_field.ident, "subject_public_key"); assert_eq!( subject_public_key_field.attrs.asn1_type, Some(Asn1Type::BitString) ); assert_eq!(subject_public_key_field.attrs.context_specific, None); assert_eq!(subject_public_key_field.attrs.tag_mode, TagMode::Explicit); } /// PKCS#8v2 `OneAsymmetricKey`. /// /// ```text /// OneAsymmetricKey ::= SEQUENCE { /// version Version, /// privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, /// privateKey PrivateKey, /// attributes [0] Attributes OPTIONAL, /// ..., /// [[2: publicKey [1] PublicKey OPTIONAL ]], /// ... /// } /// /// Version ::= INTEGER { v1(0), v2(1) } (v1, ..., v2) /// /// PrivateKeyAlgorithmIdentifier ::= AlgorithmIdentifier /// /// PrivateKey ::= OCTET STRING /// /// Attributes ::= SET OF Attribute /// /// PublicKey ::= BIT STRING /// ``` #[test] fn pkcs8_example() { let input = parse_quote! { #[derive(Sequence)] pub struct OneAsymmetricKey<'a> { pub version: u8, pub private_key_algorithm: AlgorithmIdentifier<'a>, #[asn1(type = "OCTET STRING")] pub private_key: &'a [u8], #[asn1(context_specific = "0", extensible = "true", optional = "true")] pub attributes: Option, 1>>, #[asn1( context_specific = "1", extensible = "true", optional = "true", type = "BIT STRING" )] pub public_key: Option<&'a [u8]>, } }; let ir = DeriveSequence::new(input).unwrap(); assert_eq!(ir.ident, "OneAsymmetricKey"); assert_eq!( ir.generics.lifetimes().next().unwrap().lifetime.to_string(), "'a" ); assert_eq!(ir.fields.len(), 5); let version_field = &ir.fields[0]; assert_eq!(version_field.ident, "version"); assert_eq!(version_field.attrs.asn1_type, None); assert_eq!(version_field.attrs.context_specific, None); assert_eq!(version_field.attrs.extensible, false); assert_eq!(version_field.attrs.optional, false); assert_eq!(version_field.attrs.tag_mode, TagMode::Explicit); let algorithm_field = &ir.fields[1]; assert_eq!(algorithm_field.ident, "private_key_algorithm"); assert_eq!(algorithm_field.attrs.asn1_type, None); assert_eq!(algorithm_field.attrs.context_specific, None); assert_eq!(algorithm_field.attrs.extensible, false); assert_eq!(algorithm_field.attrs.optional, false); assert_eq!(algorithm_field.attrs.tag_mode, TagMode::Explicit); let private_key_field = &ir.fields[2]; assert_eq!(private_key_field.ident, "private_key"); assert_eq!( private_key_field.attrs.asn1_type, Some(Asn1Type::OctetString) ); assert_eq!(private_key_field.attrs.context_specific, None); assert_eq!(private_key_field.attrs.extensible, false); assert_eq!(private_key_field.attrs.optional, false); assert_eq!(private_key_field.attrs.tag_mode, TagMode::Explicit); let attributes_field = &ir.fields[3]; assert_eq!(attributes_field.ident, "attributes"); assert_eq!(attributes_field.attrs.asn1_type, None); assert_eq!( attributes_field.attrs.context_specific, Some("0".parse().unwrap()) ); assert_eq!(attributes_field.attrs.extensible, true); assert_eq!(attributes_field.attrs.optional, true); assert_eq!(attributes_field.attrs.tag_mode, TagMode::Explicit); let public_key_field = &ir.fields[4]; assert_eq!(public_key_field.ident, "public_key"); assert_eq!(public_key_field.attrs.asn1_type, Some(Asn1Type::BitString)); assert_eq!( public_key_field.attrs.context_specific, Some("1".parse().unwrap()) ); assert_eq!(public_key_field.attrs.extensible, true); assert_eq!(public_key_field.attrs.optional, true); assert_eq!(public_key_field.attrs.tag_mode, TagMode::Explicit); } /// `IMPLICIT` tagged example #[test] fn implicit_example() { let input = parse_quote! { #[asn1(tag_mode = "IMPLICIT")] pub struct ImplicitSequence<'a> { #[asn1(context_specific = "0", type = "BIT STRING")] bit_string: BitString<'a>, #[asn1(context_specific = "1", type = "GeneralizedTime")] time: GeneralizedTime, #[asn1(context_specific = "2", type = "UTF8String")] utf8_string: String, } }; let ir = DeriveSequence::new(input).unwrap(); assert_eq!(ir.ident, "ImplicitSequence"); assert_eq!( ir.generics.lifetimes().next().unwrap().lifetime.to_string(), "'a" ); assert_eq!(ir.fields.len(), 3); let bit_string = &ir.fields[0]; assert_eq!(bit_string.ident, "bit_string"); assert_eq!(bit_string.attrs.asn1_type, Some(Asn1Type::BitString)); assert_eq!( bit_string.attrs.context_specific, Some("0".parse().unwrap()) ); assert_eq!(bit_string.attrs.tag_mode, TagMode::Implicit); let time = &ir.fields[1]; assert_eq!(time.ident, "time"); assert_eq!(time.attrs.asn1_type, Some(Asn1Type::GeneralizedTime)); assert_eq!(time.attrs.context_specific, Some("1".parse().unwrap())); assert_eq!(time.attrs.tag_mode, TagMode::Implicit); let utf8_string = &ir.fields[2]; assert_eq!(utf8_string.ident, "utf8_string"); assert_eq!(utf8_string.attrs.asn1_type, Some(Asn1Type::Utf8String)); assert_eq!( utf8_string.attrs.context_specific, Some("2".parse().unwrap()) ); assert_eq!(utf8_string.attrs.tag_mode, TagMode::Implicit); } }