Skip to content

Commit 924cba6

Browse files
committed
ML-KEM iNTT proof for x86 AVX2 implementation
This change adds the implementation of ML-KEM Inverse NTT function and its HolLight proof of correctness. At this moment, the implementation is the same as the one in mlkem-native repository at https://github.com/pq-code-package/mlkem-native. The proof was done in collaboration with @jargh. Signed-off-by: dkostic <[email protected]>
1 parent dcb8fc4 commit 924cba6

File tree

12 files changed

+3090
-11
lines changed

12 files changed

+3090
-11
lines changed

benchmarks/benchmark.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -831,14 +831,14 @@ void call_mldsa_ntt(void) repeat(mldsa_ntt((int32_t*)b0,(const int32_t*)b1))
831831
void call_mldsa_poly_reduce(void) repeat(mldsa_poly_reduce((int32_t*)b0))
832832

833833
void call_mlkem_ntt(void) repeat(mlkem_ntt_x86((int16_t*)b0,(int16_t*)b1))
834+
void call_mlkem_intt(void) repeat(mlkem_intt_x86((int16_t*)b0,(int16_t*)b1))
834835

835836
void call_bignum_copy_row_from_table_8n__32_16(void) {}
836837
void call_bignum_copy_row_from_table_8n__32_32(void) {}
837838
void call_bignum_copy_row_from_table_16__32(void) {}
838839
void call_bignum_copy_row_from_table_32__32(void) {}
839840

840841
void call_bignum_emontredc_8n_cdiff__32(void) {}
841-
void call_mlkem_intt(void) {}
842842
void call_mlkem_mulcache_compute(void) {}
843843
void call_mlkem_tobytes(void) {}
844844
void call_mlkem_tomont(void) {}
@@ -1528,7 +1528,7 @@ int main(int argc, char *argv[])
15281528
timingtest(all,"mlkem_basemul_k2",call_mlkem_basemul_k2);
15291529
timingtest(all,"mlkem_basemul_k3",call_mlkem_basemul_k3);
15301530
timingtest(all,"mlkem_basemul_k4",call_mlkem_basemul_k4);
1531-
timingtest(arm,"mlkem_intt",call_mlkem_intt);
1531+
timingtest(all,"mlkem_intt",call_mlkem_intt);
15321532
timingtest(arm,"mlkem_mulcache_compute",call_mlkem_mulcache_compute);
15331533
timingtest(all,"mlkem_ntt",call_mlkem_ntt);
15341534
timingtest(all,"mlkem_reduce",call_mlkem_reduce);

common/mlkem_mldsa.ml

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,55 @@ let avx2_ntt_order = define
5252
`avx2_ntt_order i =
5353
bitreverse7(64 * (i DIV 64) + ((i MOD 64) DIV 16) + 4 * (i MOD 16))`;;
5454

