// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use core::{ fmt::{self, Debug}, ops::Deref, }; use crate::error::{AnyError, IntoAnyError}; use alloc::vec::Vec; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; mod list; pub use list::*; /// Wrapper type representing an extension identifier along with default values /// defined by the MLS RFC. #[derive( Debug, PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode, )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[repr(transparent)] pub struct ExtensionType(u16); #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] impl ExtensionType { pub const APPLICATION_ID: ExtensionType = ExtensionType(1); pub const RATCHET_TREE: ExtensionType = ExtensionType(2); pub const REQUIRED_CAPABILITIES: ExtensionType = ExtensionType(3); pub const EXTERNAL_PUB: ExtensionType = ExtensionType(4); pub const EXTERNAL_SENDERS: ExtensionType = ExtensionType(5); /// Default extension types defined /// in [RFC 9420](https://www.rfc-editor.org/rfc/rfc9420.html#name-leaf-node-contents) pub const DEFAULT: &'static [ExtensionType] = &[ ExtensionType::APPLICATION_ID, ExtensionType::RATCHET_TREE, ExtensionType::REQUIRED_CAPABILITIES, ExtensionType::EXTERNAL_PUB, ExtensionType::EXTERNAL_SENDERS, ]; /// Extension type from a raw value pub const fn new(raw_value: u16) -> Self { ExtensionType(raw_value) } /// Raw numerical wrapped value. pub const fn raw_value(&self) -> u16 { self.0 } /// Determines if this extension type is required to be implemented /// by the MLS RFC. pub const fn is_default(&self) -> bool { self.0 <= 5 } } impl From for ExtensionType { fn from(value: u16) -> Self { ExtensionType(value) } } impl Deref for ExtensionType { type Target = u16; fn deref(&self) -> &Self::Target { &self.0 } } #[derive(Debug)] #[cfg_attr(feature = "std", derive(thiserror::Error))] pub enum ExtensionError { #[cfg_attr(feature = "std", error(transparent))] SerializationError(AnyError), #[cfg_attr(feature = "std", error(transparent))] DeserializationError(AnyError), #[cfg_attr(feature = "std", error("incorrect extension type: {0:?}"))] IncorrectType(ExtensionType), } impl IntoAnyError for ExtensionError { #[cfg(feature = "std")] fn into_dyn_error(self) -> Result, Self> { Ok(self.into()) } } #[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr( all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(clone, opaque) )] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] /// An MLS protocol [extension](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-extensions). /// /// Extensions are used as customization points in various parts of the /// MLS protocol and are inserted into an [ExtensionList](self::ExtensionList). pub struct Extension { /// Extension type of this extension pub extension_type: ExtensionType, /// Data held within this extension #[mls_codec(with = "mls_rs_codec::byte_vec")] #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))] pub extension_data: Vec, } impl Debug for Extension { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Extension") .field("extension_type", &self.extension_type) .field( "extension_data", &crate::debug::pretty_bytes(&self.extension_data), ) .finish() } } #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)] impl Extension { /// Create an extension with specified type and data properties. pub fn new(extension_type: ExtensionType, extension_data: Vec) -> Extension { Extension { extension_type, extension_data, } } /// Extension type of this extension #[cfg(feature = "ffi")] pub fn extension_type(&self) -> ExtensionType { self.extension_type } /// Data held within this extension #[cfg(feature = "ffi")] pub fn extension_data(&self) -> &[u8] { &self.extension_data } } /// Trait used to convert a type to and from an [Extension] pub trait MlsExtension: Sized { /// Error type of the underlying serializer that can convert this type into a `Vec`. type SerializationError: IntoAnyError; /// Error type of the underlying deserializer that can convert a `Vec` into this type. type DeserializationError: IntoAnyError; /// Extension type value that this type represents. fn extension_type() -> ExtensionType; /// Convert this type to opaque bytes. fn to_bytes(&self) -> Result, Self::SerializationError>; /// Create this type from opaque bytes. fn from_bytes(data: &[u8]) -> Result; /// Convert this type into an [Extension]. fn into_extension(self) -> Result { Ok(Extension::new( Self::extension_type(), self.to_bytes() .map_err(|e| ExtensionError::SerializationError(e.into_any_error()))?, )) } /// Create this type from an [Extension]. fn from_extension(ext: &Extension) -> Result { if ext.extension_type != Self::extension_type() { return Err(ExtensionError::IncorrectType(ext.extension_type)); } Self::from_bytes(&ext.extension_data) .map_err(|e| ExtensionError::DeserializationError(e.into_any_error())) } } /// Convenience trait for custom extension types that use /// [mls_rs_codec] as an underlying serialization mechanism pub trait MlsCodecExtension: MlsSize + MlsEncode + MlsDecode { fn extension_type() -> ExtensionType; } impl MlsExtension for T where T: MlsCodecExtension, { type SerializationError = mls_rs_codec::Error; type DeserializationError = mls_rs_codec::Error; fn extension_type() -> ExtensionType { ::extension_type() } fn to_bytes(&self) -> Result, Self::SerializationError> { self.mls_encode_to_vec() } fn from_bytes(data: &[u8]) -> Result { Self::mls_decode(&mut &*data) } } #[cfg(test)] mod tests { use core::convert::Infallible; use alloc::vec; use alloc::vec::Vec; use assert_matches::assert_matches; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use super::{Extension, ExtensionError, ExtensionType, MlsCodecExtension, MlsExtension}; struct TestExtension; #[derive(Debug, MlsSize, MlsEncode, MlsDecode)] struct AnotherTestExtension; impl MlsExtension for TestExtension { type SerializationError = Infallible; type DeserializationError = Infallible; fn extension_type() -> super::ExtensionType { ExtensionType(42) } fn to_bytes(&self) -> Result, Self::SerializationError> { Ok(vec![0]) } fn from_bytes(_data: &[u8]) -> Result { Ok(TestExtension) } } impl MlsCodecExtension for AnotherTestExtension { fn extension_type() -> ExtensionType { ExtensionType(43) } } #[test] fn into_extension() { assert_eq!( TestExtension.into_extension().unwrap(), Extension::new(42.into(), vec![0]) ) } #[test] fn incorrect_type_is_discovered() { let ext = Extension::new(42.into(), vec![0]); assert_matches!(AnotherTestExtension::from_extension(&ext), Err(ExtensionError::IncorrectType(found)) if found == 42.into()); } }