1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use core::ops::{Deref, DerefMut};
6
7 use alloc::format;
8 use rand::RngCore;
9
10 use super::*;
11 use crate::{
12 client::{
13 test_utils::{
14 test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
15 TEST_PROTOCOL_VERSION,
16 },
17 MlsError,
18 },
19 client_builder::test_utils::{TestClientBuilder, TestClientConfig},
20 crypto::test_utils::test_cipher_suite_provider,
21 extension::ExtensionType,
22 identity::test_utils::get_test_signing_identity,
23 key_package::{KeyPackageGeneration, KeyPackageGenerator},
24 mls_rules::{CommitOptions, DefaultMlsRules},
25 tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
26 };
27
28 use crate::extension::RequiredCapabilitiesExt;
29
30 #[cfg(not(feature = "by_ref_proposal"))]
31 use crate::crypto::HpkePublicKey;
32
33 pub const TEST_GROUP: &[u8] = b"group";
34
35 #[derive(Clone)]
36 pub(crate) struct TestGroup {
37 pub group: Group<TestClientConfig>,
38 }
39
40 impl TestGroup {
41 #[cfg(feature = "external_client")]
42 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
propose(&mut self, proposal: Proposal) -> MlsMessage43 pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
44 self.group.proposal_message(proposal, vec![]).await.unwrap()
45 }
46
47 #[cfg(feature = "by_ref_proposal")]
48 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_proposal(&mut self) -> Proposal49 pub(crate) async fn update_proposal(&mut self) -> Proposal {
50 self.group.update_proposal(None, None).await.unwrap()
51 }
52
53 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_with_custom_config<F>( &mut self, name: &str, custom_kp: bool, mut config: F, ) -> Result<(TestGroup, MlsMessage), MlsError> where F: FnMut(&mut TestClientConfig),54 pub(crate) async fn join_with_custom_config<F>(
55 &mut self,
56 name: &str,
57 custom_kp: bool,
58 mut config: F,
59 ) -> Result<(TestGroup, MlsMessage), MlsError>
60 where
61 F: FnMut(&mut TestClientConfig),
62 {
63 let (mut new_client, new_key_package) = if custom_kp {
64 test_client_with_key_pkg_custom(
65 self.group.protocol_version(),
66 self.group.cipher_suite(),
67 name,
68 &mut config,
69 )
70 .await
71 } else {
72 test_client_with_key_pkg(
73 self.group.protocol_version(),
74 self.group.cipher_suite(),
75 name,
76 )
77 .await
78 };
79
80 // Add new member to the group
81 let CommitOutput {
82 welcome_messages,
83 ratchet_tree,
84 commit_message,
85 ..
86 } = self
87 .group
88 .commit_builder()
89 .add_member(new_key_package)
90 .unwrap()
91 .build()
92 .await
93 .unwrap();
94
95 // Apply the commit to the original group
96 self.group.apply_pending_commit().await.unwrap();
97
98 config(&mut new_client.config);
99
100 // Group from new member's perspective
101 let (new_group, _) = Group::join(
102 &welcome_messages[0],
103 ratchet_tree,
104 new_client.config.clone(),
105 new_client.signer.clone().unwrap(),
106 )
107 .await?;
108
109 let new_test_group = TestGroup { group: new_group };
110
111 Ok((new_test_group, commit_message))
112 }
113
114 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join(&mut self, name: &str) -> (TestGroup, MlsMessage)115 pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
116 self.join_with_custom_config(name, false, |_| ())
117 .await
118 .unwrap()
119 }
120
121 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_pending_commit( &mut self, ) -> Result<CommitMessageDescription, MlsError>122 pub(crate) async fn process_pending_commit(
123 &mut self,
124 ) -> Result<CommitMessageDescription, MlsError> {
125 self.group.apply_pending_commit().await
126 }
127
128 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_message( &mut self, message: MlsMessage, ) -> Result<ReceivedMessage, MlsError>129 pub(crate) async fn process_message(
130 &mut self,
131 message: MlsMessage,
132 ) -> Result<ReceivedMessage, MlsError> {
133 self.group.process_incoming_message(message).await
134 }
135
136 #[cfg(feature = "private_message")]
137 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_plaintext(&mut self, content: Content) -> MlsMessage138 pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
139 let auth_content = AuthenticatedContent::new_signed(
140 &self.group.cipher_suite_provider,
141 &self.group.state.context,
142 Sender::Member(*self.group.private_tree.self_index),
143 content,
144 &self.group.signer,
145 WireFormat::PublicMessage,
146 Vec::new(),
147 )
148 .await
149 .unwrap();
150
151 self.group.format_for_wire(auth_content).await.unwrap()
152 }
153 }
154
155 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext156 pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
157 let cs = test_cipher_suite_provider(cipher_suite);
158
159 GroupContext {
160 protocol_version: TEST_PROTOCOL_VERSION,
161 cipher_suite,
162 group_id: TEST_GROUP.to_vec(),
163 epoch,
164 tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
165 confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
166 extensions: ExtensionList::from(vec![]),
167 }
168 }
169
170 #[cfg(feature = "prior_epoch")]
get_test_group_context_with_id( group_id: Vec<u8>, epoch: u64, cipher_suite: CipherSuite, ) -> GroupContext171 pub(crate) fn get_test_group_context_with_id(
172 group_id: Vec<u8>,
173 epoch: u64,
174 cipher_suite: CipherSuite,
175 ) -> GroupContext {
176 GroupContext {
177 protocol_version: TEST_PROTOCOL_VERSION,
178 cipher_suite,
179 group_id,
180 epoch,
181 tree_hash: vec![],
182 confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
183 extensions: ExtensionList::from(vec![]),
184 }
185 }
186
group_extensions() -> ExtensionList187 pub(crate) fn group_extensions() -> ExtensionList {
188 let required_capabilities = RequiredCapabilitiesExt::default();
189
190 let mut extensions = ExtensionList::new();
191 extensions.set_from(required_capabilities).unwrap();
192 extensions
193 }
194
lifetime() -> Lifetime195 pub(crate) fn lifetime() -> Lifetime {
196 Lifetime::years(1).unwrap()
197 }
198
199 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_member( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, identifier: &[u8], ) -> (KeyPackageGeneration, SignatureSecretKey)200 pub(crate) async fn test_member(
201 protocol_version: ProtocolVersion,
202 cipher_suite: CipherSuite,
203 identifier: &[u8],
204 ) -> (KeyPackageGeneration, SignatureSecretKey) {
205 let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
206
207 let key_package_generator = KeyPackageGenerator {
208 protocol_version,
209 cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
210 signing_identity: &signing_identity,
211 signing_key: &signing_key,
212 };
213
214 let key_package = key_package_generator
215 .generate(
216 lifetime(),
217 get_test_capabilities(),
218 ExtensionList::default(),
219 ExtensionList::default(),
220 )
221 .await
222 .unwrap();
223
224 (key_package, signing_key)
225 }
226
227 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, extension_types: Vec<ExtensionType>, leaf_extensions: Option<ExtensionList>, commit_options: Option<CommitOptions>, ) -> TestGroup228 pub(crate) async fn test_group_custom(
229 protocol_version: ProtocolVersion,
230 cipher_suite: CipherSuite,
231 extension_types: Vec<ExtensionType>,
232 leaf_extensions: Option<ExtensionList>,
233 commit_options: Option<CommitOptions>,
234 ) -> TestGroup {
235 let leaf_extensions = leaf_extensions.unwrap_or_default();
236 let commit_options = commit_options.unwrap_or_default();
237
238 let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
239
240 let group = TestClientBuilder::new_for_test()
241 .leaf_node_extensions(leaf_extensions)
242 .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
243 .extension_types(extension_types)
244 .protocol_versions(ProtocolVersion::all())
245 .used_protocol_version(protocol_version)
246 .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
247 .build()
248 .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
249 .await
250 .unwrap();
251
252 TestGroup { group }
253 }
254
255 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, ) -> TestGroup256 pub(crate) async fn test_group(
257 protocol_version: ProtocolVersion,
258 cipher_suite: CipherSuite,
259 ) -> TestGroup {
260 test_group_custom(
261 protocol_version,
262 cipher_suite,
263 Default::default(),
264 None,
265 None,
266 )
267 .await
268 }
269
270 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom_config<F>( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, custom: F, ) -> TestGroup where F: FnOnce(TestClientBuilder) -> TestClientBuilder,271 pub(crate) async fn test_group_custom_config<F>(
272 protocol_version: ProtocolVersion,
273 cipher_suite: CipherSuite,
274 custom: F,
275 ) -> TestGroup
276 where
277 F: FnOnce(TestClientBuilder) -> TestClientBuilder,
278 {
279 let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
280
281 let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
282
283 let group = custom(client_builder)
284 .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
285 .build()
286 .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
287 .await
288 .unwrap();
289
290 TestGroup { group }
291 }
292
293 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_n_member_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, num_members: usize, ) -> Vec<TestGroup>294 pub(crate) async fn test_n_member_group(
295 protocol_version: ProtocolVersion,
296 cipher_suite: CipherSuite,
297 num_members: usize,
298 ) -> Vec<TestGroup> {
299 let group = test_group(protocol_version, cipher_suite).await;
300
301 let mut groups = vec![group];
302
303 for i in 1..num_members {
304 let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
305 process_commit(&mut groups, commit, 0).await;
306 groups.push(new_group);
307 }
308
309 groups
310 }
311
312 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32)313 pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
314 for g in groups
315 .iter_mut()
316 .filter(|g| g.group.current_member_index() != excluded)
317 {
318 g.process_message(commit.clone()).await.unwrap();
319 }
320 }
321
get_test_25519_key(key_byte: u8) -> HpkePublicKey322 pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
323 vec![key_byte; 32].into()
324 }
325
326 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_groups_with_features( n: usize, extensions: ExtensionList, leaf_extensions: ExtensionList, ) -> Vec<Group<TestClientConfig>>327 pub(crate) async fn get_test_groups_with_features(
328 n: usize,
329 extensions: ExtensionList,
330 leaf_extensions: ExtensionList,
331 ) -> Vec<Group<TestClientConfig>> {
332 let mut clients = Vec::new();
333
334 for i in 0..n {
335 let (identity, secret_key) =
336 get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
337
338 clients.push(
339 TestClientBuilder::new_for_test()
340 .extension_type(999.into())
341 .leaf_node_extensions(leaf_extensions.clone())
342 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
343 .build(),
344 );
345 }
346
347 let group = clients[0]
348 .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
349 .await
350 .unwrap();
351
352 let mut groups = vec![group];
353
354 for client in clients.iter().skip(1) {
355 let key_package = client.generate_key_package_message().await.unwrap();
356
357 let commit_output = groups[0]
358 .commit_builder()
359 .add_member(key_package)
360 .unwrap()
361 .build()
362 .await
363 .unwrap();
364
365 groups[0].apply_pending_commit().await.unwrap();
366
367 for group in groups.iter_mut().skip(1) {
368 group
369 .process_incoming_message(commit_output.commit_message.clone())
370 .await
371 .unwrap();
372 }
373
374 groups.push(
375 client
376 .join_group(None, &commit_output.welcome_messages[0])
377 .await
378 .unwrap()
379 .0,
380 );
381 }
382
383 groups
384 }
385
random_bytes(count: usize) -> Vec<u8>386 pub fn random_bytes(count: usize) -> Vec<u8> {
387 let mut buf = vec![0; count];
388 rand::thread_rng().fill_bytes(&mut buf);
389 buf
390 }
391
392 pub(crate) struct GroupWithoutKeySchedule {
393 inner: Group<TestClientConfig>,
394 pub secrets: Option<(TreeKemPrivate, PathSecret)>,
395 pub provisional_public_state: Option<ProvisionalState>,
396 }
397
398 impl Deref for GroupWithoutKeySchedule {
399 type Target = Group<TestClientConfig>;
400
401 #[cfg_attr(coverage_nightly, coverage(off))]
deref(&self) -> &Self::Target402 fn deref(&self) -> &Self::Target {
403 &self.inner
404 }
405 }
406
407 impl DerefMut for GroupWithoutKeySchedule {
408 #[cfg_attr(coverage_nightly, coverage(off))]
deref_mut(&mut self) -> &mut Self::Target409 fn deref_mut(&mut self) -> &mut Self::Target {
410 &mut self.inner
411 }
412 }
413
414 #[cfg(feature = "rfc_compliant")]
415 impl GroupWithoutKeySchedule {
416 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new(cs: CipherSuite) -> Self417 pub async fn new(cs: CipherSuite) -> Self {
418 Self {
419 inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
420 secrets: None,
421 provisional_public_state: None,
422 }
423 }
424 }
425
426 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
427 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
428 #[cfg_attr(
429 all(not(target_arch = "wasm32"), mls_build_async),
430 maybe_async::must_be_async
431 )]
432 impl MessageProcessor for GroupWithoutKeySchedule {
433 type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
434 type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
435 type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
436 type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
437 type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
438
group_state(&self) -> &GroupState439 fn group_state(&self) -> &GroupState {
440 self.inner.group_state()
441 }
442
443 #[cfg_attr(coverage_nightly, coverage(off))]
group_state_mut(&mut self) -> &mut GroupState444 fn group_state_mut(&mut self) -> &mut GroupState {
445 self.inner.group_state_mut()
446 }
447
mls_rules(&self) -> Self::MlsRules448 fn mls_rules(&self) -> Self::MlsRules {
449 self.inner.mls_rules()
450 }
451
identity_provider(&self) -> Self::IdentityProvider452 fn identity_provider(&self) -> Self::IdentityProvider {
453 self.inner.identity_provider()
454 }
455
cipher_suite_provider(&self) -> &Self::CipherSuiteProvider456 fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
457 self.inner.cipher_suite_provider()
458 }
459
psk_storage(&self) -> Self::PreSharedKeyStorage460 fn psk_storage(&self) -> Self::PreSharedKeyStorage {
461 self.inner.psk_storage()
462 }
463
can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool464 fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
465 self.inner.can_continue_processing(provisional_state)
466 }
467
468 #[cfg(feature = "private_message")]
469 #[cfg_attr(coverage_nightly, coverage(off))]
min_epoch_available(&self) -> Option<u64>470 fn min_epoch_available(&self) -> Option<u64> {
471 self.inner.min_epoch_available()
472 }
473
apply_update_path( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, provisional_state: &mut ProvisionalState, ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError>474 async fn apply_update_path(
475 &mut self,
476 sender: LeafIndex,
477 update_path: &ValidatedUpdatePath,
478 provisional_state: &mut ProvisionalState,
479 ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
480 self.inner
481 .apply_update_path(sender, update_path, provisional_state)
482 .await
483 }
484
485 #[cfg(feature = "private_message")]
486 #[cfg_attr(coverage_nightly, coverage(off))]
process_ciphertext( &mut self, cipher_text: &PrivateMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>487 async fn process_ciphertext(
488 &mut self,
489 cipher_text: &PrivateMessage,
490 ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
491 self.inner.process_ciphertext(cipher_text).await
492 }
493
494 #[cfg_attr(coverage_nightly, coverage(off))]
verify_plaintext_authentication( &self, message: PublicMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>495 async fn verify_plaintext_authentication(
496 &self,
497 message: PublicMessage,
498 ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
499 self.inner.verify_plaintext_authentication(message).await
500 }
501
update_key_schedule( &mut self, secrets: Option<(TreeKemPrivate, PathSecret)>, _interim_transcript_hash: InterimTranscriptHash, _confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, ) -> Result<(), MlsError>502 async fn update_key_schedule(
503 &mut self,
504 secrets: Option<(TreeKemPrivate, PathSecret)>,
505 _interim_transcript_hash: InterimTranscriptHash,
506 _confirmation_tag: &ConfirmationTag,
507 provisional_public_state: ProvisionalState,
508 ) -> Result<(), MlsError> {
509 self.provisional_public_state = Some(provisional_public_state);
510 self.secrets = secrets;
511 Ok(())
512 }
513 }
514