From 27a7108328fcea9bba7bfe1b1411167e9a7551ef Mon Sep 17 00:00:00 2001 From: Fabian Kaczmarczyck Date: Tue, 12 Jan 2021 07:01:25 +0100 Subject: [PATCH] moves filter_credentials to call side --- src/ctap/mod.rs | 52 +++++++++++--------- src/ctap/storage.rs | 112 ++++++++------------------------------------ 2 files changed, 49 insertions(+), 115 deletions(-) diff --git a/src/ctap/mod.rs b/src/ctap/mod.rs index 5072a47..2047500 100644 --- a/src/ctap/mod.rs +++ b/src/ctap/mod.rs @@ -761,11 +761,23 @@ where vec![], ) } else { - let mut stored_credentials = - self.persistent_store.filter_credentials(&rp_id, !has_uv)?; - stored_credentials.sort_unstable_by_key(|c| c.1); - let mut stored_credentials: Vec = - stored_credentials.into_iter().map(|c| c.0).collect(); + let mut iter_result = Ok(()); + let iter = self.persistent_store.iter_credentials(&mut iter_result)?; + let mut stored_credentials: Vec<(usize, u64)> = iter + .filter_map(|(key, credential)| { + if credential.rp_id == rp_id && (has_uv || credential.is_discoverable()) { + Some((key, credential.creation_order)) + } else { + None + } + }) + .collect(); + iter_result?; + stored_credentials.sort_unstable_by_key(|&(_key, order)| order); + let mut stored_credentials: Vec = stored_credentials + .into_iter() + .map(|(key, _order)| key) + .collect(); let credential = stored_credentials .pop() .map(|key| self.persistent_store.get_credential(key)) @@ -1252,17 +1264,14 @@ mod test { ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID); assert!(make_credential_response.is_ok()); - let credential_key = ctap_state + let mut iter_result = Ok(()); + let iter = ctap_state .persistent_store - .filter_credentials("example.com", false) - .unwrap() - .pop() - .unwrap() - .0; - let stored_credential = ctap_state - .persistent_store - .get_credential(credential_key) + .iter_credentials(&mut iter_result) .unwrap(); + // There is only 1 credential, so last is good enough. + let (_, stored_credential) = iter.last().unwrap(); + iter_result.unwrap(); let credential_id = stored_credential.credential_id; assert_eq!(stored_credential.cred_protect_policy, Some(test_policy)); @@ -1282,17 +1291,14 @@ mod test { ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID); assert!(make_credential_response.is_ok()); - let credential_key = ctap_state + let mut iter_result = Ok(()); + let iter = ctap_state .persistent_store - .filter_credentials("example.com", false) - .unwrap() - .pop() - .unwrap() - .0; - let stored_credential = ctap_state - .persistent_store - .get_credential(credential_key) + .iter_credentials(&mut iter_result) .unwrap(); + // There is only 1 credential, so last is good enough. + let (_, stored_credential) = iter.last().unwrap(); + iter_result.unwrap(); let credential_id = stored_credential.credential_id; assert_eq!(stored_credential.cred_protect_policy, Some(test_policy)); diff --git a/src/ctap/storage.rs b/src/ctap/storage.rs index 343751a..8df987d 100644 --- a/src/ctap/storage.rs +++ b/src/ctap/storage.rs @@ -171,7 +171,7 @@ impl PersistentStore { let credential = match self.find_credential_item(credential_id) { Err(Ctap2StatusCode::CTAP2_ERR_NO_CREDENTIALS) => return Ok(None), Err(e) => return Err(e), - Ok(credential) => credential.1, + Ok((_key, credential)) => credential, }; let is_protected = credential.cred_protect_policy == Some(CredentialProtectionPolicy::UserVerificationRequired); @@ -261,31 +261,6 @@ impl PersistentStore { Ok(self.store.insert(key, &value)?) } - /// Returns the list of matching credentials. - /// - /// Does not return credentials that are not discoverable if `check_cred_protect` is set. - pub fn filter_credentials( - &self, - rp_id: &str, - check_cred_protect: bool, - ) -> Result, Ctap2StatusCode> { - let mut iter_result = Ok(()); - let iter = self.iter_credentials(&mut iter_result)?; - let result = iter - .filter_map(|(key, credential)| { - if credential.rp_id == rp_id - && (!check_cred_protect || credential.is_discoverable()) - { - Some((key, credential.creation_order)) - } else { - None - } - }) - .collect(); - iter_result?; - Ok(result) - } - /// Returns the number of credentials. pub fn count_credentials(&self) -> Result { let mut iter_result = Ok(()); @@ -811,7 +786,8 @@ mod test { // These should have different IDs. let credential_source0 = create_credential_source(&mut rng, "example.com", vec![0x00]); let credential_source1 = create_credential_source(&mut rng, "example.com", vec![0x00]); - let expected_credential = credential_source1.clone(); + let credential_id0 = credential_source0.credential_id.clone(); + let credential_id1 = credential_source1.credential_id.clone(); assert!(persistent_store .store_credential(credential_source0) @@ -820,13 +796,14 @@ mod test { .store_credential(credential_source1) .is_ok()); assert_eq!(persistent_store.count_credentials().unwrap(), 1); - let filtered_credentials = persistent_store - .filter_credentials("example.com", false) - .unwrap(); - let retrieved_credential_source = persistent_store - .get_credential(filtered_credentials[0].0) - .unwrap(); - assert_eq!(retrieved_credential_source, expected_credential); + assert!(persistent_store + .find_credential("example.com", &credential_id0, false) + .unwrap() + .is_none()); + assert!(persistent_store + .find_credential("example.com", &credential_id1, false) + .unwrap() + .is_some()); let mut persistent_store = PersistentStore::new(&mut rng); for i in 0..MAX_SUPPORTED_RESIDENTIAL_KEYS { @@ -851,70 +828,21 @@ mod test { } #[test] - fn test_filter_get_credentials() { + fn test_get_credential() { let mut rng = ThreadRng256 {}; let mut persistent_store = PersistentStore::new(&mut rng); - assert_eq!(persistent_store.count_credentials().unwrap(), 0); let credential_source0 = create_credential_source(&mut rng, "example.com", vec![0x00]); let credential_source1 = create_credential_source(&mut rng, "example.com", vec![0x01]); let credential_source2 = create_credential_source(&mut rng, "another.example.com", vec![0x02]); - let id0 = credential_source0.credential_id.clone(); - let id1 = credential_source1.credential_id.clone(); - assert!(persistent_store - .store_credential(credential_source0) - .is_ok()); - assert!(persistent_store - .store_credential(credential_source1) - .is_ok()); - assert!(persistent_store - .store_credential(credential_source2) - .is_ok()); - - let filtered_credentials = persistent_store - .filter_credentials("example.com", false) - .unwrap(); - assert_eq!(filtered_credentials.len(), 2); - let retrieved_credential0 = persistent_store - .get_credential(filtered_credentials[0].0) - .unwrap(); - let retrieved_credential1 = persistent_store - .get_credential(filtered_credentials[1].0) - .unwrap(); - assert!( - (retrieved_credential0.credential_id == id0 - && retrieved_credential1.credential_id == id1) - || (retrieved_credential1.credential_id == id0 - && retrieved_credential0.credential_id == id1) - ); - } - - #[test] - fn test_filter_with_cred_protect() { - let mut rng = ThreadRng256 {}; - let mut persistent_store = PersistentStore::new(&mut rng); - assert_eq!(persistent_store.count_credentials().unwrap(), 0); - let private_key = crypto::ecdsa::SecKey::gensk(&mut rng); - let credential = PublicKeyCredentialSource { - key_type: PublicKeyCredentialType::PublicKey, - credential_id: rng.gen_uniform_u8x32().to_vec(), - private_key, - rp_id: String::from("example.com"), - user_handle: vec![0x00], - user_display_name: None, - cred_protect_policy: Some( - CredentialProtectionPolicy::UserVerificationOptionalWithCredentialIdList, - ), - creation_order: 0, - user_name: None, - user_icon: None, - }; - assert!(persistent_store.store_credential(credential).is_ok()); - - let no_credential = persistent_store - .filter_credentials("example.com", true) - .unwrap(); - assert_eq!(no_credential, vec![]); + let credential_sources = vec![credential_source0, credential_source1, credential_source2]; + for credential_source in credential_sources.into_iter() { + let cred_id = credential_source.credential_id.clone(); + assert!(persistent_store.store_credential(credential_source).is_ok()); + let (key, _) = persistent_store.find_credential_item(&cred_id).unwrap(); + let cred = persistent_store.get_credential(key).unwrap(); + assert_eq!(&cred_id, &cred.credential_id); + } } #[test]