Skip to content

ML-DSA AArch64 HOL-Light proof for mldsa_rej_uniform_eta4#402

Open
jakemas wants to merge 11 commits into
awslabs:mainfrom
jakemas:mldsa-rej-uniform-eta4
Open

ML-DSA AArch64 HOL-Light proof for mldsa_rej_uniform_eta4#402
jakemas wants to merge 11 commits into
awslabs:mainfrom
jakemas:mldsa-rej-uniform-eta4

Conversation

@jakemas
Copy link
Copy Markdown
Contributor

@jakemas jakemas commented May 6, 2026

Summary

Add formal verification proof for mldsa_rej_uniform_eta4, the AArch64 NEON rejection-sampling routine used in ML-DSA-65 (parameter set with eta=4). The proof establishes full functional correctness: the assembly filters 4-bit nibbles < 9 from a SHAKE256 buffer, maps each accepted value n to (4 - n) as int32, and writes up to 256 results to the output buffer matching num_of_wordlist(SUB_LIST(0,256) (REJ_SAMPLE_ETA4 inlist)).

Proof architecture

  • WOP characterization of loop iteration count N (smallest i where the buffer is exhausted or 256 samples already collected).
  • ENSURES_WHILE_UP_TAC over the main 240-byte sampling loop: 75-step preamble + 16 unrolled body iterations + back-edge + post-loop exit.
  • Writeback phase (245 ARM steps): SUB+SSHLL cascade unpacking nibbles to int32 with sign extension of (4 − n), stored via 64-byte chunks to res. Split into Case A (niblen ≥ 256, uses full stack) and Case B (niblen < 256).
  • Case B small (niblen < 256) uses a virtual stack list L of length 256 materialized via BYTES_EXISTS_WORDLIST: the TBL writeback leaves garbage past the committed niblist, so the 1024-byte output identity doesn't hold against a zero-padded STACK_CONTENT. Instead we define L so that bytes(sp, 512) s = num_of_wordlist L (pure existence over bytes), prove niblist = SUB_LIST(0, niblen) L, run the Case A closure chain on L, and MOD-truncate the 1024-byte identity to 4*niblen.

Function signature

extern uint64_t mldsa_rej_uniform_eta4(
    int32_t r[static 256], const uint8_t *buf, unsigned buflen,
    const uint8_t table[static 4096]);

The table argument is a 256-entry (16 bytes each) TBL permutation lookup indexed by 8-bit acceptance masks. The routine uses 576 bytes of stack as a scratch area for nibble packing.

Function source and byte-identical assembly is maintained in mldsa-native at:
https://github.com/pq-code-package/mldsa-native/blob/main/mldsa/src/native/aarch64/src/rej_uniform_eta4_asm.S

jakemas added a commit to pq-code-package/mldsa-native that referenced this pull request May 6, 2026
Convert the proof from interactive set_goal/e(...) style into named
let MLDSA_REJ_UNIFORM_ETA4_CORRECT = prove (...) form, then wrap with
ARM_ADD_RETURN_STACK_TAC to produce _SUBROUTINE_CORRECT and
PROVE_SAFETY_SPEC_TAC to produce _SUBROUTINE_SAFE.

This is the standard s2n-bignum structure used by mldsa_pointwise,
poly_use_hint_32_aarch64_asm, etc. Matches the upstream draft PR
awslabs/s2n-bignum#402.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Jake Massimo <jakemas@amazon.com>
jakemas added a commit to pq-code-package/mldsa-native that referenced this pull request May 15, 2026
Convert the proof from interactive set_goal/e(...) style into named
let MLDSA_REJ_UNIFORM_ETA4_CORRECT = prove (...) form, then wrap with
ARM_ADD_RETURN_STACK_TAC to produce _SUBROUTINE_CORRECT and
PROVE_SAFETY_SPEC_TAC to produce _SUBROUTINE_SAFE.

This is the standard s2n-bignum structure used by mldsa_pointwise,
poly_use_hint_32_aarch64_asm, etc. Matches the upstream draft PR
awslabs/s2n-bignum#402.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Signed-off-by: Jake Massimo <jakemas@amazon.com>
@jakemas jakemas force-pushed the mldsa-rej-uniform-eta4 branch from bb1d497 to ed7b774 Compare May 19, 2026 00:26
jakemas and others added 6 commits May 19, 2026 00:47
Add formal verification proof for mldsa_rej_uniform_eta4, the AArch64 NEON
rejection-sampling routine used in ML-DSA-65 (parameter set with eta=4).
Establishes full functional correctness: the assembly filters 4-bit
nibbles < 9 from a SHAKE256 buffer, maps each accepted value n to
(4 - n) as int32, and writes up to 256 results to the output buffer
matching num_of_wordlist(SUB_LIST(0,256) (REJ_SAMPLE_ETA4 inlist)).

