1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::collections::HashMap;
16 use std::ffi::CString;
17 use std::ptr::null_mut;
18
19 use lazy_static::lazy_static;
20 use rand::Rng;
21 use rand_chacha::rand_core::SeedableRng;
22 use rand_chacha::ChaCha20Rng;
23 use spin::Mutex;
24
25 use ukey2_connections::{
26 D2DConnectionContextV1, D2DHandshakeContext, HandshakeImplementation,
27 InitiatorD2DHandshakeContext, ServerD2DHandshakeContext,
28 };
29
30 cfg_if::cfg_if! {
31 if #[cfg(feature = "rustcrypto")] {
32 use crypto_provider_rustcrypto::RustCrypto as CryptoProvider;
33 } else {
34 use crypto_provider_openssl::Openssl as CryptoProvider;
35 }
36 }
37 #[repr(C)]
38 pub struct RustFFIByteArray {
39 ptr: *mut u8,
40 len: usize,
41 }
42
43 #[repr(C)]
44 pub struct CFFIByteArray {
45 ptr: *mut u8,
46 len: usize,
47 }
48
49 type D2DBox = Box<dyn D2DHandshakeContext>;
50 type ConnectionBox = Box<D2DConnectionContextV1>;
51
52 lazy_static! {
53 static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
54 static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
55 Mutex::new(HashMap::new());
56 static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
57 }
58
generate_handle() -> u6459 fn generate_handle() -> u64 {
60 RNG.lock().gen()
61 }
62
insert_gen_handle(item: D2DBox) -> u6463 fn insert_gen_handle(item: D2DBox) -> u64 {
64 let handle = generate_handle();
65 HANDLE_MAPPING.lock().insert(handle, item);
66 handle
67 }
68
insert_conn_gen_handle(item: ConnectionBox) -> u6469 fn insert_conn_gen_handle(item: ConnectionBox) -> u64 {
70 let handle = generate_handle();
71 CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
72 handle
73 }
74
75 // Utilities
76 /// This function deallocates FFIByteArray instances allocated from Rust only.
77 /// NOTE: Any FFIByteArray instances deallocated by this function will no longer be in a guaranteed
78 /// usable state.
79 ///
80 /// # Safety
81 /// The array must have been allocated by a Rust function with the Rust allocator, e.g.
82 /// [get_next_handshake_message].
83 #[no_mangle]
rust_dealloc_ffi_byte_array(arr: RustFFIByteArray)84 pub unsafe extern "C" fn rust_dealloc_ffi_byte_array(arr: RustFFIByteArray) {
85 if !arr.ptr.is_null() {
86 let _ = Vec::from_raw_parts(arr.ptr, arr.len, arr.len);
87 }
88 }
89
90 // Common functions
91 #[no_mangle]
is_handshake_complete(handle: u64) -> bool92 pub extern "C" fn is_handshake_complete(handle: u64) -> bool {
93 HANDLE_MAPPING
94 .lock()
95 .get(&handle)
96 .map_or(false, |ctx| ctx.is_handshake_complete())
97 }
98
99 #[no_mangle]
get_next_handshake_message(handle: u64) -> RustFFIByteArray100 pub extern "C" fn get_next_handshake_message(handle: u64) -> RustFFIByteArray {
101 // TODO: error handling
102 let opt_msg = HANDLE_MAPPING
103 .lock()
104 .get(&handle)
105 .and_then(|c| c.get_next_handshake_message());
106 if let Some(msg) = opt_msg {
107 let ret_len = msg.len();
108 let data: CString = unsafe { CString::from_vec_unchecked(msg) };
109 RustFFIByteArray {
110 ptr: data.into_raw() as *mut u8,
111 len: ret_len,
112 }
113 } else {
114 RustFFIByteArray {
115 ptr: null_mut(),
116 len: usize::MAX,
117 }
118 }
119 }
120
121 /// # Safety
122 /// We treat msg as data, so we should never have an issue trying to execute it.
123 #[no_mangle]
parse_handshake_message( handle: u64, arr: CFFIByteArray, ) -> RustFFIByteArray124 pub unsafe extern "C" fn parse_handshake_message(
125 handle: u64,
126 arr: CFFIByteArray,
127 ) -> RustFFIByteArray {
128 let msg = Vec::<u8>::from_raw_parts(arr.ptr, arr.len, arr.len);
129 // TODO error handling
130 let result = HANDLE_MAPPING
131 .lock()
132 .get_mut(&handle)
133 .unwrap()
134 .handle_handshake_message(msg.as_slice());
135 if let Err(error) = result {
136 log::error!("{:?}", error);
137 }
138 RustFFIByteArray {
139 ptr: null_mut(),
140 len: usize::MAX,
141 }
142 }
143
144 #[no_mangle]
get_verification_string(handle: u64, length: usize) -> RustFFIByteArray145 pub extern "C" fn get_verification_string(handle: u64, length: usize) -> RustFFIByteArray {
146 HANDLE_MAPPING
147 .lock()
148 .get(&handle)
149 .map(|h| {
150 let auth_vec = h
151 .to_completed_handshake()
152 .unwrap()
153 .auth_string::<CryptoProvider>()
154 .derive_vec(length)
155 .unwrap();
156 let vec_len = auth_vec.len();
157 RustFFIByteArray {
158 ptr: auth_vec.leak().as_mut_ptr(),
159 len: vec_len,
160 }
161 })
162 .unwrap()
163 }
164
165 #[no_mangle]
to_connection_context(handle: u64) -> u64166 pub extern "C" fn to_connection_context(handle: u64) -> u64 {
167 // TODO: error handling
168 let ctx = HANDLE_MAPPING
169 .lock()
170 .remove(&handle)
171 .map(move |mut ctx| {
172 let result = Box::new(ctx.to_connection_context().unwrap());
173 drop(ctx);
174 result
175 })
176 .unwrap();
177 insert_conn_gen_handle(ctx)
178 }
179
180 // Responder-specific functions
181 #[no_mangle]
responder_new() -> u64182 pub extern "C" fn responder_new() -> u64 {
183 let ctx = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
184 HandshakeImplementation::PublicKeyInProtobuf,
185 ));
186 insert_gen_handle(ctx)
187 }
188
189 // Initiator-specific functions
190
191 /// # Safety
192 /// We treat next_protocol as data, not as executable memory.
193 #[no_mangle]
initiator_new() -> u64194 pub extern "C" fn initiator_new() -> u64 {
195 let ctx = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
196 HandshakeImplementation::PublicKeyInProtobuf,
197 ));
198 insert_gen_handle(ctx)
199 }
200
201 // Connection Context
202
203 /// # Safety
204 /// We treat msg and associated_data as data, not as executable memory.
205 /// associated_data and msg are slices so Rust won't try to do anything weird with allocation.
206 #[no_mangle]
encode_message_to_peer( handle: u64, msg: CFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray207 pub unsafe extern "C" fn encode_message_to_peer(
208 handle: u64,
209 msg: CFFIByteArray,
210 associated_data: CFFIByteArray,
211 ) -> RustFFIByteArray {
212 if msg.len == 0 {
213 return RustFFIByteArray {
214 ptr: null_mut(),
215 len: usize::MAX,
216 };
217 }
218 let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
219 let associated_data = if !associated_data.ptr.is_null() {
220 Some(std::slice::from_raw_parts(
221 associated_data.ptr,
222 associated_data.len,
223 ))
224 } else {
225 None
226 };
227 let ret = CONNECTION_HANDLE_MAPPING
228 .lock()
229 .get_mut(&handle)
230 .map(|c| c.encode_message_to_peer::<CryptoProvider, _>(msg, associated_data));
231 if let Some(msg) = ret {
232 let len = msg.len();
233 RustFFIByteArray {
234 ptr: msg.leak().as_mut_ptr(),
235 len,
236 }
237 } else {
238 log::error!("Was unable to find handle!");
239 RustFFIByteArray {
240 ptr: null_mut(),
241 len: usize::MAX,
242 }
243 }
244 }
245
246 /// # Safety
247 /// We treat msg as data, not as executable memory.
248 #[no_mangle]
decode_message_from_peer( handle: u64, msg: RustFFIByteArray, associated_data: CFFIByteArray, ) -> RustFFIByteArray249 pub unsafe extern "C" fn decode_message_from_peer(
250 handle: u64,
251 msg: RustFFIByteArray,
252 associated_data: CFFIByteArray,
253 ) -> RustFFIByteArray {
254 if msg.len == 0 {
255 return RustFFIByteArray {
256 ptr: null_mut(),
257 len: usize::MAX,
258 };
259 }
260 let msg = std::slice::from_raw_parts(msg.ptr, msg.len);
261 let associated_data = if !associated_data.ptr.is_null() {
262 Some(std::slice::from_raw_parts(
263 associated_data.ptr,
264 associated_data.len,
265 ))
266 } else {
267 None
268 };
269 let ret: Result<Vec<u8>, ukey2_connections::DecodeError> = CONNECTION_HANDLE_MAPPING
270 .lock()
271 .get_mut(&handle)
272 .unwrap()
273 .decode_message_from_peer::<CryptoProvider, _>(msg, associated_data);
274 if let Ok(decoded) = ret {
275 let len = decoded.len();
276 RustFFIByteArray {
277 ptr: decoded.leak().as_mut_ptr(),
278 len,
279 }
280 } else {
281 RustFFIByteArray {
282 ptr: null_mut(),
283 len: usize::MAX,
284 }
285 }
286 }
287
288 #[no_mangle]
get_session_unique(handle: u64) -> RustFFIByteArray289 pub extern "C" fn get_session_unique(handle: u64) -> RustFFIByteArray {
290 let session_unique_bytes = CONNECTION_HANDLE_MAPPING
291 .lock()
292 .get(&handle)
293 .unwrap()
294 .get_session_unique::<CryptoProvider>();
295 let handle_size = session_unique_bytes.len();
296 RustFFIByteArray {
297 ptr: session_unique_bytes.leak().as_mut_ptr(),
298 len: handle_size,
299 }
300 }
301
302 #[no_mangle]
get_sequence_number_for_encoding(handle: u64) -> i32303 pub extern "C" fn get_sequence_number_for_encoding(handle: u64) -> i32 {
304 CONNECTION_HANDLE_MAPPING
305 .lock()
306 .get(&handle)
307 .unwrap()
308 .get_sequence_number_for_encoding()
309 }
310
311 #[no_mangle]
get_sequence_number_for_decoding(handle: u64) -> i32312 pub extern "C" fn get_sequence_number_for_decoding(handle: u64) -> i32 {
313 CONNECTION_HANDLE_MAPPING
314 .lock()
315 .get(&handle)
316 .unwrap()
317 .get_sequence_number_for_decoding()
318 }
319
320 #[no_mangle]
save_session(handle: u64) -> RustFFIByteArray321 pub extern "C" fn save_session(handle: u64) -> RustFFIByteArray {
322 let key = CONNECTION_HANDLE_MAPPING
323 .lock()
324 .get(&handle)
325 .unwrap()
326 .save_session();
327 let handle_size = key.len();
328 RustFFIByteArray {
329 ptr: key.leak().as_mut_ptr(),
330 len: handle_size,
331 }
332 }
333
334 #[repr(i32)]
335 #[derive(Debug)]
336 pub enum Status {
337 Good,
338 Error,
339 }
340
341 #[repr(C)]
342 pub struct CD2DRestoreConnectionContextV1Result {
343 handle: u64,
344 status: Status,
345 }
346
347 /// # Safety
348 /// We error out if the length is incorrect (too large or too small) for restoring a session.
349 #[no_mangle]
from_saved_session( arr: CFFIByteArray, ) -> CD2DRestoreConnectionContextV1Result350 pub unsafe extern "C" fn from_saved_session(
351 arr: CFFIByteArray,
352 ) -> CD2DRestoreConnectionContextV1Result {
353 let saved_session = std::slice::from_raw_parts(arr.ptr, arr.len);
354 let ctx = D2DConnectionContextV1::from_saved_session::<CryptoProvider>(saved_session);
355 if let Ok(conn_ctx) = ctx {
356 let final_ctx = Box::new(conn_ctx);
357 CD2DRestoreConnectionContextV1Result {
358 handle: insert_conn_gen_handle(final_ctx),
359 status: Status::Good,
360 }
361 } else {
362 log::error!(
363 "failed to restore session with error {:?}",
364 ctx.unwrap_err()
365 );
366 CD2DRestoreConnectionContextV1Result {
367 handle: u64::MAX,
368 status: Status::Error,
369 }
370 }
371 }
372