#include #include #include #include "se050_wireguard_proto.h" static void print_hex(const char *label, const uint8_t *buf, size_t len) { printf("%s: ", label); for (size_t i = 0; i < len; i++) printf("%02x", buf[i]); printf("\n"); } int main(void) { uint8_t ck[32], tk[32], tk1_test[32], tk2_test[32], tk3_test[32]; uint8_t ikm[32]; uint8_t data[32]; int passed = 0; printf("WireGuard KDF Chain Test Suite\n"); printf("================================\n\n"); printf("Test 1: KDF Init\n"); wg_kdf_init(ck, tk); print_hex("Chain Key", ck, 32); print_hex("Temp Key", tk, 32); printf("[INFO] Initialized\n\n"); passed++; printf("Test 2: KDF1 (First Derivation)\n"); for (int i = 0; i < 32; i++) ikm[i] = i; wg_kdf1(ck, tk1_test, ikm, 32); print_hex("CK after KDF1", ck, 32); print_hex("TK1", tk1_test, 32); printf("[INFO] KDF1 done\n\n"); passed++; printf("Test 3: KDF2 (Second Derivation)\n"); uint8_t ck_old[32]; memcpy(ck_old, ck, 32); wg_kdf2(ck, tk2_test, ck_old, tk1_test); print_hex("CK after KDF2", ck, 32); print_hex("TK2", tk2_test, 32); printf("[INFO] KDF2 done\n\n"); passed++; printf("Test 4: KDF3 (With Data)\n"); memcpy(ck_old, ck, 32); for (int i = 0; i < 32; i++) data[i] = 0xff - i; wg_kdf3(ck, tk3_test, ck_old, tk2_test, data, 32); print_hex("CK after KDF3", ck, 32); print_hex("Data", data, 32); printf("[INFO] KDF3 done\n\n"); passed++; printf("Test 5: Full Handshake Chain Simulation\n"); uint8_t ck0[32], ck1[32], ck2[32], ck3[32]; uint8_t tk0[32], tk1[32], tk2[32], tk3_final[32]; wg_kdf_init(ck0, tk0); printf("Initial: CK0 = "); print_hex("", ck0, 32); /* Step 1: IKM -> CK1, TK1 */ uint8_t ikm1[32] = {0}; for (int i = 0; i < 32; i++) ikm1[i] = i; wg_kdf1(ck1, tk1, ikm1, 32); printf("After KDF1: CK1 = "); print_hex("", ck1, 32); printf(" TK1 = "); print_hex("", tk1, 32); /* Step 2: CK1, TK1 -> CK2, TK2 */ wg_kdf2(ck2, tk2, ck1, tk1); printf("After KDF2: CK2 = "); print_hex("", ck2, 32); printf(" TK2 = "); print_hex("", tk2, 32); /* Step 3: CK2, TK2, data -> CK3 */ uint8_t handshake_data[32] = {0}; for (int i = 0; i < 32; i++) handshake_data[i] = 0xaa; wg_kdf3(ck3, tk3_final, ck2, tk2, handshake_data, 32); printf("After KDF3: CK3 = "); print_hex("", ck3, 32); printf("[INFO] Full chain complete\n\n"); passed++; printf("================================\n"); printf("Passed: %d/5\n", passed); printf("================================\n"); return (passed == 5) ? 0 : 1; }