- WOP characterization of loop iteration count N (smallest i where the
  buffer is exhausted or 256 samples already collected).
- ENSURES_WHILE_UP_TAC over the main 240-byte sampling loop: 75-step
  preamble + 16 unrolled body iterations + back-edge + post-loop exit.
- Writeback phase (245 ARM steps): SUB+SSHLL cascade unpacking nibbles
  to int32 with sign extension of (4 - n), stored via 64-byte chunks to
  res. Split into Case A (niblen >= 256, uses full stack) and Case B
  (niblen < 256).
- Case B small (niblen < 256) uses a "virtual stack list" L of length
  256 materialized via BYTES_EXISTS_WORDLIST: the TBL writeback leaves
  garbage past the committed niblist, so the 1024-byte output identity
  doesn't hold against a zero-padded STACK_CONTENT. Instead we define L
  so that bytes(sp, 512) s = num_of_wordlist L (L exists by pure
  existence over bytes), prove niblist = SUB_LIST(0, niblen) L, run the
  Case A closure chain on L, and MOD-truncate the 1024-byte identity
  to 4*niblen.

decode.ml gains a dispatch for MOVI (op=0, cmode=1000, 16-bit element,
no shift) — the existing `arm_adv_simd_expand_imm` already handled this
cmode, only the top-level bitmatch was missing a pattern. Needed for
`movi v30.8h, #0x9` (0x4f0085fe etc.) in the preamble.

    extern uint64_t mldsa_rej_uniform_eta4(
        int32_t r[static 256], const uint8_t *buf, unsigned buflen,
        const uint8_t table[static 4096]);

The table is a 256-entry (16 bytes each) TBL permutation lookup indexed
by 8-bit acceptance masks. The routine uses 576 bytes of stack as a
scratch area for nibble packing.

Function source and identical byte-for-byte assembly is maintained
in mldsa-native at
https://github.com/pq-code-package/mldsa-native/blob/main/mldsa/src/native/aarch64/src/rej_uniform_eta4_asm.S

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-30-87.us-west-2.compute.internal>
The original comment block had:
  // Input buf (buflen bytes, buflen a multiple of 8, >= 8), table (4096 bytes, lookup table)
  // Output r[256] (signed 32-bit words); returns number of coefficients written (<= 256)

tools/collect-signatures.py splits on `;` and expects exactly 3 parts
(inputs/outputs/temporaries) when `;` appears more than once. The stray `;`
in "Output r[256] ...; returns ..." caused the third part "returns number of
coefficients written (<= 256)" to be parsed as a temporary buffer. Because
it has no `[...]` bracket, parseArr() returns None, which then fails to
unpack in the outer `for argname, bufferlen in meminout.temporaries` loop
with TypeError.

Align the comment with mlkem_rej_uniform_VARIABLE_TIME's format (a
structurally identical function):
  // Inputs *buf (unsigned bytes), buflen, table (unsigned bytes); output r[256] (signed 32-bit words), return

Regenerated subroutine_signatures.ml to match.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Ubuntu <ubuntu@claude.local>
Wrap the core _CORRECT with ARM_ADD_RETURN_STACK_TAC (pre=1 SUB SP,
post=1 for RET; ADD SP is absorbed by the return adjustment), producing
the subroutine form with returnaddress tracking and the proper
(word_sub stackpointer (word 576), 576) stack region.

Add _SUBROUTINE_SAFE via mk_safety_spec + PROVE_SAFETY_SPEC_TAC for the
constant-time/memory-safety proof.

