1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::HashMap;
16 use std::ffi::CString;
17 use std::ptr::null_mut;
18 
19 use lazy_static::lazy_static;
20 use rand::Rng;
21 use rand_chacha::rand_core::SeedableRng;
22 use rand_chacha::ChaCha20Rng;
23 use spin::Mutex;
24 
25 use ukey2_connections::{
26     D2DConnectionContextV1, D2DHandshakeContext, HandshakeImplementation,
27     InitiatorD2DHandshakeContext, ServerD2DHandshakeContext,
28 };
29 
30 cfg_if::cfg_if! {
31     if #[cfg(feature = "rustcrypto")] {
32         use crypto_provider_rustcrypto::RustCrypto as CryptoProvider;
33     } else {
34         use crypto_provider_openssl::Openssl as CryptoProvider;
35     }
36 }
37 #[repr(C)]
38 pub struct RustFFIByteArray {
39     ptr: *mut u8,
40     len: usize,
41 }
42 
43 #[repr(C)]
44 pub struct CFFIByteArray {
45     ptr: *mut u8,
46     len: usize,
47 }
48 
49 type D2DBox = Box<dyn D2DHandshakeContext>;
50 type ConnectionBox = Box<D2DConnectionContextV1>;
51 
52 lazy_static! {
53     static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
54     static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
55         Mutex::new(HashMap::new());
56     static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
57 }
58 
generate_handle() -> u6459 fn generate_handle() -> u64 {
60     RNG.lock().gen()
61 }
62 
insert_gen_handle(item: D2DBox) -> u6463 fn insert_gen_handle(item: D2DBox) -> u64 {
64     let handle = generate_handle();
65     HANDLE_MAPPING.lock().insert(handle, item);
66     handle
67 }
68 
insert_conn_gen_handle(item: ConnectionBox) -> u6469 fn insert_conn_gen_handle(item: ConnectionBox) -> u64 {
70     let handle = generate_handle();
71     CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
72     handle
73 }
74 
75 // Utilities
76 /// This function deallocates FFIByteArray instances allocated from Rust only.
77 /// NOTE: Any FFIByteArray instances deallocated by this function will no longer be in a guaranteed
78 /// usable state.
79 ///
80 /// # Safety
81 /// The array must have been allocated by a Rust function with the Rust allocator, e.g.
82 /// [get_next_handshake_message].
83 #[no_mangle]
rust_dealloc_ffi_byte_array(arr: RustFFIByteArray)84 pub unsafe extern "C" fn rust_dealloc_ffi_byte_array(arr: RustFFIByteArray) {
85     if !arr.ptr.is_null() {
86         let _ = Vec::from_raw_parts(arr.ptr, arr.len, arr.len);
87     }
88 }
89 
90 // Common functions
91 #[no_mangle]
is_handshake_complete(handle: u64) -> bool92 pub extern "C" fn is_handshake_complete(handle: u64) -> bool {
93     HANDLE_MAPPING
94         .lock()
95         .get(&handle)
96         .map_or(false, |ctx| ctx.is_handshake_complete())
97 }
98 
99 #[no_mangle]
get_next_handshake_message(handle: u64) -> RustFFIByteArray100 pub extern "C" fn get_next_handshake_message(handle: u64) -> RustFFIByteArray {
101     // TODO: error handling
102     let opt_msg = HANDLE_MAPPING
103         .lock()
104         .get(&handle)
105         .and_then(|c| c.get_next_handshake_message());
106     if let Some(msg) = opt_msg {
107         let ret_len = msg.len();
108         let data: CString = unsafe { CString::from_vec_unchecked(msg) };
109         RustFFIByteArray {
110             ptr: data.into_raw() as *mut u8,
111             len: ret_len,
112         }
113     } else {
114         RustFFIByteArray {
115             ptr: null_mut(),
116             len: usize::MAX,
117         }
118     }
119 }
120 
121 /// # Safety
122 /// We treat msg as data, so we should never have an issue trying to execute it.
123 #[no_mangle]
parse_handshake_message( handle: u64, arr: CFFIByteArray, ) -> RustFFIByteArray124 pub unsafe extern "C" fn parse_handshake_message(
125     handle: u64,
126     arr: CFFIByteArray,
127 ) -> RustFFIByteArray {
128     let msg = Vec::<u8>::from_raw_parts(arr.ptr, arr.len, arr.len);
129     // TODO error handling
130     let result = HANDLE_MAPPING
131         .lock()
132         .get_mut(&handle)
133         .unwrap()
134         .handle_handshake_message(msg.as_slice());
135     if let Err(error) = result {
136         log::error!("{:?}", error);
137     }
138     RustFFIByteArray {
139         ptr: null_mut(),
140         len: usize::MAX,
141     }
142 }
143 
144 #[no_mangle]
get_verification_string(handle: u64, length: usize) -> RustFFIByteArray145 pub extern "C" fn get_verification_string(handle: u64, length: usize) -> RustFFIByteArray {
146     HANDLE_MAPPING
147         .lock()
148         .get(&handle)
149         .map(|h| {
150             let auth_vec = h
151                 .to_completed_handshake()
152                 .unwrap()
153                 .auth_string::<CryptoProvider>()
154                 .derive_vec(length)
155                 .unwrap();
156             let vec_len = auth_vec.len();
157             RustFFIByteArray {
158                 ptr: auth_vec.leak().as_mut_ptr(),
159                 len: vec_len,
160             }
161         })
162         .unwrap()
163 }
164 
165 #[no_mangle]
to_connection_context(handle: u64) -> u64166 pub extern "C" fn to_connection_context(handle: u64) -> u64 {
167     // TODO: error handling
168     let ctx = HANDLE_MAPPING
169         .lock()
170         .remove(&handle)
171         .map(move |mut ctx| {
172             let result = Box::new(ctx.to_connection_context().unwrap());
173             drop(ctx);
174             result
175         })
176         .unwrap();
177     insert_conn_gen_handle(ctx)
178 }
179 
180 // Responder-specific functions
181 #[no_mangle]
responder_new() -> u64182 pub extern "C" fn responder_new() -> u64 {
183     let ctx = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
184         HandshakeImplementation::PublicKeyInProtobuf,
185     ));
186     insert_gen_handle(ctx)
187 }
188 
189 // Initiator-specific functions
190 
191 /// # Safety
192 /// We treat next_protocol as data, not as executable memory.
193 #[no_mangle]
initiator_new() -> u64194 pub extern "C" fn initiator_new() -> u64 {
195     let ctx = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
196         HandshakeImplementation::PublicKeyInProtobuf,
197     ));
198     insert_gen_handle(ctx)
199 }
200 
201 // Connection Context
202 
203 /// # Safety
204 /// We treat msg and associated_data as data, not as executable memory.
205 /// associated_data and msg are slices so Rust won't try to do anything weird with allocation.
206 #[no_mangle]
encode_message_to_peer( handle: u64, msg: CFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray207 pub unsafe extern "C" fn encode_message_to_peer(
208     handle: u64,
209     msg: CFFIByteArray,
210     associated_data: CFFIByteArray,
211 ) -> RustFFIByteArray {
212     if msg.len == 0 {
213         return RustFFIByteArray {
214             ptr: null_mut(),
215             len: usize::MAX,
216         };
217     }
218     let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
219     let associated_data = if !associated_data.ptr.is_null() {
220         Some(std::slice::from_raw_parts(
221             associated_data.ptr,
222             associated_data.len,
223         ))
224     } else {
225         None
226     };
227     let ret = CONNECTION_HANDLE_MAPPING
228         .lock()
229         .get_mut(&handle)
230         .map(|c| c.encode_message_to_peer::<CryptoProvider, _>(msg, associated_data));
231     if let Some(msg) = ret {
232         let len = msg.len();
233         RustFFIByteArray {
234             ptr: msg.leak().as_mut_ptr(),
235             len,
236         }
237     } else {
238         log::error!("Was unable to find handle!");
239         RustFFIByteArray {
240             ptr: null_mut(),
241             len: usize::MAX,
242         }
243     }
244 }
245 
246 /// # Safety
247 /// We treat msg as data, not as executable memory.
248 #[no_mangle]
decode_message_from_peer( handle: u64, msg: RustFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray249 pub unsafe extern "C" fn decode_message_from_peer(
250     handle: u64,
251     msg: RustFFIByteArray,
252     associated_data: CFFIByteArray,
253 ) -> RustFFIByteArray {
254     if msg.len == 0 {
255         return RustFFIByteArray {
256             ptr: null_mut(),
257             len: usize::MAX,
258         };
259     }
260     let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
261     let associated_data = if !associated_data.ptr.is_null() {
262         Some(std::slice::from_raw_parts(
263             associated_data.ptr,
264             associated_data.len,
265         ))
266     } else {
267         None
268     };
269     let ret: Result<Vec<u8>, ukey2_connections::DecodeError> = CONNECTION_HANDLE_MAPPING
270         .lock()
271         .get_mut(&handle)
272         .unwrap()
273         .decode_message_from_peer::<CryptoProvider, _>(msg, associated_data);
274     if let Ok(decoded) = ret {
275         let len = decoded.len();
276         RustFFIByteArray {
277             ptr: decoded.leak().as_mut_ptr(),
278             len,
279         }
280     } else {
281         RustFFIByteArray {
282             ptr: null_mut(),
283             len: usize::MAX,
284         }
285     }
286 }
287 
288 #[no_mangle]
get_session_unique(handle: u64) -> RustFFIByteArray289 pub extern "C" fn get_session_unique(handle: u64) -> RustFFIByteArray {
290     let session_unique_bytes = CONNECTION_HANDLE_MAPPING
291         .lock()
292         .get(&handle)
293         .unwrap()
294         .get_session_unique::<CryptoProvider>();
295     let handle_size = session_unique_bytes.len();
296     RustFFIByteArray {
297         ptr: session_unique_bytes.leak().as_mut_ptr(),
298         len: handle_size,
299     }
300 }
301 
302 #[no_mangle]
get_sequence_number_for_encoding(handle: u64) -> i32303 pub extern "C" fn get_sequence_number_for_encoding(handle: u64) -> i32 {
304     CONNECTION_HANDLE_MAPPING
305         .lock()
306         .get(&handle)
307         .unwrap()
308         .get_sequence_number_for_encoding()
309 }
310 
311 #[no_mangle]
get_sequence_number_for_decoding(handle: u64) -> i32312 pub extern "C" fn get_sequence_number_for_decoding(handle: u64) -> i32 {
313     CONNECTION_HANDLE_MAPPING
314         .lock()
315         .get(&handle)
316         .unwrap()
317         .get_sequence_number_for_decoding()
318 }
319 
320 #[no_mangle]
save_session(handle: u64) -> RustFFIByteArray321 pub extern "C" fn save_session(handle: u64) -> RustFFIByteArray {
322     let key = CONNECTION_HANDLE_MAPPING
323         .lock()
324         .get(&handle)
325         .unwrap()
326         .save_session();
327     let handle_size = key.len();
328     RustFFIByteArray {
329         ptr: key.leak().as_mut_ptr(),
330         len: handle_size,
331     }
332 }
333 
334 #[repr(i32)]
335 #[derive(Debug)]
336 pub enum Status {
337     Good,
338     Error,
339 }
340 
341 #[repr(C)]
342 pub struct CD2DRestoreConnectionContextV1Result {
343     handle: u64,
344     status: Status,
345 }
346 
347 /// # Safety
348 /// We error out if the length is incorrect (too large or too small) for restoring a session.
349 #[no_mangle]
from_saved_session( arr: CFFIByteArray, ) -> CD2DRestoreConnectionContextV1Result350 pub unsafe extern "C" fn from_saved_session(
351     arr: CFFIByteArray,
352 ) -> CD2DRestoreConnectionContextV1Result {
353     let saved_session = std::slice::from_raw_parts(arr.ptr, arr.len);
354     let ctx = D2DConnectionContextV1::from_saved_session::<CryptoProvider>(saved_session);
355     if let Ok(conn_ctx) = ctx {
356         let final_ctx = Box::new(conn_ctx);
357         CD2DRestoreConnectionContextV1Result {
358             handle: insert_conn_gen_handle(final_ctx),
359             status: Status::Good,
360         }
361     } else {
362         log::error!(
363             "failed to restore session with error {:?}",
364             ctx.unwrap_err()
365         );
366         CD2DRestoreConnectionContextV1Result {
367             handle: u64::MAX,
368             status: Status::Error,
369         }
370     }
371 }
372