• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use alloc::vec;
6 use alloc::vec::Vec;
7 use core::fmt::Debug;
8 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
9 use mls_rs_core::{crypto::SignatureSecretKey, error::IntoAnyError};
10 
11 use crate::{
12     cipher_suite::CipherSuite,
13     client::MlsError,
14     client_config::ClientConfig,
15     extension::RatchetTreeExt,
16     identity::SigningIdentity,
17     protocol_version::ProtocolVersion,
18     signer::Signable,
19     tree_kem::{
20         kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
21     },
22     ExtensionList, MlsRules,
23 };
24 
25 #[cfg(all(not(mls_build_async), feature = "rayon"))]
26 use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
27 
28 use crate::tree_kem::leaf_node::LeafNode;
29 
30 #[cfg(not(feature = "private_message"))]
31 use crate::WireFormat;
32 
33 #[cfg(feature = "psk")]
34 use crate::{
35     group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
36     psk::ExternalPskId,
37 };
38 
39 use super::{
40     confirmation_tag::ConfirmationTag,
41     framing::{Content, MlsMessage, MlsMessagePayload, Sender},
42     key_schedule::{KeySchedule, WelcomeSecret},
43     message_hash::MessageHash,
44     message_processor::{path_update_required, MessageProcessor},
45     message_signature::AuthenticatedContent,
46     mls_rules::CommitDirection,
47     proposal::{Proposal, ProposalOrRef},
48     ConfirmedTranscriptHash, EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo,
49     Welcome,
50 };
51 
52 #[cfg(not(feature = "by_ref_proposal"))]
53 use super::proposal_cache::prepare_commit;
54 
55 #[cfg(feature = "custom_proposal")]
56 use super::proposal::CustomProposal;
57 
58 #[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
59 #[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
60 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61 pub(crate) struct Commit {
62     pub proposals: Vec<ProposalOrRef>,
63     pub path: Option<UpdatePath>,
64 }
65 
66 #[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
67 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68 pub(super) struct CommitGeneration {
69     pub content: AuthenticatedContent,
70     pub pending_private_tree: TreeKemPrivate,
71     pub pending_commit_secret: PathSecret,
72     pub commit_message_hash: MessageHash,
73 }
74 
75 #[cfg_attr(
76     all(feature = "ffi", not(test)),
77     safer_ffi_gen::ffi_type(clone, opaque)
78 )]
79 #[derive(Clone, Debug)]
80 #[non_exhaustive]
81 /// Result of MLS commit operation using
82 /// [`Group::commit`](crate::group::Group::commit) or
83 /// [`CommitBuilder::build`](CommitBuilder::build).
84 pub struct CommitOutput {
85     /// Commit message to send to other group members.
86     pub commit_message: MlsMessage,
87     /// Welcome messages to send to new group members. If the commit does not add members,
88     /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
89     /// set to true, then this list contains a single message sent to all members. Else, the list
90     /// contains one message for each added member. Recipients of each message can be identified using
91     /// [`MlsMessage::key_package_reference`] of their key packages and
92     /// [`MlsMessage::welcome_key_package_references`].
93     pub welcome_messages: Vec<MlsMessage>,
94     /// Ratchet tree that can be sent out of band if
95     /// `ratchet_tree_extension` is not used according to
96     /// [`MlsRules::commit_options`].
97     pub ratchet_tree: Option<ExportedTree<'static>>,
98     /// A group info that can be provided to new members in order to enable external commit
99     /// functionality. This value is set if [`MlsRules::commit_options`] returns
100     /// `allow_external_commit` set to true.
101     pub external_commit_group_info: Option<MlsMessage>,
102     /// Proposals that were received in the prior epoch but not included in the following commit.
103     #[cfg(feature = "by_ref_proposal")]
104     pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
105 }
106 
107 #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
108 impl CommitOutput {
109     /// Commit message to send to other group members.
110     #[cfg(feature = "ffi")]
commit_message(&self) -> &MlsMessage111     pub fn commit_message(&self) -> &MlsMessage {
112         &self.commit_message
113     }
114 
115     /// Welcome message to send to new group members.
116     #[cfg(feature = "ffi")]
welcome_messages(&self) -> &[MlsMessage]117     pub fn welcome_messages(&self) -> &[MlsMessage] {
118         &self.welcome_messages
119     }
120 
121     /// Ratchet tree that can be sent out of band if
122     /// `ratchet_tree_extension` is not used according to
123     /// [`MlsRules::commit_options`].
124     #[cfg(feature = "ffi")]
ratchet_tree(&self) -> Option<&ExportedTree<'static>>125     pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
126         self.ratchet_tree.as_ref()
127     }
128 
129     /// A group info that can be provided to new members in order to enable external commit
130     /// functionality. This value is set if [`MlsRules::commit_options`] returns
131     /// `allow_external_commit` set to true.
132     #[cfg(feature = "ffi")]
external_commit_group_info(&self) -> Option<&MlsMessage>133     pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
134         self.external_commit_group_info.as_ref()
135     }
136 
137     /// Proposals that were received in the prior epoch but not included in the following commit.
138     #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>]139     pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
140         &self.unused_proposals
141     }
142 }
143 
144 /// Build a commit with multiple proposals by-value.
145 ///
146 /// Proposals within a commit can be by-value or by-reference.
147 /// Proposals received during the current epoch will be added to the resulting
148 /// commit by-reference automatically so long as they pass the rules defined
149 /// in the current
150 /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
151 pub struct CommitBuilder<'a, C>
152 where
153     C: ClientConfig + Clone,
154 {
155     group: &'a mut Group<C>,
156     pub(super) proposals: Vec<Proposal>,
157     authenticated_data: Vec<u8>,
158     group_info_extensions: ExtensionList,
159     new_signer: Option<SignatureSecretKey>,
160     new_signing_identity: Option<SigningIdentity>,
161 }
162 
163 impl<'a, C> CommitBuilder<'a, C>
164 where
165     C: ClientConfig + Clone,
166 {
167     /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
168     /// the current commit that is being built.
add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError>169     pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
170         let proposal = self.group.add_proposal(key_package)?;
171         self.proposals.push(proposal);
172         Ok(self)
173     }
174 
175     /// Set group info extensions that will be inserted into the resulting
176     /// [welcome messages](CommitOutput::welcome_messages) for new members.
177     ///
178     /// Group info extensions that are transmitted as part of a welcome message
179     /// are encrypted along with other private values.
180     ///
181     /// These extensions can be retrieved as part of
182     /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
183     /// by joining the group via
184     /// [`Client::join_group`](crate::Client::join_group).
set_group_info_ext(self, extensions: ExtensionList) -> Self185     pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
186         Self {
187             group_info_extensions: extensions,
188             ..self
189         }
190     }
191 
192     /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
193     /// the current commit that is being built.
remove_member(mut self, index: u32) -> Result<Self, MlsError>194     pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
195         let proposal = self.group.remove_proposal(index)?;
196         self.proposals.push(proposal);
197         Ok(self)
198     }
199 
200     /// Insert a
201     /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
202     /// into the current commit that is being built.
set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError>203     pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
204         let proposal = self.group.group_context_extensions_proposal(extensions);
205         self.proposals.push(proposal);
206         Ok(self)
207     }
208 
209     /// Insert a
210     /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
211     /// an external PSK into the current commit that is being built.
212     #[cfg(feature = "psk")]
add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError>213     pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
214         let key_id = JustPreSharedKeyID::External(psk_id);
215         let proposal = self.group.psk_proposal(key_id)?;
216         self.proposals.push(proposal);
217         Ok(self)
218     }
219 
220     /// Insert a
221     /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
222     /// a resumption PSK into the current commit that is being built.
223     #[cfg(feature = "psk")]
add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError>224     pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
225         let psk_id = ResumptionPsk {
226             psk_epoch,
227             usage: ResumptionPSKUsage::Application,
228             psk_group_id: PskGroupId(self.group.group_id().to_vec()),
229         };
230 
231         let key_id = JustPreSharedKeyID::Resumption(psk_id);
232         let proposal = self.group.psk_proposal(key_id)?;
233         self.proposals.push(proposal);
234         Ok(self)
235     }
236 
237     /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
238     /// the current commit that is being built.
reinit( mut self, group_id: Option<Vec<u8>>, version: ProtocolVersion, cipher_suite: CipherSuite, extensions: ExtensionList, ) -> Result<Self, MlsError>239     pub fn reinit(
240         mut self,
241         group_id: Option<Vec<u8>>,
242         version: ProtocolVersion,
243         cipher_suite: CipherSuite,
244         extensions: ExtensionList,
245     ) -> Result<Self, MlsError> {
246         let proposal = self
247             .group
248             .reinit_proposal(group_id, version, cipher_suite, extensions)?;
249 
250         self.proposals.push(proposal);
251         Ok(self)
252     }
253 
254     /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
255     /// the current commit that is being built.
256     #[cfg(feature = "custom_proposal")]
custom_proposal(mut self, proposal: CustomProposal) -> Self257     pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
258         self.proposals.push(Proposal::Custom(proposal));
259         self
260     }
261 
262     /// Insert a proposal that was previously constructed such as when a
263     /// proposal is returned from
264     /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
raw_proposal(mut self, proposal: Proposal) -> Self265     pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
266         self.proposals.push(proposal);
267         self
268     }
269 
270     /// Insert proposals that were previously constructed such as when a
271     /// proposal is returned from
272     /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self273     pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
274         self.proposals.append(&mut proposals);
275         self
276     }
277 
278     /// Add additional authenticated data to the commit.
279     ///
280     /// # Warning
281     ///
282     /// The data provided here is always sent unencrypted.
authenticated_data(self, authenticated_data: Vec<u8>) -> Self283     pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
284         Self {
285             authenticated_data,
286             ..self
287         }
288     }
289 
290     /// Change the committer's signing identity as part of making this commit.
291     /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
292     /// in use by the group considers the credential inside this signing_identity
293     /// [valid](crate::IdentityProvider::validate_member)
294     /// and results in the same
295     /// [identity](crate::IdentityProvider::identity)
296     /// being used.
set_new_signing_identity( self, signer: SignatureSecretKey, signing_identity: SigningIdentity, ) -> Self297     pub fn set_new_signing_identity(
298         self,
299         signer: SignatureSecretKey,
300         signing_identity: SigningIdentity,
301     ) -> Self {
302         Self {
303             new_signer: Some(signer),
304             new_signing_identity: Some(signing_identity),
305             ..self
306         }
307     }
308 
309     /// Finalize the commit to send.
310     ///
311     /// # Errors
312     ///
313     /// This function will return an error if any of the proposals provided
314     /// are not contextually valid according to the rules defined by the
315     /// MLS RFC, or if they do not pass the custom rules defined by the current
316     /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
317     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
build(self) -> Result<CommitOutput, MlsError>318     pub async fn build(self) -> Result<CommitOutput, MlsError> {
319         self.group
320             .commit_internal(
321                 self.proposals,
322                 None,
323                 self.authenticated_data,
324                 self.group_info_extensions,
325                 self.new_signer,
326                 self.new_signing_identity,
327             )
328             .await
329     }
330 }
331 
332 impl<C> Group<C>
333 where
334     C: ClientConfig + Clone,
335 {
336     /// Perform a commit of received proposals.
337     ///
338     /// This function is the equivalent of [`Group::commit_builder`] immediately
339     /// followed by [`CommitBuilder::build`]. Any received proposals since the
340     /// last commit will be included in the resulting message by-reference.
341     ///
342     /// Data provided in the `authenticated_data` field will be placed into
343     /// the resulting commit message unencrypted.
344     ///
345     /// # Pending Commits
346     ///
347     /// When a commit is created, it is not applied immediately in order to
348     /// allow for the resolution of conflicts when multiple members of a group
349     /// attempt to make commits at the same time. For example, a central relay
350     /// can be used to decide which commit should be accepted by the group by
351     /// determining a consistent view of commit packet order for all clients.
352     ///
353     /// Pending commits are stored internally as part of the group's state
354     /// so they do not need to be tracked outside of this library. Any commit
355     /// message that is processed before calling [Group::apply_pending_commit]
356     /// will clear the currently pending commit.
357     ///
358     /// # Empty Commits
359     ///
360     /// Sending a commit that contains no proposals is a valid operation
361     /// within the MLS protocol. It is useful for providing stronger forward
362     /// secrecy and post-compromise security, especially for long running
363     /// groups when group membership does not change often.
364     ///
365     /// # Path Updates
366     ///
367     /// Path updates provide forward secrecy and post-compromise security
368     /// within the MLS protocol.
369     /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
370     /// controls the ability of a group to send a commit without a path update.
371     /// An update path will automatically be sent if there are no proposals
372     /// in the commit, or if any proposal other than
373     /// [`Add`](crate::group::proposal::Proposal::Add),
374     /// [`Psk`](crate::group::proposal::Proposal::Psk),
375     /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
376     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError>377     pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
378         self.commit_internal(
379             vec![],
380             None,
381             authenticated_data,
382             Default::default(),
383             None,
384             None,
385         )
386         .await
387     }
388 
389     /// Create a new commit builder that can include proposals
390     /// by-value.
commit_builder(&mut self) -> CommitBuilder<C>391     pub fn commit_builder(&mut self) -> CommitBuilder<C> {
392         CommitBuilder {
393             group: self,
394             proposals: Default::default(),
395             authenticated_data: Default::default(),
396             group_info_extensions: Default::default(),
397             new_signer: Default::default(),
398             new_signing_identity: Default::default(),
399         }
400     }
401 
402     /// Returns commit and optional [`MlsMessage`] containing a welcome message
403     /// for newly added members.
404     #[allow(clippy::too_many_arguments)]
405     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
commit_internal( &mut self, proposals: Vec<Proposal>, external_leaf: Option<&LeafNode>, authenticated_data: Vec<u8>, mut welcome_group_info_extensions: ExtensionList, new_signer: Option<SignatureSecretKey>, new_signing_identity: Option<SigningIdentity>, ) -> Result<CommitOutput, MlsError>406     pub(super) async fn commit_internal(
407         &mut self,
408         proposals: Vec<Proposal>,
409         external_leaf: Option<&LeafNode>,
410         authenticated_data: Vec<u8>,
411         mut welcome_group_info_extensions: ExtensionList,
412         new_signer: Option<SignatureSecretKey>,
413         new_signing_identity: Option<SigningIdentity>,
414     ) -> Result<CommitOutput, MlsError> {
415         if self.pending_commit.is_some() {
416             return Err(MlsError::ExistingPendingCommit);
417         }
418 
419         if self.state.pending_reinit.is_some() {
420             return Err(MlsError::GroupUsedAfterReInit);
421         }
422 
423         let mls_rules = self.config.mls_rules();
424 
425         let is_external = external_leaf.is_some();
426 
427         // Construct an initial Commit object with the proposals field populated from Proposals
428         // received during the current epoch, and an empty path field. Add passed in proposals
429         // by value
430         let sender = if is_external {
431             Sender::NewMemberCommit
432         } else {
433             Sender::Member(*self.private_tree.self_index)
434         };
435 
436         let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer);
437         let old_signer = &self.signer;
438 
439         #[cfg(feature = "std")]
440         let time = Some(crate::time::MlsTime::now());
441 
442         #[cfg(not(feature = "std"))]
443         let time = None;
444 
445         #[cfg(feature = "by_ref_proposal")]
446         let proposals = self.state.proposals.prepare_commit(sender, proposals);
447 
448         #[cfg(not(feature = "by_ref_proposal"))]
449         let proposals = prepare_commit(sender, proposals);
450 
451         let mut provisional_state = self
452             .state
453             .apply_resolved(
454                 sender,
455                 proposals,
456                 external_leaf,
457                 &self.config.identity_provider(),
458                 &self.cipher_suite_provider,
459                 &self.config.secret_store(),
460                 &mls_rules,
461                 time,
462                 CommitDirection::Send,
463             )
464             .await?;
465 
466         let (mut provisional_private_tree, _) =
467             self.provisional_private_tree(&provisional_state)?;
468 
469         if is_external {
470             provisional_private_tree.self_index = provisional_state
471                 .external_init_index
472                 .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
473 
474             self.private_tree.self_index = provisional_private_tree.self_index;
475         }
476 
477         let mut provisional_group_context = provisional_state.group_context;
478 
479         // Decide whether to populate the path field: If the path field is required based on the
480         // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
481         // sender MAY omit the path field at its discretion.
482         let commit_options = mls_rules
483             .commit_options(
484                 &provisional_state.public_tree.roster(),
485                 &provisional_group_context.extensions,
486                 &provisional_state.applied_proposals,
487             )
488             .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
489 
490         let perform_path_update = commit_options.path_required
491             || path_update_required(&provisional_state.applied_proposals);
492 
493         let (update_path, path_secrets, commit_secret) = if perform_path_update {
494             // If populating the path field: Create an UpdatePath using the new tree. Any new
495             // member (from an add proposal) MUST be excluded from the resolution during the
496             // computation of the UpdatePath. The GroupContext for this operation uses the
497             // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
498             // GroupContext object. The leaf_key_package for this UpdatePath must have a
499             // parent_hash extension.
500             let encap_gen = TreeKem::new(
501                 &mut provisional_state.public_tree,
502                 &mut provisional_private_tree,
503             )
504             .encap(
505                 &mut provisional_group_context,
506                 &provisional_state.indexes_of_added_kpkgs,
507                 new_signer_ref,
508                 self.config.leaf_properties(),
509                 new_signing_identity,
510                 &self.cipher_suite_provider,
511                 #[cfg(test)]
512                 &self.commit_modifiers,
513             )
514             .await?;
515 
516             (
517                 Some(encap_gen.update_path),
518                 Some(encap_gen.path_secrets),
519                 encap_gen.commit_secret,
520             )
521         } else {
522             // Update the tree hash, since it was not updated by encap.
523             provisional_state
524                 .public_tree
525                 .update_hashes(
526                     &[provisional_private_tree.self_index],
527                     &self.cipher_suite_provider,
528                 )
529                 .await?;
530 
531             provisional_group_context.tree_hash = provisional_state
532                 .public_tree
533                 .tree_hash(&self.cipher_suite_provider)
534                 .await?;
535 
536             (None, None, PathSecret::empty(&self.cipher_suite_provider))
537         };
538 
539         #[cfg(feature = "psk")]
540         let (psk_secret, psks) = self
541             .get_psk(&provisional_state.applied_proposals.psks)
542             .await?;
543 
544         #[cfg(not(feature = "psk"))]
545         let psk_secret = self.get_psk();
546 
547         let added_key_pkgs: Vec<_> = provisional_state
548             .applied_proposals
549             .additions
550             .iter()
551             .map(|info| info.proposal.key_package.clone())
552             .collect();
553 
554         let commit = Commit {
555             proposals: provisional_state.applied_proposals.into_proposals_or_refs(),
556             path: update_path,
557         };
558 
559         let mut auth_content = AuthenticatedContent::new_signed(
560             &self.cipher_suite_provider,
561             self.context(),
562             sender,
563             Content::Commit(alloc::boxed::Box::new(commit)),
564             old_signer,
565             #[cfg(feature = "private_message")]
566             self.encryption_options()?.control_wire_format(sender),
567             #[cfg(not(feature = "private_message"))]
568             WireFormat::PublicMessage,
569             authenticated_data,
570         )
571         .await?;
572 
573         // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
574         // compute the confirmation_tag value in the MlsPlaintext.
575         let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
576             self.cipher_suite_provider(),
577             &self.state.interim_transcript_hash,
578             &auth_content,
579         )
580         .await?;
581 
582         provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash;
583 
584         let key_schedule_result = KeySchedule::from_key_schedule(
585             &self.key_schedule,
586             &commit_secret,
587             &provisional_group_context,
588             #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
589             self.state.public_tree.total_leaf_count(),
590             &psk_secret,
591             &self.cipher_suite_provider,
592         )
593         .await?;
594 
595         let confirmation_tag = ConfirmationTag::create(
596             &key_schedule_result.confirmation_key,
597             &provisional_group_context.confirmed_transcript_hash,
598             &self.cipher_suite_provider,
599         )
600         .await?;
601 
602         auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
603 
604         let ratchet_tree_ext = commit_options
605             .ratchet_tree_extension
606             .then(|| RatchetTreeExt {
607                 tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
608             });
609 
610         // Generate external commit group info if required by commit_options
611         let external_commit_group_info = match commit_options.allow_external_commit {
612             true => {
613                 let mut extensions = ExtensionList::new();
614 
615                 extensions.set_from({
616                     key_schedule_result
617                         .key_schedule
618                         .get_external_key_pair_ext(&self.cipher_suite_provider)
619                         .await?
620                 })?;
621 
622                 if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
623                     extensions.set_from(ratchet_tree_ext.clone())?;
624                 }
625 
626                 let info = self
627                     .make_group_info(
628                         &provisional_group_context,
629                         extensions,
630                         &confirmation_tag,
631                         new_signer_ref,
632                     )
633                     .await?;
634 
635                 let msg =
636                     MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
637 
638                 Some(msg)
639             }
640             false => None,
641         };
642 
643         // Build the group info that will be placed into the welcome messages.
644         // Add the ratchet tree extension if necessary
645         if let Some(ratchet_tree_ext) = ratchet_tree_ext {
646             welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
647         }
648 
649         let welcome_group_info = self
650             .make_group_info(
651                 &provisional_group_context,
652                 welcome_group_info_extensions,
653                 &confirmation_tag,
654                 new_signer_ref,
655             )
656             .await?;
657 
658         // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
659         // the new epoch
660         let welcome_secret = WelcomeSecret::from_joiner_secret(
661             &self.cipher_suite_provider,
662             &key_schedule_result.joiner_secret,
663             &psk_secret,
664         )
665         .await?;
666 
667         let encrypted_group_info = welcome_secret
668             .encrypt(&welcome_group_info.mls_encode_to_vec()?)
669             .await?;
670 
671         // Encrypt path secrets and joiner secret to new members
672         let path_secrets = path_secrets.as_ref();
673 
674         #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
675         let encrypted_path_secrets: Vec<_> = added_key_pkgs
676             .into_par_iter()
677             .zip(provisional_state.indexes_of_added_kpkgs)
678             .map(|(key_package, leaf_index)| {
679                 self.encrypt_group_secrets(
680                     &key_package,
681                     leaf_index,
682                     &key_schedule_result.joiner_secret,
683                     path_secrets,
684                     #[cfg(feature = "psk")]
685                     psks.clone(),
686                     &encrypted_group_info,
687                 )
688             })
689             .try_collect()?;
690 
691         #[cfg(any(mls_build_async, not(feature = "rayon")))]
692         let encrypted_path_secrets = {
693             let mut secrets = Vec::new();
694 
695             for (key_package, leaf_index) in added_key_pkgs
696                 .into_iter()
697                 .zip(provisional_state.indexes_of_added_kpkgs)
698             {
699                 secrets.push(
700                     self.encrypt_group_secrets(
701                         &key_package,
702                         leaf_index,
703                         &key_schedule_result.joiner_secret,
704                         path_secrets,
705                         #[cfg(feature = "psk")]
706                         psks.clone(),
707                         &encrypted_group_info,
708                     )
709                     .await?,
710                 );
711             }
712 
713             secrets
714         };
715 
716         let welcome_messages =
717             if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
718                 vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
719             } else {
720                 encrypted_path_secrets
721                     .into_iter()
722                     .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
723                     .collect()
724             };
725 
726         let commit_message = self.format_for_wire(auth_content.clone()).await?;
727 
728         let pending_commit = CommitGeneration {
729             content: auth_content,
730             pending_private_tree: provisional_private_tree,
731             pending_commit_secret: commit_secret,
732             commit_message_hash: MessageHash::compute(&self.cipher_suite_provider, &commit_message)
733                 .await?,
734         };
735 
736         self.pending_commit = Some(pending_commit);
737 
738         let ratchet_tree = (!commit_options.ratchet_tree_extension)
739             .then(|| ExportedTree::new(provisional_state.public_tree.nodes));
740 
741         if let Some(signer) = new_signer {
742             self.signer = signer;
743         }
744 
745         Ok(CommitOutput {
746             commit_message,
747             welcome_messages,
748             ratchet_tree,
749             external_commit_group_info,
750             #[cfg(feature = "by_ref_proposal")]
751             unused_proposals: provisional_state.unused_proposals,
752         })
753     }
754 
755     // Construct a GroupInfo reflecting the new state
756     // Group ID, epoch, tree, and confirmed transcript hash from the new state
757     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
make_group_info( &self, group_context: &GroupContext, extensions: ExtensionList, confirmation_tag: &ConfirmationTag, signer: &SignatureSecretKey, ) -> Result<GroupInfo, MlsError>758     async fn make_group_info(
759         &self,
760         group_context: &GroupContext,
761         extensions: ExtensionList,
762         confirmation_tag: &ConfirmationTag,
763         signer: &SignatureSecretKey,
764     ) -> Result<GroupInfo, MlsError> {
765         let mut group_info = GroupInfo {
766             group_context: group_context.clone(),
767             extensions,
768             confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
769             signer: LeafIndex(self.current_member_index()),
770             signature: vec![],
771         };
772 
773         group_info.grease(self.cipher_suite_provider())?;
774 
775         // Sign the GroupInfo using the member's private signing key
776         group_info
777             .sign(&self.cipher_suite_provider, signer, &())
778             .await?;
779 
780         Ok(group_info)
781     }
782 
make_welcome_message( &self, secrets: Vec<EncryptedGroupSecrets>, encrypted_group_info: Vec<u8>, ) -> MlsMessage783     fn make_welcome_message(
784         &self,
785         secrets: Vec<EncryptedGroupSecrets>,
786         encrypted_group_info: Vec<u8>,
787     ) -> MlsMessage {
788         MlsMessage::new(
789             self.context().protocol_version,
790             MlsMessagePayload::Welcome(Welcome {
791                 cipher_suite: self.context().cipher_suite,
792                 secrets,
793                 encrypted_group_info,
794             }),
795         )
796     }
797 }
798 
799 #[cfg(test)]
800 pub(crate) mod test_utils {
801     use alloc::vec::Vec;
802 
803     use crate::{
804         crypto::SignatureSecretKey,
805         tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
806     };
807 
808     #[derive(Copy, Clone, Debug)]
809     pub struct CommitModifiers {
810         pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
811         pub modify_tree: fn(&mut TreeKemPublic),
812         pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
813     }
814 
815     impl Default for CommitModifiers {
default() -> Self816         fn default() -> Self {
817             Self {
818                 modify_leaf: |_, _| None,
819                 modify_tree: |_| (),
820                 modify_path: |a| a,
821             }
822         }
823     }
824 }
825 
826 #[cfg(test)]
827 mod tests {
828     use alloc::boxed::Box;
829 
830     use mls_rs_core::{
831         error::IntoAnyError,
832         extension::ExtensionType,
833         identity::{CredentialType, IdentityProvider},
834         time::MlsTime,
835     };
836 
837     use crate::{
838         crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
839         group::{mls_rules::DefaultMlsRules, test_utils::test_group_custom},
840         mls_rules::CommitOptions,
841         Client,
842     };
843 
844     #[cfg(feature = "by_ref_proposal")]
845     use crate::extension::ExternalSendersExt;
846 
847     use crate::{
848         client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
849         client_builder::{
850             test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
851             WithIdentityProvider,
852         },
853         client_config::ClientConfig,
854         extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
855         group::{
856             proposal::ProposalType,
857             test_utils::{test_group_custom_config, test_n_member_group},
858         },
859         identity::test_utils::get_test_signing_identity,
860         identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
861         key_package::test_utils::test_key_package_message,
862     };
863 
864     use crate::extension::RequiredCapabilitiesExt;
865 
866     #[cfg(feature = "psk")]
867     use crate::{
868         group::proposal::PreSharedKeyProposal,
869         psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
870     };
871 
872     use super::*;
873 
874     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_commit_builder_group() -> Group<TestClientConfig>875     async fn test_commit_builder_group() -> Group<TestClientConfig> {
876         test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
877             b.custom_proposal_type(ProposalType::from(42))
878                 .extension_type(TEST_EXTENSION_TYPE.into())
879         })
880         .await
881         .group
882     }
883 
assert_commit_builder_output<C: ClientConfig>( group: Group<C>, mut commit_output: CommitOutput, expected: Vec<Proposal>, welcome_count: usize, )884     fn assert_commit_builder_output<C: ClientConfig>(
885         group: Group<C>,
886         mut commit_output: CommitOutput,
887         expected: Vec<Proposal>,
888         welcome_count: usize,
889     ) {
890         let plaintext = commit_output.commit_message.into_plaintext().unwrap();
891 
892         let commit_data = match plaintext.content.content {
893             Content::Commit(commit) => commit,
894             #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
895             _ => panic!("Found non-commit data"),
896         };
897 
898         assert_eq!(commit_data.proposals.len(), expected.len());
899 
900         commit_data.proposals.into_iter().for_each(|proposal| {
901             let proposal = match proposal {
902                 ProposalOrRef::Proposal(p) => p,
903                 #[cfg(feature = "by_ref_proposal")]
904                 ProposalOrRef::Reference(_) => panic!("found proposal reference"),
905             };
906 
907             #[cfg(feature = "psk")]
908             if let Some(psk_id) = match proposal.as_ref() {
909                 Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
910                 _ => None,
911             } {
912                 let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
913 
914                 assert!(found)
915             } else {
916                 assert!(expected.contains(&proposal));
917             }
918 
919             #[cfg(not(feature = "psk"))]
920             assert!(expected.contains(&proposal));
921         });
922 
923         if welcome_count > 0 {
924             let welcome_msg = commit_output.welcome_messages.pop().unwrap();
925 
926             assert_eq!(welcome_msg.version, group.state.context.protocol_version);
927 
928             let welcome_msg = welcome_msg.into_welcome().unwrap();
929 
930             assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
931             assert_eq!(welcome_msg.secrets.len(), welcome_count);
932         } else {
933             assert!(commit_output.welcome_messages.is_empty());
934         }
935     }
936 
937     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_add()938     async fn test_commit_builder_add() {
939         let mut group = test_commit_builder_group().await;
940 
941         let test_key_package =
942             test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
943 
944         let commit_output = group
945             .commit_builder()
946             .add_member(test_key_package.clone())
947             .unwrap()
948             .build()
949             .await
950             .unwrap();
951 
952         let expected_add = group.add_proposal(test_key_package).unwrap();
953 
954         assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
955     }
956 
957     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_add_with_ext()958     async fn test_commit_builder_add_with_ext() {
959         let mut group = test_commit_builder_group().await;
960 
961         let (bob_client, bob_key_package) =
962             test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
963 
964         let ext = TestExtension { foo: 42 };
965         let mut extension_list = ExtensionList::default();
966         extension_list.set_from(ext.clone()).unwrap();
967 
968         let welcome_message = group
969             .commit_builder()
970             .add_member(bob_key_package)
971             .unwrap()
972             .set_group_info_ext(extension_list)
973             .build()
974             .await
975             .unwrap()
976             .welcome_messages
977             .remove(0);
978 
979         let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
980 
981         assert_eq!(
982             context
983                 .group_info_extensions
984                 .get_as::<TestExtension>()
985                 .unwrap()
986                 .unwrap(),
987             ext
988         );
989     }
990 
991     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_remove()992     async fn test_commit_builder_remove() {
993         let mut group = test_commit_builder_group().await;
994         let test_key_package =
995             test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
996 
997         group
998             .commit_builder()
999             .add_member(test_key_package)
1000             .unwrap()
1001             .build()
1002             .await
1003             .unwrap();
1004 
1005         group.apply_pending_commit().await.unwrap();
1006 
1007         let commit_output = group
1008             .commit_builder()
1009             .remove_member(1)
1010             .unwrap()
1011             .build()
1012             .await
1013             .unwrap();
1014 
1015         let expected_remove = group.remove_proposal(1).unwrap();
1016 
1017         assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
1018     }
1019 
1020     #[cfg(feature = "psk")]
1021     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_psk()1022     async fn test_commit_builder_psk() {
1023         let mut group = test_commit_builder_group().await;
1024         let test_psk = ExternalPskId::new(vec![1]);
1025 
1026         group
1027             .config
1028             .secret_store()
1029             .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
1030 
1031         let commit_output = group
1032             .commit_builder()
1033             .add_external_psk(test_psk.clone())
1034             .unwrap()
1035             .build()
1036             .await
1037             .unwrap();
1038 
1039         let key_id = JustPreSharedKeyID::External(test_psk);
1040         let expected_psk = group.psk_proposal(key_id).unwrap();
1041 
1042         assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
1043     }
1044 
1045     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_group_context_ext()1046     async fn test_commit_builder_group_context_ext() {
1047         let mut group = test_commit_builder_group().await;
1048         let mut test_ext = ExtensionList::default();
1049         test_ext
1050             .set_from(RequiredCapabilitiesExt::default())
1051             .unwrap();
1052 
1053         let commit_output = group
1054             .commit_builder()
1055             .set_group_context_ext(test_ext.clone())
1056             .unwrap()
1057             .build()
1058             .await
1059             .unwrap();
1060 
1061         let expected_ext = group.group_context_extensions_proposal(test_ext);
1062 
1063         assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
1064     }
1065 
1066     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_reinit()1067     async fn test_commit_builder_reinit() {
1068         let mut group = test_commit_builder_group().await;
1069         let test_group_id = "foo".as_bytes().to_vec();
1070         let test_cipher_suite = TEST_CIPHER_SUITE;
1071         let test_protocol_version = TEST_PROTOCOL_VERSION;
1072         let mut test_ext = ExtensionList::default();
1073 
1074         test_ext
1075             .set_from(RequiredCapabilitiesExt::default())
1076             .unwrap();
1077 
1078         let commit_output = group
1079             .commit_builder()
1080             .reinit(
1081                 Some(test_group_id.clone()),
1082                 test_protocol_version,
1083                 test_cipher_suite,
1084                 test_ext.clone(),
1085             )
1086             .unwrap()
1087             .build()
1088             .await
1089             .unwrap();
1090 
1091         let expected_reinit = group
1092             .reinit_proposal(
1093                 Some(test_group_id),
1094                 test_protocol_version,
1095                 test_cipher_suite,
1096                 test_ext,
1097             )
1098             .unwrap();
1099 
1100         assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
1101     }
1102 
1103     #[cfg(feature = "custom_proposal")]
1104     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_custom_proposal()1105     async fn test_commit_builder_custom_proposal() {
1106         let mut group = test_commit_builder_group().await;
1107 
1108         let proposal = CustomProposal::new(42.into(), vec![0, 1]);
1109 
1110         let commit_output = group
1111             .commit_builder()
1112             .custom_proposal(proposal.clone())
1113             .build()
1114             .await
1115             .unwrap();
1116 
1117         assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
1118     }
1119 
1120     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_chaining()1121     async fn test_commit_builder_chaining() {
1122         let mut group = test_commit_builder_group().await;
1123         let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
1124         let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
1125 
1126         let expected_adds = vec![
1127             group.add_proposal(kp1.clone()).unwrap(),
1128             group.add_proposal(kp2.clone()).unwrap(),
1129         ];
1130 
1131         let commit_output = group
1132             .commit_builder()
1133             .add_member(kp1)
1134             .unwrap()
1135             .add_member(kp2)
1136             .unwrap()
1137             .build()
1138             .await
1139             .unwrap();
1140 
1141         assert_commit_builder_output(group, commit_output, expected_adds, 2);
1142     }
1143 
1144     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_empty_commit()1145     async fn test_commit_builder_empty_commit() {
1146         let mut group = test_commit_builder_group().await;
1147 
1148         let commit_output = group.commit_builder().build().await.unwrap();
1149 
1150         assert_commit_builder_output(group, commit_output, vec![], 0);
1151     }
1152 
1153     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_authenticated_data()1154     async fn test_commit_builder_authenticated_data() {
1155         let mut group = test_commit_builder_group().await;
1156         let test_data = "test".as_bytes().to_vec();
1157 
1158         let commit_output = group
1159             .commit_builder()
1160             .authenticated_data(test_data.clone())
1161             .build()
1162             .await
1163             .unwrap();
1164 
1165         assert_eq!(
1166             commit_output
1167                 .commit_message
1168                 .into_plaintext()
1169                 .unwrap()
1170                 .content
1171                 .authenticated_data,
1172             test_data
1173         );
1174     }
1175 
1176     #[cfg(feature = "by_ref_proposal")]
1177     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_commit_builder_multiple_welcome_messages()1178     async fn test_commit_builder_multiple_welcome_messages() {
1179         let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
1180             let options = CommitOptions::new().with_single_welcome_message(false);
1181             b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
1182         })
1183         .await;
1184 
1185         let (alice, alice_kp) =
1186             test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
1187 
1188         let (bob, bob_kp) =
1189             test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
1190 
1191         group
1192             .group
1193             .propose_add(alice_kp.clone(), vec![])
1194             .await
1195             .unwrap();
1196 
1197         group
1198             .group
1199             .propose_add(bob_kp.clone(), vec![])
1200             .await
1201             .unwrap();
1202 
1203         let output = group.group.commit(Vec::new()).await.unwrap();
1204         let welcomes = output.welcome_messages;
1205 
1206         let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
1207 
1208         for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
1209             let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
1210 
1211             let welcome = welcomes
1212                 .iter()
1213                 .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
1214                 .unwrap();
1215 
1216             client.join_group(None, welcome).await.unwrap();
1217 
1218             assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
1219         }
1220     }
1221 
1222     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_can_change_credential()1223     async fn commit_can_change_credential() {
1224         let cs = TEST_CIPHER_SUITE;
1225         let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
1226         let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
1227 
1228         let commit_output = groups[0]
1229             .group
1230             .commit_builder()
1231             .set_new_signing_identity(secret_key, identity.clone())
1232             .build()
1233             .await
1234             .unwrap();
1235 
1236         // Check that the credential was updated by in the committer's state.
1237         groups[0].process_pending_commit().await.unwrap();
1238         let new_member = groups[0].group.roster().member_with_index(0).unwrap();
1239 
1240         assert_eq!(
1241             new_member.signing_identity.credential,
1242             get_test_basic_credential(b"member".to_vec())
1243         );
1244 
1245         assert_eq!(
1246             new_member.signing_identity.signature_key,
1247             identity.signature_key
1248         );
1249 
1250         // Check that the credential was updated in another member's state.
1251         groups[1]
1252             .process_message(commit_output.commit_message)
1253             .await
1254             .unwrap();
1255 
1256         let new_member = groups[1].group.roster().member_with_index(0).unwrap();
1257 
1258         assert_eq!(
1259             new_member.signing_identity.credential,
1260             get_test_basic_credential(b"member".to_vec())
1261         );
1262 
1263         assert_eq!(
1264             new_member.signing_identity.signature_key,
1265             identity.signature_key
1266         );
1267     }
1268 
1269     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_includes_tree_if_no_ratchet_tree_ext()1270     async fn commit_includes_tree_if_no_ratchet_tree_ext() {
1271         let mut group = test_group_custom(
1272             TEST_PROTOCOL_VERSION,
1273             TEST_CIPHER_SUITE,
1274             Default::default(),
1275             None,
1276             Some(CommitOptions::new().with_ratchet_tree_extension(false)),
1277         )
1278         .await
1279         .group;
1280 
1281         let commit = group.commit(vec![]).await.unwrap();
1282 
1283         group.apply_pending_commit().await.unwrap();
1284 
1285         let new_tree = group.export_tree();
1286 
1287         assert_eq!(new_tree, commit.ratchet_tree.unwrap())
1288     }
1289 
1290     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_does_not_include_tree_if_ratchet_tree_ext()1291     async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
1292         let mut group = test_group_custom(
1293             TEST_PROTOCOL_VERSION,
1294             TEST_CIPHER_SUITE,
1295             Default::default(),
1296             None,
1297             Some(CommitOptions::new().with_ratchet_tree_extension(true)),
1298         )
1299         .await
1300         .group;
1301 
1302         let commit = group.commit(vec![]).await.unwrap();
1303 
1304         assert!(commit.ratchet_tree.is_none());
1305     }
1306 
1307     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_includes_external_commit_group_info_if_requested()1308     async fn commit_includes_external_commit_group_info_if_requested() {
1309         let mut group = test_group_custom(
1310             TEST_PROTOCOL_VERSION,
1311             TEST_CIPHER_SUITE,
1312             Default::default(),
1313             None,
1314             Some(
1315                 CommitOptions::new()
1316                     .with_allow_external_commit(true)
1317                     .with_ratchet_tree_extension(false),
1318             ),
1319         )
1320         .await
1321         .group;
1322 
1323         let commit = group.commit(vec![]).await.unwrap();
1324 
1325         let info = commit
1326             .external_commit_group_info
1327             .unwrap()
1328             .into_group_info()
1329             .unwrap();
1330 
1331         assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1332         assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1333     }
1334 
1335     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_includes_external_commit_and_tree_if_requested()1336     async fn commit_includes_external_commit_and_tree_if_requested() {
1337         let mut group = test_group_custom(
1338             TEST_PROTOCOL_VERSION,
1339             TEST_CIPHER_SUITE,
1340             Default::default(),
1341             None,
1342             Some(
1343                 CommitOptions::new()
1344                     .with_allow_external_commit(true)
1345                     .with_ratchet_tree_extension(true),
1346             ),
1347         )
1348         .await
1349         .group;
1350 
1351         let commit = group.commit(vec![]).await.unwrap();
1352 
1353         let info = commit
1354             .external_commit_group_info
1355             .unwrap()
1356             .into_group_info()
1357             .unwrap();
1358 
1359         assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
1360         assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
1361     }
1362 
1363     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
commit_does_not_include_external_commit_group_info_if_not_requested()1364     async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
1365         let mut group = test_group_custom(
1366             TEST_PROTOCOL_VERSION,
1367             TEST_CIPHER_SUITE,
1368             Default::default(),
1369             None,
1370             Some(CommitOptions::new().with_allow_external_commit(false)),
1371         )
1372         .await
1373         .group;
1374 
1375         let commit = group.commit(vec![]).await.unwrap();
1376 
1377         assert!(commit.external_commit_group_info.is_none());
1378     }
1379 
1380     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
member_identity_is_validated_against_new_extensions()1381     async fn member_identity_is_validated_against_new_extensions() {
1382         let alice = client_with_test_extension(b"alice").await;
1383         let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
1384 
1385         let bob = client_with_test_extension(b"bob").await;
1386         let bob_kp = bob.generate_key_package_message().await.unwrap();
1387 
1388         let mut extension_list = ExtensionList::new();
1389         let extension = TestExtension { foo: b'a' };
1390         extension_list.set_from(extension).unwrap();
1391 
1392         let res = alice
1393             .commit_builder()
1394             .add_member(bob_kp)
1395             .unwrap()
1396             .set_group_context_ext(extension_list.clone())
1397             .unwrap()
1398             .build()
1399             .await;
1400 
1401         assert!(res.is_err());
1402 
1403         let alex = client_with_test_extension(b"alex").await;
1404 
1405         alice
1406             .commit_builder()
1407             .add_member(alex.generate_key_package_message().await.unwrap())
1408             .unwrap()
1409             .set_group_context_ext(extension_list.clone())
1410             .unwrap()
1411             .build()
1412             .await
1413             .unwrap();
1414     }
1415 
1416     #[cfg(feature = "by_ref_proposal")]
1417     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
server_identity_is_validated_against_new_extensions()1418     async fn server_identity_is_validated_against_new_extensions() {
1419         let alice = client_with_test_extension(b"alice").await;
1420         let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
1421 
1422         let mut extension_list = ExtensionList::new();
1423         let extension = TestExtension { foo: b'a' };
1424         extension_list.set_from(extension).unwrap();
1425 
1426         let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
1427 
1428         let mut alex_extensions = extension_list.clone();
1429 
1430         alex_extensions
1431             .set_from(ExternalSendersExt {
1432                 allowed_senders: vec![alex_server],
1433             })
1434             .unwrap();
1435 
1436         let res = alice
1437             .commit_builder()
1438             .set_group_context_ext(alex_extensions)
1439             .unwrap()
1440             .build()
1441             .await;
1442 
1443         assert!(res.is_err());
1444 
1445         let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
1446 
1447         let mut bob_extensions = extension_list;
1448 
1449         bob_extensions
1450             .set_from(ExternalSendersExt {
1451                 allowed_senders: vec![bob_server],
1452             })
1453             .unwrap();
1454 
1455         alice
1456             .commit_builder()
1457             .set_group_context_ext(bob_extensions)
1458             .unwrap()
1459             .build()
1460             .await
1461             .unwrap();
1462     }
1463 
1464     #[derive(Debug, Clone)]
1465     struct IdentityProviderWithExtension(BasicIdentityProvider);
1466 
1467     #[derive(Clone, Debug)]
1468     #[cfg_attr(feature = "std", derive(thiserror::Error))]
1469     #[cfg_attr(feature = "std", error("test error"))]
1470     struct IdentityProviderWithExtensionError {}
1471 
1472     impl IntoAnyError for IdentityProviderWithExtensionError {
1473         #[cfg(feature = "std")]
into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self>1474         fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
1475             Ok(self.into())
1476         }
1477     }
1478 
1479     impl IdentityProviderWithExtension {
1480         // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
1481         // is not set.
1482         #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
starts_with_foo( &self, identity: &SigningIdentity, _timestamp: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> bool1483         async fn starts_with_foo(
1484             &self,
1485             identity: &SigningIdentity,
1486             _timestamp: Option<MlsTime>,
1487             extensions: Option<&ExtensionList>,
1488         ) -> bool {
1489             if let Some(extensions) = extensions {
1490                 if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
1491                     self.identity(identity, extensions).await.unwrap()[0] == ext.foo
1492                 } else {
1493                     true
1494                 }
1495             } else {
1496                 true
1497             }
1498         }
1499     }
1500 
1501     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
1502     #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
1503     impl IdentityProvider for IdentityProviderWithExtension {
1504         type Error = IdentityProviderWithExtensionError;
1505 
validate_member( &self, identity: &SigningIdentity, timestamp: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error>1506         async fn validate_member(
1507             &self,
1508             identity: &SigningIdentity,
1509             timestamp: Option<MlsTime>,
1510             extensions: Option<&ExtensionList>,
1511         ) -> Result<(), Self::Error> {
1512             self.starts_with_foo(identity, timestamp, extensions)
1513                 .await
1514                 .then_some(())
1515                 .ok_or(IdentityProviderWithExtensionError {})
1516         }
1517 
validate_external_sender( &self, identity: &SigningIdentity, timestamp: Option<MlsTime>, extensions: Option<&ExtensionList>, ) -> Result<(), Self::Error>1518         async fn validate_external_sender(
1519             &self,
1520             identity: &SigningIdentity,
1521             timestamp: Option<MlsTime>,
1522             extensions: Option<&ExtensionList>,
1523         ) -> Result<(), Self::Error> {
1524             (!self.starts_with_foo(identity, timestamp, extensions).await)
1525                 .then_some(())
1526                 .ok_or(IdentityProviderWithExtensionError {})
1527         }
1528 
identity( &self, signing_identity: &SigningIdentity, extensions: &ExtensionList, ) -> Result<Vec<u8>, Self::Error>1529         async fn identity(
1530             &self,
1531             signing_identity: &SigningIdentity,
1532             extensions: &ExtensionList,
1533         ) -> Result<Vec<u8>, Self::Error> {
1534             self.0
1535                 .identity(signing_identity, extensions)
1536                 .await
1537                 .map_err(|_| IdentityProviderWithExtensionError {})
1538         }
1539 
valid_successor( &self, _predecessor: &SigningIdentity, _successor: &SigningIdentity, _extensions: &ExtensionList, ) -> Result<bool, Self::Error>1540         async fn valid_successor(
1541             &self,
1542             _predecessor: &SigningIdentity,
1543             _successor: &SigningIdentity,
1544             _extensions: &ExtensionList,
1545         ) -> Result<bool, Self::Error> {
1546             Ok(true)
1547         }
1548 
supported_types(&self) -> Vec<CredentialType>1549         fn supported_types(&self) -> Vec<CredentialType> {
1550             self.0.supported_types()
1551         }
1552     }
1553 
1554     type ExtensionClientConfig = WithIdentityProvider<
1555         IdentityProviderWithExtension,
1556         WithCryptoProvider<TestCryptoProvider, BaseConfig>,
1557     >;
1558 
1559     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig>1560     async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
1561         let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
1562 
1563         ClientBuilder::new()
1564             .crypto_provider(TestCryptoProvider::new())
1565             .extension_types(vec![TEST_EXTENSION_TYPE.into()])
1566             .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
1567             .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
1568             .build()
1569     }
1570 }
1571