diff --git a/src/se050_wireguard.c b/src/se050_wireguard.c index b5ec111..42cabb3 100644 --- a/src/se050_wireguard.c +++ b/src/se050_wireguard.c @@ -61,54 +61,46 @@ static bool constant_time_eq(const uint8_t *a, const uint8_t *b, size_t len) return result == 0; } -/* HKDF expand for WireGuard - simplified for 64 bytes output */ -static void wg_hkdf_expand(const uint8_t *prk, size_t prk_len, - uint8_t *out1, uint8_t *out2) +/* HKDF for WireGuard - always uses 32-byte PRK + * T(1) = HMAC(PRK, 0x01) + * T(2) = HMAC(PRK, T(1) || 0x02) + * T(3) = HMAC(PRK, T(2) || 0x03) + */ +static void wg_hkdf_2(const uint8_t *prk, + uint8_t *out1, uint8_t *out2) { - /* WireGuard uses a simplified HKDF: - * T(1) = HMAC(PRK, 0x01) - * T(2) = HMAC(PRK, T(1) || 0x02) - */ - /* T(1) = HMAC(PRK, 0x01) */ uint8_t c1 = 0x01; - se050_hmac_blake2s(out1, prk, prk_len, &c1, 1); + se050_hmac_blake2s(out1, prk, 32, &c1, 1); /* T(2) = HMAC(PRK, T(1) || 0x02) */ uint8_t t2_input[33]; memcpy(t2_input, out1, 32); t2_input[32] = 0x02; - se050_hmac_blake2s(out2, prk, prk_len, t2_input, 33); + se050_hmac_blake2s(out2, prk, 32, t2_input, 33); memzero_explicit(t2_input, 33); } -/* Compute HKDF-1 (two outputs) - WireGuard style */ -static void wg_hkdf_1(const uint8_t *prk, size_t prk_len, - uint8_t *out1, uint8_t *out2) -{ - wg_hkdf_expand(prk, prk_len, out1, out2); -} - -/* Compute HKDF-3 (three outputs) - WireGuard style */ -static void wg_hkdf_3(const uint8_t *prk, size_t prk_len, +/* HKDF-3 (three outputs) - WireGuard style */ +static void wg_hkdf_3(const uint8_t *prk, uint8_t *out1, uint8_t *out2, uint8_t *out3) { /* T(1) = HMAC(PRK, 0x01) */ uint8_t c1 = 0x01; - se050_hmac_blake2s(out1, prk, prk_len, &c1, 1); + se050_hmac_blake2s(out1, prk, 32, &c1, 1); /* T(2) = HMAC(PRK, T(1) || 0x02) */ uint8_t t2_input[33]; memcpy(t2_input, out1, 32); t2_input[32] = 0x02; - se050_hmac_blake2s(out2, prk, prk_len, t2_input, 33); + se050_hmac_blake2s(out2, prk, 32, t2_input, 33); /* T(3) = HMAC(PRK, T(2) || 0x03) */ uint8_t t3_input[33]; memcpy(t3_input, out2, 32); t3_input[32] = 0x03; - se050_hmac_blake2s(out3, prk, prk_len, t3_input, 33); + se050_hmac_blake2s(out3, prk, 32, t3_input, 33); memzero_explicit(t2_input, 33); memzero_explicit(t3_input, 33); @@ -180,10 +172,9 @@ int se050_wireguard_derive_keys(se050_wireguard_session_t *session, } /* Derive sending and receiving keys using HKDF - * WireGuard uses: HKDF(shared_secret, "WireGuard v1 zx2c4 IPsec v1") - * But simplified to just use shared_secret as PRK + * WireGuard uses simplified HKDF with 32-byte PRK */ - wg_hkdf_1(shared_secret, 32, session->sending_key, session->receiving_key); + wg_hkdf_2(shared_secret, session->sending_key, session->receiving_key); /* Reset nonces */ session->sending_nonce = 0; @@ -314,7 +305,6 @@ int se050_wireguard_decrypt_packet(se050_wireguard_session_t *session, memcpy(nonce_buf + 4, packet + 8, 8); size_t ciphertext_len = packet_len - 16 - 16; /* Total - header - tag */ - *plaintext_len = ciphertext_len; uint8_t tag[16]; memcpy(tag, packet + 16 + ciphertext_len, 16); @@ -337,7 +327,8 @@ int se050_wireguard_decrypt_packet(se050_wireguard_session_t *session, return -1; } - /* Update plaintext length and nonce */ + /* Update plaintext length and nonce only on success */ + *plaintext_len = ciphertext_len; session->receiving_nonce = nonce; session->packets_received++;