diff --git a/arm/Makefile b/arm/Makefile index 01c8ec7ad..05ddfb2bc 100644 --- a/arm/Makefile +++ b/arm/Makefile @@ -258,6 +258,7 @@ BIGNUM_OBJ = curve25519/bignum_add_p25519.o \ mldsa/mldsa_pointwise_acc_l7.o \ mldsa/mldsa_ntt.o \ mldsa/mldsa_pointwise.o \ + mldsa/mldsa_poly_use_hint_88.o \ mlkem/mlkem_basemul_k2.o \ mlkem/mlkem_basemul_k3.o \ mlkem/mlkem_basemul_k4.o \ diff --git a/arm/mldsa/mldsa_poly_use_hint_88.S b/arm/mldsa/mldsa_poly_use_hint_88.S new file mode 100644 index 000000000..afbedc916 --- /dev/null +++ b/arm/mldsa/mldsa_poly_use_hint_88.S @@ -0,0 +1,166 @@ +// Copyright (c) The mldsa-native project authors +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +// ---------------------------------------------------------------------------- +// Use hint to correct high bits of decomposition (parameter set 44) +// Inputs a[256] (unsigned 32-bit, in [0,Q)), h[256] (hint bits, 0 or 1) +// Output b[256] (unsigned 32-bit, in [0,44)) +// +// Implements mld_use_hint for ML-DSA parameter set 44: +// GAMMA2 = (Q-1)/88 = 95232 +// 2*GAMMA2 = 190464 +// Output range: [0, 43] +// +// Algorithm per coefficient: +// 1. Decompose: a1 = round_down(a / 190464), a0 = a - a1*190464 +// If a > 43*GAMMA2 = 4094976, wrap: a1=0, a0=a-Q +// 2. delta = (a0 <= 0) ? -1 : 1 +// 3. b = min((a1 + delta * h) & ~mask_gt_43, 43) +// where mask_gt_43 clears values > 43 and umin clamps to 43 +// +// extern void mldsa_poly_use_hint_88 +// (int32_t b[static 256], const int32_t a[static 256], +// const int32_t h[static 256]); +// +// Standard ARM ABI: X0 = b, X1 = a, X2 = h +// ---------------------------------------------------------------------------- +#include "_internal_s2n_bignum_arm.h" + + S2N_BN_SYM_VISIBILITY_DIRECTIVE(mldsa_poly_use_hint_88) + S2N_BN_FUNCTION_TYPE_DIRECTIVE(mldsa_poly_use_hint_88) + S2N_BN_SYM_PRIVACY_DIRECTIVE(mldsa_poly_use_hint_88) + .text + .balign 4 + +S2N_BN_SYMBOL(mldsa_poly_use_hint_88): + CFI_START + +// This matches the code in the mldsa-native repository +// https://github.com/pq-code-package/mldsa-native/blob/main/mldsa/src/native/aarch64/src/poly_use_hint_88_asm.S + +// Load constants into SIMD registers + +// v20 = Q = 8380417 = 0x7fe001 + mov w4, #0xe001 + movk w4, #0x7f, lsl #16 + dup v20.4s, w4 + +// v21 = 43*GAMMA2 = 4094976 = 0x3e6c00 (wraparound threshold) + mov w5, #0x6c00 + movk w5, #0x7e, lsl #16 + dup v21.4s, w5 + +// v22 = 2*GAMMA2 = 190464 = 0x2e800 (decompose multiplier) + mov w7, #0xe800 + movk w7, #0x2, lsl #16 + dup v22.4s, w7 + +// v23 = Barrett constant = 0x58160581 = 1477837185 +// Used for SQDMULH-based Barrett reduction: a1 ~= (2*a*c) >> 48 + mov w11, #0x0581 + movk w11, #0x5816, lsl #16 + dup v23.4s, w11 + +// v24 = 43 = 0x0000002b (for clamping with umin) + movi v24.4s, #0x2b + +// Loop counter: 16 iterations, processing 16 coefficients per iteration +// 16 * 16 = 256 total coefficients + mov x3, #0x10 + +Lmldsa_poly_use_hint_88_loop: + // Load 16 coefficients from a (4 vectors of 4 int32s) + ldr q1, [x1, #0x10] + ldr q2, [x1, #0x20] + ldr q3, [x1, #0x30] + ldr q0, [x1], #0x40 + + // Load 16 hint bits from h (4 vectors of 4 int32s) + ldr q5, [x2, #0x10] + ldr q6, [x2, #0x20] + ldr q7, [x2, #0x30] + ldr q4, [x2], #0x40 + + // --- Process v1 (coefficients at offset +16) --- + // Decompose: a1 = sqdmulh(a, barrett_const) >> 17 + sqdmulh v17.4s, v1.4s, v23.4s + srshr v17.4s, v17.4s, #0x11 + // Check wraparound: mask = (a > 43*GAMMA2) ? all_ones : 0 + cmgt v25.4s, v1.4s, v21.4s + // a0 = a - a1 * 2*GAMMA2 + mls v1.4s, v17.4s, v22.4s + // If wraparound: a1 = 0 (clear a1 where mask is set) + bic v17.16b, v17.16b, v25.16b + // If wraparound: a0 -= 1 (add all_ones = -1) + add v1.4s, v1.4s, v25.4s + // delta = (a0 <= 0) ? all_ones : 0 + cmle v1.4s, v1.4s, #0 + // delta = (a0 <= 0) ? -1 : 1 (set bit 0) + orr v1.4s, #0x1 + // b = a1 + delta * hint + mla v17.4s, v1.4s, v5.4s + // Clamp to [0, 43]: clear if > 43, then min with 43 + cmgt v25.4s, v17.4s, v24.4s + bic v17.16b, v17.16b, v25.16b + umin v17.4s, v17.4s, v24.4s + + // --- Process v2 (coefficients at offset +32) --- + sqdmulh v18.4s, v2.4s, v23.4s + srshr v18.4s, v18.4s, #0x11 + cmgt v25.4s, v2.4s, v21.4s + mls v2.4s, v18.4s, v22.4s + bic v18.16b, v18.16b, v25.16b + add v2.4s, v2.4s, v25.4s + cmle v2.4s, v2.4s, #0 + orr v2.4s, #0x1 + mla v18.4s, v2.4s, v6.4s + cmgt v25.4s, v18.4s, v24.4s + bic v18.16b, v18.16b, v25.16b + umin v18.4s, v18.4s, v24.4s + + // --- Process v3 (coefficients at offset +48) --- + sqdmulh v19.4s, v3.4s, v23.4s + srshr v19.4s, v19.4s, #0x11 + cmgt v25.4s, v3.4s, v21.4s + mls v3.4s, v19.4s, v22.4s + bic v19.16b, v19.16b, v25.16b + add v3.4s, v3.4s, v25.4s + cmle v3.4s, v3.4s, #0 + orr v3.4s, #0x1 + mla v19.4s, v3.4s, v7.4s + cmgt v25.4s, v19.4s, v24.4s + bic v19.16b, v19.16b, v25.16b + umin v19.4s, v19.4s, v24.4s + + // --- Process v0 (coefficients at offset +0, loaded last for post-increment) --- + sqdmulh v16.4s, v0.4s, v23.4s + srshr v16.4s, v16.4s, #0x11 + cmgt v25.4s, v0.4s, v21.4s + mls v0.4s, v16.4s, v22.4s + bic v16.16b, v16.16b, v25.16b + add v0.4s, v0.4s, v25.4s + cmle v0.4s, v0.4s, #0 + orr v0.4s, #0x1 + mla v16.4s, v0.4s, v4.4s + cmgt v25.4s, v16.4s, v24.4s + bic v16.16b, v16.16b, v25.16b + umin v16.4s, v16.4s, v24.4s + + // Store 16 output coefficients + str q17, [x0, #0x10] + str q18, [x0, #0x20] + str q19, [x0, #0x30] + str q16, [x0], #0x40 + + // Decrement loop counter and branch + subs x3, x3, #0x1 + b.ne Lmldsa_poly_use_hint_88_loop + + CFI_RET + +S2N_BN_SIZE_DIRECTIVE(mldsa_poly_use_hint_88) + +#if defined(__linux__) && defined(__ELF__) +.section .note.GNU-stack, "", %progbits +#endif diff --git a/arm/proofs/mldsa_poly_use_hint_88.ml b/arm/proofs/mldsa_poly_use_hint_88.ml new file mode 100644 index 000000000..287fa3c25 --- /dev/null +++ b/arm/proofs/mldsa_poly_use_hint_88.ml @@ -0,0 +1,1030 @@ +(* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Use hint to correct high bits of decomposition (ML-DSA, param 44). *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "common/mlkem_mldsa.ml";; + + +(**** print_literal_from_elf "arm/mldsa/mldsa_poly_use_hint_88.o";; + ****) + +let mldsa_poly_use_hint_88_mc = define_assert_from_elf + "mldsa_poly_use_hint_88_mc" "arm/mldsa/mldsa_poly_use_hint_88.o" +(*** BYTECODE START ***) +[ + 0x529c0024; (* arm_MOV W4 (rvalue (word 57345)) *) + 0x72a00fe4; (* arm_MOVK W4 (word 127) 16 *) + 0x4e040c94; (* arm_DUP_GEN Q20 X4 32 128 *) + 0x528d8005; (* arm_MOV W5 (rvalue (word 27648)) *) + 0x72a00fc5; (* arm_MOVK W5 (word 126) 16 *) + 0x4e040cb5; (* arm_DUP_GEN Q21 X5 32 128 *) + 0x529d0007; (* arm_MOV W7 (rvalue (word 59392)) *) + 0x72a00047; (* arm_MOVK W7 (word 2) 16 *) + 0x4e040cf6; (* arm_DUP_GEN Q22 X7 32 128 *) + 0x5280b02b; (* arm_MOV W11 (rvalue (word 1409)) *) + 0x72ab02cb; (* arm_MOVK W11 (word 22550) 16 *) + 0x4e040d77; (* arm_DUP_GEN Q23 X11 32 128 *) + 0x4f010578; (* arm_MOVI Q24 (word 184683593771) *) + 0xd2800203; (* arm_MOV X3 (rvalue (word 16)) *) + 0x3dc00421; (* arm_LDR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3dc00822; (* arm_LDR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3dc00c23; (* arm_LDR Q3 X1 (Immediate_Offset (word 48)) *) + 0x3cc40420; (* arm_LDR Q0 X1 (Postimmediate_Offset (word 64)) *) + 0x3dc00445; (* arm_LDR Q5 X2 (Immediate_Offset (word 16)) *) + 0x3dc00846; (* arm_LDR Q6 X2 (Immediate_Offset (word 32)) *) + 0x3dc00c47; (* arm_LDR Q7 X2 (Immediate_Offset (word 48)) *) + 0x3cc40444; (* arm_LDR Q4 X2 (Postimmediate_Offset (word 64)) *) + 0x4eb7b431; (* arm_SQDMULH_VEC Q17 Q1 Q23 32 128 *) + 0x4f2f2631; (* arm_SRSHR_VEC Q17 Q17 17 32 128 *) + 0x4eb53439; (* arm_CMGT_VEC Q25 Q1 Q21 32 128 *) + 0x6eb69621; (* arm_MLS_VEC Q1 Q17 Q22 32 128 *) + 0x4e791e31; (* arm_BIC_VEC Q17 Q17 Q25 128 *) + 0x4eb98421; (* arm_ADD_VEC Q1 Q1 Q25 32 128 *) + 0x6ea09821; (* arm_CMLE_VEC_ZERO Q1 Q1 32 128 *) + 0x4f001421; (* arm_ORR_VEC Q1 Q1 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea59431; (* arm_MLA_VEC Q17 Q1 Q5 32 128 *) + 0x4eb83639; (* arm_CMGT_VEC Q25 Q17 Q24 32 128 *) + 0x4e791e31; (* arm_BIC_VEC Q17 Q17 Q25 128 *) + 0x6eb86e31; (* arm_UMIN_VEC Q17 Q17 Q24 32 128 *) + 0x4eb7b452; (* arm_SQDMULH_VEC Q18 Q2 Q23 32 128 *) + 0x4f2f2652; (* arm_SRSHR_VEC Q18 Q18 17 32 128 *) + 0x4eb53459; (* arm_CMGT_VEC Q25 Q2 Q21 32 128 *) + 0x6eb69642; (* arm_MLS_VEC Q2 Q18 Q22 32 128 *) + 0x4e791e52; (* arm_BIC_VEC Q18 Q18 Q25 128 *) + 0x4eb98442; (* arm_ADD_VEC Q2 Q2 Q25 32 128 *) + 0x6ea09842; (* arm_CMLE_VEC_ZERO Q2 Q2 32 128 *) + 0x4f001422; (* arm_ORR_VEC Q2 Q2 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea69452; (* arm_MLA_VEC Q18 Q2 Q6 32 128 *) + 0x4eb83659; (* arm_CMGT_VEC Q25 Q18 Q24 32 128 *) + 0x4e791e52; (* arm_BIC_VEC Q18 Q18 Q25 128 *) + 0x6eb86e52; (* arm_UMIN_VEC Q18 Q18 Q24 32 128 *) + 0x4eb7b473; (* arm_SQDMULH_VEC Q19 Q3 Q23 32 128 *) + 0x4f2f2673; (* arm_SRSHR_VEC Q19 Q19 17 32 128 *) + 0x4eb53479; (* arm_CMGT_VEC Q25 Q3 Q21 32 128 *) + 0x6eb69663; (* arm_MLS_VEC Q3 Q19 Q22 32 128 *) + 0x4e791e73; (* arm_BIC_VEC Q19 Q19 Q25 128 *) + 0x4eb98463; (* arm_ADD_VEC Q3 Q3 Q25 32 128 *) + 0x6ea09863; (* arm_CMLE_VEC_ZERO Q3 Q3 32 128 *) + 0x4f001423; (* arm_ORR_VEC Q3 Q3 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea79473; (* arm_MLA_VEC Q19 Q3 Q7 32 128 *) + 0x4eb83679; (* arm_CMGT_VEC Q25 Q19 Q24 32 128 *) + 0x4e791e73; (* arm_BIC_VEC Q19 Q19 Q25 128 *) + 0x6eb86e73; (* arm_UMIN_VEC Q19 Q19 Q24 32 128 *) + 0x4eb7b410; (* arm_SQDMULH_VEC Q16 Q0 Q23 32 128 *) + 0x4f2f2610; (* arm_SRSHR_VEC Q16 Q16 17 32 128 *) + 0x4eb53419; (* arm_CMGT_VEC Q25 Q0 Q21 32 128 *) + 0x6eb69600; (* arm_MLS_VEC Q0 Q16 Q22 32 128 *) + 0x4e791e10; (* arm_BIC_VEC Q16 Q16 Q25 128 *) + 0x4eb98400; (* arm_ADD_VEC Q0 Q0 Q25 32 128 *) + 0x6ea09800; (* arm_CMLE_VEC_ZERO Q0 Q0 32 128 *) + 0x4f001420; (* arm_ORR_VEC Q0 Q0 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea49410; (* arm_MLA_VEC Q16 Q0 Q4 32 128 *) + 0x4eb83619; (* arm_CMGT_VEC Q25 Q16 Q24 32 128 *) + 0x4e791e10; (* arm_BIC_VEC Q16 Q16 Q25 128 *) + 0x6eb86e10; (* arm_UMIN_VEC Q16 Q16 Q24 32 128 *) + 0x3d800411; (* arm_STR Q17 X0 (Immediate_Offset (word 16)) *) + 0x3d800812; (* arm_STR Q18 X0 (Immediate_Offset (word 32)) *) + 0x3d800c13; (* arm_STR Q19 X0 (Immediate_Offset (word 48)) *) + 0x3c840410; (* arm_STR Q16 X0 (Postimmediate_Offset (word 64)) *) + 0xf1000463; (* arm_SUBS X3 X3 (rvalue (word 1)) *) + 0x54fff861; (* arm_BNE (word 2096908) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_USE_HINT_88_EXEC = ARM_MK_EXEC_RULE mldsa_poly_use_hint_88_mc;; + +(* Per-element word function matching the assembly computation *) +let mldsa_use_hint_88_asm = new_definition + `mldsa_use_hint_88_asm (a:int32) (h:int32) : int32 = + let a1 = word_ishr_round (word_2smulh a (word 1477838209)) 17 in + let m:int32 = word_neg(word(bitval(word_igt a (word 8285184)))) in + let a0 = word_add (word_sub a (word_mul a1 (word 190464))) m in + let a1' = word_and a1 (word_not m) in + let delta:int32 = word_or (word_neg(word(bitval(word_ile a0 (word 0))))) (word 1) in + let tmp = word_add a1' (word_mul delta h) in + let neg_mask:int32 = word_neg(word(bitval(word_igt tmp (word 43)))) in + let tmp' = word_and tmp (word_not neg_mask) in + word_umin tmp' (word 43)`;; + +(* Numeric description of the assembly's UseHint path, exposing the Barrett + approximation used by the code. Connected to the FIPS 204 definition + mldsa_use_hint_88 via MLDSA_USE_HINT_88_EQUIV below. *) +let mldsa_use_hint_88_code = new_definition + `mldsa_use_hint_88_code (a:num) (h:num) = + let a1_raw = ((((a + 127) DIV 128) * 11275 + 8388608) DIV 16777216) in + let a1 = if a1_raw > 43 then 0 else a1_raw in + let a0:int = &a - &a1 * &190464 in + let a0' = if a0 > &4190208 then a0 - &8380417 else a0 in + if h = 0 then a1 + else if a0' > &0 then if a1 = 43 then 0 else a1 + 1 + else if a1 = 0 then 43 else a1 - 1`;; + +(* ========================================================================= *) +(* Helper lemmas *) +(* ========================================================================= *) + +let WORD_2SMULH_NOSATURATE_88 = prove( + `!a:int32. val a < 8380417 + ==> word_2smulh a (word 1477838209:int32) : int32 = + iword((&2 * &(val a) * &1477838209) div &2 pow 32)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[word_2smulh; DIMINDEX_32] THEN + ASM_SIMP_TAC[MLDSA_IVAL_VAL] THEN + CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[iword_saturate; word_INT_MIN; word_INT_MAX; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN + SIMP_TAC[INT_OF_NUM_DIV] THEN + REWRITE_TAC[INT_POS_NEG_BOUND] THEN + SUBGOAL_THEN `~(&((2 * val(a:int32) * 1477838209) DIV 4294967296):int > &2147483647)` + (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_ARITH `~(x:int > y) <=> x <= y`; INT_OF_NUM_LE] THEN + TRANS_TAC LE_TRANS `(2 * 8380416 * 1477838209) DIV 4294967296` THEN + CONJ_TAC THENL + [MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV]);; + +let VAL_DECOMPOSE_A1_88 = prove( + `!a:int32. val a < 8380417 + ==> val(word_ishr_round (word_2smulh a (word 1477838209:int32)) 17 : int32) + = ((2 * val a * 1477838209) DIV 4294967296 + 65536) DIV 131072`, + GEN_TAC THEN DISCH_TAC THEN + ASM_SIMP_TAC[WORD_2SMULH_NOSATURATE_88] THEN + SUBGOAL_THEN `(2 * val(a:int32) * 1477838209) DIV 4294967296 < 2147483648` + ASSUME_TAC THENL + [TRANS_TAC LT_TRANS `(2 * 8380416 * 1477838209) DIV 4294967296 + 1` THEN + CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ABBREV_TAC `t:int32 = iword(&((2 * val(a:int32) * 1477838209) DIV 4294967296))` THEN + SUBGOAL_THEN `val(t:int32) = (2 * val(a:int32) * 1477838209) DIV 4294967296` + ASSUME_TAC THENL + [EXPAND_TAC "t" THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `val(t:int32) < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[word_ishr_round] THEN CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV THEN + SUBGOAL_THEN `ival(t:int32) = &(val t)` ASSUME_TAC THENL + [SIMP_TAC[ival; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `(val(t:int32) + 65536) DIV 131072 < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + TRANS_TAC LT_TRANS `(5767167 + 65536) DIV 131072 + 1` THEN CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + ASM_SIMP_TAC[VAL_IWORD_NUM_32] THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN + UNDISCH_THEN `val(t:int32) = (2 * val(a:int32) * 1477838209) DIV 4294967296` + (SUBST1_TAC o SYM) THEN ASM_REWRITE_TAC[]);; + +let WORD_IGT_THRESHOLD_88 = BITBLAST_RULE + `!a:int32. val a < 8380417 + ==> word_igt a (word 8285184:int32) <=> val a > 8285184`;; + +let A1_BOUND_88 = prove( + `!a. a < 8380417 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 44`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8380416 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `751819508` (SPEC `d * 11275 + 8388608` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 11275 <= 65472 * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A1_BOUND_NOWRAP_88 = prove( + `!a. a <= 8285184 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 43`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8285184 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `738196808` (SPEC `d * 11275 + 8388608` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 11275 <= 64728 * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +(* Barrett equivalence for _88: 45-interval case analysis *) +let BARRETT_INTERVAL_88 = prove( + `!a lo hi k. + lo <= a /\ a <= hi /\ + k * 131072 <= (2 * lo * 1477838209) DIV 4294967296 + 65536 /\ + (2 * hi * 1477838209) DIV 4294967296 + 65536 < (k + 1) * 131072 /\ + k * 16777216 <= (lo + 127) DIV 128 * 11275 + 8388608 /\ + (hi + 127) DIV 128 * 11275 + 8388608 < (k + 1) * 16777216 + ==> ((2 * a * 1477838209) DIV 4294967296 + 65536) DIV 131072 = k /\ + ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + CONJ_TAC THEN MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THENL + [CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(2 * lo * 1477838209) DIV 4294967296 + 65536` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 65536 <= y + 65536 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(2 * hi * 1477838209) DIV 4294967296 + 65536` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 65536 <= y + 65536 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]; + CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(lo + 127) DIV 128 * 11275 + 8388608` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 8388608 <= y + 8388608 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(hi + 127) DIV 128 * 11275 + 8388608` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 8388608 <= y + 8388608 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]]);; + +let BARRETT_EQUIV_88 = prove( + `!a. a < 8380417 ==> + ((2 * a * 1477838209) DIV 4294967296 + 65536) DIV 131072 = + ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216`, + GEN_TAC THEN DISCH_TAC THEN + let intervals = [ + (0, 95232); (95233, 285696); (285697, 476160); (476161, 666624); + (666625, 857088); (857089, 1047552); (1047553, 1238016); + (1238017, 1428480); (1428481, 1618944); (1618945, 1809408); + (1809409, 1999872); (1999873, 2190336); (2190337, 2380800); + (2380801, 2571264); (2571265, 2761728); (2761729, 2952192); + (2952193, 3142656); (3142657, 3333120); (3333121, 3523584); + (3523585, 3714048); (3714049, 3904512); (3904513, 4094976); + (4094977, 4285440); (4285441, 4475904); (4475905, 4666368); + (4666369, 4856832); (4856833, 5047296); (5047297, 5237760); + (5237761, 5428224); (5428225, 5618688); (5618689, 5809152); + (5809153, 5999616); (5999617, 6190080); (6190081, 6380544); + (6380545, 6571008); (6571009, 6761472); (6761473, 6951936); + (6951937, 7142400); (7142401, 7332864); (7332865, 7523328); + (7523329, 7713792); (7713793, 7904256); (7904257, 8094720); + (8094721, 8285184); (8285185, 8380416)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("a",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`a:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_88 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + + + +let WORD_SUB_SIGN_88 = BITBLAST_RULE + `!a:int32 b:int32. val b <= 8189952 /\ val a <= 8285184 ==> + ((bit 31 (word_sub a b) \/ word_sub a b = word 0) <=> val a <= val b)`;; + +let WRAP_A0_NEGATIVE_88 = BITBLAST_RULE + `!a:int32. val a < 8380417 /\ val a > 8285184 + ==> bit 31 (word_add (word_sub a (word 8380416:int32)) (word 4294967295:int32))`;; + +let A1_WRAP_88 = prove( + `!a. 8285184 < a /\ a < 8380417 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 = 44`, + GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `44 <= ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216` + ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `a + 127` (SPEC `8285185 + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `d * 11275 + 8388608` (SPEC `738208083` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `64729 * 11275 <= d * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + MP_TAC(SPEC `a:num` A1_BOUND_88) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_ARITH_TAC);; + +let A0_UPPER_88 = prove( + `!a. a <= 8285184 + ==> a < (((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 + 1) * 190464`, + GEN_TAC THEN DISCH_TAC THEN + ABBREV_TAC `nv = ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216` THEN + SUBGOAL_THEN `nv * 16777216 <= (a + 127) DIV 128 * 11275 + 8388608` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN MP_TAC(SPECL [`(a + 127) DIV 128 * 11275 + 8388608`; `16777216`] (CONJUNCT1 DIVISION_SIMP)) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(a + 127) DIV 128 <= 64728` ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `8285184 + 127` (SPEC `a + 127` DIV_MONO))) THEN ANTS_TAC THENL [ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 16777216 <= 64728 * 11275 + 8388608` ASSUME_TAC THENL + [SUBGOAL_THEN `(a + 127) DIV 128 * 11275 <= 64728 * 11275` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ASM_ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 43` ASSUME_TAC THENL [CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_ARITH_TAC);; + +let WORD_IGT_43_BOUND = BITBLAST_RULE + `!a:int32. val a <= 43 ==> ~(word_igt a (word 43:int32))`;; +let WORD_IGT_43_ADD1 = BITBLAST_RULE + `!a:int32. val a <= 42 ==> ~(word_igt (word_add a (word 1:int32)) (word 43:int32))`;; +let WORD_IGT_43_SUB1 = BITBLAST_RULE + `!a:int32. val a <= 43 /\ ~(val a = 0) ==> + ~(word_igt (word_add a (word 4294967295:int32)) (word 43:int32))`;; +let WORD_IGT_43_TRUE = BITBLAST_RULE + `word_igt (word 44:int32) (word 43:int32)`;; + +let ELEMENT_CORRECT_88 = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> val(mldsa_use_hint_88_asm a h) = mldsa_use_hint_88_code (val a) (val h)`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[mldsa_use_hint_88_asm; mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ABBREV_TAC `nv = ((val(a:int32) + 127) DIV 128 * 11275 + 8388608) DIV 16777216` THEN + SUBGOAL_THEN `val(word_ishr_round (word_2smulh (a:int32) (word 1477838209)) 17 : int32) = nv` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN TRANS_TAC EQ_TRANS `((2 * val(a:int32) * 1477838209) DIV 4294967296 + 65536) DIV 131072` THEN CONJ_TAC THENL [MATCH_MP_TAC VAL_DECOMPOSE_A1_88 THEN ASM_REWRITE_TAC[]; MATCH_MP_TAC BARRETT_EQUIV_88 THEN ASM_REWRITE_TAC[]]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 44` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND_88) THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_igt (a:int32) (word 8285184:int32) <=> val a > 8285184` SUBST1_TAC THENL [MP_TAC(SPEC `a:int32` WORD_IGT_THRESHOLD_88) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(a:int32) > 8285184` THENL + [(* Wrap case: val a > 8285184, nv = 44 *) + REWRITE_TAC[ASSUME `val(a:int32) > 8285184`; bitval] THEN + SUBGOAL_THEN `nv = 44` SUBST_ALL_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_WRAP_88) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + ABBREV_TAC `a1w = word_ishr_round (word_2smulh (a:int32) (word 1477838209)) 17 : int32` THEN + SUBGOAL_THEN `a1w = (word 44:int32)` SUBST_ALL_TAC THENL [EXPAND_TAC "a1w" THEN ONCE_REWRITE_TAC[GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_ILE_ZERO_32] THEN + MP_TAC(SPEC `a:int32` WRAP_A0_NEGATIVE_88) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + SUBGOAL_THEN `~((if int_gt (&(val(a:int32))) (&4190208) then &(val a) - &8380417 else &(val a):int) > &0)` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN COND_CASES_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT] (ASSUME `val(a:int32) < 8380417`)) THEN INT_ARITH_TAC; + POP_ASSUM(MP_TAC o REWRITE_RULE[INT_GT; INT_NOT_LT]) THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_GT; INT_GT] (ASSUME `val(a:int32) > 8285184`)) THEN INT_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(h:int32) = 0` THENL + [REWRITE_TAC[ASSUME `val(h:int32) = 0`] THEN + SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV; + REWRITE_TAC[ASSUME `~(val(h:int32) = 0)`] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV] + ;(* Nowrap case: val a <= 8285184 *) + REWRITE_TAC[ASSUME `~(val(a:int32) > 8285184)`; bitval] THEN + SUBGOAL_THEN `nv <= 43` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND_NOWRAP_88) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `(if nv > 43 then 0 else nv) = nv` SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ABBREV_TAC `a1w = word_ishr_round (word_2smulh (a:int32) (word 1477838209)) 17 : int32` THEN + SUBGOAL_THEN `a1w = (word nv:int32)` SUBST_ALL_TAC THENL [EXPAND_TAC "a1w" THEN GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_REFL; WORD_ILE_ZERO_32; WORD_ADD_0; WORD_AND_ONES_32] THEN + SUBGOAL_THEN `nv * 190464 <= 8189952` ASSUME_TAC THENL [SUBGOAL_THEN `nv * 190464 <= 43 * 190464` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 190464:int32)) = nv * 190464` ASSUME_TAC THENL [REWRITE_TAC[VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 190464:int32)) <= 8189952` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(bit 31 (word_sub (a:int32) (word_mul (word nv:int32) (word 190464:int32))) \/ word_sub a (word_mul (word nv) (word 190464)) = word 0) <=> val a <= nv * 190464` SUBST1_TAC THENL + [MP_TAC(ISPECL [`a:int32`; `word_mul (word nv:int32) (word 190464:int32)`] WORD_SUB_SIGN_88) THEN ASM_REWRITE_TAC[] THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[bitval] THEN + SUBGOAL_THEN `val(word nv:int32) = nv` ASSUME_TAC THENL [MATCH_MP_TAC VAL_WORD_EQ THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~(word_igt (word nv:int32) (word 43:int32))` ASSUME_TAC THENL [MP_TAC(SPEC `word nv:int32` WORD_IGT_43_BOUND) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(h:int32) = 0` THENL + [(* h = 0 nowrap *) + REWRITE_TAC[ASSUME `val(h:int32) = 0`] THEN + SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_0; WORD_ADD_0] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32; VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[MIN] THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC + ;(* h = 1 nowrap *) + REWRITE_TAC[ASSUME `~(val(h:int32) = 0)`] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_1_32; WORD_AND_ONES_32] THEN + SUBGOAL_THEN `val(word nv:int32) <= 43` ASSUME_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(int_gt (&(val(a:int32)) - &nv * &190464) (&4190208))` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN MP_TAC(SPEC `val(a:int32)` A0_UPPER_88) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD] (ASSUME `val(a:int32) < (nv + 1) * 190464`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 190464` THENL + [MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `190464`] REAL_INT_GT_BRIDGE) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]; + MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `190464`] REAL_INT_GT_BRIDGE_POS) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]] THENL + [(* delta = -1: a0' <= 0 *) + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32] THEN + ASM_CASES_TAC `nv = 0` THENL + [ASM_REWRITE_TAC[] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV + ;REWRITE_TAC[ASSUME `~(nv = 0)`] THEN + SUBGOAL_THEN `~word_igt (word_add (word nv:int32) (word 4294967295)) (word 43:int32)` ASSUME_TAC THENL + [MP_TAC(SPEC `word nv:int32` WORD_IGT_43_SUB1) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN MATCH_MP_TAC THEN CONJ_TAC THENL [ASM_REWRITE_TAC[]; REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[] THEN ASM_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32; VAL_WORD_UMIN; VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(nv + 4294967295) MOD 4294967296 = nv - 1` SUBST1_TAC THENL [SUBGOAL_THEN `nv + 4294967295 = (nv - 1) + 1 * 4294967296` SUBST1_TAC THENL [UNDISCH_TAC `~(nv = 0)` THEN ARITH_TAC; REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC]; ALL_TAC] THEN + REWRITE_TAC[MIN] THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC] + ;(* delta = +1: a0' > 0 *) + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32] THEN + ASM_CASES_TAC `nv = 43` THENL + [ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN MP_TAC WORD_IGT_43_TRUE THEN DISCH_TAC THEN ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV + ;REWRITE_TAC[ASSUME `~(nv = 43)`] THEN + SUBGOAL_THEN `~word_igt (word_add (word nv:int32) (word 1)) (word 43:int32)` ASSUME_TAC THENL + [MP_TAC(SPEC `word nv:int32` WORD_IGT_43_ADD1) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN MATCH_MP_TAC THEN UNDISCH_TAC `nv <= 43` THEN UNDISCH_TAC `~(nv = 43)` THEN UNDISCH_TAC `val(word nv:int32) = nv` THEN ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32; VAL_WORD_UMIN; VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(nv + 1) MOD 4294967296 = nv + 1` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[MIN] THEN UNDISCH_TAC `nv <= 43` THEN UNDISCH_TAC `~(nv = 43)` THEN ARITH_TAC]]]]);; + +let ELEMENT_CORRECT_WORD_88 = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> mldsa_use_hint_88_asm a h = + word(mldsa_use_hint_88_code (val a) (val h))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN + AP_TERM_TAC THEN MP_TAC(SPECL [`a:int32`; `h:int32`] ELEMENT_CORRECT_88) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]));; + +(* ========================================================================= *) +(* Correctness proof, code-aligned spec (intermediate) *) +(* ========================================================================= *) + +let MLDSA_USE_HINT_88_CORRECT_CODE = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH mldsa_poly_use_hint_88_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_use_hint_88_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH mldsa_poly_use_hint_88_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i)) (val(y i))))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + + MAP_EVERY X_GEN_TAC + [`b:int64`; `a:int64`; `h:int64`; + `x:num->int32`; `y:num->int32`; `pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; + NONOVERLAPPING_CLAUSES; ALL; + fst MLDSA_USE_HINT_88_EXEC] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + GLOBALIZE_PRECONDITION_TAC THEN + CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV EXPAND_CASES_CONV))) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC[SOME_FLAGS; MODIFIABLE_SIMD_REGS] THEN + + ENSURES_INIT_TAC "s0" THEN + MEMORY_128_FROM_32_TAC "a" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + MEMORY_128_FROM_32_TAC "h" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + + MAP_EVERY (fun n -> ARM_STEPS_TAC MLDSA_USE_HINT_88_EXEC [n] THEN + SIMD_SIMPLIFY_TAC[]) + (1--1006) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + REPEAT(FIRST_X_ASSUM(STRIP_ASSUME_TAC o + CONV_RULE (SIMD_SIMPLIFY_CONV []) o + CONV_RULE(READ_MEMORY_SPLIT_CONV 2) o + check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN + + CONV_TAC(TOP_DEPTH_CONV EXPAND_CASES_CONV) THEN + CONV_TAC(DEPTH_CONV NUM_MULT_CONV THENC DEPTH_CONV NUM_ADD_CONV) THEN + REWRITE_TAC[WORD_ADD_0] THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN ASM_REWRITE_TAC[] THEN + + (* Push word_subword through SIMD ops to per-element form *) + REWRITE_TAC[WORD_SUBWORD_AND; WORD_SUBWORD_OR] THEN + let WSN_TAC = REWRITE_TAC(map (fun n -> prove( + subst [mk_small_numeral n, `n:num`] + `!x:int128. word_subword(word_not x) (n,32):int32 = word_not(word_subword x (n,32))`, + GEN_TAC THEN MATCH_MP_TAC WORD_SUBWORD_NOT THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_128] THEN ARITH_TAC)) [0;32;64;96]) in + WSN_TAC THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + let EC_DEEP = CONV_RULE(DEPTH_CONV WORD_NUM_RED_CONV) + (CONV_RULE(DEPTH_CONV(INT_RED_CONV ORELSEC NUM_RED_CONV)) + (CONV_RULE(TOP_DEPTH_CONV let_CONV) + (REWRITE_RULE[mldsa_use_hint_88_asm; word_2smulh; word_ishr_round; + DIMINDEX_32] ELEMENT_CORRECT_WORD_88))) in + let EC_OR = ONCE_REWRITE_RULE[WORD_OR_SYM] EC_DEEP in + REPEAT CONJ_TAC THEN + (MATCH_MP_TAC EC_OR ORELSE MATCH_MP_TAC EC_DEEP) THEN + CONJ_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ARITH_TAC);; + +(* ========================================================================= *) +(* Subroutine form (intermediate, code-aligned) *) +(* ========================================================================= *) + +let ENSURES_STRENGTHEN_POST = prove( + `!P (Q:armstate->bool) Q' R. + ensures arm P Q' R /\ (!s. Q' s ==> Q s) ==> ensures arm P Q R`, + REPEAT GEN_TAC THEN DISCH_THEN(CONJUNCTS_THEN2 MP_TAC ASSUME_TAC) THEN + REWRITE_TAC[ensures] THEN MATCH_MP_TAC MONO_FORALL THEN + X_GEN_TAC `s0:armstate` THEN MATCH_MP_TAC MONO_IMP THEN REWRITE_TAC[] THEN + MP_TAC(BETA_RULE(ISPECL [`arm`; + `\s':armstate. (Q':armstate->bool) s' /\ (R:armstate->armstate->bool) (s0:armstate) s'`; + `\s':armstate. (Q:armstate->bool) s' /\ (R:armstate->armstate->bool) (s0:armstate) s'`] + EVENTUALLY_MONO)) THEN + ANTS_TAC THENL [ASM_MESON_TAC[]; MESON_TAC[]]);; + +let MLDSA_USE_HINT_88_CORRECT_BOUND_CODE = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH mldsa_poly_use_hint_88_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_use_hint_88_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH mldsa_poly_use_hint_88_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i)) (val(y i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + MATCH_MP_TAC ENSURES_STRENGTHEN_POST THEN + EXISTS_TAC + `\s. read PC s = word(pc + LENGTH mldsa_poly_use_hint_88_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i:int32)) (val(y i:int32))))` THEN + CONJ_TAC THENL + [MATCH_MP_TAC MLDSA_USE_HINT_88_CORRECT_CODE THEN ASM_REWRITE_TAC[]; + REWRITE_TAC[] THEN REPEAT STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + FIRST_X_ASSUM(MP_TAC o SPEC `i:num`) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC(ARITH_RULE `x < 44 ==> x MOD 4294967296 < 44`) THEN + REWRITE_TAC[mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN ASM_ARITH_TAC]);; + +(* Intermediate subroutine correctness against the code-aligned spec. + Bridged to the public FIPS 204-aligned theorem below via + MLDSA_USE_HINT_88_EQUIV. *) +let MLDSA_USE_HINT_88_SUBROUTINE_CORRECT_CODE = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH mldsa_poly_use_hint_88_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_use_hint_88_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i)) (val(y i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REWRITE_TAC[fst MLDSA_USE_HINT_88_EXEC] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ARM_ADD_RETURN_NOSTACK_TAC MLDSA_USE_HINT_88_EXEC + (CONV_RULE(ONCE_DEPTH_CONV NUM_REDUCE_CONV) + (REWRITE_RULE[fst MLDSA_USE_HINT_88_EXEC] + MLDSA_USE_HINT_88_CORRECT_BOUND_CODE)));; + + +(* ========================================================================= *) +(* FIPS 204 = code-aligned equivalence *) +(* ========================================================================= *) + +let LINEARIZE_DIV_MOD_TAC_88 = + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +(* Prove r DIV 190464 = k via DIV_SANDWICH + LE_MULT_RCANCEL *) +let DIV_190464_TAC k = + let k_num = mk_small_numeral k and km1 = mk_small_numeral (k-1) + and kp1 = mk_small_numeral (k+1) + and q = mk_var("q",`:num`) and le = `(<=):num->num->bool` + and lt = `(<):num->num->bool` + and c = `190464` in + let mk_mul a b = mk_binop (rator(rator `0*0`)) a b in + MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; c] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, q) THEN + REPEAT GEN_TAC THEN STRIP_TAC THEN + ASM_CASES_TAC(mk_comb(mk_comb(le, q), km1)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul q c), + mk_mul km1 c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul k_num c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_CASES_TAC(mk_comb(mk_comb(lt, k_num), q)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul kp1 c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC]];; + +(* Replace (r - r MOD 190464) DIV 190464 with r DIV 190464 *) +let DIV_MOD_TO_DIV_TAC_88 = + SUBGOAL_THEN `(r - r MOD 190464) DIV 190464 = r DIV 190464` SUBST1_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 = 190464 * r DIV 190464` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV; ALL_TAC];; + +(* Lower half nowrap: dismiss wrap cond, reduce, prove r DIV 190464 = k *) +let DECOMPOSE_R1_LOWER_TAC_88 = + SUBGOAL_THEN `~((&r:int) - &(r MOD 190464) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN LINEARIZE_DIV_MOD_TAC_88; + ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + DIV_MOD_TO_DIV_TAC_88 THEN + CONV_TAC SYM_CONV THEN + LINEARIZE_DIV_MOD_TAC_88;; + +(* Upper half nowrap: dismiss wrap cond, reduce, prove r DIV 190464 + 1 = k *) +let DECOMPOSE_R1_UPPER_TAC_88 = + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~((&r:int) - (&(r MOD 190464) - &190464) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [REWRITE_TAC[INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`] THEN + ASM_SIMP_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + LINEARIZE_DIV_MOD_TAC_88; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 + 190464 = 190464 * (r DIV 190464 + 1)` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464 + 1`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN SUBST1_TAC THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +let DECOMPOSE_R1_NOWRAP_TAC_88 = + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THEN + TRY DECOMPOSE_R1_LOWER_TAC_88 THEN TRY DECOMPOSE_R1_UPPER_TAC_88;; + +let DECOMPOSE_88_R1_EQUIV = time prove( + `!r. r < 8380417 ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = + decompose_88_r1 r`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_CASES_TAC `r <= 8285184` THENL + [ALL_TAC; + (* Wrap zone *) + SUBGOAL_THEN `8285184 < r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `decompose_88_r1 r = 0` SUBST1_TAC THENL + [REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; + ALL_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THENL + [(* Lower wrap: r DIV 190464 = 44 *) + SUBGOAL_THEN `r DIV 190464 = 44` ASSUME_TAC THENL + [DIV_190464_TAC 44; ALL_TAC] THEN + SUBGOAL_THEN `44 * 190464 + r MOD 190464 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + ASM_ARITH_TAC; + (* Upper wrap: r DIV 190464 = 43 *) + SUBGOAL_THEN `r DIV 190464 = 43` ASSUME_TAC THENL + [DIV_190464_TAC 43; ALL_TAC] THEN + SUBGOAL_THEN `43 * 190464 + r MOD 190464 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + ALL_TAC] THEN + MP_TAC(SPEC `r:num` A1_WRAP_88) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONV_TAC NUM_REDUCE_CONV] THEN + (* Nowrap zone: Barrett <= 43, so if > 43 then 0 else Barrett = Barrett *) + SUBGOAL_THEN `((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 43` + ASSUME_TAC THENL + [MATCH_MP_TAC A1_BOUND_NOWRAP_88 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 > 43)` + (fun th -> REWRITE_TAC[th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + (* Nowrap zone: unfold and do interval cascade *) + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + let intervals = [ + (0, 95232); (95233, 285696); (285697, 476160); (476161, 666624); + (666625, 857088); (857089, 1047552); (1047553, 1238016); + (1238017, 1428480); (1428481, 1618944); (1618945, 1809408); + (1809409, 1999872); (1999873, 2190336); (2190337, 2380800); + (2380801, 2571264); (2571265, 2761728); (2761729, 2952192); + (2952193, 3142656); (3142657, 3333120); (3333121, 3523584); + (3523585, 3714048); (3714049, 3904512); (3904513, 4094976); + (4094977, 4285440); (4285441, 4475904); (4475905, 4666368); + (4666369, 4856832); (4856833, 5047296); (5047297, 5237760); + (5237761, 5428224); (5428225, 5618688); (5618689, 5809152); + (5809153, 5999616); (5999617, 6190080); (6190081, 6380544); + (6380545, 6571008); (6571009, 6761472); (6761473, 6951936); + (6951937, 7142400); (7142401, 7332864); (7332865, 7523328); + (7523329, 7713792); (7713793, 7904256); (7904257, 8094720); + (8094721, 8285184)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("r",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`r:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_88 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + DECOMPOSE_R1_NOWRAP_TAC_88 in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + +let R1_IS_DIV_LOWER_88 = prove( + `!r. r < 8380417 /\ r MOD 190464 * 2 <= 190464 /\ + ~((&r:int) - &(r MOD 190464) = &8380416) ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = r DIV 190464`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` LOWER_NONWRAP_R1_88) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R1_IS_DIV_PLUS1_UPPER_88 = prove( + `!r. r < 8380417 /\ ~(r MOD 190464 * 2 <= 190464) /\ + ~((&r:int) - (&(r MOD 190464) - &190464) = &8380416) ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = r DIV 190464 + 1`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` UPPER_NONWRAP_R1_88) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +(* Upper nowrap: substitute Barrett = r DIV 190464 + 1, use INT_MOD_RESIDUE *) +let R0_SIGN_UPPER_NOWRAP_TAC_88 = + MP_TAC(SPEC `r:num` R1_IS_DIV_PLUS1_UPPER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_MUL] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `(a:int) - (b + &1) * c = a - b * c - c`] THEN + REWRITE_TAC[INT_ARITH `x - &190464 > &0 <=> x > &190464`; + INT_ARITH `x - &190464 - &8380417 > &0 <=> x > &8570881`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Lower nowrap: substitute Barrett = r DIV 190464, use INT_MOD_RESIDUE *) +let R0_SIGN_LOWER_NOWRAP_TAC_88 = + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Wrap: derive 8285184 < r, use DECOMPOSE_88_R1_EQUIV to get Barrett = 0 *) +let R0_SIGN_WRAP_TAC_88 = + SUBGOAL_THEN `8285184 < r` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> t = `&8380416:int`)) (concl th) && + not(is_neg(concl th)))) THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ; + INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`; + GSYM INT_OF_NUM_ADD] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + REWRITE_TAC[INT_ARITH `x - &1 > &0 <=> x > &1`; + INT_ARITH `(x - &190464) - &1 > &0 <=> x > &190465`; + INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let DECOMPOSE_88_R0_SIGN = time prove( + `!r. r < 8380417 ==> + let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + let a1 = if a1_raw > 43 then 0 else a1_raw in + let a0':int = if (&r:int) - &a1 * &190464 > &4190208 + then &r - &a1 * &190464 - &8380417 + else &r - &a1 * &190464 in + (decompose_88_r0 r > &0 <=> a0' > &0) /\ + (decompose_88_r0 r <= &0 <=> ~(a0' > &0))`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[INT_ARITH `(x:int) <= &0 <=> ~(x > &0)`] THEN + MATCH_MP_TAC(TAUT `(p <=> q) ==> (p <=> q) /\ (~p <=> ~q)`) THEN + REWRITE_TAC[decompose_88_r0; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ONCE_REWRITE_TAC[COND_RAND] THEN REWRITE_TAC[SND; FST] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + TRY R0_SIGN_LOWER_NOWRAP_TAC_88 THEN + TRY R0_SIGN_UPPER_NOWRAP_TAC_88 THEN + TRY R0_SIGN_WRAP_TAC_88 THEN + TRY( + (* Contradiction: lower nowrap with > 4190208 *) + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - &((let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw)) * &190464 = &(r MOD 190464)` ASSUME_TAC THENL + [CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(&(r MOD 190464) > (&4190208:int))` MP_TAC THENL + [REWRITE_TAC[INT_NOT_LT; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[INT_OF_NUM_GT] THEN ASM_ARITH_TAC + ));; + +(* R1 equivalence for _88 (needed for MLDSA_USE_HINT_88_EQUIV below). + + Show decompose_88_r1 r equals the a1 value from the code-aligned spec: + let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw *) + +let MLDSA_USE_HINT_88_EQUIV = prove( + `!r h. r < 8380417 /\ h <= 1 + ==> mldsa_use_hint_88 h r = mldsa_use_hint_88_code r h`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[MLDSA_USE_HINT_88_UNFOLD] THEN + REWRITE_TAC[mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + DISCH_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R0_SIGN) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN STRIP_TAC THEN + ASM_CASES_TAC `h = 0` THENL + [ASM_REWRITE_TAC[ARITH_RULE `~(0 = 1)`]; ALL_TAC] THEN + SUBGOAL_THEN `h = 1` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `decompose_88_r0 r > &0` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[]);; + +(* ========================================================================= *) +(* Public subroutine correctness (FIPS 204-aligned) *) +(* ========================================================================= *) + +(* Postcondition is stated in terms of mldsa_use_hint_88 from FIPS 204 + (Algorithm 40), with the output bound < 44 as a corollary. + Derived from MLDSA_USE_HINT_88_SUBROUTINE_CORRECT_CODE by + rewriting mldsa_use_hint_88_code -> mldsa_use_hint_88 via + MLDSA_USE_HINT_88_EQUIV. *) +let MLDSA_USE_HINT_88_SUBROUTINE_CORRECT = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH mldsa_poly_use_hint_88_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) /\ + (!i. i < 256 ==> val((x:num->int32) i) < 8380417) /\ + (!i. i < 256 ==> val((y:num->int32) i) <= 1) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_use_hint_88_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REPEAT GEN_TAC THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + SUBGOAL_THEN + `!i. i < 256 ==> + mldsa_use_hint_88 (val((y:num->int32) i)) (val((x:num->int32) i)) = + mldsa_use_hint_88_code (val(x i)) (val(y i))` + (fun th -> SIMP_TAC[th]) THENL + [REPEAT STRIP_TAC THEN MATCH_MP_TAC MLDSA_USE_HINT_88_EQUIV THEN + CONJ_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]; + MATCH_MP_TAC MLDSA_USE_HINT_88_SUBROUTINE_CORRECT_CODE THEN + ASM_REWRITE_TAC[]]);; + + +(* ========================================================================= *) +(* Constant-time and memory safety proof. *) +(* ========================================================================= *) + +needs "arm/proofs/consttime.ml";; +needs "arm/proofs/subroutine_signatures.ml";; + + +let full_spec,public_vars = mk_safety_spec + ~keep_maychanges:false + (assoc "mldsa_poly_use_hint_88" subroutine_signatures) + MLDSA_USE_HINT_88_SUBROUTINE_CORRECT_CODE + MLDSA_USE_HINT_88_EXEC;; + +let MLDSA_USE_HINT_88_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e b a h pc returnaddress. + nonoverlapping (word pc,LENGTH mldsa_poly_use_hint_88_mc) (b,1024) /\ + nonoverlapping (b,1024) (a,1024) /\ + nonoverlapping (b,1024) (h,1024) + ==> ensures arm + (\s. + aligned_bytes_loaded s (word pc) + mldsa_poly_use_hint_88_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + read events s = e) + (\s. + read PC s = returnaddress /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h b pc returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; b,1024] + [b,1024])) + (\s s'. true)`, + ASSERT_CONCL_TAC full_spec THEN + PROVE_SAFETY_SPEC_TAC ~public_vars:public_vars MLDSA_USE_HINT_88_EXEC);; diff --git a/arm/proofs/specifications.txt b/arm/proofs/specifications.txt index 0115199f2..79c0b1a71 100644 --- a/arm/proofs/specifications.txt +++ b/arm/proofs/specifications.txt @@ -335,6 +335,8 @@ MLDSA_POINTWISE_ACC_L7_SUBROUTINE_CORRECT MLDSA_POINTWISE_ACC_L7_SUBROUTINE_SAFE MLDSA_POINTWISE_SUBROUTINE_CORRECT MLDSA_POINTWISE_SUBROUTINE_SAFE +MLDSA_USE_HINT_88_SUBROUTINE_CORRECT +MLDSA_USE_HINT_88_SUBROUTINE_SAFE MLKEM_BASEMUL_K2_SUBROUTINE_CORRECT MLKEM_BASEMUL_K2_SUBROUTINE_SAFE MLKEM_BASEMUL_K3_SUBROUTINE_CORRECT diff --git a/arm/proofs/subroutine_signatures.ml b/arm/proofs/subroutine_signatures.ml index ac5e134a6..0b363775e 100644 --- a/arm/proofs/subroutine_signatures.ml +++ b/arm/proofs/subroutine_signatures.ml @@ -4555,6 +4555,24 @@ let subroutine_signatures = [ ]) ); +("mldsa_poly_use_hint_88", + ([(*args*) + ("b", "int32_t[static 256]", (*is const?*)"false"); + ("a", "int32_t[static 256]", (*is const?*)"true"); + ("h", "int32_t[static 256]", (*is const?*)"true"); + ], + "void", + [(* input buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ("h", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* output buffers *) + ("b", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + ("mlkem_basemul_k2", ([(*args*) ("r", "int16_t[static 256]", (*is const?*)"false"); diff --git a/benchmarks/benchmark.c b/benchmarks/benchmark.c index d64e5c4ee..eef014cd8 100644 --- a/benchmarks/benchmark.c +++ b/benchmarks/benchmark.c @@ -1111,6 +1111,7 @@ void call_mldsa_pointwise(void) repeat(mldsa_pointwise_x86((int32_t*)b0,(int32_t void call_mldsa_pointwise_acc_l4(void) repeat(mldsa_pointwise_acc_l4_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) void call_mldsa_pointwise_acc_l5(void) repeat(mldsa_pointwise_acc_l5_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) void call_mldsa_pointwise_acc_l7(void) repeat(mldsa_pointwise_acc_l7_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) +void call_mldsa_poly_use_hint_88(void) {} void call_mldsa_reduce(void) repeat(mldsa_reduce((int32_t*)b0)) void call_mlkem_frombytes(void) repeat(mlkem_frombytes((uint16_t*)b0,(int8_t*)b1)) @@ -1155,6 +1156,7 @@ void call_mldsa_pointwise(void) repeat(mldsa_pointwise((int32_t*)b0,(int32_t*)b1 void call_mldsa_pointwise_acc_l4(void) repeat(mldsa_pointwise_acc_l4((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) void call_mldsa_pointwise_acc_l5(void) repeat(mldsa_pointwise_acc_l5((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) void call_mldsa_pointwise_acc_l7(void) repeat(mldsa_pointwise_acc_l7((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) +void call_mldsa_poly_use_hint_88(void) repeat(mldsa_poly_use_hint_88((int32_t*)b0,(int32_t*)b1,(int32_t*)b2)) void call_mldsa_reduce(void) {} void call_bignum_copy_row_from_table_8n__32_16(void) \ @@ -1629,6 +1631,7 @@ int main(int argc, char *argv[]) timingtest(all,"mldsa_pointwise_acc_l4",call_mldsa_pointwise_acc_l4); timingtest(all,"mldsa_pointwise_acc_l5",call_mldsa_pointwise_acc_l5); timingtest(all,"mldsa_pointwise_acc_l7",call_mldsa_pointwise_acc_l7); + timingtest(arm,"mldsa_poly_use_hint_88",call_mldsa_poly_use_hint_88); timingtest(!arm,"mldsa_reduce",call_mldsa_reduce); timingtest(bmi,"p256_montjadd",call_p256_montjadd); timingtest(all,"p256_montjadd_alt",call_p256_montjadd_alt); diff --git a/common/mlkem_mldsa.ml b/common/mlkem_mldsa.ml index 2858dc298..06c3afad2 100644 --- a/common/mlkem_mldsa.ml +++ b/common/mlkem_mldsa.ml @@ -832,7 +832,7 @@ let Q_MUL_COMM = WORD_RULE (* Normalization rules for VPSRLQ/VMOVSHDUP patterns *) let USHR32_SUBWORD = WORD_BLAST `word_subword (word_ushr (x:int64) 32) (0,32):int32 = word_subword x (32,32)`;; - + let DUP32_SUBWORD = WORD_BLAST `word_subword (word_duplicate (word_subword (x:int64) (32,32):int32):int64) (0,32):int32 = word_subword x (32,32)`;; @@ -871,6 +871,19 @@ let MEMORY_128_FROM_32_TAC = READ_MEMORY_MERGE_CONV 2 (subst[itm,n_tm] pat') in MP_TAC(end_itlist CONJ (map f (0--(n-1))));; +(* ML-DSA use_hint Merge 4 x bytes32 into bytes128 at a given base+offset *) +(* ------------------------------------------------------------------------- *) + +let USE_HINT_MEMORY_128_FROM_32_TAC = + let a_tm = `a:int64` and n_tm = `n:num` and i64_ty = `:int64` + and pat = `read (memory :> bytes128(word_add a (word n))) s0` in + fun v boff n -> + let pat' = subst[mk_var(v,i64_ty),a_tm] pat in + let f i = + let itm = mk_small_numeral(boff + 16*i) in + READ_MEMORY_MERGE_CONV 2 (subst[itm,n_tm] pat') in + MP_TAC(end_itlist CONJ (map f (0--(n-1))));; + (* ------------------------------------------------------------------------- *) (* From |- (x == y) (mod m) /\ P to |- (x == y) (mod n) /\ P *) (* ------------------------------------------------------------------------- *) @@ -1949,3 +1962,175 @@ let SIMD_SIMPLIFY_ABBREV_TAC = let tms = sort free_in (find_terms pam (rand(concl th''))) in (MP_TAC th'' THEN MAP_EVERY AUTO_ABBREV_TAC tms THEN DISCH_TAC) (asl,w) in TRY(FIRST_X_ASSUM(ttac o check (simdable o concl)));; +(* ML-DSA use_hint shared infrastructure lemmas *) +(* Used by both poly_use_hint_32 and poly_use_hint_88 proofs *) +(* ========================================================================= *) + +(* ival equals val for values in [0, Q) where Q = 8380417 < 2^31 *) +let MLDSA_IVAL_VAL = prove( + `!a:int32. val a < 8380417 ==> ival a = &(val a)`, + GEN_TAC THEN DISCH_TAC THEN + SIMP_TAC[ival; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC);; + +(* For natural numbers, &n is never < -2^31 *) +let INT_POS_NEG_BOUND = prove(`!n. ~((&n:int) < --(&2147483648))`, + GEN_TAC THEN REWRITE_TAC[INT_NOT_LT] THEN + MP_TAC(SPEC `n:num` INT_POS) THEN INT_ARITH_TAC);; + +(* val(iword(&n)) = n for n < 2^31 *) +let VAL_IWORD_NUM_32 = prove( + `!n. n < 2147483648 ==> val(iword(&n):int32) = n`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(ISPECL [`&n:int`] (INST_TYPE [`:32`,`:N`] INT_VAL_IWORD)) THEN + REWRITE_TAC[DIMINDEX_32; INT_POS] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; + REWRITE_TAC[INT_OF_NUM_EQ] THEN SIMP_TAC[]]);; + +(* word_ile x 0 in terms of bit 31 (signed non-positive check) *) +let WORD_ILE_ZERO_32 = BITBLAST_RULE + `!x:int32. word_ile x (word 0) <=> bit 31 x \/ x = word 0`;; + +(* val(word_and x (word 15)) = val x MOD 16 *) +let VAL_WORD_AND_15_32 = BITBLAST_RULE + `!x:int32. val(word_and x (word 15:int32)) = val x MOD 16`;; + +(* word_and x all-ones = x *) +let WORD_AND_ONES_32 = prove( + `!x:int32. word_and x (word 4294967295) = x`, + GEN_TAC THEN SUBGOAL_THEN `(word 4294967295 : int32) = word_not(word 0)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; REWRITE_TAC[WORD_AND_NOT0]]);; + +(* word_mul x 1 = x *) +let WORD_MUL_1_32 = prove( + `!x:int32. word_mul x (word 1) = x`, + GEN_TAC THEN REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN REWRITE_TAC[MULT_CLAUSES] THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `x:int32` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV);; + +(* Bridge lemmas: derive both real_gt and int_gt from a single NUM fact. + Needed for native mode where real_gt and int_gt are distinct types. *) +let REAL_INT_GT_BRIDGE = prove( + `!a:num b c. a <= b * c ==> + ~(real_gt (&a - &b * &c) (&0)) /\ ~(int_gt (&a - &b * &c) (&0))`, + REPEAT GEN_TAC THEN DISCH_TAC THEN CONJ_TAC THENL + [REWRITE_TAC[real_gt; REAL_NOT_LT] THEN + MP_TAC(REWRITE_RULE[GSYM REAL_OF_NUM_LE; GSYM REAL_OF_NUM_MUL] (ASSUME `a <= b * c`)) THEN REAL_ARITH_TAC; + REWRITE_TAC[INT_GT; INT_NOT_LT] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_MUL] (ASSUME `a <= b * c`)) THEN INT_ARITH_TAC]);; + +let REAL_INT_GT_BRIDGE_POS = prove( + `!a:num b c. ~(a <= b * c) ==> + real_gt (&a - &b * &c) (&0) /\ int_gt (&a - &b * &c) (&0)`, + REPEAT GEN_TAC THEN REWRITE_TAC[NOT_LE] THEN DISCH_TAC THEN CONJ_TAC THENL + [REWRITE_TAC[real_gt] THEN + MP_TAC(REWRITE_RULE[GSYM REAL_OF_NUM_LT; GSYM REAL_OF_NUM_MUL] (ASSUME `b * c < a`)) THEN REAL_ARITH_TAC; + REWRITE_TAC[INT_GT] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL] (ASSUME `b * c < a`)) THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Shared helper lemmas for UseHint proofs *) +(* ========================================================================= *) + +let DIV_SANDWICH = prove( + `!x d k. ~(d = 0) /\ k * d <= x /\ x < (k + 1) * d ==> x DIV d = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `k <= x DIV d` ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_RDIV_EQ] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `x DIV d < k + 1` ASSUME_TAC THENL + [ASM_SIMP_TAC[RDIV_LT_EQ] THEN ASM_ARITH_TAC; ASM_ARITH_TAC]);; + +let INT_MOD_RESIDUE = prove( + `!r m. ~(m = 0) ==> (&r:int) - &(r DIV m) * &m = &(r MOD m)`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPECL [`r:num`; `m:num`] (CONJUNCT1 DIVISION_SIMP)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD; + GSYM INT_OF_NUM_EQ] THEN + INT_ARITH_TAC);; + +(* ========================================================================= *) +(* FIPS 204 UseHint definitions (Algorithms 36 and 40) *) +(* ========================================================================= *) + +let mldsa_cmod = new_definition + `mldsa_cmod (r:num) (m:num) : int = + if (r MOD m) * 2 <= m then &(r MOD m) else &(r MOD m) - &m`;; + +let mldsa_decompose_88 = new_definition + `mldsa_decompose_88 (r:num) : num # int = + let r0 = mldsa_cmod r 190464 in + if &r - r0 = &8380416 then (0, r0 - &1) + else (num_of_int((&r - r0) div &190464), r0)`;; + +let decompose_88_r1 = new_definition + `decompose_88_r1 (r:num) : num = FST(mldsa_decompose_88 r)`;; + +let decompose_88_r0 = new_definition + `decompose_88_r0 (r:num) : int = SND(mldsa_decompose_88 r)`;; + +let mldsa_use_hint_88 = new_definition + `mldsa_use_hint_88 (h:num) (r:num) : num = + let (r1, r0) = mldsa_decompose_88 r in + if h = 1 /\ r0 > &0 then if r1 = 43 then 0 else r1 + 1 + else if h = 1 /\ r0 <= &0 then if r1 = 0 then 43 else r1 - 1 + else r1`;; + +let LOWER_NONWRAP_R1_88 = prove( + `!r. r MOD 190464 * 2 <= 190464 /\ + ~((&r:int) - &(r MOD 190464) = &8380416) ==> + decompose_88_r1 r = r DIV 190464`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; + NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 = 190464 * r DIV 190464` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV);; + +let UPPER_NONWRAP_R1_88 = prove( + `!r. ~(r MOD 190464 * 2 <= 190464) /\ + ~((&r:int) - (&(r MOD 190464) - &190464) = &8380416) ==> + decompose_88_r1 r = r DIV 190464 + 1`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` ASSUME_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 + 190464 = (r DIV 190464 + 1) * 190464` + ASSUME_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + MP_TAC(SPECL [`(r DIV 190464 + 1) * 190464`; `190464`] DIV_MULT) THEN + ARITH_TAC);; + +let MLDSA_USE_HINT_88_UNFOLD = prove( + `!h r. mldsa_use_hint_88 h r = + (if h = 1 /\ decompose_88_r0 r > &0 + then if decompose_88_r1 r = 43 then 0 else decompose_88_r1 r + 1 + else if h = 1 /\ decompose_88_r0 r <= &0 + then if decompose_88_r1 r = 0 then 43 else decompose_88_r1 r - 1 + else decompose_88_r1 r)`, + REPEAT GEN_TAC THEN + REWRITE_TAC[mldsa_use_hint_88; decompose_88_r1; decompose_88_r0] THEN + SPEC_TAC(`mldsa_decompose_88 r`, `p:num#int`) THEN + REWRITE_TAC[FORALL_PAIR_THM] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN REWRITE_TAC[]);; diff --git a/include/s2n-bignum.h b/include/s2n-bignum.h index 2c9273750..cc0f0513b 100644 --- a/include/s2n-bignum.h +++ b/include/s2n-bignum.h @@ -1054,6 +1054,10 @@ extern void mldsa_pointwise_acc_l7_x86(int32_t c[S2N_BIGNUM_STATIC 256], const i // Input a[256] (signed 32-bit words); output a[256] (signed 32-bit words) extern void mldsa_reduce(int32_t a[S2N_BIGNUM_STATIC 256]); +// Use hint to correct high bits of decomposition for ML-DSA (parameter set 44) +// Inputs a[256], h[256] (signed 32-bit words); output b[256] (signed 32-bit words) +extern void mldsa_poly_use_hint_88(int32_t b[S2N_BIGNUM_STATIC 256], const int32_t a[S2N_BIGNUM_STATIC 256], const int32_t h[S2N_BIGNUM_STATIC 256]); + // Scalar product of 2-element polynomial vectors in NTT domain, with mulcache // Inputs a[512], b[512], bt[256] (signed 16-bit words); output r[256] (signed 16-bit words) extern void mlkem_basemul_k2(int16_t r[S2N_BIGNUM_STATIC 256],const int16_t a[S2N_BIGNUM_STATIC 512],const int16_t b[S2N_BIGNUM_STATIC 512],const int16_t bt[S2N_BIGNUM_STATIC 256]); diff --git a/tests/test.c b/tests/test.c index a545ceeac..fedba5751 100644 --- a/tests/test.c +++ b/tests/test.c @@ -13290,7 +13290,7 @@ int test_mldsa_intt(void) } // Reference implementation for mldsa_nttunpack -// +// // SPECIFICATION: // This function performs an 8x8 matrix transpose within each of 4 blocks of 64 coefficients. // It converts from AVX2 lane-interleaved layout to sequential layout. @@ -13311,16 +13311,16 @@ void reference_mldsa_nttunpack(int32_t a[256]) { int32_t temp[256]; int i; - + // Copy input to temp for (i = 0; i < 256; i++) { temp[i] = a[i]; } - + // Apply the transpose specification to each of 4 blocks for (int block = 0; block < 4; block++) { int base = block * 64; - + for (i = 0; i < 64; i++) { // Specification: output[base + i] = input[base + (i % 8) * 8 + (i / 8)] int src_index = base + (i % 8) * 8 + (i / 8); @@ -13331,7 +13331,7 @@ void reference_mldsa_nttunpack(int32_t a[256]) int test_mldsa_nttunpack(void) { - // Skip test on non-x86_64 architectures + // Skip test on non-x86_64 architectures if (get_arch_name() != ARCH_X86_64) { return 0; } @@ -13386,6 +13386,97 @@ int test_mldsa_nttunpack(void) #endif } +// Reference implementation of mldsa_poly_use_hint_88 for ML-DSA parameter set 44 +// GAMMA2 = (Q-1)/88 = 95232, output range [0, 43] +// Matches the exact assembly algorithm using SQDMULH-based Barrett decomposition +void reference_mldsa_poly_use_hint_88(int32_t b[256], const int32_t a[256], const int32_t h[256]) +{ + const int32_t TWO_GAMMA2 = 190464; + const int32_t THRESHOLD = 8285184; // 87 * GAMMA2 + const int32_t BARRETT = 1477838209; // 0x58160581 + for (int i = 0; i < 256; i++) { + int32_t ai = a[i]; + // Decompose using SQDMULH + SRSHR (matching assembly) + // sqdmulh: (2 * ai * BARRETT) >> 32 + int32_t sqdmulh_result = (int32_t)(((int64_t)2 * ai * BARRETT) >> 32); + // srshr by 17: (x + (1 << 16)) >> 17 (signed rounding shift right) + int32_t a1 = (sqdmulh_result + (1 << 16)) >> 17; + // a0 = ai - a1 * 2*GAMMA2 + int32_t a0 = ai - a1 * TWO_GAMMA2; + // Wraparound: if ai > threshold, set a1=0, a0 += -1 + if (ai > THRESHOLD) { + a1 = 0; + a0 = a0 + (-1); + } + // delta = (a0 <= 0) ? -1 : 1 + int32_t delta = (a0 <= 0) ? -1 : 1; + // result = a1 + delta * hint + int32_t result = a1 + delta * h[i]; + // Assembly uses CMGT(signed)+BIC+UMIN: if result > 43 (signed), zero it, + // then unsigned min with 43. For negative values (-1), signed compare + // with 43 is false so BIC keeps it, then UMIN(0xFFFFFFFF, 43) = 43. + // This matches ML-DSA spec where -1 mod 44 = 43. + uint32_t uresult = (uint32_t)result; + if (result > 43) uresult = 0; + if (uresult > 43) uresult = 43; + b[i] = (int32_t)uresult; + } +} + +int test_mldsa_poly_use_hint_88(void) +{ + // Skip test on non-aarch64 architectures (ARM-only function) + if (get_arch_name() != ARCH_AARCH64) { + return 0; + } + +#ifdef __aarch64__ + uint64_t t, i; + int32_t a[256] __attribute__((aligned(32))); + int32_t h[256] __attribute__((aligned(32))); + int32_t b_asm[256] __attribute__((aligned(32))); + int32_t b_ref[256] __attribute__((aligned(32))); + + printf("Testing mldsa_poly_use_hint_88 with %d cases\n", tests); + + for (t = 0; t < tests; ++t) { + // Generate random coefficients in [0, Q) + for (i = 0; i < 256; ++i) { + a[i] = (int32_t)(random64() % 8380417); + h[i] = (int32_t)(random64() % 2); // hint is 0 or 1 + } + + // Compute reference result + reference_mldsa_poly_use_hint_88(b_ref, a, h); + + // Call the assembly implementation + mldsa_poly_use_hint_88(b_asm, a, h); + + // Compare results + for (i = 0; i < 256; ++i) { + if (b_asm[i] != b_ref[i]) { + printf("Error in mldsa_poly_use_hint_88 element i = %"PRIu64"; " + "asm = %"PRId32" ref = %"PRId32" " + "(a[i] = %"PRId32", h[i] = %"PRId32")\n", + i, b_asm[i], b_ref[i], a[i], h[i]); + return 1; + } + } + + if (VERBOSE) { + printf("OK: mldsa_poly_use_hint_88: a[0]=0x%08"PRIx32", h[0]=%"PRId32" => b[0]=%"PRId32"\n", + a[0], h[0], b_asm[0]); + } + } + + printf("All OK\n"); + return 0; +#else + return 0; +#endif +} + + int test_p256_montjadd(void) { uint64_t t, k; printf("Testing p256_montjadd with %d cases\n",tests); @@ -16784,6 +16875,7 @@ int main(int argc, char *argv[]) functionaltest(all,"mldsa_pointwise_acc_l5",test_mldsa_pointwise_acc_l5); functionaltest(all,"mldsa_pointwise_acc_l7",test_mldsa_pointwise_acc_l7); functionaltest(all,"mldsa_reduce",test_mldsa_reduce); + functionaltest(all,"mldsa_poly_use_hint_88",test_mldsa_poly_use_hint_88); functionaltest(all,"mlkem_basemul_k2",test_mlkem_basemul_k2); functionaltest(all,"mlkem_basemul_k3",test_mlkem_basemul_k3); functionaltest(all,"mlkem_basemul_k4",test_mlkem_basemul_k4); diff --git a/tools/collect-signatures.py b/tools/collect-signatures.py index 8e28befb3..5f7e4ec91 100644 --- a/tools/collect-signatures.py +++ b/tools/collect-signatures.py @@ -305,6 +305,7 @@ def stripPrefixes(s, prefixes): "mldsa_pointwise_acc_l4", "mldsa_pointwise_acc_l5", "mldsa_pointwise_acc_l7", + "mldsa_poly_use_hint_88", "mlkem_ntt", "mlkem_intt", "mlkem_mulcache_compute",