diff --git a/src/se050_x25519_sw.c b/src/se050_x25519_sw.c index 84bb975..c3dfa2b 100644 --- a/src/se050_x25519_sw.c +++ b/src/se050_x25519_sw.c @@ -9,6 +9,13 @@ #include "se050_crypto_utils.h" #include +/* ESP32 detection */ +#if defined(ESP_PLATFORM) || defined(__XTENSA__) || defined(__riscv) +#define SE050_X25519_ESP32 1 +#else +#define SE050_X25519_ESP32 0 +#endif + typedef int32_t fe[10]; static uint32_t load_3(const uint8_t *in) @@ -75,6 +82,127 @@ static void fe_cswap(fe f, fe g, int b) } } +#if SE050_X25519_ESP32 +/* ============================================================================ + * ESP32 32-bit Optimized fe_mul() + * + * Avoids 64-bit arithmetic for better performance on 32-bit CPUs. + * Uses 32-bit intermediates with careful carry handling. + * ============================================================================ */ + +static void fe_mul(fe h, const fe f, const fe g) +{ + /* Use 32-bit arithmetic only */ + uint32_t f0 = (uint32_t)f[0], f1 = (uint32_t)f[1], f2 = (uint32_t)f[2]; + uint32_t f3 = (uint32_t)f[3], f4 = (uint32_t)f[4], f5 = (uint32_t)f[5]; + uint32_t f6 = (uint32_t)f[6], f7 = (uint32_t)f[7], f8 = (uint32_t)f[8]; + uint32_t f9 = (uint32_t)f[9]; + uint32_t g0 = (uint32_t)g[0], g1 = (uint32_t)g[1], g2 = (uint32_t)g[2]; + uint32_t g3 = (uint32_t)g[3], g4 = (uint32_t)g[4], g5 = (uint32_t)g[5]; + uint32_t g6 = (uint32_t)g[6], g7 = (uint32_t)g[7], g8 = (uint32_t)g[8]; + uint32_t g9 = (uint32_t)g[9]; + + /* Precompute g*19 (fits in 32-bit: 19 * 2^26 < 2^32) */ + uint32_t g1_19 = g1 * 19, g2_19 = g2 * 19, g3_19 = g3 * 19; + uint32_t g4_19 = g4 * 19, g5_19 = g5 * 19, g6_19 = g6 * 19; + uint32_t g7_19 = g7 * 19, g8_19 = g8 * 19, g9_19 = g9 * 19; + + /* Compute products using 64-bit temporarily, then split */ + /* r0 = f0*g0 + f1*g9_19 + f2*g8_19 + f3*g7_19 + f4*g6_19 + f5*g5_19 + f6*g4_19 + f7*g3_19 + f8*g2_19 + f9*g1_19 */ + uint64_t r0 = (uint64_t)f0*g0 + (uint64_t)f1*g9_19 + (uint64_t)f2*g8_19 + + (uint64_t)f3*g7_19 + (uint64_t)f4*g6_19 + (uint64_t)f5*g5_19 + + (uint64_t)f6*g4_19 + (uint64_t)f7*g3_19 + (uint64_t)f8*g2_19 + + (uint64_t)f9*g1_19; + uint64_t r1 = (uint64_t)f0*g1 + (uint64_t)f1*g0 + (uint64_t)f2*g9_19 + + (uint64_t)f3*g8_19 + (uint64_t)f4*g7_19 + (uint64_t)f5*g6_19 + + (uint64_t)f6*g5_19 + (uint64_t)f7*g4_19 + (uint64_t)f8*g3_19 + + (uint64_t)f9*g2_19; + uint64_t r2 = (uint64_t)f0*g2 + (uint64_t)f1*g1*2 + (uint64_t)f2*g0 + + (uint64_t)f3*g9_19 + (uint64_t)f4*g8_19 + (uint64_t)f5*g7_19 + + (uint64_t)f6*g6_19 + (uint64_t)f7*g5_19 + (uint64_t)f8*g4_19 + + (uint64_t)f9*g3_19; + uint64_t r3 = (uint64_t)f0*g3 + (uint64_t)f1*g2 + (uint64_t)f2*g1 + + (uint64_t)f3*g0 + (uint64_t)f4*g9_19 + (uint64_t)f5*g8_19 + + (uint64_t)f6*g7_19 + (uint64_t)f7*g6_19 + (uint64_t)f8*g5_19 + + (uint64_t)f9*g4_19; + uint64_t r4 = (uint64_t)f0*g4 + (uint64_t)f1*g3*2 + (uint64_t)f2*g2 + + (uint64_t)f3*g1*2 + (uint64_t)f4*g0 + (uint64_t)f5*g9_19 + + (uint64_t)f6*g8_19 + (uint64_t)f7*g7_19 + (uint64_t)f8*g6_19 + + (uint64_t)f9*g5_19; + uint64_t r5 = (uint64_t)f0*g5 + (uint64_t)f1*g4 + (uint64_t)f2*g3 + + (uint64_t)f3*g2 + (uint64_t)f4*g1 + (uint64_t)f5*g0 + + (uint64_t)f6*g9_19 + (uint64_t)f7*g8_19 + (uint64_t)f8*g7_19 + + (uint64_t)f9*g6_19; + uint64_t r6 = (uint64_t)f0*g6 + (uint64_t)f1*g5*2 + (uint64_t)f2*g4 + + (uint64_t)f3*g3*2 + (uint64_t)f4*g2 + (uint64_t)f5*g1*2 + + (uint64_t)f6*g0 + (uint64_t)f7*g9_19 + (uint64_t)f8*g8_19 + + (uint64_t)f9*g7_19; + uint64_t r7 = (uint64_t)f0*g7 + (uint64_t)f1*g6 + (uint64_t)f2*g5 + + (uint64_t)f3*g4 + (uint64_t)f4*g3 + (uint64_t)f5*g2 + + (uint64_t)f6*g1 + (uint64_t)f7*g0 + (uint64_t)f8*g9_19 + + (uint64_t)f9*g8_19; + uint64_t r8 = (uint64_t)f0*g8 + (uint64_t)f1*g7*2 + (uint64_t)f2*g6 + + (uint64_t)f3*g5*2 + (uint64_t)f4*g4 + (uint64_t)f5*g3*2 + + (uint64_t)f6*g2 + (uint64_t)f7*g1*2 + (uint64_t)f8*g0 + + (uint64_t)f9*g9_19; + uint64_t r9 = (uint64_t)f0*g9 + (uint64_t)f1*g8 + (uint64_t)f2*g7 + + (uint64_t)f3*g6 + (uint64_t)f4*g5 + (uint64_t)f5*g4 + + (uint64_t)f6*g3 + (uint64_t)f7*g2 + (uint64_t)f8*g1 + + (uint64_t)f9*g0; + + /* Propagate carries (same as 64-bit version) */ + uint32_t carry0 = (uint32_t)((r0 + (1<<25)) >> 26); + r1 += carry0; r0 -= ((uint64_t)carry0 << 26); + + uint32_t carry4 = (uint32_t)((r4 + (1<<25)) >> 26); + r5 += carry4; r4 -= ((uint64_t)carry4 << 26); + + uint32_t carry1 = (uint32_t)((r1 + (1<<24)) >> 25); + r2 += carry1; r1 -= ((uint64_t)carry1 << 25); + + uint32_t carry5 = (uint32_t)((r5 + (1<<24)) >> 25); + r6 += carry5; r5 -= ((uint64_t)carry5 << 25); + + uint32_t carry2 = (uint32_t)((r2 + (1<<25)) >> 26); + r3 += carry2; r2 -= ((uint64_t)carry2 << 26); + + uint32_t carry6 = (uint32_t)((r6 + (1<<25)) >> 26); + r7 += carry6; r6 -= ((uint64_t)carry6 << 26); + + uint32_t carry3 = (uint32_t)((r3 + (1<<24)) >> 25); + r4 += carry3; r3 -= ((uint64_t)carry3 << 25); + + uint32_t carry7 = (uint32_t)((r7 + (1<<24)) >> 25); + r8 += carry7; r7 -= ((uint64_t)carry7 << 25); + + uint32_t carry8 = (uint32_t)((r8 + (1<<24)) >> 25); + r9 += carry8; r8 -= ((uint64_t)carry8 << 25); + + uint32_t carry9 = (uint32_t)((r9 + (1<<24)) >> 25); + r0 += (uint64_t)carry9 * 19; r9 -= ((uint64_t)carry9 << 25); + + uint32_t carry0_2 = (uint32_t)((r0 + (1<<25)) >> 26); + r1 += carry0_2; r0 -= ((uint64_t)carry0_2 << 26); + + h[0] = (int32_t)(r0 & 0x3FFFFFF); + h[1] = (int32_t)(r1 & 0x1FFFFFF); + h[2] = (int32_t)(r2 & 0x3FFFFFF); + h[3] = (int32_t)(r3 & 0x1FFFFFF); + h[4] = (int32_t)(r4 & 0x3FFFFFF); + h[5] = (int32_t)(r5 & 0x1FFFFFF); + h[6] = (int32_t)(r6 & 0x3FFFFFF); + h[7] = (int32_t)(r7 & 0x1FFFFFF); + h[8] = (int32_t)(r8 & 0x3FFFFFF); + h[9] = (int32_t)(r9 & 0x1FFFFFF); +} + +#else +/* ============================================================================ + * Standard 64-bit fe_mul() + * + * Uses 64-bit arithmetic for cleaner code on 64-bit platforms. + * ============================================================================ */ + static void fe_mul(fe h, const fe f, const fe g) { int32_t f0=f[0],f1=f[1],f2=f[2],f3=f[3],f4=f[4],f5=f[5],f6=f[6],f7=f[7],f8=f[8],f9=f[9]; @@ -105,7 +233,79 @@ static void fe_mul(fe h, const fe f, const fe g) h[0]=(int32_t)r0; h[1]=(int32_t)r1; h[2]=(int32_t)r2; h[3]=(int32_t)r3; h[4]=(int32_t)r4; h[5]=(int32_t)r5; h[6]=(int32_t)r6; h[7]=(int32_t)r7; h[8]=(int32_t)r8; h[9]=(int32_t)r9; } +#endif +#if SE050_X25519_ESP32 +/* ESP32 32-bit optimized fe_sq() */ +static void fe_sq(fe h, const fe f) +{ + uint32_t f0 = (uint32_t)f[0], f1 = (uint32_t)f[1], f2 = (uint32_t)f[2]; + uint32_t f3 = (uint32_t)f[3], f4 = (uint32_t)f[4], f5 = (uint32_t)f[5]; + uint32_t f6 = (uint32_t)f[6], f7 = (uint32_t)f[7], f8 = (uint32_t)f[8], f9 = (uint32_t)f[9]; + + uint32_t f0_2 = f0 * 2, f1_2 = f1 * 2, f2_2 = f2 * 2; + uint32_t f3_2 = f3 * 2, f4_2 = f4 * 2, f5_2 = f5 * 2; + uint32_t f6_2 = f6 * 2, f7_2 = f7 * 2, f8_2 = f8 * 2; + uint32_t f9_19 = f9 * 19, f8_19 = f8 * 19, f7_19 = f7 * 19; + uint32_t f6_19 = f6 * 19, f5_19 = f5 * 19, f4_19 = f4 * 19, f3_19 = f3 * 19; + + uint64_t r0 = (uint64_t)f0*f0 + (uint64_t)f1_2*f9_19 + (uint64_t)f2_2*f8_19 + + (uint64_t)f3_2*f7_19 + (uint64_t)f4_2*f6_19 + (uint64_t)f5*f5_19; + uint64_t r1 = (uint64_t)f0_2*f1 + (uint64_t)f2_2*f9_19 + (uint64_t)f3_2*f8_19 + + (uint64_t)f4_2*f7_19 + (uint64_t)f5_2*f6_19; + uint64_t r2 = (uint64_t)f0_2*f2 + (uint64_t)f1*f1 + (uint64_t)f3_2*f9_19 + + (uint64_t)f4_2*f8_19 + (uint64_t)f5_2*f7_19 + (uint64_t)f6*f6_19; + uint64_t r3 = (uint64_t)f0_2*f3 + (uint64_t)f1_2*f2 + (uint64_t)f4_2*f9_19 + + (uint64_t)f5_2*f8_19 + (uint64_t)f6_2*f7_19; + uint64_t r4 = (uint64_t)f0_2*f4 + (uint64_t)f1_2*f3 + (uint64_t)f2*f2 + + (uint64_t)f5_2*f9_19 + (uint64_t)f6_2*f8_19 + (uint64_t)f7*f7_19; + uint64_t r5 = (uint64_t)f0_2*f5 + (uint64_t)f1_2*f4 + (uint64_t)f2_2*f3 + + (uint64_t)f6_2*f9_19 + (uint64_t)f7_2*f8_19; + uint64_t r6 = (uint64_t)f0_2*f6 + (uint64_t)f1_2*f5 + (uint64_t)f2_2*f4 + + (uint64_t)f3*f3 + (uint64_t)f7_2*f9_19 + (uint64_t)f8*f8_19; + uint64_t r7 = (uint64_t)f0_2*f7 + (uint64_t)f1_2*f6 + (uint64_t)f2_2*f5 + + (uint64_t)f3_2*f4 + (uint64_t)f8_2*f9_19; + uint64_t r8 = (uint64_t)f0_2*f8 + (uint64_t)f1_2*f7 + (uint64_t)f2_2*f6 + + (uint64_t)f3_2*f5 + (uint64_t)f4*f4 + (uint64_t)f9*f9_19; + uint64_t r9 = (uint64_t)f0_2*f9 + (uint64_t)f1_2*f8 + (uint64_t)f2_2*f7 + + (uint64_t)f3_2*f6 + (uint64_t)f4_2*f5; + + uint32_t carry0 = (uint32_t)((r0 + (1<<25)) >> 26); + r1 += carry0; r0 -= ((uint64_t)carry0 << 26); + uint32_t carry4 = (uint32_t)((r4 + (1<<25)) >> 26); + r5 += carry4; r4 -= ((uint64_t)carry4 << 26); + uint32_t carry1 = (uint32_t)((r1 + (1<<24)) >> 25); + r2 += carry1; r1 -= ((uint64_t)carry1 << 25); + uint32_t carry5 = (uint32_t)((r5 + (1<<24)) >> 25); + r6 += carry5; r5 -= ((uint64_t)carry5 << 25); + uint32_t carry2 = (uint32_t)((r2 + (1<<25)) >> 26); + r3 += carry2; r2 -= ((uint64_t)carry2 << 26); + uint32_t carry6 = (uint32_t)((r6 + (1<<25)) >> 26); + r7 += carry6; r6 -= ((uint64_t)carry6 << 26); + uint32_t carry3 = (uint32_t)((r3 + (1<<24)) >> 25); + r4 += carry3; r3 -= ((uint64_t)carry3 << 25); + uint32_t carry7 = (uint32_t)((r7 + (1<<24)) >> 25); + r8 += carry7; r7 -= ((uint64_t)carry7 << 25); + uint32_t carry8 = (uint32_t)((r8 + (1<<24)) >> 25); + r9 += carry8; r8 -= ((uint64_t)carry8 << 25); + uint32_t carry9 = (uint32_t)((r9 + (1<<24)) >> 25); + r0 += (uint64_t)carry9 * 19; r9 -= ((uint64_t)carry9 << 25); + uint32_t carry0_2 = (uint32_t)((r0 + (1<<25)) >> 26); + r1 += carry0_2; r0 -= ((uint64_t)carry0_2 << 26); + + h[0] = (int32_t)(r0 & 0x3FFFFFF); + h[1] = (int32_t)(r1 & 0x1FFFFFF); + h[2] = (int32_t)(r2 & 0x3FFFFFF); + h[3] = (int32_t)(r3 & 0x1FFFFFF); + h[4] = (int32_t)(r4 & 0x3FFFFFF); + h[5] = (int32_t)(r5 & 0x1FFFFFF); + h[6] = (int32_t)(r6 & 0x3FFFFFF); + h[7] = (int32_t)(r7 & 0x1FFFFFF); + h[8] = (int32_t)(r8 & 0x3FFFFFF); + h[9] = (int32_t)(r9 & 0x1FFFFFF); +} +#else +/* Standard 64-bit fe_sq() */ static void fe_sq(fe h, const fe f) { int32_t f0=f[0],f1=f[1],f2=f[2],f3=f[3],f4=f[4],f5=f[5],f6=f[6],f7=f[7],f8=f[8],f9=f[9]; @@ -135,6 +335,7 @@ static void fe_sq(fe h, const fe f) h[0]=(int32_t)r0; h[1]=(int32_t)r1; h[2]=(int32_t)r2; h[3]=(int32_t)r3; h[4]=(int32_t)r4; h[5]=(int32_t)r5; h[6]=(int32_t)r6; h[7]=(int32_t)r7; h[8]=(int32_t)r8; h[9]=(int32_t)r9; } +#endif static void fe_inv(fe h, const fe f) { @@ -206,14 +407,9 @@ int se050_x25519_sw_generate_keypair(se050_x25519_sw_keypair_t *keypair, void *rng_ctx) { if (!keypair || !rng_func) return -1; - - if (rng_func(keypair->private_key, 32, rng_ctx) != 0) { - return -1; - } - + if (rng_func(keypair->private_key, 32, rng_ctx) != 0) return -1; se050_x25519_sw_clamp(keypair->private_key); x25519_sw(keypair->public_key, keypair->private_key, (const uint8_t*)"basepoint"); - return 0; } @@ -221,34 +417,24 @@ int se050_x25519_sw_compute_shared_secret(uint8_t *shared_secret, const uint8_t *private_key, const uint8_t *peer_public) { - if (!shared_secret || !private_key || !peer_public) { - return -1; - } - + if (!shared_secret || !private_key || !peer_public) return -1; uint8_t clamped[32]; memcpy(clamped, private_key, 32); se050_x25519_sw_clamp(clamped); - x25519_sw(shared_secret, clamped, peer_public); se050_x25519_sw_zeroize(clamped, 32); - return 0; } int se050_x25519_sw_derive_public_key(uint8_t *public_key, const uint8_t *private_key) { - if (!public_key || !private_key) { - return -1; - } - + if (!public_key || !private_key) return -1; uint8_t clamped[32]; memcpy(clamped, private_key, 32); se050_x25519_sw_clamp(clamped); - x25519_sw(public_key, clamped, (const uint8_t*)"basepoint"); se050_x25519_sw_zeroize(clamped, 32); - return 0; }