diff --git a/src/ctap/hid/mod.rs b/src/ctap/hid/mod.rs index 7489f6e..3a96699 100644 --- a/src/ctap/hid/mod.rs +++ b/src/ctap/hid/mod.rs @@ -227,44 +227,35 @@ impl CtapHid { } // CTAP specification (version 20190130) section 8.1.9.1.3 CtapHid::COMMAND_INIT => { - if cid == CtapHid::CHANNEL_BROADCAST { - if message.payload.len() != 8 { - return CtapHid::error_message(cid, CtapHid::ERR_INVALID_LEN); - } + if message.payload.len() != 8 { + return CtapHid::error_message(cid, CtapHid::ERR_INVALID_LEN); + } + let new_cid = if cid == CtapHid::CHANNEL_BROADCAST { // TODO: Prevent allocating 2^32 channels. self.allocated_cids += 1; - let allocated_cid = (self.allocated_cids as u32).to_ne_bytes(); - - let mut payload = vec![0; 17]; - payload[..8].copy_from_slice(&message.payload); - payload[8..12].copy_from_slice(&allocated_cid); - payload[12] = CtapHid::PROTOCOL_VERSION; - payload[13] = CtapHid::DEVICE_VERSION_MAJOR; - payload[14] = CtapHid::DEVICE_VERSION_MINOR; - payload[15] = CtapHid::DEVICE_VERSION_BUILD; - payload[16] = CtapHid::CAPABILITIES; - - // This unwrap is safe because the payload length is 17 <= 7609 bytes. - CtapHid::split_message(Message { - cid, - cmd: CtapHid::COMMAND_INIT, - payload, - }) - .unwrap() + (self.allocated_cids as u32).to_ne_bytes() } else { // Sync the channel and discard the current transaction. - // TODO: The specification (version 20190130) wording isn't clear about - // the payload format in this case. - // - // This unwrap is safe because the payload length is 0 <= 7609 bytes. - CtapHid::split_message(Message { - cid, - cmd: CtapHid::COMMAND_INIT, - payload: vec![], - }) - .unwrap() - } + cid + }; + + let mut payload = vec![0; 17]; + payload[..8].copy_from_slice(&message.payload); + payload[8..12].copy_from_slice(&new_cid); + payload[12] = CtapHid::PROTOCOL_VERSION; + payload[13] = CtapHid::DEVICE_VERSION_MAJOR; + payload[14] = CtapHid::DEVICE_VERSION_MINOR; + payload[15] = CtapHid::DEVICE_VERSION_BUILD; + payload[16] = CtapHid::CAPABILITIES; + + // This unwrap is safe because the payload length is 17 <= 7609 bytes. + CtapHid::split_message(Message { + cid, + cmd: CtapHid::COMMAND_INIT, + payload, + }) + .unwrap() } // CTAP specification (version 20190130) section 8.1.9.1.4 CtapHid::COMMAND_PING => { @@ -568,6 +559,66 @@ mod test { ); } + #[test] + fn test_command_init_for_sync() { + let mut rng = ThreadRng256 {}; + let user_immediately_present = |_| Ok(()); + let mut ctap_state = CtapState::new(&mut rng, user_immediately_present); + let mut ctap_hid = CtapHid::new(); + let cid = cid_from_init(&mut ctap_hid, &mut ctap_state); + + // Ping packet with a length longer than one packet. + let mut packet1 = [0x51; 64]; + packet1[..4].copy_from_slice(&cid); + packet1[4..7].copy_from_slice(&[0x81, 0x02, 0x00]); + // Init packet on the same channel. + let mut packet2 = [0x00; 64]; + packet2[..4].copy_from_slice(&cid); + packet2[4..15].copy_from_slice(&[ + 0x86, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0, + ]); + let mut result = Vec::new(); + let mut assembler_reply = MessageAssembler::new(); + for pkt_request in &[packet1, packet2] { + for pkt_reply in + ctap_hid.process_hid_packet(&pkt_request, DUMMY_CLOCK_VALUE, &mut ctap_state) + { + if let Some(message) = assembler_reply + .parse_packet(&pkt_reply, DUMMY_TIMESTAMP) + .unwrap() + { + result.push(message); + } + } + } + assert_eq!( + result, + vec![Message { + cid, + cmd: CtapHid::COMMAND_INIT, + payload: vec![ + 0x12, // Nonce + 0x34, + 0x56, + 0x78, + 0x9A, + 0xBC, + 0xDE, + 0xF0, + cid[0], // Allocated CID + cid[1], + cid[2], + cid[3], + 0x02, // Protocol version + 0x00, // Device version + 0x00, + 0x00, + CtapHid::CAPABILITIES + ] + }] + ); + } + #[test] fn test_command_ping() { let mut rng = ThreadRng256 {}; diff --git a/src/ctap/hid/receive.rs b/src/ctap/hid/receive.rs index fef51a4..b522837 100644 --- a/src/ctap/hid/receive.rs +++ b/src/ctap/hid/receive.rs @@ -586,5 +586,33 @@ mod test { ); } + #[test] + fn test_init_sync() { + let mut assembler = MessageAssembler::new(); + // Ping packet with a length longer than one packet. + assert_eq!( + assembler.parse_packet( + &byte_extend(&[0x12, 0x34, 0x56, 0x78, 0x81, 0x02, 0x00], 0x51), + DUMMY_TIMESTAMP + ), + Ok(None) + ); + // Init packet on the same channel. + assert_eq!( + assembler.parse_packet( + &zero_extend(&[ + 0x12, 0x34, 0x56, 0x78, 0x86, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, + 0xDE, 0xF0 + ]), + DUMMY_TIMESTAMP + ), + Ok(Some(Message { + cid: [0x12, 0x34, 0x56, 0x78], + cmd: 0x06, + payload: vec![0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xF0] + })) + ); + } + // TODO: more tests }