diff --git a/src/se050_x25519_sw.c b/src/se050_x25519_sw.c index a548bd9..3727b7a 100644 --- a/src/se050_x25519_sw.c +++ b/src/se050_x25519_sw.c @@ -1,7 +1,7 @@ /** * @file se050_x25519_sw.c - * @brief Software X25519 ECDH Implementation - * Based on RFC 7748 reference implementation + * @brief Software X25519 ECDH Implementation (Clean-room RFC7748) + * Based on RFC 7748 reference implementation with 5×51-bit limbs * License: MIT (Clean-room implementation) */ @@ -9,505 +9,444 @@ #include "se050_crypto_utils.h" #include -/* ESP32 detection */ -#if defined(ESP_PLATFORM) || defined(__XTENSA__) || defined(__riscv) -#define SE050_X25519_ESP32 1 -#else // Always use standard version -#define SE050_X25519_ESP32 0 -#endif - -typedef int32_t fe[10]; - -/* X25519 constants */ -#define CURVE_A 486662 -#define CURVE_A24 121665 /* (A - 2) / 4 */ - -static uint32_t load_3(const uint8_t *in) -{ return (uint32_t)in[0] | ((uint32_t)in[1] << 8) | ((uint32_t)in[2] << 16); } - -static uint64_t load_4(const uint8_t *in) -{ return (uint64_t)in[0] | ((uint64_t)in[1] << 8) | ((uint64_t)in[2] << 16) | ((uint64_t)in[3] << 24); } - -static void store_4(uint8_t *out, uint64_t in) -{ out[0] = (uint8_t)in; out[1] = (uint8_t)(in >> 8); out[2] = (uint8_t)(in >> 16); out[3] = (uint8_t)(in >> 24); } - -static void fe_0(fe h) { for (int i = 0; i < 10; i++) h[i] = 0; } -static void fe_1(fe h) { h[0] = 1; for (int i = 1; i < 10; i++) h[i] = 0; } - -static void fe_frombytes(fe h, const uint8_t *s) -{ - uint64_t h0 = load_4(s); - uint64_t h1 = load_3(s + 4) << 6; - uint64_t h2 = load_4(s + 7) >> 2; - uint64_t h3 = load_3(s + 11) << 5; - uint64_t h4 = load_3(s + 14) >> 1; - uint64_t h5 = load_4(s + 17) << 2; - uint64_t h6 = load_4(s + 21) >> 3; - uint64_t h7 = load_3(s + 24) << 6; - uint64_t h8 = load_3(s + 27) >> 1; - uint64_t h9 = (load_4(s + 30) & 0x7FFFFF) << 5; - h[0] = (int32_t)h0; h[1] = (int32_t)h1; h[2] = (int32_t)h2; h[3] = (int32_t)h3; - h[4] = (int32_t)h4; h[5] = (int32_t)h5; h[6] = (int32_t)h6; h[7] = (int32_t)h7; - h[8] = (int32_t)h8; h[9] = (int32_t)h9; -} - -static void fe_tobytes(uint8_t *s, const fe h) -{ - int32_t h0=h[0],h1=h[1],h2=h[2],h3=h[3],h4=h[4],h5=h[5],h6=h[6],h7=h[7],h8=h[8],h9=h[9]; - int32_t carry9=(h9+65536)>>16; h0+=carry9*19; h9-=carry9<<16; - int32_t carry1=(h1+65536)>>16; h2+=carry1; h1-=carry1<<16; - int32_t carry3=(h3+65536)>>16; h4+=carry3; h3-=carry3<<16; - int32_t carry5=(h5+65536)>>16; h6+=carry5; h5-=carry5<<25; - int32_t carry7=(h7+65536)>>16; h8+=carry7; h7-=carry7<<16; - int32_t carry0=(h0+65536)>>16; h1+=carry0; h0-=carry0<<16; - int32_t carry2=(h2+65536)>>16; h3+=carry2; h2-=carry2<<16; - int32_t carry4=(h4+65536)>>16; h5+=carry4; h4-=carry4<<16; - int32_t carry6=(h6+65536)>>16; h7+=carry6; h6-=carry6<<16; - store_4(s, h0); - store_4(s+4, h1); - store_4(s+8, h2); - store_4(s+12, h3); - store_4(s+16, h4); - store_4(s+20, h5); - store_4(s+24, h6); - s[28] = h7 & 0xff; - s[29] = (h7 >> 8) & 0xff; - s[30] = (h7 >> 16) & 0xff; - s[31] = ((h7 >> 24) | ((h8 & 0x0f) << 4)) & 0xff; -} - -/* Field operations from RFC 7748 ref10 implementation */ - -/* h = f + g */ -static void fe_add(fe h, const fe f, const fe g) -{ - for (int i = 0; i < 10; i++) { - h[i] = f[i] + g[i]; - } - /* Carry propagation */ - int32_t carry; - carry = (h[0] + 65536) >> 16; h[1] += carry; h[0] -= carry << 16; - carry = (h[2] + 65536) >> 16; h[3] += carry; h[2] -= carry << 16; - carry = (h[4] + 65536) >> 16; h[5] += carry; h[4] -= carry << 16; - carry = (h[6] + 65536) >> 16; h[7] += carry; h[6] -= carry << 16; - carry = (h[8] + 65536) >> 16; h[9] += carry; h[8] -= carry << 16; - carry = (h[1] + 65536) >> 16; h[2] += carry; h[1] -= carry << 16; - carry = (h[3] + 65536) >> 16; h[4] += carry; h[3] -= carry << 16; - carry = (h[5] + 65536) >> 16; h[6] += carry; h[5] -= carry << 16; - carry = (h[7] + 65536) >> 16; h[8] += carry; h[7] -= carry << 16; - carry = (h[9] + 65536) >> 16; h[0] += carry * 19; h[9] -= carry << 16; - carry = (h[0] + 65536) >> 16; h[1] += carry; h[0] -= carry << 16; -} - -/* h = f - g */ -static void fe_sub(fe h, const fe f, const fe g) -{ - for (int i = 0; i < 10; i++) { - h[i] = f[i] - g[i]; - } - /* Carry propagation */ - int32_t carry; - carry = (h[0] + 65536) >> 16; h[1] += carry; h[0] -= carry << 16; - carry = (h[2] + 65536) >> 16; h[3] += carry; h[2] -= carry << 16; - carry = (h[4] + 65536) >> 16; h[5] += carry; h[4] -= carry << 16; - carry = (h[6] + 65536) >> 16; h[7] += carry; h[6] -= carry << 16; - carry = (h[8] + 65536) >> 16; h[9] += carry; h[8] -= carry << 16; - carry = (h[1] + 65536) >> 16; h[2] += carry; h[1] -= carry << 16; - carry = (h[3] + 65536) >> 16; h[4] += carry; h[3] -= carry << 16; - carry = (h[5] + 65536) >> 16; h[6] += carry; h[5] -= carry << 16; - carry = (h[7] + 65536) >> 16; h[8] += carry; h[7] -= carry << 16; - carry = (h[9] + 65536) >> 16; h[0] += carry * 19; h[9] -= carry << 16; - carry = (h[0] + 65536) >> 16; h[1] += carry; h[0] -= carry << 16; - /* Normalize negative values */ - for (int i = 0; i < 10; i++) { - while (h[i] < 0) { - h[i] += (i & 1) ? 65536 : 1048576; - if (i < 9) h[i+1]--; - else h[0] -= 19; - } - } -} - -static void fe_copy(fe h, const fe f) -{ for (int i = 0; i < 10; i++) h[i] = f[i]; } - -static void fe_cswap(fe f, fe g, int b) -{ - int32_t mask = -b; - for (int i = 0; i < 10; i++) { - int32_t x = (f[i] ^ g[i]) & mask; - f[i] ^= x; g[i] ^= x; - } -} - -#if 0 -/* ============================================================================ - * ESP32 32-bit Optimized fe_mul() +/* ========================================================================= + * Field GF(2^255-19) * - * Avoids 64-bit arithmetic for better performance on 32-bit CPUs. - * Uses 32-bit intermediates with careful carry handling. - * ============================================================================ */ - -static void fe_mul(fe h, const fe f, const fe g) -{ - /* Use 32-bit arithmetic only */ - uint32_t f0 = (uint32_t)f[0], f1 = (uint32_t)f[1], f2 = (uint32_t)f[2]; - uint32_t f3 = (uint32_t)f[3], f4 = (uint32_t)f[4], f5 = (uint32_t)f[5]; - uint32_t f6 = (uint32_t)f[6], f7 = (uint32_t)f[7], f8 = (uint32_t)f[8]; - uint32_t f9 = (uint32_t)f[9]; - uint32_t g0 = (uint32_t)g[0], g1 = (uint32_t)g[1], g2 = (uint32_t)g[2]; - uint32_t g3 = (uint32_t)g[3], g4 = (uint32_t)g[4], g5 = (uint32_t)g[5]; - uint32_t g6 = (uint32_t)g[6], g7 = (uint32_t)g[7], g8 = (uint32_t)g[8]; - uint32_t g9 = (uint32_t)g[9]; - - /* Precompute g*19 (fits in 32-bit: 19 * 2^26 < 2^32) */ - uint32_t g1_19 = g1 * 19, g2_19 = g2 * 19, g3_19 = g3 * 19; - uint32_t g4_19 = g4 * 19, g5_19 = g5 * 19, g6_19 = g6 * 19; - uint32_t g7_19 = g7 * 19, g8_19 = g8 * 19, g9_19 = g9 * 19; - - /* Compute products using 64-bit temporarily, then split */ - /* r0 = f0*g0 + f1*g9_19 + f2*g8_19 + f3*g7_19 + f4*g6_19 + f5*g5_19 + f6*g4_19 + f7*g3_19 + f8*g2_19 + f9*g1_19 */ - uint64_t r0 = (uint64_t)f0*g0 + (uint64_t)f1*g9_19 + (uint64_t)f2*g8_19 + - (uint64_t)f3*g7_19 + (uint64_t)f4*g6_19 + (uint64_t)f5*g5_19 + - (uint64_t)f6*g4_19 + (uint64_t)f7*g3_19 + (uint64_t)f8*g2_19 + - (uint64_t)f9*g1_19; - uint64_t r1 = (uint64_t)f0*g1 + (uint64_t)f1*g0 + (uint64_t)f2*g9_19 + - (uint64_t)f3*g8_19 + (uint64_t)f4*g7_19 + (uint64_t)f5*g6_19 + - (uint64_t)f6*g5_19 + (uint64_t)f7*g4_19 + (uint64_t)f8*g3_19 + - (uint64_t)f9*g2_19; - uint64_t r2 = (uint64_t)f0*g2 + (uint64_t)f1*g1*2 + (uint64_t)f2*g0 + - (uint64_t)f3*g9_19 + (uint64_t)f4*g8_19 + (uint64_t)f5*g7_19 + - (uint64_t)f6*g6_19 + (uint64_t)f7*g5_19 + (uint64_t)f8*g4_19 + - (uint64_t)f9*g3_19; - uint64_t r3 = (uint64_t)f0*g3 + (uint64_t)f1*g2 + (uint64_t)f2*g1 + - (uint64_t)f3*g0 + (uint64_t)f4*g9_19 + (uint64_t)f5*g8_19 + - (uint64_t)f6*g7_19 + (uint64_t)f7*g6_19 + (uint64_t)f8*g5_19 + - (uint64_t)f9*g4_19; - uint64_t r4 = (uint64_t)f0*g4 + (uint64_t)f1*g3*2 + (uint64_t)f2*g2 + - (uint64_t)f3*g1*2 + (uint64_t)f4*g0 + (uint64_t)f5*g9_19 + - (uint64_t)f6*g8_19 + (uint64_t)f7*g7_19 + (uint64_t)f8*g6_19 + - (uint64_t)f9*g5_19; - uint64_t r5 = (uint64_t)f0*g5 + (uint64_t)f1*g4 + (uint64_t)f2*g3 + - (uint64_t)f3*g2 + (uint64_t)f4*g1 + (uint64_t)f5*g0 + - (uint64_t)f6*g9_19 + (uint64_t)f7*g8_19 + (uint64_t)f8*g7_19 + - (uint64_t)f9*g6_19; - uint64_t r6 = (uint64_t)f0*g6 + (uint64_t)f1*g5*2 + (uint64_t)f2*g4 + - (uint64_t)f3*g3*2 + (uint64_t)f4*g2 + (uint64_t)f5*g1*2 + - (uint64_t)f6*g0 + (uint64_t)f7*g9_19 + (uint64_t)f8*g8_19 + - (uint64_t)f9*g7_19; - uint64_t r7 = (uint64_t)f0*g7 + (uint64_t)f1*g6 + (uint64_t)f2*g5 + - (uint64_t)f3*g4 + (uint64_t)f4*g3 + (uint64_t)f5*g2 + - (uint64_t)f6*g1 + (uint64_t)f7*g0 + (uint64_t)f8*g9_19 + - (uint64_t)f9*g8_19; - uint64_t r8 = (uint64_t)f0*g8 + (uint64_t)f1*g7*2 + (uint64_t)f2*g6 + - (uint64_t)f3*g5*2 + (uint64_t)f4*g4 + (uint64_t)f5*g3*2 + - (uint64_t)f6*g2 + (uint64_t)f7*g1*2 + (uint64_t)f8*g0 + - (uint64_t)f9*g9_19; - uint64_t r9 = (uint64_t)f0*g9 + (uint64_t)f1*g8 + (uint64_t)f2*g7 + - (uint64_t)f3*g6 + (uint64_t)f4*g5 + (uint64_t)f5*g4 + - (uint64_t)f6*g3 + (uint64_t)f7*g2 + (uint64_t)f8*g1 + - (uint64_t)f9*g0; - - /* Propagate carries (same as 64-bit version) */ - uint32_t carry0 = (uint32_t)((r0 + (1<<25)) >> 26); - r1 += carry0; r0 -= ((uint64_t)carry0 << 26); - - uint32_t carry4 = (uint32_t)((r4 + (1<<25)) >> 26); - r5 += carry4; r4 -= ((uint64_t)carry4 << 26); - - uint32_t carry1 = (uint32_t)((r1 + (1<<24)) >> 25); - r2 += carry1; r1 -= ((uint64_t)carry1 << 25); - - uint32_t carry5 = (uint32_t)((r5 + (1<<24)) >> 25); - r6 += carry5; r5 -= ((uint64_t)carry5 << 25); - - uint32_t carry2 = (uint32_t)((r2 + (1<<25)) >> 26); - r3 += carry2; r2 -= ((uint64_t)carry2 << 26); - - uint32_t carry6 = (uint32_t)((r6 + (1<<25)) >> 26); - r7 += carry6; r6 -= ((uint64_t)carry6 << 26); - - uint32_t carry3 = (uint32_t)((r3 + (1<<24)) >> 25); - r4 += carry3; r3 -= ((uint64_t)carry3 << 25); - - uint32_t carry7 = (uint32_t)((r7 + (1<<24)) >> 25); - r8 += carry7; r7 -= ((uint64_t)carry7 << 25); - - uint32_t carry8 = (uint32_t)((r8 + (1<<24)) >> 25); - r9 += carry8; r8 -= ((uint64_t)carry8 << 25); - - uint32_t carry9 = (uint32_t)((r9 + (1<<24)) >> 25); - r0 += (uint64_t)carry9 * 19; r9 -= ((uint64_t)carry9 << 25); - - uint32_t carry0_2 = (uint32_t)((r0 + (1<<25)) >> 26); - r1 += carry0_2; r0 -= ((uint64_t)carry0_2 << 26); - - h[0] = (int32_t)(r0 & 0x3FFFFFF); - h[1] = (int32_t)(r1 & 0x1FFFFFF); - h[2] = (int32_t)(r2 & 0x3FFFFFF); - h[3] = (int32_t)(r3 & 0x1FFFFFF); - h[4] = (int32_t)(r4 & 0x3FFFFFF); - h[5] = (int32_t)(r5 & 0x1FFFFFF); - h[6] = (int32_t)(r6 & 0x3FFFFFF); - h[7] = (int32_t)(r7 & 0x1FFFFFF); - h[8] = (int32_t)(r8 & 0x3FFFFFF); - h[9] = (int32_t)(r9 & 0x1FFFFFF); -} - -#else // Always use standard version -/* ============================================================================ - * Standard 64-bit fe_mul() + * We represent field elements as arrays of 5 uint64_t limbs in radix 2^51. + * Each limb holds at most 51 bits in "loose" form. * - * Uses 64-bit arithmetic for cleaner code on 64-bit platforms. - * ============================================================================ */ + * value = limb[0] + limb[1] * 2^51 + limb[2] * 2^102 + limb[3] * 2^153 + limb[4] * 2^204 + * + * p = 2^255 - 19, so 2^255 ≡ 19 (mod p) + * ========================================================================= */ -static void fe_mul(fe h, const fe f, const fe g) -{ - int32_t f0=f[0],f1=f[1],f2=f[2],f3=f[3],f4=f[4],f5=f[5],f6=f[6],f7=f[7],f8=f[8],f9=f[9]; - int32_t g0=g[0],g1=g[1],g2=g[2],g3=g[3],g4=g[4],g5=g[5],g6=g[6],g7=g[7],g8=g[8],g9=g[9]; - int64_t g1_19=g1*19, g2_19=g2*19, g3_19=g3*19, g4_19=g4*19, g5_19=g5*19; - int64_t g6_19=g6*19, g7_19=g7*19, g8_19=g8*19, g9_19=g9*19; - int64_t r0=(int64_t)f0*g0+(int64_t)f1*g9_19+(int64_t)f2*g8_19+(int64_t)f3*g7_19+(int64_t)f4*g6_19+(int64_t)f5*g5_19+(int64_t)f6*g4_19+(int64_t)f7*g3_19+(int64_t)f8*g2_19+(int64_t)f9*g1_19; - int64_t r1=(int64_t)f0*g1+(int64_t)f1*g0+(int64_t)f2*g9_19+(int64_t)f3*g8_19+(int64_t)f4*g7_19+(int64_t)f5*g6_19+(int64_t)f6*g5_19+(int64_t)f7*g4_19+(int64_t)f8*g3_19+(int64_t)f9*g2_19; - int64_t r2=(int64_t)f0*g2+(int64_t)f1*g1*2+(int64_t)f2*g0+(int64_t)f3*g9_19+(int64_t)f4*g8_19+(int64_t)f5*g7_19+(int64_t)f6*g6_19+(int64_t)f7*g5_19+(int64_t)f8*g4_19+(int64_t)f9*g3_19; - int64_t r3=(int64_t)f0*g3+(int64_t)f1*g2+(int64_t)f2*g1+(int64_t)f3*g0+(int64_t)f4*g9_19+(int64_t)f5*g8_19+(int64_t)f6*g7_19+(int64_t)f7*g6_19+(int64_t)f8*g5_19+(int64_t)f9*g4_19; - int64_t r4=(int64_t)f0*g4+(int64_t)f1*g3*2+(int64_t)f2*g2+(int64_t)f3*g1*2+(int64_t)f4*g0+(int64_t)f5*g9_19+(int64_t)f6*g8_19+(int64_t)f7*g7_19+(int64_t)f8*g6_19+(int64_t)f9*g5_19; - int64_t r5=(int64_t)f0*g5+(int64_t)f1*g4+(int64_t)f2*g3+(int64_t)f3*g2+(int64_t)f4*g1+(int64_t)f5*g0+(int64_t)f6*g9_19+(int64_t)f7*g8_19+(int64_t)f8*g7_19+(int64_t)f9*g6_19; - int64_t r6=(int64_t)f0*g6+(int64_t)f1*g5*2+(int64_t)f2*g4+(int64_t)f3*g3*2+(int64_t)f4*g2+(int64_t)f5*g1*2+(int64_t)f6*g0+(int64_t)f7*g9_19+(int64_t)f8*g8_19+(int64_t)f9*g7_19; - int64_t r7=(int64_t)f0*g7+(int64_t)f1*g6+(int64_t)f2*g5+(int64_t)f3*g4+(int64_t)f4*g3+(int64_t)f5*g2+(int64_t)f6*g1+(int64_t)f7*g0+(int64_t)f8*g9_19+(int64_t)f9*g8_19; - int64_t r8=(int64_t)f0*g8+(int64_t)f1*g7*2+(int64_t)f2*g6+(int64_t)f3*g5*2+(int64_t)f4*g4+(int64_t)f5*g3*2+(int64_t)f6*g2+(int64_t)f7*g1*2+(int64_t)f8*g0+(int64_t)f9*g9_19; - int64_t r9=(int64_t)f0*g9+(int64_t)f1*g8+(int64_t)f2*g7+(int64_t)f3*g6+(int64_t)f4*g5+(int64_t)f5*g4+(int64_t)f6*g3+(int64_t)f7*g2+(int64_t)f8*g1+(int64_t)f9*g0; - int64_t carry0=(r0+(1<<25))>>26; r1+=carry0; r0-=carry0<<26; - int64_t carry4=(r4+(1<<25))>>26; r5+=carry4; r4-=carry4<<26; - int64_t carry1=(r1+(1<<24))>>25; r2+=carry1; r1-=carry1<<25; - int64_t carry5=(r5+(1<<24))>>25; r6+=carry5; r5-=carry5<<25; - int64_t carry2=(r2+(1<<25))>>26; r3+=carry2; r2-=carry2<<26; - int64_t carry6=(r6+(1<<25))>>26; r7+=carry6; r6-=carry6<<26; - int64_t carry3=(r3+(1<<24))>>25; r4+=carry3; r3-=carry3<<25; - int64_t carry7=(r7+(1<<24))>>25; r8+=carry7; r7-=carry7<<25; - int64_t carry8=(r8+(1<<24))>>25; r9+=carry8; r8-=carry8<<25; - int64_t carry9=(r9+(1<<24))>>25; r0+=carry9*19; r9-=carry9<<25; - int64_t carry0_2=(r0+(1<<25))>>26; r1+=carry0_2; r0-=carry0_2<<26; - h[0]=(int32_t)r0; h[1]=(int32_t)r1; h[2]=(int32_t)r2; h[3]=(int32_t)r3; h[4]=(int32_t)r4; - h[5]=(int32_t)r5; h[6]=(int32_t)r6; h[7]=(int32_t)r7; h[8]=(int32_t)r8; h[9]=(int32_t)r9; -} -#endif +#define NLIMBS 5 +typedef uint64_t fe[NLIMBS]; /* field element */ -#if 0 -/* ESP32 32-bit optimized fe_sq() */ -static void fe_sq(fe h, const fe f) -{ - uint32_t f0 = (uint32_t)f[0], f1 = (uint32_t)f[1], f2 = (uint32_t)f[2]; - uint32_t f3 = (uint32_t)f[3], f4 = (uint32_t)f[4], f5 = (uint32_t)f[5]; - uint32_t f6 = (uint32_t)f[6], f7 = (uint32_t)f[7], f8 = (uint32_t)f[8], f9 = (uint32_t)f[9]; - - uint32_t f0_2 = f0 * 2, f1_2 = f1 * 2, f2_2 = f2 * 2; - uint32_t f3_2 = f3 * 2, f4_2 = f4 * 2, f5_2 = f5 * 2; - uint32_t f6_2 = f6 * 2, f7_2 = f7 * 2, f8_2 = f8 * 2; - uint32_t f9_19 = f9 * 19, f8_19 = f8 * 19, f7_19 = f7 * 19; - uint32_t f6_19 = f6 * 19, f5_19 = f5 * 19, f4_19 = f4 * 19, f3_19 = f3 * 19; - - uint64_t r0 = (uint64_t)f0*f0 + (uint64_t)f1_2*f9_19 + (uint64_t)f2_2*f8_19 + - (uint64_t)f3_2*f7_19 + (uint64_t)f4_2*f6_19 + (uint64_t)f5*f5_19; - uint64_t r1 = (uint64_t)f0_2*f1 + (uint64_t)f2_2*f9_19 + (uint64_t)f3_2*f8_19 + - (uint64_t)f4_2*f7_19 + (uint64_t)f5_2*f6_19; - uint64_t r2 = (uint64_t)f0_2*f2 + (uint64_t)f1*f1 + (uint64_t)f3_2*f9_19 + - (uint64_t)f4_2*f8_19 + (uint64_t)f5_2*f7_19 + (uint64_t)f6*f6_19; - uint64_t r3 = (uint64_t)f0_2*f3 + (uint64_t)f1_2*f2 + (uint64_t)f4_2*f9_19 + - (uint64_t)f5_2*f8_19 + (uint64_t)f6_2*f7_19; - uint64_t r4 = (uint64_t)f0_2*f4 + (uint64_t)f1_2*f3 + (uint64_t)f2*f2 + - (uint64_t)f5_2*f9_19 + (uint64_t)f6_2*f8_19 + (uint64_t)f7*f7_19; - uint64_t r5 = (uint64_t)f0_2*f5 + (uint64_t)f1_2*f4 + (uint64_t)f2_2*f3 + - (uint64_t)f6_2*f9_19 + (uint64_t)f7_2*f8_19; - uint64_t r6 = (uint64_t)f0_2*f6 + (uint64_t)f1_2*f5 + (uint64_t)f2_2*f4 + - (uint64_t)f3*f3 + (uint64_t)f7_2*f9_19 + (uint64_t)f8*f8_19; - uint64_t r7 = (uint64_t)f0_2*f7 + (uint64_t)f1_2*f6 + (uint64_t)f2_2*f5 + - (uint64_t)f3_2*f4 + (uint64_t)f8_2*f9_19; - uint64_t r8 = (uint64_t)f0_2*f8 + (uint64_t)f1_2*f7 + (uint64_t)f2_2*f6 + - (uint64_t)f3_2*f5 + (uint64_t)f4*f4 + (uint64_t)f9*f9_19; - uint64_t r9 = (uint64_t)f0_2*f9 + (uint64_t)f1_2*f8 + (uint64_t)f2_2*f7 + - (uint64_t)f3_2*f6 + (uint64_t)f4_2*f5; - - uint32_t carry0 = (uint32_t)((r0 + (1<<25)) >> 26); - r1 += carry0; r0 -= ((uint64_t)carry0 << 26); - uint32_t carry4 = (uint32_t)((r4 + (1<<25)) >> 26); - r5 += carry4; r4 -= ((uint64_t)carry4 << 26); - uint32_t carry1 = (uint32_t)((r1 + (1<<24)) >> 25); - r2 += carry1; r1 -= ((uint64_t)carry1 << 25); - uint32_t carry5 = (uint32_t)((r5 + (1<<24)) >> 25); - r6 += carry5; r5 -= ((uint64_t)carry5 << 25); - uint32_t carry2 = (uint32_t)((r2 + (1<<25)) >> 26); - r3 += carry2; r2 -= ((uint64_t)carry2 << 26); - uint32_t carry6 = (uint32_t)((r6 + (1<<25)) >> 26); - r7 += carry6; r6 -= ((uint64_t)carry6 << 26); - uint32_t carry3 = (uint32_t)((r3 + (1<<24)) >> 25); - r4 += carry3; r3 -= ((uint64_t)carry3 << 25); - uint32_t carry7 = (uint32_t)((r7 + (1<<24)) >> 25); - r8 += carry7; r7 -= ((uint64_t)carry7 << 25); - uint32_t carry8 = (uint32_t)((r8 + (1<<24)) >> 25); - r9 += carry8; r8 -= ((uint64_t)carry8 << 25); - uint32_t carry9 = (uint32_t)((r9 + (1<<24)) >> 25); - r0 += (uint64_t)carry9 * 19; r9 -= ((uint64_t)carry9 << 25); - uint32_t carry0_2 = (uint32_t)((r0 + (1<<25)) >> 26); - r1 += carry0_2; r0 -= ((uint64_t)carry0_2 << 26); - - h[0] = (int32_t)(r0 & 0x3FFFFFF); - h[1] = (int32_t)(r1 & 0x1FFFFFF); - h[2] = (int32_t)(r2 & 0x3FFFFFF); - h[3] = (int32_t)(r3 & 0x1FFFFFF); - h[4] = (int32_t)(r4 & 0x3FFFFFF); - h[5] = (int32_t)(r5 & 0x1FFFFFF); - h[6] = (int32_t)(r6 & 0x3FFFFFF); - h[7] = (int32_t)(r7 & 0x1FFFFFF); - h[8] = (int32_t)(r8 & 0x3FFFFFF); - h[9] = (int32_t)(r9 & 0x1FFFFFF); -} -#else // Always use standard version -/* Standard 64-bit fe_sq() */ -static void fe_sq(fe h, const fe f) -{ - int32_t f0=f[0],f1=f[1],f2=f[2],f3=f[3],f4=f[4],f5=f[5],f6=f[6],f7=f[7],f8=f[8],f9=f[9]; - int32_t f0_2=f0*2,f1_2=f1*2,f2_2=f2*2,f3_2=f3*2,f4_2=f4*2,f5_2=f5*2,f6_2=f6*2,f7_2=f7*2,f8_2=f8*2; - int32_t f9_19=f9*19,f8_19=f8*19,f7_19=f7*19,f6_19=f6*19,f5_19=f5*19,f4_19=f4*19,f3_19=f3*19; - int64_t r0=(int64_t)f0*f0+(int64_t)f1_2*f9_19+(int64_t)f2_2*f8_19+(int64_t)f3_2*f7_19+(int64_t)f4_2*f6_19+(int64_t)f5*f5_19; - int64_t r1=(int64_t)f0_2*f1+(int64_t)f2_2*f9_19+(int64_t)f3_2*f8_19+(int64_t)f4_2*f7_19+(int64_t)f5_2*f6_19; - int64_t r2=(int64_t)f0_2*f2+(int64_t)f1*f1+(int64_t)f3_2*f9_19+(int64_t)f4_2*f8_19+(int64_t)f5_2*f7_19+(int64_t)f6*f6_19; - int64_t r3=(int64_t)f0_2*f3+(int64_t)f1_2*f2+(int64_t)f4_2*f9_19+(int64_t)f5_2*f8_19+(int64_t)f6_2*f7_19; - int64_t r4=(int64_t)f0_2*f4+(int64_t)f1_2*f3+(int64_t)f2*f2+(int64_t)f5_2*f9_19+(int64_t)f6_2*f8_19+(int64_t)f7*f7_19; - int64_t r5=(int64_t)f0_2*f5+(int64_t)f1_2*f4+(int64_t)f2_2*f3+(int64_t)f6_2*f9_19+(int64_t)f7_2*f8_19; - int64_t r6=(int64_t)f0_2*f6+(int64_t)f1_2*f5+(int64_t)f2_2*f4+(int64_t)f3*f3+(int64_t)f7_2*f9_19+(int64_t)f8*f8_19; - int64_t r7=(int64_t)f0_2*f7+(int64_t)f1_2*f6+(int64_t)f2_2*f5+(int64_t)f3_2*f4+(int64_t)f8_2*f9_19; - int64_t r8=(int64_t)f0_2*f8+(int64_t)f1_2*f7+(int64_t)f2_2*f6+(int64_t)f3_2*f5+(int64_t)f4*f4+(int64_t)f9*f9_19; - int64_t r9=(int64_t)f0_2*f9+(int64_t)f1_2*f8+(int64_t)f2_2*f7+(int64_t)f3_2*f6+(int64_t)f4_2*f5; - int64_t carry0=(r0+(1<<25))>>26; r1+=carry0; r0-=carry0<<26; - int64_t carry4=(r4+(1<<25))>>26; r5+=carry4; r4-=carry4<<26; - int64_t carry1=(r1+(1<<24))>>25; r2+=carry1; r1-=carry1<<25; - int64_t carry5=(r5+(1<<24))>>25; r6+=carry5; r5-=carry5<<25; - int64_t carry2=(r2+(1<<25))>>26; r3+=carry2; r2-=carry2<<26; - int64_t carry6=(r6+(1<<25))>>26; r7+=carry6; r6-=carry6<<26; - int64_t carry3=(r3+(1<<24))>>25; r4+=carry3; r3-=carry3<<25; - int64_t carry7=(r7+(1<<24))>>25; r8+=carry7; r7-=carry7<<25; - int64_t carry8=(r8+(1<<24))>>25; r9+=carry8; r8-=carry8<<25; - int64_t carry9=(r9+(1<<24))>>25; r0+=carry9*19; r9-=carry9<<25; - int64_t carry0_2=(r0+(1<<25))>>26; r1+=carry0_2; r0-=carry0_2<<26; - h[0]=(int32_t)r0; h[1]=(int32_t)r1; h[2]=(int32_t)r2; h[3]=(int32_t)r3; h[4]=(int32_t)r4; - h[5]=(int32_t)r5; h[6]=(int32_t)r6; h[7]=(int32_t)r7; h[8]=(int32_t)r8; h[9]=(int32_t)r9; -} -#endif +#define L51 ((uint64_t)1 << 51) +#define MASK51 (L51 - 1) -static void fe_inv(fe h, const fe f) +/* 128-bit helpers */ +static inline uint64_t u128_lo(unsigned __int128 x) { return (uint64_t)x; } +static inline uint64_t u128_hi(unsigned __int128 x) { return (uint64_t)(x >> 64); } + +/* --- Basic operations --- */ + +static void fe_zero(fe f) { f[0] = f[1] = f[2] = f[3] = f[4] = 0; } +static void fe_one(fe f) { f[0] = 1; f[1] = f[2] = f[3] = f[4] = 0; } + +static void fe_copy(fe out, const fe in) { - fe t0, t1; - fe_sq(t0, f); fe_sq(t1, t0); fe_sq(t1, t1); fe_mul(t1, f, t1); - fe_mul(t0, t0, t1); fe_sq(t0, t0); fe_mul(t0, t1, t0); fe_sq(t1, t0); - for (int i = 0; i < 4; i++) fe_sq(t1, t1); fe_mul(t0, t1, t0); - fe_sq(t1, t0); for (int i = 0; i < 9; i++) fe_sq(t1, t1); - fe_mul(t1, t1, t0); fe_sq(t1, t1); for (int i = 0; i < 19; i++) fe_sq(t1, t1); - fe_mul(t1, t1, t0); fe_sq(t1, t1); for (int i = 0; i < 9; i++) fe_sq(t1, t1); - fe_mul(t0, t1, t0); fe_sq(t0, t0); for (int i = 0; i < 49; i++) fe_sq(t0, t0); - fe_mul(t0, t0, t1); fe_sq(t0, t0); for (int i = 0; i < 9; i++) fe_sq(t0, t0); - fe_mul(t1, t0, t1); fe_sq(t1, t1); for (int i = 0; i < 99; i++) fe_sq(t1, t1); - fe_mul(t1, t1, t0); fe_sq(t1, t1); for (int i = 0; i < 49; i++) fe_sq(t1, t1); - fe_mul(t0, t1, t0); fe_sq(t0, t0); for (int i = 0; i < 9; i++) fe_sq(t0, t0); - fe_mul(t0, t0, t1); fe_sq(t0, t0); for (int i = 0; i < 4; i++) fe_sq(t0, t0); - fe_mul(h, t0, t1); + out[0] = in[0]; out[1] = in[1]; out[2] = in[2]; out[3] = in[3]; out[4] = in[4]; } -static void x25519_sw(uint8_t *out, const uint8_t *scalar, const uint8_t *point) +/* fe_add: out = a + b (loose, ≤ 2·2^51) */ +static void fe_add(fe out, const fe a, const fe b) { - fe x2, z2, x3, z3, a, aa, b, bb, e, c, d, da, cb, u_coord; - uint8_t e_arr[32]; - - /* Python: scalar = scalar_clamp(scalar) */ - memcpy(e_arr, scalar, 32); - e_arr[0] &= 248; e_arr[31] &= 127; e_arr[31] |= 64; - - /* Python: u = bytes_to_int(u_bytes) */ - fe_frombytes(u_coord, point); - - /* Python: x2, z2 = 1, 0 */ - fe_1(x2); fe_0(z2); - - /* Python: x3, z3 = u, 1 */ - fe_copy(x3, u_coord); fe_1(z3); - - /* Python: for i in range(254, -1, -1): */ - for (int i = 254; i >= 0; i--) { - /* Python: bit = (scalar[i // 8] >> (i % 8)) & 1 */ - int bit = (e_arr[i/8] >> (i&7)) & 1; - - /* Python: swap = bit; if swap: x2, x3 = x3, x2; z2, z3 = z3, z2 */ - if (bit) { - fe_cswap(x2, x3, 1); - fe_cswap(z2, z3, 1); - } - - /* Python: a = (x2 + z2) % P */ - fe_add(a, x2, z2); - - /* Python: aa = (a * a) % P */ - fe_sq(aa, a); - - /* Python: b = (x2 - z2) % P */ - fe_sub(b, x2, z2); - - /* Python: bb = (b * b) % P */ - fe_sq(bb, b); - - /* Python: e = (aa - bb) % P */ - fe_sub(e, aa, bb); - - /* Python: c = (x3 + z3) % P */ - fe_add(c, x3, z3); - - /* Python: d = (x3 - z3) % P */ - fe_sub(d, x3, z3); - - /* Python: da = (d * a) % P */ - fe_mul(da, d, a); - - /* Python: cb = (c * b) % P */ - fe_mul(cb, c, b); - - /* Python: x3 = ((da + cb) * (da + cb)) % P */ - fe_add(a, da, cb); - fe_sq(x3, a); - - /* Python: z3 = (u * ((da - cb) * (da - cb))) % P */ - fe_sub(a, da, cb); - fe_sq(z3, a); - fe_mul(z3, z3, u_coord); - - /* Python: x2 = (aa * bb) % P */ - fe_mul(x2, aa, bb); - - /* Python: z2 = (e * (aa + ((A - 2) // 4) * e)) % P */ - /* (A - 2) // 4 = 121665 */ - /* For proper A24 multiplication, we need to multiply field element e by scalar 121665 */ - /* Simplified: treat A24 as field element */ - fe a24 = {121665, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - fe_mul(a, a24, e); /* a = A24 * e */ - fe_add(aa, aa, a); /* aa = aa + A24*e */ - fe_mul(z2, e, aa); /* z2 = e * (aa + A24*e) */ - - /* Python: if swap: x2, x3 = x3, x2; z2, z3 = z3, z2 */ - if (bit) { - fe_cswap(x2, x3, 1); - fe_cswap(z2, z3, 1); - } + out[0] = a[0] + b[0]; + out[1] = a[1] + b[1]; + out[2] = a[2] + b[2]; + out[3] = a[3] + b[3]; + out[4] = a[4] + b[4]; +} + +/* fe_sub: out = a - b (loose, uses bias to avoid underflow) */ +static void fe_sub(fe out, const fe a, const fe b) +{ + out[0] = a[0] + 2*(L51 - 19) - b[0]; + out[1] = a[1] + 2*(L51 - 1) - b[1]; + out[2] = a[2] + 2*(L51 - 1) - b[2]; + out[3] = a[3] + 2*(L51 - 1) - b[3]; + out[4] = a[4] + 2*(L51 - 1) - b[4]; +} + +/* fe_reduce: propagate carries, keep limbs < 2^51 */ +static void fe_reduce(fe f) +{ + uint64_t c; + c = f[0] >> 51; f[0] &= MASK51; f[1] += c; + c = f[1] >> 51; f[1] &= MASK51; f[2] += c; + c = f[2] >> 51; f[2] &= MASK51; f[3] += c; + c = f[3] >> 51; f[3] &= MASK51; f[4] += c; + c = f[4] >> 51; f[4] &= MASK51; f[0] += 19 * c; + c = f[0] >> 51; f[0] &= MASK51; f[1] += c; +} + +/* --- Multiplication --- */ + +/* fe_mul: out = a * b mod p (128-bit accumulators) */ +static void fe_mul(fe out, const fe a, const fe b) +{ + unsigned __int128 t0, t1, t2, t3, t4; + uint64_t c; + uint64_t b1_19 = 19 * b[1], b2_19 = 19 * b[2], b3_19 = 19 * b[3], b4_19 = 19 * b[4]; + + t0 = (unsigned __int128)a[0] * b[0]; + t0 += (unsigned __int128)a[1] * b4_19; + t0 += (unsigned __int128)a[2] * b3_19; + t0 += (unsigned __int128)a[3] * b2_19; + t0 += (unsigned __int128)a[4] * b1_19; + + t1 = (unsigned __int128)a[0] * b[1]; + t1 += (unsigned __int128)a[1] * b[0]; + t1 += (unsigned __int128)a[2] * b4_19; + t1 += (unsigned __int128)a[3] * b3_19; + t1 += (unsigned __int128)a[4] * b2_19; + + t2 = (unsigned __int128)a[0] * b[2]; + t2 += (unsigned __int128)a[1] * b[1]; + t2 += (unsigned __int128)a[2] * b[0]; + t2 += (unsigned __int128)a[3] * b4_19; + t2 += (unsigned __int128)a[4] * b3_19; + + t3 = (unsigned __int128)a[0] * b[3]; + t3 += (unsigned __int128)a[1] * b[2]; + t3 += (unsigned __int128)a[2] * b[1]; + t3 += (unsigned __int128)a[3] * b[0]; + t3 += (unsigned __int128)a[4] * b4_19; + + t4 = (unsigned __int128)a[0] * b[4]; + t4 += (unsigned __int128)a[1] * b[3]; + t4 += (unsigned __int128)a[2] * b[2]; + t4 += (unsigned __int128)a[3] * b[1]; + t4 += (unsigned __int128)a[4] * b[0]; + + out[0] = u128_lo(t0) & MASK51; c = u128_lo(t0) >> 51 | u128_hi(t0) << 13; t1 += c; + out[1] = u128_lo(t1) & MASK51; c = u128_lo(t1) >> 51 | u128_hi(t1) << 13; t2 += c; + out[2] = u128_lo(t2) & MASK51; c = u128_lo(t2) >> 51 | u128_hi(t2) << 13; t3 += c; + out[3] = u128_lo(t3) & MASK51; c = u128_lo(t3) >> 51 | u128_hi(t3) << 13; t4 += c; + out[4] = u128_lo(t4) & MASK51; c = u128_lo(t4) >> 51 | u128_hi(t4) << 13; + out[0] += 19 * c; + + c = out[0] >> 51; out[0] &= MASK51; out[1] += c; +} + +/* fe_sq: out = a^2 mod p (optimized) */ +static void fe_sq(fe out, const fe a) +{ + unsigned __int128 t0, t1, t2, t3, t4; + uint64_t c; + uint64_t d1 = 2 * a[1], d2 = 2 * a[2], d3 = 2 * a[3]; + uint64_t a4_19 = 19 * a[4], d1_19 = 19 * d1, d2_19 = 19 * d2, a3_19 = 19 * a[3]; + + t0 = (unsigned __int128)a[0] * a[0]; + t0 += (unsigned __int128)d1_19 * a[4]; + t0 += (unsigned __int128)d2_19 * a[3]; + + t1 = (unsigned __int128)a[0] * d1; + t1 += (unsigned __int128)d2_19 * a[4]; + t1 += (unsigned __int128)a3_19 * a[3]; + + t2 = (unsigned __int128)a[0] * d2; + t2 += (unsigned __int128)a[1] * a[1]; + t2 += (unsigned __int128)d3 * a4_19; + + t3 = (unsigned __int128)a[0] * d3; + t3 += (unsigned __int128)d1 * a[2]; + t3 += (unsigned __int128)a[4] * a4_19; + + t4 = (unsigned __int128)a[0] * (2 * a[4]); + t4 += (unsigned __int128)d1 * a[3]; + t4 += (unsigned __int128)a[2] * a[2]; + + out[0] = u128_lo(t0) & MASK51; c = u128_lo(t0) >> 51 | u128_hi(t0) << 13; t1 += c; + out[1] = u128_lo(t1) & MASK51; c = u128_lo(t1) >> 51 | u128_hi(t1) << 13; t2 += c; + out[2] = u128_lo(t2) & MASK51; c = u128_lo(t2) >> 51 | u128_hi(t2) << 13; t3 += c; + out[3] = u128_lo(t3) & MASK51; c = u128_lo(t3) >> 51 | u128_hi(t3) << 13; t4 += c; + out[4] = u128_lo(t4) & MASK51; c = u128_lo(t4) >> 51 | u128_hi(t4) << 13; + out[0] += 19 * c; + + /* Final carry from limb 0 */ + c = out[0] >> 51; out[0] &= MASK51; out[1] += c; +} + +/* fe_mul_small: out = f * n (n < 2^22) */ +static void fe_mul_small(fe out, const fe f, uint64_t n) +{ + unsigned __int128 t0, t1, t2, t3, t4; + uint64_t c; + t0 = (unsigned __int128)f[0] * n; + t1 = (unsigned __int128)f[1] * n; + t2 = (unsigned __int128)f[2] * n; + t3 = (unsigned __int128)f[3] * n; + t4 = (unsigned __int128)f[4] * n; + out[0] = u128_lo(t0) & MASK51; c = u128_lo(t0) >> 51 | u128_hi(t0) << 13; t1 += c; + out[1] = u128_lo(t1) & MASK51; c = u128_lo(t1) >> 51 | u128_hi(t1) << 13; t2 += c; + out[2] = u128_lo(t2) & MASK51; c = u128_lo(t2) >> 51 | u128_hi(t2) << 13; t3 += c; + out[3] = u128_lo(t3) & MASK51; c = u128_lo(t3) >> 51 | u128_hi(t3) << 13; t4 += c; + out[4] = u128_lo(t4) & MASK51; c = u128_lo(t4) >> 51 | u128_hi(t4) << 13; + out[0] += 19 * c; + c = out[0] >> 51; out[0] &= MASK51; out[1] += c; +} + +/* --- Inversion --- */ + +/* fe_invert: out = a^(-1) = a^(p-2) using addition chain */ +static void fe_invert(fe out, const fe a) +{ + fe t0, t1, t2, t3; + int i; + + fe_sq(t0, a); /* t0 = a^2 */ + fe_sq(t1, t0); /* t1 = a^4 */ + fe_sq(t1, t1); /* t1 = a^8 */ + fe_mul(t1, t1, a); /* t1 = a^9 */ + fe_mul(t0, t0, t1); /* t0 = a^11 */ + fe_sq(t2, t0); /* t2 = a^22 */ + fe_mul(t1, t1, t2); /* t1 = a^31 */ + + fe_sq(t2, t1); + for (i = 1; i < 5; i++) fe_sq(t2, t2); + fe_mul(t1, t2, t1); /* t1 = a^(2^10-1) */ + + fe_sq(t2, t1); + for (i = 1; i < 10; i++) fe_sq(t2, t2); + fe_mul(t2, t2, t1); /* t2 = a^(2^20-1) */ + + fe_sq(t3, t2); + for (i = 1; i < 20; i++) fe_sq(t3, t3); + fe_mul(t2, t3, t2); /* t2 = a^(2^40-1) */ + + fe_sq(t2, t2); + for (i = 1; i < 10; i++) fe_sq(t2, t2); + fe_mul(t1, t2, t1); /* t1 = a^(2^50-1) */ + + fe_sq(t2, t1); + for (i = 1; i < 50; i++) fe_sq(t2, t2); + fe_mul(t2, t2, t1); /* t2 = a^(2^100-1) */ + + fe_sq(t3, t2); + for (i = 1; i < 100; i++) fe_sq(t3, t3); + fe_mul(t2, t3, t2); /* t2 = a^(2^200-1) */ + + fe_sq(t2, t2); + for (i = 1; i < 50; i++) fe_sq(t2, t2); + fe_mul(t1, t2, t1); /* t1 = a^(2^250-1) */ + + fe_sq(t1, t1); + fe_sq(t1, t1); + fe_sq(t1, t1); + fe_sq(t1, t1); + fe_sq(t1, t1); /* t1 = a^(2^255-2^5) */ + fe_mul(out, t1, t0); /* out = a^(2^255-21) = a^(p-2) */ +} + +/* --- Byte conversion --- */ + +/* fe_from_bytes: 32-byte little-endian → field element */ +static void fe_from_bytes(fe out, const uint8_t in[32]) +{ + uint8_t buf[32]; + memcpy(buf, in, 32); + buf[31] &= 0x7f; /* clear top bit per RFC 7748 §5 */ + + out[0] = ((uint64_t)buf[ 0]) + | ((uint64_t)buf[ 1] << 8) + | ((uint64_t)buf[ 2] << 16) + | ((uint64_t)buf[ 3] << 24) + | ((uint64_t)buf[ 4] << 32) + | ((uint64_t)buf[ 5] << 40) + | ((uint64_t)(buf[6] & 0x07) << 48); + + out[1] = ((uint64_t)buf[ 6] >> 3) + | ((uint64_t)buf[ 7] << 5) + | ((uint64_t)buf[ 8] << 13) + | ((uint64_t)buf[ 9] << 21) + | ((uint64_t)buf[10] << 29) + | ((uint64_t)buf[11] << 37) + | ((uint64_t)(buf[12] & 0x3f) << 45); + + out[2] = ((uint64_t)buf[12] >> 6) + | ((uint64_t)buf[13] << 2) + | ((uint64_t)buf[14] << 10) + | ((uint64_t)buf[15] << 18) + | ((uint64_t)buf[16] << 26) + | ((uint64_t)buf[17] << 34) + | ((uint64_t)buf[18] << 42) + | ((uint64_t)(buf[19] & 0x01) << 50); + + out[3] = ((uint64_t)buf[19] >> 1) + | ((uint64_t)buf[20] << 7) + | ((uint64_t)buf[21] << 15) + | ((uint64_t)buf[22] << 23) + | ((uint64_t)buf[23] << 31) + | ((uint64_t)buf[24] << 39) + | ((uint64_t)(buf[25] & 0x0f) << 47); + + out[4] = ((uint64_t)buf[25] >> 4) + | ((uint64_t)buf[26] << 4) + | ((uint64_t)buf[27] << 12) + | ((uint64_t)buf[28] << 20) + | ((uint64_t)buf[29] << 28) + | ((uint64_t)buf[30] << 36) + | ((uint64_t)(buf[31] & 0x7f) << 44); +} + +/* fe_to_bytes: field element → 32-byte little-endian */ +static void fe_to_bytes(uint8_t out[32], const fe in) +{ + fe f; + uint64_t c, t; + + fe_copy(f, in); + fe_reduce(f); + fe_reduce(f); + + /* Conditional subtract p = 2^255 - 19 */ + t = f[0] + 19; + c = t >> 51; t &= MASK51; uint64_t g0 = t; + t = f[1] + c; c = t >> 51; t &= MASK51; uint64_t g1 = t; + t = f[2] + c; c = t >> 51; t &= MASK51; uint64_t g2 = t; + t = f[3] + c; c = t >> 51; t &= MASK51; uint64_t g3 = t; + t = f[4] + c; uint64_t g4 = t & MASK51; + uint64_t mask = -((t >> 51) & 1); + f[0] = (f[0] & ~mask) | (g0 & mask); + f[1] = (f[1] & ~mask) | (g1 & mask); + f[2] = (f[2] & ~mask) | (g2 & mask); + f[3] = (f[3] & ~mask) | (g3 & mask); + f[4] = (f[4] & ~mask) | (g4 & mask); + + /* Unpack to bytes */ + out[ 0] = (uint8_t)(f[0]); + out[ 1] = (uint8_t)(f[0] >> 8); + out[ 2] = (uint8_t)(f[0] >> 16); + out[ 3] = (uint8_t)(f[0] >> 24); + out[ 4] = (uint8_t)(f[0] >> 32); + out[ 5] = (uint8_t)(f[0] >> 40); + out[ 6] = (uint8_t)((f[0] >> 48) | (f[1] << 3)); + out[ 7] = (uint8_t)(f[1] >> 5); + out[ 8] = (uint8_t)(f[1] >> 13); + out[ 9] = (uint8_t)(f[1] >> 21); + out[10] = (uint8_t)(f[1] >> 29); + out[11] = (uint8_t)(f[1] >> 37); + out[12] = (uint8_t)((f[1] >> 45) | (f[2] << 6)); + out[13] = (uint8_t)(f[2] >> 2); + out[14] = (uint8_t)(f[2] >> 10); + out[15] = (uint8_t)(f[2] >> 18); + out[16] = (uint8_t)(f[2] >> 26); + out[17] = (uint8_t)(f[2] >> 34); + out[18] = (uint8_t)(f[2] >> 42); + out[19] = (uint8_t)((f[2] >> 50) | (f[3] << 1)); + out[20] = (uint8_t)(f[3] >> 7); + out[21] = (uint8_t)(f[3] >> 15); + out[22] = (uint8_t)(f[3] >> 23); + out[23] = (uint8_t)(f[3] >> 31); + out[24] = (uint8_t)(f[3] >> 39); + out[25] = (uint8_t)((f[3] >> 47) | (f[4] << 4)); + out[26] = (uint8_t)(f[4] >> 4); + out[27] = (uint8_t)(f[4] >> 12); + out[28] = (uint8_t)(f[4] >> 20); + out[29] = (uint8_t)(f[4] >> 28); + out[30] = (uint8_t)(f[4] >> 36); + out[31] = (uint8_t)(f[4] >> 44); +} + +/* --- Montgomery ladder --- */ + +#define A24 121665ULL + +/* fe_cswap: conditional swap */ +static void fe_cswap(fe a, fe b, uint64_t swap) +{ + uint64_t mask = -(swap & 1); + for (int i = 0; i < NLIMBS; i++) { + uint64_t t = mask & (a[i] ^ b[i]); + a[i] ^= t; + b[i] ^= t; } - - /* Python: z2_inv = mod_inverse(z2, P) */ - fe_inv(z2, z2); - - /* Python: result = (x2 * z2_inv) % P */ - fe_mul(x2, x2, z2); - - /* Python: return int_to_bytes(result) */ - fe_tobytes(out, x2); } +/* ladder_step: one Montgomery ladder step */ +static void ladder_step( + fe X2, fe Z2, fe X3, fe Z3, + const fe X2_in, const fe Z2_in, + const fe X3_in, const fe Z3_in, + const fe x1) +{ + fe A, AA, B, BB, E, C, D, DA, CB, tmp, a24_E; + fe_add(A, X2_in, Z2_in); + fe_sq (AA, A); + fe_sub(B, X2_in, Z2_in); + fe_sq (BB, B); + fe_sub(E, AA, BB); + fe_add(C, X3_in, Z3_in); + fe_sub(D, X3_in, Z3_in); + fe_mul(DA, D, A); + fe_mul(CB, C, B); + + fe_add(tmp, DA, CB); + fe_sq (X3, tmp); + fe_sub(tmp, DA, CB); + fe_sq (tmp, tmp); + fe_mul(Z3, tmp, x1); + fe_mul(X2, AA, BB); + + fe_mul_small(a24_E, E, A24); + fe_add(tmp, AA, a24_E); + fe_mul(Z2, E, tmp); +} + +/* --- Public API --- */ + +const uint8_t X25519_BASE_POINT[32] = { 9 }; + +int x25519_sw(uint8_t out[32], const uint8_t scalar[32], const uint8_t point[32]) +{ + uint8_t e[32]; + fe x1, X2, Z2, X3, Z3; + uint64_t prev_bit, swap; + int i; + + /* Step 1: clamp scalar */ + memcpy(e, scalar, 32); + e[ 0] &= 248; + e[31] &= 127; + e[31] |= 64; + + /* Step 2: decode u-coordinate */ + if (point == NULL) + fe_from_bytes(x1, X25519_BASE_POINT); + else + fe_from_bytes(x1, point); + + /* Step 3: initialise projective points */ + fe_one (X2); fe_zero(Z2); + fe_copy(X3, x1); fe_one(Z3); + + /* Step 4: Montgomery ladder */ + prev_bit = 0; + for (i = 254; i >= 0; i--) { + uint64_t bit = (e[i / 8] >> (i % 8)) & 1; + swap = bit ^ prev_bit; + prev_bit = bit; + + fe_cswap(X2, X3, swap); + fe_cswap(Z2, Z3, swap); + + fe nX2, nZ2, nX3, nZ3; + ladder_step(nX2, nZ2, nX3, nZ3, X2, Z2, X3, Z3, x1); + fe_copy(X2, nX2); fe_copy(Z2, nZ2); + fe_copy(X3, nX3); fe_copy(Z3, nZ3); + } + fe_cswap(X2, X3, prev_bit); + fe_cswap(Z2, Z3, prev_bit); + + /* Step 5: convert from projective to affine */ + fe Z2_inv; + fe_invert(Z2_inv, Z2); + fe_mul(X2, X2, Z2_inv); + + /* Step 6: encode result */ + fe_to_bytes(out, X2); + + /* Step 7: reject all-zero output */ + uint8_t acc = 0; + for (i = 0; i < 32; i++) acc |= out[i]; + if (acc == 0) return -1; + + return 0; +} void se050_x25519_sw_clamp(uint8_t *scalar) {