• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use core::ops::{Deref, DerefMut};
6 
7 use alloc::format;
8 use rand::RngCore;
9 
10 use super::*;
11 use crate::{
12     client::{
13         test_utils::{
14             test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
15             TEST_PROTOCOL_VERSION,
16         },
17         MlsError,
18     },
19     client_builder::test_utils::{TestClientBuilder, TestClientConfig},
20     crypto::test_utils::test_cipher_suite_provider,
21     extension::ExtensionType,
22     identity::test_utils::get_test_signing_identity,
23     key_package::{KeyPackageGeneration, KeyPackageGenerator},
24     mls_rules::{CommitOptions, DefaultMlsRules},
25     tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
26 };
27 
28 use crate::extension::RequiredCapabilitiesExt;
29 
30 #[cfg(not(feature = "by_ref_proposal"))]
31 use crate::crypto::HpkePublicKey;
32 
33 pub const TEST_GROUP: &[u8] = b"group";
34 
35 #[derive(Clone)]
36 pub(crate) struct TestGroup {
37     pub group: Group<TestClientConfig>,
38 }
39 
40 impl TestGroup {
41     #[cfg(feature = "external_client")]
42     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
propose(&mut self, proposal: Proposal) -> MlsMessage43     pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
44         self.group.proposal_message(proposal, vec![]).await.unwrap()
45     }
46 
47     #[cfg(feature = "by_ref_proposal")]
48     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_proposal(&mut self) -> Proposal49     pub(crate) async fn update_proposal(&mut self) -> Proposal {
50         self.group.update_proposal(None, None).await.unwrap()
51     }
52 
53     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join_with_custom_config<F>( &mut self, name: &str, custom_kp: bool, mut config: F, ) -> Result<(TestGroup, MlsMessage), MlsError> where F: FnMut(&mut TestClientConfig),54     pub(crate) async fn join_with_custom_config<F>(
55         &mut self,
56         name: &str,
57         custom_kp: bool,
58         mut config: F,
59     ) -> Result<(TestGroup, MlsMessage), MlsError>
60     where
61         F: FnMut(&mut TestClientConfig),
62     {
63         let (mut new_client, new_key_package) = if custom_kp {
64             test_client_with_key_pkg_custom(
65                 self.group.protocol_version(),
66                 self.group.cipher_suite(),
67                 name,
68                 &mut config,
69             )
70             .await
71         } else {
72             test_client_with_key_pkg(
73                 self.group.protocol_version(),
74                 self.group.cipher_suite(),
75                 name,
76             )
77             .await
78         };
79 
80         // Add new member to the group
81         let CommitOutput {
82             welcome_messages,
83             ratchet_tree,
84             commit_message,
85             ..
86         } = self
87             .group
88             .commit_builder()
89             .add_member(new_key_package)
90             .unwrap()
91             .build()
92             .await
93             .unwrap();
94 
95         // Apply the commit to the original group
96         self.group.apply_pending_commit().await.unwrap();
97 
98         config(&mut new_client.config);
99 
100         // Group from new member's perspective
101         let (new_group, _) = Group::join(
102             &welcome_messages[0],
103             ratchet_tree,
104             new_client.config.clone(),
105             new_client.signer.clone().unwrap(),
106         )
107         .await?;
108 
109         let new_test_group = TestGroup { group: new_group };
110 
111         Ok((new_test_group, commit_message))
112     }
113 
114     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
join(&mut self, name: &str) -> (TestGroup, MlsMessage)115     pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
116         self.join_with_custom_config(name, false, |_| ())
117             .await
118             .unwrap()
119     }
120 
121     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_pending_commit( &mut self, ) -> Result<CommitMessageDescription, MlsError>122     pub(crate) async fn process_pending_commit(
123         &mut self,
124     ) -> Result<CommitMessageDescription, MlsError> {
125         self.group.apply_pending_commit().await
126     }
127 
128     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_message( &mut self, message: MlsMessage, ) -> Result<ReceivedMessage, MlsError>129     pub(crate) async fn process_message(
130         &mut self,
131         message: MlsMessage,
132     ) -> Result<ReceivedMessage, MlsError> {
133         self.group.process_incoming_message(message).await
134     }
135 
136     #[cfg(feature = "private_message")]
137     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_plaintext(&mut self, content: Content) -> MlsMessage138     pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
139         let auth_content = AuthenticatedContent::new_signed(
140             &self.group.cipher_suite_provider,
141             &self.group.state.context,
142             Sender::Member(*self.group.private_tree.self_index),
143             content,
144             &self.group.signer,
145             WireFormat::PublicMessage,
146             Vec::new(),
147         )
148         .await
149         .unwrap();
150 
151         self.group.format_for_wire(auth_content).await.unwrap()
152     }
153 }
154 
155 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext156 pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
157     let cs = test_cipher_suite_provider(cipher_suite);
158 
159     GroupContext {
160         protocol_version: TEST_PROTOCOL_VERSION,
161         cipher_suite,
162         group_id: TEST_GROUP.to_vec(),
163         epoch,
164         tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
165         confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
166         extensions: ExtensionList::from(vec![]),
167     }
168 }
169 
170 #[cfg(feature = "prior_epoch")]
get_test_group_context_with_id( group_id: Vec<u8>, epoch: u64, cipher_suite: CipherSuite, ) -> GroupContext171 pub(crate) fn get_test_group_context_with_id(
172     group_id: Vec<u8>,
173     epoch: u64,
174     cipher_suite: CipherSuite,
175 ) -> GroupContext {
176     GroupContext {
177         protocol_version: TEST_PROTOCOL_VERSION,
178         cipher_suite,
179         group_id,
180         epoch,
181         tree_hash: vec![],
182         confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
183         extensions: ExtensionList::from(vec![]),
184     }
185 }
186 
group_extensions() -> ExtensionList187 pub(crate) fn group_extensions() -> ExtensionList {
188     let required_capabilities = RequiredCapabilitiesExt::default();
189 
190     let mut extensions = ExtensionList::new();
191     extensions.set_from(required_capabilities).unwrap();
192     extensions
193 }
194 
lifetime() -> Lifetime195 pub(crate) fn lifetime() -> Lifetime {
196     Lifetime::years(1).unwrap()
197 }
198 
199 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_member( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, identifier: &[u8], ) -> (KeyPackageGeneration, SignatureSecretKey)200 pub(crate) async fn test_member(
201     protocol_version: ProtocolVersion,
202     cipher_suite: CipherSuite,
203     identifier: &[u8],
204 ) -> (KeyPackageGeneration, SignatureSecretKey) {
205     let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
206 
207     let key_package_generator = KeyPackageGenerator {
208         protocol_version,
209         cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
210         signing_identity: &signing_identity,
211         signing_key: &signing_key,
212     };
213 
214     let key_package = key_package_generator
215         .generate(
216             lifetime(),
217             get_test_capabilities(),
218             ExtensionList::default(),
219             ExtensionList::default(),
220         )
221         .await
222         .unwrap();
223 
224     (key_package, signing_key)
225 }
226 
227 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, extension_types: Vec<ExtensionType>, leaf_extensions: Option<ExtensionList>, commit_options: Option<CommitOptions>, ) -> TestGroup228 pub(crate) async fn test_group_custom(
229     protocol_version: ProtocolVersion,
230     cipher_suite: CipherSuite,
231     extension_types: Vec<ExtensionType>,
232     leaf_extensions: Option<ExtensionList>,
233     commit_options: Option<CommitOptions>,
234 ) -> TestGroup {
235     let leaf_extensions = leaf_extensions.unwrap_or_default();
236     let commit_options = commit_options.unwrap_or_default();
237 
238     let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
239 
240     let group = TestClientBuilder::new_for_test()
241         .leaf_node_extensions(leaf_extensions)
242         .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
243         .extension_types(extension_types)
244         .protocol_versions(ProtocolVersion::all())
245         .used_protocol_version(protocol_version)
246         .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
247         .build()
248         .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
249         .await
250         .unwrap();
251 
252     TestGroup { group }
253 }
254 
255 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, ) -> TestGroup256 pub(crate) async fn test_group(
257     protocol_version: ProtocolVersion,
258     cipher_suite: CipherSuite,
259 ) -> TestGroup {
260     test_group_custom(
261         protocol_version,
262         cipher_suite,
263         Default::default(),
264         None,
265         None,
266     )
267     .await
268 }
269 
270 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_group_custom_config<F>( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, custom: F, ) -> TestGroup where F: FnOnce(TestClientBuilder) -> TestClientBuilder,271 pub(crate) async fn test_group_custom_config<F>(
272     protocol_version: ProtocolVersion,
273     cipher_suite: CipherSuite,
274     custom: F,
275 ) -> TestGroup
276 where
277     F: FnOnce(TestClientBuilder) -> TestClientBuilder,
278 {
279     let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
280 
281     let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
282 
283     let group = custom(client_builder)
284         .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
285         .build()
286         .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
287         .await
288         .unwrap();
289 
290     TestGroup { group }
291 }
292 
293 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_n_member_group( protocol_version: ProtocolVersion, cipher_suite: CipherSuite, num_members: usize, ) -> Vec<TestGroup>294 pub(crate) async fn test_n_member_group(
295     protocol_version: ProtocolVersion,
296     cipher_suite: CipherSuite,
297     num_members: usize,
298 ) -> Vec<TestGroup> {
299     let group = test_group(protocol_version, cipher_suite).await;
300 
301     let mut groups = vec![group];
302 
303     for i in 1..num_members {
304         let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
305         process_commit(&mut groups, commit, 0).await;
306         groups.push(new_group);
307     }
308 
309     groups
310 }
311 
312 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32)313 pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
314     for g in groups
315         .iter_mut()
316         .filter(|g| g.group.current_member_index() != excluded)
317     {
318         g.process_message(commit.clone()).await.unwrap();
319     }
320 }
321 
get_test_25519_key(key_byte: u8) -> HpkePublicKey322 pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
323     vec![key_byte; 32].into()
324 }
325 
326 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_groups_with_features( n: usize, extensions: ExtensionList, leaf_extensions: ExtensionList, ) -> Vec<Group<TestClientConfig>>327 pub(crate) async fn get_test_groups_with_features(
328     n: usize,
329     extensions: ExtensionList,
330     leaf_extensions: ExtensionList,
331 ) -> Vec<Group<TestClientConfig>> {
332     let mut clients = Vec::new();
333 
334     for i in 0..n {
335         let (identity, secret_key) =
336             get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
337 
338         clients.push(
339             TestClientBuilder::new_for_test()
340                 .extension_type(999.into())
341                 .leaf_node_extensions(leaf_extensions.clone())
342                 .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
343                 .build(),
344         );
345     }
346 
347     let group = clients[0]
348         .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
349         .await
350         .unwrap();
351 
352     let mut groups = vec![group];
353 
354     for client in clients.iter().skip(1) {
355         let key_package = client.generate_key_package_message().await.unwrap();
356 
357         let commit_output = groups[0]
358             .commit_builder()
359             .add_member(key_package)
360             .unwrap()
361             .build()
362             .await
363             .unwrap();
364 
365         groups[0].apply_pending_commit().await.unwrap();
366 
367         for group in groups.iter_mut().skip(1) {
368             group
369                 .process_incoming_message(commit_output.commit_message.clone())
370                 .await
371                 .unwrap();
372         }
373 
374         groups.push(
375             client
376                 .join_group(None, &commit_output.welcome_messages[0])
377                 .await
378                 .unwrap()
379                 .0,
380         );
381     }
382 
383     groups
384 }
385 
random_bytes(count: usize) -> Vec<u8>386 pub fn random_bytes(count: usize) -> Vec<u8> {
387     let mut buf = vec![0; count];
388     rand::thread_rng().fill_bytes(&mut buf);
389     buf
390 }
391 
392 pub(crate) struct GroupWithoutKeySchedule {
393     inner: Group<TestClientConfig>,
394     pub secrets: Option<(TreeKemPrivate, PathSecret)>,
395     pub provisional_public_state: Option<ProvisionalState>,
396 }
397 
398 impl Deref for GroupWithoutKeySchedule {
399     type Target = Group<TestClientConfig>;
400 
401     #[cfg_attr(coverage_nightly, coverage(off))]
deref(&self) -> &Self::Target402     fn deref(&self) -> &Self::Target {
403         &self.inner
404     }
405 }
406 
407 impl DerefMut for GroupWithoutKeySchedule {
408     #[cfg_attr(coverage_nightly, coverage(off))]
deref_mut(&mut self) -> &mut Self::Target409     fn deref_mut(&mut self) -> &mut Self::Target {
410         &mut self.inner
411     }
412 }
413 
414 #[cfg(feature = "rfc_compliant")]
415 impl GroupWithoutKeySchedule {
416     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new(cs: CipherSuite) -> Self417     pub async fn new(cs: CipherSuite) -> Self {
418         Self {
419             inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
420             secrets: None,
421             provisional_public_state: None,
422         }
423     }
424 }
425 
426 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
427 #[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
428 #[cfg_attr(
429     all(not(target_arch = "wasm32"), mls_build_async),
430     maybe_async::must_be_async
431 )]
432 impl MessageProcessor for GroupWithoutKeySchedule {
433     type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
434     type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
435     type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
436     type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
437     type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
438 
group_state(&self) -> &GroupState439     fn group_state(&self) -> &GroupState {
440         self.inner.group_state()
441     }
442 
443     #[cfg_attr(coverage_nightly, coverage(off))]
group_state_mut(&mut self) -> &mut GroupState444     fn group_state_mut(&mut self) -> &mut GroupState {
445         self.inner.group_state_mut()
446     }
447 
mls_rules(&self) -> Self::MlsRules448     fn mls_rules(&self) -> Self::MlsRules {
449         self.inner.mls_rules()
450     }
451 
identity_provider(&self) -> Self::IdentityProvider452     fn identity_provider(&self) -> Self::IdentityProvider {
453         self.inner.identity_provider()
454     }
455 
cipher_suite_provider(&self) -> &Self::CipherSuiteProvider456     fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
457         self.inner.cipher_suite_provider()
458     }
459 
psk_storage(&self) -> Self::PreSharedKeyStorage460     fn psk_storage(&self) -> Self::PreSharedKeyStorage {
461         self.inner.psk_storage()
462     }
463 
can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool464     fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
465         self.inner.can_continue_processing(provisional_state)
466     }
467 
468     #[cfg(feature = "private_message")]
469     #[cfg_attr(coverage_nightly, coverage(off))]
min_epoch_available(&self) -> Option<u64>470     fn min_epoch_available(&self) -> Option<u64> {
471         self.inner.min_epoch_available()
472     }
473 
apply_update_path( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, provisional_state: &mut ProvisionalState, ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError>474     async fn apply_update_path(
475         &mut self,
476         sender: LeafIndex,
477         update_path: &ValidatedUpdatePath,
478         provisional_state: &mut ProvisionalState,
479     ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
480         self.inner
481             .apply_update_path(sender, update_path, provisional_state)
482             .await
483     }
484 
485     #[cfg(feature = "private_message")]
486     #[cfg_attr(coverage_nightly, coverage(off))]
process_ciphertext( &mut self, cipher_text: &PrivateMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>487     async fn process_ciphertext(
488         &mut self,
489         cipher_text: &PrivateMessage,
490     ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
491         self.inner.process_ciphertext(cipher_text).await
492     }
493 
494     #[cfg_attr(coverage_nightly, coverage(off))]
verify_plaintext_authentication( &self, message: PublicMessage, ) -> Result<EventOrContent<Self::OutputType>, MlsError>495     async fn verify_plaintext_authentication(
496         &self,
497         message: PublicMessage,
498     ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
499         self.inner.verify_plaintext_authentication(message).await
500     }
501 
update_key_schedule( &mut self, secrets: Option<(TreeKemPrivate, PathSecret)>, _interim_transcript_hash: InterimTranscriptHash, _confirmation_tag: &ConfirmationTag, provisional_public_state: ProvisionalState, ) -> Result<(), MlsError>502     async fn update_key_schedule(
503         &mut self,
504         secrets: Option<(TreeKemPrivate, PathSecret)>,
505         _interim_transcript_hash: InterimTranscriptHash,
506         _confirmation_tag: &ConfirmationTag,
507         provisional_public_state: ProvisionalState,
508     ) -> Result<(), MlsError> {
509         self.provisional_public_state = Some(provisional_public_state);
510         self.secrets = secrets;
511         Ok(())
512     }
513 }
514