@@ -52,6 +52,55 @@ let avx2_ntt_order = define
5252 `avx2_ntt_order i =
5353 bitreverse7(64 * (i DIV 64 ) + ((i MOD 64 ) DIV 16 ) + 4 * (i MOD 16 ))`;;
5454
55+ let avx2_ntt_order' = define
56+ `avx2_ntt_order ' i =
57+ let j = bitreverse7 i in
58+ (64 * (j DIV 64 ) + 16 * (j MOD 4 ) + (j MOD 64 ) DIV 4 )`;;
59+
60+ let avx2_reorder = define
61+ `avx2_reorder i =
62+ let r = (i DIV 16 ) MOD 2
63+ and q = 16 * (i DIV 32 ) + i MOD 16 in
64+ 2 * avx2_ntt_order q + r`;;
65+
66+ let avx2_reorder' = define
67+ `avx2_reorder ' i =
68+ let r = i MOD 2
69+ and q = avx2_ntt_order'(i DIV 2 ) in
70+ (q DIV 16 ) * 32 + r * 16 + q MOD 16 `;;
71+
72+ (* ------------------------------------------------------------------------- *)
73+ (* The simpler ones as used on ARM are actually involutions. *)
74+ (* ------------------------------------------------------------------------- *)
75+
76+ let BITREVERSE7_INVOLUTION = prove
77+ (`! n. n < 128 ==> bitreverse7(bitreverse7 n) = n`,
78+ CONV_TAC EXPAND_CASES_CONV THEN REWRITE_TAC [bitreverse7] THEN
79+ CONV_TAC (DEPTH_CONV WORD_NUM_RED_CONV ));;
80+
81+ let BITREVERSE_PAIRS_INVOLUTION = prove
82+ (`! n. n < 256 ==> bitreverse_pairs(bitreverse_pairs n) = n`,
83+ CONV_TAC EXPAND_CASES_CONV THEN
84+ REWRITE_TAC [bitreverse_pairs; bitreverse7] THEN
85+ CONV_TAC (DEPTH_CONV WORD_NUM_RED_CONV ));;
86+
87+ let AVX2_NTT_ORDER_INVOLUTION = prove
88+ (`! n. n < 128 ==> avx2_ntt_order'(avx2_ntt_order n) = n / \
89+ avx2_ntt_order(avx2_ntt_order' n) = n`,
90+ CONV_TAC EXPAND_CASES_CONV THEN
91+ REWRITE_TAC [avx2_ntt_order; avx2_ntt_order'; bitreverse7] THEN
92+ CONV_TAC (TOP_DEPTH_CONV let_CONV) THEN
93+ CONV_TAC (DEPTH_CONV WORD_NUM_RED_CONV ));;
94+
95+ let AVX2_REORDER_INVOLUTION = prove
96+ (`! n. n < 256 ==> avx2_reorder'(avx2_reorder n) = n / \
97+ avx2_reorder(avx2_reorder' n) = n`,
98+ CONV_TAC EXPAND_CASES_CONV THEN
99+ REWRITE_TAC [avx2_reorder; avx2_reorder';
100+ avx2_ntt_order; avx2_ntt_order'; bitreverse7] THEN
101+ CONV_TAC (TOP_DEPTH_CONV let_CONV) THEN
102+ CONV_TAC (DEPTH_CONV WORD_NUM_RED_CONV ));;
103+
55104(* ------------------------------------------------------------------------- *)
56105(* AVX2-optimized ordering for ML-DSA NTT (swaps bit fields then reverses) *)
57106(* ------------------------------------------------------------------------- *)
@@ -107,6 +156,15 @@ let avx2_forward_ntt = define
107156 & 17 pow ((2 * avx2_ntt_order q + 1 ) * j))
108157 rem & 3329 `;;
109158
159+ let avx2_inverse_ntt = define
160+ `avx2_inverse_ntt f k =
161+ (& 512 * isum (0. .127 )
162+ (\j. f(avx2_ntt_order' j DIV 16 * 32 +
163+ k MOD 2 * 16 +
164+ avx2_ntt_order' j MOD 16 ) *
165+ & 1175 pow ((2 * j + 1 ) * k DIV 2 )))
166+ rem & 3329 `;;
167+
110168let mldsa_forward_ntt = define
111169 `mldsa_forward_ntt f k =
112170 isum (0. .255 ) (\j. f j * & 1753 pow ((2 * mldsa_avx2_ntt_order k + 1 ) * j))
@@ -133,6 +191,26 @@ let INVERSE_NTT = prove
133191 CONV_TAC INT_REM_DOWN_CONV THEN REWRITE_TAC [INT_MUL_ASSOC ] THEN
134192 ONCE_REWRITE_TAC [GSYM INT_MUL_REM ] THEN CONV_TAC INT_REDUCE_CONV );;
135193
194+ let AVX2_FORWARD_NTT = prove
195+ (`avx2_forward_ntt = reorder avx2_reorder o pure_forward_ntt`,
196+ REWRITE_TAC [FUN_EQ_THM ; o_DEF; avx2_reorder; reorder] THEN
197+ REWRITE_TAC [avx2_forward_ntt; pure_forward_ntt] THEN
198+ MAP_EVERY X_GEN_TAC [`x :num->int `; `k :num`] THEN
199+ CONV_TAC (ONCE_DEPTH_CONV let_CONV) THEN
200+ SIMP_TAC [MOD_MULT_ADD ; DIV_MULT_ADD ; ARITH_EQ ; MOD_MOD_REFL ] THEN
201+ REWRITE_TAC [ARITH_RULE `x MOD 2 DIV 2 = 0 `; ADD_CLAUSES ]);;
202+
203+ let AVX2_INVERSE_NTT = prove
204+ (`avx2_inverse_ntt = tomont_3329 o pure_inverse_ntt o reorder avx2_reorder'`,
205+ REWRITE_TAC [FUN_EQ_THM ; o_DEF; avx2_reorder'; reorder] THEN
206+ REWRITE_TAC [avx2_inverse_ntt; pure_inverse_ntt; tomont_3329] THEN
207+ REWRITE_TAC [ARITH_RULE `(2 * x + i MOD 2 ) DIV 2 = x`] THEN
208+ REWRITE_TAC [MOD_MULT_ADD ; MOD_MOD_REFL ] THEN
209+ MAP_EVERY X_GEN_TAC [`x :num->int `; `k :num`] THEN
210+ CONV_TAC (ONCE_DEPTH_CONV let_CONV) THEN
211+ CONV_TAC INT_REM_DOWN_CONV THEN REWRITE_TAC [INT_MUL_ASSOC ] THEN
212+ ONCE_REWRITE_TAC [GSYM INT_MUL_REM ] THEN CONV_TAC INT_REDUCE_CONV );;
213+
136214let MLDSA_FORWARD_NTT = prove
137215 (`mldsa_forward_ntt f k =
138216 isum (0. .255 ) (\j. f j * & 1753 pow ((2 * mldsa_avx2_ntt_order k + 1 ) * j)) rem & 8380417 `,
@@ -198,6 +276,25 @@ let INVERSE_NTT_ALT = prove
198276 CONV_TAC INT_REM_DOWN_CONV THEN
199277 AP_THM_TAC THEN AP_TERM_TAC THEN CONV_TAC INT_ARITH );;
200278
279+ let AVX2_INVERSE_NTT_ALT = prove
280+ (`avx2_inverse_ntt f k =
281+ isum (0. .127 )
282+ (\j. f(avx2_ntt_order' j DIV 16 * 32 +
283+ k MOD 2 * 16 +
284+ avx2_ntt_order' j MOD 16 ) *
285+ (& 512 *
286+ (& 1175 pow ((2 * j + 1 ) * k DIV 2 )) rem & 3329 )
287+ rem & 3329 ) rem & 3329 `,
288+ REWRITE_TAC [avx2_inverse_ntt; GSYM ISUM_LMUL ] THEN
289+ MATCH_MP_TAC (REWRITE_RULE [] (ISPEC
290+ `(\x y. x rem & 3329 = y rem & 3329 )` ISUM_RELATED )) THEN
291+ REWRITE_TAC [INT_REM_EQ ; FINITE_NUMSEG ; INT_CONG_ADD ] THEN
292+ X_GEN_TAC `i :num` THEN DISCH_TAC THEN
293+ REWRITE_TAC [GSYM INT_OF_NUM_REM ; GSYM INT_OF_NUM_CLAUSES ;
294+ GSYM INT_REM_EQ ] THEN
295+ CONV_TAC INT_REM_DOWN_CONV THEN
296+ AP_THM_TAC THEN AP_TERM_TAC THEN CONV_TAC INT_ARITH );;
297+
201298let FORWARD_NTT_CONV =
202299 GEN_REWRITE_CONV I [FORWARD_NTT_ALT ] THENC
203300 LAND_CONV EXPAND_ISUM_CONV THENC
@@ -212,6 +309,12 @@ let AVX2_NTT_ORDER_CLAUSES = end_itlist CONJ (map
212309 GEN_REWRITE_CONV I [BITREVERSE7_CLAUSES ])
213310 (map (curry mk_comb `avx2_ntt_order ` o mk_small_numeral) (0 -- 127 )));;
214311
312+ let AVX2_NTT_ORDER_CLAUSES ' = end_itlist CONJ (map
313+ (GEN_REWRITE_CONV I [avx2_ntt_order'] THENC DEPTH_CONV WORD_NUM_RED_CONV THENC
314+ DEPTH_CONV let_CONV THENC
315+ GEN_REWRITE_CONV ONCE_DEPTH_CONV [BITREVERSE7_CLAUSES ] THENC NUM_REDUCE_CONV )
316+ (map (curry mk_comb `avx2_ntt_order '` o mk_small_numeral) (0 -- 127 )));;
317+
215318let AVX2_FORWARD_NTT_CONV =
216319 GEN_REWRITE_CONV I [AVX2_FORWARD_NTT_ALT ] THENC
217320 NUM_REDUCE_CONV THENC ONCE_DEPTH_CONV let_CONV THENC
@@ -231,6 +334,16 @@ let INVERSE_NTT_CONV =
231334 GEN_REWRITE_CONV DEPTH_CONV [INT_OF_NUM_POW ; INT_OF_NUM_REM ] THENC
232335 ONCE_DEPTH_CONV EXP_MOD_CONV THENC INT_REDUCE_CONV ;;
233336
337+ let AVX2_INVERSE_NTT_CONV =
338+ GEN_REWRITE_CONV I [AVX2_INVERSE_NTT_ALT ] THENC
339+ NUM_REDUCE_CONV THENC ONCE_DEPTH_CONV let_CONV THENC
340+ LAND_CONV EXPAND_ISUM_CONV THENC
341+ DEPTH_CONV NUM_RED_CONV THENC
342+ GEN_REWRITE_CONV ONCE_DEPTH_CONV [AVX2_NTT_ORDER_CLAUSES '] THENC
343+ DEPTH_CONV NUM_RED_CONV THENC
344+ GEN_REWRITE_CONV DEPTH_CONV [INT_OF_NUM_POW ; INT_OF_NUM_REM ] THENC
345+ ONCE_DEPTH_CONV EXP_MOD_CONV THENC INT_REDUCE_CONV ;;
346+
234347(* ------------------------------------------------------------------------- *)
235348(* Explicit computation rules to evaluate mod-8380417 powers less naively. *)
236349(* ------------------------------------------------------------------------- *)
@@ -672,7 +785,7 @@ let CONGBOUND_BARRED_X86 = prove
672785 (`! a a' l u.
673786 ((ival a == a') (mod & 3329 ) / \ l < = ival a / \ ival a < = u)
674787 ==> (ival(barred_x86 a) == a') (mod & 3329 ) / \
675- & 0 < = ival(barred_x86 a) / \ ival(barred_x86 a) < & 6658 `,
788+ & 0 < = ival(barred_x86 a) / \ ival(barred_x86 a) < = & 6657 `,
676789 REPEAT GEN_TAC THEN STRIP_TAC THEN REWRITE_TAC [barred_x86] THEN
677790 REWRITE_TAC [WORD_BLAST
678791 `word_ishr (word_subword (x:int32 ) (16 ,16 ):int16) 10 =
0 commit comments