1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::fmt::Debug;
16
17 use log::error;
18
19 use crypto_provider::CryptoProvider;
20 use ukey2_proto::protobuf::Message;
21 use ukey2_proto::ukey2_all_proto::ukey;
22
23 use crate::proto_adapter::{IntoAdapter, MessageType, ToWrappedMessage as _};
24 use crate::ukey2_handshake::ClientFinishedError;
25 use crate::ukey2_handshake::{
26 ClientInit, ClientInitError, Ukey2Client, Ukey2ClientStage1, Ukey2Server, Ukey2ServerStage1,
27 Ukey2ServerStage2,
28 };
29
30 /// An alert type and message to be sent to the other party.
31 #[derive(Debug, PartialEq, Eq)]
32 pub struct SendAlert {
33 alert_type: ukey::ukey2alert::AlertType,
34 msg: Option<String>,
35 }
36
37 impl SendAlert {
from(alert_type: ukey::ukey2alert::AlertType, msg: Option<String>) -> Self38 pub(crate) fn from(alert_type: ukey::ukey2alert::AlertType, msg: Option<String>) -> Self {
39 Self { alert_type, msg }
40 }
41
42 /// Convert this `SendAlert` into serialized bytes of the `Ukey2Alert` protobuf message.
into_wrapped_alert_msg(self) -> Vec<u8>43 pub fn into_wrapped_alert_msg(self) -> Vec<u8> {
44 let alert_message = ukey::Ukey2Alert {
45 type_: Some(self.alert_type.into()),
46 error_message: self.msg,
47 ..Default::default()
48 };
49 #[allow(clippy::expect_used)]
50 alert_message.to_wrapped_msg().write_to_bytes().expect("Writing to proto should succeed")
51 }
52 }
53
54 /// Generic trait for implementation of a state machine. Each state in this machine has two possible
55 /// transitions – Success and Failure.
56 ///
57 /// On Success, the machine will transition to the next state, represented by the associated type
58 /// `Success`.
59 ///
60 /// On Failure, a [`SendAlert`] message is returned indicating the failure, and there no further
61 /// transitions will be possible on this state machine.
62 ///
63 /// ### State transitions
64 ///
65 /// Here are the states both parties of the handshake goes through, with the Failure transitions
66 /// omitted to keep the documentation simple.
67 ///
68 /// ```text
69 /// Ukey2ClientStage1 Ukey2ServerStage1
70 /// |
71 /// | -------[msg: ClientInit]-----> |
72 /// |
73 /// Ukey2ServerStage2
74 /// |
75 /// | <------[msg: ServerInit]------ |
76 /// |
77 /// Ukey2Client
78 /// |
79 /// | -----[msg: ClientFinished]---> |
80 /// |
81 /// Ukey2Server
82 /// ```
83 ///
84 pub trait StateMachine {
85 /// The type produced by each successful state transition
86 type Success;
87
88 /// Advance to the next state in the relevant half (client/server) of the protocol.
advance_state<R: rand::Rng + rand::CryptoRng>( self, rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>89 fn advance_state<R: rand::Rng + rand::CryptoRng>(
90 self,
91 rng: &mut R,
92 message_bytes: &[u8],
93 ) -> Result<Self::Success, SendAlert>;
94 }
95
96 impl<C: CryptoProvider> StateMachine for Ukey2ClientStage1<C> {
97 type Success = Ukey2Client;
98
advance_state<R: rand::Rng + rand::CryptoRng>( self, _rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>99 fn advance_state<R: rand::Rng + rand::CryptoRng>(
100 self,
101 _rng: &mut R,
102 message_bytes: &[u8],
103 ) -> Result<Self::Success, SendAlert> {
104 let (message_data, message_type) = decode_wrapper_msg_and_type(message_bytes)?;
105
106 match message_type {
107 // Client should not be receiving ClientInit/ClientFinish
108 MessageType::ClientInit | MessageType::ClientFinish => Err(SendAlert::from(
109 ukey::ukey2alert::AlertType::INCORRECT_MESSAGE,
110 Some("wrong message".to_string()),
111 )),
112 MessageType::ServerInit => {
113 let message = decode_msg_contents::<_, ukey::Ukey2ServerInit>(message_data)?;
114 self.handle_server_init(message, message_bytes.to_vec()).map_err(|_| {
115 SendAlert::from(
116 ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA,
117 Some("bad message_data".to_string()),
118 )
119 })
120 }
121 }
122 }
123 }
124
125 impl<C: CryptoProvider> StateMachine for Ukey2ServerStage1<C> {
126 type Success = Ukey2ServerStage2<C>;
127
advance_state<R: rand::Rng + rand::CryptoRng>( self, rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>128 fn advance_state<R: rand::Rng + rand::CryptoRng>(
129 self,
130 rng: &mut R,
131 message_bytes: &[u8],
132 ) -> Result<Self::Success, SendAlert> {
133 let (message_data, message_type) = decode_wrapper_msg_and_type(message_bytes)?;
134 match message_type {
135 MessageType::ClientInit => {
136 let message: ClientInit =
137 decode_msg_contents::<_, ukey::Ukey2ClientInit>(message_data)?;
138 self.handle_client_init(rng, message, message_bytes.to_vec()).map_err(|e| {
139 SendAlert::from(
140 match e {
141 ClientInitError::BadVersion => ukey::ukey2alert::AlertType::BAD_VERSION,
142 ClientInitError::BadHandshakeCipher => {
143 ukey::ukey2alert::AlertType::BAD_HANDSHAKE_CIPHER
144 }
145 ClientInitError::BadNextProtocol => {
146 ukey::ukey2alert::AlertType::BAD_NEXT_PROTOCOL
147 }
148 },
149 None,
150 )
151 })
152 }
153 MessageType::ClientFinish | MessageType::ServerInit => Err(SendAlert::from(
154 ukey::ukey2alert::AlertType::INCORRECT_MESSAGE,
155 Some("wrong message".to_string()),
156 )),
157 }
158 }
159 }
160
161 impl<C: CryptoProvider> StateMachine for Ukey2ServerStage2<C> {
162 type Success = Ukey2Server;
163
advance_state<R: rand::Rng + rand::CryptoRng>( self, _rng: &mut R, message_bytes: &[u8], ) -> Result<Self::Success, SendAlert>164 fn advance_state<R: rand::Rng + rand::CryptoRng>(
165 self,
166 _rng: &mut R,
167 message_bytes: &[u8],
168 ) -> Result<Self::Success, SendAlert> {
169 let (message_data, message_type) = decode_wrapper_msg_and_type(message_bytes)?;
170 match message_type {
171 MessageType::ClientFinish => {
172 let message = decode_msg_contents::<_, ukey::Ukey2ClientFinished>(message_data)?;
173 self.handle_client_finished_msg(message, message_bytes).map_err(|e| match e {
174 ClientFinishedError::BadEd25519Key => SendAlert::from(
175 ukey::ukey2alert::AlertType::BAD_PUBLIC_KEY,
176 "Bad ED25519 Key".to_string().into(),
177 ),
178 ClientFinishedError::BadP256Key => SendAlert::from(
179 ukey::ukey2alert::AlertType::BAD_PUBLIC_KEY,
180 "Bad P256 Key".to_string().into(),
181 ),
182 ClientFinishedError::UnknownCommitment => SendAlert::from(
183 ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA,
184 "Unknown commitment".to_string().into(),
185 ),
186 ClientFinishedError::BadKeyExchange => SendAlert::from(
187 ukey::ukey2alert::AlertType::INTERNAL_ERROR,
188 "Key exchange error".to_string().into(),
189 ),
190 })
191 }
192 MessageType::ClientInit | MessageType::ServerInit => Err(SendAlert::from(
193 ukey::ukey2alert::AlertType::INCORRECT_MESSAGE,
194 "wrong message".to_string().into(),
195 )),
196 }
197 }
198 }
199
200 /// Extract the message field and message type from a Ukey2Message
decode_wrapper_msg_and_type(bytes: &[u8]) -> Result<(Vec<u8>, MessageType), SendAlert>201 fn decode_wrapper_msg_and_type(bytes: &[u8]) -> Result<(Vec<u8>, MessageType), SendAlert> {
202 ukey::Ukey2Message::parse_from_bytes(bytes)
203 .map_err(|_| {
204 error!("Unable to marshal into Ukey2Message");
205 SendAlert::from(
206 ukey::ukey2alert::AlertType::BAD_MESSAGE,
207 Some("Bad message data".to_string()),
208 )
209 })
210 .and_then(|message| {
211 let message_data = message.message_data();
212 if message_data.is_empty() {
213 return Err(SendAlert::from(ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA, None));
214 }
215 let message_type = message.message_type();
216 if message_type == ukey::ukey2message::Type::UNKNOWN_DO_NOT_USE {
217 return Err(SendAlert::from(ukey::ukey2alert::AlertType::BAD_MESSAGE_TYPE, None));
218 }
219 message_type
220 .into_adapter()
221 .map_err(|e| {
222 error!("Unknown UKEY2 Message Type");
223 SendAlert::from(e, Some("bad message type".to_string()))
224 })
225 .map(|message_type| (message_data.to_vec(), message_type))
226 })
227 }
228
229 /// Extract a specific message type from message data in a Ukey2Messaage
230 ///
231 /// See [decode_wrapper_msg_and_type] for getting the message data.
decode_msg_contents<A, M: Message + Default + IntoAdapter<A>>( message_data: Vec<u8>, ) -> Result<A, SendAlert>232 fn decode_msg_contents<A, M: Message + Default + IntoAdapter<A>>(
233 message_data: Vec<u8>,
234 ) -> Result<A, SendAlert> {
235 M::parse_from_bytes(message_data.as_slice())
236 .map_err(|_| {
237 error!(
238 "Unable to unmarshal message, check frame of the message you were trying to send"
239 );
240 SendAlert::from(
241 ukey::ukey2alert::AlertType::BAD_MESSAGE_DATA,
242 Some("frame error".to_string()),
243 )
244 })?
245 .into_adapter()
246 .map_err(|t| SendAlert::from(t, Some("failed to translate proto".to_string())))
247 }
248