// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use mls_rs::client_builder::Preferences; use mls_rs::group::{ReceivedMessage, StateUpdate}; use mls_rs::{CipherSuite, ExtensionList, Group, MlsMessage, ProtocolVersion}; use crate::test_client::{generate_client, TestClientConfig}; #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] pub struct TestCase { pub cipher_suite: u16, pub external_psks: Vec, #[serde(with = "hex::serde")] pub key_package: Vec, #[serde(with = "hex::serde")] pub signature_priv: Vec, #[serde(with = "hex::serde")] pub encryption_priv: Vec, #[serde(with = "hex::serde")] pub init_priv: Vec, #[serde(with = "hex::serde")] pub welcome: Vec, pub ratchet_tree: Option, #[serde(with = "hex::serde")] pub initial_epoch_authenticator: Vec, pub epochs: Vec, } #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] pub struct TestExternalPsk { #[serde(with = "hex::serde")] pub psk_id: Vec, #[serde(with = "hex::serde")] pub psk: Vec, } #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] pub struct TestEpoch { pub proposals: Vec, #[serde(with = "hex::serde")] pub commit: Vec, #[serde(with = "hex::serde")] pub epoch_authenticator: Vec, } #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec); #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)] pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec); impl TestEpoch { pub fn new( proposals: Vec, commit: &MlsMessage, epoch_authenticator: Vec, ) -> Self { let proposals = proposals .into_iter() .map(|p| TestMlsMessage(p.to_bytes().unwrap())) .collect(); Self { proposals, commit: commit.to_bytes().unwrap(), epoch_authenticator, } } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn get_test_groups( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, num_participants: usize, preferences: Preferences, ) -> Vec> { // Create the group with Alice as the group initiator let creator = generate_client(cipher_suite, b"alice".to_vec(), preferences.clone()); let mut creator_group = creator .client .create_group_with_id( protocol_version, cipher_suite, b"group".to_vec(), creator.identity, ExtensionList::default(), ) .await .unwrap(); // Generate random clients that will be members of the group let receiver_clients = (0..num_participants - 1) .map(|i| { generate_client( cipher_suite, format!("bob{i}").into_bytes(), preferences.clone(), ) }) .collect::>(); let mut receiver_keys = Vec::new(); for client in &receiver_clients { let keys = client .client .generate_key_package_message(protocol_version, cipher_suite, client.identity.clone()) .await .unwrap(); receiver_keys.push(keys); } // Add the generated clients to the group the creator made let mut commit_builder = creator_group.commit_builder(); for key in &receiver_keys { commit_builder = commit_builder.add_member(key.clone()).unwrap(); } let welcome = commit_builder.build().await.unwrap().welcome_message; // Creator can confirm the commit was processed by the server #[cfg(feature = "state_update")] { let commit_description = creator_group.apply_pending_commit().await.unwrap(); assert!(commit_description.state_update.is_active()); assert_eq!(commit_description.state_update.new_epoch(), 1); } #[cfg(not(feature = "state_update"))] creator_group.apply_pending_commit().await.unwrap(); for client in &receiver_clients { let res = creator_group .member_with_identity(client.identity.credential.as_basic().unwrap().identifier()) .await; assert!(res.is_ok()); } #[cfg(feature = "state_update")] assert!(commit_description .state_update .roster_update() .removed() .is_empty()); // Export the tree for receivers let tree_data = creator_group.export_tree().unwrap(); // All the receivers will be able to join the group let mut receiver_groups = Vec::new(); for client in &receiver_clients { let test_client = client .client .join_group(Some(&tree_data), welcome.clone().unwrap()) .await .unwrap() .0; receiver_groups.push(test_client); } for one_receiver in &receiver_groups { assert!(Group::equal_group_state(&creator_group, one_receiver)); } receiver_groups.insert(0, creator_group); receiver_groups } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn all_process_commit_with_update( groups: &mut [Group], commit: &MlsMessage, sender: usize, ) -> Vec { let mut state_updates = Vec::new(); for g in groups { let state_update = if sender != g.current_member_index() as usize { let processed_msg = g.process_incoming_message(commit.clone()).await.unwrap(); match processed_msg { ReceivedMessage::Commit(update) => update.state_update, _ => panic!("Expected commit, got {processed_msg:?}"), } } else { g.apply_pending_commit().await.unwrap().state_update }; state_updates.push(state_update); } state_updates } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn all_process_message( groups: &mut [Group], message: &MlsMessage, sender: usize, is_commit: bool, ) { for group in groups { if sender != group.current_member_index() as usize { group .process_incoming_message(message.clone()) .await .unwrap(); } else if is_commit { group.apply_pending_commit().await.unwrap(); } } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn add_random_members( first_id: usize, num_added: usize, committer: usize, groups: &mut Vec>, test_case: Option<&mut TestCase>, ) { let cipher_suite = groups[committer].cipher_suite(); let committer_index = groups[committer].current_member_index() as usize; let mut key_packages = Vec::new(); let mut new_clients = Vec::new(); for i in 0..num_added { let id = first_id + i; let new_client = generate_client( cipher_suite, format!("dave-{id}").into(), Preferences::default(), ); let key_package = new_client .client .generate_key_package_message( ProtocolVersion::MLS_10, cipher_suite, new_client.identity.clone(), ) .await .unwrap(); key_packages.push(key_package); new_clients.push(new_client); } let committer_group = &mut groups[committer]; let mut commit = committer_group.commit_builder(); for key_package in key_packages { commit = commit.add_member(key_package).unwrap(); } let commit_output = commit.build().await.unwrap(); all_process_message(groups, &commit_output.commit_message, committer_index, true).await; let auth = groups[committer].epoch_authenticator().unwrap().to_vec(); let epoch = TestEpoch::new(vec![], &commit_output.commit_message, auth); if let Some(tc) = test_case { tc.epochs.push(epoch) }; let tree_data = groups[committer].export_tree().unwrap(); let mut new_groups = Vec::new(); for client in &new_clients { let tree_data = tree_data.clone(); let commit = commit_output.welcome_message.clone().unwrap(); let client = client .client .join_group(Some(&tree_data.clone()), commit) .await .unwrap() .0; new_groups.push(client); } groups.append(&mut new_groups); } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn remove_members( removed_members: Vec, committer: usize, groups: &mut Vec>, test_case: Option<&mut TestCase>, ) { let remove_indexes = removed_members .iter() .map(|removed| groups[*removed].current_member_index()) .collect::>(); let mut commit_builder = groups[committer].commit_builder(); for index in remove_indexes { commit_builder = commit_builder.remove_member(index).unwrap(); } let commit = commit_builder.build().await.unwrap().commit_message; let committer_index = groups[committer].current_member_index() as usize; all_process_message(groups, &commit, committer_index, true).await; let auth = groups[committer].epoch_authenticator().unwrap().to_vec(); let epoch = TestEpoch::new(vec![], &commit, auth); if let Some(tc) = test_case { tc.epochs.push(epoch) }; let mut index = 0; groups.retain(|_| { index += 1; !(removed_members.contains(&(index - 1))) }); }