1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 use std::collections::HashMap;
16
17 use jni::objects::JClass;
18 use jni::sys::{jboolean, jbyteArray, jint, jlong, JNI_TRUE};
19 use jni::JNIEnv;
20 use lazy_static::lazy_static;
21 use rand::Rng;
22 use rand_chacha::rand_core::SeedableRng;
23 use rand_chacha::ChaCha20Rng;
24 use spin::Mutex;
25
26 use ukey2_connections::{
27 D2DConnectionContextV1, D2DHandshakeContext, DecodeError, DeserializeError, HandleMessageError,
28 HandshakeError, HandshakeImplementation, InitiatorD2DHandshakeContext,
29 ServerD2DHandshakeContext,
30 };
31
32 cfg_if::cfg_if! {
33 if #[cfg(feature = "rustcrypto")] {
34 use crypto_provider_rustcrypto::RustCrypto as CryptoProvider;
35 } else {
36 use crypto_provider_openssl::Openssl as CryptoProvider;
37 }
38 }
39 // Handle management
40
41 type D2DBox = Box<dyn D2DHandshakeContext>;
42 type ConnectionBox = Box<D2DConnectionContextV1>;
43
44 lazy_static! {
45 static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
46 static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
47 Mutex::new(HashMap::new());
48 static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
49 }
50
generate_handle() -> u6451 fn generate_handle() -> u64 {
52 RNG.lock().gen()
53 }
54
insert_handshake_handle(item: D2DBox) -> u6455 pub(crate) fn insert_handshake_handle(item: D2DBox) -> u64 {
56 let handle = generate_handle();
57 HANDLE_MAPPING.lock().insert(handle, item);
58 handle
59 }
60
insert_conn_handle(item: ConnectionBox) -> u6461 pub(crate) fn insert_conn_handle(item: ConnectionBox) -> u64 {
62 let handle = generate_handle();
63 CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
64 handle
65 }
66
67 #[derive(Debug)]
68 enum JniError {
69 BadHandle,
70 DecodeError(DecodeError),
71 HandleMessageError(HandleMessageError),
72 HandshakeError(HandshakeError),
73 }
74
75 // D2DHandshakeContext
76 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_is_1handshake_1complete( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jboolean77 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_is_1handshake_1complete(
78 mut env: JNIEnv,
79 _: JClass,
80 context_handle: jlong,
81 ) -> jboolean {
82 let mut is_complete = false;
83 if let Some(ctx) = HANDLE_MAPPING.lock().get(&(context_handle as u64)) {
84 is_complete = ctx.is_handshake_complete();
85 } else {
86 env.throw_new(
87 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
88 "",
89 )
90 .expect("failed to find error class");
91 }
92 is_complete as jboolean
93 }
94
95 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_create_1context( _: JNIEnv, _: JClass, is_client: jboolean, ) -> jlong96 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_create_1context(
97 _: JNIEnv,
98 _: JClass,
99 is_client: jboolean,
100 ) -> jlong {
101 if is_client == JNI_TRUE {
102 let client_obj = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
103 HandshakeImplementation::PublicKeyInProtobuf,
104 ));
105 insert_handshake_handle(client_obj) as jlong
106 } else {
107 let server_obj = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
108 HandshakeImplementation::PublicKeyInProtobuf,
109 ));
110 insert_handshake_handle(server_obj) as jlong
111 }
112 }
113
114 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1next_1handshake_1message( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jbyteArray115 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1next_1handshake_1message(
116 mut env: JNIEnv,
117 _: JClass,
118 context_handle: jlong,
119 ) -> jbyteArray {
120 let empty_arr = env.new_byte_array(0).unwrap();
121 let next_message = if let Some(ctx) = HANDLE_MAPPING.lock().get(&(context_handle as u64)) {
122 ctx.get_next_handshake_message()
123 } else {
124 env.throw_new(
125 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
126 "",
127 )
128 .expect("failed to find error class");
129 None
130 };
131 // TODO error handling
132 if let Some(message) = next_message {
133 env.byte_array_from_slice(message.as_slice()).unwrap()
134 } else {
135 empty_arr
136 }
137 }
138
139 #[no_mangle]
140 #[allow(clippy::not_unsafe_ptr_arg_deref)]
141 /// Safety: We know the message pointer is safe as it is coming directly from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_parse_1handshake_1message( mut env: JNIEnv, _: JClass, context_handle: jlong, message: jbyteArray, )142 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_parse_1handshake_1message(
143 mut env: JNIEnv,
144 _: JClass,
145 context_handle: jlong,
146 message: jbyteArray,
147 ) {
148 let rust_buffer = env
149 .convert_byte_array(message)
150 .unwrap();
151 let result = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
152 ctx.handle_handshake_message(rust_buffer.as_slice())
153 .map_err(JniError::HandleMessageError)
154 } else {
155 env.throw_new(
156 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
157 "",
158 )
159 .expect("failed to find error class");
160 Err(JniError::BadHandle)
161 };
162 if let Err(e) = result {
163 if !env.exception_check().unwrap() {
164 env.throw_new(
165 "com/google/security/cryptauth/lib/securegcm/HandshakeException",
166 match e {
167 JniError::BadHandle => "Bad handle",
168 JniError::DecodeError(_) => "Unable to decode message",
169 JniError::HandleMessageError(_) => "Unable to handle message",
170 JniError::HandshakeError(_) => "Handshake incomplete",
171 },
172 )
173 .expect("failed to find error class");
174 }
175 }
176 }
177
178 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1verification_1string( mut env: JNIEnv, _: JClass, context_handle: jlong, length: jint, ) -> jbyteArray179 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1verification_1string(
180 mut env: JNIEnv,
181 _: JClass,
182 context_handle: jlong,
183 length: jint,
184 ) -> jbyteArray {
185 let empty_array = env.new_byte_array(0).unwrap();
186 let result = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
187 ctx.to_completed_handshake()
188 .map_err(|_| JniError::HandshakeError(HandshakeError::HandshakeNotComplete))
189 .map(|h| {
190 h.auth_string::<CryptoProvider>()
191 .derive_vec(length as usize)
192 .unwrap()
193 })
194 } else {
195 env.throw_new(
196 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
197 "",
198 )
199 .expect("failed to find error class");
200 Err(JniError::BadHandle)
201 };
202 if let Err(e) = result {
203 if !env.exception_check().unwrap() {
204 env.throw_new(
205 "com/google/security/cryptauth/lib/securegcm/HandshakeException",
206 match e {
207 JniError::BadHandle => "Bad handle",
208 JniError::DecodeError(_) => "Unable to decode message",
209 JniError::HandleMessageError(_) => "Unable to handle message",
210 JniError::HandshakeError(_) => "Handshake incomplete",
211 },
212 )
213 .expect("failed to find error class");
214 }
215 empty_array
216 } else {
217 let ret_vec = result.unwrap();
218 env.byte_array_from_slice(&ret_vec).unwrap()
219 }
220 }
221
222 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_to_1connection_1context( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jlong223 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_to_1connection_1context(
224 mut env: JNIEnv,
225 _: JClass,
226 context_handle: jlong,
227 ) -> jlong {
228 let conn_context = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
229 ctx.to_connection_context()
230 .map_err(JniError::HandshakeError)
231 } else {
232 Err(JniError::BadHandle)
233 };
234 if let Err(error) = conn_context {
235 env.throw_new(
236 "com/google/security/cryptauth/lib/securegcm/HandshakeException",
237 match error {
238 JniError::BadHandle => "Bad context handle",
239 JniError::HandshakeError(_) => "Handshake not complete",
240 JniError::DecodeError(_) | JniError::HandleMessageError(_) => "Unknown exception",
241 },
242 )
243 .expect("failed to find error class");
244 return -1;
245 } else {
246 HANDLE_MAPPING.lock().remove(&(context_handle as u64));
247 }
248 insert_conn_handle(Box::new(conn_context.unwrap())) as jlong
249 }
250
251 // D2DConnectionContextV1
252 #[no_mangle]
253 #[allow(clippy::not_unsafe_ptr_arg_deref)]
254 /// Safety: We know the payload and associated_data pointers are safe as they are coming directly
255 /// from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_encode_1message_1to_1peer( mut env: JNIEnv, _: JClass, context_handle: jlong, payload: jbyteArray, associated_data: jbyteArray, ) -> jbyteArray256 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_encode_1message_1to_1peer(
257 mut env: JNIEnv,
258 _: JClass,
259 context_handle: jlong,
260 payload: jbyteArray,
261 associated_data: jbyteArray,
262 ) -> jbyteArray {
263 // We create the empty array here so we don't run into issues requesting a new byte array from
264 // the JNI env while an exception is being thrown.
265 let empty_array = env.new_byte_array(0).unwrap();
266 let result = if let Some(ctx) = CONNECTION_HANDLE_MAPPING
267 .lock()
268 .get_mut(&(context_handle as u64))
269 {
270 Ok(ctx.encode_message_to_peer::<CryptoProvider, _>(
271 env.convert_byte_array(payload)
272 .unwrap()
273 .as_slice(),
274 if associated_data.is_null() {
275 None
276 } else {
277 Some(
278 env.convert_byte_array(associated_data)
279 .unwrap(),
280 )
281 },
282 ))
283 } else {
284 Err(JniError::BadHandle)
285 };
286 if let Ok(ret_vec) = result {
287 env.byte_array_from_slice(ret_vec.as_slice())
288 .expect("unable to create jByteArray")
289 } else {
290 env.throw_new(
291 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
292 "",
293 )
294 .expect("failed to find error class");
295 empty_array
296 }
297 }
298
299 #[no_mangle]
300 #[allow(clippy::not_unsafe_ptr_arg_deref)]
301 /// Safety: We know the message and associated_data pointers are safe as they are coming directly
302 /// from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_decode_1message_1from_1peer( mut env: JNIEnv, _: JClass, context_handle: jlong, message: jbyteArray, associated_data: jbyteArray, ) -> jbyteArray303 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_decode_1message_1from_1peer(
304 mut env: JNIEnv,
305 _: JClass,
306 context_handle: jlong,
307 message: jbyteArray,
308 associated_data: jbyteArray,
309 ) -> jbyteArray {
310 let empty_array = env.new_byte_array(0).unwrap();
311 let result = if let Some(ctx) = CONNECTION_HANDLE_MAPPING
312 .lock()
313 .get_mut(&(context_handle as u64))
314 {
315 ctx.decode_message_from_peer::<CryptoProvider, _>(
316 env.convert_byte_array(message)
317 .unwrap()
318 .as_slice(),
319 if associated_data.is_null() {
320 None
321 } else {
322 Some(
323 env.convert_byte_array(associated_data)
324 .unwrap(),
325 )
326 },
327 )
328 .map_err(JniError::DecodeError)
329 } else {
330 Err(JniError::BadHandle)
331 };
332 if let Ok(message) = result {
333 env.byte_array_from_slice(message.as_slice())
334 .expect("unable to create jByteArray")
335 } else {
336 env.throw_new(
337 "com/google/security/cryptauth/lib/securegcm/CryptoException",
338 match result.unwrap_err() {
339 JniError::BadHandle => "Bad context handle",
340 JniError::DecodeError(e) => match e {
341 DecodeError::BadData => "Bad data",
342 DecodeError::BadSequenceNumber => "Bad sequence number",
343 },
344 // None of these should ever occur in this case.
345 JniError::HandleMessageError(_) | JniError::HandshakeError(_) => "Unknown error",
346 },
347 )
348 .expect("failed to find exception class");
349 empty_array
350 }
351 }
352
353 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1encoding( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jint354 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1encoding(
355 mut env: JNIEnv,
356 _: JClass,
357 context_handle: jlong,
358 ) -> jint {
359 if let Some(ctx) = CONNECTION_HANDLE_MAPPING
360 .lock()
361 .get(&(context_handle as u64))
362 {
363 ctx.get_sequence_number_for_encoding() as jint
364 } else {
365 env.throw_new(
366 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
367 "",
368 )
369 .expect("failed to find error class");
370 -1 as jint
371 }
372 }
373
374 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1decoding( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jint375 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1decoding(
376 mut env: JNIEnv,
377 _: JClass,
378 context_handle: jlong,
379 ) -> jint {
380 if let Some(ctx) = CONNECTION_HANDLE_MAPPING
381 .lock()
382 .get(&(context_handle as u64))
383 {
384 ctx.get_sequence_number_for_decoding() as jint
385 } else {
386 env.throw_new(
387 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
388 "",
389 )
390 .expect("failed to find error class");
391 -1 as jint
392 }
393 }
394
395 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_save_1session( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jbyteArray396 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_save_1session(
397 mut env: JNIEnv,
398 _: JClass,
399 context_handle: jlong,
400 ) -> jbyteArray {
401 let empty_array = env.new_byte_array(0).unwrap();
402 if let Some(ctx) = CONNECTION_HANDLE_MAPPING
403 .lock()
404 .get(&(context_handle as u64))
405 {
406 env.byte_array_from_slice(ctx.save_session().as_slice())
407 .expect("unable to save session")
408 } else {
409 env.throw_new(
410 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
411 "",
412 )
413 .expect("failed to find error class");
414 empty_array
415 }
416 }
417
418 #[no_mangle]
419 #[allow(clippy::not_unsafe_ptr_arg_deref)]
420 /// Safety: We know the session_info pointer is safe because it is coming directly from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_from_1saved_1session( mut env: JNIEnv, _: JClass, session_info: jbyteArray, ) -> jlong421 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_from_1saved_1session(
422 mut env: JNIEnv,
423 _: JClass,
424 session_info: jbyteArray,
425 ) -> jlong {
426 let session_info_rust = env
427 .convert_byte_array(session_info)
428 .expect("bad session_info data");
429 let ctx =
430 D2DConnectionContextV1::from_saved_session::<CryptoProvider>(session_info_rust.as_slice());
431 if ctx.is_err() {
432 env.throw_new(
433 "com/google/security/cryptauth/lib/securegcm/SessionRestoreException",
434 match ctx.err().unwrap() {
435 DeserializeError::BadDataLength => "DeserializeError: bad session_info length",
436 DeserializeError::BadProtocolVersion => "DeserializeError: bad protocol version",
437 DeserializeError::BadData => "DeserializeError: bad data",
438 },
439 )
440 .expect("failed to find exception class");
441 return -1;
442 }
443 let final_ctx = ctx.ok().unwrap();
444 let conn_context_final = Box::new(final_ctx);
445 insert_conn_handle(conn_context_final) as jlong
446 }
447
448 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1session_1unique( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jbyteArray449 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1session_1unique(
450 mut env: JNIEnv,
451 _: JClass,
452 context_handle: jlong,
453 ) -> jbyteArray {
454 let empty_array = env.new_byte_array(0).unwrap();
455 if let Some(ctx) = CONNECTION_HANDLE_MAPPING
456 .lock()
457 .get(&(context_handle as u64))
458 {
459 env.byte_array_from_slice(ctx.get_session_unique::<CryptoProvider>().as_slice())
460 .expect("unable to get unique session id")
461 } else {
462 env.throw_new(
463 "com/google/security/cryptauth/lib/securegcm/BadHandleException",
464 "",
465 )
466 .expect("failed to find error class");
467 empty_array
468 }
469 }
470