• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2022, The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 //! Helper wrapper around RKPD interface.
16 
17 use android_security_rkp_aidl::aidl::android::security::rkp::{
18     IGetKeyCallback::BnGetKeyCallback, IGetKeyCallback::ErrorCode::ErrorCode as GetKeyErrorCode,
19     IGetKeyCallback::IGetKeyCallback, IGetRegistrationCallback::BnGetRegistrationCallback,
20     IGetRegistrationCallback::IGetRegistrationCallback, IRegistration::IRegistration,
21     IRemoteProvisioning::IRemoteProvisioning,
22     IStoreUpgradedKeyCallback::BnStoreUpgradedKeyCallback,
23     IStoreUpgradedKeyCallback::IStoreUpgradedKeyCallback,
24     RemotelyProvisionedKey::RemotelyProvisionedKey,
25 };
26 use anyhow::{Context, Result};
27 use binder::{BinderFeatures, Interface, StatusCode, Strong};
28 use message_macro::source_location_msg;
29 use std::sync::Mutex;
30 use std::time::Duration;
31 use tokio::sync::oneshot;
32 use tokio::time::timeout;
33 
34 // Normally, we block indefinitely when making calls outside of keystore and rely on watchdog to
35 // report deadlocks. However, RKPD is mainline updatable. Also, calls to RKPD may wait on network
36 // for certificates. So, we err on the side of caution and timeout instead.
37 static RKPD_TIMEOUT: Duration = Duration::from_secs(10);
38 
tokio_rt() -> tokio::runtime::Runtime39 fn tokio_rt() -> tokio::runtime::Runtime {
40     tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap()
41 }
42 
43 /// Errors occurred during the interaction with RKPD.
44 #[derive(Debug, Clone, Copy, thiserror::Error, PartialEq, Eq)]
45 pub enum Error {
46     /// An RKPD request gets cancelled.
47     #[error("An RKPD request gets cancelled")]
48     RequestCancelled,
49 
50     /// Failed to get registration.
51     #[error("Failed to get registration")]
52     GetRegistrationFailed,
53 
54     /// Failed to get key.
55     #[error("Failed to get key: {0:?}")]
56     GetKeyFailed(GetKeyErrorCode),
57 
58     /// Failed to store upgraded key.
59     #[error("Failed to store upgraded key")]
60     StoreUpgradedKeyFailed,
61 
62     /// Retryable timeout when waiting for a callback.
63     #[error("Retryable timeout when waiting for a callback")]
64     RetryableTimeout,
65 
66     /// Timeout when waiting for a callback.
67     #[error("Timeout when waiting for a callback")]
68     Timeout,
69 
70     /// Wraps a Binder status code.
71     #[error("Binder transaction error {0:?}")]
72     BinderTransaction(StatusCode),
73 }
74 
75 impl From<StatusCode> for Error {
from(s: StatusCode) -> Self76     fn from(s: StatusCode) -> Self {
77         Self::BinderTransaction(s)
78     }
79 }
80 
81 /// Thread-safe channel for sending a value once and only once. If a value has
82 /// already been send, subsequent calls to send will noop.
83 struct SafeSender<T> {
84     inner: Mutex<Option<oneshot::Sender<T>>>,
85 }
86 
87 impl<T> SafeSender<T> {
new(sender: oneshot::Sender<T>) -> Self88     fn new(sender: oneshot::Sender<T>) -> Self {
89         Self { inner: Mutex::new(Some(sender)) }
90     }
91 
send(&self, value: T)92     fn send(&self, value: T) {
93         if let Some(inner) = self.inner.lock().unwrap().take() {
94             // It's possible for the corresponding receiver to time out and be dropped. In this
95             // case send() will fail. This error is not actionable though, so only log the error.
96             if inner.send(value).is_err() {
97                 log::error!("SafeSender::send() failed");
98             }
99         }
100     }
101 }
102 
103 struct GetRegistrationCallback {
104     registration_tx: SafeSender<Result<binder::Strong<dyn IRegistration>>>,
105 }
106 
107 impl GetRegistrationCallback {
new_native_binder( registration_tx: oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>, ) -> Strong<dyn IGetRegistrationCallback>108     pub fn new_native_binder(
109         registration_tx: oneshot::Sender<Result<binder::Strong<dyn IRegistration>>>,
110     ) -> Strong<dyn IGetRegistrationCallback> {
111         let result: Self =
112             GetRegistrationCallback { registration_tx: SafeSender::new(registration_tx) };
113         BnGetRegistrationCallback::new_binder(result, BinderFeatures::default())
114     }
115 }
116 
117 impl Interface for GetRegistrationCallback {}
118 
119 impl IGetRegistrationCallback for GetRegistrationCallback {
onSuccess(&self, registration: &Strong<dyn IRegistration>) -> binder::Result<()>120     fn onSuccess(&self, registration: &Strong<dyn IRegistration>) -> binder::Result<()> {
121         self.registration_tx.send(Ok(registration.clone()));
122         Ok(())
123     }
onCancel(&self) -> binder::Result<()>124     fn onCancel(&self) -> binder::Result<()> {
125         log::warn!("IGetRegistrationCallback cancelled");
126         self.registration_tx.send(
127             Err(Error::RequestCancelled)
128                 .context(source_location_msg!("GetRegistrationCallback cancelled.")),
129         );
130         Ok(())
131     }
onError(&self, description: &str) -> binder::Result<()>132     fn onError(&self, description: &str) -> binder::Result<()> {
133         log::error!("IGetRegistrationCallback failed: '{description}'");
134         self.registration_tx.send(
135             Err(Error::GetRegistrationFailed)
136                 .context(source_location_msg!("GetRegistrationCallback failed: {:?}", description)),
137         );
138         Ok(())
139     }
140 }
141 
142 /// Make a new connection to a IRegistration service.
get_rkpd_registration(rpc_name: &str) -> Result<binder::Strong<dyn IRegistration>>143 async fn get_rkpd_registration(rpc_name: &str) -> Result<binder::Strong<dyn IRegistration>> {
144     let remote_provisioning: Strong<dyn IRemoteProvisioning> =
145         binder::get_interface("remote_provisioning")
146             .map_err(Error::from)
147             .context(source_location_msg!("Trying to connect to IRemoteProvisioning service."))?;
148 
149     let (tx, rx) = oneshot::channel();
150     let cb = GetRegistrationCallback::new_native_binder(tx);
151 
152     remote_provisioning
153         .getRegistration(rpc_name, &cb)
154         .context(source_location_msg!("Trying to get registration."))?;
155 
156     match timeout(RKPD_TIMEOUT, rx).await {
157         Err(e) => Err(Error::Timeout).context(source_location_msg!("Waiting for RKPD: {:?}", e)),
158         Ok(v) => v.unwrap(),
159     }
160 }
161 
162 struct GetKeyCallback {
163     key_tx: SafeSender<Result<RemotelyProvisionedKey>>,
164 }
165 
166 impl GetKeyCallback {
new_native_binder( key_tx: oneshot::Sender<Result<RemotelyProvisionedKey>>, ) -> Strong<dyn IGetKeyCallback>167     pub fn new_native_binder(
168         key_tx: oneshot::Sender<Result<RemotelyProvisionedKey>>,
169     ) -> Strong<dyn IGetKeyCallback> {
170         let result: Self = GetKeyCallback { key_tx: SafeSender::new(key_tx) };
171         BnGetKeyCallback::new_binder(result, BinderFeatures::default())
172     }
173 }
174 
175 impl Interface for GetKeyCallback {}
176 
177 impl IGetKeyCallback for GetKeyCallback {
onSuccess(&self, key: &RemotelyProvisionedKey) -> binder::Result<()>178     fn onSuccess(&self, key: &RemotelyProvisionedKey) -> binder::Result<()> {
179         self.key_tx.send(Ok(RemotelyProvisionedKey {
180             keyBlob: key.keyBlob.clone(),
181             encodedCertChain: key.encodedCertChain.clone(),
182         }));
183         Ok(())
184     }
onCancel(&self) -> binder::Result<()>185     fn onCancel(&self) -> binder::Result<()> {
186         log::warn!("IGetKeyCallback cancelled");
187         self.key_tx.send(
188             Err(Error::RequestCancelled).context(source_location_msg!("GetKeyCallback cancelled.")),
189         );
190         Ok(())
191     }
onError(&self, error: GetKeyErrorCode, description: &str) -> binder::Result<()>192     fn onError(&self, error: GetKeyErrorCode, description: &str) -> binder::Result<()> {
193         log::error!("IGetKeyCallback failed: {description}");
194         self.key_tx.send(Err(Error::GetKeyFailed(error)).context(source_location_msg!(
195             "GetKeyCallback failed: {:?} {:?}",
196             error,
197             description
198         )));
199         Ok(())
200     }
201 }
202 
get_rkpd_attestation_key_from_registration_async( registration: &Strong<dyn IRegistration>, caller_uid: u32, ) -> Result<RemotelyProvisionedKey>203 async fn get_rkpd_attestation_key_from_registration_async(
204     registration: &Strong<dyn IRegistration>,
205     caller_uid: u32,
206 ) -> Result<RemotelyProvisionedKey> {
207     let (tx, rx) = oneshot::channel();
208     let cb = GetKeyCallback::new_native_binder(tx);
209 
210     registration
211         .getKey(caller_uid.try_into().unwrap(), &cb)
212         .context(source_location_msg!("Trying to get key."))?;
213 
214     match timeout(RKPD_TIMEOUT, rx).await {
215         Err(e) => {
216             // Make a best effort attempt to cancel the timed out request.
217             if let Err(e) = registration.cancelGetKey(&cb) {
218                 log::error!("IRegistration::cancelGetKey failed: {:?}", e);
219             }
220             Err(Error::RetryableTimeout)
221                 .context(source_location_msg!("Waiting for RKPD key timed out: {:?}", e))
222         }
223         Ok(v) => v.unwrap(),
224     }
225 }
226 
get_rkpd_attestation_key_async( rpc_name: &str, caller_uid: u32, ) -> Result<RemotelyProvisionedKey>227 async fn get_rkpd_attestation_key_async(
228     rpc_name: &str,
229     caller_uid: u32,
230 ) -> Result<RemotelyProvisionedKey> {
231     let registration = get_rkpd_registration(rpc_name)
232         .await
233         .context(source_location_msg!("Trying to get to IRegistration service."))?;
234     get_rkpd_attestation_key_from_registration_async(&registration, caller_uid).await
235 }
236 
237 struct StoreUpgradedKeyCallback {
238     completer: SafeSender<Result<()>>,
239 }
240 
241 impl StoreUpgradedKeyCallback {
new_native_binder( completer: oneshot::Sender<Result<()>>, ) -> Strong<dyn IStoreUpgradedKeyCallback>242     pub fn new_native_binder(
243         completer: oneshot::Sender<Result<()>>,
244     ) -> Strong<dyn IStoreUpgradedKeyCallback> {
245         let result: Self = StoreUpgradedKeyCallback { completer: SafeSender::new(completer) };
246         BnStoreUpgradedKeyCallback::new_binder(result, BinderFeatures::default())
247     }
248 }
249 
250 impl Interface for StoreUpgradedKeyCallback {}
251 
252 impl IStoreUpgradedKeyCallback for StoreUpgradedKeyCallback {
onSuccess(&self) -> binder::Result<()>253     fn onSuccess(&self) -> binder::Result<()> {
254         self.completer.send(Ok(()));
255         Ok(())
256     }
257 
onError(&self, error: &str) -> binder::Result<()>258     fn onError(&self, error: &str) -> binder::Result<()> {
259         log::error!("IStoreUpgradedKeyCallback failed: {error}");
260         self.completer.send(
261             Err(Error::StoreUpgradedKeyFailed)
262                 .context(source_location_msg!("Failed to store upgraded key: {:?}", error)),
263         );
264         Ok(())
265     }
266 }
267 
store_rkpd_attestation_key_with_registration_async( registration: &Strong<dyn IRegistration>, key_blob: &[u8], upgraded_blob: &[u8], ) -> Result<()>268 async fn store_rkpd_attestation_key_with_registration_async(
269     registration: &Strong<dyn IRegistration>,
270     key_blob: &[u8],
271     upgraded_blob: &[u8],
272 ) -> Result<()> {
273     let (tx, rx) = oneshot::channel();
274     let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
275 
276     registration
277         .storeUpgradedKeyAsync(key_blob, upgraded_blob, &cb)
278         .context(source_location_msg!("Failed to store upgraded blob with RKPD."))?;
279 
280     match timeout(RKPD_TIMEOUT, rx).await {
281         Err(e) => Err(Error::Timeout)
282             .context(source_location_msg!("Waiting for RKPD to complete storing key: {:?}", e)),
283         Ok(v) => v.unwrap(),
284     }
285 }
286 
store_rkpd_attestation_key_async( rpc_name: &str, key_blob: &[u8], upgraded_blob: &[u8], ) -> Result<()>287 async fn store_rkpd_attestation_key_async(
288     rpc_name: &str,
289     key_blob: &[u8],
290     upgraded_blob: &[u8],
291 ) -> Result<()> {
292     let registration = get_rkpd_registration(rpc_name)
293         .await
294         .context(source_location_msg!("Trying to get to IRegistration service."))?;
295     store_rkpd_attestation_key_with_registration_async(&registration, key_blob, upgraded_blob).await
296 }
297 
298 /// Get attestation key from RKPD.
get_rkpd_attestation_key(rpc_name: &str, caller_uid: u32) -> Result<RemotelyProvisionedKey>299 pub fn get_rkpd_attestation_key(rpc_name: &str, caller_uid: u32) -> Result<RemotelyProvisionedKey> {
300     tokio_rt().block_on(get_rkpd_attestation_key_async(rpc_name, caller_uid))
301 }
302 
303 /// Store attestation key in RKPD.
store_rkpd_attestation_key( rpc_name: &str, key_blob: &[u8], upgraded_blob: &[u8], ) -> Result<()>304 pub fn store_rkpd_attestation_key(
305     rpc_name: &str,
306     key_blob: &[u8],
307     upgraded_blob: &[u8],
308 ) -> Result<()> {
309     tokio_rt().block_on(store_rkpd_attestation_key_async(rpc_name, key_blob, upgraded_blob))
310 }
311 
312 #[cfg(test)]
313 mod tests {
314     use super::*;
315     use android_security_rkp_aidl::aidl::android::security::rkp::IRegistration::BnRegistration;
316     use std::sync::atomic::{AtomicU32, Ordering};
317     use std::sync::{Arc, Mutex};
318 
319     const DEFAULT_RPC_SERVICE_NAME: &str =
320         "android.hardware.security.keymint.IRemotelyProvisionedComponent/default";
321 
322     struct MockRegistrationValues {
323         key: RemotelyProvisionedKey,
324         latency: Option<Duration>,
325         thread_join_handles: Vec<Option<std::thread::JoinHandle<()>>>,
326     }
327 
328     struct MockRegistration(Arc<Mutex<MockRegistrationValues>>);
329 
330     impl MockRegistration {
new_native_binder( key: &RemotelyProvisionedKey, latency: Option<Duration>, ) -> Strong<dyn IRegistration>331         pub fn new_native_binder(
332             key: &RemotelyProvisionedKey,
333             latency: Option<Duration>,
334         ) -> Strong<dyn IRegistration> {
335             let result = Self(Arc::new(Mutex::new(MockRegistrationValues {
336                 key: RemotelyProvisionedKey {
337                     keyBlob: key.keyBlob.clone(),
338                     encodedCertChain: key.encodedCertChain.clone(),
339                 },
340                 latency,
341                 thread_join_handles: Vec::new(),
342             })));
343             BnRegistration::new_binder(result, BinderFeatures::default())
344         }
345     }
346 
347     impl Drop for MockRegistration {
drop(&mut self)348         fn drop(&mut self) {
349             let mut values = self.0.lock().unwrap();
350             for handle in values.thread_join_handles.iter_mut() {
351                 // These are test threads. So, no need to worry too much about error handling.
352                 handle.take().unwrap().join().unwrap();
353             }
354         }
355     }
356 
357     impl Interface for MockRegistration {}
358 
359     impl IRegistration for MockRegistration {
getKey(&self, _: i32, cb: &Strong<dyn IGetKeyCallback>) -> binder::Result<()>360         fn getKey(&self, _: i32, cb: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
361             let mut values = self.0.lock().unwrap();
362             let key = RemotelyProvisionedKey {
363                 keyBlob: values.key.keyBlob.clone(),
364                 encodedCertChain: values.key.encodedCertChain.clone(),
365             };
366             let latency = values.latency;
367             let get_key_cb = cb.clone();
368 
369             // Need a separate thread to trigger timeout in the caller.
370             let join_handle = std::thread::spawn(move || {
371                 if let Some(duration) = latency {
372                     std::thread::sleep(duration);
373                 }
374                 get_key_cb.onSuccess(&key).unwrap();
375             });
376             values.thread_join_handles.push(Some(join_handle));
377             Ok(())
378         }
379 
cancelGetKey(&self, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()>380         fn cancelGetKey(&self, _: &Strong<dyn IGetKeyCallback>) -> binder::Result<()> {
381             Ok(())
382         }
383 
storeUpgradedKeyAsync( &self, _: &[u8], _: &[u8], cb: &Strong<dyn IStoreUpgradedKeyCallback>, ) -> binder::Result<()>384         fn storeUpgradedKeyAsync(
385             &self,
386             _: &[u8],
387             _: &[u8],
388             cb: &Strong<dyn IStoreUpgradedKeyCallback>,
389         ) -> binder::Result<()> {
390             // We are primarily concerned with timing out correctly. Storing the key in this mock
391             // registration isn't particularly interesting, so skip that part.
392             let values = self.0.lock().unwrap();
393             let store_cb = cb.clone();
394             let latency = values.latency;
395 
396             std::thread::spawn(move || {
397                 if let Some(duration) = latency {
398                     std::thread::sleep(duration);
399                 }
400                 store_cb.onSuccess().unwrap();
401             });
402             Ok(())
403         }
404     }
405 
get_mock_registration( key: &RemotelyProvisionedKey, latency: Option<Duration>, ) -> Result<binder::Strong<dyn IRegistration>>406     fn get_mock_registration(
407         key: &RemotelyProvisionedKey,
408         latency: Option<Duration>,
409     ) -> Result<binder::Strong<dyn IRegistration>> {
410         let (tx, rx) = oneshot::channel();
411         let cb = GetRegistrationCallback::new_native_binder(tx);
412         let mock_registration = MockRegistration::new_native_binder(key, latency);
413 
414         assert!(cb.onSuccess(&mock_registration).is_ok());
415         tokio_rt().block_on(rx).unwrap()
416     }
417 
418     // Using the same key ID makes test cases race with each other. So, we use separate key IDs for
419     // different test cases.
get_next_key_id() -> u32420     fn get_next_key_id() -> u32 {
421         static ID: AtomicU32 = AtomicU32::new(0);
422         ID.fetch_add(1, Ordering::Relaxed)
423     }
424 
425     #[test]
test_get_registration_cb_success()426     fn test_get_registration_cb_success() {
427         let key: RemotelyProvisionedKey = Default::default();
428         let registration = get_mock_registration(&key, /*latency=*/ None);
429         assert!(registration.is_ok());
430     }
431 
432     #[test]
test_get_registration_cb_cancel()433     fn test_get_registration_cb_cancel() {
434         let (tx, rx) = oneshot::channel();
435         let cb = GetRegistrationCallback::new_native_binder(tx);
436         assert!(cb.onCancel().is_ok());
437 
438         let result = tokio_rt().block_on(rx).unwrap();
439         assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::RequestCancelled);
440     }
441 
442     #[test]
test_get_registration_cb_error()443     fn test_get_registration_cb_error() {
444         let (tx, rx) = oneshot::channel();
445         let cb = GetRegistrationCallback::new_native_binder(tx);
446         assert!(cb.onError("error").is_ok());
447 
448         let result = tokio_rt().block_on(rx).unwrap();
449         assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::GetRegistrationFailed);
450     }
451 
452     #[test]
test_get_key_cb_success()453     fn test_get_key_cb_success() {
454         let mock_key =
455             RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
456         let (tx, rx) = oneshot::channel();
457         let cb = GetKeyCallback::new_native_binder(tx);
458         assert!(cb.onSuccess(&mock_key).is_ok());
459 
460         let key = tokio_rt().block_on(rx).unwrap().unwrap();
461         assert_eq!(key, mock_key);
462     }
463 
464     #[test]
test_get_key_cb_cancel()465     fn test_get_key_cb_cancel() {
466         let (tx, rx) = oneshot::channel();
467         let cb = GetKeyCallback::new_native_binder(tx);
468         assert!(cb.onCancel().is_ok());
469 
470         let result = tokio_rt().block_on(rx).unwrap();
471         assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::RequestCancelled);
472     }
473 
474     #[test]
test_get_key_cb_error()475     fn test_get_key_cb_error() {
476         for get_key_error in GetKeyErrorCode::enum_values() {
477             let (tx, rx) = oneshot::channel();
478             let cb = GetKeyCallback::new_native_binder(tx);
479             assert!(cb.onError(get_key_error, "error").is_ok());
480 
481             let result = tokio_rt().block_on(rx).unwrap();
482             assert_eq!(
483                 result.unwrap_err().downcast::<Error>().unwrap(),
484                 Error::GetKeyFailed(get_key_error),
485             );
486         }
487     }
488 
489     #[test]
test_store_upgraded_cb_success()490     fn test_store_upgraded_cb_success() {
491         let (tx, rx) = oneshot::channel();
492         let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
493         assert!(cb.onSuccess().is_ok());
494 
495         tokio_rt().block_on(rx).unwrap().unwrap();
496     }
497 
498     #[test]
test_store_upgraded_key_cb_error()499     fn test_store_upgraded_key_cb_error() {
500         let (tx, rx) = oneshot::channel();
501         let cb = StoreUpgradedKeyCallback::new_native_binder(tx);
502         assert!(cb.onError("oh no! it failed").is_ok());
503 
504         let result = tokio_rt().block_on(rx).unwrap();
505         assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::StoreUpgradedKeyFailed);
506     }
507 
508     #[test]
test_get_mock_key_success()509     fn test_get_mock_key_success() {
510         let mock_key =
511             RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
512         let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
513 
514         let key = tokio_rt()
515             .block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0))
516             .unwrap();
517         assert_eq!(key, mock_key);
518     }
519 
520     #[test]
test_get_mock_key_timeout()521     fn test_get_mock_key_timeout() {
522         let mock_key =
523             RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
524         let latency = RKPD_TIMEOUT + Duration::from_secs(1);
525         let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
526 
527         let result =
528             tokio_rt().block_on(get_rkpd_attestation_key_from_registration_async(&registration, 0));
529         assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::RetryableTimeout);
530     }
531 
532     #[test]
test_store_mock_key_success()533     fn test_store_mock_key_success() {
534         let mock_key =
535             RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
536         let registration = get_mock_registration(&mock_key, /*latency=*/ None).unwrap();
537         tokio_rt()
538             .block_on(store_rkpd_attestation_key_with_registration_async(&registration, &[], &[]))
539             .unwrap();
540     }
541 
542     #[test]
test_store_mock_key_timeout()543     fn test_store_mock_key_timeout() {
544         let mock_key =
545             RemotelyProvisionedKey { keyBlob: vec![1, 2, 3], encodedCertChain: vec![4, 5, 6] };
546         let latency = RKPD_TIMEOUT + Duration::from_secs(1);
547         let registration = get_mock_registration(&mock_key, Some(latency)).unwrap();
548 
549         let result = tokio_rt().block_on(store_rkpd_attestation_key_with_registration_async(
550             &registration,
551             &[],
552             &[],
553         ));
554         assert_eq!(result.unwrap_err().downcast::<Error>().unwrap(), Error::Timeout);
555     }
556 
557     #[test]
test_get_rkpd_attestation_key()558     fn test_get_rkpd_attestation_key() {
559         binder::ProcessState::start_thread_pool();
560         let key_id = get_next_key_id();
561         let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
562         assert!(!key.keyBlob.is_empty());
563         assert!(!key.encodedCertChain.is_empty());
564     }
565 
566     #[test]
test_get_rkpd_attestation_key_same_caller()567     fn test_get_rkpd_attestation_key_same_caller() {
568         binder::ProcessState::start_thread_pool();
569         let key_id = get_next_key_id();
570 
571         // Multiple calls should return the same key.
572         let first_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
573         let second_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
574 
575         assert_eq!(first_key.keyBlob, second_key.keyBlob);
576         assert_eq!(first_key.encodedCertChain, second_key.encodedCertChain);
577     }
578 
579     #[test]
test_get_rkpd_attestation_key_different_caller()580     fn test_get_rkpd_attestation_key_different_caller() {
581         binder::ProcessState::start_thread_pool();
582         let first_key_id = get_next_key_id();
583         let second_key_id = get_next_key_id();
584 
585         // Different callers should be getting different keys.
586         let first_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, first_key_id).unwrap();
587         let second_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, second_key_id).unwrap();
588 
589         assert_ne!(first_key.keyBlob, second_key.keyBlob);
590         assert_ne!(first_key.encodedCertChain, second_key.encodedCertChain);
591     }
592 
593     #[test]
594     // Couple of things to note:
595     // 1. This test must never run with UID of keystore. Otherwise, it can mess up keys stored by
596     //    keystore.
597     // 2. Storing and reading the stored key is prone to race condition. So, we only do this in one
598     //    test case.
test_store_rkpd_attestation_key()599     fn test_store_rkpd_attestation_key() {
600         binder::ProcessState::start_thread_pool();
601         let key_id = get_next_key_id();
602         let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
603         let new_blob: [u8; 8] = rand::random();
604 
605         assert!(
606             store_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, &key.keyBlob, &new_blob).is_ok()
607         );
608 
609         let new_key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
610 
611         // Restore original key so that we don't leave RKPD with invalid blobs.
612         assert!(
613             store_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, &new_blob, &key.keyBlob).is_ok()
614         );
615         assert_eq!(new_key.keyBlob, new_blob);
616     }
617 
618     #[test]
test_stress_get_rkpd_attestation_key()619     fn test_stress_get_rkpd_attestation_key() {
620         binder::ProcessState::start_thread_pool();
621         let key_id = get_next_key_id();
622         let mut threads = vec![];
623         const NTHREADS: u32 = 10;
624         const NCALLS: u32 = 1000;
625 
626         for _ in 0..NTHREADS {
627             threads.push(std::thread::spawn(move || {
628                 for _ in 0..NCALLS {
629                     let key = get_rkpd_attestation_key(DEFAULT_RPC_SERVICE_NAME, key_id).unwrap();
630                     assert!(!key.keyBlob.is_empty());
631                     assert!(!key.encodedCertChain.is_empty());
632                 }
633             }));
634         }
635 
636         for t in threads {
637             assert!(t.join().is_ok());
638         }
639     }
640 }
641