diff --git a/arm/Makefile b/arm/Makefile index 01c8ec7ad..93942721d 100644 --- a/arm/Makefile +++ b/arm/Makefile @@ -258,6 +258,7 @@ BIGNUM_OBJ = curve25519/bignum_add_p25519.o \ mldsa/mldsa_pointwise_acc_l7.o \ mldsa/mldsa_ntt.o \ mldsa/mldsa_pointwise.o \ + mldsa/mldsa_poly_use_hint_32.o \ mlkem/mlkem_basemul_k2.o \ mlkem/mlkem_basemul_k3.o \ mlkem/mlkem_basemul_k4.o \ diff --git a/arm/mldsa/mldsa_poly_use_hint_32.S b/arm/mldsa/mldsa_poly_use_hint_32.S new file mode 100644 index 000000000..b28e6f711 --- /dev/null +++ b/arm/mldsa/mldsa_poly_use_hint_32.S @@ -0,0 +1,157 @@ +// Copyright (c) The mldsa-native project authors +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +// ---------------------------------------------------------------------------- +// Use hint to correct high bits of decomposition (parameter sets 65/87) +// Inputs a[256] (unsigned 32-bit, in [0,Q)), h[256] (hint bits, 0 or 1) +// Output b[256] (unsigned 32-bit, in [0,16)) +// +// Implements mld_use_hint for ML-DSA parameter sets 65/87: +// GAMMA2 = (Q-1)/32 = 261888 +// 2*GAMMA2 = 523776 +// Output range: [0, 15] +// +// Algorithm per coefficient: +// 1. Decompose: a1 = round_down(a / 523776), a0 = a - a1*523776 +// If a > 31*GAMMA2 = 8118528, wrap: a1=0, a0=a-Q +// 2. delta = (a0 <= 0) ? -1 : 1 +// 3. b = (a1 + delta * h) & 15 +// +// extern void mldsa_poly_use_hint_32 +// (int32_t b[static 256], const int32_t a[static 256], +// const int32_t h[static 256]); +// +// Standard ARM ABI: X0 = b, X1 = a, X2 = h +// ---------------------------------------------------------------------------- +#include "_internal_s2n_bignum_arm.h" + + S2N_BN_SYM_VISIBILITY_DIRECTIVE(mldsa_poly_use_hint_32) + S2N_BN_FUNCTION_TYPE_DIRECTIVE(mldsa_poly_use_hint_32) + S2N_BN_SYM_PRIVACY_DIRECTIVE(mldsa_poly_use_hint_32) + .text + .balign 4 + +S2N_BN_SYMBOL(mldsa_poly_use_hint_32): + CFI_START + +// This matches the code in the mldsa-native repository +// https://github.com/pq-code-package/mldsa-native/blob/main/mldsa/src/native/aarch64/src/poly_use_hint_32_asm.S + +// Load constants into SIMD registers + +// v20 = Q = 8380417 (unused in computation but part of original code) + mov w4, #0xe001 + movk w4, #0x7f, lsl #16 + dup v20.4s, w4 + +// v21 = 31*GAMMA2 = 8118528 = 0x7be100 (wraparound threshold) + mov w5, #0xe100 + movk w5, #0x7b, lsl #16 + dup v21.4s, w5 + +// v22 = 2*GAMMA2 = 523776 = 0x7fe00 (decompose multiplier) + mov w7, #0xfe00 + movk w7, #0x7, lsl #16 + dup v22.4s, w7 + +// v23 = Barrett constant = 0x40100401 = 1074791425 +// Used for SQDMULH-based Barrett reduction: a1 ~= (2*a*c) >> 49 + mov w11, #0x0401 + movk w11, #0x4010, lsl #16 + dup v23.4s, w11 + +// v24 = mask 15 = 0x0000000f (for final AND to compute mod 16) + movi v24.4s, #0xf + +// Loop counter: 16 iterations, processing 16 coefficients per iteration +// 16 * 16 = 256 total coefficients + mov x3, #0x10 + +Lmldsa_poly_use_hint_32_loop: + // Load 16 coefficients from a (4 vectors of 4 int32s) + ldr q1, [x1, #0x10] + ldr q2, [x1, #0x20] + ldr q3, [x1, #0x30] + ldr q0, [x1], #0x40 + + // Load 16 hint bits from h (4 vectors of 4 int32s) + ldr q5, [x2, #0x10] + ldr q6, [x2, #0x20] + ldr q7, [x2, #0x30] + ldr q4, [x2], #0x40 + + // --- Process v1 (coefficients at offset +16) --- + // Decompose: a1 = sqdmulh(a, barrett_const) >> 18 + sqdmulh v17.4s, v1.4s, v23.4s + srshr v17.4s, v17.4s, #0x12 + // Check wraparound: mask = (a > 31*GAMMA2) ? all_ones : 0 + cmgt v25.4s, v1.4s, v21.4s + // a0 = a - a1 * 2*GAMMA2 + mls v1.4s, v17.4s, v22.4s + // If wraparound: a1 = 0 (clear a1 where mask is set) + bic v17.16b, v17.16b, v25.16b + // If wraparound: a0 -= 1 (add all_ones = -1) + add v1.4s, v1.4s, v25.4s + // delta = (a0 <= 0) ? all_ones : 0 + cmle v1.4s, v1.4s, #0 + // delta = (a0 <= 0) ? -1 : 1 (set bit 0) + orr v1.4s, #0x1 + // b = a1 + delta * hint + mla v17.4s, v1.4s, v5.4s + // b = b & 15 + and v17.16b, v17.16b, v24.16b + + // --- Process v2 (coefficients at offset +32) --- + sqdmulh v18.4s, v2.4s, v23.4s + srshr v18.4s, v18.4s, #0x12 + cmgt v25.4s, v2.4s, v21.4s + mls v2.4s, v18.4s, v22.4s + bic v18.16b, v18.16b, v25.16b + add v2.4s, v2.4s, v25.4s + cmle v2.4s, v2.4s, #0 + orr v2.4s, #0x1 + mla v18.4s, v2.4s, v6.4s + and v18.16b, v18.16b, v24.16b + + // --- Process v3 (coefficients at offset +48) --- + sqdmulh v19.4s, v3.4s, v23.4s + srshr v19.4s, v19.4s, #0x12 + cmgt v25.4s, v3.4s, v21.4s + mls v3.4s, v19.4s, v22.4s + bic v19.16b, v19.16b, v25.16b + add v3.4s, v3.4s, v25.4s + cmle v3.4s, v3.4s, #0 + orr v3.4s, #0x1 + mla v19.4s, v3.4s, v7.4s + and v19.16b, v19.16b, v24.16b + + // --- Process v0 (coefficients at offset +0, loaded last for post-increment) --- + sqdmulh v16.4s, v0.4s, v23.4s + srshr v16.4s, v16.4s, #0x12 + cmgt v25.4s, v0.4s, v21.4s + mls v0.4s, v16.4s, v22.4s + bic v16.16b, v16.16b, v25.16b + add v0.4s, v0.4s, v25.4s + cmle v0.4s, v0.4s, #0 + orr v0.4s, #0x1 + mla v16.4s, v0.4s, v4.4s + and v16.16b, v16.16b, v24.16b + + // Store 16 output coefficients + str q17, [x0, #0x10] + str q18, [x0, #0x20] + str q19, [x0, #0x30] + str q16, [x0], #0x40 + + // Decrement loop counter and branch + subs x3, x3, #0x1 + b.ne Lmldsa_poly_use_hint_32_loop + + CFI_RET + +S2N_BN_SIZE_DIRECTIVE(mldsa_poly_use_hint_32) + +#if defined(__linux__) && defined(__ELF__) +.section .note.GNU-stack, "", %progbits +#endif diff --git a/arm/proofs/mldsa_poly_use_hint_32.ml b/arm/proofs/mldsa_poly_use_hint_32.ml new file mode 100644 index 000000000..6d133ea6b --- /dev/null +++ b/arm/proofs/mldsa_poly_use_hint_32.ml @@ -0,0 +1,921 @@ +(* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Use hint to correct high bits of decomposition (ML-DSA, param 65/87). *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "common/mlkem_mldsa.ml";; + + +(**** print_literal_from_elf "arm/mldsa/mldsa_poly_use_hint_32.o";; + ****) + +let mldsa_poly_use_hint_32_mc = define_assert_from_elf + "mldsa_poly_use_hint_32_mc" "arm/mldsa/mldsa_poly_use_hint_32.o" +(*** BYTECODE START ***) +[ + 0x529c0024; (* arm_MOV W4 (rvalue (word 57345)) *) + 0x72a00fe4; (* arm_MOVK W4 (word 127) 16 *) + 0x4e040c94; (* arm_DUP_GEN Q20 X4 32 128 *) + 0x529c2005; (* arm_MOV W5 (rvalue (word 57600)) *) + 0x72a00f65; (* arm_MOVK W5 (word 123) 16 *) + 0x4e040cb5; (* arm_DUP_GEN Q21 X5 32 128 *) + 0x529fc007; (* arm_MOV W7 (rvalue (word 65024)) *) + 0x72a000e7; (* arm_MOVK W7 (word 7) 16 *) + 0x4e040cf6; (* arm_DUP_GEN Q22 X7 32 128 *) + 0x5280802b; (* arm_MOV W11 (rvalue (word 1025)) *) + 0x72a8020b; (* arm_MOVK W11 (word 16400) 16 *) + 0x4e040d77; (* arm_DUP_GEN Q23 X11 32 128 *) + 0x4f0005f8; (* arm_MOVI Q24 (word 64424509455) *) + 0xd2800203; (* arm_MOV X3 (rvalue (word 16)) *) + 0x3dc00421; (* arm_LDR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3dc00822; (* arm_LDR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3dc00c23; (* arm_LDR Q3 X1 (Immediate_Offset (word 48)) *) + 0x3cc40420; (* arm_LDR Q0 X1 (Postimmediate_Offset (word 64)) *) + 0x3dc00445; (* arm_LDR Q5 X2 (Immediate_Offset (word 16)) *) + 0x3dc00846; (* arm_LDR Q6 X2 (Immediate_Offset (word 32)) *) + 0x3dc00c47; (* arm_LDR Q7 X2 (Immediate_Offset (word 48)) *) + 0x3cc40444; (* arm_LDR Q4 X2 (Postimmediate_Offset (word 64)) *) + 0x4eb7b431; (* arm_SQDMULH_VEC Q17 Q1 Q23 32 128 *) + 0x4f2e2631; (* arm_SRSHR_VEC Q17 Q17 18 32 128 *) + 0x4eb53439; (* arm_CMGT_VEC Q25 Q1 Q21 32 128 *) + 0x6eb69621; (* arm_MLS_VEC Q1 Q17 Q22 32 128 *) + 0x4e791e31; (* arm_BIC_VEC Q17 Q17 Q25 128 *) + 0x4eb98421; (* arm_ADD_VEC Q1 Q1 Q25 32 128 *) + 0x6ea09821; (* arm_CMLE_VEC_ZERO Q1 Q1 32 128 *) + 0x4f001421; (* arm_ORR_VEC Q1 Q1 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea59431; (* arm_MLA_VEC Q17 Q1 Q5 32 128 *) + 0x4e381e31; (* arm_AND_VEC Q17 Q17 Q24 128 *) + 0x4eb7b452; (* arm_SQDMULH_VEC Q18 Q2 Q23 32 128 *) + 0x4f2e2652; (* arm_SRSHR_VEC Q18 Q18 18 32 128 *) + 0x4eb53459; (* arm_CMGT_VEC Q25 Q2 Q21 32 128 *) + 0x6eb69642; (* arm_MLS_VEC Q2 Q18 Q22 32 128 *) + 0x4e791e52; (* arm_BIC_VEC Q18 Q18 Q25 128 *) + 0x4eb98442; (* arm_ADD_VEC Q2 Q2 Q25 32 128 *) + 0x6ea09842; (* arm_CMLE_VEC_ZERO Q2 Q2 32 128 *) + 0x4f001422; (* arm_ORR_VEC Q2 Q2 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea69452; (* arm_MLA_VEC Q18 Q2 Q6 32 128 *) + 0x4e381e52; (* arm_AND_VEC Q18 Q18 Q24 128 *) + 0x4eb7b473; (* arm_SQDMULH_VEC Q19 Q3 Q23 32 128 *) + 0x4f2e2673; (* arm_SRSHR_VEC Q19 Q19 18 32 128 *) + 0x4eb53479; (* arm_CMGT_VEC Q25 Q3 Q21 32 128 *) + 0x6eb69663; (* arm_MLS_VEC Q3 Q19 Q22 32 128 *) + 0x4e791e73; (* arm_BIC_VEC Q19 Q19 Q25 128 *) + 0x4eb98463; (* arm_ADD_VEC Q3 Q3 Q25 32 128 *) + 0x6ea09863; (* arm_CMLE_VEC_ZERO Q3 Q3 32 128 *) + 0x4f001423; (* arm_ORR_VEC Q3 Q3 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea79473; (* arm_MLA_VEC Q19 Q3 Q7 32 128 *) + 0x4e381e73; (* arm_AND_VEC Q19 Q19 Q24 128 *) + 0x4eb7b410; (* arm_SQDMULH_VEC Q16 Q0 Q23 32 128 *) + 0x4f2e2610; (* arm_SRSHR_VEC Q16 Q16 18 32 128 *) + 0x4eb53419; (* arm_CMGT_VEC Q25 Q0 Q21 32 128 *) + 0x6eb69600; (* arm_MLS_VEC Q0 Q16 Q22 32 128 *) + 0x4e791e10; (* arm_BIC_VEC Q16 Q16 Q25 128 *) + 0x4eb98400; (* arm_ADD_VEC Q0 Q0 Q25 32 128 *) + 0x6ea09800; (* arm_CMLE_VEC_ZERO Q0 Q0 32 128 *) + 0x4f001420; (* arm_ORR_VEC Q0 Q0 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea49410; (* arm_MLA_VEC Q16 Q0 Q4 32 128 *) + 0x4e381e10; (* arm_AND_VEC Q16 Q16 Q24 128 *) + 0x3d800411; (* arm_STR Q17 X0 (Immediate_Offset (word 16)) *) + 0x3d800812; (* arm_STR Q18 X0 (Immediate_Offset (word 32)) *) + 0x3d800c13; (* arm_STR Q19 X0 (Immediate_Offset (word 48)) *) + 0x3c840410; (* arm_STR Q16 X0 (Postimmediate_Offset (word 64)) *) + 0xf1000463; (* arm_SUBS X3 X3 (rvalue (word 1)) *) + 0x54fff961; (* arm_BNE (word 2096940) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let MLDSA_USE_HINT_32_EXEC = ARM_MK_EXEC_RULE mldsa_poly_use_hint_32_mc;; + +(* Per-element word function matching the assembly computation *) +let mldsa_use_hint_32_asm = new_definition + `mldsa_use_hint_32_asm (a:int32) (h:int32) : int32 = + let a1 = word_ishr_round (word_2smulh a (word 1074791425)) 18 in + let m:int32 = word_neg(word(bitval(word_igt a (word 8118528)))) in + let a0 = word_add (word_sub a (word_mul a1 (word 523776))) m in + let a1' = word_and a1 (word_not m) in + let delta:int32 = word_or (word_neg(word(bitval(word_ile a0 (word 0))))) (word 1) in + word_and (word_add a1' (word_mul delta h)) (word 15)`;; + +(* Numeric description of the assembly's UseHint path, exposing the Barrett + approximation used by the code. Connected to the FIPS 204 definition + mldsa_use_hint_32 via MLDSA_USE_HINT_32_EQUIV below. *) +let mldsa_use_hint_32_code = new_definition + `mldsa_use_hint_32_code (a:num) (h:num) = + let a1 = ((((a + 127) DIV 128) * 1025 + 2097152) DIV 4194304) MOD 16 in + let a0:int = &a - &a1 * &523776 in + let a0' = if a0 > &4190208 then a0 - &8380417 else a0 in + if h = 0 then a1 + else if a0' > &0 then (a1 + 1) MOD 16 + else (a1 + 15) MOD 16`;; + +(* ========================================================================= *) +(* Functional correctness helper lemmas *) +(* ========================================================================= *) + +let WORD_2SMULH_NOSATURATE_32 = prove( + `!a:int32. val a < 8380417 + ==> word_2smulh a (word 1074791425:int32) : int32 = + iword((&2 * &(val a) * &1074791425) div &2 pow 32)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[word_2smulh; DIMINDEX_32] THEN + ASM_SIMP_TAC[MLDSA_IVAL_VAL] THEN + CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[iword_saturate; word_INT_MIN; word_INT_MAX; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN + SIMP_TAC[INT_OF_NUM_DIV] THEN + REWRITE_TAC[INT_POS_NEG_BOUND] THEN + SUBGOAL_THEN `~(&((2 * val(a:int32) * 1074791425) DIV 4294967296):int > &2147483647)` + (fun th -> REWRITE_TAC[th]) THEN + SUBGOAL_THEN `(2 * val(a:int32) * 1074791425) DIV 4294967296 <= 2147483647` + (fun th -> MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LE] th) THEN INT_ARITH_TAC) THEN + TRANS_TAC LE_TRANS `(2 * 8380416 * 1074791425) DIV 4294967296` THEN + CONJ_TAC THENL + [MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV]);; + +let VAL_DECOMPOSE_A1 = prove( + `!a:int32. val a < 8380417 + ==> val(word_ishr_round (word_2smulh a (word 1074791425:int32)) 18 : int32) + = ((2 * val a * 1074791425) DIV 4294967296 + 131072) DIV 262144`, + GEN_TAC THEN DISCH_TAC THEN + ASM_SIMP_TAC[WORD_2SMULH_NOSATURATE_32] THEN + SUBGOAL_THEN `(2 * val(a:int32) * 1074791425) DIV 4294967296 < 2147483648` + ASSUME_TAC THENL + [TRANS_TAC LT_TRANS `(2 * 8380416 * 1074791425) DIV 4294967296 + 1` THEN + CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_SIMP_TAC[VAL_IWORD_NUM_32] THEN + ABBREV_TAC `t:int32 = iword(&((2 * val(a:int32) * 1074791425) DIV 4294967296))` THEN + SUBGOAL_THEN `val(t:int32) = (2 * val(a:int32) * 1074791425) DIV 4294967296` + ASSUME_TAC THENL + [EXPAND_TAC "t" THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `val(t:int32) < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[word_ishr_round] THEN CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV THEN + SUBGOAL_THEN `ival(t:int32) = &(val t)` ASSUME_TAC THENL + [SIMP_TAC[ival; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `(val(t:int32) + 131072) DIV 262144 < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + TRANS_TAC LT_TRANS `(4194303 + 131072) DIV 262144 + 1` THEN CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + ASM_SIMP_TAC[VAL_IWORD_NUM_32] THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN + UNDISCH_THEN `val(t:int32) = (2 * val(a:int32) * 1074791425) DIV 4294967296` + (SUBST1_TAC o SYM) THEN ASM_REWRITE_TAC[]);; + +let WORD_IGT_THRESHOLD_32 = BITBLAST_RULE + `!a:int32. val a < 8380417 + ==> word_igt a (word 8118528:int32) <=> val a > 8118528`;; + +let A1_BOUND = prove( + `!a. a < 8380417 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 <= 16`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8380416 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `69205952` (SPEC `d * 1025 + 2097152` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 1025 <= 65472 * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A1_WRAP = prove( + `!a. 8118528 < a /\ a < 8380417 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 = 16`, + GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `16 <= ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304` + ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `a + 127` (SPEC `8118529 + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `d * 1025 + 2097152` (SPEC `67108977` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `63427 * 1025 <= d * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + MP_TAC(SPEC `a:num` A1_BOUND) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_ARITH_TAC);; + +let A1_BOUND_NOWRAP = prove( + `!a. a <= 8118528 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 <= 15`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8118528 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `67108802` (SPEC `d * 1025 + 2097152` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 1025 <= 63426 * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A0_UPPER_32 = prove( + `!a. a <= 8118528 + ==> a < (((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 + 1) * 523776`, + GEN_TAC THEN DISCH_TAC THEN + ABBREV_TAC `nv = ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304` THEN + SUBGOAL_THEN `nv * 4194304 <= (a + 127) DIV 128 * 1025 + 2097152` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN + MP_TAC(SPECL [`(a + 127) DIV 128 * 1025 + 2097152`; `4194304`] (CONJUNCT1 DIVISION_SIMP)) THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(a + 127) DIV 128 <= 63426` ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `8118528 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 4194304 <= 63426 * 1025 + 2097152` ASSUME_TAC THENL + [SUBGOAL_THEN `(a + 127) DIV 128 * 1025 <= 63426 * 1025` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ASM_ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 15` ASSUME_TAC THENL + [CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_ARITH_TAC);; + +let WORD_SUB_SIGN_32 = BITBLAST_RULE + `!a:int32 b:int32. val b <= 7856640 /\ val a <= 8118528 ==> + ((bit 31 (word_sub a b) \/ word_sub a b = word 0) <=> val a <= val b)`;; + +let WRAP_A0_NEGATIVE = BITBLAST_RULE + `!a:int32. val a < 8380417 /\ val a > 8118528 + ==> bit 31 (word_add (word_sub a (word 8380416:int32)) (word 4294967295:int32))`;; + +(* Barrett equivalence: assembly and C decomposition formulas agree. + Both compute round_half_down(a / 523776) via different Barrett + approximation paths. Proved by case analysis on 17 output intervals + using DIV_MONO to sandwich both LHS and RHS to the same constant. *) +let BARRETT_INTERVAL_32 = prove( + `!a lo hi k. + lo <= a /\ a <= hi /\ + k * 262144 <= (2 * lo * 1074791425) DIV 4294967296 + 131072 /\ + (2 * hi * 1074791425) DIV 4294967296 + 131072 < (k + 1) * 262144 /\ + k * 4194304 <= (lo + 127) DIV 128 * 1025 + 2097152 /\ + (hi + 127) DIV 128 * 1025 + 2097152 < (k + 1) * 4194304 + ==> ((2 * a * 1074791425) DIV 4294967296 + 131072) DIV 262144 = k /\ + ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + CONJ_TAC THEN MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THENL + [CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(2 * lo * 1074791425) DIV 4294967296 + 131072` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 131072 <= y + 131072 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(2 * hi * 1074791425) DIV 4294967296 + 131072` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 131072 <= y + 131072 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]; + CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(lo + 127) DIV 128 * 1025 + 2097152` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 2097152 <= y + 2097152 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(hi + 127) DIV 128 * 1025 + 2097152` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 2097152 <= y + 2097152 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]]);; + +let BARRETT_EQUIV = prove( + `!a. a < 8380417 ==> + ((2 * a * 1074791425) DIV 4294967296 + 131072) DIV 262144 = + ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304`, + GEN_TAC THEN DISCH_TAC THEN + let intervals = [ + (0, 261888); (261889, 785664); (785665, 1309440); (1309441, 1833216); + (1833217, 2356992); (2356993, 2880768); (2880769, 3404544); + (3404545, 3928320); (3928321, 4452096); (4452097, 4975872); + (4975873, 5499648); (5499649, 6023424); (6023425, 6547200); + (6547201, 7070976); (7070977, 7594752); (7594753, 8118528); + (8118529, 8380416)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("a",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`a:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_32 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + + + +(* ========================================================================= *) +(* Element-level functional correctness *) +(* ========================================================================= *) + +let ELEMENT_CORRECT = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> val(mldsa_use_hint_32_asm a h) = mldsa_use_hint_32_code (val a) (val h)`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[mldsa_use_hint_32_asm; mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ABBREV_TAC `nv = ((val(a:int32) + 127) DIV 128 * 1025 + 2097152) DIV 4194304` THEN + SUBGOAL_THEN `val(word_ishr_round (word_2smulh (a:int32) (word 1074791425)) 18 : int32) = nv` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN TRANS_TAC EQ_TRANS `((2 * val(a:int32) * 1074791425) DIV 4294967296 + 131072) DIV 262144` THEN CONJ_TAC THENL [MATCH_MP_TAC VAL_DECOMPOSE_A1 THEN ASM_REWRITE_TAC[]; MATCH_MP_TAC BARRETT_EQUIV THEN ASM_REWRITE_TAC[]]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 16` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND) THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_igt (a:int32) (word 8118528:int32) <=> val a > 8118528` SUBST1_TAC THENL [MP_TAC(SPEC `a:int32` WORD_IGT_THRESHOLD_32) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(a:int32) > 8118528` THEN ASM_REWRITE_TAC[bitval] THENL + [ + SUBGOAL_THEN `nv = 16` SUBST_ALL_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_WRAP) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_ishr_round (word_2smulh (a:int32) (word 1074791425)) 18 = (word 16:int32)` (fun th -> REWRITE_TAC[th]) THENL [ONCE_REWRITE_TAC[GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[WORD_ILE_ZERO_32] THEN + MP_TAC(SPEC `a:int32` WRAP_A0_NEGATIVE) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + SUBGOAL_THEN `~((if int_gt (&(val(a:int32))) (&4190208) then &(val a) - &8380417 else &(val a):int) > &0)` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN COND_CASES_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT] (ASSUME `val(a:int32) < 8380417`)) THEN INT_ARITH_TAC; + POP_ASSUM(MP_TAC o REWRITE_RULE[INT_GT; INT_NOT_LT]) THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_GT; INT_GT] (ASSUME `val(a:int32) > 8118528`)) THEN INT_ARITH_TAC]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN CONV_TAC WORD_REDUCE_CONV; + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN CONV_TAC WORD_REDUCE_CONV] + ; + SUBGOAL_THEN `nv <= 15` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND_NOWRAP) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `nv MOD 16 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `word_ishr_round (word_2smulh (a:int32) (word 1074791425)) 18 = (word nv:int32)` (fun th -> REWRITE_TAC[th]) THENL [GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_REFL] THEN REWRITE_TAC[WORD_ILE_ZERO_32; WORD_ADD_0] THEN + SUBGOAL_THEN `nv * 523776 <= 7856640` ASSUME_TAC THENL [SUBGOAL_THEN `nv * 523776 <= 15 * 523776` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 523776:int32)) = nv * 523776` ASSUME_TAC THENL [REWRITE_TAC[VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 523776:int32)) <= 7856640` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(bit 31 (word_sub (a:int32) (word_mul (word nv:int32) (word 523776:int32))) \/ word_sub a (word_mul (word nv) (word 523776)) = word 0) <=> ~(&(val a) - &nv * &523776 > &0)` SUBST1_TAC THENL + [MP_TAC(ISPECL [`a:int32`; `word_mul (word nv:int32) (word 523776:int32)`] WORD_SUB_SIGN_32) THEN ASM_REWRITE_TAC[] THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 523776` THENL + [ASM_REWRITE_TAC[] THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] (REWRITE_RULE[GSYM INT_OF_NUM_LE] (ASSUME `val(a:int32) <= nv * 523776`))) THEN INT_ARITH_TAC; + ASM_REWRITE_TAC[] THEN SUBGOAL_THEN `nv * 523776 < val(a:int32)` ASSUME_TAC THENL [UNDISCH_TAC `~(val(a:int32) <= nv * 523776)` THEN ARITH_TAC; ALL_TAC] THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] (REWRITE_RULE[GSYM INT_OF_NUM_LT] (ASSUME `nv * 523776 < val(a:int32)`))) THEN INT_ARITH_TAC]; ALL_TAC] THEN + REWRITE_TAC[bitval] THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN REWRITE_TAC[WORD_MUL_0; WORD_ADD_0; WORD_AND_ONES_32] THEN REWRITE_TAC[VAL_WORD_AND_15_32] THEN SUBGOAL_THEN `val(word nv:int32) = nv` SUBST1_TAC THENL [MATCH_MP_TAC VAL_WORD_EQ THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_1_32; WORD_AND_ONES_32] THEN + SUBGOAL_THEN `val(word nv:int32) = nv` ASSUME_TAC THENL [MATCH_MP_TAC VAL_WORD_EQ THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word nv:int32) <= 15` ASSUME_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(int_gt (&(val(a:int32)) - &nv * &523776) (&4190208))` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN MP_TAC(SPEC `val(a:int32)` A0_UPPER_32) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD] (ASSUME `val(a:int32) < (nv + 1) * 523776`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 523776` THENL + [MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `523776`] REAL_INT_GT_BRIDGE) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]; + MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `523776`] REAL_INT_GT_BRIDGE_POS) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]] THEN + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_AND_15_32; VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `16 = 2 EXP 4`; ARITH_RULE `4294967296 = 2 EXP 32`; MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[ARITH_RULE `4294967295 = 15 + 268435455 * 16`; ARITH_RULE `n + (15 + 268435455 * 16) = (n + 15) + 268435455 * 16`; MOD_MULT_ADD] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` (fun th -> REWRITE_TAC[th]) THEN TRY(MATCH_MP_TAC MOD_LT) THEN ASM_ARITH_TAC]);; + + +let ELEMENT_CORRECT_WORD = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> mldsa_use_hint_32_asm a h = + word(mldsa_use_hint_32_code (val a) (val h))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN + AP_TERM_TAC THEN MP_TAC(SPECL [`a:int32`; `h:int32`] ELEMENT_CORRECT) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]));; + + +(* ========================================================================= *) +(* FIPS 204 = code-aligned equivalence *) +(* ========================================================================= *) +(* *) +(* Bridges mldsa_use_hint_32 (FIPS 204 Algorithm 40, used in the public *) +(* postcondition) to mldsa_use_hint_32_code (the Barrett-style numeric form *) +(* the assembly actually computes). The main correctness proof states its *) +(* postcondition in FIPS 204 terms and rewrites with this equivalence in the *) +(* strengthening branch to expose the code-aligned form for symbolic *) +(* execution. *) +(* ========================================================================= *) + +let LINEARIZE_DIV_MOD_TAC = + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +(* Prove r DIV 523776 = k via DIV_SANDWICH + LE_MULT_RCANCEL *) +let DIV_523776_TAC k = + let k_num = mk_small_numeral k and km1 = mk_small_numeral (k-1) + and kp1 = mk_small_numeral (k+1) + and q = mk_var("q",`:num`) and le = `(<=):num->num->bool` + and lt = `(<):num->num->bool` + and c = `523776` in + let mk_mul a b = mk_binop (rator(rator `0*0`)) a b in + MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; c] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, q) THEN + REPEAT GEN_TAC THEN STRIP_TAC THEN + ASM_CASES_TAC(mk_comb(mk_comb(le, q), km1)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul q c), + mk_mul km1 c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul k_num c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_CASES_TAC(mk_comb(mk_comb(lt, k_num), q)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul kp1 c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC]];; + +(* Replace (r - r MOD 523776) DIV 523776 with r DIV 523776 *) +let DIV_MOD_TO_DIV_TAC = + SUBGOAL_THEN `(r - r MOD 523776) DIV 523776 = r DIV 523776` SUBST1_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 = 523776 * r DIV 523776` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV; ALL_TAC];; + +(* Lower half nowrap: dismiss wrap cond, reduce, prove r DIV 523776 = k *) +let DECOMPOSE_R1_LOWER_TAC = + SUBGOAL_THEN `~((&r:int) - &(r MOD 523776) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN LINEARIZE_DIV_MOD_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + DIV_MOD_TO_DIV_TAC THEN + CONV_TAC SYM_CONV THEN + LINEARIZE_DIV_MOD_TAC;; + +(* Upper half nowrap: dismiss wrap cond, reduce, prove r DIV 523776 + 1 = k *) +let DECOMPOSE_R1_UPPER_TAC = + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~((&r:int) - (&(r MOD 523776) - &523776) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [REWRITE_TAC[INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`] THEN + ASM_SIMP_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + LINEARIZE_DIV_MOD_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 + 523776 = 523776 * (r DIV 523776 + 1)` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776 + 1`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN SUBST1_TAC THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +let DECOMPOSE_R1_NOWRAP_TAC = + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THEN + TRY DECOMPOSE_R1_LOWER_TAC THEN TRY DECOMPOSE_R1_UPPER_TAC;; + +let DECOMPOSE_32_R1_EQUIV = time prove( + `!r. r < 8380417 ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = + decompose_32_r1 r`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `r <= 8118528` THENL + [ALL_TAC; + (* Wrap zone *) + SUBGOAL_THEN `8118528 < r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `decompose_32_r1 r = 0` SUBST1_TAC THENL + [REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THENL + [(* Lower wrap: r DIV 523776 = 16 *) + SUBGOAL_THEN `r DIV 523776 = 16` ASSUME_TAC THENL + [DIV_523776_TAC 16; ALL_TAC] THEN + SUBGOAL_THEN `16 * 523776 + r MOD 523776 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + ASM_ARITH_TAC; + (* Upper wrap: r DIV 523776 = 15 *) + SUBGOAL_THEN `r DIV 523776 = 15` ASSUME_TAC THENL + [DIV_523776_TAC 15; ALL_TAC] THEN + SUBGOAL_THEN `15 * 523776 + r MOD 523776 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + ALL_TAC] THEN + MP_TAC(SPEC `r:num` A1_WRAP) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONV_TAC NUM_REDUCE_CONV] THEN + (* Nowrap zone: unfold and do interval cascade *) + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + let intervals = [ + (0, 261888); (261889, 785664); (785665, 1309440); + (1309441, 1833216); (1833217, 2356992); (2356993, 2880768); + (2880769, 3404544); (3404545, 3928320); (3928321, 4452096); + (4452097, 4975872); (4975873, 5499648); (5499649, 6023424); + (6023425, 6547200); (6547201, 7070976); (7070977, 7594752); + (7594753, 8118528)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("r",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`r:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_32 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + DECOMPOSE_R1_NOWRAP_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + +let R1_IS_DIV_LOWER = prove( + `!r. r < 8380417 /\ r MOD 523776 * 2 <= 523776 /\ + ~((&r:int) - &(r MOD 523776) = &8380416) ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = r DIV 523776`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + MP_TAC(SPEC `r:num` LOWER_NONWRAP_R1) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R1_IS_DIV_PLUS1_UPPER = prove( + `!r. r < 8380417 /\ ~(r MOD 523776 * 2 <= 523776) /\ + ~((&r:int) - (&(r MOD 523776) - &523776) = &8380416) ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = + r DIV 523776 + 1`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + MP_TAC(SPEC `r:num` UPPER_NONWRAP_R1) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +(* Upper nowrap: substitute Barrett = r DIV 523776 + 1, use INT_MOD_RESIDUE *) +let R0_SIGN_UPPER_NOWRAP_TAC = + MP_TAC(SPEC `r:num` R1_IS_DIV_PLUS1_UPPER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_MUL] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `(a:int) - (b + &1) * c = a - b * c - c`] THEN + REWRITE_TAC[INT_ARITH `x - &523776 > &0 <=> x > &523776`; + INT_ARITH `x - &523776 - &8380417 > &0 <=> x > &8904193`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Lower nowrap: substitute Barrett = r DIV 523776, use INT_MOD_RESIDUE *) +let R0_SIGN_LOWER_NOWRAP_TAC = + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Wrap: derive 8118528 < r, use DECOMPOSE_32_R1_EQUIV to get Barrett = 0 *) +let R0_SIGN_WRAP_TAC = + SUBGOAL_THEN `8118528 < r` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> t = `&8380416:int`)) (concl th) && + not(is_neg(concl th)))) THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ; + INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`; + GSYM INT_OF_NUM_ADD] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + REWRITE_TAC[INT_ARITH `x - &1 > &0 <=> x > &1`; + INT_ARITH `(x - &523776) - &1 > &0 <=> x > &523777`; + INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let DECOMPOSE_32_R0_SIGN = time prove( + `!r. r < 8380417 ==> + let a1 = (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 in + let a0':int = if (&r:int) - &a1 * &523776 > &4190208 + then &r - &a1 * &523776 - &8380417 + else &r - &a1 * &523776 in + (decompose_32_r0 r > &0 <=> a0' > &0) /\ + (decompose_32_r0 r <= &0 <=> ~(a0' > &0))`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[INT_ARITH `(x:int) <= &0 <=> ~(x > &0)`] THEN + MATCH_MP_TAC(TAUT `(p <=> q) ==> (p <=> q) /\ (~p <=> ~q)`) THEN + REWRITE_TAC[decompose_32_r0; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ONCE_REWRITE_TAC[COND_RAND] THEN REWRITE_TAC[SND; FST] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + TRY R0_SIGN_LOWER_NOWRAP_TAC THEN + TRY R0_SIGN_UPPER_NOWRAP_TAC THEN + TRY R0_SIGN_WRAP_TAC THEN + TRY( + (* Contradiction: lower nowrap with > 4190208 *) + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - &((((r + 127) DIV 128 * 1025 + 2097152) DIV + 4194304) MOD 16) * &523776 = &(r MOD 523776)` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(&(r MOD 523776) > (&4190208:int))` MP_TAC THENL + [REWRITE_TAC[INT_NOT_LT; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[INT_OF_NUM_GT] THEN ASM_ARITH_TAC + ));; + +let MLDSA_USE_HINT_32_EQUIV = prove( + `!r h. r < 8380417 /\ h <= 1 + ==> mldsa_use_hint_32 h r = mldsa_use_hint_32_code r h`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[MLDSA_USE_HINT_32_UNFOLD] THEN + REWRITE_TAC[mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R0_SIGN) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN STRIP_TAC THEN + ASM_CASES_TAC `h = 0` THENL + [ASM_REWRITE_TAC[ARITH_RULE `~(0 = 1)`]; ALL_TAC] THEN + SUBGOAL_THEN `h = 1` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `decompose_32_r0 r > &0` THEN ASM_REWRITE_TAC[] THEN + ASM_MESON_TAC[]);; + +(* ========================================================================= *) +(* Strengthen-post utility for the FIPS-aligned correctness proof *) +(* ========================================================================= *) + +let ENSURES_STRENGTHEN_POST = prove( + `!P (Q:armstate->bool) Q' R. + ensures arm P Q' R /\ (!s. Q' s ==> Q s) ==> ensures arm P Q R`, + REPEAT GEN_TAC THEN DISCH_THEN(CONJUNCTS_THEN2 MP_TAC ASSUME_TAC) THEN + REWRITE_TAC[ensures] THEN MATCH_MP_TAC MONO_FORALL THEN + X_GEN_TAC `s0:armstate` THEN MATCH_MP_TAC MONO_IMP THEN REWRITE_TAC[] THEN + MP_TAC(BETA_RULE(ISPECL [`arm`; + `\s':armstate. (Q':armstate->bool) s' /\ (R:armstate->armstate->bool) (s0:armstate) s'`; + `\s':armstate. (Q:armstate->bool) s' /\ (R:armstate->armstate->bool) (s0:armstate) s'`] + EVENTUALLY_MONO)) THEN + ANTS_TAC THENL [ASM_MESON_TAC[]; MESON_TAC[]]);; + + +(* ========================================================================= *) +(* Correctness (FIPS 204-aligned) *) +(* ========================================================================= *) + +(* Postcondition is stated in terms of mldsa_use_hint_32 from FIPS 204 + (Algorithm 40), with the output bound < 16 as a corollary. The bounds + on val(x i) / val(y i) appear as antecedents inside the postcondition + (decompose-style): the assembly executes regardless of input ranges, + and only the FIPS-equivalence + output bound require the input bounds. *) +let MLDSA_USE_HINT_32_CORRECT = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH mldsa_poly_use_hint_32_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_use_hint_32_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH mldsa_poly_use_hint_32_mc - 4) /\ + ((!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1) + ==> (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32 + (word_add b (word(4 * i)))) s) < 16))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + + MAP_EVERY X_GEN_TAC + [`b:int64`; `a:int64`; `h:int64`; + `x:num->int32`; `y:num->int32`; `pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; + NONOVERLAPPING_CLAUSES; ALL; + fst MLDSA_USE_HINT_32_EXEC] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + GLOBALIZE_PRECONDITION_TAC THEN + CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV EXPAND_CASES_CONV))) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC[SOME_FLAGS; MODIFIABLE_SIMD_REGS] THEN + + (* Initialize and merge memory (input bounds NOT used yet). *) + ENSURES_INIT_TAC "s0" THEN + MEMORY_128_FROM_32_TAC "a" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + MEMORY_128_FROM_32_TAC "h" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + + (* Simulate 878 instructions (the assembly is bound-independent). *) + MAP_EVERY (fun n -> ARM_STEPS_TAC MLDSA_USE_HINT_32_EXEC [n] THEN + SIMD_SIMPLIFY_TAC[]) + (1--878) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + (* Pick up the postcondition's input-bound antecedents + (val(x i) < 8380417 /\ val(y i) <= 1) as assumptions. *) + DISCH_THEN(CONJUNCTS_THEN ASSUME_TAC) THEN + + (* Split bytes128 -> bytes32 for output memory. *) + REPEAT(FIRST_X_ASSUM(STRIP_ASSUME_TAC o + CONV_RULE (SIMD_SIMPLIFY_CONV []) o + CONV_RULE(READ_MEMORY_SPLIT_CONV 2) o + check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN + + (* Expand output cases, substitute. *) + CONV_TAC(TOP_DEPTH_CONV EXPAND_CASES_CONV) THEN + CONV_TAC(DEPTH_CONV NUM_MULT_CONV THENC DEPTH_CONV NUM_ADD_CONV) THEN + REWRITE_TAC[WORD_ADD_0] THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN ASM_REWRITE_TAC[] THEN + + (* Push word_subword through SIMD ops to per-element form. *) + REWRITE_TAC[WORD_SUBWORD_AND; WORD_SUBWORD_OR] THEN + let WSN_TAC = REWRITE_TAC(map (fun n -> prove( + subst [mk_small_numeral n, `n:num`] + `!x:int128. word_subword(word_not x) (n,32):int32 = word_not(word_subword x (n,32))`, + GEN_TAC THEN MATCH_MP_TAC WORD_SUBWORD_NOT THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_128] THEN ARITH_TAC)) [0;32;64;96]) in + WSN_TAC THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + + (* Build the per-element FIPS-eq lemma EC_FINAL by composing + ELEMENT_CORRECT_WORD with the asm definition unfold. *) + let EC_DEEP = + CONV_RULE(DEPTH_CONV WORD_NUM_RED_CONV) + (CONV_RULE(DEPTH_CONV(INT_RED_CONV ORELSEC NUM_RED_CONV)) + (CONV_RULE(TOP_DEPTH_CONV let_CONV) + (REWRITE_RULE[mldsa_use_hint_32_asm; word_2smulh; word_ishr_round; + DIMINDEX_32] ELEMENT_CORRECT_WORD))) in + let EC_FINAL = ONCE_REWRITE_RULE[WORD_AND_SYM] + (ONCE_REWRITE_RULE[WORD_OR_SYM] EC_DEEP) in + + (* Pre-rewrite mldsa_use_hint_32 -> _code via the equivalence at all + occurrences in the goal. IMP_REWRITE_TAC handles the conditional lemma + and leaves index-bound side conditions (i < 256) which we close + uniformly via ARITH below. *) + REPEAT (IMP_REWRITE_TAC[MLDSA_USE_HINT_32_EQUIV]) THEN + + (* Split into per-element leaf goals (FIPS-eq + val<16) plus index-bound + side conditions left over from IMP_REWRITE_TAC. Each FIPS-eq leaf is + closed by EC_FINAL; each val<16 leaf is closed by reducing val(word ..) + to MOD via VAL_WORD then bounding mldsa_use_hint_32_code < 16. *) + REPEAT CONJ_TAC THEN + (FIRST [ + MATCH_MP_TAC EC_FINAL THEN CONJ_TAC THEN + FIRST_X_ASSUM MATCH_MP_TAC THEN ARITH_TAC; + REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC(ARITH_RULE `x < 16 ==> x MOD 4294967296 < 16`) THEN + REWRITE_TAC[mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN + REWRITE_TAC[MOD_LT_EQ; ARITH_EQ]; + ARITH_TAC]));; + + +(* ========================================================================= *) +(* Public subroutine correctness (FIPS 204-aligned) *) +(* ========================================================================= *) + +(* Subroutine form: derives directly from MLDSA_USE_HINT_32_CORRECT by adding + the X30 -> RET return wiring via ARM_ADD_RETURN_NOSTACK_TAC. The bound + antecedents inside the postcondition pass through unchanged (decompose + pattern). *) +let MLDSA_USE_HINT_32_SUBROUTINE_CORRECT = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH mldsa_poly_use_hint_32_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) mldsa_poly_use_hint_32_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + ((!i. i < 256 ==> val(x i:int32) < 8380417) /\ + (!i. i < 256 ==> val(y i:int32) <= 1) + ==> (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32 + (word_add b (word(4 * i)))) s) < 16))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REWRITE_TAC[fst MLDSA_USE_HINT_32_EXEC] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ARM_ADD_RETURN_NOSTACK_TAC MLDSA_USE_HINT_32_EXEC + (CONV_RULE(ONCE_DEPTH_CONV NUM_REDUCE_CONV) + (REWRITE_RULE[fst MLDSA_USE_HINT_32_EXEC] + MLDSA_USE_HINT_32_CORRECT)));; + + +(* ========================================================================= *) +(* Constant-time and memory safety proof. *) +(* ========================================================================= *) + +needs "arm/proofs/consttime.ml";; +needs "arm/proofs/subroutine_signatures.ml";; + + +let full_spec,public_vars = mk_safety_spec + ~keep_maychanges:false + (assoc "mldsa_poly_use_hint_32" subroutine_signatures) + MLDSA_USE_HINT_32_SUBROUTINE_CORRECT + MLDSA_USE_HINT_32_EXEC;; + +let MLDSA_USE_HINT_32_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e b a h pc returnaddress. + nonoverlapping (word pc,LENGTH mldsa_poly_use_hint_32_mc) (b,1024) /\ + nonoverlapping (b,1024) (a,1024) /\ + nonoverlapping (b,1024) (h,1024) + ==> ensures arm + (\s. + aligned_bytes_loaded s (word pc) + mldsa_poly_use_hint_32_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + read events s = e) + (\s. + read PC s = returnaddress /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h b pc returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; b,1024] + [b,1024])) + (\s s'. true)`, + ASSERT_CONCL_TAC full_spec THEN + PROVE_SAFETY_SPEC_TAC ~public_vars:public_vars MLDSA_USE_HINT_32_EXEC);; + diff --git a/arm/proofs/specifications.txt b/arm/proofs/specifications.txt index 0115199f2..ace018c46 100644 --- a/arm/proofs/specifications.txt +++ b/arm/proofs/specifications.txt @@ -335,6 +335,8 @@ MLDSA_POINTWISE_ACC_L7_SUBROUTINE_CORRECT MLDSA_POINTWISE_ACC_L7_SUBROUTINE_SAFE MLDSA_POINTWISE_SUBROUTINE_CORRECT MLDSA_POINTWISE_SUBROUTINE_SAFE +MLDSA_USE_HINT_32_SUBROUTINE_CORRECT +MLDSA_USE_HINT_32_SUBROUTINE_SAFE MLKEM_BASEMUL_K2_SUBROUTINE_CORRECT MLKEM_BASEMUL_K2_SUBROUTINE_SAFE MLKEM_BASEMUL_K3_SUBROUTINE_CORRECT diff --git a/arm/proofs/subroutine_signatures.ml b/arm/proofs/subroutine_signatures.ml index ac5e134a6..c15730f5b 100644 --- a/arm/proofs/subroutine_signatures.ml +++ b/arm/proofs/subroutine_signatures.ml @@ -4555,6 +4555,24 @@ let subroutine_signatures = [ ]) ); +("mldsa_poly_use_hint_32", + ([(*args*) + ("b", "int32_t[static 256]", (*is const?*)"false"); + ("a", "int32_t[static 256]", (*is const?*)"true"); + ("h", "int32_t[static 256]", (*is const?*)"true"); + ], + "void", + [(* input buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ("h", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* output buffers *) + ("b", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + ("mlkem_basemul_k2", ([(*args*) ("r", "int16_t[static 256]", (*is const?*)"false"); diff --git a/benchmarks/benchmark.c b/benchmarks/benchmark.c index d64e5c4ee..6517b3979 100644 --- a/benchmarks/benchmark.c +++ b/benchmarks/benchmark.c @@ -1112,6 +1112,7 @@ void call_mldsa_pointwise_acc_l4(void) repeat(mldsa_pointwise_acc_l4_x86((int32_ void call_mldsa_pointwise_acc_l5(void) repeat(mldsa_pointwise_acc_l5_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) void call_mldsa_pointwise_acc_l7(void) repeat(mldsa_pointwise_acc_l7_x86((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2,mldsa_avx2_qdata)) void call_mldsa_reduce(void) repeat(mldsa_reduce((int32_t*)b0)) +void call_mldsa_poly_use_hint_32(void) {} void call_mlkem_frombytes(void) repeat(mlkem_frombytes((uint16_t*)b0,(int8_t*)b1)) void call_mlkem_intt(void) repeat(mlkem_intt_x86((int16_t*)b0,(int16_t*)b1)) @@ -1155,6 +1156,7 @@ void call_mldsa_pointwise(void) repeat(mldsa_pointwise((int32_t*)b0,(int32_t*)b1 void call_mldsa_pointwise_acc_l4(void) repeat(mldsa_pointwise_acc_l4((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) void call_mldsa_pointwise_acc_l5(void) repeat(mldsa_pointwise_acc_l5((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) void call_mldsa_pointwise_acc_l7(void) repeat(mldsa_pointwise_acc_l7((int32_t*)b0,(const int32_t*)b1,(const int32_t*)b2)) +void call_mldsa_poly_use_hint_32(void) repeat(mldsa_poly_use_hint_32((int32_t*)b0,(int32_t*)b1,(int32_t*)b2)) void call_mldsa_reduce(void) {} void call_bignum_copy_row_from_table_8n__32_16(void) \ @@ -1629,6 +1631,7 @@ int main(int argc, char *argv[]) timingtest(all,"mldsa_pointwise_acc_l4",call_mldsa_pointwise_acc_l4); timingtest(all,"mldsa_pointwise_acc_l5",call_mldsa_pointwise_acc_l5); timingtest(all,"mldsa_pointwise_acc_l7",call_mldsa_pointwise_acc_l7); + timingtest(arm,"mldsa_poly_use_hint_32",call_mldsa_poly_use_hint_32); timingtest(!arm,"mldsa_reduce",call_mldsa_reduce); timingtest(bmi,"p256_montjadd",call_p256_montjadd); timingtest(all,"p256_montjadd_alt",call_p256_montjadd_alt); diff --git a/common/mlkem_mldsa.ml b/common/mlkem_mldsa.ml index 2858dc298..a495f3a69 100644 --- a/common/mlkem_mldsa.ml +++ b/common/mlkem_mldsa.ml @@ -871,6 +871,20 @@ let MEMORY_128_FROM_32_TAC = READ_MEMORY_MERGE_CONV 2 (subst[itm,n_tm] pat') in MP_TAC(end_itlist CONJ (map f (0--(n-1))));; +(* ------------------------------------------------------------------------- *) +(* ML-DSA use_hint Merge 4 x bytes32 into bytes128 at a given base+offset *) +(* ------------------------------------------------------------------------- *) + +let USE_HINT_MEMORY_128_FROM_32_TAC = + let a_tm = `a:int64` and n_tm = `n:num` and i64_ty = `:int64` + and pat = `read (memory :> bytes128(word_add a (word n))) s0` in + fun v boff n -> + let pat' = subst[mk_var(v,i64_ty),a_tm] pat in + let f i = + let itm = mk_small_numeral(boff + 16*i) in + READ_MEMORY_MERGE_CONV 2 (subst[itm,n_tm] pat') in + MP_TAC(end_itlist CONJ (map f (0--(n-1))));; + (* ------------------------------------------------------------------------- *) (* From |- (x == y) (mod m) /\ P to |- (x == y) (mod n) /\ P *) (* ------------------------------------------------------------------------- *) @@ -1949,3 +1963,176 @@ let SIMD_SIMPLIFY_ABBREV_TAC = let tms = sort free_in (find_terms pam (rand(concl th''))) in (MP_TAC th'' THEN MAP_EVERY AUTO_ABBREV_TAC tms THEN DISCH_TAC) (asl,w) in TRY(FIRST_X_ASSUM(ttac o check (simdable o concl)));; + +(* ========================================================================= *) +(* ML-DSA use_hint shared infrastructure lemmas *) +(* Used by both poly_use_hint_32 and poly_use_hint_88 proofs *) +(* ========================================================================= *) + +(* ival equals val for values in [0, Q) where Q = 8380417 < 2^31 *) +let MLDSA_IVAL_VAL = prove( + `!a:int32. val a < 8380417 ==> ival a = &(val a)`, + GEN_TAC THEN DISCH_TAC THEN + SIMP_TAC[ival; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC);; + +(* For natural numbers, &n is never < -2^31 *) +let INT_POS_NEG_BOUND = prove(`!n. ~((&n:int) < --(&2147483648))`, + GEN_TAC THEN REWRITE_TAC[INT_NOT_LT] THEN + MP_TAC(SPEC `n:num` INT_POS) THEN INT_ARITH_TAC);; + +(* val(iword(&n)) = n for n < 2^31 *) +let VAL_IWORD_NUM_32 = prove( + `!n. n < 2147483648 ==> val(iword(&n):int32) = n`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(ISPECL [`&n:int`] (INST_TYPE [`:32`,`:N`] INT_VAL_IWORD)) THEN + REWRITE_TAC[DIMINDEX_32; INT_POS] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; + REWRITE_TAC[INT_OF_NUM_EQ] THEN SIMP_TAC[]]);; + +(* word_ile x 0 in terms of bit 31 (signed non-positive check) *) +let WORD_ILE_ZERO_32 = BITBLAST_RULE + `!x:int32. word_ile x (word 0) <=> bit 31 x \/ x = word 0`;; + +(* val(word_and x (word 15)) = val x MOD 16 *) +let VAL_WORD_AND_15_32 = BITBLAST_RULE + `!x:int32. val(word_and x (word 15:int32)) = val x MOD 16`;; + +(* word_and x all-ones = x *) +let WORD_AND_ONES_32 = prove( + `!x:int32. word_and x (word 4294967295) = x`, + GEN_TAC THEN SUBGOAL_THEN `(word 4294967295 : int32) = word_not(word 0)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; REWRITE_TAC[WORD_AND_NOT0]]);; + +(* word_mul x 1 = x *) +let WORD_MUL_1_32 = prove( + `!x:int32. word_mul x (word 1) = x`, + GEN_TAC THEN REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN REWRITE_TAC[MULT_CLAUSES] THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `x:int32` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV);; + +(* Bridge lemmas: derive both real_gt and int_gt from a single NUM fact. + Needed for native mode where real_gt and int_gt are distinct types. *) +let REAL_INT_GT_BRIDGE = prove( + `!a:num b c. a <= b * c ==> + ~(real_gt (&a - &b * &c) (&0)) /\ ~(int_gt (&a - &b * &c) (&0))`, + REPEAT GEN_TAC THEN DISCH_TAC THEN CONJ_TAC THENL + [REWRITE_TAC[real_gt; REAL_NOT_LT] THEN + MP_TAC(REWRITE_RULE[GSYM REAL_OF_NUM_LE; GSYM REAL_OF_NUM_MUL] (ASSUME `a <= b * c`)) THEN REAL_ARITH_TAC; + REWRITE_TAC[INT_GT; INT_NOT_LT] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_MUL] (ASSUME `a <= b * c`)) THEN INT_ARITH_TAC]);; + +let REAL_INT_GT_BRIDGE_POS = prove( + `!a:num b c. ~(a <= b * c) ==> + real_gt (&a - &b * &c) (&0) /\ int_gt (&a - &b * &c) (&0)`, + REPEAT GEN_TAC THEN REWRITE_TAC[NOT_LE] THEN DISCH_TAC THEN CONJ_TAC THENL + [REWRITE_TAC[real_gt] THEN + MP_TAC(REWRITE_RULE[GSYM REAL_OF_NUM_LT; GSYM REAL_OF_NUM_MUL] (ASSUME `b * c < a`)) THEN REAL_ARITH_TAC; + REWRITE_TAC[INT_GT] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL] (ASSUME `b * c < a`)) THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Shared helper lemmas for UseHint proofs *) +(* ========================================================================= *) + +let DIV_SANDWICH = prove( + `!x d k. ~(d = 0) /\ k * d <= x /\ x < (k + 1) * d ==> x DIV d = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `k <= x DIV d` ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_RDIV_EQ] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `x DIV d < k + 1` ASSUME_TAC THENL + [ASM_SIMP_TAC[RDIV_LT_EQ] THEN ASM_ARITH_TAC; ASM_ARITH_TAC]);; + +let INT_MOD_RESIDUE = prove( + `!r m. ~(m = 0) ==> (&r:int) - &(r DIV m) * &m = &(r MOD m)`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPECL [`r:num`; `m:num`] (CONJUNCT1 DIVISION_SIMP)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD; + GSYM INT_OF_NUM_EQ] THEN + INT_ARITH_TAC);; + +(* ========================================================================= *) +(* FIPS 204 UseHint definitions (Algorithms 36 and 40) *) +(* ========================================================================= *) + +let mldsa_cmod = new_definition + `mldsa_cmod (r:num) (m:num) : int = + if (r MOD m) * 2 <= m then &(r MOD m) else &(r MOD m) - &m`;; + +let mldsa_decompose_32 = new_definition + `mldsa_decompose_32 (r:num) : num # int = + let r0 = mldsa_cmod r 523776 in + if &r - r0 = &8380416 then (0, r0 - &1) + else (num_of_int((&r - r0) div &523776), r0)`;; + +let decompose_32_r1 = new_definition + `decompose_32_r1 (r:num) : num = FST(mldsa_decompose_32 r)`;; + +let decompose_32_r0 = new_definition + `decompose_32_r0 (r:num) : int = SND(mldsa_decompose_32 r)`;; + +let mldsa_use_hint_32 = new_definition + `mldsa_use_hint_32 (h:num) (r:num) : num = + let (r1, r0) = mldsa_decompose_32 r in + if h = 1 /\ r0 > &0 then (r1 + 1) MOD 16 + else if h = 1 /\ r0 <= &0 then (r1 + 15) MOD 16 + else r1`;; + +let LOWER_NONWRAP_R1 = prove( + `!r. r MOD 523776 * 2 <= 523776 /\ + ~((&r:int) - &(r MOD 523776) = &8380416) ==> + decompose_32_r1 r = r DIV 523776`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; + NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 = 523776 * r DIV 523776` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV);; + +let UPPER_NONWRAP_R1 = prove( + `!r. ~(r MOD 523776 * 2 <= 523776) /\ + ~((&r:int) - (&(r MOD 523776) - &523776) = &8380416) ==> + decompose_32_r1 r = r DIV 523776 + 1`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` ASSUME_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 + 523776 = (r DIV 523776 + 1) * 523776` + ASSUME_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + MP_TAC(SPECL [`(r DIV 523776 + 1) * 523776`; `523776`] DIV_MULT) THEN + ARITH_TAC);; + +let MLDSA_USE_HINT_32_UNFOLD = prove( + `!h r. mldsa_use_hint_32 h r = + (if h = 1 /\ decompose_32_r0 r > &0 then (decompose_32_r1 r + 1) MOD 16 + else if h = 1 /\ decompose_32_r0 r <= &0 + then (decompose_32_r1 r + 15) MOD 16 + else decompose_32_r1 r)`, + REPEAT GEN_TAC THEN + REWRITE_TAC[mldsa_use_hint_32; decompose_32_r1; decompose_32_r0] THEN + SPEC_TAC(`mldsa_decompose_32 r`, `p:num#int`) THEN + REWRITE_TAC[FORALL_PAIR_THM] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN REWRITE_TAC[]);; diff --git a/include/s2n-bignum.h b/include/s2n-bignum.h index 2c9273750..c4f95bacb 100644 --- a/include/s2n-bignum.h +++ b/include/s2n-bignum.h @@ -1054,6 +1054,10 @@ extern void mldsa_pointwise_acc_l7_x86(int32_t c[S2N_BIGNUM_STATIC 256], const i // Input a[256] (signed 32-bit words); output a[256] (signed 32-bit words) extern void mldsa_reduce(int32_t a[S2N_BIGNUM_STATIC 256]); +// Use hint to correct high bits of decomposition for ML-DSA (parameter sets 65/87) +// Inputs a[256], h[256] (signed 32-bit words); output b[256] (signed 32-bit words) +extern void mldsa_poly_use_hint_32(int32_t b[S2N_BIGNUM_STATIC 256], const int32_t a[S2N_BIGNUM_STATIC 256], const int32_t h[S2N_BIGNUM_STATIC 256]); + // Scalar product of 2-element polynomial vectors in NTT domain, with mulcache // Inputs a[512], b[512], bt[256] (signed 16-bit words); output r[256] (signed 16-bit words) extern void mlkem_basemul_k2(int16_t r[S2N_BIGNUM_STATIC 256],const int16_t a[S2N_BIGNUM_STATIC 512],const int16_t b[S2N_BIGNUM_STATIC 512],const int16_t bt[S2N_BIGNUM_STATIC 256]); diff --git a/tests/test.c b/tests/test.c index a545ceeac..a254245e9 100644 --- a/tests/test.c +++ b/tests/test.c @@ -12288,6 +12288,89 @@ static void mlkem_poly_mulcache_to_avx2_layout(int16_t a[128]) } #endif +// Reference implementation of mldsa_poly_use_hint_32 for ML-DSA parameter sets 65/87 +// GAMMA2 = (Q-1)/32 = 261888, output range [0, 15] +// Matches the exact assembly algorithm using SQDMULH-based Barrett decomposition +void reference_mldsa_poly_use_hint_32(int32_t b[256], const int32_t a[256], const int32_t h[256]) +{ + const int32_t TWO_GAMMA2 = 523776; + const int32_t THRESHOLD = 8118528; // 31 * GAMMA2 + const int32_t BARRETT = 1074791425; // 0x40100401 + for (int i = 0; i < 256; i++) { + int32_t ai = a[i]; + // Decompose using SQDMULH + SRSHR (matching assembly) + // sqdmulh: (2 * ai * BARRETT) >> 32 + int32_t sqdmulh_result = (int32_t)(((int64_t)2 * ai * BARRETT) >> 32); + // srshr by 18: (x + (1 << 17)) >> 18 (signed rounding shift right) + int32_t a1 = (sqdmulh_result + (1 << 17)) >> 18; + // a0 = ai - a1 * 2*GAMMA2 + int32_t a0 = ai - a1 * TWO_GAMMA2; + // Wraparound: if ai > threshold, set a1=0, a0 += -1 (since mask = -1) + if (ai > THRESHOLD) { + a1 = 0; + a0 = a0 + (-1); // add the all-ones mask + } + // delta = (a0 <= 0) ? -1 : 1 + int32_t delta = (a0 <= 0) ? -1 : 1; + // b = (a1 + delta * hint) & 15 + b[i] = (a1 + delta * h[i]) & 15; + } +} + +int test_mldsa_poly_use_hint_32(void) +{ + // Skip test on non-aarch64 architectures (ARM-only function) + if (get_arch_name() != ARCH_AARCH64) { + return 0; + } + +#ifdef __aarch64__ + uint64_t t, i; + int32_t a[256] __attribute__((aligned(32))); + int32_t h[256] __attribute__((aligned(32))); + int32_t b_asm[256] __attribute__((aligned(32))); + int32_t b_ref[256] __attribute__((aligned(32))); + + printf("Testing mldsa_poly_use_hint_32 with %d cases\n", tests); + + for (t = 0; t < tests; ++t) { + // Generate random coefficients in [0, Q) + for (i = 0; i < 256; ++i) { + a[i] = (int32_t)(random64() % 8380417); + h[i] = (int32_t)(random64() % 2); // hint is 0 or 1 + } + + // Compute reference result + reference_mldsa_poly_use_hint_32(b_ref, a, h); + + // Call the assembly implementation + mldsa_poly_use_hint_32(b_asm, a, h); + + // Compare results + for (i = 0; i < 256; ++i) { + if (b_asm[i] != b_ref[i]) { + printf("Error in mldsa_poly_use_hint_32 element i = %"PRIu64"; " + "asm = %"PRId32" ref = %"PRId32" " + "(a[i] = %"PRId32", h[i] = %"PRId32")\n", + i, b_asm[i], b_ref[i], a[i], h[i]); + return 1; + } + } + + if (VERBOSE) { + printf("OK: mldsa_poly_use_hint_32: a[0]=0x%08"PRIx32", h[0]=%"PRId32" => b[0]=%"PRId32"\n", + a[0], h[0], b_asm[0]); + } + } + + printf("All OK\n"); + return 0; +#else + return 0; +#endif +} + + int test_mlkem_basemul_k2(void) { uint64_t t, i; @@ -16784,6 +16867,7 @@ int main(int argc, char *argv[]) functionaltest(all,"mldsa_pointwise_acc_l5",test_mldsa_pointwise_acc_l5); functionaltest(all,"mldsa_pointwise_acc_l7",test_mldsa_pointwise_acc_l7); functionaltest(all,"mldsa_reduce",test_mldsa_reduce); + functionaltest(all,"mldsa_poly_use_hint_32",test_mldsa_poly_use_hint_32); functionaltest(all,"mlkem_basemul_k2",test_mlkem_basemul_k2); functionaltest(all,"mlkem_basemul_k3",test_mlkem_basemul_k3); functionaltest(all,"mlkem_basemul_k4",test_mlkem_basemul_k4); diff --git a/tools/collect-signatures.py b/tools/collect-signatures.py index 8e28befb3..b622ba632 100644 --- a/tools/collect-signatures.py +++ b/tools/collect-signatures.py @@ -305,6 +305,7 @@ def stripPrefixes(s, prefixes): "mldsa_pointwise_acc_l4", "mldsa_pointwise_acc_l5", "mldsa_pointwise_acc_l7", + "mldsa_poly_use_hint_32", "mlkem_ntt", "mlkem_intt", "mlkem_mulcache_compute",