55+
let avx2_ntt_order' = define
56+
`avx2_ntt_order' i =
57+
let j = bitreverse7 i in
58+
(64 * (j DIV 64) + 16 * (j MOD 4) + (j MOD 64) DIV 4)`;;
59+
60+
let avx2_reorder = define
61+
`avx2_reorder i =
62+
let r = (i DIV 16) MOD 2
63+
and q = 16 * (i DIV 32) + i MOD 16 in
64+
2 * avx2_ntt_order q + r`;;
65+
66+
let avx2_reorder' = define
67+
`avx2_reorder' i =
68+
let r = i MOD 2
69+
and q = avx2_ntt_order'(i DIV 2) in
70+
(q DIV 16) * 32 + r * 16 + q MOD 16`;;
71+
72+
(* ------------------------------------------------------------------------- *)
73+
(* The simpler ones as used on ARM are actually involutions. *)
74+
(* ------------------------------------------------------------------------- *)
75+
76+
let BITREVERSE7_INVOLUTION = prove
77+
(`!n. n < 128 ==> bitreverse7(bitreverse7 n) = n`,
78+
CONV_TAC EXPAND_CASES_CONV THEN REWRITE_TAC[bitreverse7] THEN
79+
CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV));;
80+
81+
let BITREVERSE_PAIRS_INVOLUTION = prove
82+
(`!n. n < 256 ==> bitreverse_pairs(bitreverse_pairs n) = n`,
83+
CONV_TAC EXPAND_CASES_CONV THEN
84+
REWRITE_TAC[bitreverse_pairs; bitreverse7] THEN
85+
CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV));;
86+
87+
let AVX2_NTT_ORDER_INVOLUTION = prove
88+
(`!n. n < 128 ==> avx2_ntt_order'(avx2_ntt_order n) = n /\
89+
avx2_ntt_order(avx2_ntt_order' n) = n`,
90+
CONV_TAC EXPAND_CASES_CONV THEN
91+
REWRITE_TAC[avx2_ntt_order; avx2_ntt_order'; bitreverse7] THEN
92+
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
93+
CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV));;
94+
95+
let AVX2_REORDER_INVOLUTION = prove
96+
(`!n. n < 256 ==> avx2_reorder'(avx2_reorder n) = n /\
97+
avx2_reorder(avx2_reorder' n) = n`,
98+
CONV_TAC EXPAND_CASES_CONV THEN
99+
REWRITE_TAC[avx2_reorder; avx2_reorder';
100+
avx2_ntt_order; avx2_ntt_order'; bitreverse7] THEN
101+
CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN
102+
CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV));;
103+
55104
(* ------------------------------------------------------------------------- *)
56105
(* AVX2-optimized ordering for ML-DSA NTT (swaps bit fields then reverses) *)
57106
(* ------------------------------------------------------------------------- *)
@@ -107,6 +156,15 @@ let avx2_forward_ntt = define
107156
&17 pow ((2 * avx2_ntt_order q + 1) * j))
108157
rem &3329`;;
109158

159+
let avx2_inverse_ntt = define
160+
`avx2_inverse_ntt f k =
161+
(&512 * isum (0..127)
162+
(\j. f(avx2_ntt_order' j DIV 16 * 32 +
163+
k MOD 2 * 16 +
164+
avx2_ntt_order' j MOD 16) *
165+
&1175 pow ((2 * j + 1) * k DIV 2)))
166+
rem &3329`;;
167+
110168
let mldsa_forward_ntt = define
111169
`mldsa_forward_ntt f k =
112170
isum (0..255) (\j. f j * &1753 pow ((2 * mldsa_avx2_ntt_order k + 1) * j))
@@ -133,6 +191,26 @@ let INVERSE_NTT = prove
133191
CONV_TAC INT_REM_DOWN_CONV THEN REWRITE_TAC[INT_MUL_ASSOC] THEN
134192
ONCE_REWRITE_TAC[GSYM INT_MUL_REM] THEN CONV_TAC INT_REDUCE_CONV);;
135193

