moves filter_credentials to call side

This commit is contained in:
Fabian Kaczmarczyck
2021-01-12 07:01:25 +01:00
parent 4cee0c4c65
commit 27a7108328
2 changed files with 49 additions and 115 deletions

View File

@@ -761,11 +761,23 @@ where
vec![], vec![],
) )
} else { } else {
let mut stored_credentials = let mut iter_result = Ok(());
self.persistent_store.filter_credentials(&rp_id, !has_uv)?; let iter = self.persistent_store.iter_credentials(&mut iter_result)?;
stored_credentials.sort_unstable_by_key(|c| c.1); let mut stored_credentials: Vec<(usize, u64)> = iter
let mut stored_credentials: Vec<usize> = .filter_map(|(key, credential)| {
stored_credentials.into_iter().map(|c| c.0).collect(); if credential.rp_id == rp_id && (has_uv || credential.is_discoverable()) {
Some((key, credential.creation_order))
} else {
None
}
})
.collect();
iter_result?;
stored_credentials.sort_unstable_by_key(|&(_key, order)| order);
let mut stored_credentials: Vec<usize> = stored_credentials
.into_iter()
.map(|(key, _order)| key)
.collect();
let credential = stored_credentials let credential = stored_credentials
.pop() .pop()
.map(|key| self.persistent_store.get_credential(key)) .map(|key| self.persistent_store.get_credential(key))
@@ -1252,17 +1264,14 @@ mod test {
ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID); ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID);
assert!(make_credential_response.is_ok()); assert!(make_credential_response.is_ok());
let credential_key = ctap_state let mut iter_result = Ok(());
let iter = ctap_state
.persistent_store .persistent_store
.filter_credentials("example.com", false) .iter_credentials(&mut iter_result)
.unwrap()
.pop()
.unwrap()
.0;
let stored_credential = ctap_state
.persistent_store
.get_credential(credential_key)
.unwrap(); .unwrap();
// There is only 1 credential, so last is good enough.
let (_, stored_credential) = iter.last().unwrap();
iter_result.unwrap();
let credential_id = stored_credential.credential_id; let credential_id = stored_credential.credential_id;
assert_eq!(stored_credential.cred_protect_policy, Some(test_policy)); assert_eq!(stored_credential.cred_protect_policy, Some(test_policy));
@@ -1282,17 +1291,14 @@ mod test {
ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID); ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID);
assert!(make_credential_response.is_ok()); assert!(make_credential_response.is_ok());
let credential_key = ctap_state let mut iter_result = Ok(());
let iter = ctap_state
.persistent_store .persistent_store
.filter_credentials("example.com", false) .iter_credentials(&mut iter_result)
.unwrap()
.pop()
.unwrap()
.0;
let stored_credential = ctap_state
.persistent_store
.get_credential(credential_key)
.unwrap(); .unwrap();
// There is only 1 credential, so last is good enough.
let (_, stored_credential) = iter.last().unwrap();
iter_result.unwrap();
let credential_id = stored_credential.credential_id; let credential_id = stored_credential.credential_id;
assert_eq!(stored_credential.cred_protect_policy, Some(test_policy)); assert_eq!(stored_credential.cred_protect_policy, Some(test_policy));

View File

