• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::fmt::Debug;
16 
17 use log::error;
18 
19 use crypto_provider::CryptoProvider;
20 use ukey2_proto::protobuf::Message;
21 use ukey2_proto::ukey2_all_proto::ukey;
22 
23 use crate::proto_adapter::{IntoAdapter, MessageType, ToWrappedMessage as _};
24 use crate::ukey2_handshake::ClientFinishedError;
25 use crate::ukey2_handshake::{
26     ClientInit, ClientInitError, Ukey2Client, Ukey2ClientStage1, Ukey2Server, Ukey2ServerStage1,
27     Ukey2ServerStage2,
28 };
29 
30 /// An alert type and message to be sent to the other party.
31 #[derive(Debug, PartialEq, Eq)]
32 pub struct SendAlert {
33     alert_type: ukey::ukey2alert::AlertType,
34     msg: Option<String>,
35 }
36 
37 impl SendAlert {
from(alert_type: ukey::ukey2alert::AlertType, msg: Option<String>) -> Self38     pub(crate) fn from(alert_type: ukey::ukey2alert::AlertType, msg: Option<String>) -> Self {
39         Self { alert_type, msg }
40     }
41 
42     /// Convert this `SendAlert` into serialized bytes of the `Ukey2Alert` protobuf message.
into_wrapped_alert_msg(self) -> Vec<u8>43     pub fn into_wrapped_alert_msg(self) -> Vec<u8> {
44         let alert_message = ukey::Ukey2Alert {
45             type_: Some(self.alert_type.into()),
46             error_message: self.msg,
47             ..Default::default()
48         };
49         #[allow(clippy::expect_used)]
50         alert_message.to_wrapped_msg().write_to_bytes().expect("Writing to proto should succeed")
51     }
52 }
53 
54 /// Generic trait for implementation of a state machine. Each state in this machine has two possible
55 /// transitions – Success and Failure.
56 ///
57 /// On Success, the machine will transition to the next state, represented by the associated type
58 /// `Success`.
59 ///
60 /// On Failure, a [`SendAlert`] message is returned indicating the failure, and there no further
61 /// transitions will be possible on this state machine.
62 ///
63 /// ### State transitions
64 ///
65 /// Here are the states both parties of the handshake goes through, with the Failure transitions
66 /// omitted to keep the documentation simple.
67 ///
68 /// ```text
69 ///          Ukey2ClientStage1               Ukey2ServerStage1
70 ///                 |
71 ///                 | -------[msg: ClientInit]-----> |
72 ///                                                  |
73 ///                                           Ukey2ServerStage2
74 ///                                                  |
75 ///                 | <------[msg: ServerInit]------ |
76 ///                 |
77 ///              Ukey2Client
78 ///                 |
79 ///                 | -----[msg: ClientFinished]---> |
80 ///                                                  |
81 ///                                              Ukey2Server
82 /// ```
83 ///
84 pub trait StateMachine {
85     /// The type produced by each successful state transition
86     type Success;
87 
88     /// Advance to the next state in the relevant half (client/server) of the protocol.
advance_state<R: rand::Rng + rand::CryptoRng>( self, rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>89     fn advance_state<R: rand::Rng + rand::CryptoRng>(
90         self,
91         rng: &mut R,
92         message_bytes: &[u8],
93     ) -> Result<Self::Success, SendAlert>;
94 }
95 
96 impl<C: CryptoProvider> StateMachine for Ukey2ClientStage1<C> {
97     type Success = Ukey2Client;
98 
advance_state<R: rand::Rng + rand::CryptoRng>( self, _rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>99     fn advance_state<R: rand::Rng + rand::CryptoRng>(
100         self,
101         _rng: &mut R,
102         message_bytes: &[u8],
103     ) -> Result<Self::Success, SendAlert> {
104         let (message_data, message_type) = decode_wrapper_msg_and_type(message_bytes)?;
105 
106         match message_type {
107             // Client should not be receiving ClientInit/ClientFinish
108             MessageType::ClientInit | MessageType::ClientFinish => Err(SendAlert::from(
109                 ukey::ukey2alert::AlertType::INCORRECT_MESSAGE,
110                 Some("wrong message".to_string()),
111             )),
112             MessageType::ServerInit => {
113                 let message = decode_msg_contents::<_, ukey::Ukey2ServerInit>(message_data)?;
114                 self.handle_server_init(message, message_bytes.to_vec()).map_err(|_| {
115                     SendAlert::from(
116                         ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA,
117                         Some("bad message_data".to_string()),
118                     )
119                 })
120             }
121         }
122     }
123 }
124 
125 impl<C: CryptoProvider> StateMachine for Ukey2ServerStage1<C> {
126     type Success = Ukey2ServerStage2<C>;
127 
advance_state<R: rand::Rng + rand::CryptoRng>( self, rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>128     fn advance_state<R: rand::Rng + rand::CryptoRng>(
129         self,
130         rng: &mut R,
131         message_bytes: &[u8],
132     ) -> Result<Self::Success, SendAlert> {
133         let (message_data, message_type) = decode_wrapper_msg_and_type(message_bytes)?;
134         match message_type {
135             MessageType::ClientInit => {
136                 let message: ClientInit =
137                     decode_msg_contents::<_, ukey::Ukey2ClientInit>(message_data)?;
138                 self.handle_client_init(rng, message, message_bytes.to_vec()).map_err(|e| {
139                     SendAlert::from(
140                         match e {
141                             ClientInitError::BadVersion => ukey::ukey2alert::AlertType::BAD_VERSION,
142                             ClientInitError::BadHandshakeCipher => {
143                                 ukey::ukey2alert::AlertType::BAD_HANDSHAKE_CIPHER
144                             }
145                             ClientInitError::BadNextProtocol => {
146                                 ukey::ukey2alert::AlertType::BAD_NEXT_PROTOCOL
147                             }
148                         },
149                         None,
150                     )
151                 })
152             }
153             MessageType::ClientFinish | MessageType::ServerInit => Err(SendAlert::from(
154                 ukey::ukey2alert::AlertType::INCORRECT_MESSAGE,
155                 Some("wrong message".to_string()),
156             )),
157         }
158     }
159 }
160 
161 impl<C: CryptoProvider> StateMachine for Ukey2ServerStage2<C> {
162     type Success = Ukey2Server;
163 
advance_state<R: rand::Rng + rand::CryptoRng>( self, _rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>164     fn advance_state<R: rand::Rng + rand::CryptoRng>(
165         self,
166         _rng: &mut R,
167         message_bytes: &[u8],
168     ) -> Result<Self::Success, SendAlert> {
169         let (message_data, message_type) = decode_wrapper_msg_and_type(message_bytes)?;
170         match message_type {
171             MessageType::ClientFinish => {
172                 let message = decode_msg_contents::<_, ukey::Ukey2ClientFinished>(message_data)?;
173                 self.handle_client_finished_msg(message, message_bytes).map_err(|e| match e {
174                     ClientFinishedError::BadEd25519Key => SendAlert::from(
175                         ukey::ukey2alert::AlertType::BAD_PUBLIC_KEY,
176                         "Bad ED25519 Key".to_string().into(),
177                     ),
178                     ClientFinishedError::BadP256Key => SendAlert::from(
179                         ukey::ukey2alert::AlertType::BAD_PUBLIC_KEY,
180                         "Bad P256 Key".to_string().into(),
181                     ),
182                     ClientFinishedError::UnknownCommitment => SendAlert::from(
183                         ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA,
184                         "Unknown commitment".to_string().into(),
185                     ),
186                     ClientFinishedError::BadKeyExchange => SendAlert::from(
187                         ukey::ukey2alert::AlertType::INTERNAL_ERROR,
188                         "Key exchange error".to_string().into(),
189                     ),
190                 })
191             }
192             MessageType::ClientInit | MessageType::ServerInit => Err(SendAlert::from(
193                 ukey::ukey2alert::AlertType::INCORRECT_MESSAGE,
194                 "wrong message".to_string().into(),
195             )),
196         }
197     }
198 }
199 
200 /// Extract the message field and message type from a Ukey2Message
decode_wrapper_msg_and_type(bytes: &[u8]) -> Result<(Vec<u8>, MessageType), SendAlert>201 fn decode_wrapper_msg_and_type(bytes: &[u8]) -> Result<(Vec<u8>, MessageType), SendAlert> {
202     ukey::Ukey2Message::parse_from_bytes(bytes)
203         .map_err(|_| {
204             error!("Unable to marshal into Ukey2Message");
205             SendAlert::from(
206                 ukey::ukey2alert::AlertType::BAD_MESSAGE,
207                 Some("Bad message data".to_string()),
208             )
209         })
210         .and_then(|message| {
211             let message_data = message.message_data();
212             if message_data.is_empty() {
213                 return Err(SendAlert::from(ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA, None));
214             }
215             let message_type = message.message_type();
216             if message_type == ukey::ukey2message::Type::UNKNOWN_DO_NOT_USE {
217                 return Err(SendAlert::from(ukey::ukey2alert::AlertType::BAD_MESSAGE_TYPE, None));
218             }
219             message_type
220                 .into_adapter()
221                 .map_err(|e| {
222                     error!("Unknown UKEY2 Message Type");
223                     SendAlert::from(e, Some("bad message type".to_string()))
224                 })
225                 .map(|message_type| (message_data.to_vec(), message_type))
226         })
227 }
228 
229 /// Extract a specific message type from message data in a Ukey2Messaage
230 ///
231 /// See [decode_wrapper_msg_and_type] for getting the message data.
decode_msg_contents<A, M: Message + Default + IntoAdapter<A>>( message_data: Vec<u8>, ) -> Result<A, SendAlert>232 fn decode_msg_contents<A, M: Message + Default + IntoAdapter<A>>(
233     message_data: Vec<u8>,
234 ) -> Result<A, SendAlert> {
235     M::parse_from_bytes(message_data.as_slice())
236         .map_err(|_| {
237             error!(
238                 "Unable to unmarshal message, check frame of the message you were trying to send"
239             );
240             SendAlert::from(
241                 ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA,
242                 Some("frame error".to_string()),
243             )
244         })?
245         .into_adapter()
246         .map_err(|t| SendAlert::from(t, Some("failed to translate proto".to_string())))
247 }
248