Also fix the include/s2n-bignum.h comment so collect-signatures.py can
parse it (match the mlkem_rej_uniform_VARIABLE_TIME header shape: "Inputs
*buf (...), buflen, table[4096] (...); output r[256] (...), return").

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Ubuntu <ubuntu@ip-172-31-30-87.us-west-2.compute.internal>
Drop the DBG no-op tracer definition and all 140 DBG "..." THEN call
sites. They were development progress markers; with DBG bound to
ALL_TAC they are harmless but noisy.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Ubuntu <ubuntu@claude.local>
The `movi v26.8b, #0xf` NEON instruction operates on the lower 64 bits of
V26 (i.e., D26), not the full 128-bit Q26. Fix the annotation to match
what print_literal_from_elf actually emits.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Ubuntu <ubuntu@claude.local>
The prior _SUBROUTINE_SAFE used mk_safety_spec + PROVE_SAFETY_SPEC_TAC,
which assumes constant-time execution. mldsa_rej_uniform_eta4 is a
variable-time rejection sampler (which nibbles pass the < 9 filter
depends on data), so PROVE_SAFETY_SPEC_TAC fails (No read events).

Replace with the mlkem_rej_uniform_VARIABLE_TIME pattern: a custom
MEMSAFE spec with ENSURES_SEQUENCE_TAC + ENSURES_WHILE_UP_TAC and an
event-tracking loop invariant, then DISCHARGE_MEMSAFE_TAC at boundaries.

Helpers added:
- DISCHARGE_MEMSAFE_TAC, DISCHARGE_MEMSAFE_ASM_TAC for the two flavors
  of memsafe postcondition (top-level vs loop-body)
- MEMSAFE_ARITH_TAC: SIMPLE_ARITH_TAC variant that allows val/word
- CONTAINED_ASM_TAC: ASM-aware containment for symbolic addresses
- STRIP_EXISTS_ASSUM_TAC: unpack ?e_acc. read events s = APPEND e_acc e

Bound lemmas (closing the val(idx0/idx1) < 256 step):
- WORD_ZX_INT32_INT64, VAL_WORD_SUBWORD_0_32: width-conversion identities
- SUM_8_BIT_BOUND_POLY: val(word_add of 8 :N word values bounded by
  1, 2, ..., 128) <= 255 when 256 <= 2^dimindex(:N)
- SBND_K_POLY: val(word_and (word k:N word) X) <= k for k <= 128 and
  8 <= dimindex(:N)
These are polymorphic in N because the eta4 ARM stepper produces
popcount summands at int128 width — non-poly versions of MATCH_MP_TAC
failed to unify against the actual goal form.

Proof structure (mlkem_rej_uniform_VARIABLE_TIME-inspired):
- Phase 0: setup
- Phase 1: ENSURES_SEQUENCE at pc+256 splitting main computation from
  writeback branch (245 ARM steps)
- Phase 2: WOP characterization for N
- ENSURES_WHILE_UP_TAC subgoals 1-5 (0<N, pre-loop init, loop body
  3a/3b SIMD compute + ST1 stores, backedge, post-loop with 3-way case
  split)

0 CHEATs. Total proof execution time: ~30 minutes.

Signed-off-by: Jake Massimo <jakemas@amazon.com>
@jakemas jakemas force-pushed the mldsa-rej-uniform-eta4 branch from ed7b774 to ad3ba62 Compare May 19, 2026 00:48
jakemas added 4 commits May 18, 2026 17:48
Add test_mldsa_rej_uniform_eta4 plus a pure-C reference
(reference_mldsa_rej_uniform_eta4) and a 4096-byte
mldsa_rej_uniform_eta_table mirroring arm/proofs/mldsa_rej_uniform_eta_table.ml.
Test runs ARM-only (gated via functionaltest(arm,...) and an explicit
arch check at function entry), with buflens that satisfy the asm
precondition 8 | buflen and 8 <= buflen.

Also regenerate include/s2n-bignum-c89.h so the new prototype is
available in the C89 variant header.
The previous arm-runtime check still allowed the link-time symbol
reference to mldsa_rej_uniform_eta4, which is undefined on x86 since
the function is ARM-only. Wrap the test body in #ifdef __x86_64__ so
the symbol is never referenced when targeting x86.
Using functionaltest(arm, ...) puts the test in the 'inapplicable'
bucket on x86, which the summary equation
  successes + skipped == tested
does not count, producing 'Testing all passed but is incomplete' and
a non-zero exit. Switch to functionaltest(all, ...) so the function
runs everywhere; the body is already gated with #ifdef __x86_64__ to
return 0 without referencing the ARM-only symbol.

Mirrors the pattern used for test_mldsa_poly_use_hint_32 in PR awslabs#372.
@jakemas jakemas marked this pull request as ready for review May 20, 2026 03:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant