diff --git a/tools/vendor_hid_test.py b/tools/vendor_hid_test.py new file mode 100644 index 0000000..616a4e6 --- /dev/null +++ b/tools/vendor_hid_test.py @@ -0,0 +1,269 @@ +"""These tests verify the functionality of the VendorHID interface.""" +from typing import Dict, Iterable +import hid +import time +import unittest + +_OPENSK_VID = 0x1915 +_OPENSK_PID = 0x521F +_FIDO_USAGE_PAGE = 0xF1D0 +_VENDOR_USAGE_PAGE = 0xFF00 +_PACKETS = 4 +_PACKET_SIZE = 64 +_SEND_DATA_SIZE = _PACKET_SIZE + 1 +_BROADCAST_CID = bytes([0xFF, 0xFF, 0xFF, 0xFF]) + + +def sleep(): + # TODO(liamjm): remove this sleep once it is not necessary. + time.sleep(.01) + pass + + +def ping_data_size(packets): + return 57 + 59 * (packets - 1) + + +def get_byte(): + """Return a unique byte value per request.""" + get_byte.byte_val += 1 + return get_byte.byte_val + + +get_byte.byte_val = 0x10 + + +class HidDevice(object): + """Class that helps interact with HID devices.""" + + def __init__(self, device): + self.device = device + self.dev = None + self.cid = None + self.rx_packets = [] + self.create_and_init() + + def __del__(self): + if self.dev: + self.dev.close() + + def reset(self) -> None: + self.rx_packets = [] + + def create_and_init(self) -> None: + self.dev = hid.Device(path=self.device['path']) + # Nonce is all zeros, because we don't care. + init_packet = [0] + list(_BROADCAST_CID) + [0x86, 0x00, 0x08] + [0x00] * 57 + assert len(init_packet) == _SEND_DATA_SIZE, ( + f'Expected packet to be {_SEND_DATA_SIZE} but was {len(init_packet)}') + self.dev.write(bytes(init_packet)) + self.cid = self.dev.read(_PACKET_SIZE, 2000)[15:19] + sleep() + + def ping_init(self, packets=1, byte=0x88) -> int: + size = ping_data_size(packets) + ping_packet = [0] + list(self.cid) + [0x81, size // 256, size % 256 + ] + [byte] * 57 + assert len(ping_packet) == _SEND_DATA_SIZE, ( + f'Expected packet to be {_SEND_DATA_SIZE} but was {len(ping_packet)}') + + r = self.dev.write(bytes(ping_packet)) + sleep() + return r + + def ping_continue(self, num, byte=0x88) -> int: + continue_packet = [0] + list(self.cid) + [num] + [byte] * 59 + assert len(continue_packet) == _SEND_DATA_SIZE, ( + f'Expected packet to be {_SEND_DATA_SIZE} ' + 'but was {len(continue_packet)}') + r = self.dev.write(bytes(continue_packet)) + sleep() + return r + + def read_and_print(self) -> int: + d = self.dev.read(_PACKET_SIZE, 2000) + self.rx_packets.append(d) + sleep() + return len(d) + + def get_received_data(self) -> bytes: + """This combines the data from the received packets, to match the ping +packets sent.""" + d = b'' + d += self.rx_packets.pop(0)[7:] + for p in self.rx_packets: + d += p[5:] + # And clear the packets received to ensure consistency. + self.rx_packets = [] + return d + + +def get_devices(usage_page) -> Iterable[Dict]: + for device in hid.enumerate(_OPENSK_VID, _OPENSK_PID): + if device['usage_page'] == usage_page: + yield device + + +class HidInterfaces(unittest.TestCase): + """Tests for the Vendor and FIDO HID interfaces.""" + + @classmethod + def setUpClass(cls): + cls.fido_hid = cls.get_device(_FIDO_USAGE_PAGE) + cls.vendor_hid = cls.get_device(_VENDOR_USAGE_PAGE) + + @classmethod + def get_device(cls, usage_page) -> HidDevice: + devices = list(get_devices(usage_page)) + if len(devices) != 1: + raise Exception(f'Found {len(devices)} devices') + return HidDevice(devices[0]) + + def setUp(self) -> None: + super().setUp() + # Ensure the rx_packets are empty + self.fido_hid.reset() + self.vendor_hid.reset() + + def assertReceivedDataMatches(self, device: HidDevice, byte): + expected = bytes([byte] * ping_data_size(_PACKETS)) + self.assertEqual(len(device.rx_packets), _PACKETS) + self.assertEqual(device.get_received_data(), expected) + + def test_00_init(self): + self.assertNotEqual(self.vendor_hid, None) + self.assertNotEqual(self.vendor_hid.device, None) + self.assertNotEqual(self.vendor_hid.dev, None) + self.assertNotEqual(self.vendor_hid.cid, None) + + self.assertNotEqual(self.fido_hid, None) + self.assertNotEqual(self.fido_hid.device, None) + self.assertNotEqual(self.fido_hid.dev, None) + self.assertNotEqual(self.fido_hid.cid, None) + + def test_01_cid(self): + self.assertNotEqual(self.vendor_hid.cid, _BROADCAST_CID) + + def _test_ping(self, dev: HidDevice, byte): + r = dev.ping_init(_PACKETS, byte) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS - 1): + r = dev.ping_continue(i, byte) + self.assertEqual(r, _SEND_DATA_SIZE) + for _ in range(_PACKETS): + r = dev.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + self.assertReceivedDataMatches(dev, byte) + + def test_02_fido_ping(self): + self._test_ping(self.fido_hid, byte=0x11) + + def test_03_vendor_ping(self): + self._test_ping(self.vendor_hid, byte=0x22) + + def _test_send_and_receive(self, a: HidDevice): + byte_a = get_byte() + r = a.ping_init(packets=_PACKETS, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS - 1): + r = a.ping_continue(i, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS): + r = a.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + self.assertReceivedDataMatches(a, byte_a) + + def test_04_send_and_receive_fido(self): + self._test_send_and_receive(self.fido_hid) + + def test_05_send_and_receive_vendor(self): + self._test_send_and_receive(self.vendor_hid) + + def _test_interleaved_send_and_receive(self, a: HidDevice, b: HidDevice): + byte_a = get_byte() + byte_b = get_byte() + r = a.ping_init(packets=_PACKETS, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + r = b.ping_init(packets=_PACKETS, byte=byte_b) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS - 1): + r = a.ping_continue(i, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + r = b.ping_continue(i, byte=byte_b) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS): + r = a.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + r = b.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + self.assertReceivedDataMatches(a, byte_a) + self.assertReceivedDataMatches(b, byte_b) + + def test_06_interleaved_send_and_receive_fido_first(self): + self._test_interleaved_send_and_receive(self.fido_hid, self.vendor_hid) + + def test_07_interleaved_send_and_receive_vendor_first(self): + self._test_interleaved_send_and_receive(self.vendor_hid, self.fido_hid) + + def _test_interleaved_send_and_batch_receive(self, a: HidDevice, + b: HidDevice): + byte_a = get_byte() + byte_b = get_byte() + r = a.ping_init(packets=_PACKETS, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + r = b.ping_init(packets=_PACKETS, byte=byte_b) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS - 1): + r = a.ping_continue(i, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + r = b.ping_continue(i, byte=byte_b) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS): + r = a.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + for i in range(_PACKETS): + r = b.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + self.assertReceivedDataMatches(a, byte_a) + self.assertReceivedDataMatches(b, byte_b) + + def test_08_interleaved_send_and_batch_receive_fido_first(self): + self._test_interleaved_send_and_batch_receive(self.fido_hid, + self.vendor_hid) + + def test_09_interleaved_send_and_batch_receive_vendor_first(self): + self._test_interleaved_send_and_batch_receive(self.vendor_hid, + self.fido_hid) + + def _test_batch_send_and_interleaved_receive(self, a: HidDevice, + b: HidDevice): + byte_a = get_byte() + byte_b = get_byte() + r = a.ping_init(packets=_PACKETS, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS - 1): + r = a.ping_continue(i, byte=byte_a) + self.assertEqual(r, _SEND_DATA_SIZE) + r = b.ping_init(packets=_PACKETS, byte=byte_b) + for i in range(_PACKETS - 1): + r = b.ping_continue(i, byte=byte_b) + self.assertEqual(r, _SEND_DATA_SIZE) + for i in range(_PACKETS): + r = a.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + r = b.read_and_print() + self.assertEqual(r, _PACKET_SIZE) + self.assertReceivedDataMatches(a, byte_a) + self.assertReceivedDataMatches(b, byte_b) + + def test_10_batch_send_and_interleaved_receive_fido_first(self): + self._test_batch_send_and_interleaved_receive(self.fido_hid, + self.vendor_hid) + + def test_11_batch_send_and_interleaved_receive_vendor_first(self): + self._test_batch_send_and_interleaved_receive(self.vendor_hid, + self.fido_hid) + + +if __name__ == '__main__': + unittest.main()