194+
let AVX2_FORWARD_NTT = prove
195+
(`avx2_forward_ntt = reorder avx2_reorder o pure_forward_ntt`,
196+
REWRITE_TAC[FUN_EQ_THM; o_DEF; avx2_reorder; reorder] THEN
197+
REWRITE_TAC[avx2_forward_ntt; pure_forward_ntt] THEN
198+
MAP_EVERY X_GEN_TAC [`x:num->int`; `k:num`] THEN
199+
CONV_TAC(ONCE_DEPTH_CONV let_CONV) THEN
200+
SIMP_TAC[MOD_MULT_ADD; DIV_MULT_ADD; ARITH_EQ; MOD_MOD_REFL] THEN
201+
REWRITE_TAC[ARITH_RULE `x MOD 2 DIV 2 = 0`; ADD_CLAUSES]);;
202+
203+
let AVX2_INVERSE_NTT = prove
204+
(`avx2_inverse_ntt = tomont_3329 o pure_inverse_ntt o reorder avx2_reorder'`,
205+
REWRITE_TAC[FUN_EQ_THM; o_DEF; avx2_reorder'; reorder] THEN
206+
REWRITE_TAC[avx2_inverse_ntt; pure_inverse_ntt; tomont_3329] THEN
207+
REWRITE_TAC[ARITH_RULE `(2 * x + i MOD 2) DIV 2 = x`] THEN
208+
REWRITE_TAC[MOD_MULT_ADD; MOD_MOD_REFL] THEN
209+
MAP_EVERY X_GEN_TAC [`x:num->int`; `k:num`] THEN
210+
CONV_TAC(ONCE_DEPTH_CONV let_CONV) THEN
211+
CONV_TAC INT_REM_DOWN_CONV THEN REWRITE_TAC[INT_MUL_ASSOC] THEN
212+
ONCE_REWRITE_TAC[GSYM INT_MUL_REM] THEN CONV_TAC INT_REDUCE_CONV);;
213+
136214
let MLDSA_FORWARD_NTT = prove
137215
(`mldsa_forward_ntt f k =
138216
isum (0..255) (\j. f j * &1753 pow ((2 * mldsa_avx2_ntt_order k + 1) * j)) rem &8380417`,
@@ -198,6 +276,25 @@ let INVERSE_NTT_ALT = prove
198276
CONV_TAC INT_REM_DOWN_CONV THEN
199277
AP_THM_TAC THEN AP_TERM_TAC THEN CONV_TAC INT_ARITH);;
200278

279+
let AVX2_INVERSE_NTT_ALT = prove
280+
(`avx2_inverse_ntt f k =
281+
isum (0..127)
282+
(\j. f(avx2_ntt_order' j DIV 16 * 32 +
283+
k MOD 2 * 16 +
284+
avx2_ntt_order' j MOD 16) *
285+
(&512 *
286+
(&1175 pow ((2 * j + 1) * k DIV 2)) rem &3329)
287+
rem &3329) rem &3329`,
288+
REWRITE_TAC[avx2_inverse_ntt; GSYM ISUM_LMUL] THEN
289+
MATCH_MP_TAC (REWRITE_RULE[] (ISPEC
290+
`(\x y. x rem &3329 = y rem &3329)` ISUM_RELATED)) THEN
291+
REWRITE_TAC[INT_REM_EQ; FINITE_NUMSEG; INT_CONG_ADD] THEN
292+
X_GEN_TAC `i:num` THEN DISCH_TAC THEN
293+
REWRITE_TAC[GSYM INT_OF_NUM_REM; GSYM INT_OF_NUM_CLAUSES;
294+
GSYM INT_REM_EQ] THEN
295+
CONV_TAC INT_REM_DOWN_CONV THEN
296+
AP_THM_TAC THEN AP_TERM_TAC THEN CONV_TAC INT_ARITH);;
297+
201298
let FORWARD_NTT_CONV =
202299
GEN_REWRITE_CONV I [FORWARD_NTT_ALT] THENC
203300
LAND_CONV EXPAND_ISUM_CONV THENC
@@ -212,6 +309,12 @@ let AVX2_NTT_ORDER_CLAUSES = end_itlist CONJ (map
212309
GEN_REWRITE_CONV I [BITREVERSE7_CLAUSES])
213310
(map (curry mk_comb `avx2_ntt_order` o mk_small_numeral) (0--127)));;
214311

312+
let AVX2_NTT_ORDER_CLAUSES' = end_itlist CONJ (map
313+
(GEN_REWRITE_CONV I [avx2_ntt_order'] THENC DEPTH_CONV WORD_NUM_RED_CONV THENC
314+
DEPTH_CONV let_CONV THENC
315+
GEN_REWRITE_CONV ONCE_DEPTH_CONV [BITREVERSE7_CLAUSES] THENC NUM_REDUCE_CONV)
316+
(map (curry mk_comb `avx2_ntt_order'` o mk_small_numeral) (0--127)));;
317+
215318
let AVX2_FORWARD_NTT_CONV =
216319
GEN_REWRITE_CONV I [AVX2_FORWARD_NTT_ALT] THENC
217320
NUM_REDUCE_CONV THENC ONCE_DEPTH_CONV let_CONV THENC
@@ -231,6 +334,16 @@ let INVERSE_NTT_CONV =
231334
GEN_REWRITE_CONV DEPTH_CONV [INT_OF_NUM_POW; INT_OF_NUM_REM] THENC
232335
ONCE_DEPTH_CONV EXP_MOD_CONV THENC INT_REDUCE_CONV;;
233336

