1 #![allow(missing_docs)] 2 // Copyright 2023 Google LLC 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use crate::d2d_connection_context_v1::D2DConnectionContextV1; 17 use crypto_provider::CryptoProvider; 18 use rand::{rngs::StdRng, SeedableRng as _}; 19 use std::{collections::HashSet, mem}; 20 use ukey2_rs::{ 21 CompletedHandshake, HandshakeImplementation, StateMachine, Ukey2Client, Ukey2ClientStage1, 22 Ukey2Server, Ukey2ServerStage1, Ukey2ServerStage2, 23 }; 24 25 #[derive(Debug)] 26 pub enum HandshakeError { 27 HandshakeNotComplete, 28 } 29 30 #[derive(Debug)] 31 pub enum HandleMessageError { 32 /// The supplied message was not applicable for the current state 33 InvalidState, 34 /// Handling the message produced an error that should be sent to the other party 35 ErrorMessage(Vec<u8>), 36 /// Bad message 37 BadMessage, 38 } 39 40 /// Implements UKEY2 and produces a [`D2DConnectionContextV1`]. 41 /// This class should be kept compatible with the Java and C++ implementations in 42 /// <https://github.com/google/ukey2>. 43 /// 44 /// For usage examples, see `ukey2_shell`. This file contains a shell exercising 45 /// both the initiator and responder handshake roles. 46 pub trait D2DHandshakeContext<R = rand::rngs::StdRng>: Send 47 where 48 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 49 { 50 /// Tells the caller whether the handshake has completed or not. If the handshake is complete, 51 /// the caller may call [`to_connection_context`][Self::to_connection_context] to obtain a 52 /// connection context. 53 /// 54 /// Returns true if the handshake is complete, false otherwise. is_handshake_complete(&self) -> bool55 fn is_handshake_complete(&self) -> bool; 56 57 /// Constructs the next message that should be sent in the handshake. 58 /// 59 /// Returns the next message or `None` if the handshake is over. get_next_handshake_message(&self) -> Option<Vec<u8>>60 fn get_next_handshake_message(&self) -> Option<Vec<u8>>; 61 62 /// Parses a handshake message and advances the internal state of the context. 63 /// 64 /// * `handshakeMessage` - message received from the remote end in the handshake handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>65 fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>; 66 67 /// Creates a [`D2DConnectionContextV1`] using the results of the handshake. May only be called 68 /// if [`is_handshake_complete`][Self::is_handshake_complete] returns true. Before trusting the 69 /// connection, callers should check that `to_completed_handshake().auth_string()` matches on 70 /// the client and server sides first. See the documentation for 71 /// [`to_completed_handshake`][Self::to_completed_handshake]. to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>72 fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>; 73 74 /// Returns the [`CompletedHandshake`] using the results from this handshake context. May only 75 /// be called if [`is_handshake_complete`][Self::is_handshake_complete] returns true. 76 /// Callers should verify that the authentication strings from 77 /// `to_completed_handshake().auth_string()` matches on the server and client sides before 78 /// trying to create a connection context. This authentication string verification needs to be 79 /// done out-of-band, either by displaying the string to the user, or verified by some other 80 /// secure means. to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>81 fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>; 82 } 83 84 enum InitiatorState<C: CryptoProvider> { 85 Stage1(Ukey2ClientStage1<C>), 86 Complete(Ukey2Client), 87 /// If the initiator enters into an invalid state, e.g. by receiving invalid input. 88 /// Also a momentary placeholder while swapping out states. 89 Invalid, 90 } 91 92 /// Implementation of [`D2DHandshakeContext`] for the initiator (a.k.a the client). 93 pub struct InitiatorD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng> 94 where 95 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 96 { 97 state: InitiatorState<C>, 98 rng: R, 99 } 100 101 impl<C: CryptoProvider> InitiatorD2DHandshakeContext<C, rand::rngs::StdRng> { new(handshake_impl: HandshakeImplementation) -> Self102 pub fn new(handshake_impl: HandshakeImplementation) -> Self { 103 Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy()) 104 } 105 } 106 107 impl<C: CryptoProvider, R> InitiatorD2DHandshakeContext<C, R> 108 where 109 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 110 { 111 // Used for testing / fuzzing only. 112 #[doc(hidden)] new_impl(handshake_impl: HandshakeImplementation, mut rng: R) -> Self113 pub fn new_impl(handshake_impl: HandshakeImplementation, mut rng: R) -> Self { 114 let client = Ukey2ClientStage1::from( 115 &mut rng, 116 D2DConnectionContextV1::<StdRng>::NEXT_PROTOCOL_IDENTIFIER.to_owned(), 117 handshake_impl, 118 ); 119 Self { 120 state: InitiatorState::Stage1(client), 121 rng, 122 } 123 } 124 } 125 126 impl<C: CryptoProvider, R> D2DHandshakeContext<R> for InitiatorD2DHandshakeContext<C, R> 127 where 128 R: rand::RngCore + rand::CryptoRng + rand::SeedableRng + Send, 129 { is_handshake_complete(&self) -> bool130 fn is_handshake_complete(&self) -> bool { 131 match self.state { 132 InitiatorState::Stage1(_) => false, 133 InitiatorState::Complete(_) => true, 134 InitiatorState::Invalid => false, 135 } 136 } 137 get_next_handshake_message(&self) -> Option<Vec<u8>>138 fn get_next_handshake_message(&self) -> Option<Vec<u8>> { 139 let next_msg = match &self.state { 140 InitiatorState::Stage1(c) => Some(c.client_init_msg().to_vec()), 141 InitiatorState::Complete(c) => Some(c.client_finished_msg().to_vec()), 142 InitiatorState::Invalid => None, 143 }?; 144 Some(next_msg) 145 } 146 handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>147 fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> { 148 match mem::replace(&mut self.state, InitiatorState::Invalid) { 149 InitiatorState::Stage1(c) => { 150 let client = c 151 .advance_state(&mut self.rng, message) 152 .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?; 153 self.state = InitiatorState::Complete(client); 154 Ok(()) 155 } 156 InitiatorState::Complete(_) | InitiatorState::Invalid => { 157 // already in invalid state, so leave it as is 158 Err(HandleMessageError::InvalidState) 159 } 160 } 161 } 162 to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>163 fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> { 164 match &self.state { 165 InitiatorState::Stage1(_) | InitiatorState::Invalid => { 166 Err(HandshakeError::HandshakeNotComplete) 167 } 168 InitiatorState::Complete(c) => Ok(c.completed_handshake()), 169 } 170 } 171 to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>172 fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> { 173 // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng 174 // should never fail. https://rust-random.github.io/book/guide-err.html 175 let rng = R::from_rng(&mut self.rng).unwrap(); 176 self.to_completed_handshake() 177 .and_then(|h| match h.next_protocol.as_ref() { 178 D2DConnectionContextV1::<R>::NEXT_PROTOCOL_IDENTIFIER => Ok( 179 D2DConnectionContextV1::from_initiator_handshake::<C>(h, rng), 180 ), 181 _ => Err(HandshakeError::HandshakeNotComplete), 182 }) 183 } 184 } 185 186 enum ServerState<C: CryptoProvider> { 187 Stage1(Ukey2ServerStage1<C>), 188 Stage2(Ukey2ServerStage2<C>), 189 Complete(Ukey2Server), 190 /// If the initiator enters into an invalid state, e.g. by receiving invalid input. 191 /// Also a momentary placeholder while swapping out states. 192 Invalid, 193 } 194 195 /// Implementation of [`D2DHandshakeContext`] for the server. 196 pub struct ServerD2DHandshakeContext<C: CryptoProvider, R = rand::rngs::StdRng> 197 where 198 R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send, 199 { 200 state: ServerState<C>, 201 rng: R, 202 } 203 204 impl<C: CryptoProvider> ServerD2DHandshakeContext<C, rand::rngs::StdRng> { new(handshake_impl: HandshakeImplementation) -> Self205 pub fn new(handshake_impl: HandshakeImplementation) -> Self { 206 Self::new_impl(handshake_impl, rand::rngs::StdRng::from_entropy()) 207 } 208 } 209 210 impl<C: CryptoProvider, R> ServerD2DHandshakeContext<C, R> 211 where 212 R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send, 213 { 214 // Used for testing / fuzzing only. 215 #[doc(hidden)] new_impl(handshake_impl: HandshakeImplementation, rng: R) -> Self216 pub fn new_impl(handshake_impl: HandshakeImplementation, rng: R) -> Self { 217 Self { 218 state: ServerState::Stage1(Ukey2ServerStage1::from( 219 HashSet::from([ 220 D2DConnectionContextV1::<rand::rngs::StdRng>::NEXT_PROTOCOL_IDENTIFIER 221 .to_owned(), 222 ]), 223 handshake_impl, 224 )), 225 rng, 226 } 227 } 228 } 229 230 impl<C, R> D2DHandshakeContext<R> for ServerD2DHandshakeContext<C, R> 231 where 232 C: CryptoProvider, 233 R: rand::Rng + rand::SeedableRng + rand::CryptoRng + Send, 234 { is_handshake_complete(&self) -> bool235 fn is_handshake_complete(&self) -> bool { 236 match &self.state { 237 ServerState::Complete(_) => true, 238 ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => false, 239 } 240 } 241 get_next_handshake_message(&self) -> Option<Vec<u8>>242 fn get_next_handshake_message(&self) -> Option<Vec<u8>> { 243 let next_msg = match &self.state { 244 ServerState::Stage1(_) => None, 245 ServerState::Stage2(s) => Some(s.server_init_msg().to_vec()), 246 ServerState::Complete(_) => None, 247 ServerState::Invalid => None, 248 }?; 249 Some(next_msg) 250 } 251 handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError>252 fn handle_handshake_message(&mut self, message: &[u8]) -> Result<(), HandleMessageError> { 253 match mem::replace(&mut self.state, ServerState::Invalid) { 254 ServerState::Stage1(s) => { 255 let server2 = s 256 .advance_state(&mut self.rng, message) 257 .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?; 258 self.state = ServerState::Stage2(server2); 259 Ok(()) 260 } 261 ServerState::Stage2(s) => { 262 let server = s 263 .advance_state(&mut self.rng, message) 264 .map_err(|a| HandleMessageError::ErrorMessage(a.into_wrapped_alert_msg()))?; 265 self.state = ServerState::Complete(server); 266 Ok(()) 267 } 268 ServerState::Complete(_) | ServerState::Invalid => { 269 Err(HandleMessageError::InvalidState) 270 } 271 } 272 } 273 to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError>274 fn to_completed_handshake(&self) -> Result<&CompletedHandshake, HandshakeError> { 275 match &self.state { 276 ServerState::Stage1(_) | ServerState::Stage2(_) | ServerState::Invalid => { 277 Err(HandshakeError::HandshakeNotComplete) 278 } 279 ServerState::Complete(s) => Ok(s.completed_handshake()), 280 } 281 } 282 to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError>283 fn to_connection_context(&mut self) -> Result<D2DConnectionContextV1<R>, HandshakeError> { 284 // Since self.rng is expected to be a seeded PRNG, not an OsRng directly, from_rng 285 // should never fail. https://rust-random.github.io/book/guide-err.html 286 let rng = R::from_rng(&mut self.rng).unwrap(); 287 self.to_completed_handshake() 288 .map(|h| match h.next_protocol.as_ref() { 289 D2DConnectionContextV1::<R>::NEXT_PROTOCOL_IDENTIFIER => { 290 D2DConnectionContextV1::from_responder_handshake::<C>(h, rng) 291 } 292 _ => { 293 // This should never happen because ukey2_handshake should set next_protocol to 294 // one of the values we passed in Ukey2ServerStage1::from, which doesn't contain 295 // any other value. 296 panic!("Unknown next protocol: {}", h.next_protocol); 297 } 298 }) 299 } 300 } 301