• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::collections::HashMap;
16 
17 use jni::objects::JClass;
18 use jni::sys::{jboolean, jbyteArray, jint, jlong, JNI_TRUE};
19 use jni::JNIEnv;
20 use lazy_static::lazy_static;
21 use rand::Rng;
22 use rand_chacha::rand_core::SeedableRng;
23 use rand_chacha::ChaCha20Rng;
24 use spin::Mutex;
25 
26 use ukey2_connections::{
27     D2DConnectionContextV1, D2DHandshakeContext, DecodeError, DeserializeError, HandleMessageError,
28     HandshakeError, HandshakeImplementation, InitiatorD2DHandshakeContext,
29     ServerD2DHandshakeContext,
30 };
31 
32 cfg_if::cfg_if! {
33     if #[cfg(feature = "rustcrypto")] {
34         use crypto_provider_rustcrypto::RustCrypto as CryptoProvider;
35     } else {
36         use crypto_provider_openssl::Openssl as CryptoProvider;
37     }
38 }
39 // Handle management
40 
41 type D2DBox = Box<dyn D2DHandshakeContext>;
42 type ConnectionBox = Box<D2DConnectionContextV1>;
43 
44 lazy_static! {
45     static ref HANDLE_MAPPING: Mutex<HashMap<u64, D2DBox>> = Mutex::new(HashMap::new());
46     static ref CONNECTION_HANDLE_MAPPING: Mutex<HashMap<u64, ConnectionBox>> =
47         Mutex::new(HashMap::new());
48     static ref RNG: Mutex<ChaCha20Rng> = Mutex::new(ChaCha20Rng::from_entropy());
49 }
50 
generate_handle() -> u6451 fn generate_handle() -> u64 {
52     RNG.lock().gen()
53 }
54 
insert_handshake_handle(item: D2DBox) -> u6455 pub(crate) fn insert_handshake_handle(item: D2DBox) -> u64 {
56     let handle = generate_handle();
57     HANDLE_MAPPING.lock().insert(handle, item);
58     handle
59 }
60 
insert_conn_handle(item: ConnectionBox) -> u6461 pub(crate) fn insert_conn_handle(item: ConnectionBox) -> u64 {
62     let handle = generate_handle();
63     CONNECTION_HANDLE_MAPPING.lock().insert(handle, item);
64     handle
65 }
66 
67 #[derive(Debug)]
68 enum JniError {
69     BadHandle,
70     DecodeError(DecodeError),
71     HandleMessageError(HandleMessageError),
72     HandshakeError(HandshakeError),
73 }
74 
75 // D2DHandshakeContext
76 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_is_1handshake_1complete( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jboolean77 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_is_1handshake_1complete(
78     mut env: JNIEnv,
79     _: JClass,
80     context_handle: jlong,
81 ) -> jboolean {
82     let mut is_complete = false;
83     if let Some(ctx) = HANDLE_MAPPING.lock().get(&(context_handle as u64)) {
84         is_complete = ctx.is_handshake_complete();
85     } else {
86         env.throw_new(
87             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
88             "",
89         )
90         .expect("failed to find error class");
91     }
92     is_complete as jboolean
93 }
94 
95 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_create_1context( _: JNIEnv, _: JClass, is_client: jboolean, ) -> jlong96 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_create_1context(
97     _: JNIEnv,
98     _: JClass,
99     is_client: jboolean,
100 ) -> jlong {
101     if is_client == JNI_TRUE {
102         let client_obj = Box::new(InitiatorD2DHandshakeContext::<CryptoProvider>::new(
103             HandshakeImplementation::PublicKeyInProtobuf,
104         ));
105         insert_handshake_handle(client_obj) as jlong
106     } else {
107         let server_obj = Box::new(ServerD2DHandshakeContext::<CryptoProvider>::new(
108             HandshakeImplementation::PublicKeyInProtobuf,
109         ));
110         insert_handshake_handle(server_obj) as jlong
111     }
112 }
113 
114 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1next_1handshake_1message( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jbyteArray115 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1next_1handshake_1message(
116     mut env: JNIEnv,
117     _: JClass,
118     context_handle: jlong,
119 ) -> jbyteArray {
120     let empty_arr = env.new_byte_array(0).unwrap();
121     let next_message = if let Some(ctx) = HANDLE_MAPPING.lock().get(&(context_handle as u64)) {
122         ctx.get_next_handshake_message()
123     } else {
124         env.throw_new(
125             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
126             "",
127         )
128         .expect("failed to find error class");
129         None
130     };
131     // TODO error handling
132     if let Some(message) = next_message {
133         env.byte_array_from_slice(message.as_slice()).unwrap()
134     } else {
135         empty_arr
136     }
137 }
138 
139 #[no_mangle]
140 #[allow(clippy::not_unsafe_ptr_arg_deref)]
141 /// Safety: We know the message pointer is safe as it is coming directly from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_parse_1handshake_1message( mut env: JNIEnv, _: JClass, context_handle: jlong, message: jbyteArray, )142 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_parse_1handshake_1message(
143     mut env: JNIEnv,
144     _: JClass,
145     context_handle: jlong,
146     message: jbyteArray,
147 ) {
148     let rust_buffer = env
149         .convert_byte_array(message)
150         .unwrap();
151     let result = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
152         ctx.handle_handshake_message(rust_buffer.as_slice())
153             .map_err(JniError::HandleMessageError)
154     } else {
155         env.throw_new(
156             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
157             "",
158         )
159         .expect("failed to find error class");
160         Err(JniError::BadHandle)
161     };
162     if let Err(e) = result {
163         if !env.exception_check().unwrap() {
164             env.throw_new(
165                 "com/google/security/cryptauth/lib/securegcm/HandshakeException",
166                 match e {
167                     JniError::BadHandle => "Bad handle",
168                     JniError::DecodeError(_) => "Unable to decode message",
169                     JniError::HandleMessageError(_) => "Unable to handle message",
170                     JniError::HandshakeError(_) => "Handshake incomplete",
171                 },
172             )
173             .expect("failed to find error class");
174         }
175     }
176 }
177 
178 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1verification_1string( mut env: JNIEnv, _: JClass, context_handle: jlong, length: jint, ) -> jbyteArray179 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_get_1verification_1string(
180     mut env: JNIEnv,
181     _: JClass,
182     context_handle: jlong,
183     length: jint,
184 ) -> jbyteArray {
185     let empty_array = env.new_byte_array(0).unwrap();
186     let result = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
187         ctx.to_completed_handshake()
188             .map_err(|_| JniError::HandshakeError(HandshakeError::HandshakeNotComplete))
189             .map(|h| {
190                 h.auth_string::<CryptoProvider>()
191                     .derive_vec(length as usize)
192                     .unwrap()
193             })
194     } else {
195         env.throw_new(
196             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
197             "",
198         )
199         .expect("failed to find error class");
200         Err(JniError::BadHandle)
201     };
202     if let Err(e) = result {
203         if !env.exception_check().unwrap() {
204             env.throw_new(
205                 "com/google/security/cryptauth/lib/securegcm/HandshakeException",
206                 match e {
207                     JniError::BadHandle => "Bad handle",
208                     JniError::DecodeError(_) => "Unable to decode message",
209                     JniError::HandleMessageError(_) => "Unable to handle message",
210                     JniError::HandshakeError(_) => "Handshake incomplete",
211                 },
212             )
213             .expect("failed to find error class");
214         }
215         empty_array
216     } else {
217         let ret_vec = result.unwrap();
218         env.byte_array_from_slice(&ret_vec).unwrap()
219     }
220 }
221 
222 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_to_1connection_1context( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jlong223 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DHandshakeContext_to_1connection_1context(
224     mut env: JNIEnv,
225     _: JClass,
226     context_handle: jlong,
227 ) -> jlong {
228     let conn_context = if let Some(ctx) = HANDLE_MAPPING.lock().get_mut(&(context_handle as u64)) {
229         ctx.to_connection_context()
230             .map_err(JniError::HandshakeError)
231     } else {
232         Err(JniError::BadHandle)
233     };
234     if let Err(error) = conn_context {
235         env.throw_new(
236             "com/google/security/cryptauth/lib/securegcm/HandshakeException",
237             match error {
238                 JniError::BadHandle => "Bad context handle",
239                 JniError::HandshakeError(_) => "Handshake not complete",
240                 JniError::DecodeError(_) | JniError::HandleMessageError(_) => "Unknown exception",
241             },
242         )
243         .expect("failed to find error class");
244         return -1;
245     } else {
246         HANDLE_MAPPING.lock().remove(&(context_handle as u64));
247     }
248     insert_conn_handle(Box::new(conn_context.unwrap())) as jlong
249 }
250 
251 // D2DConnectionContextV1
252 #[no_mangle]
253 #[allow(clippy::not_unsafe_ptr_arg_deref)]
254 /// Safety: We know the payload and associated_data pointers are safe as they are coming directly
255 /// from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_encode_1message_1to_1peer( mut env: JNIEnv, _: JClass, context_handle: jlong, payload: jbyteArray, associated_data: jbyteArray, ) -> jbyteArray256 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_encode_1message_1to_1peer(
257     mut env: JNIEnv,
258     _: JClass,
259     context_handle: jlong,
260     payload: jbyteArray,
261     associated_data: jbyteArray,
262 ) -> jbyteArray {
263     // We create the empty array here so we don't run into issues requesting a new byte array from
264     // the JNI env while an exception is being thrown.
265     let empty_array = env.new_byte_array(0).unwrap();
266     let result = if let Some(ctx) = CONNECTION_HANDLE_MAPPING
267         .lock()
268         .get_mut(&(context_handle as u64))
269     {
270         Ok(ctx.encode_message_to_peer::<CryptoProvider, _>(
271             env.convert_byte_array(payload)
272                 .unwrap()
273                 .as_slice(),
274             if associated_data.is_null() {
275                 None
276             } else {
277                 Some(
278                     env.convert_byte_array(associated_data)
279                         .unwrap(),
280                 )
281             },
282         ))
283     } else {
284         Err(JniError::BadHandle)
285     };
286     if let Ok(ret_vec) = result {
287         env.byte_array_from_slice(ret_vec.as_slice())
288             .expect("unable to create jByteArray")
289     } else {
290         env.throw_new(
291             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
292             "",
293         )
294         .expect("failed to find error class");
295         empty_array
296     }
297 }
298 
299 #[no_mangle]
300 #[allow(clippy::not_unsafe_ptr_arg_deref)]
301 /// Safety: We know the message and associated_data pointers are safe as they are coming directly
302 /// from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_decode_1message_1from_1peer( mut env: JNIEnv, _: JClass, context_handle: jlong, message: jbyteArray, associated_data: jbyteArray, ) -> jbyteArray303 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_decode_1message_1from_1peer(
304     mut env: JNIEnv,
305     _: JClass,
306     context_handle: jlong,
307     message: jbyteArray,
308     associated_data: jbyteArray,
309 ) -> jbyteArray {
310     let empty_array = env.new_byte_array(0).unwrap();
311     let result = if let Some(ctx) = CONNECTION_HANDLE_MAPPING
312         .lock()
313         .get_mut(&(context_handle as u64))
314     {
315         ctx.decode_message_from_peer::<CryptoProvider, _>(
316             env.convert_byte_array(message)
317                 .unwrap()
318                 .as_slice(),
319             if associated_data.is_null() {
320                 None
321             } else {
322                 Some(
323                     env.convert_byte_array(associated_data)
324                         .unwrap(),
325                 )
326             },
327         )
328         .map_err(JniError::DecodeError)
329     } else {
330         Err(JniError::BadHandle)
331     };
332     if let Ok(message) = result {
333         env.byte_array_from_slice(message.as_slice())
334             .expect("unable to create jByteArray")
335     } else {
336         env.throw_new(
337             "com/google/security/cryptauth/lib/securegcm/CryptoException",
338             match result.unwrap_err() {
339                 JniError::BadHandle => "Bad context handle",
340                 JniError::DecodeError(e) => match e {
341                     DecodeError::BadData => "Bad data",
342                     DecodeError::BadSequenceNumber => "Bad sequence number",
343                 },
344                 // None of these should ever occur in this case.
345                 JniError::HandleMessageError(_) | JniError::HandshakeError(_) => "Unknown error",
346             },
347         )
348         .expect("failed to find exception class");
349         empty_array
350     }
351 }
352 
353 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1encoding( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jint354 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1encoding(
355     mut env: JNIEnv,
356     _: JClass,
357     context_handle: jlong,
358 ) -> jint {
359     if let Some(ctx) = CONNECTION_HANDLE_MAPPING
360         .lock()
361         .get(&(context_handle as u64))
362     {
363         ctx.get_sequence_number_for_encoding() as jint
364     } else {
365         env.throw_new(
366             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
367             "",
368         )
369         .expect("failed to find error class");
370         -1 as jint
371     }
372 }
373 
374 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1decoding( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jint375 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1sequence_1number_1for_1decoding(
376     mut env: JNIEnv,
377     _: JClass,
378     context_handle: jlong,
379 ) -> jint {
380     if let Some(ctx) = CONNECTION_HANDLE_MAPPING
381         .lock()
382         .get(&(context_handle as u64))
383     {
384         ctx.get_sequence_number_for_decoding() as jint
385     } else {
386         env.throw_new(
387             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
388             "",
389         )
390         .expect("failed to find error class");
391         -1 as jint
392     }
393 }
394 
395 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_save_1session( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jbyteArray396 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_save_1session(
397     mut env: JNIEnv,
398     _: JClass,
399     context_handle: jlong,
400 ) -> jbyteArray {
401     let empty_array = env.new_byte_array(0).unwrap();
402     if let Some(ctx) = CONNECTION_HANDLE_MAPPING
403         .lock()
404         .get(&(context_handle as u64))
405     {
406         env.byte_array_from_slice(ctx.save_session().as_slice())
407             .expect("unable to save session")
408     } else {
409         env.throw_new(
410             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
411             "",
412         )
413         .expect("failed to find error class");
414         empty_array
415     }
416 }
417 
418 #[no_mangle]
419 #[allow(clippy::not_unsafe_ptr_arg_deref)]
420 /// Safety: We know the session_info pointer is safe because it is coming directly from the JVM.
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_from_1saved_1session( mut env: JNIEnv, _: JClass, session_info: jbyteArray, ) -> jlong421 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_from_1saved_1session(
422     mut env: JNIEnv,
423     _: JClass,
424     session_info: jbyteArray,
425 ) -> jlong {
426     let session_info_rust = env
427         .convert_byte_array(session_info)
428         .expect("bad session_info data");
429     let ctx =
430         D2DConnectionContextV1::from_saved_session::<CryptoProvider>(session_info_rust.as_slice());
431     if ctx.is_err() {
432         env.throw_new(
433             "com/google/security/cryptauth/lib/securegcm/SessionRestoreException",
434             match ctx.err().unwrap() {
435                 DeserializeError::BadDataLength => "DeserializeError: bad session_info length",
436                 DeserializeError::BadProtocolVersion => "DeserializeError: bad protocol version",
437                 DeserializeError::BadData => "DeserializeError: bad data",
438             },
439         )
440         .expect("failed to find exception class");
441         return -1;
442     }
443     let final_ctx = ctx.ok().unwrap();
444     let conn_context_final = Box::new(final_ctx);
445     insert_conn_handle(conn_context_final) as jlong
446 }
447 
448 #[no_mangle]
Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1session_1unique( mut env: JNIEnv, _: JClass, context_handle: jlong, ) -> jbyteArray449 pub extern "system" fn Java_com_google_security_cryptauth_lib_securegcm_D2DConnectionContextV1_get_1session_1unique(
450     mut env: JNIEnv,
451     _: JClass,
452     context_handle: jlong,
453 ) -> jbyteArray {
454     let empty_array = env.new_byte_array(0).unwrap();
455     if let Some(ctx) = CONNECTION_HANDLE_MAPPING
456         .lock()
457         .get(&(context_handle as u64))
458     {
459         env.byte_array_from_slice(ctx.get_session_unique::<CryptoProvider>().as_slice())
460             .expect("unable to get unique session id")
461     } else {
462         env.throw_new(
463             "com/google/security/cryptauth/lib/securegcm/BadHandleException",
464             "",
465         )
466         .expect("failed to find error class");
467         empty_array
468     }
469 }
470