diff --git a/src/ctap/mod.rs b/src/ctap/mod.rs index 4e93210..5072a47 100644 --- a/src/ctap/mod.rs +++ b/src/ctap/mod.rs @@ -142,7 +142,7 @@ struct AssertionInput { struct AssertionState { assertion_input: AssertionInput, // Sorted by ascending order of creation, so the last element is the most recent one. - next_credentials: Vec, + next_credential_keys: Vec, } enum StatefulCommand { @@ -606,7 +606,7 @@ where // and returns the correct Get(Next)Assertion response. fn assertion_response( &mut self, - credential: PublicKeyCredentialSource, + mut credential: PublicKeyCredentialSource, assertion_input: AssertionInput, number_of_credentials: Option, ) -> Result { @@ -642,6 +642,12 @@ where key_id: credential.credential_id, transports: None, // You can set USB as a hint here. }; + // Remove user identifiable information without uv. + if !has_uv { + credential.user_name = None; + credential.user_display_name = None; + credential.user_icon = None; + } let user = if !credential.user_handle.is_empty() { Some(PublicKeyCredentialUserEntity { user_id: credential.user_handle, @@ -749,26 +755,23 @@ where } let rp_id_hash = Sha256::hash(rp_id.as_bytes()); - let mut applicable_credentials = if let Some(allow_list) = allow_list { - if let Some(credential) = - self.get_any_credential_from_allow_list(allow_list, &rp_id, &rp_id_hash, has_uv)? - { - vec![credential] - } else { - vec![] - } + let (credential, next_credential_keys) = if let Some(allow_list) = allow_list { + ( + self.get_any_credential_from_allow_list(allow_list, &rp_id, &rp_id_hash, has_uv)?, + vec![], + ) } else { - self.persistent_store.filter_credential(&rp_id, !has_uv)? + let mut stored_credentials = + self.persistent_store.filter_credentials(&rp_id, !has_uv)?; + stored_credentials.sort_unstable_by_key(|c| c.1); + let mut stored_credentials: Vec = + stored_credentials.into_iter().map(|c| c.0).collect(); + let credential = stored_credentials + .pop() + .map(|key| self.persistent_store.get_credential(key)) + .transpose()?; + (credential, stored_credentials) }; - // Remove user identifiable information without uv. - if !has_uv { - for credential in &mut applicable_credentials { - credential.user_name = None; - credential.user_display_name = None; - credential.user_icon = None; - } - } - applicable_credentials.sort_unstable_by_key(|c| c.creation_order); // This check comes before CTAP2_ERR_NO_CREDENTIALS in CTAP 2.0. // For CTAP 2.1, it was moved to a later protocol step. @@ -776,9 +779,7 @@ where (self.check_user_presence)(cid)?; } - let credential = applicable_credentials - .pop() - .ok_or(Ctap2StatusCode::CTAP2_ERR_NO_CREDENTIALS)?; + let credential = credential.ok_or(Ctap2StatusCode::CTAP2_ERR_NO_CREDENTIALS)?; self.increment_global_signature_counter()?; @@ -788,15 +789,15 @@ where hmac_secret_input, has_uv, }; - let number_of_credentials = if applicable_credentials.is_empty() { + let number_of_credentials = if next_credential_keys.is_empty() { None } else { - let number_of_credentials = Some(applicable_credentials.len() + 1); + let number_of_credentials = Some(next_credential_keys.len() + 1); self.stateful_command_permission = TimedPermission::granted(now, STATEFUL_COMMAND_TIMEOUT_DURATION); self.stateful_command_type = Some(StatefulCommand::GetAssertion(AssertionState { assertion_input: assertion_input.clone(), - next_credentials: applicable_credentials, + next_credential_keys, })); number_of_credentials }; @@ -812,10 +813,11 @@ where if let Some(StatefulCommand::GetAssertion(assertion_state)) = &mut self.stateful_command_type { - let credential = assertion_state - .next_credentials + let credential_key = assertion_state + .next_credential_keys .pop() .ok_or(Ctap2StatusCode::CTAP2_ERR_NOT_ALLOWED)?; + let credential = self.persistent_store.get_credential(credential_key)?; (assertion_state.assertion_input.clone(), credential) } else { return Err(Ctap2StatusCode::CTAP2_ERR_NOT_ALLOWED); @@ -1250,11 +1252,16 @@ mod test { ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID); assert!(make_credential_response.is_ok()); - let stored_credential = ctap_state + let credential_key = ctap_state .persistent_store - .filter_credential("example.com", false) + .filter_credentials("example.com", false) .unwrap() .pop() + .unwrap() + .0; + let stored_credential = ctap_state + .persistent_store + .get_credential(credential_key) .unwrap(); let credential_id = stored_credential.credential_id; assert_eq!(stored_credential.cred_protect_policy, Some(test_policy)); @@ -1275,11 +1282,16 @@ mod test { ctap_state.process_make_credential(make_credential_params, DUMMY_CHANNEL_ID); assert!(make_credential_response.is_ok()); - let stored_credential = ctap_state + let credential_key = ctap_state .persistent_store - .filter_credential("example.com", false) + .filter_credentials("example.com", false) .unwrap() .pop() + .unwrap() + .0; + let stored_credential = ctap_state + .persistent_store + .get_credential(credential_key) .unwrap(); let credential_id = stored_credential.credential_id; assert_eq!(stored_credential.cred_protect_policy, Some(test_policy)); diff --git a/src/ctap/storage.rs b/src/ctap/storage.rs index 76c2fd6..343751a 100644 --- a/src/ctap/storage.rs +++ b/src/ctap/storage.rs @@ -117,6 +117,24 @@ impl PersistentStore { Ok(()) } + /// Returns the credential at the given key. + /// + /// # Errors + /// + /// Returns `CTAP2_ERR_VENDOR_INTERNAL_ERROR` if the key does not hold a valid credential. + pub fn get_credential(&self, key: usize) -> Result { + let min_key = key::CREDENTIALS.start; + if key < min_key || key >= min_key + MAX_SUPPORTED_RESIDENTIAL_KEYS { + return Err(Ctap2StatusCode::CTAP2_ERR_VENDOR_INTERNAL_ERROR); + } + let credential_entry = self + .store + .find(key)? + .ok_or(Ctap2StatusCode::CTAP2_ERR_VENDOR_INTERNAL_ERROR)?; + deserialize_credential(&credential_entry) + .ok_or(Ctap2StatusCode::CTAP2_ERR_VENDOR_INTERNAL_ERROR) + } + /// Finds the key and value for a given credential ID. /// /// # Errors @@ -246,22 +264,23 @@ impl PersistentStore { /// Returns the list of matching credentials. /// /// Does not return credentials that are not discoverable if `check_cred_protect` is set. - pub fn filter_credential( + pub fn filter_credentials( &self, rp_id: &str, check_cred_protect: bool, - ) -> Result, Ctap2StatusCode> { + ) -> Result, Ctap2StatusCode> { let mut iter_result = Ok(()); let iter = self.iter_credentials(&mut iter_result)?; let result = iter - .filter_map(|(_, credential)| { - if credential.rp_id == rp_id { - Some(credential) + .filter_map(|(key, credential)| { + if credential.rp_id == rp_id + && (!check_cred_protect || credential.is_discoverable()) + { + Some((key, credential.creation_order)) } else { None } }) - .filter(|cred| !check_cred_protect || cred.is_discoverable()) .collect(); iter_result?; Ok(result) @@ -801,12 +820,13 @@ mod test { .store_credential(credential_source1) .is_ok()); assert_eq!(persistent_store.count_credentials().unwrap(), 1); - assert_eq!( - &persistent_store - .filter_credential("example.com", false) - .unwrap(), - &[expected_credential] - ); + let filtered_credentials = persistent_store + .filter_credentials("example.com", false) + .unwrap(); + let retrieved_credential_source = persistent_store + .get_credential(filtered_credentials[0].0) + .unwrap(); + assert_eq!(retrieved_credential_source, expected_credential); let mut persistent_store = PersistentStore::new(&mut rng); for i in 0..MAX_SUPPORTED_RESIDENTIAL_KEYS { @@ -831,7 +851,7 @@ mod test { } #[test] - fn test_filter() { + fn test_filter_get_credentials() { let mut rng = ThreadRng256 {}; let mut persistent_store = PersistentStore::new(&mut rng); assert_eq!(persistent_store.count_credentials().unwrap(), 0); @@ -852,14 +872,20 @@ mod test { .is_ok()); let filtered_credentials = persistent_store - .filter_credential("example.com", false) + .filter_credentials("example.com", false) .unwrap(); assert_eq!(filtered_credentials.len(), 2); + let retrieved_credential0 = persistent_store + .get_credential(filtered_credentials[0].0) + .unwrap(); + let retrieved_credential1 = persistent_store + .get_credential(filtered_credentials[1].0) + .unwrap(); assert!( - (filtered_credentials[0].credential_id == id0 - && filtered_credentials[1].credential_id == id1) - || (filtered_credentials[1].credential_id == id0 - && filtered_credentials[0].credential_id == id1) + (retrieved_credential0.credential_id == id0 + && retrieved_credential1.credential_id == id1) + || (retrieved_credential1.credential_id == id0 + && retrieved_credential0.credential_id == id1) ); } @@ -886,7 +912,7 @@ mod test { assert!(persistent_store.store_credential(credential).is_ok()); let no_credential = persistent_store - .filter_credential("example.com", true) + .filter_credentials("example.com", true) .unwrap(); assert_eq!(no_credential, vec![]); }