diff --git a/Makefile b/Makefile index 2fa4c78..2a94999 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ SRCS = src/se050_i2c_hal.c \ src/se050_rng.c \ src/se050_x25519.c \ src/se050_x25519_sw.c \ + src/se050_chacha20_poly1305.c \ src/se050_scp03.c \ src/se050_scp03_keys.c @@ -30,6 +31,7 @@ TEST_SE050 = test_scp03_se050 TEST_X25519 = test_x25519_ecdh TEST_KEY_ROTATION = test_key_rotation TEST_X25519_SW = test_x25519_sw +TEST_CHACHA20 = test_chacha20_poly1305 # Target library LIB = libse050_wireguard.a @@ -54,7 +56,7 @@ else endif # Default target -all: $(LIB) $(TEST_SCP03) $(TEST_HARDWARE) $(TEST_SE050) $(TEST_X25519) $(TEST_X25519_SW) +all: $(LIB) $(TEST_SCP03) $(TEST_HARDWARE) $(TEST_SE050) $(TEST_X25519) $(TEST_X25519_SW) $(TEST_CHACHA20) # Create build directory build: @@ -94,6 +96,11 @@ $(TEST_X25519_SW): src/se050_x25519_sw.c @mkdir -p build $(CC) $(CFLAGS) -DX25519_SW_TEST -o build/$@ $< +# ChaCha20-Poly1305 test +$(TEST_CHACHA20): src/se050_chacha20_poly1305.c + @mkdir -p build + $(CC) $(CFLAGS) -DCHACHA20_POLY1305_TEST -o build/$@ $< + # Compile source files src/%.o: src/%.c $(CC) $(CFLAGS) -c $< -o $@ @@ -116,6 +123,9 @@ test: all @echo "Running Software X25519 tests..." ./build/$(TEST_X25519_SW) @echo "" + @echo "Running ChaCha20-Poly1305 tests..." + ./build/$(TEST_CHACHA20) + @echo "" @echo "Note: To run SE050 hardware tests, use:" @echo " make SE050_CHIP=SE050C1 test_se050" diff --git a/include/se050_chacha20_poly1305.h b/include/se050_chacha20_poly1305.h new file mode 100644 index 0000000..1b2b9bd --- /dev/null +++ b/include/se050_chacha20_poly1305.h @@ -0,0 +1,206 @@ +/** + * @file se050_chacha20_poly1305.h + * @brief ChaCha20-Poly1305 AEAD Implementation + * + * Software implementation for WireGuard protocol. + * Based on RFC 7539 (ChaCha20-Poly1305) and RFC 8434 (WireGuard). + * + * License: MIT (Clean-room implementation) + */ + +#ifndef SE050_CHACHA20_POLY1305_H +#define SE050_CHACHA20_POLY1305_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* ============================================================================ + * Constants + * ============================================================================ */ + +#define CHACHA20_KEY_SIZE 32 +#define CHACHA20_NONCE_SIZE 12 +#define CHACHA20_BLOCK_SIZE 64 + +#define POLY1305_KEY_SIZE 32 +#define POLY1305_TAG_SIZE 16 + +#define CHACHA20_POLY1305_AEAD_KEY_SIZE 32 +#define CHACHA20_POLY1305_AEAD_NONCE_SIZE 12 +#define CHACHA20_POLY1305_TAG_SIZE 16 + +/* WireGuard-specific constants */ +#define WG_KEY_SIZE 32 +#define WG_NONCE_SIZE 12 + +/* ============================================================================ + * Type Definitions + * ============================================================================ */ + +/** + * @brief ChaCha20-Poly1305 AEAD context + */ +typedef struct { + uint8_t key[CHACHA20_KEY_SIZE]; +} se050_chacha20_poly1305_ctx_t; + +/* ============================================================================ + * ChaCha20 Core Functions + * ============================================================================ */ + +/** + * @brief ChaCha20 quarter round + * + * @param a Pointer to a + * @param b Pointer to b + * @param c Pointer to c + * @param d Pointer to d + */ +void se050_chacha20_quarter_round(uint32_t *a, uint32_t *b, uint32_t *c, uint32_t *d); + +/** + * @brief ChaCha20 block function + * + * @param output Output buffer (64 bytes) + * @param key Key (32 bytes) + * @param counter Initial counter (4 bytes) + * @param nonce Nonce (12 bytes) + */ +void se050_chacha20_block(uint8_t output[64], const uint8_t key[32], + uint32_t counter, const uint8_t nonce[12]); + +/** + * @brief ChaCha20 encrypt/decrypt + * + * @param output Output buffer + * @param input Input buffer + * @param len Length + * @param key Key (32 bytes) + * @param counter Initial counter + * @param nonce Nonce (12 bytes) + */ +void se050_chacha20(uint8_t *output, const uint8_t *input, size_t len, + const uint8_t key[32], uint32_t counter, const uint8_t nonce[12]); + +/* ============================================================================ + * Poly1305 Core Functions + * ============================================================================ */ + +/** + * @brief Poly1305 MAC generation + * + * @param mac Output MAC (16 bytes) + * @param key Key (32 bytes) + * @param data Data to authenticate + * @param len Data length + */ +void se050_poly1305_mac(uint8_t mac[16], const uint8_t key[32], + const uint8_t *data, size_t len); + +/* ============================================================================ + * ChaCha20-Poly1305 AEAD + * ============================================================================ */ + +/** + * @brief Initialize ChaCha20-Poly1305 context + * + * @param ctx Context to initialize + * @param key Key (32 bytes) + * @return 0 on success, -1 on error + */ +int se050_chacha20_poly1305_init(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t key[CHACHA20_KEY_SIZE]); + +/** + * @brief ChaCha20-Poly1305 AEAD encryption + * + * @param ctx Context + * @param nonce Nonce (12 bytes) + * @param plaintext Plaintext data + * @param plaintext_len Plaintext length + * @param aad Additional authenticated data + * @param aad_len AAD length + * @param ciphertext Output ciphertext + * @param tag Output authentication tag (16 bytes) + * @return 0 on success, -1 on error + */ +int se050_chacha20_poly1305_encrypt(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t nonce[CHACHA20_NONCE_SIZE], + const uint8_t *plaintext, size_t plaintext_len, + const uint8_t *aad, size_t aad_len, + uint8_t *ciphertext, uint8_t tag[POLY1305_TAG_SIZE]); + +/** + * @brief ChaCha20-Poly1305 AEAD decryption + * + * @param ctx Context + * @param nonce Nonce (12 bytes) + * @param ciphertext Ciphertext data + * @param ciphertext_len Ciphertext length + * @param aad Additional authenticated data + * @param aad_len AAD length + * @param tag Authentication tag (16 bytes) + * @param plaintext Output plaintext + * @return 0 on success, -1 on error (tag mismatch) + */ +int se050_chacha20_poly1305_decrypt(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t nonce[CHACHA20_NONCE_SIZE], + const uint8_t *ciphertext, size_t ciphertext_len, + const uint8_t *aad, size_t aad_len, + const uint8_t tag[POLY1305_TAG_SIZE], + uint8_t *plaintext); + +/** + * @brief WireGuard-specific encrypt + * + * WireGuard uses ChaCha20-Poly1305 with: + * - 32-byte key + * - 12-byte nonce + * - No AAD + * + * @param key Key (32 bytes) + * @param nonce Nonce (12 bytes) + * @param plaintext Plaintext + * @param len Length + * @param ciphertext Output ciphertext + * @param tag Output tag (16 bytes) + * @return 0 on success, -1 on error + */ +int se050_wireguard_encrypt(const uint8_t key[WG_KEY_SIZE], + const uint8_t nonce[WG_NONCE_SIZE], + const uint8_t *plaintext, size_t len, + uint8_t *ciphertext, uint8_t tag[POLY1305_TAG_SIZE]); + +/** + * @brief WireGuard-specific decrypt + * + * @param key Key (32 bytes) + * @param nonce Nonce (12 bytes) + * @param ciphertext Ciphertext + * @param len Length + * @param tag Tag (16 bytes) + * @param plaintext Output plaintext + * @return 0 on success, -1 on error + */ +int se050_wireguard_decrypt(const uint8_t key[WG_KEY_SIZE], + const uint8_t nonce[WG_NONCE_SIZE], + const uint8_t *ciphertext, size_t len, + const uint8_t tag[POLY1305_TAG_SIZE], + uint8_t *plaintext); + +/** + * @brief Securely zeroize context + * + * @param ctx Context to zeroize + */ +void se050_chacha20_poly1305_zeroize(se050_chacha20_poly1305_ctx_t *ctx); + +#ifdef __cplusplus +} +#endif + +#endif /* SE050_CHACHA20_POLY1305_H */ diff --git a/src/se050_chacha20_poly1305.c b/src/se050_chacha20_poly1305.c new file mode 100644 index 0000000..3bb6b64 --- /dev/null +++ b/src/se050_chacha20_poly1305.c @@ -0,0 +1,696 @@ +/** + * @file se050_chacha20_poly1305.c + * @brief ChaCha20-Poly1305 AEAD Implementation + * Based on RFC 7539 and RFC 8434 (WireGuard) + * License: MIT (Clean-room implementation) + */ + +#include "se050_chacha20_poly1305.h" +#include "se050_crypto_utils.h" +#include + +/* ESP32 detection */ +#if defined(ESP_PLATFORM) || defined(__XTENSA__) || defined(__riscv) +#define SE050_CHACHA20_ESP32 1 +#else +#define SE050_CHACHA20_ESP32 0 +#endif + +/* ============================================================================ + * ChaCha20 Implementation + * ============================================================================ */ + +/** + * @brief ChaCha20 quarter round + */ +void se050_chacha20_quarter_round(uint32_t *a, uint32_t *b, uint32_t *c, uint32_t *d) +{ + *a += *b; *d ^= *a; *d <<= 16; *d |= (*d >> 16); + *c += *d; *b ^= *c; *b <<= 12; *b |= (*b >> 20); + *a += *b; *d ^= *a; *d <<= 8; *d |= (*d >> 24); + *c += *d; *b ^= *c; *b <<= 7; *b |= (*b >> 25); +} + +/** + * @brief ChaCha20 block function + */ +void se050_chacha20_block(uint8_t output[64], const uint8_t key[32], + uint32_t counter, const uint8_t nonce[12]) +{ + /* Constants "expand 32-byte k" */ + uint32_t state[16] = { + 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + + /* Load key (8 words) */ + state[4] = (uint32_t)key[0] | ((uint32_t)key[1] << 8) | + ((uint32_t)key[2] << 16) | ((uint32_t)key[3] << 24); + state[5] = (uint32_t)key[4] | ((uint32_t)key[5] << 8) | + ((uint32_t)key[6] << 16) | ((uint32_t)key[7] << 24); + state[6] = (uint32_t)key[8] | ((uint32_t)key[9] << 8) | + ((uint32_t)key[10] << 16) | ((uint32_t)key[11] << 24); + state[7] = (uint32_t)key[12] | ((uint32_t)key[13] << 8) | + ((uint32_t)key[14] << 16) | ((uint32_t)key[15] << 24); + state[8] = (uint32_t)key[16] | ((uint32_t)key[17] << 8) | + ((uint32_t)key[18] << 16) | ((uint32_t)key[19] << 24); + state[9] = (uint32_t)key[20] | ((uint32_t)key[21] << 8) | + ((uint32_t)key[22] << 16) | ((uint32_t)key[23] << 24); + state[10] = (uint32_t)key[24] | ((uint32_t)key[25] << 8) | + ((uint32_t)key[26] << 16) | ((uint32_t)key[27] << 24); + state[11] = (uint32_t)key[28] | ((uint32_t)key[29] << 8) | + ((uint32_t)key[30] << 16) | ((uint32_t)key[31] << 24); + + /* Load counter and nonce */ + state[12] = counter; + state[13] = (uint32_t)nonce[0] | ((uint32_t)nonce[1] << 8) | + ((uint32_t)nonce[2] << 16) | ((uint32_t)nonce[3] << 24); + state[14] = (uint32_t)nonce[4] | ((uint32_t)nonce[5] << 8) | + ((uint32_t)nonce[6] << 16) | ((uint32_t)nonce[7] << 24); + state[15] = (uint32_t)nonce[8] | ((uint32_t)nonce[9] << 8) | + ((uint32_t)nonce[10] << 16) | ((uint32_t)nonce[11] << 24); + + /* Save initial state */ + uint32_t initial[16]; + memcpy(initial, state, sizeof(initial)); + + /* 20 rounds (10 double rounds) */ + for (int i = 0; i < 10; i++) { + /* Column rounds */ + se050_chacha20_quarter_round(&state[0], &state[4], &state[8], &state[12]); + se050_chacha20_quarter_round(&state[1], &state[5], &state[9], &state[13]); + se050_chacha20_quarter_round(&state[2], &state[6], &state[10], &state[14]); + se050_chacha20_quarter_round(&state[3], &state[7], &state[11], &state[15]); + /* Diagonal rounds */ + se050_chacha20_quarter_round(&state[0], &state[5], &state[10], &state[15]); + se050_chacha20_quarter_round(&state[1], &state[6], &state[11], &state[12]); + se050_chacha20_quarter_round(&state[2], &state[7], &state[8], &state[13]); + se050_chacha20_quarter_round(&state[3], &state[4], &state[9], &state[14]); + } + + /* Add initial state */ + for (int i = 0; i < 16; i++) { + state[i] += initial[i]; + } + + /* Output little-endian */ + for (int i = 0; i < 16; i++) { + output[i*4] = (uint8_t)(state[i]); + output[i*4+1] = (uint8_t)(state[i] >> 8); + output[i*4+2] = (uint8_t)(state[i] >> 16); + output[i*4+3] = (uint8_t)(state[i] >> 24); + } +} + +/** + * @brief ChaCha20 encrypt/decrypt + */ +void se050_chacha20(uint8_t *output, const uint8_t *input, size_t len, + const uint8_t key[32], uint32_t counter, const uint8_t nonce[12]) +{ + uint8_t block[64]; + size_t i; + + for (i = 0; i < len; i += 64) { + se050_chacha20_block(block, key, counter++, nonce); + size_t chunk = (i + 64 <= len) ? 64 : (len - i); + for (size_t j = 0; j < chunk; j++) { + output[i + j] = input[i + j] ^ block[j]; + } + } +} + +/* ============================================================================ + * Poly1305 Implementation + * ============================================================================ */ + +#if SE050_CHACHA20_ESP32 +/* ESP32 32-bit optimized Poly1305 */ + +typedef struct { + uint32_t r[4]; + uint32_t h[4]; + uint32_t s[2]; + uint8_t buf[16]; + size_t left; +} poly1305_state_t; + +static void poly1305_init(poly1305_state_t *st, const uint8_t key[32]) +{ + uint32_t r0 = (uint32_t)key[0] | ((uint32_t)key[1] << 8) | + ((uint32_t)key[2] << 16) | ((uint32_t)key[3] << 24); + uint32_t r1 = (uint32_t)key[4] | ((uint32_t)key[5] << 8) | + ((uint32_t)key[6] << 16) | ((uint32_t)key[7] << 24); + uint32_t r2 = (uint32_t)key[8] | ((uint32_t)key[9] << 8) | + ((uint32_t)key[10] << 16) | ((uint32_t)key[11] << 24); + uint32_t r3 = (uint32_t)key[12] | ((uint32_t)key[13] << 8) | + ((uint32_t)key[14] << 16) | ((uint32_t)key[15] << 24); + + uint32_t s0 = (uint32_t)key[16] | ((uint32_t)key[17] << 8) | + ((uint32_t)key[18] << 16) | ((uint32_t)key[19] << 24); + uint32_t s1 = (uint32_t)key[20] | ((uint32_t)key[21] << 8) | + ((uint32_t)key[22] << 16) | ((uint32_t)key[23] << 24); + + st->r[0] = r0 & 0x3ffffff; + st->r[1] = ((r0 >> 26) | (r1 << 8)) & 0x3ffff03; + st->r[2] = ((r1 >> 18) | (r2 << 16)) & 0x3ffc0ff; + st->r[3] = ((r2 >> 10) | (r3 << 24)) & 0x3f03fff; + st->r[4] = (r3 >> 2) & 0x00fffff; + + st->s[0] = s0; + st->s[1] = s1; + + for (int i = 0; i < 4; i++) st->h[i] = 0; + st->left = 0; +} + +static void poly1305_update(poly1305_state_t *st, const uint8_t *data, size_t len) +{ + if (st->left) { + size_t needed = 16 - st->left; + if (len < needed) { + memcpy(st->buf + st->left, data, len); + st->left += len; + return; + } + memcpy(st->buf + st->left, data, needed); + data += needed; + len -= needed; + + uint32_t hibit = 0x01000000; + uint32_t d0 = st->buf[0] | (st->buf[1] << 8) | (st->buf[2] << 16) | ((st->buf[3] | hibit) << 24); + uint32_t d1 = (st->buf[4] | (st->buf[5] << 8) | (st->buf[6] << 16) | (st->buf[7] << 24)) & 0x3ffff03; + uint32_t d2 = (st->buf[8] | (st->buf[9] << 8) | (st->buf[10] << 16) | (st->buf[11] << 24)) & 0x3ffc0ff; + uint32_t d3 = (st->buf[12] | (st->buf[13] << 8) | (st->buf[14] << 16) | (st->buf[15] << 24)) & 0x3f03fff; + + uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; + uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; + + uint64_t t0 = h0 + d0; + uint64_t t1 = h1 + d1; + uint64_t t2 = h2 + d2; + uint64_t t3 = h3 + d3; + + t1 += (t0 >> 26); t0 &= 0x3ffffff; + t2 += (t1 >> 22); t1 &= 0x3ffffff; + t3 += (t2 >> 22); t2 &= 0x3ffffff; + t0 += (t3 >> 22) * 5; t3 &= 0x3ffffff; + + h0 = t0; h1 = t1; h2 = t2; h3 = t3; + + uint64_t c0 = h0, c1 = h1, c2 = h2, c3 = h3; + c0 *= r0; c1 *= r0; c2 *= r0; c3 *= r0; + c0 *= r1; c1 *= r1; c2 *= r1; c3 *= r1; + c0 *= r2; c1 *= r2; c2 *= r2; c3 *= r2; + c0 *= r3; c1 *= r3; c2 *= r3; c3 *= r3; + c0 *= r4; c1 *= r4; c2 *= r4; c3 *= r4; + + st->h[0] = (uint32_t)(h0 & 0x3ffffff); + st->h[1] = (uint32_t)(h1 & 0x3ffffff); + st->h[2] = (uint32_t)(h2 & 0x3ffffff); + st->h[3] = (uint32_t)(h3 & 0x3ffffff); + st->h[4] = 0; + st->left = 0; + } + + while (len >= 16) { + uint32_t hibit = 0x01000000; + uint32_t d0 = data[0] | (data[1] << 8) | (data[2] << 16) | ((data[3] | hibit) << 24); + uint32_t d1 = (data[4] | (data[5] << 8) | (data[6] << 16) | (data[7] << 24)) & 0x3ffff03; + uint32_t d2 = (data[8] | (data[9] << 8) | (data[10] << 16) | (data[11] << 24)) & 0x3ffc0ff; + uint32_t d3 = (data[12] | (data[13] << 8) | (data[14] << 16) | (data[15] << 24)) & 0x3f03fff; + + uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; + uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; + + h0 += d0; h1 += d1; h2 += d2; h3 += d3; + + uint64_t t0 = h0, t1 = h1, t2 = h2, t3 = h3; + t1 += (t0 >> 26); t0 &= 0x3ffffff; + t2 += (t1 >> 22); t1 &= 0x3ffffff; + t3 += (t2 >> 22); t2 &= 0x3ffffff; + t0 += (t3 >> 22) * 5; t3 &= 0x3ffffff; + + h0 = t0; h1 = t1; h2 = t2; h3 = t3; + + st->h[0] = (uint32_t)(h0 & 0x3ffffff); + st->h[1] = (uint32_t)(h1 & 0x3ffffff); + st->h[2] = (uint32_t)(h2 & 0x3ffffff); + st->h[3] = (uint32_t)(h3 & 0x3ffffff); + st->h[4] = 0; + + data += 16; + len -= 16; + } + + if (len) { + memcpy(st->buf + st->left, data, len); + st->left += len; + } +} + +static void poly1305_final(poly1305_state_t *st, uint8_t mac[16]) +{ + if (st->left) { + uint32_t hibit = 0x01000000; + uint32_t d0 = st->buf[0] | (st->buf[1] << 8) | (st->buf[2] << 16) | ((st->buf[3] | hibit) << 24); + uint32_t d1 = (st->buf[4] | (st->buf[5] << 8) | (st->buf[6] << 16) | (st->buf[7] << 24)) & 0x3ffff03; + uint32_t d2 = (st->buf[8] | (st->buf[9] << 8) | (st->buf[10] << 16) | (st->buf[11] << 24)) & 0x3ffc0ff; + uint32_t d3 = (st->buf[12] | (st->buf[13] << 8) | (st->buf[14] << 16) | (st->buf[15] << 24)) & 0x3f03fff; + + uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3]; + h0 += d0; h1 += d1; h2 += d2; h3 += d3; + + st->h[0] = (uint32_t)(h0 & 0x3ffffff); + st->h[1] = (uint32_t)(h1 & 0x3ffffff); + st->h[2] = (uint32_t)(h2 & 0x3ffffff); + st->h[3] = (uint32_t)(h3 & 0x3ffffff); + } + + uint32_t c = st->h[4] + 5; + uint32_t mask = -(c >> 26); + c &= 0x3ffffff; + + uint32_t h0 = st->h[0] + (c & ~mask); + uint32_t h1 = st->h[1] + ((c >> 26) & ~mask); + uint32_t h2 = st->h[2] + ((c >> 52) & ~mask); + uint32_t h3 = st->h[3]; + + uint32_t s0 = st->s[0]; + uint32_t s1 = st->s[1]; + + uint64_t mac0 = (uint64_t)h0 + s0; + uint64_t mac1 = (uint64_t)h1 + s1 + (mac0 >> 32); + + mac[0] = (uint8_t)mac0; + mac[1] = (uint8_t)(mac0 >> 8); + mac[2] = (uint8_t)(mac0 >> 16); + mac[3] = (uint8_t)(mac0 >> 24); + mac[4] = (uint8_t)mac1; + mac[5] = (uint8_t)(mac1 >> 8); + mac[6] = (uint8_t)(mac1 >> 16); + mac[7] = (uint8_t)(mac1 >> 24); + mac[8] = 0; mac[9] = 0; mac[10] = 0; mac[11] = 0; + mac[12] = 0; mac[13] = 0; mac[14] = 0; mac[15] = 0; +} + +#else +/* Standard 64-bit Poly1305 */ + +typedef struct { + uint64_t r[5]; + uint64_t h[5]; + uint64_t s[2]; + uint8_t buf[16]; + size_t left; +} poly1305_state_t; + +static void poly1305_init(poly1305_state_t *st, const uint8_t key[32]) +{ + st->r[0] = ((uint64_t)key[0] | ((uint64_t)key[1] << 8) | + ((uint64_t)key[2] << 16) | ((uint64_t)key[3] << 24)) & 0x3ffffff; + st->r[1] = ((uint64_t)key[4] | ((uint64_t)key[5] << 8) | + ((uint64_t)key[6] << 16) | ((uint64_t)key[7] << 24) | + ((uint64_t)key[8] << 32) | ((uint64_t)key[9] << 40)) & 0x3ffff03; + st->r[2] = ((uint64_t)key[10] | ((uint64_t)key[11] << 8) | + ((uint64_t)key[12] << 16) | ((uint64_t)key[13] << 24) | + ((uint64_t)key[14] << 32) | ((uint64_t)key[15] << 40)) & 0x3ffc0ff; + st->r[3] = ((uint64_t)key[16] | ((uint64_t)key[17] << 8) | + ((uint64_t)key[18] << 16) | ((uint64_t)key[19] << 24) | + ((uint64_t)key[20] << 32) | ((uint64_t)key[21] << 40)) & 0x3f03fff; + st->r[4] = ((uint64_t)key[22] | ((uint64_t)key[23] << 8) | + ((uint64_t)key[24] << 16) | ((uint64_t)key[25] << 24) | + ((uint64_t)key[26] << 32) | ((uint64_t)key[27] << 40)) & 0x00fffff; + + st->s[0] = ((uint64_t)key[28] | ((uint64_t)key[29] << 8) | + ((uint64_t)key[30] << 16) | ((uint64_t)key[31] << 24)); + st->s[1] = ((uint64_t)key[32] | ((uint64_t)key[33] << 8) | + ((uint64_t)key[34] << 16) | ((uint64_t)key[35] << 24)); + + for (int i = 0; i < 5; i++) st->h[i] = 0; + st->left = 0; +} + +static void poly1305_update(poly1305_state_t *st, const uint8_t *data, size_t len) +{ + if (st->left) { + size_t needed = 16 - st->left; + if (len < needed) { + memcpy(st->buf + st->left, data, len); + st->left += len; + return; + } + memcpy(st->buf + st->left, data, needed); + data += needed; + len -= needed; + + uint64_t hibit = ((uint64_t)1) << 40; + st->h[0] += (uint64_t)st->buf[0] | ((uint64_t)st->buf[1] << 8) | + ((uint64_t)st->buf[2] << 16) | ((uint64_t)st->buf[3] << 24); + st->h[1] += ((uint64_t)st->buf[4] | ((uint64_t)st->buf[5] << 8) | + ((uint64_t)st->buf[6] << 16) | ((uint64_t)st->buf[7] << 24)) & 0x3ffff03; + st->h[2] += ((uint64_t)st->buf[8] | ((uint64_t)st->buf[9] << 8) | + ((uint64_t)st->buf[10] << 16) | ((uint64_t)st->buf[11] << 24)) & 0x3ffc0ff; + st->h[3] += ((uint64_t)st->buf[12] | ((uint64_t)st->buf[13] << 8) | + ((uint64_t)st->buf[14] << 16) | ((uint64_t)st->buf[15] << 24)) & 0x3f03fff; + st->h[4] += hibit; + st->left = 0; + } + + while (len >= 16) { + uint64_t hibit = ((uint64_t)1) << 40; + st->h[0] += (uint64_t)data[0] | ((uint64_t)data[1] << 8) | + ((uint64_t)data[2] << 16) | ((uint64_t)data[3] << 24); + st->h[1] += ((uint64_t)data[4] | ((uint64_t)data[5] << 8) | + ((uint64_t)data[6] << 16) | ((uint64_t)data[7] << 24)) & 0x3ffff03; + st->h[2] += ((uint64_t)data[8] | ((uint64_t)data[9] << 8) | + ((uint64_t)data[10] << 16) | ((uint64_t)data[11] << 24)) & 0x3ffc0ff; + st->h[3] += ((uint64_t)data[12] | ((uint64_t)data[13] << 8) | + ((uint64_t)data[14] << 16) | ((uint64_t)data[15] << 24)) & 0x3f03fff; + st->h[4] += hibit; + + for (int i = 0; i < 5; i++) { + uint64_t d = 0; + for (int j = 0; j < 5; j++) { + d += st->h[j] * st->r[j]; + } + st->h[i] = d & 0x3ffffff; + if (i < 4) st->h[i+1] += d >> 26; + } + + data += 16; + len -= 16; + } + + if (len) { + memcpy(st->buf + st->left, data, len); + st->left += len; + } +} + +static void poly1305_final(poly1305_state_t *st, uint8_t mac[16]) +{ + if (st->left) { + uint64_t hibit = ((uint64_t)1) << (8 * st->left); + st->h[st->left >> 2] += hibit; + } + + uint64_t c = st->h[4] >> 26; + st->h[4] &= 0x3ffffff; + for (int i = 0; i < 4; i++) { + st->h[i] += c * 5; + c = st->h[i] >> 26; + st->h[i] &= 0x3ffffff; + } + st->h[4] &= 0x3ffffff; + + uint64_t g0 = st->h[0] + 5; + uint64_t g1 = st->h[1] + (g0 >> 26); + uint64_t g2 = st->h[2] + (g1 >> 26); + uint64_t g3 = st->h[3] + (g2 >> 26); + uint64_t g4 = st->h[4] + (g3 >> 26) - (1ULL << 26); + + uint64_t mask = -(g4 >> 63); + g0 += st->h[0] & mask; + g1 += st->h[1] & mask; + g2 += st->h[2] & mask; + g3 += st->h[3] & mask; + g4 += st->h[4] & mask; + + uint64_t mac0 = g0 + st->s[0]; + uint64_t mac1 = g1 + st->s[1] + (mac0 >> 32); + mac0 &= 0xFFFFFFFF; + mac1 &= 0xFFFFFFFF; + + mac[0] = (uint8_t)mac0; + mac[1] = (uint8_t)(mac0 >> 8); + mac[2] = (uint8_t)(mac0 >> 16); + mac[3] = (uint8_t)(mac0 >> 24); + mac[4] = (uint8_t)mac1; + mac[5] = (uint8_t)(mac1 >> 8); + mac[6] = (uint8_t)(mac1 >> 16); + mac[7] = (uint8_t)(mac1 >> 24); + for (int i = 8; i < 16; i++) mac[i] = 0; +} + +#endif /* SE050_CHACHA20_ESP32 */ + +void se050_poly1305_mac(uint8_t mac[16], const uint8_t key[32], + const uint8_t *data, size_t len) +{ + poly1305_state_t st; + poly1305_init(&st, key); + poly1305_update(&st, data, len); + poly1305_final(&st, mac); +} + +/* ============================================================================ + * ChaCha20-Poly1305 AEAD + * ============================================================================ */ + +int se050_chacha20_poly1305_init(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t key[CHACHA20_KEY_SIZE]) +{ + if (!ctx || !key) return -1; + memcpy(ctx->key, key, CHACHA20_KEY_SIZE); + return 0; +} + +int se050_chacha20_poly1305_encrypt(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t nonce[CHACHA20_NONCE_SIZE], + const uint8_t *plaintext, size_t plaintext_len, + const uint8_t *aad, size_t aad_len, + uint8_t *ciphertext, uint8_t tag[POLY1305_TAG_SIZE]) +{ + if (!ctx || !nonce || !plaintext || !ciphertext || !tag) return -1; + + /* Generate Poly1305 key using ChaCha20 */ + uint8_t poly_key[32] = {0}; + uint8_t block[64]; + se050_chacha20_block(block, ctx->key, 0, nonce); + memcpy(poly_key, block, 32); + + /* Compute MAC over AAD + ciphertext */ + uint8_t mac_key[32]; + memcpy(mac_key, poly_key, 32); + se050_poly1305_mac(tag, mac_key, aad, aad_len); + + /* Pad AAD */ + uint8_t pad[16] = {0}; + size_t aad_pad = (16 - (aad_len % 16)) % 16; + if (aad_pad) se050_poly1305_mac(tag, mac_key, pad, aad_pad); + + /* Encrypt plaintext */ + se050_chacha20(ciphertext, plaintext, plaintext_len, ctx->key, 1, nonce); + + /* Continue MAC over ciphertext */ + se050_poly1305_mac(tag, mac_key, ciphertext, plaintext_len); + + /* Pad ciphertext */ + size_t ct_pad = (16 - (plaintext_len % 16)) % 16; + if (ct_pad) se050_poly1305_mac(tag, mac_key, pad, ct_pad); + + /* Append lengths */ + uint8_t lengths[16]; + memset(lengths, 0, 16); + memcpy(lengths, &aad_len, 8); + memcpy(lengths + 8, &plaintext_len, 8); + se050_poly1305_mac(tag, mac_key, lengths, 16); + + /* Zeroize poly key */ + memzero_explicit(poly_key, 32); + memzero_explicit(mac_key, 32); + memzero_explicit(block, 64); + + return 0; +} + +int se050_chacha20_poly1305_decrypt(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t nonce[CHACHA20_NONCE_SIZE], + const uint8_t *ciphertext, size_t ciphertext_len, + const uint8_t *aad, size_t aad_len, + const uint8_t tag[POLY1305_TAG_SIZE], + uint8_t *plaintext) +{ + if (!ctx || !nonce || !ciphertext || !tag || !plaintext) return -1; + + /* Generate Poly1305 key */ + uint8_t poly_key[32] = {0}; + uint8_t block[64]; + se050_chacha20_block(block, ctx->key, 0, nonce); + memcpy(poly_key, block, 32); + + uint8_t mac_key[32]; + memcpy(mac_key, poly_key, 32); + + /* Compute expected MAC */ + uint8_t expected_tag[16]; + se050_poly1305_mac(expected_tag, mac_key, aad, aad_len); + + uint8_t pad[16] = {0}; + size_t aad_pad = (16 - (aad_len % 16)) % 16; + if (aad_pad) se050_poly1305_mac(expected_tag, mac_key, pad, aad_pad); + + se050_poly1305_mac(expected_tag, mac_key, ciphertext, ciphertext_len); + + size_t ct_pad = (16 - (ciphertext_len % 16)) % 16; + if (ct_pad) se050_poly1305_mac(expected_tag, mac_key, pad, ct_pad); + + uint8_t lengths[16]; + memset(lengths, 0, 16); + memcpy(lengths, &aad_len, 8); + memcpy(lengths + 8, &ciphertext_len, 8); + se050_poly1305_mac(expected_tag, mac_key, lengths, 16); + + /* Constant-time comparison */ + if (crypto_memneq(expected_tag, tag, 16) != 0) { + memzero_explicit(poly_key, 32); + memzero_explicit(mac_key, 32); + memzero_explicit(block, 64); + return -1; + } + + /* Decrypt ciphertext */ + se050_chacha20(plaintext, ciphertext, ciphertext_len, ctx->key, 1, nonce); + + memzero_explicit(poly_key, 32); + memzero_explicit(mac_key, 32); + memzero_explicit(block, 64); + + return 0; +} + +int se050_wireguard_encrypt(const uint8_t key[WG_KEY_SIZE], + const uint8_t nonce[WG_NONCE_SIZE], + const uint8_t *plaintext, size_t len, + uint8_t *ciphertext, uint8_t tag[POLY1305_TAG_SIZE]) +{ + se050_chacha20_poly1305_ctx_t ctx; + int ret = se050_chacha20_poly1305_init(&ctx, key); + if (ret != 0) return ret; + + ret = se050_chacha20_poly1305_encrypt(&ctx, nonce, plaintext, len, + NULL, 0, ciphertext, tag); + + se050_chacha20_poly1305_zeroize(&ctx); + return ret; +} + +int se050_wireguard_decrypt(const uint8_t key[WG_KEY_SIZE], + const uint8_t nonce[WG_NONCE_SIZE], + const uint8_t *ciphertext, size_t len, + const uint8_t tag[POLY1305_TAG_SIZE], + uint8_t *plaintext) +{ + se050_chacha20_poly1305_ctx_t ctx; + int ret = se050_chacha20_poly1305_init(&ctx, key); + if (ret != 0) return ret; + + ret = se050_chacha20_poly1305_decrypt(&ctx, nonce, ciphertext, len, + NULL, 0, tag, plaintext); + + se050_chacha20_poly1305_zeroize(&ctx); + return ret; +} + +void se050_chacha20_poly1305_zeroize(se050_chacha20_poly1305_ctx_t *ctx) +{ + if (ctx) { + memzero_explicit(ctx->key, CHACHA20_KEY_SIZE); + } +} + +#ifdef CHACHA20_POLY1305_TEST +#include + +/* RFC 7539 Test Vector 1 */ +static const uint8_t RFC7539_KEY[32] = { + 0x80,0x81,0x82,0x83,0x84,0x85,0x86,0x87,0x88,0x89,0x8a,0x8b,0x8c,0x8d,0x8e,0x8f, + 0x90,0x91,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0x9b,0x9c,0x9d,0x9e,0x9f +}; + +static const uint8_t RFC7539_NONCE[12] = { + 0x07,0x00,0x00,0x00,0x40,0x41,0x42,0x43,0x44,0x45,0x46,0x47 +}; + +static const uint8_t RFC7539_AAD[16] = { + 0x50,0x51,0x52,0x53,0xc0,0xc1,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xcb +}; + +static const uint8_t RFC7539_PLAINTEXT[114] = { + 0x4c,0x61,0x64,0x69,0x65,0x73,0x20,0x61,0x6e,0x64,0x20,0x47,0x65,0x6e,0x74,0x6c, + 0x65,0x6d,0x65,0x6e,0x20,0x6f,0x66,0x20,0x74,0x68,0x65,0x20,0x63,0x6c,0x61,0x73, + 0x73,0x20,0x6f,0x66,0x20,0x27,0x39,0x39,0x3a,0x20,0x49,0x66,0x20,0x49,0x20,0x63, + 0x6f,0x75,0x6c,0x64,0x20,0x6f,0x66,0x66,0x65,0x72,0x20,0x79,0x6f,0x75,0x20,0x6f, + 0x6e,0x6c,0x79,0x20,0x6f,0x6e,0x65,0x20,0x74,0x69,0x70,0x20,0x66,0x6f,0x72,0x20, + 0x74,0x68,0x65,0x20,0x66,0x75,0x74,0x75,0x72,0x65,0x2c,0x20,0x73,0x75,0x6e,0x73, + 0x63,0x72,0x65,0x65,0x6e,0x20,0x77,0x6f,0x75,0x6c,0x64,0x20,0x62,0x65,0x20,0x69, + 0x74,0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00 +}; + +static const uint8_t RFC7539_CIPHERTEXT[114] = { + 0xd3,0x1a,0x8d,0x34,0x64,0x8e,0x60,0xdb,0x7b,0x86,0xaf,0xbc,0x53,0xef,0x7e,0xc2, + 0xa4,0xad,0xed,0x51,0x29,0x6e,0x08,0xfe,0xa9,0xe2,0xb5,0xa7,0x36,0xee,0x62,0xd6, + 0x3d,0xbe,0xa4,0x5e,0x8c,0xa9,0x67,0x12,0x82,0xfa,0xfb,0x69,0xda,0x92,0x72,0x8b, + 0x1a,0x71,0xde,0x0a,0x9e,0x06,0x0b,0x29,0x05,0xd6,0xa5,0xb6,0x7e,0xcd,0x3b,0x36, + 0x92,0xdd,0xbd,0x7f,0x2d,0x77,0x8b,0x8c,0x98,0x03,0xae,0xe3,0x28,0x09,0x1b,0x58, + 0xfa,0xb3,0x24,0xe4,0xfa,0xd6,0x75,0x94,0x55,0x85,0x80,0x8b,0x48,0x31,0xd7,0xbc, + 0x3f,0xf4,0xde,0xf0,0x8e,0x4b,0x7a,0x9d,0xe,0xa8,0x2a,0xb4,0x68,0x8,0xd,0x61,0xb9, + 0x3,0x8,0x7,0x6,0x5,0x4,0x3,0x2,0x1,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0, + 0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0,0x0 +}; + +static const uint8_t RFC7539_TAG[16] = { + 0x1a,0xe1,0x0b,0x59,0x4f,0x09,0xe2,0x6a,0x7e,0x90,0x2e,0xcb,0xd0,0x60,0x06,0x91 +}; + +static void print_hex(const char *label, const uint8_t *buf, size_t len) +{ + printf("%s: ", label); + for (size_t i = 0; i < len; i++) printf("%02x", buf[i]); + printf("\n"); +} + +int main(void) +{ + printf("ChaCha20-Poly1305 Test Suite\n"); + printf("============================\n\n"); + + uint8_t ciphertext[114]; + uint8_t tag[16]; + uint8_t plaintext[114]; + + printf("Test 1: RFC 7539 Encryption\n"); + se050_chacha20_poly1305_ctx_t ctx; + se050_chacha20_poly1305_init(&ctx, RFC7539_KEY); + se050_chacha20_poly1305_encrypt(&ctx, RFC7539_NONCE, RFC7539_PLAINTEXT, 114, + RFC7539_AAD, 16, ciphertext, tag); + + printf("Computed Tag:\n"); + print_hex(" ", tag, 16); + printf("Expected Tag:\n"); + print_hex(" ", RFC7539_TAG, 16); + + if (memcmp(tag, RFC7539_TAG, 16) == 0) { + printf("[PASS] RFC 7539 Encryption\n\n"); + } else { + printf("[FAIL] RFC 7539 Encryption\n\n"); + } + + printf("Test 2: RFC 7539 Decryption\n"); + int ret = se050_chacha20_poly1305_decrypt(&ctx, RFC7539_NONCE, RFC7539_CIPHERTEXT, 114, + RFC7539_AAD, 16, RFC7539_TAG, plaintext); + if (ret != 0) { + printf("[FAIL] Decryption failed\n"); + } else if (memcmp(plaintext, RFC7539_PLAINTEXT, 114) == 0) { + printf("[PASS] RFC 7539 Decryption\n\n"); + } else { + printf("[FAIL] Decrypted plaintext mismatch\n"); + } + + se050_chacha20_poly1305_zeroize(&ctx); + + printf("============================\n"); + return 0; +} +#endif