• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![allow(missing_docs)]
2 // Copyright 2023 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 use crate::d2d_connection_context_v1::D2DConnectionContextV1;
17 use crypto_provider::CryptoProvider;
18 use rand::{rngs::StdRng, SeedableRng as _};
19 use std::{collections::HashSet, mem};
20 use ukey2_rs::{
21     CompletedHandshake, HandshakeImplementation, StateMachine, Ukey2Client, Ukey2ClientStage1,
22     Ukey2Server, Ukey2ServerStage1, Ukey2ServerStage2,
23 };
24 
25 #[derive(Debug)]
26 pub enum HandshakeError {
27     HandshakeNotComplete,
28 }
29 
30 #[derive(Debug)]
31 pub enum HandleMessageError {
32     /// The supplied message was not applicable for the current state
33     InvalidState,
34     /// Handling the message produced an error that should be sent to the other party
35     ErrorMessage(Vec<u8>),
36     /// Bad message
37     BadMessage,
38 }
39 
40 /// Implements UKEY2 and produces a [`D2DConnectionContextV1`].
41 /// This class should be kept compatible with the Java and C++ implementations in
42 /// <https://github.com/google/ukey2>.
43 ///
44 /// For usage examples, see `ukey2_shell`. This file contains a shell exercising
45 /// both the initiator and responder handshake roles.
46 pub trait D2DHandshakeContext<R = rand::rngs::StdRng>: Send
47 where
48     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
49 {
50     /// Tells the caller whether the handshake has completed or not. If the handshake is complete,
51     /// the caller may call [`to_connection_context`][Self::to_connection_context] to obtain a
52     /// connection context.
53     ///
54     /// Returns true if the handshake is complete, false otherwise.
is_handshake_complete(&self) -> bool55     fn is_handshake_complete(&self) -> bool;
56 
57     /// Constructs the next message that should be sent in the handshake.
58     ///
59     /// Returns the next message or `None` if the handshake is over.
get_next_handshake_message(&self) -> Option<Vec<u8>>60     fn get_next_handshake_message(&self) -> Option<Vec<u8>>;
61 
62     /// Parses a handshake message and advances the internal state of the context.
63     ///
64     /// * `handshakeMessage` - message received from the remote end in the handshake
handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>65     fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>;
66 
67     /// Creates a [`D2DConnectionContextV1`] using the results of the handshake. May only be called
68     /// if [`is_handshake_complete`][Self::is_handshake_complete] returns true. Before trusting the
69     /// connection, callers should check that `to_completed_handshake().auth_string()` matches on
70     /// the client and server sides first. See the documentation for
71     /// [`to_completed_handshake`][Self::to_completed_handshake].
to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>72     fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>;
73 
74     /// Returns the [`CompletedHandshake`] using the results from this handshake context. May only
75     /// be called if [`is_handshake_complete`][Self::is_handshake_complete] returns true.
76     /// Callers should verify that the authentication strings from
77     /// `to_completed_handshake().auth_string()` matches on the server and client sides before
78     /// trying to create a connection context. This authentication string verification needs to be
79     /// done out-of-band, either by displaying the string to the user, or verified by some other
80     /// secure means.
to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>81     fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>;
82 }
83 
84 enum InitiatorState<C: CryptoProvider> {
85     Stage1(Ukey2ClientStage1<C>),
86     Complete(Ukey2Client),
87     /// If the initiator enters into an invalid state, e.g. by receiving invalid input.
88     /// Also a momentary placeholder while swapping out states.
89     Invalid,
90 }
91 
92 /// Implementation of [`D2DHandshakeContext`] for the initiator (a.k.a the client).
93 pub struct InitiatorD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng>
94 where
95     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
96 {
97     state: InitiatorState<C>,
98     rng: R,
99 }
100 
101 impl<C: CryptoProvider> InitiatorD2DHandshakeContext<C, rand::rngs::StdRng> {
new(handshake_impl: HandshakeImplementation) -> Self102     pub fn new(handshake_impl: HandshakeImplementation) -> Self {
103         Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy())
104     }
105 }
106 
107 impl<C: CryptoProvider, R> InitiatorD2DHandshakeContext<C, R>
108 where
109     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
110 {
111     // Used for testing / fuzzing only.
112     #[doc(hidden)]
new_impl(handshake_impl: HandshakeImplementation, mut rng: R) -> Self113     pub fn new_impl(handshake_impl: HandshakeImplementation, mut rng: R) -> Self {
114         let client = Ukey2ClientStage1::from(
115             &mut rng,
116             D2DConnectionContextV1::<StdRng>::NEXT_PROTOCOL_IDENTIFIER.to_owned(),
117             handshake_impl,
118         );
119         Self {
120             state: InitiatorState::Stage1(client),
121             rng,
122         }
123     }
124 }
125 
126 impl<C: CryptoProvider, R> D2DHandshakeContext<R> for InitiatorD2DHandshakeContext<C, R>
127 where
128     R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send,
129 {
is_handshake_complete(&self) -> bool130     fn is_handshake_complete(&self) -> bool {
131         match self.state {
132             InitiatorState::Stage1(_) => false,
133             InitiatorState::Complete(_) => true,
134             InitiatorState::Invalid => false,
135         }
136     }
137 
get_next_handshake_message(&self) -> Option<Vec<u8>>138     fn get_next_handshake_message(&self) -> Option<Vec<u8>> {
139         let next_msg = match &self.state {
140             InitiatorState::Stage1(c) => Some(c.client_init_msg().to_vec()),
141             InitiatorState::Complete(c) => Some(c.client_finished_msg().to_vec()),
142             InitiatorState::Invalid => None,
143         }?;
144         Some(next_msg)
145     }
146 
handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>147     fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> {
148         match mem::replace(&mut self.state, InitiatorState::Invalid) {
149             InitiatorState::Stage1(c) => {
150                 let client = c
151                     .advance_state(&mut self.rng, message)
152                     .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?;
153                 self.state = InitiatorState::Complete(client);
154                 Ok(())
155             }
156             InitiatorState::Complete(_) | InitiatorState::Invalid => {
157                 // already in invalid state, so leave it as is
158                 Err(HandleMessageError::InvalidState)
159             }
160         }
161     }
162 
to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>163     fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> {
164         match &self.state {
165             InitiatorState::Stage1(_) | InitiatorState::Invalid => {
166                 Err(HandshakeError::HandshakeNotComplete)
167             }
168             InitiatorState::Complete(c) => Ok(c.completed_handshake()),
169         }
170     }
171 
to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>172     fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> {
173         // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng
174         // should never fail. https://rust-random.github.io/book/guide-err.html
175         let rng = R::from_rng(&mut self.rng).unwrap();
176         self.to_completed_handshake()
177             .and_then(|h| match h.next_protocol.as_ref() {
178                 D2DConnectionContextV1::<R>::NEXT_PROTOCOL_IDENTIFIER => Ok(
179                     D2DConnectionContextV1::from_initiator_handshake::<C>(h, rng),
180                 ),
181                 _ => Err(HandshakeError::HandshakeNotComplete),
182             })
183     }
184 }
185 
186 enum ServerState<C: CryptoProvider> {
187     Stage1(Ukey2ServerStage1<C>),
188     Stage2(Ukey2ServerStage2<C>),
189     Complete(Ukey2Server),
190     /// If the initiator enters into an invalid state, e.g. by receiving invalid input.
191     /// Also a momentary placeholder while swapping out states.
192     Invalid,
193 }
194 
195 /// Implementation of [`D2DHandshakeContext`] for the server.
196 pub struct ServerD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng>
197 where
198     R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send,
199 {
200     state: ServerState<C>,
201     rng: R,
202 }
203 
204 impl<C: CryptoProvider> ServerD2DHandshakeContext<C, rand::rngs::StdRng> {
new(handshake_impl: HandshakeImplementation) -> Self205     pub fn new(handshake_impl: HandshakeImplementation) -> Self {
206         Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy())
207     }
208 }
209 
210 impl<C: CryptoProvider, R> ServerD2DHandshakeContext<C, R>
211 where
212     R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send,
213 {
214     // Used for testing / fuzzing only.
215     #[doc(hidden)]
new_impl(handshake_impl: HandshakeImplementation, rng: R) -> Self216     pub fn new_impl(handshake_impl: HandshakeImplementation, rng: R) -> Self {
217         Self {
218             state: ServerState::Stage1(Ukey2ServerStage1::from(
219                 HashSet::from([
220                     D2DConnectionContextV1::<rand::rngs::StdRng>::NEXT_PROTOCOL_IDENTIFIER
221                         .to_owned(),
222                 ]),
223                 handshake_impl,
224             )),
225             rng,
226         }
227     }
228 }
229 
230 impl<C, R> D2DHandshakeContext<R> for ServerD2DHandshakeContext<C, R>
231 where
232     C: CryptoProvider,
233     R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send,
234 {
is_handshake_complete(&self) -> bool235     fn is_handshake_complete(&self) -> bool {
236         match &self.state {
237             ServerState::Complete(_) => true,
238             ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => false,
239         }
240     }
241 
get_next_handshake_message(&self) -> Option<Vec<u8>>242     fn get_next_handshake_message(&self) -> Option<Vec<u8>> {
243         let next_msg = match &self.state {
244             ServerState::Stage1(_) => None,
245             ServerState::Stage2(s) => Some(s.server_init_msg().to_vec()),
246             ServerState::Complete(_) => None,
247             ServerState::Invalid => None,
248         }?;
249         Some(next_msg)
250     }
251 
handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>252     fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> {
253         match mem::replace(&mut self.state, ServerState::Invalid) {
254             ServerState::Stage1(s) => {
255                 let server2 = s
256                     .advance_state(&mut self.rng, message)
257                     .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?;
258                 self.state = ServerState::Stage2(server2);
259                 Ok(())
260             }
261             ServerState::Stage2(s) => {
262                 let server = s
263                     .advance_state(&mut self.rng, message)
264                     .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?;
265                 self.state = ServerState::Complete(server);
266                 Ok(())
267             }
268             ServerState::Complete(_) | ServerState::Invalid => {
269                 Err(HandleMessageError::InvalidState)
270             }
271         }
272     }
273 
to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>274     fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> {
275         match &self.state {
276             ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => {
277                 Err(HandshakeError::HandshakeNotComplete)
278             }
279             ServerState::Complete(s) => Ok(s.completed_handshake()),
280         }
281     }
282 
to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>283     fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> {
284         // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng
285         // should never fail. https://rust-random.github.io/book/guide-err.html
286         let rng = R::from_rng(&mut self.rng).unwrap();
287         self.to_completed_handshake()
288             .map(|h| match h.next_protocol.as_ref() {
289                 D2DConnectionContextV1::<R>::NEXT_PROTOCOL_IDENTIFIER => {
290                     D2DConnectionContextV1::from_responder_handshake::<C>(h, rng)
291                 }
292                 _ => {
293                     // This should never happen because ukey2_handshake should set next_protocol to
294                     // one of the values we passed in Ukey2ServerStage1::from, which doesn't contain
295                     // any other value.
296                     panic!("Unknown next protocol: {}", h.next_protocol);
297                 }
298             })
299     }
300 }
301