• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![allow(missing_docs)]
2 // Copyright 2023 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 use std::fmt::Formatter;
17 
18 use bytes::BufMut;
19 use rand::SeedableRng as _;
20 
21 use crypto_provider::{hkdf::Hkdf, hmac::Hmac, sha2::Sha256, CryptoProvider};
22 use ukey2_proto::protobuf::Message as _;
23 use ukey2_proto::ukey2_all_proto::{
24     device_to_device_messages::DeviceToDeviceMessage,
25     securegcm::{GcmMetadata, Type},
26     securemessage::{EncScheme, Header, HeaderAndBody, SecureMessage, SigScheme},
27 };
28 use ukey2_rs::CompletedHandshake;
29 
30 use crate::{crypto_utils, java_utils};
31 
32 /// Version of the D2D protocol implementation (the connection encryption part after the UKEY2
33 /// handshake). V0 is a half-duplex communication, with the key and sequence number shared between
34 /// both sides, and V1 is a full-duplex communication, with separate keys and sequence numbers
35 /// for encoding and decoding.
36 ///
37 /// Only V1 is implemented by this library.
38 const PROTOCOL_VERSION: u8 = 1;
39 /// Number of bytes in the key
40 pub(crate) const AES_256_KEY_SIZE: usize = 32;
41 /// SHA-256 of "SecureMessage"
42 const ENCRYPTION_SALT: [u8; 32] = [
43     0xbf, 0x9d, 0x2a, 0x53, 0xc6, 0x36, 0x16, 0xd7, 0x5d, 0xb0, 0xa7, 0x16, 0x5b, 0x91, 0xc1, 0xef,
44     0x73, 0xe5, 0x37, 0xf2, 0x42, 0x74, 0x05, 0xfa, 0x23, 0x61, 0x0a, 0x4b, 0xe6, 0x57, 0x64, 0x2e,
45 ];
46 
47 /// Salt for Sha256 for [`get_session_unique`][D2DConnectionContextV1::get_session_unique].
48 /// SHA-256 of "D2D"
49 const SESSION_UNIQUE_SALT: [u8; 32] = [
50     0x82, 0xAA, 0x55, 0xA0, 0xD3, 0x97, 0xF8, 0x83, 0x46, 0xCA, 0x1C, 0xEE, 0x8D, 0x39, 0x09, 0xB9,
51     0x5F, 0x13, 0xFA, 0x7D, 0xEB, 0x1D, 0x4A, 0xB3, 0x83, 0x76, 0xB8, 0x25, 0x6D, 0xA8, 0x55, 0x10,
52 ];
53 
54 pub(crate) type AesCbcIv = [u8; 16];
55 pub type Aes256Key = [u8; 32];
56 
57 const HKDF_INFO_KEY_INITIATOR: &[u8; 6] = b"client";
58 const HKDF_INFO_KEY_RESPONDER: &[u8; 6] = b"server";
59 const HKDF_SALT_ENCRYPT_KEY: &[u8] = b"D2D";
60 
61 // Static utilities for dealing with AES keys
62 /// Returns `None` if the requested size > 255 * 512 bytes.
encryption_key<const N: usize, C: CryptoProvider>( next_protocol_key: &[u8], purpose: &[u8], ) -> Option<[u8; N]>63 fn encryption_key<const N: usize, C: CryptoProvider>(
64     next_protocol_key: &[u8],
65     purpose: &[u8],
66 ) -> Option<[u8; N]> {
67     let mut buf = [0u8; N];
68     let result = &C::Sha256::sha256(HKDF_SALT_ENCRYPT_KEY);
69     let hkdf = C::HkdfSha256::new(Some(result), next_protocol_key);
70     hkdf.expand(purpose, &mut buf).ok().map(|_| buf)
71 }
72 
73 struct RustDeviceToDeviceMessage {
74     sequence_num: i32,
75     message: Vec<u8>,
76 }
77 
78 // Static utility functions for dealing with DeviceToDeviceMessage.
create_device_to_device_message(msg: RustDeviceToDeviceMessage) -> Vec<u8>79 fn create_device_to_device_message(msg: RustDeviceToDeviceMessage) -> Vec<u8> {
80     let d2d_message = DeviceToDeviceMessage {
81         message: Some(msg.message),
82         sequence_number: Some(msg.sequence_num),
83         ..Default::default()
84     };
85     d2d_message.write_to_bytes().unwrap()
86 }
87 
unwrap_device_to_device_message( message: &[u8], ) -> Result<RustDeviceToDeviceMessage, DecodeError>88 fn unwrap_device_to_device_message(
89     message: &[u8],
90 ) -> Result<RustDeviceToDeviceMessage, DecodeError> {
91     let result =
92         DeviceToDeviceMessage::parse_from_bytes(message).map_err(|_| DecodeError::BadData)?;
93     let (msg, seq_num) = result
94         .message
95         .zip(result.sequence_number)
96         .ok_or(DecodeError::BadData)?;
97     Ok(RustDeviceToDeviceMessage {
98         sequence_num: seq_num,
99         message: msg,
100     })
101 }
102 
derive_aes256_key<C: CryptoProvider>(initial_key: &[u8], purpose: &[u8]) -> Aes256Key103 fn derive_aes256_key<C: CryptoProvider>(initial_key: &[u8], purpose: &[u8]) -> Aes256Key {
104     let mut buf = [0u8; AES_256_KEY_SIZE];
105     let hkdf = C::HkdfSha256::new(Some(&ENCRYPTION_SALT), initial_key);
106     hkdf.expand(purpose, &mut buf).unwrap();
107     buf
108 }
109 
110 /// Implementation of the UKEY2 connection protocol, also known as version 1 of the D2D protocol. In
111 /// this  version, communication is fully duplex, as separate keys and sequence numbers are used for
112 /// encoding and decoding.
113 #[derive(Debug)]
114 pub struct D2DConnectionContextV1<R = rand::rngs::StdRng>
115 where
116     R: rand::Rng + rand::SeedableRng + rand::CryptoRng,
117 {
118     decode_sequence_num: i32,
119     encode_sequence_num: i32,
120     encode_key: Aes256Key,
121     decode_key: Aes256Key,
122     encryption_key: Aes256Key,
123     decryption_key: Aes256Key,
124     signing_key: Aes256Key,
125     verify_key: Aes256Key,
126     rng: R,
127 }
128 
129 /// Error type for [`decode_message_from_peer`][D2DConnectionContextV1::decode_message_from_peer].
130 #[derive(Debug)]
131 pub enum DecodeError {
132     /// The data input being decoded does not match the expected input format.
133     BadData,
134     /// The sequence number of the incoming message does not match the expected number. This means
135     /// messages has been lost, received out of order, or duplicates have been received.
136     BadSequenceNumber,
137 }
138 
139 /// Error type for [`from_saved_session`][D2DConnectionContextV1::from_saved_session].
140 #[derive(Debug, PartialEq, Eq)]
141 pub enum DeserializeError {
142     /// The input data is not a valid protobuf message and cannot be deserialized.
143     BadData,
144     /// The data length for the input data or some of its fields do not match the required length.
145     BadDataLength,
146     /// The protocol version indicated in the input data is not expected by this implementation.
147     BadProtocolVersion,
148 }
149 
150 impl std::fmt::Display for DecodeError {
fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result151     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152         match self {
153             DecodeError::BadData => write!(f, "DecodeError: BadData"),
154             DecodeError::BadSequenceNumber => write!(f, "DecodeError: Bad sequence number"),
155         }
156     }
157 }
158 
159 impl D2DConnectionContextV1<rand::rngs::StdRng> {
from_saved_session<C: CryptoProvider>(session: &[u8]) -> Result<Self, DeserializeError>160     pub fn from_saved_session<C: CryptoProvider>(session: &[u8]) -> Result<Self, DeserializeError> {
161         Self::from_saved_session_with_rng::<C>(session, rand::rngs::StdRng::from_entropy())
162     }
163 }
164 
165 impl<R> D2DConnectionContextV1<R>
166 where
167     R: rand::Rng + rand::SeedableRng + rand::CryptoRng,
168 {
169     pub(crate) const NEXT_PROTOCOL_IDENTIFIER: &'static str = "AES_256_CBC-HMAC_SHA256";
170 
new<C: CryptoProvider>( decode_sequence_num: i32, encode_sequence_num: i32, encode_key: Aes256Key, decode_key: Aes256Key, rng: R, ) -> Self171     pub fn new<C: CryptoProvider>(
172         decode_sequence_num: i32,
173         encode_sequence_num: i32,
174         encode_key: Aes256Key,
175         decode_key: Aes256Key,
176         rng: R,
177     ) -> Self {
178         let encryption_key = derive_aes256_key::<C>(&encode_key, b"ENC:2");
179         let decryption_key = derive_aes256_key::<C>(&decode_key, b"ENC:2");
180         let signing_key = derive_aes256_key::<C>(&encode_key, b"SIG:1");
181         let verify_key = derive_aes256_key::<C>(&decode_key, b"SIG:1");
182         D2DConnectionContextV1 {
183             decode_sequence_num,
184             encode_sequence_num,
185             encode_key,
186             decode_key,
187             encryption_key,
188             decryption_key,
189             signing_key,
190             verify_key,
191             rng,
192         }
193     }
194 
from_initiator_handshake<C: CryptoProvider>( handshake: &CompletedHandshake, rng: R, ) -> Self195     pub(crate) fn from_initiator_handshake<C: CryptoProvider>(
196         handshake: &CompletedHandshake,
197         rng: R,
198     ) -> Self {
199         let next_protocol_secret = handshake
200             .next_protocol_secret::<C>()
201             .derive_array::<AES_256_KEY_SIZE>()
202             .unwrap();
203         D2DConnectionContextV1::new::<C>(
204             0,
205             0,
206             encryption_key::<32, C>(&next_protocol_secret, HKDF_INFO_KEY_INITIATOR).unwrap(),
207             encryption_key::<32, C>(&next_protocol_secret, HKDF_INFO_KEY_RESPONDER).unwrap(),
208             rng,
209         )
210     }
211 
from_responder_handshake<C: CryptoProvider>( handshake: &CompletedHandshake, rng: R, ) -> Self212     pub(crate) fn from_responder_handshake<C: CryptoProvider>(
213         handshake: &CompletedHandshake,
214         rng: R,
215     ) -> Self {
216         let next_protocol_secret = handshake
217             .next_protocol_secret::<C>()
218             .derive_array::<AES_256_KEY_SIZE>()
219             .unwrap();
220         D2DConnectionContextV1::new::<C>(
221             0,
222             0,
223             encryption_key::<32, C>(&next_protocol_secret, HKDF_INFO_KEY_RESPONDER).unwrap(),
224             encryption_key::<32, C>(&next_protocol_secret, HKDF_INFO_KEY_INITIATOR).unwrap(),
225             rng,
226         )
227     }
228 
229     /// Creates a saved session that can later be used for resumption. The session data may be
230     /// persisted, but it must be stored in a secure location.
231     ///
232     /// Returns the serialized saved session, suitable for resumption using
233     /// [`from_saved_session`][Self::from_saved_session].
234     ///
235     /// Structure of saved session is:
236     ///
237     /// ```text
238     /// +---------------------------------------------------------------------------+
239     /// | 1 Byte  |      4 Bytes      |      4 Bytes      |  32 Bytes  |  32 Bytes  |
240     /// +---------------------------------------------------------------------------+
241     /// | Version | encode seq number | decode seq number | encode key | decode key |
242     /// +---------------------------------------------------------------------------+
243     /// ```
244     ///
245     /// The sequence numbers are represented in big-endian.
save_session(&self) -> Vec<u8>246     pub fn save_session(&self) -> Vec<u8> {
247         let mut ret: Vec<u8> = vec![];
248         ret.push(PROTOCOL_VERSION);
249         ret.put_i32(self.encode_sequence_num);
250         ret.put_i32(self.decode_sequence_num);
251         ret.extend_from_slice(self.encode_key.as_slice());
252         ret.extend_from_slice(self.decode_key.as_slice());
253         ret
254     }
255 
from_saved_session_with_rng<C: CryptoProvider>( session: &[u8], rng: R, ) -> Result<Self, DeserializeError>256     pub(crate) fn from_saved_session_with_rng<C: CryptoProvider>(
257         session: &[u8],
258         rng: R,
259     ) -> Result<Self, DeserializeError> {
260         if session.len() != 73 {
261             return Err(DeserializeError::BadDataLength);
262         }
263         let (rem, _) = nom::bytes::complete::tag(PROTOCOL_VERSION.to_be_bytes())(session)
264             .map_err(|_: nom::Err<nom::error::Error<_>>| DeserializeError::BadProtocolVersion)?;
265 
266         let (_, (encode_sequence_num, decode_sequence_num, encode_key, decode_key)) =
267             nom::combinator::all_consuming(nom::sequence::tuple::<_, _, nom::error::Error<_>, _>(
268                 (
269                     nom::number::complete::be_i32,
270                     nom::number::complete::be_i32,
271                     nom::combinator::map_res(
272                         nom::bytes::complete::take(32_usize),
273                         TryInto::<Aes256Key>::try_into,
274                     ),
275                     nom::combinator::map_res(
276                         nom::bytes::complete::take(32_usize),
277                         TryInto::<Aes256Key>::try_into,
278                     ),
279                 ),
280             ))(rem)
281             // This should always succeed since all of the parsers above are valid over the entire
282             // [u8] space, and we already checked the length at the start.
283             .expect("Saved session parsing should succeed");
284         Ok(Self::new::<C>(
285             encode_sequence_num,
286             decode_sequence_num,
287             encode_key,
288             decode_key,
289             rng,
290         ))
291     }
292 
293     /// Once initiator and responder have exchanged public keys, use this method to encrypt and
294     /// sign a payload. Both initiator and responder devices can use this message.
295     ///
296     /// * `payload` - The payload that should be encrypted.
297     /// * `associated_data` - Optional data that is not included in the payload but is included in
298     ///       the calculation of the signature for this message. Note that the *size* (length in
299     ///       bytes) of the associated data will be sent in the *UNENCRYPTED* header information,
300     ///       even if you are using encryption.
encode_message_to_peer<C: CryptoProvider, A: AsRef<[u8]>>( &mut self, payload: &[u8], associated_data: Option<A>, ) -> Vec<u8>301     pub fn encode_message_to_peer<C: CryptoProvider, A: AsRef<[u8]>>(
302         &mut self,
303         payload: &[u8],
304         associated_data: Option<A>,
305     ) -> Vec<u8> {
306         self.increment_encode_sequence_number();
307         let message = create_device_to_device_message(RustDeviceToDeviceMessage {
308             message: payload.to_vec(),
309             sequence_num: self.get_sequence_number_for_encoding(),
310         });
311         let (ciphertext, iv) = crypto_utils::encrypt::<_, C::AesCbcPkcs7Padded>(
312             &self.encryption_key,
313             message.as_slice(),
314             &mut self.rng,
315         );
316         let metadata = GcmMetadata {
317             type_: Some(Type::DEVICE_TO_DEVICE_MESSAGE.into()),
318             // As specified in
319             // google3/third_party/ukey2/src/main/java/com/google/security/cryptauth/lib/securegcm/SecureGcmConstants.java
320             version: Some(1),
321             ..Default::default()
322         };
323         let header = Header {
324             signature_scheme: Some(SigScheme::HMAC_SHA256.into()),
325             encryption_scheme: Some(EncScheme::AES_256_CBC.into()),
326             iv: Some(iv.to_vec()),
327             public_metadata: Some(metadata.write_to_bytes().unwrap()),
328             associated_data_length: associated_data.as_ref().map(|d| d.as_ref().len() as u32),
329             ..Default::default()
330         };
331         let header_and_body = HeaderAndBody {
332             header: Some(header).into(),
333             body: Some(ciphertext),
334             ..Default::default()
335         };
336         let header_and_body_bytes = header_and_body.write_to_bytes().unwrap();
337 
338         // add sha256 MAC
339         let mut hmac = C::HmacSha256::new_from_slice(&self.signing_key).unwrap();
340         hmac.update(header_and_body_bytes.as_slice());
341         if let Some(associated_data_vec) = associated_data.as_ref() {
342             hmac.update(associated_data_vec.as_ref())
343         }
344         let result_mac = hmac.finalize().to_vec();
345 
346         let secure_message = SecureMessage {
347             header_and_body: Some(header_and_body_bytes),
348             signature: Some(result_mac),
349             ..Default::default()
350         };
351         secure_message.write_to_bytes().unwrap()
352     }
353 
354     /// Once `InitiatorHello` and `ResponderHello` (and payload) are exchanged, use this method to
355     /// decrypt and verify a message received from the other device. Both initiator and responder
356     /// devices can use this message.
357     ///
358     /// * `message` - the message that should be encrypted.
359     /// * `associated_data` - Optional associated data that must match what the sender provided. See
360     ///       the documentation on [`encode_message_to_peer`][Self::encode_message_to_peer].
decode_message_from_peer<C: CryptoProvider, A: AsRef<[u8]>>( &mut self, payload: &[u8], associated_data: Option<A>, ) -> Result<Vec<u8>, DecodeError>361     pub fn decode_message_from_peer<C: CryptoProvider, A: AsRef<[u8]>>(
362         &mut self,
363         payload: &[u8],
364         associated_data: Option<A>,
365     ) -> Result<Vec<u8>, DecodeError> {
366         // first confirm that the payload MAC matches the header_and_body
367         let message = SecureMessage::parse_from_bytes(payload).map_err(|_| DecodeError::BadData)?;
368         let payload_mac: [u8; 32] = message
369             .signature
370             .and_then(|signature| signature.try_into().ok())
371             .ok_or(DecodeError::BadData)?;
372         let payload = message.header_and_body.ok_or(DecodeError::BadData)?;
373         let mut hmac = C::HmacSha256::new_from_slice(&self.verify_key).unwrap();
374         hmac.update(&payload);
375         if let Some(associated_data) = associated_data.as_ref() {
376             hmac.update(associated_data.as_ref())
377         }
378         hmac.verify(payload_mac).map_err(|_| DecodeError::BadData)?;
379         let payload =
380             HeaderAndBody::parse_from_bytes(&payload).map_err(|_| DecodeError::BadData)?;
381         let associated_data_len = payload
382             .header
383             .as_ref()
384             .and_then(|header| header.associated_data_length);
385         if associated_data_len != associated_data.map(|ad| ad.as_ref().len() as u32) {
386             return Err(DecodeError::BadData);
387         }
388         let iv: AesCbcIv = payload
389             .header
390             .as_ref()
391             .and_then(|header| header.iv().try_into().ok())
392             .ok_or(DecodeError::BadData)?;
393         let decrypted = crypto_utils::decrypt::<C::AesCbcPkcs7Padded>(
394             &self.decryption_key,
395             &payload.body.unwrap_or_default(),
396             &iv,
397         )
398         .map_err(|_| DecodeError::BadData)?;
399         let d2d_message = unwrap_device_to_device_message(decrypted.as_slice())?;
400         if d2d_message.sequence_num != self.get_sequence_number_for_decoding() + 1 {
401             return Err(DecodeError::BadSequenceNumber);
402         }
403         self.increment_decode_sequence_number();
404         Ok(d2d_message.message)
405     }
406 
increment_encode_sequence_number(&mut self)407     fn increment_encode_sequence_number(&mut self) {
408         self.encode_sequence_num += 1;
409     }
410 
increment_decode_sequence_number(&mut self)411     fn increment_decode_sequence_number(&mut self) {
412         self.decode_sequence_num += 1;
413     }
414 
415     /// Returns the last sequence number used to encode a message.
get_sequence_number_for_encoding(&self) -> i32416     pub fn get_sequence_number_for_encoding(&self) -> i32 {
417         self.encode_sequence_num
418     }
419 
420     /// Returns the last sequence number used to decode a message.
get_sequence_number_for_decoding(&self) -> i32421     pub fn get_sequence_number_for_decoding(&self) -> i32 {
422         self.decode_sequence_num
423     }
424 
425     /// Returns a cryptographic digest (SHA256) of the session keys prepended by the SHA256 hash
426     /// of the ASCII string "D2D". Since the server and client share the same session keys, the
427     /// resulting session unique is also the same.
get_session_unique<C: CryptoProvider>(&self) -> Vec<u8>428     pub fn get_session_unique<C: CryptoProvider>(&self) -> Vec<u8> {
429         let encode_key_hash = java_utils::hash_code(self.encode_key.as_slice());
430         let decode_key_hash = java_utils::hash_code(self.decode_key.as_slice());
431         let first_key_bytes = if encode_key_hash < decode_key_hash {
432             self.encode_key.as_slice()
433         } else {
434             self.decode_key.as_slice()
435         };
436         let second_key_bytes = if first_key_bytes == self.encode_key.as_slice() {
437             self.decode_key.as_slice()
438         } else {
439             self.encode_key.as_slice()
440         };
441         C::Sha256::sha256(&[&SESSION_UNIQUE_SALT, first_key_bytes, second_key_bytes].concat())
442             .to_vec()
443     }
444 }
445