337+
let AVX2_INVERSE_NTT_CONV =
338+
GEN_REWRITE_CONV I [AVX2_INVERSE_NTT_ALT] THENC
339+
NUM_REDUCE_CONV THENC ONCE_DEPTH_CONV let_CONV THENC
340+
LAND_CONV EXPAND_ISUM_CONV THENC
341+
DEPTH_CONV NUM_RED_CONV THENC
342+
GEN_REWRITE_CONV ONCE_DEPTH_CONV [AVX2_NTT_ORDER_CLAUSES'] THENC
343+
DEPTH_CONV NUM_RED_CONV THENC
344+
GEN_REWRITE_CONV DEPTH_CONV [INT_OF_NUM_POW; INT_OF_NUM_REM] THENC
345+
ONCE_DEPTH_CONV EXP_MOD_CONV THENC INT_REDUCE_CONV;;
346+
234347
(* ------------------------------------------------------------------------- *)
235348
(* Explicit computation rules to evaluate mod-8380417 powers less naively. *)
236349
(* ------------------------------------------------------------------------- *)
@@ -672,7 +785,7 @@ let CONGBOUND_BARRED_X86 = prove
672785
(`!a a' l u.
673786
((ival a == a') (mod &3329) /\ l <= ival a /\ ival a <= u)
674787
==> (ival(barred_x86 a) == a') (mod &3329) /\
675-
&0 <= ival(barred_x86 a) /\ ival(barred_x86 a) < &6658`,
788+
&0 <= ival(barred_x86 a) /\ ival(barred_x86 a) <= &6657`,
676789
REPEAT GEN_TAC THEN STRIP_TAC THEN REWRITE_TAC[barred_x86] THEN
677790
REWRITE_TAC[WORD_BLAST
678791
`word_ishr (word_subword (x:int32) (16,16):int16) 10 =

include/s2n-bignum.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,10 @@ extern void mlkem_basemul_k4(int16_t r[S2N_BIGNUM_STATIC 256],const int16_t a[S2
10021002
// Input a[256] (signed 16-bit words), z_01234[80] (signed 16-bit words), z_56[384] (signed 16-bit words); output a[256] (signed 16-bit words)
10031003
extern void mlkem_intt(int16_t a[S2N_BIGNUM_STATIC 256],const int16_t z_01234[S2N_BIGNUM_STATIC 80],const int16_t z_56[S2N_BIGNUM_STATIC 384]);
10041004

1005+
// Inverse number-theoretic transform from ML-KEM
1006+
// Input a[256] (signed 16-bit words), qdata[624]; output a[256] (signed 16-bit words)
1007+
extern void mlkem_intt_x86(int16_t a[S2N_BIGNUM_STATIC 256],const int16_t qdata[S2N_BIGNUM_STATIC 624]);
1008+
10051009
// Precompute the mulcache data for a polynomial in the NTT domain
10061010
// Inputs a[256], z[128] and t[128] (signed 16-bit words); output x[128] (signed 16-bit words)
10071011
extern void mlkem_mulcache_compute(int16_t x[S2N_BIGNUM_STATIC 128],const int16_t a[S2N_BIGNUM_STATIC 256],const int16_t z[S2N_BIGNUM_STATIC 128],const int16_t t[S2N_BIGNUM_STATIC 128]);

tests/test.c

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12106,26 +12106,33 @@ uint64_t t, i;
1210612106

1210712107
int test_mlkem_intt(void)
1210812108
{
12109-
#ifdef __x86_64__
12110-
return 1;
12111-
#else
1211212109
uint64_t t, i;
12113-
int16_t a[256], b[256], c[256];
12110+
int16_t a[256] __attribute__((aligned(32)));
12111+
int16_t b[256] __attribute__((aligned(32)));
12112+
int16_t c[256] __attribute__((aligned(32)));
1211412113
printf("Testing mlkem_intt with %d cases\n",tests);
1211512114

1211612115
for (t = 0; t < tests; ++t)
1211712116
{ for (i = 0; i < 256; ++i)
1211812117
a[i] = (int16_t) (random64()); // any int16_t inputs allowed
1211912118
for (i = 0; i < 256; ++i) b[i] = a[i];
12119+
#ifdef __x86_64__
12120+
mlkem_poly_to_avx2_layout(b);
12121+
mlkem_intt_x86(b,mlkem_qdata);
12122+
#else
1212012123
mlkem_intt(b,intt_zetas_layer01234,intt_zetas_layer56);
12124+
#endif
12125+
1212112126
reference_bitreverse(c,a);
1212212127
reference_inverse_ntt(c,c);
1212312128
reference_tomont3329(c,c);
12129+
12130+
1212412131
for (i = 0; i < 256; ++i)
1212512132
{ if (rem_3329(b[i]) != rem_3329(c[i]))
1212612133
{ printf("Error in iNTT element i = %"PRIu64"; code[i] = 0x%04"PRIx16
1212712134
" while reference[i] = 0x%04"PRIx16"\n",
12128-
i,b[i],c[i]);
12135+
i,rem_3329(b[i]),rem_3329(c[i]));
1212912136
return 1;
1213012137
}
1213112138
}
@@ -12140,7 +12147,6 @@ int test_mlkem_intt(void)
1214012147
}
1214112148
printf("All OK\n");
1214212149
return 0;
12143-
#endif
1214412150
}
1214512151

1214612152
int test_mlkem_mulcache_compute(void)
@@ -15622,6 +15628,7 @@ int main(int argc, char *argv[])
1562215628
functionaltest(all,"mlkem_basemul_k2",test_mlkem_basemul_k2);
1562315629
functionaltest(all,"mlkem_basemul_k3",test_mlkem_basemul_k3);
1562415630
functionaltest(all,"mlkem_basemul_k4",test_mlkem_basemul_k4);
15631+
functionaltest(all,"mlkem_intt",test_mlkem_intt);
1562515632
functionaltest(all,"mlkem_ntt",test_mlkem_ntt);
1562615633
functionaltest(all,"mlkem_reduce",test_mlkem_reduce);
1562715634
functionaltest(bmi,"p256_montjadd",test_p256_montjadd);
@@ -15682,7 +15689,6 @@ int main(int argc, char *argv[])
1568215689
functionaltest(all,"bignum_copy_row_from_table_16",test_bignum_copy_row_from_table_16);
1568315690
functionaltest(all,"bignum_copy_row_from_table_32",test_bignum_copy_row_from_table_32);
1568415691
functionaltest(all,"bignum_emontredc_8n_cdiff",test_bignum_emontredc_8n_cdiff);
15685-
functionaltest(arm,"mlkem_intt",test_mlkem_intt);
1568615692
functionaltest(arm,"mlkem_mulcache_compute",test_mlkem_mulcache_compute);
1568715693
functionaltest(arm,"mlkem_tobytes",test_mlkem_tobytes);
1568815694
functionaltest(arm,"mlkem_tomont",test_mlkem_tomont);

tools/collect-signatures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def stripPrefixes(s, prefixes):
333333
"mldsa_ntt",
334334
"mldsa_poly_reduce",
335335
"mlkem_ntt_x86",
336+
"mlkem_intt_x86",
336337
]
337338

338339
for arch in ["arm","x86"]:

x86/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ BIGNUM_OBJ = curve25519/bignum_add_p25519.o \
255255
mlkem/mlkem_basemul_k3.o \
256256
mlkem/mlkem_basemul_k4.o \
257257
mlkem/mlkem_ntt.o \
258+
mlkem/mlkem_intt.o \
258259
mlkem/mlkem_reduce.o \
259260
p256/bignum_add_p256.o \
260261
p256/bignum_bigendian_4.o \

0 commit comments

Comments
 (0)