diff --git a/src/se050_chacha20_poly1305.c b/src/se050_chacha20_poly1305.c index 35f137a..5c5fb9d 100644 --- a/src/se050_chacha20_poly1305.c +++ b/src/se050_chacha20_poly1305.c @@ -1,575 +1,385 @@ /** * @file se050_chacha20_poly1305.c * @brief ChaCha20-Poly1305 AEAD Implementation - * Based on RFC 7539 and RFC 8434 (WireGuard) + * + * Based on RFC 8439 (ChaCha20-Poly1305) and the WireGuard protocol spec. * License: MIT (Clean-room implementation) + * + * Verified against RFC 8439 Section 2.8.2 test vector. + * + * Design notes + * ------------ + * Poly1305 state uses the 5×26-bit limb representation throughout. + * The same poly1305_* implementation is used for both ESP32 and 64-bit + * targets; 64-bit intermediates (uint64_t) keep carry arithmetic simple + * and correct on both platforms. */ #include "se050_chacha20_poly1305.h" #include "se050_crypto_utils.h" #include -/* ESP32 detection */ -#if defined(ESP_PLATFORM) || defined(__XTENSA__) || defined(__riscv) -#define SE050_CHACHA20_ESP32 1 -#else -#define SE050_CHACHA20_ESP32 0 -#endif - /* ============================================================================ - * ChaCha20 Implementation + * ChaCha20 * ============================================================================ */ -/** - * @brief ChaCha20 quarter round - */ -void se050_chacha20_quarter_round(uint32_t *a, uint32_t *b, uint32_t *c, uint32_t *d) +static inline uint32_t rotl32(uint32_t x, unsigned n) { - *a += *b; *d ^= *a; *d <<= 16; *d |= (*d >> 16); - *c += *d; *b ^= *c; *b <<= 12; *b |= (*b >> 20); - *a += *b; *d ^= *a; *d <<= 8; *d |= (*d >> 24); - *c += *d; *b ^= *c; *b <<= 7; *b |= (*b >> 25); + return (x << n) | (x >> (32u - n)); +} + +#define QR(a, b, c, d) \ + a += b; d ^= a; d = rotl32(d, 16); \ + c += d; b ^= c; b = rotl32(b, 12); \ + a += b; d ^= a; d = rotl32(d, 8); \ + c += d; b ^= c; b = rotl32(b, 7); + +static inline uint32_t load32_le(const uint8_t *p) +{ + return (uint32_t)p[0] + | ((uint32_t)p[1] << 8) + | ((uint32_t)p[2] << 16) + | ((uint32_t)p[3] << 24); } /** - * @brief ChaCha20 block function + * @brief ChaCha20 block function (RFC 8439 §2.1) + * @param output 64-byte keystream block + * @param key 32-byte key + * @param counter 32-bit block counter (little-endian in the state) + * @param nonce 12-byte nonce */ void se050_chacha20_block(uint8_t output[64], const uint8_t key[32], uint32_t counter, const uint8_t nonce[12]) { - /* Constants "expand 32-byte k" */ - uint32_t state[16] = { - 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + uint32_t s[16] = { + /* "expand 32-byte k" */ + 0x61707865u, 0x3320646eu, 0x79622d32u, 0x6b206574u, + /* key words 0-7 */ + load32_le(key + 0), load32_le(key + 4), + load32_le(key + 8), load32_le(key + 12), + load32_le(key + 16), load32_le(key + 20), + load32_le(key + 24), load32_le(key + 28), + /* counter, nonce words 0-2 */ + counter, + load32_le(nonce + 0), load32_le(nonce + 4), load32_le(nonce + 8) }; - - /* Load key (8 words) */ - state[4] = (uint32_t)key[0] | ((uint32_t)key[1] << 8) | - ((uint32_t)key[2] << 16) | ((uint32_t)key[3] << 24); - state[5] = (uint32_t)key[4] | ((uint32_t)key[5] << 8) | - ((uint32_t)key[6] << 16) | ((uint32_t)key[7] << 24); - state[6] = (uint32_t)key[8] | ((uint32_t)key[9] << 8) | - ((uint32_t)key[10] << 16) | ((uint32_t)key[11] << 24); - state[7] = (uint32_t)key[12] | ((uint32_t)key[13] << 8) | - ((uint32_t)key[14] << 16) | ((uint32_t)key[15] << 24); - state[8] = (uint32_t)key[16] | ((uint32_t)key[17] << 8) | - ((uint32_t)key[18] << 16) | ((uint32_t)key[19] << 24); - state[9] = (uint32_t)key[20] | ((uint32_t)key[21] << 8) | - ((uint32_t)key[22] << 16) | ((uint32_t)key[23] << 24); - state[10] = (uint32_t)key[24] | ((uint32_t)key[25] << 8) | - ((uint32_t)key[26] << 16) | ((uint32_t)key[27] << 24); - state[11] = (uint32_t)key[28] | ((uint32_t)key[29] << 8) | - ((uint32_t)key[30] << 16) | ((uint32_t)key[31] << 24); - - /* Load counter and nonce */ - state[12] = counter; - state[13] = (uint32_t)nonce[0] | ((uint32_t)nonce[1] << 8) | - ((uint32_t)nonce[2] << 16) | ((uint32_t)nonce[3] << 24); - state[14] = (uint32_t)nonce[4] | ((uint32_t)nonce[5] << 8) | - ((uint32_t)nonce[6] << 16) | ((uint32_t)nonce[7] << 24); - state[15] = (uint32_t)nonce[8] | ((uint32_t)nonce[9] << 8) | - ((uint32_t)nonce[10] << 16) | ((uint32_t)nonce[11] << 24); - - /* Save initial state */ - uint32_t initial[16]; - memcpy(initial, state, sizeof(initial)); - - /* 20 rounds (10 double rounds) */ + uint32_t w[16]; + memcpy(w, s, sizeof(s)); + for (int i = 0; i < 10; i++) { - /* Column rounds */ - se050_chacha20_quarter_round(&state[0], &state[4], &state[8], &state[12]); - se050_chacha20_quarter_round(&state[1], &state[5], &state[9], &state[13]); - se050_chacha20_quarter_round(&state[2], &state[6], &state[10], &state[14]); - se050_chacha20_quarter_round(&state[3], &state[7], &state[11], &state[15]); - /* Diagonal rounds */ - se050_chacha20_quarter_round(&state[0], &state[5], &state[10], &state[15]); - se050_chacha20_quarter_round(&state[1], &state[6], &state[11], &state[12]); - se050_chacha20_quarter_round(&state[2], &state[7], &state[8], &state[13]); - se050_chacha20_quarter_round(&state[3], &state[4], &state[9], &state[14]); + /* column rounds */ + QR(w[ 0], w[ 4], w[ 8], w[12]); + QR(w[ 1], w[ 5], w[ 9], w[13]); + QR(w[ 2], w[ 6], w[10], w[14]); + QR(w[ 3], w[ 7], w[11], w[15]); + /* diagonal rounds */ + QR(w[ 0], w[ 5], w[10], w[15]); + QR(w[ 1], w[ 6], w[11], w[12]); + QR(w[ 2], w[ 7], w[ 8], w[13]); + QR(w[ 3], w[ 4], w[ 9], w[14]); } - - /* Add initial state */ + for (int i = 0; i < 16; i++) { - state[i] += initial[i]; - } - - /* Output little-endian */ - for (int i = 0; i < 16; i++) { - output[i*4] = (uint8_t)(state[i]); - output[i*4+1] = (uint8_t)(state[i] >> 8); - output[i*4+2] = (uint8_t)(state[i] >> 16); - output[i*4+3] = (uint8_t)(state[i] >> 24); + uint32_t v = w[i] + s[i]; + output[i*4+0] = (uint8_t)(v); + output[i*4+1] = (uint8_t)(v >> 8); + output[i*4+2] = (uint8_t)(v >> 16); + output[i*4+3] = (uint8_t)(v >> 24); } } /** - * @brief ChaCha20 encrypt/decrypt + * @brief ChaCha20 stream cipher (encrypt or decrypt; symmetric) + * @param counter Starting block counter (use 1 for AEAD payload) */ void se050_chacha20(uint8_t *output, const uint8_t *input, size_t len, - const uint8_t key[32], uint32_t counter, const uint8_t nonce[12]) + const uint8_t key[32], uint32_t counter, + const uint8_t nonce[12]) { uint8_t block[64]; - size_t i; - - for (i = 0; i < len; i += 64) { + while (len > 0) { se050_chacha20_block(block, key, counter++, nonce); - size_t chunk = (i + 64 <= len) ? 64 : (len - i); - for (size_t j = 0; j < chunk; j++) { - output[i + j] = input[i + j] ^ block[j]; - } + size_t chunk = len < 64u ? len : 64u; + for (size_t j = 0; j < chunk; j++) + output[j] = input[j] ^ block[j]; + output += chunk; + input += chunk; + len -= chunk; } + memzero_explicit(block, sizeof(block)); } /* ============================================================================ - * Poly1305 Implementation + * Poly1305 — 5 × 26-bit limb representation + * + * State: h = h[0] + h[1]*2^26 + h[2]*2^52 + h[3]*2^78 + h[4]*2^104 + * r, s loaded from the 32-byte one-time key. + * + * Reference: RFC 8439 §2.5, D.J. Bernstein's original paper. * ============================================================================ */ -#if SE050_CHACHA20_ESP32 -/* ESP32 32-bit optimized Poly1305 */ - typedef struct { - uint32_t r[5]; - uint32_t h[5]; - uint32_t s[2]; - uint8_t buf[16]; - size_t left; + uint64_t r[5]; /* clamped r, 26-bit limbs */ + uint64_t h[5]; /* accumulator, 26-bit limbs */ + uint32_t s[4]; /* s = key[16..31], four 32-bit words */ + uint8_t buf[16]; + size_t left; /* bytes in buf (0..15) */ } poly1305_state_t; static void poly1305_init(poly1305_state_t *st, const uint8_t key[32]) { - /* Clamp r: r &= 0x0ffffffc0ffffffc0ffffffc0fffffff */ - uint32_t r0 = ((uint32_t)key[0]) | ((uint32_t)key[1] << 8) | - ((uint32_t)key[2] << 16) | ((uint32_t)key[3] << 24); - uint32_t r1 = ((uint32_t)key[4]) | ((uint32_t)key[5] << 8) | - ((uint32_t)key[6] << 16) | ((uint32_t)key[7] << 24); - uint32_t r2 = ((uint32_t)key[8]) | ((uint32_t)key[9] << 8) | - ((uint32_t)key[10] << 16) | ((uint32_t)key[11] << 24); - uint32_t r3 = ((uint32_t)key[12]) | ((uint32_t)key[13] << 8) | - ((uint32_t)key[14] << 16) | ((uint32_t)key[15] << 24); - - st->r[0] = r0 & 0x3ffffff; - st->r[1] = (r1 >> 2) & 0x3ffff03; - st->r[2] = ((r1 >> 30) | (r2 << 4)) & 0x3ffc0ff; - st->r[3] = ((r2 >> 22) | (r3 << 12)) & 0x3f03fff; - st->r[4] = (r3 >> 10) & 0x00fffff; - - /* s0, s1 = (r0 + 5) mod 2^13, (r1 + 5) mod 2^13, ... for carry handling */ - st->s[0] = ((uint32_t)key[16]) | ((uint32_t)key[17] << 8) | - ((uint32_t)key[18] << 16) | ((uint32_t)key[19] << 24); - st->s[1] = ((uint32_t)key[20]) | ((uint32_t)key[21] << 8) | - ((uint32_t)key[22] << 16) | ((uint32_t)key[23] << 24); - - for (int i = 0; i < 5; i++) st->h[i] = 0; + /* + * r = key[0..15], split into 26-bit limbs with clamping. + * Clamping clears specific bits per RFC 8439 §2.5.1. + * + * The 128 bits of r span key[0..15]: + * r[0] = bits 0..25 of r → key[0..3] & 0x3ffffff + * r[1] = bits 26..51 of r → clamp mask 0x3ffff03 + * r[2] = bits 52..77 of r → clamp mask 0x3ffc0ff + * r[3] = bits 78..103 of r → clamp mask 0x3f03fff + * r[4] = bits 104..127 of r → clamp mask 0x00fffff + */ + uint64_t t0 = (uint64_t)key[ 0] | ((uint64_t)key[ 1] << 8) + | ((uint64_t)key[ 2] << 16) | ((uint64_t)key[ 3] << 24); + uint64_t t1 = (uint64_t)key[ 4] | ((uint64_t)key[ 5] << 8) + | ((uint64_t)key[ 6] << 16) | ((uint64_t)key[ 7] << 24); + uint64_t t2 = (uint64_t)key[ 8] | ((uint64_t)key[ 9] << 8) + | ((uint64_t)key[10] << 16) | ((uint64_t)key[11] << 24); + uint64_t t3 = (uint64_t)key[12] | ((uint64_t)key[13] << 8) + | ((uint64_t)key[14] << 16) | ((uint64_t)key[15] << 24); + + st->r[0] = t0 & 0x3ffffffULL; + st->r[1] = ((t0 >> 26) | (t1 << 6)) & 0x3ffff03ULL; + st->r[2] = ((t1 >> 20) | (t2 << 12)) & 0x3ffc0ffULL; + st->r[3] = ((t2 >> 14) | (t3 << 18)) & 0x3f03fffULL; + st->r[4] = (t3 >> 8) & 0x00fffffULL; + + /* s = key[16..31], four little-endian 32-bit words */ + st->s[0] = (uint32_t)key[16] | ((uint32_t)key[17] << 8) + | ((uint32_t)key[18] << 16) | ((uint32_t)key[19] << 24); + st->s[1] = (uint32_t)key[20] | ((uint32_t)key[21] << 8) + | ((uint32_t)key[22] << 16) | ((uint32_t)key[23] << 24); + st->s[2] = (uint32_t)key[24] | ((uint32_t)key[25] << 8) + | ((uint32_t)key[26] << 16) | ((uint32_t)key[27] << 24); + st->s[3] = (uint32_t)key[28] | ((uint32_t)key[29] << 8) + | ((uint32_t)key[30] << 16) | ((uint32_t)key[31] << 24); + + st->h[0] = st->h[1] = st->h[2] = st->h[3] = st->h[4] = 0; st->left = 0; } +/* + * Process one 16-byte block: h = (h + m) * r mod 2^130 - 5 + * + * @param pad128 true → full block, append bit 2^128 (complete message chunk) + * false → partial final block, append 2^(8*actual_len) already + * encoded in the loaded words (handled by caller) + */ +static void poly1305_block(poly1305_state_t *st, + const uint8_t m[16], int pad128) +{ + /* Load 16 bytes as four 32-bit LE words, split into 26-bit limbs */ + uint64_t d0 = (uint64_t)m[ 0] | ((uint64_t)m[ 1] << 8) + | ((uint64_t)m[ 2] << 16) | ((uint64_t)m[ 3] << 24); + uint64_t d1 = (uint64_t)m[ 4] | ((uint64_t)m[ 5] << 8) + | ((uint64_t)m[ 6] << 16) | ((uint64_t)m[ 7] << 24); + uint64_t d2 = (uint64_t)m[ 8] | ((uint64_t)m[ 9] << 8) + | ((uint64_t)m[10] << 16) | ((uint64_t)m[11] << 24); + uint64_t d3 = (uint64_t)m[12] | ((uint64_t)m[13] << 8) + | ((uint64_t)m[14] << 16) | ((uint64_t)m[15] << 24); + + /* Split into 26-bit limbs */ + uint64_t m0 = d0 & 0x3ffffffULL; + uint64_t m1 = ((d0 >> 26) | (d1 << 6)) & 0x3ffffffULL; + uint64_t m2 = ((d1 >> 20) | (d2 << 12)) & 0x3ffffffULL; + uint64_t m3 = ((d2 >> 14) | (d3 << 18)) & 0x3ffffffULL; + uint64_t m4 = (d3 >> 8); + if (pad128) m4 |= (1ULL << 24); /* 2^128 in 26-bit limb representation */ + + /* h += m */ + uint64_t h0 = st->h[0] + m0; + uint64_t h1 = st->h[1] + m1; + uint64_t h2 = st->h[2] + m2; + uint64_t h3 = st->h[3] + m3; + uint64_t h4 = st->h[4] + m4; + + /* h = h * r mod 2^130 - 5 + * Using the identity: x * 2^130 ≡ x * 5 (mod 2^130 - 5) + * so r[i]*5 terms fold the high limbs back in. + */ + uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], + r3 = st->r[3], r4 = st->r[4]; + uint64_t r1_5 = r1 * 5, r2_5 = r2 * 5, r3_5 = r3 * 5, r4_5 = r4 * 5; + + uint64_t t0 = h0*r0 + h1*r4_5 + h2*r3_5 + h3*r2_5 + h4*r1_5; + uint64_t t1 = h0*r1 + h1*r0 + h2*r4_5 + h3*r3_5 + h4*r2_5; + uint64_t t2 = h0*r2 + h1*r1 + h2*r0 + h3*r4_5 + h4*r3_5; + uint64_t t3 = h0*r3 + h1*r2 + h2*r1 + h3*r0 + h4*r4_5; + uint64_t t4 = h0*r4 + h1*r3 + h2*r2 + h3*r1 + h4*r0; + + /* Carry propagation (all limbs → 26 bits) */ + uint64_t c; + c = t0 >> 26; st->h[0] = t0 & 0x3ffffffULL; t1 += c; + c = t1 >> 26; st->h[1] = t1 & 0x3ffffffULL; t2 += c; + c = t2 >> 26; st->h[2] = t2 & 0x3ffffffULL; t3 += c; + c = t3 >> 26; st->h[3] = t3 & 0x3ffffffULL; t4 += c; + c = t4 >> 26; st->h[4] = t4 & 0x3ffffffULL; + /* Wrap-around carry: h[4] overflow → h[0] += 5 * overflow */ + st->h[0] += c * 5; + c = st->h[0] >> 26; st->h[0] &= 0x3ffffffULL; + st->h[1] += c; +} + static void poly1305_update(poly1305_state_t *st, const uint8_t *data, size_t len) { - /* Handle partial buffer */ - if (st->left) { - size_t needed = 16 - st->left; - if (len < needed) { + /* Fill partial buffer first */ + if (st->left > 0) { + size_t need = 16 - st->left; + if (len < need) { memcpy(st->buf + st->left, data, len); st->left += len; return; } - memcpy(st->buf + st->left, data, needed); - data += needed; - len -= needed; - - /* Process 16-byte block */ - uint32_t hibit = 0x01000000; /* 2^128 as high bit in 4th word */ - uint32_t d0 = st->buf[0] | (st->buf[1] << 8) | (st->buf[2] << 16) | ((st->buf[3] | hibit) << 24); - uint32_t d1 = (st->buf[4] | (st->buf[5] << 8) | (st->buf[6] << 16) | (st->buf[7] << 24)) & 0x3ffff03; - uint32_t d2 = (st->buf[8] | (st->buf[9] << 8) | (st->buf[10] << 16) | (st->buf[11] << 24)) & 0x3ffc0ff; - uint32_t d3 = (st->buf[12] | (st->buf[13] << 8) | (st->buf[14] << 16) | (st->buf[15] | hibit) << 24) & 0x3f03fff; - - st->h[0] += d0; - st->h[1] += d1; - st->h[2] += d2; - st->h[3] += d3; - st->h[4] += 0; /* hibit goes to position 4 */ - - /* Multiply by r and reduce mod 2^130 - 5 */ - uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; - uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; - - /* Compute h = h * r mod 2^130 - 5 */ - uint64_t t0 = h0 * r0 + h1 * (r4 * 5) + h2 * (r3 * 5) + h3 * (r2 * 5) + h4 * (r1 * 5); - uint64_t t1 = h0 * r1 + h1 * r0 + h2 * (r4 * 5) + h3 * (r3 * 5) + h4 * (r2 * 5); - uint64_t t2 = h0 * r2 + h1 * r1 + h2 * r0 + h3 * (r4 * 5) + h4 * (r3 * 5); - uint64_t t3 = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * (r4 * 5); - uint64_t t4 = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0; - - /* Carry propagation */ - uint32_t c = (uint32_t)(t0 >> 26); - st->h[0] = (uint32_t)(t0 & 0x3ffffff); - t1 += c; - c = (uint32_t)(t1 >> 22); - st->h[1] = (uint32_t)(t1 & 0x3ffffff); - t2 += c; - c = (uint32_t)(t2 >> 26); - st->h[2] = (uint32_t)(t2 & 0x3ffffff); - t3 += c; - c = (uint32_t)(t3 >> 26); - st->h[3] = (uint32_t)(t3 & 0x3ffffff); - t4 += c; - st->h[4] = (uint32_t)(t4 & 0x3ffffff); - + memcpy(st->buf + st->left, data, need); + poly1305_block(st, st->buf, 1); st->left = 0; + data += need; + len -= need; } - - /* Process full 16-byte blocks */ + + /* Full blocks */ while (len >= 16) { - uint32_t hibit = 0x01000000; - uint32_t d0 = data[0] | (data[1] << 8) | (data[2] << 16) | ((data[3] | hibit) << 24); - uint32_t d1 = (data[4] | (data[5] << 8) | (data[6] << 16) | (data[7] << 24)) & 0x3ffff03; - uint32_t d2 = (data[8] | (data[9] << 8) | (data[10] << 16) | (data[11] << 24)) & 0x3ffc0ff; - uint32_t d3 = (data[12] | (data[13] << 8) | (data[14] << 16) | ((data[15] | hibit) << 24)) & 0x3f03fff; - - st->h[0] += d0; - st->h[1] += d1; - st->h[2] += d2; - st->h[3] += d3; - st->h[4] += 0; - - /* Multiply by r */ - uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; - uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; - - uint64_t t0 = h0 * r0 + h1 * (r4 * 5) + h2 * (r3 * 5) + h3 * (r2 * 5) + h4 * (r1 * 5); - uint64_t t1 = h0 * r1 + h1 * r0 + h2 * (r4 * 5) + h3 * (r3 * 5) + h4 * (r2 * 5); - uint64_t t2 = h0 * r2 + h1 * r1 + h2 * r0 + h3 * (r4 * 5) + h4 * (r3 * 5); - uint64_t t3 = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * (r4 * 5); - uint64_t t4 = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0; - - uint32_t c = (uint32_t)(t0 >> 26); - st->h[0] = (uint32_t)(t0 & 0x3ffffff); - t1 += c; - c = (uint32_t)(t1 >> 22); - st->h[1] = (uint32_t)(t1 & 0x3ffffff); - t2 += c; - c = (uint32_t)(t2 >> 26); - st->h[2] = (uint32_t)(t2 & 0x3ffffff); - t3 += c; - c = (uint32_t)(t3 >> 26); - st->h[3] = (uint32_t)(t3 & 0x3ffffff); - t4 += c; - st->h[4] = (uint32_t)(t4 & 0x3ffffff); - + poly1305_block(st, data, 1); data += 16; - len -= 16; + len -= 16; } - - /* Save remaining data */ - if (len) { - memcpy(st->buf + st->left, data, len); - st->left += len; + + /* Buffer remainder */ + if (len > 0) { + memcpy(st->buf, data, len); + st->left = len; } } static void poly1305_final(poly1305_state_t *st, uint8_t mac[16]) { - /* Process remaining bytes */ - if (st->left) { - uint32_t hibit = 0x01000000; - /* Pad with 0x01 byte after data */ - st->buf[st->left] = 1; - for (size_t i = st->left + 1; i < 16; i++) { - st->buf[i] = 0; - } - - uint32_t d0 = st->buf[0] | (st->buf[1] << 8) | (st->buf[2] << 16) | ((st->buf[3] | hibit) << 24); - uint32_t d1 = (st->buf[4] | (st->buf[5] << 8) | (st->buf[6] << 16) | (st->buf[7] << 24)) & 0x3ffff03; - uint32_t d2 = (st->buf[8] | (st->buf[9] << 8) | (st->buf[10] << 16) | (st->buf[11] << 24)) & 0x3ffc0ff; - uint32_t d3 = (st->buf[12] | (st->buf[13] << 8) | (st->buf[14] << 16) | ((st->buf[15] | hibit) << 24)) & 0x3f03fff; - - st->h[0] += d0; - st->h[1] += d1; - st->h[2] += d2; - st->h[3] += d3; - st->h[4] += 0; - - /* Multiply by r one last time */ - uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; - uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; - - uint64_t t0 = h0 * r0 + h1 * (r4 * 5) + h2 * (r3 * 5) + h3 * (r2 * 5) + h4 * (r1 * 5); - uint64_t t1 = h0 * r1 + h1 * r0 + h2 * (r4 * 5) + h3 * (r3 * 5) + h4 * (r2 * 5); - uint64_t t2 = h0 * r2 + h1 * r1 + h2 * r0 + h3 * (r4 * 5) + h4 * (r3 * 5); - uint64_t t3 = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * (r4 * 5); - uint64_t t4 = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0; - - uint32_t c = (uint32_t)(t0 >> 26); - st->h[0] = (uint32_t)(t0 & 0x3ffffff); - t1 += c; - c = (uint32_t)(t1 >> 22); - st->h[1] = (uint32_t)(t1 & 0x3ffffff); - t2 += c; - c = (uint32_t)(t2 >> 26); - st->h[2] = (uint32_t)(t2 & 0x3ffffff); - t3 += c; - c = (uint32_t)(t3 >> 26); - st->h[3] = (uint32_t)(t3 & 0x3ffffff); - t4 += c; - st->h[4] = (uint32_t)(t4 & 0x3ffffff); + /* Process partial final block (if any) with 2^(8*n) pad instead of 2^128 */ + if (st->left > 0) { + uint8_t padded[16]; + memcpy(padded, st->buf, st->left); + padded[st->left] = 0x01; + memset(padded + st->left + 1, 0, 16 - st->left - 1); + /* The 0x01 byte IS the pad bit; do NOT set pad128 flag */ + poly1305_block(st, padded, 0); + memzero_explicit(padded, sizeof(padded)); } - - /* Final reduction: add 5 * carry from h[4] */ - uint32_t c = st->h[4] + 5; - st->h[4] &= 0x3ffffff; - st->h[0] += (c >> 26); - st->h[1] += (st->h[0] >> 26); - st->h[0] &= 0x3ffffff; - st->h[2] += (st->h[1] >> 22); - st->h[1] &= 0x3ffffff; - st->h[3] += (st->h[2] >> 26); - st->h[2] &= 0x3ffffff; - st->h[4] += (st->h[3] >> 22); - st->h[3] &= 0x3ffffff; - - /* Add s[0], s[1] */ - uint64_t mac0 = (uint64_t)st->h[0] + st->s[0]; - uint64_t mac1 = (uint64_t)st->h[1] + st->s[1] + (mac0 >> 32); - mac0 &= 0xFFFFFFFF; - mac1 &= 0xFFFFFFFF; - - mac[0] = (uint8_t)mac0; - mac[1] = (uint8_t)(mac0 >> 8); - mac[2] = (uint8_t)(mac0 >> 16); - mac[3] = (uint8_t)(mac0 >> 24); - mac[4] = (uint8_t)mac1; - mac[5] = (uint8_t)(mac1 >> 8); - mac[6] = (uint8_t)(mac1 >> 16); - mac[7] = (uint8_t)(mac1 >> 24); - for (int i = 8; i < 16; i++) mac[i] = 0; -} -#else -/* Standard 64-bit Poly1305 */ + /* + * Final reduction: bring h fully into [0, 2^130-5). + * + * Two passes: + * 1. Propagate any carry from h[4] back through h[0..4]. + * 2. Conditionally subtract p = 2^130 - 5 if h >= p. + * Done in constant-time using a mask derived from the borrow. + */ + uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], + h3 = st->h[3], h4 = st->h[4]; -typedef struct { - uint64_t r[5]; - uint64_t h[5]; - uint64_t s[2]; - uint8_t buf[16]; - size_t left; -} poly1305_state_t; + /* Pass 1: full carry propagation */ + uint64_t c; + c = h4 >> 26; h4 &= 0x3ffffffULL; h0 += c * 5; + c = h0 >> 26; h0 &= 0x3ffffffULL; h1 += c; + c = h1 >> 26; h1 &= 0x3ffffffULL; h2 += c; + c = h2 >> 26; h2 &= 0x3ffffffULL; h3 += c; + c = h3 >> 26; h3 &= 0x3ffffffULL; h4 += c; -static void poly1305_init(poly1305_state_t *st, const uint8_t key[32]) -{ - /* r = key[0..15], clamp された値を 26 ビットリムに展開 */ - uint64_t t0 = (uint64_t)key[0] | ((uint64_t)key[1] << 8) | - ((uint64_t)key[2] << 16) | ((uint64_t)key[3] << 24); - uint64_t t1 = (uint64_t)key[4] | ((uint64_t)key[5] << 8) | - ((uint64_t)key[6] << 16) | ((uint64_t)key[7] << 24); - uint64_t t2 = (uint64_t)key[8] | ((uint64_t)key[9] << 8) | - ((uint64_t)key[10] << 16) | ((uint64_t)key[11] << 24); - uint64_t t3 = (uint64_t)key[12] | ((uint64_t)key[13] << 8) | - ((uint64_t)key[14] << 16) | ((uint64_t)key[15] << 24); + /* Pass 2: try h - p = h + 5 - 2^130; keep if no borrow */ + uint64_t g0 = h0 + 5; + c = g0 >> 26; g0 &= 0x3ffffffULL; + uint64_t g1 = h1 + c; c = g1 >> 26; g1 &= 0x3ffffffULL; + uint64_t g2 = h2 + c; c = g2 >> 26; g2 &= 0x3ffffffULL; + uint64_t g3 = h3 + c; c = g3 >> 26; g3 &= 0x3ffffffULL; + uint64_t g4 = h4 + c - (1ULL << 26); - /* 26 ビットリムに分割して clamp */ - st->r[0] = t0 & 0x3ffffff; - st->r[1] = ((t0 >> 26) | (t1 << 6)) & 0x3ffff03; - st->r[2] = ((t1 >> 20) | (t2 << 12)) & 0x3ffc0ff; - st->r[3] = ((t2 >> 14) | (t3 << 18)) & 0x3f03fff; - st->r[4] = (t3 >> 8) & 0x00fffff; + /* mask = 0xfff...fff if g4 did NOT underflow (i.e. h >= p), else 0 */ + uint64_t mask = (uint64_t)(-(int64_t)(1 - (g4 >> 63))); + h0 = (h0 & ~mask) | (g0 & mask); + h1 = (h1 & ~mask) | (g1 & mask); + h2 = (h2 & ~mask) | (g2 & mask); + h3 = (h3 & ~mask) | (g3 & mask); + h4 = h4 & ~mask; /* when h>=p, h4 reduces to 0 after subtract */ - /* s = key[16..31] */ - st->s[0] = (uint64_t)key[16] | ((uint64_t)key[17] << 8) | - ((uint64_t)key[18] << 16) | ((uint64_t)key[19] << 24); - st->s[1] = (uint64_t)key[20] | ((uint64_t)key[21] << 8) | - ((uint64_t)key[22] << 16) | ((uint64_t)key[23] << 24); - - for (int i = 0; i < 5; i++) st->h[i] = 0; - st->left = 0; -} + /* + * Reconstruct 128-bit h from five 26-bit limbs, then add s = key[16..31]. + * + * After reduction h4 == 0, so h fits in 104 bits. + * We extract four 32-bit words from h before adding s to keep carry + * propagation correct (each f-word carry is at most 1 bit). + * + * h bits 0.. 31 → w0 = h0 | (h1 << 26) [bits 0..51, take lower 32] + * h bits 32.. 63 → w1 = (h1 >> 6) | (h2 << 20) [bits 32..77, take lower 32] + * h bits 64.. 95 → w2 = (h2 >> 12) | (h3 << 14) [bits 64..103, take lower 32] + * h bits 96..127 → w3 = (h3 >> 18) | (h4 << 8) [bits 96..129, take lower 32] + * + * Each w-value must be masked to 32 bits BEFORE adding s so that the + * carry into the next word is exactly 0 or 1. + */ + uint64_t w0 = (h0 | (h1 << 26)) & 0xffffffffULL; + uint64_t w1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffffULL; + uint64_t w2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffffULL; + uint64_t w3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffffULL; -static void poly1305_update(poly1305_state_t *st, const uint8_t *data, size_t len) -{ - if (st->left) { - size_t needed = 16 - st->left; - if (len < needed) { - memcpy(st->buf + st->left, data, len); - st->left += len; - return; - } - memcpy(st->buf + st->left, data, needed); - data += needed; - len -= needed; - - /* Add buffer data to h */ - st->h[0] += (uint64_t)st->buf[0] | ((uint64_t)st->buf[1] << 8) | - ((uint64_t)st->buf[2] << 16) | ((uint64_t)st->buf[3] << 24); - st->h[1] += ((uint64_t)st->buf[4] | ((uint64_t)st->buf[5] << 8) | - ((uint64_t)st->buf[6] << 16) | ((uint64_t)st->buf[7] << 24)) & 0x3ffff03; - st->h[2] += ((uint64_t)st->buf[8] | ((uint64_t)st->buf[9] << 8) | - ((uint64_t)st->buf[10] << 16) | ((uint64_t)st->buf[11] << 24)) & 0x3ffc0ff; - st->h[3] += ((uint64_t)st->buf[12] | ((uint64_t)st->buf[13] << 8) | - ((uint64_t)st->buf[14] << 16) | ((uint64_t)st->buf[15] << 24)) & 0x3f03fff; - st->h[4] += (1ULL << 24); /* 2^128 in 26-bit limb representation */ - - /* Multiply by r */ - uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; - uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; - - uint64_t t0 = h0 * r0 + h1 * (r4 * 5) + h2 * (r3 * 5) + h3 * (r2 * 5) + h4 * (r1 * 5); - uint64_t t1 = h0 * r1 + h1 * r0 + h2 * (r4 * 5) + h3 * (r3 * 5) + h4 * (r2 * 5); - uint64_t t2 = h0 * r2 + h1 * r1 + h2 * r0 + h3 * (r4 * 5) + h4 * (r3 * 5); - uint64_t t3 = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * (r4 * 5); - uint64_t t4 = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0; - - uint32_t c = (uint32_t)(t0 >> 26); - st->h[0] = (uint32_t)(t0 & 0x3ffffff); - t1 += c; - c = (uint32_t)(t1 >> 22); - st->h[1] = (uint32_t)(t1 & 0x3ffffff); - t2 += c; - c = (uint32_t)(t2 >> 26); - st->h[2] = (uint32_t)(t2 & 0x3ffffff); - t3 += c; - c = (uint32_t)(t3 >> 26); - st->h[3] = (uint32_t)(t3 & 0x3ffffff); - t4 += c; - st->h[4] = (uint32_t)(t4 & 0x3ffffff); - - st->left = 0; - } - - while (len >= 16) { - /* 2^128 in 26-bit limb */ - st->h[0] += (uint64_t)data[0] | ((uint64_t)data[1] << 8) | - ((uint64_t)data[2] << 16) | ((uint64_t)data[3] << 24); - st->h[1] += ((uint64_t)data[4] | ((uint64_t)data[5] << 8) | - ((uint64_t)data[6] << 16) | ((uint64_t)data[7] << 24)) & 0x3ffff03; - st->h[2] += ((uint64_t)data[8] | ((uint64_t)data[9] << 8) | - ((uint64_t)data[10] << 16) | ((uint64_t)data[11] << 24)) & 0x3ffc0ff; - st->h[3] += ((uint64_t)data[12] | ((uint64_t)data[13] << 8) | - ((uint64_t)data[14] << 16) | ((uint64_t)data[15] << 24)) & 0x3f03fff; - st->h[4] += (1ULL << 24); /* 2^128 */ - - /* Multiply by r */ - uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; - uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; - - uint64_t t0 = h0 * r0 + h1 * (r4 * 5) + h2 * (r3 * 5) + h3 * (r2 * 5) + h4 * (r1 * 5); - uint64_t t1 = h0 * r1 + h1 * r0 + h2 * (r4 * 5) + h3 * (r3 * 5) + h4 * (r2 * 5); - uint64_t t2 = h0 * r2 + h1 * r1 + h2 * r0 + h3 * (r4 * 5) + h4 * (r3 * 5); - uint64_t t3 = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * (r4 * 5); - uint64_t t4 = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0; - - uint32_t c = (uint32_t)(t0 >> 26); - st->h[0] = (uint32_t)(t0 & 0x3ffffff); - t1 += c; - c = (uint32_t)(t1 >> 22); - st->h[1] = (uint32_t)(t1 & 0x3ffffff); - t2 += c; - c = (uint32_t)(t2 >> 26); - st->h[2] = (uint32_t)(t2 & 0x3ffffff); - t3 += c; - c = (uint32_t)(t3 >> 26); - st->h[3] = (uint32_t)(t3 & 0x3ffffff); - t4 += c; - st->h[4] = (uint32_t)(t4 & 0x3ffffff); - - data += 16; - len -= 16; - } - - if (len) { - memcpy(st->buf + st->left, data, len); - st->left += len; - } -} + uint64_t f0 = w0 + (uint64_t)st->s[0]; + uint64_t f1 = w1 + (uint64_t)st->s[1] + (f0 >> 32); + uint64_t f2 = w2 + (uint64_t)st->s[2] + (f1 >> 32); + uint64_t f3 = w3 + (uint64_t)st->s[3] + (f2 >> 32); -static void poly1305_final(poly1305_state_t *st, uint8_t mac[16]) -{ - /* Process remaining bytes */ - if (st->left) { - uint64_t hibit = ((uint64_t)1) << (8 * st->left); - st->buf[st->left] = 1; - for (size_t i = st->left + 1; i < 16; i++) { - st->buf[i] = 0; - } - - st->h[0] += (uint64_t)st->buf[0] | ((uint64_t)st->buf[1] << 8) | - ((uint64_t)st->buf[2] << 16) | ((uint64_t)st->buf[3] << 24); - st->h[1] += ((uint64_t)st->buf[4] | ((uint64_t)st->buf[5] << 8) | - ((uint64_t)st->buf[6] << 16) | ((uint64_t)st->buf[7] << 24)) & 0x3ffff03; - st->h[2] += ((uint64_t)st->buf[8] | ((uint64_t)st->buf[9] << 8) | - ((uint64_t)st->buf[10] << 16) | ((uint64_t)st->buf[11] << 24)) & 0x3ffc0ff; - st->h[3] += ((uint64_t)st->buf[12] | ((uint64_t)st->buf[13] << 8) | - ((uint64_t)st->buf[14] << 16) | ((uint64_t)st->buf[15] << 24)) & 0x3f03fff; - st->h[4] += (1ULL << 24); /* 2^128 */ - - /* Multiply by r one last time */ - uint64_t r0 = st->r[0], r1 = st->r[1], r2 = st->r[2], r3 = st->r[3], r4 = st->r[4]; - uint64_t h0 = st->h[0], h1 = st->h[1], h2 = st->h[2], h3 = st->h[3], h4 = st->h[4]; - - uint64_t t0 = h0 * r0 + h1 * (r4 * 5) + h2 * (r3 * 5) + h3 * (r2 * 5) + h4 * (r1 * 5); - uint64_t t1 = h0 * r1 + h1 * r0 + h2 * (r4 * 5) + h3 * (r3 * 5) + h4 * (r2 * 5); - uint64_t t2 = h0 * r2 + h1 * r1 + h2 * r0 + h3 * (r4 * 5) + h4 * (r3 * 5); - uint64_t t3 = h0 * r3 + h1 * r2 + h2 * r1 + h3 * r0 + h4 * (r4 * 5); - uint64_t t4 = h0 * r4 + h1 * r3 + h2 * r2 + h3 * r1 + h4 * r0; - - uint32_t c = (uint32_t)(t0 >> 26); - st->h[0] = (uint32_t)(t0 & 0x3ffffff); - t1 += c; - c = (uint32_t)(t1 >> 22); - st->h[1] = (uint32_t)(t1 & 0x3ffffff); - t2 += c; - c = (uint32_t)(t2 >> 26); - st->h[2] = (uint32_t)(t2 & 0x3ffffff); - t3 += c; - c = (uint32_t)(t3 >> 26); - st->h[3] = (uint32_t)(t3 & 0x3ffffff); - t4 += c; - st->h[4] = (uint32_t)(t4 & 0x3ffffff); - } - - /* Final reduction */ - uint64_t c = st->h[4] + 5; - st->h[4] &= 0x3ffffff; - st->h[0] += (c >> 26); - st->h[1] += (st->h[0] >> 26); - st->h[0] &= 0x3ffffff; - st->h[2] += (st->h[1] >> 22); - st->h[1] &= 0x3ffffff; - st->h[3] += (st->h[2] >> 26); - st->h[2] &= 0x3ffffff; - st->h[4] += (st->h[3] >> 22); - st->h[3] &= 0x3ffffff; - - /* Add s[0], s[1] and output full 128-bit MAC */ - uint64_t f0 = st->h[0] + st->s[0]; - uint64_t f1 = st->h[1] + st->s[1] + (f0 >> 32); - uint64_t f2 = st->h[2] + (f1 >> 32); - uint64_t f3 = st->h[3] + (f2 >> 32); - - mac[0] = (uint8_t)(f0); mac[1] = (uint8_t)(f0 >> 8); - mac[2] = (uint8_t)(f0 >> 16); mac[3] = (uint8_t)(f0 >> 24); - mac[4] = (uint8_t)(f1); mac[5] = (uint8_t)(f1 >> 8); - mac[6] = (uint8_t)(f1 >> 16); mac[7] = (uint8_t)(f1 >> 24); - mac[8] = (uint8_t)(f2); mac[9] = (uint8_t)(f2 >> 8); + mac[ 0] = (uint8_t)(f0); mac[ 1] = (uint8_t)(f0 >> 8); + mac[ 2] = (uint8_t)(f0 >> 16); mac[ 3] = (uint8_t)(f0 >> 24); + mac[ 4] = (uint8_t)(f1); mac[ 5] = (uint8_t)(f1 >> 8); + mac[ 6] = (uint8_t)(f1 >> 16); mac[ 7] = (uint8_t)(f1 >> 24); + mac[ 8] = (uint8_t)(f2); mac[ 9] = (uint8_t)(f2 >> 8); mac[10] = (uint8_t)(f2 >> 16); mac[11] = (uint8_t)(f2 >> 24); - mac[12] = (uint8_t)(f3); mac[13] = (uint8_t)(f3 >> 8); + mac[12] = (uint8_t)(f3); mac[13] = (uint8_t)(f3 >> 8); mac[14] = (uint8_t)(f3 >> 16); mac[15] = (uint8_t)(f3 >> 24); } -#endif /* SE050_CHACHA20_ESP32 */ +/* ============================================================================ + * ChaCha20-Poly1305 AEAD (RFC 8439 §2.8) + * ============================================================================ + * + * Poly1305 input layout: + * AAD || pad(AAD,16) || ciphertext || pad(CT,16) + * || LE64(len(AAD)) || LE64(len(ciphertext)) + * ============================================================================ */ -void se050_poly1305_mac(uint8_t mac[16], const uint8_t key[32], - const uint8_t *data, size_t len) +static void store64_le(uint8_t *p, uint64_t v) { - poly1305_state_t st; - poly1305_init(&st, key); - poly1305_update(&st, data, len); - poly1305_final(&st, mac); + for (int i = 0; i < 8; i++) { p[i] = (uint8_t)v; v >>= 8; } } -/* ============================================================================ - * ChaCha20-Poly1305 AEAD - * ============================================================================ */ +static void aead_poly1305_input(poly1305_state_t *st, + const uint8_t *aad, size_t aad_len, + const uint8_t *ct, size_t ct_len) +{ + static const uint8_t zeros[16] = {0}; + + poly1305_update(st, aad, aad_len); + size_t aad_pad = (16u - (aad_len & 15u)) & 15u; + if (aad_pad) poly1305_update(st, zeros, aad_pad); + + poly1305_update(st, ct, ct_len); + size_t ct_pad = (16u - (ct_len & 15u)) & 15u; + if (ct_pad) poly1305_update(st, zeros, ct_pad); + + uint8_t lengths[16]; + store64_le(lengths + 0, (uint64_t)aad_len); + store64_le(lengths + 8, (uint64_t)ct_len); + poly1305_update(st, lengths, 16); +} + +/* --- Context init / zeroize --- */ int se050_chacha20_poly1305_init(se050_chacha20_poly1305_ctx_t *ctx, const uint8_t key[CHACHA20_KEY_SIZE]) @@ -579,138 +389,81 @@ int se050_chacha20_poly1305_init(se050_chacha20_poly1305_ctx_t *ctx, return 0; } -int se050_chacha20_poly1305_encrypt(se050_chacha20_poly1305_ctx_t *ctx, - const uint8_t nonce[CHACHA20_NONCE_SIZE], - const uint8_t *plaintext, size_t plaintext_len, - const uint8_t *aad, size_t aad_len, - uint8_t *ciphertext, uint8_t tag[POLY1305_TAG_SIZE]) +void se050_chacha20_poly1305_zeroize(se050_chacha20_poly1305_ctx_t *ctx) { - if (!nonce || !plaintext || !ciphertext || !tag) return -1; - - /* Get key from context */ - const uint8_t *key; - if (ctx) { - key = ctx->key; - } else { - return -1; - } - - /* Generate Poly1305 key using ChaCha20 */ - uint8_t poly_key[32] = {0}; - uint8_t block[64]; - se050_chacha20_block(block, key, 0, nonce); - memcpy(poly_key, block, 32); - - /* Compute MAC over AAD + ciphertext using state */ - uint8_t mac_key[32]; - memcpy(mac_key, poly_key, 32); - + if (ctx) memzero_explicit(ctx->key, CHACHA20_KEY_SIZE); +} + +/* --- Encrypt --- */ + +int se050_chacha20_poly1305_encrypt(se050_chacha20_poly1305_ctx_t *ctx, + const uint8_t nonce[CHACHA20_NONCE_SIZE], + const uint8_t *plaintext, size_t pt_len, + const uint8_t *aad, size_t aad_len, + uint8_t *ciphertext, + uint8_t tag[POLY1305_TAG_SIZE]) +{ + if (!ctx || !nonce || !ciphertext || !tag) return -1; + if (pt_len > 0 && !plaintext) return -1; + if (aad_len > 0 && !aad) return -1; + + /* 1. Derive one-time Poly1305 key from counter=0 block */ + uint8_t otk[64]; + se050_chacha20_block(otk, ctx->key, 0, nonce); + + /* 2. Encrypt plaintext (counter starts at 1) */ + se050_chacha20(ciphertext, plaintext, pt_len, ctx->key, 1, nonce); + + /* 3. Compute tag over AAD + ciphertext */ poly1305_state_t st; - poly1305_init(&st, mac_key); - - /* Process AAD */ - poly1305_update(&st, aad, aad_len); - - /* Pad AAD */ - uint8_t pad[16] = {0}; - size_t aad_pad = (16 - (aad_len % 16)) % 16; - if (aad_pad) poly1305_update(&st, pad, aad_pad); - - /* Encrypt plaintext */ - se050_chacha20(ciphertext, plaintext, plaintext_len, key, 1, nonce); - - /* Process ciphertext */ - poly1305_update(&st, ciphertext, plaintext_len); - - /* Pad ciphertext */ - size_t ct_pad = (16 - (plaintext_len % 16)) % 16; - if (ct_pad) poly1305_update(&st, pad, ct_pad); - - /* Append lengths */ - uint8_t lengths[16]; - memset(lengths, 0, 16); - memcpy(lengths, &aad_len, 8); - memcpy(lengths + 8, &plaintext_len, 8); - poly1305_update(&st, lengths, 16); - - /* Finalize MAC */ + poly1305_init(&st, otk); /* otk[0..31] is the Poly1305 key */ + aead_poly1305_input(&st, aad, aad_len, ciphertext, pt_len); poly1305_final(&st, tag); - - /* Zeroize poly key */ - memzero_explicit(poly_key, 32); - memzero_explicit(mac_key, 32); - memzero_explicit(block, 64); - + + memzero_explicit(otk, sizeof(otk)); + memzero_explicit(&st, sizeof(st)); return 0; } +/* --- Decrypt --- */ + int se050_chacha20_poly1305_decrypt(se050_chacha20_poly1305_ctx_t *ctx, - const uint8_t nonce[CHACHA20_NONCE_SIZE], - const uint8_t *ciphertext, size_t ciphertext_len, - const uint8_t *aad, size_t aad_len, - const uint8_t tag[POLY1305_TAG_SIZE], - uint8_t *plaintext) + const uint8_t nonce[CHACHA20_NONCE_SIZE], + const uint8_t *ciphertext, size_t ct_len, + const uint8_t *aad, size_t aad_len, + const uint8_t tag[POLY1305_TAG_SIZE], + uint8_t *plaintext) { if (!ctx || !nonce || !ciphertext || !tag || !plaintext) return -1; - - /* Generate Poly1305 key */ - uint8_t poly_key[32] = {0}; - uint8_t block[64]; - se050_chacha20_block(block, ctx->key, 0, nonce); - memcpy(poly_key, block, 32); - - uint8_t mac_key[32]; - memcpy(mac_key, poly_key, 32); - - /* Compute expected MAC using state */ + if (ct_len > 0 && !ciphertext) return -1; + if (aad_len > 0 && !aad) return -1; + + /* 1. Derive one-time Poly1305 key */ + uint8_t otk[64]; + se050_chacha20_block(otk, ctx->key, 0, nonce); + + /* 2. Verify tag BEFORE decrypting (authenticate-then-decrypt) */ poly1305_state_t st; - poly1305_init(&st, mac_key); - - /* Process AAD */ - poly1305_update(&st, aad, aad_len); - - /* Pad AAD */ - uint8_t pad[16] = {0}; - size_t aad_pad = (16 - (aad_len % 16)) % 16; - if (aad_pad) poly1305_update(&st, pad, aad_pad); - - /* Process ciphertext */ - poly1305_update(&st, ciphertext, ciphertext_len); - - /* Pad ciphertext */ - size_t ct_pad = (16 - (ciphertext_len % 16)) % 16; - if (ct_pad) poly1305_update(&st, pad, ct_pad); - - /* Append lengths */ - uint8_t lengths[16]; - memset(lengths, 0, 16); - memcpy(lengths, &aad_len, 8); - memcpy(lengths + 8, &ciphertext_len, 8); - poly1305_update(&st, lengths, 16); - - /* Finalize MAC */ - uint8_t expected_tag[16]; - poly1305_final(&st, expected_tag); - + poly1305_init(&st, otk); + aead_poly1305_input(&st, aad, aad_len, ciphertext, ct_len); + uint8_t expected[POLY1305_TAG_SIZE]; + poly1305_final(&st, expected); + /* Constant-time comparison */ - int ret = 0; - if (crypto_memneq(expected_tag, tag, 16) != 0) { - ret = -1; - } - - /* Only decrypt if MAC is valid */ - if (ret == 0) { - se050_chacha20(plaintext, ciphertext, ciphertext_len, ctx->key, 1, nonce); - } - - /* Zeroize sensitive data */ - memzero_explicit(poly_key, 32); - memzero_explicit(mac_key, 32); - memzero_explicit(block, 64); - - return ret; + int ok = crypto_memneq(expected, tag, POLY1305_TAG_SIZE) == 0 ? 0 : -1; + + /* 3. Decrypt only if tag is valid */ + if (ok == 0) + se050_chacha20(plaintext, ciphertext, ct_len, ctx->key, 1, nonce); + + memzero_explicit(otk, sizeof(otk)); + memzero_explicit(&st, sizeof(st)); + memzero_explicit(expected, sizeof(expected)); + return ok; } +/* --- WireGuard convenience wrappers (no AAD) --- */ + int se050_wireguard_encrypt(const uint8_t key[WG_KEY_SIZE], const uint8_t nonce[WG_NONCE_SIZE], const uint8_t *plaintext, size_t len, @@ -719,10 +472,10 @@ int se050_wireguard_encrypt(const uint8_t key[WG_KEY_SIZE], se050_chacha20_poly1305_ctx_t ctx; int ret = se050_chacha20_poly1305_init(&ctx, key); if (ret != 0) return ret; - - ret = se050_chacha20_poly1305_encrypt(&ctx, nonce, plaintext, len, - NULL, 0, ciphertext, tag); - + ret = se050_chacha20_poly1305_encrypt(&ctx, nonce, + plaintext, len, + NULL, 0, + ciphertext, tag); se050_chacha20_poly1305_zeroize(&ctx); return ret; } @@ -736,138 +489,154 @@ int se050_wireguard_decrypt(const uint8_t key[WG_KEY_SIZE], se050_chacha20_poly1305_ctx_t ctx; int ret = se050_chacha20_poly1305_init(&ctx, key); if (ret != 0) return ret; - - ret = se050_chacha20_poly1305_decrypt(&ctx, nonce, ciphertext, len, - NULL, 0, tag, plaintext); - + ret = se050_chacha20_poly1305_decrypt(&ctx, nonce, + ciphertext, len, + NULL, 0, + tag, plaintext); se050_chacha20_poly1305_zeroize(&ctx); return ret; } -void se050_chacha20_poly1305_zeroize(se050_chacha20_poly1305_ctx_t *ctx) -{ - if (ctx) { - memzero_explicit(ctx->key, CHACHA20_KEY_SIZE); - } -} +/* ============================================================================ + * Self-test (compile with -DCHACHA20_POLY1305_TEST) + * ============================================================================ */ #ifdef CHACHA20_POLY1305_TEST #include -/* RFC 7539 Section 2.8.2 Test Vector */ -static const uint8_t RFC7539_KEY[32] = { +/* RFC 8439 §2.8.2 */ +static const uint8_t TV_KEY[32] = { 0x80,0x81,0x82,0x83,0x84,0x85,0x86,0x87, 0x88,0x89,0x8a,0x8b,0x8c,0x8d,0x8e,0x8f, 0x90,0x91,0x92,0x93,0x94,0x95,0x96,0x97, 0x98,0x99,0x9a,0x9b,0x9c,0x9d,0x9e,0x9f }; - -static const uint8_t RFC7539_NONCE[12] = { +static const uint8_t TV_NONCE[12] = { 0x07,0x00,0x00,0x00,0x40,0x41,0x42,0x43, 0x44,0x45,0x46,0x47 }; - -static const uint8_t RFC7539_AAD[16] = { +/* RFC 8439 §2.8.2: AAD is 12 bytes */ +static const uint8_t TV_AAD[12] = { 0x50,0x51,0x52,0x53,0xc0,0xc1,0xc2,0xc3, - 0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xcb + 0xc4,0xc5,0xc6,0xc7 }; - -static const uint8_t RFC7539_PLAINTEXT[114] = { - 0x4c,0x61,0x64,0x69,0x65,0x73,0x20,0x61, - 0x6e,0x64,0x20,0x47,0x65,0x6e,0x74,0x6c, - 0x65,0x6d,0x65,0x6e,0x20,0x6f,0x66,0x20, - 0x74,0x68,0x65,0x20,0x63,0x6c,0x61,0x73, - 0x73,0x20,0x6f,0x66,0x20,0x27,0x39,0x39, - 0x3a,0x20,0x49,0x66,0x20,0x49,0x20,0x63, - 0x6f,0x75,0x6c,0x64,0x20,0x6f,0x66,0x66, - 0x65,0x72,0x20,0x79,0x6f,0x75,0x20,0x6f, - 0x6e,0x6c,0x79,0x20,0x6f,0x6e,0x65,0x20, - 0x74,0x69,0x70,0x20,0x66,0x6f,0x72,0x20, - 0x74,0x68,0x65,0x20,0x66,0x75,0x74,0x75, - 0x72,0x65,0x2c,0x20,0x73,0x75,0x6e,0x73, - 0x63,0x72,0x65,0x65,0x6e,0x20,0x77,0x6f, - 0x75,0x6c,0x64,0x20,0x62,0x65,0x20,0x69, +static const uint8_t TV_PT[114] = { + 0x4c,0x61,0x64,0x69,0x65,0x73,0x20,0x61,0x6e,0x64,0x20,0x47,0x65,0x6e, + 0x74,0x6c,0x65,0x6d,0x65,0x6e,0x20,0x6f,0x66,0x20,0x74,0x68,0x65,0x20, + 0x63,0x6c,0x61,0x73,0x73,0x20,0x6f,0x66,0x20,0x27,0x39,0x39,0x3a,0x20, + 0x49,0x66,0x20,0x49,0x20,0x63,0x6f,0x75,0x6c,0x64,0x20,0x6f,0x66,0x66, + 0x65,0x72,0x20,0x79,0x6f,0x75,0x20,0x6f,0x6e,0x6c,0x79,0x20,0x6f,0x6e, + 0x65,0x20,0x74,0x69,0x70,0x20,0x66,0x6f,0x72,0x20,0x74,0x68,0x65,0x20, + 0x66,0x75,0x74,0x75,0x72,0x65,0x2c,0x20,0x73,0x75,0x6e,0x73,0x63,0x72, + 0x65,0x65,0x6e,0x20,0x77,0x6f,0x75,0x6c,0x64,0x20,0x62,0x65,0x20,0x69, 0x74,0x2e }; - -static const uint8_t RFC7539_CIPHERTEXT[114] = { - 0xd3,0x1a,0x8d,0x34,0x64,0x8e,0x60,0xdb, - 0x7b,0x86,0xaf,0xbc,0x53,0xef,0x7e,0xc2, - 0xa4,0xad,0xed,0x51,0x29,0x6e,0x08,0xfe, - 0xa9,0xe2,0xb5,0xa7,0x36,0xee,0x62,0xd6, - 0x3d,0xbe,0xa4,0x5e,0x8c,0xa9,0x67,0x12, - 0x82,0xfa,0xfb,0x69,0xda,0x92,0x72,0x8b, - 0x1a,0x71,0xde,0x0a,0x9e,0x06,0x0b,0x29, - 0x05,0xd6,0xa5,0xb6,0x7e,0xcd,0x3b,0x36, - 0x92,0xdd,0xbd,0x7f,0x2d,0x77,0x8b,0x8c, - 0x98,0x03,0xae,0xe3,0x28,0x09,0x1b,0x58, - 0xfa,0xb3,0x24,0xe4,0xfa,0xd6,0x75,0x94, - 0x55,0x85,0x80,0x8b,0x48,0x31,0xd7,0xbc, - 0x3f,0xf4,0xde,0xf0,0x8e,0x4b,0x7a,0x9d, - 0xe0,0xa8,0x2a,0xb4,0x68,0x08,0xd6,0x1b, - 0x9b,0x39,0x87,0x65,0x43,0x21,0x00,0x00, - 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, - 0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, - 0x00,0x00,0x00,0x00,0x00,0x00,0x00 +/* + * Verified with Python reference implementation against RFC 8439 §2.8.2. + * PT = "Ladies and Gentlemen of the class of '99: If I could offer you only + * one tip for the future, sunscreen would be it." (114 bytes) + */ +static const uint8_t TV_CT[114] = { + 0xd3,0x1a,0x8d,0x34,0x64,0x8e,0x60,0xdb,0x7b,0x86,0xaf,0xbc,0x53,0xef, + 0x7e,0xc2,0xa4,0xad,0xed,0x51,0x29,0x6e,0x08,0xfe,0xa9,0xe2,0xb5,0xa7, + 0x36,0xee,0x62,0xd6,0x3d,0xbe,0xa4,0x5e,0x8c,0xa9,0x67,0x12,0x82,0xfa, + 0xfb,0x69,0xda,0x92,0x72,0x8b,0x1a,0x71,0xde,0x0a,0x9e,0x06,0x0b,0x29, + 0x05,0xd6,0xa5,0xb6,0x7e,0xcd,0x3b,0x36,0x92,0xdd,0xbd,0x7f,0x2d,0x77, + 0x8b,0x8c,0x98,0x03,0xae,0xe3,0x28,0x09,0x1b,0x58,0xfa,0xb3,0x24,0xe4, + 0xfa,0xd6,0x75,0x94,0x55,0x85,0x80,0x8b,0x48,0x31,0xd7,0xbc,0x3f,0xf4, + 0xde,0xf0,0x8e,0x4b,0x7a,0x9d,0xe5,0x76,0xd2,0x65,0x86,0xce,0xc6,0x4b, + 0x61,0x16 }; - -static const uint8_t RFC7539_TAG[16] = { +static const uint8_t TV_TAG[16] = { 0x1a,0xe1,0x0b,0x59,0x4f,0x09,0xe2,0x6a, 0x7e,0x90,0x2e,0xcb,0xd0,0x60,0x06,0x91 }; -static void print_hex(const char *label, const uint8_t *buf, size_t len) +static void print_hex(const char *label, const uint8_t *b, size_t n) { - printf("%s: ", label); - for (size_t i = 0; i < len; i++) printf("%02x", buf[i]); + printf(" %-12s: ", label); + for (size_t i = 0; i < n; i++) printf("%02x", b[i]); printf("\n"); } int main(void) { - printf("ChaCha20-Poly1305 Test Suite\n"); - printf("============================\n\n"); - - /* Test 1: ChaCha20 block function */ - printf("Test 1: ChaCha20 Block Function\n"); - uint8_t block[64]; - uint8_t key[32] = {0}; - uint8_t nonce[12] = {0}; - - for (int i = 0; i < 32; i++) key[i] = i; - nonce[8] = 0x4a; - - se050_chacha20_block(block, key, 1, nonce); - print_hex("Block output (first 32 bytes)", block, 32); - printf("[INFO] ChaCha20 block computed\n\n"); - - /* Test 2: ChaCha20-Poly1305 simple AEAD */ - printf("Test 2: ChaCha20-Poly1305 Simple AEAD\n"); - uint8_t plaintext[16] = {0}; - uint8_t ciphertext[16]; - uint8_t tag[16]; - uint8_t decrypted[16]; - - for (int i = 0; i < 16; i++) plaintext[i] = i; - - se050_chacha20_poly1305_ctx_t ctx; - se050_chacha20_poly1305_init(&ctx, key); - se050_chacha20_poly1305_encrypt(&ctx, nonce, plaintext, 16, NULL, 0, ciphertext, tag); - - printf("Tag: "); - print_hex("", tag, 16); - - int ret = se050_chacha20_poly1305_decrypt(&ctx, nonce, ciphertext, 16, NULL, 0, tag, decrypted); - - if (ret == 0 && memcmp(plaintext, decrypted, 16) == 0) { - printf("[PASS] ChaCha20-Poly1305 AEAD\n"); - } else { - printf("[FAIL] ChaCha20-Poly1305 AEAD (ret=%d)\n", ret); + int fail = 0; + printf("ChaCha20-Poly1305 Test Suite\n============================\n\n"); + + /* Test 1: RFC 8439 §2.8.2 encrypt */ + { + printf("Test 1: RFC 8439 §2.8.2 encrypt\n"); + uint8_t ct[114], tag[16]; + se050_chacha20_poly1305_ctx_t ctx; + se050_chacha20_poly1305_init(&ctx, TV_KEY); + se050_chacha20_poly1305_encrypt(&ctx, TV_NONCE, + TV_PT, sizeof(TV_PT), + TV_AAD, sizeof(TV_AAD), + ct, tag); + int ct_ok = memcmp(ct, TV_CT, sizeof(TV_CT)) == 0; + int tag_ok = memcmp(tag, TV_TAG, 16) == 0; + print_hex("tag expected", TV_TAG, 16); + print_hex("tag computed", tag, 16); + if (ct_ok && tag_ok) printf(" [PASS]\n\n"); + else { printf(" [FAIL] ct_ok=%d tag_ok=%d\n\n", ct_ok, tag_ok); fail++; } + se050_chacha20_poly1305_zeroize(&ctx); } - - se050_chacha20_poly1305_zeroize(&ctx); - + + /* Test 2: RFC 8439 §2.8.2 decrypt */ + { + printf("Test 2: RFC 8439 §2.8.2 decrypt\n"); + uint8_t pt[114]; + se050_chacha20_poly1305_ctx_t ctx; + se050_chacha20_poly1305_init(&ctx, TV_KEY); + int ret = se050_chacha20_poly1305_decrypt(&ctx, TV_NONCE, + TV_CT, sizeof(TV_CT), + TV_AAD, sizeof(TV_AAD), + TV_TAG, pt); + int pt_ok = (ret == 0) && (memcmp(pt, TV_PT, sizeof(TV_PT)) == 0); + if (pt_ok) printf(" [PASS]\n\n"); + else { printf(" [FAIL] ret=%d\n\n", ret); fail++; } + se050_chacha20_poly1305_zeroize(&ctx); + } + + /* Test 3: tampered tag is rejected */ + { + printf("Test 3: tampered tag rejected\n"); + uint8_t bad_tag[16]; + memcpy(bad_tag, TV_TAG, 16); + bad_tag[0] ^= 0xff; + uint8_t pt[114]; + se050_chacha20_poly1305_ctx_t ctx; + se050_chacha20_poly1305_init(&ctx, TV_KEY); + int ret = se050_chacha20_poly1305_decrypt(&ctx, TV_NONCE, + TV_CT, sizeof(TV_CT), + TV_AAD, sizeof(TV_AAD), + bad_tag, pt); + if (ret != 0) printf(" [PASS]\n\n"); + else { printf(" [FAIL] should have rejected\n\n"); fail++; } + se050_chacha20_poly1305_zeroize(&ctx); + } + + /* Test 4: encrypt→decrypt round-trip, no AAD */ + { + printf("Test 4: round-trip (no AAD)\n"); + uint8_t key[32] = {0}; + uint8_t nonce[12] = {0}; + uint8_t msg[37], ct[37], pt[37], tag[16]; + for (int i = 0; i < 32; i++) key[i] = (uint8_t)(i + 1); + for (int i = 0; i < 37; i++) msg[i] = (uint8_t)(i ^ 0xaa); + + se050_chacha20_poly1305_ctx_t ctx; + se050_chacha20_poly1305_init(&ctx, key); + se050_chacha20_poly1305_encrypt(&ctx, nonce, msg, 37, NULL, 0, ct, tag); + int ret = se050_chacha20_poly1305_decrypt(&ctx, nonce, ct, 37, NULL, 0, tag, pt); + if (ret == 0 && memcmp(msg, pt, 37) == 0) printf(" [PASS]\n\n"); + else { printf(" [FAIL]\n\n"); fail++; } + se050_chacha20_poly1305_zeroize(&ctx); + } + printf("============================\n"); - return 0; + printf("Result: %s\n", fail ? "FAIL" : "ALL PASS"); + return fail ? 1 : 0; } -#endif +#endif /* CHACHA20_POLY1305_TEST */