• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use mls_rs::client_builder::Preferences;
6 use mls_rs::group::{ReceivedMessage, StateUpdate};
7 use mls_rs::{CipherSuite, ExtensionList, Group, MlsMessage, ProtocolVersion};
8 
9 use crate::test_client::{generate_client, TestClientConfig};
10 
11 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
12 pub struct TestCase {
13     pub cipher_suite: u16,
14 
15     pub external_psks: Vec<TestExternalPsk>,
16     #[serde(with = "hex::serde")]
17     pub key_package: Vec<u8>,
18     #[serde(with = "hex::serde")]
19     pub signature_priv: Vec<u8>,
20     #[serde(with = "hex::serde")]
21     pub encryption_priv: Vec<u8>,
22     #[serde(with = "hex::serde")]
23     pub init_priv: Vec<u8>,
24 
25     #[serde(with = "hex::serde")]
26     pub welcome: Vec<u8>,
27     pub ratchet_tree: Option<TestRatchetTree>,
28     #[serde(with = "hex::serde")]
29     pub initial_epoch_authenticator: Vec<u8>,
30 
31     pub epochs: Vec<TestEpoch>,
32 }
33 
34 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
35 pub struct TestExternalPsk {
36     #[serde(with = "hex::serde")]
37     pub psk_id: Vec<u8>,
38     #[serde(with = "hex::serde")]
39     pub psk: Vec<u8>,
40 }
41 
42 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
43 pub struct TestEpoch {
44     pub proposals: Vec<TestMlsMessage>,
45     #[serde(with = "hex::serde")]
46     pub commit: Vec<u8>,
47     #[serde(with = "hex::serde")]
48     pub epoch_authenticator: Vec<u8>,
49 }
50 
51 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
52 pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec<u8>);
53 
54 #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
55 pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec<u8>);
56 
57 impl TestEpoch {
new( proposals: Vec<MlsMessage>, commit: &MlsMessage, epoch_authenticator: Vec<u8>, ) -> Self58     pub fn new(
59         proposals: Vec<MlsMessage>,
60         commit: &MlsMessage,
61         epoch_authenticator: Vec<u8>,
62     ) -> Self {
63         let proposals = proposals
64             .into_iter()
65             .map(|p| TestMlsMessage(p.to_bytes().unwrap()))
66             .collect();
67 
68         Self {
69             proposals,
70             commit: commit.to_bytes().unwrap(),
71             epoch_authenticator,
72         }
73     }
74 }
75 
76 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_groups( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, num_participants: usize, preferences: Preferences, ) -> Vec<Group<TestClientConfig>>77 pub async fn get_test_groups(
78     protocol_version: ProtocolVersion,
79     cipher_suite: CipherSuite,
80     num_participants: usize,
81     preferences: Preferences,
82 ) -> Vec<Group<TestClientConfig>> {
83     // Create the group with Alice as the group initiator
84     let creator = generate_client(cipher_suite, b"alice".to_vec(), preferences.clone());
85 
86     let mut creator_group = creator
87         .client
88         .create_group_with_id(
89             protocol_version,
90             cipher_suite,
91             b"group".to_vec(),
92             creator.identity,
93             ExtensionList::default(),
94         )
95         .await
96         .unwrap();
97 
98     // Generate random clients that will be members of the group
99     let receiver_clients = (0..num_participants - 1)
100         .map(|i| {
101             generate_client(
102                 cipher_suite,
103                 format!("bob{i}").into_bytes(),
104                 preferences.clone(),
105             )
106         })
107         .collect::<Vec<_>>();
108 
109     let mut receiver_keys = Vec::new();
110 
111     for client in &receiver_clients {
112         let keys = client
113             .client
114             .generate_key_package_message(protocol_version, cipher_suite, client.identity.clone())
115             .await
116             .unwrap();
117 
118         receiver_keys.push(keys);
119     }
120 
121     // Add the generated clients to the group the creator made
122     let mut commit_builder = creator_group.commit_builder();
123 
124     for key in &receiver_keys {
125         commit_builder = commit_builder.add_member(key.clone()).unwrap();
126     }
127 
128     let welcome = commit_builder.build().await.unwrap().welcome_message;
129 
130     // Creator can confirm the commit was processed by the server
131     #[cfg(feature = "state_update")]
132     {
133         let commit_description = creator_group.apply_pending_commit().await.unwrap();
134 
135         assert!(commit_description.state_update.is_active());
136         assert_eq!(commit_description.state_update.new_epoch(), 1);
137     }
138 
139     #[cfg(not(feature = "state_update"))]
140     creator_group.apply_pending_commit().await.unwrap();
141 
142     for client in &receiver_clients {
143         let res = creator_group
144             .member_with_identity(client.identity.credential.as_basic().unwrap().identifier())
145             .await;
146 
147         assert!(res.is_ok());
148     }
149 
150     #[cfg(feature = "state_update")]
151     assert!(commit_description
152         .state_update
153         .roster_update()
154         .removed()
155         .is_empty());
156 
157     // Export the tree for receivers
158     let tree_data = creator_group.export_tree().unwrap();
159 
160     // All the receivers will be able to join the group
161     let mut receiver_groups = Vec::new();
162 
163     for client in &receiver_clients {
164         let test_client = client
165             .client
166             .join_group(Some(&tree_data), welcome.clone().unwrap())
167             .await
168             .unwrap()
169             .0;
170 
171         receiver_groups.push(test_client);
172     }
173 
174     for one_receiver in &receiver_groups {
175         assert!(Group::equal_group_state(&creator_group, one_receiver));
176     }
177 
178     receiver_groups.insert(0, creator_group);
179 
180     receiver_groups
181 }
182 
183 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
all_process_commit_with_update( groups: &mut [Group<TestClientConfig>], commit: &MlsMessage, sender: usize, ) -> Vec<StateUpdate>184 pub async fn all_process_commit_with_update(
185     groups: &mut [Group<TestClientConfig>],
186     commit: &MlsMessage,
187     sender: usize,
188 ) -> Vec<StateUpdate> {
189     let mut state_updates = Vec::new();
190 
191     for g in groups {
192         let state_update = if sender != g.current_member_index() as usize {
193             let processed_msg = g.process_incoming_message(commit.clone()).await.unwrap();
194 
195             match processed_msg {
196                 ReceivedMessage::Commit(update) => update.state_update,
197                 _ => panic!("Expected commit, got {processed_msg:?}"),
198             }
199         } else {
200             g.apply_pending_commit().await.unwrap().state_update
201         };
202 
203         state_updates.push(state_update);
204     }
205 
206     state_updates
207 }
208 
209 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
all_process_message( groups: &mut [Group<TestClientConfig>], message: &MlsMessage, sender: usize, is_commit: bool, )210 pub async fn all_process_message(
211     groups: &mut [Group<TestClientConfig>],
212     message: &MlsMessage,
213     sender: usize,
214     is_commit: bool,
215 ) {
216     for group in groups {
217         if sender != group.current_member_index() as usize {
218             group
219                 .process_incoming_message(message.clone())
220                 .await
221                 .unwrap();
222         } else if is_commit {
223             group.apply_pending_commit().await.unwrap();
224         }
225     }
226 }
227 
228 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
add_random_members( first_id: usize, num_added: usize, committer: usize, groups: &mut Vec<Group<TestClientConfig>>, test_case: Option<&mut TestCase>, )229 pub async fn add_random_members(
230     first_id: usize,
231     num_added: usize,
232     committer: usize,
233     groups: &mut Vec<Group<TestClientConfig>>,
234     test_case: Option<&mut TestCase>,
235 ) {
236     let cipher_suite = groups[committer].cipher_suite();
237     let committer_index = groups[committer].current_member_index() as usize;
238 
239     let mut key_packages = Vec::new();
240     let mut new_clients = Vec::new();
241 
242     for i in 0..num_added {
243         let id = first_id + i;
244         let new_client = generate_client(
245             cipher_suite,
246             format!("dave-{id}").into(),
247             Preferences::default(),
248         );
249 
250         let key_package = new_client
251             .client
252             .generate_key_package_message(
253                 ProtocolVersion::MLS_10,
254                 cipher_suite,
255                 new_client.identity.clone(),
256             )
257             .await
258             .unwrap();
259 
260         key_packages.push(key_package);
261         new_clients.push(new_client);
262     }
263 
264     let committer_group = &mut groups[committer];
265     let mut commit = committer_group.commit_builder();
266 
267     for key_package in key_packages {
268         commit = commit.add_member(key_package).unwrap();
269     }
270 
271     let commit_output = commit.build().await.unwrap();
272 
273     all_process_message(groups, &commit_output.commit_message, committer_index, true).await;
274 
275     let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
276     let epoch = TestEpoch::new(vec![], &commit_output.commit_message, auth);
277 
278     if let Some(tc) = test_case {
279         tc.epochs.push(epoch)
280     };
281 
282     let tree_data = groups[committer].export_tree().unwrap();
283 
284     let mut new_groups = Vec::new();
285 
286     for client in &new_clients {
287         let tree_data = tree_data.clone();
288         let commit = commit_output.welcome_message.clone().unwrap();
289 
290         let client = client
291             .client
292             .join_group(Some(&tree_data.clone()), commit)
293             .await
294             .unwrap()
295             .0;
296 
297         new_groups.push(client);
298     }
299 
300     groups.append(&mut new_groups);
301 }
302 
303 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
remove_members( removed_members: Vec<usize>, committer: usize, groups: &mut Vec<Group<TestClientConfig>>, test_case: Option<&mut TestCase>, )304 pub async fn remove_members(
305     removed_members: Vec<usize>,
306     committer: usize,
307     groups: &mut Vec<Group<TestClientConfig>>,
308     test_case: Option<&mut TestCase>,
309 ) {
310     let remove_indexes = removed_members
311         .iter()
312         .map(|removed| groups[*removed].current_member_index())
313         .collect::<Vec<u32>>();
314 
315     let mut commit_builder = groups[committer].commit_builder();
316 
317     for index in remove_indexes {
318         commit_builder = commit_builder.remove_member(index).unwrap();
319     }
320 
321     let commit = commit_builder.build().await.unwrap().commit_message;
322     let committer_index = groups[committer].current_member_index() as usize;
323     all_process_message(groups, &commit, committer_index, true).await;
324 
325     let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
326     let epoch = TestEpoch::new(vec![], &commit, auth);
327 
328     if let Some(tc) = test_case {
329         tc.epochs.push(epoch)
330     };
331 
332     let mut index = 0;
333 
334     groups.retain(|_| {
335         index += 1;
336         !(removed_members.contains(&(index - 1)))
337     });
338 }
339