@@ -171,7 +171,7 @@ impl PersistentStore {
let credential = match self.find_credential_item(credential_id) { let credential = match self.find_credential_item(credential_id) {
Err(Ctap2StatusCode::CTAP2_ERR_NO_CREDENTIALS) => return Ok(None), Err(Ctap2StatusCode::CTAP2_ERR_NO_CREDENTIALS) => return Ok(None),
Err(e) => return Err(e), Err(e) => return Err(e),
Ok(credential) => credential.1, Ok((_key, credential)) => credential,
}; };
let is_protected = credential.cred_protect_policy let is_protected = credential.cred_protect_policy
== Some(CredentialProtectionPolicy::UserVerificationRequired); == Some(CredentialProtectionPolicy::UserVerificationRequired);
@@ -261,31 +261,6 @@ impl PersistentStore {
Ok(self.store.insert(key, &value)?) Ok(self.store.insert(key, &value)?)
} }
/// Returns the list of matching credentials.
///
/// Does not return credentials that are not discoverable if `check_cred_protect` is set.
pub fn filter_credentials(
&self,
rp_id: &str,
check_cred_protect: bool,
) -> Result<Vec<(usize, u64)>, Ctap2StatusCode> {
let mut iter_result = Ok(());
let iter = self.iter_credentials(&mut iter_result)?;
let result = iter
.filter_map(|(key, credential)| {
if credential.rp_id == rp_id
&& (!check_cred_protect || credential.is_discoverable())
{
Some((key, credential.creation_order))
} else {
None
}
})
.collect();
iter_result?;
Ok(result)
}
/// Returns the number of credentials. /// Returns the number of credentials.
pub fn count_credentials(&self) -> Result<usize, Ctap2StatusCode> { pub fn count_credentials(&self) -> Result<usize, Ctap2StatusCode> {
let mut iter_result = Ok(()); let mut iter_result = Ok(());
@@ -811,7 +786,8 @@ mod test {
// These should have different IDs. // These should have different IDs.
let credential_source0 = create_credential_source(&mut rng, "example.com", vec![0x00]); let credential_source0 = create_credential_source(&mut rng, "example.com", vec![0x00]);
let credential_source1 = create_credential_source(&mut rng, "example.com", vec![0x00]); let credential_source1 = create_credential_source(&mut rng, "example.com", vec![0x00]);
let expected_credential = credential_source1.clone(); let credential_id0 = credential_source0.credential_id.clone();
let credential_id1 = credential_source1.credential_id.clone();
assert!(persistent_store assert!(persistent_store
.store_credential(credential_source0) .store_credential(credential_source0)
@@ -820,13 +796,14 @@ mod test {
.store_credential(credential_source1) .store_credential(credential_source1)
.is_ok()); .is_ok());
assert_eq!(persistent_store.count_credentials().unwrap(), 1); assert_eq!(persistent_store.count_credentials().unwrap(), 1);
let filtered_credentials = persistent_store assert!(persistent_store
.filter_credentials("example.com", false) .find_credential("example.com", &credential_id0, false)
.unwrap(); .unwrap()
let retrieved_credential_source = persistent_store .is_none());
.get_credential(filtered_credentials[0].0) assert!(persistent_store
.unwrap(); .find_credential("example.com", &credential_id1, false)
assert_eq!(retrieved_credential_source, expected_credential); .unwrap()
.is_some());
let mut persistent_store = PersistentStore::new(&mut rng); let mut persistent_store = PersistentStore::new(&mut rng);
for i in 0..MAX_SUPPORTED_RESIDENTIAL_KEYS { for i in 0..MAX_SUPPORTED_RESIDENTIAL_KEYS {
@@ -851,70 +828,21 @@ mod test {
} }
#[test] #[test]
fn test_filter_get_credentials() { fn test_get_credential() {
let mut rng = ThreadRng256 {}; let mut rng = ThreadRng256 {};
let mut persistent_store = PersistentStore::new(&mut rng); let mut persistent_store = PersistentStore::new(&mut rng);
assert_eq!(persistent_store.count_credentials().unwrap(), 0);
let credential_source0 = create_credential_source(&mut rng, "example.com", vec![0x00]); let credential_source0 = create_credential_source(&mut rng, "example.com", vec![0x00]);
let credential_source1 = create_credential_source(&mut rng, "example.com", vec![0x01]); let credential_source1 = create_credential_source(&mut rng, "example.com", vec![0x01]);
let credential_source2 = let credential_source2 =
create_credential_source(&mut rng, "another.example.com", vec![0x02]); create_credential_source(&mut rng, "another.example.com", vec![0x02]);
let id0 = credential_source0.credential_id.clone(); let credential_sources = vec![credential_source0, credential_source1, credential_source2];
let id1 = credential_source1.credential_id.clone(); for credential_source in credential_sources.into_iter() {
assert!(persistent_store let cred_id = credential_source.credential_id.clone();
.store_credential(credential_source0) assert!(persistent_store.store_credential(credential_source).is_ok());
.is_ok()); let (key, _) = persistent_store.find_credential_item(&cred_id).unwrap();
assert!(persistent_store let cred = persistent_store.get_credential(key).unwrap();
.store_credential(credential_source1) assert_eq!(&cred_id, &cred.credential_id);
.is_ok()); }
assert!(persistent_store
.store_credential(credential_source2)
.is_ok());
let filtered_credentials = persistent_store
.filter_credentials("example.com", false)
.unwrap();
assert_eq!(filtered_credentials.len(), 2);
let retrieved_credential0 = persistent_store
.get_credential(filtered_credentials[0].0)
.unwrap();
let retrieved_credential1 = persistent_store
.get_credential(filtered_credentials[1].0)
.unwrap();
assert!(
(retrieved_credential0.credential_id == id0
&& retrieved_credential1.credential_id == id1)
|| (retrieved_credential1.credential_id == id0
&& retrieved_credential0.credential_id == id1)
);
}
#[test]
fn test_filter_with_cred_protect() {
let mut rng = ThreadRng256 {};
let mut persistent_store = PersistentStore::new(&mut rng);
assert_eq!(persistent_store.count_credentials().unwrap(), 0);
let private_key = crypto::ecdsa::SecKey::gensk(&mut rng);
let credential = PublicKeyCredentialSource {
key_type: PublicKeyCredentialType::PublicKey,
credential_id: rng.gen_uniform_u8x32().to_vec(),
private_key,
rp_id: String::from("example.com"),
user_handle: vec![0x00],
user_display_name: None,
cred_protect_policy: Some(
CredentialProtectionPolicy::UserVerificationOptionalWithCredentialIdList,
),
creation_order: 0,
user_name: None,
user_icon: None,
};
assert!(persistent_store.store_credential(credential).is_ok());
let no_credential = persistent_store
.filter_credentials("example.com", true)
.unwrap();
assert_eq!(no_credential, vec![]);
} }
#[test] #[test]