Skip to content

ML-DSA x86 AVX2 rej_uniform CORRECT proof#401

Open
jakemas wants to merge 24 commits into
awslabs:mainfrom
jakemas:jakemas/mldsa-rej-uniform-x86-proof
Open

ML-DSA x86 AVX2 rej_uniform CORRECT proof#401
jakemas wants to merge 24 commits into
awslabs:mainfrom
jakemas:jakemas/mldsa-rej-uniform-x86-proof

Conversation

@jakemas
Copy link
Copy Markdown
Contributor

@jakemas jakemas commented May 6, 2026

Issue #, if available:

Adds a CORRECT proof of the x86_64 AVX2 rej_uniform.S used in

This function implements ML-DSA's 23-bit rejection sampling. A main AVX2 loop processes 24 bytes (8 coefficients) per iteration via VPERMQ + VPSHUFB extraction, VPAND masking, VPSUBD + VMOVMSKPS rejection, and VPERMD + table compaction; a scalar tail handles any remaining bytes.

PR dependencies

This PR is stacked on two prerequisite PRs — the diff against main will include their changes until they land:

Once #378 and #387 merge, this branch will be rebased onto main.

Description of changes:

Adds proof for

  • MLDSA_REJ_UNIFORM_CORRECT

The specification matches #378's ARM counterpart: after the function the output buffer contains num_of_wordlist (SUB_LIST (0, 256) (REJ_SAMPLE inlist)) with RAX = its length, where REJ_SAMPLE is the FIPS-204 reference filter

REJ_SAMPLE l = FILTER (\x. val x < 8380417)
                      (MAP (\x. word (val x MOD 2^23)) l)

This is full functional correctness, not just output bounds. Subroutine variants (MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT, MLDSA_REJ_UNIFORM_SUBROUTINE_CORRECT, Windows variants) will be added in a follow-up once this CORRECT form is reviewed.

Files added:

  • x86/mldsa/mldsa_rej_uniform.S — Intel-syntax assembly (same bytecode as the mldsa-native AT&T source, rewritten in s2n-bignum's Intel-primary convention)
  • x86_att/mldsa/mldsa_rej_uniform.S — auto-generated from the Intel source via the existing attrofy.sed pipeline; cmp-verified to produce the same object file for both WINDOWS_ABI=0 and WINDOWS_ABI=1
  • x86/proofs/mldsa_rej_uniform.ml — single-file bundled proof (~4300 lines); all helper lemmas inlined in dependency order
  • x86/proofs/mldsa_rej_uniform_table.ml — lookup table definition

Files modified:

  • x86/Makefile, x86_att/Makefile — add the new .o target
  • x86_att/attrofy.sed — extended to handle AVX2 operand-size hints (YMMWORD/XMMWORD/DWORD/WORD/BYTE PTR) and the Intel movzx reg, WORD PTR mem / movzx reg, BYTE PTR mem forms (translated to AT&T movzwl/movzbl)

The top-level theorem proves with 0 CHEATs and only the 3 standard HOL Light axioms (INFINITY_AX, ETA_AX, SELECT_AX). Total load time on a modern x86_64 machine is ~261s.

Testing

Object-code round-trip for both ABIs:

$ make -C x86 mldsa/mldsa_rej_uniform.o
$ make -C x86_att mldsa/mldsa_rej_uniform.S
# => cmp-s succeeds for WINDOWS_ABI=0 and WINDOWS_ABI=1

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

jakemas and others added 8 commits April 8, 2026 19:40
Add semantic models for three x86 AVX instructions needed by
mldsa-native's rej_uniform formal verification:

- VMOVMSKPS: Extract sign bits (bit 31) from each 32-bit lane of a
  YMM/XMM register into a GPR. Used to build a comparison mask after
  VPSUBD rejection testing.

- VPMOVZXBD: Zero-extend bytes to dwords (8->32 bit). Used to expand
  table lookup indices from the VPERMD compaction table.

- VZEROUPPER: Clear upper 128 bits of all YMM registers. Modeled as
  a no-op (like ENDBR64) since the proof framework tracks YMM
  registers as full 256-bit values and the instruction only affects
  performance, not correctness.

Includes decoder entries, instruction type constructors, semantic
definitions, simulator test cases, and execution dispatch.

Signed-off-by: jakemas <jakemas@amazon.com>
VZEROUPPER zeros bits 128-511 of ZMM0-ZMM15 while preserving the
lower 128 bits (XMM values). Model this by writing each XMM register's
current value back through the zero-extending XMM component path,
which automatically zeros the upper bits of the containing ZMM register.
The XMM-based model (XMM := read XMM s) chains through two zerotop
wrappers (zerotop_128 then zerotop_256), creating deeply nested
word_zx terms that the sematest cosimulation tactics cannot simplify.

Write to YMM directly with explicit word_subword/word_zx instead,
which only goes through one zerotop layer (zerotop_256 to ZMM).
Semantically identical: preserves lower 128 bits, zeros upper bits.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jakemas jakemas marked this pull request as draft May 6, 2026 00:47
jakemas and others added 2 commits May 6, 2026 01:25
Adds proof of rej_uniform.S as in mldsa-native:
- https://github.com/pq-code-package/mldsa-native/blob/main/mldsa/src/native/x86_64/src/rej_uniform_avx2.S
- dilithium (reference): https://github.com/pq-crystals/dilithium/blob/master/avx2/rejsample.S

Implements the 23-bit rejection sampling loop for ML-DSA. A main AVX2
loop processes 24 bytes (8 coefficients) per iteration via
VPERMQ+VPSHUFB extraction, VPAND masking, VPSUBD+VMOVMSKPS rejection,
and VPERMD+table compaction. A scalar tail handles any remaining bytes.

Specification proved:

- MLDSA_REJ_UNIFORM_CORRECT

Functional correctness: the output buffer contains
  num_of_wordlist (SUB_LIST(0, 256) (REJ_SAMPLE inlist))
and the return value (RAX) is its length, where REJ_SAMPLE is the
FIPS-204 reference
  FILTER (\x. val x < 8380417)
         (MAP (\x. word(val x MOD 2^23)) inlist)

Infrastructure:

- x86/mldsa/mldsa_rej_uniform.S: Intel-syntax assembly (same bytecode
  as the mldsa-native AT&T source, just rewritten in the s2n-bignum
  Intel-primary convention).
- x86_att/mldsa/mldsa_rej_uniform.S: auto-generated from the Intel
  source via the existing attrofy.sed pipeline; cmp-verified to
  produce the same object file.
- x86_att/attrofy.sed: extended to handle AVX2 operand-size hints
  (YMMWORD/XMMWORD/DWORD/WORD/BYTE PTR) and the Intel
  "movzx <reg>, WORD PTR <mem>" / "movzx <reg>, BYTE PTR <mem>" forms
  (translated to AT&T movzwl/movzbl).
- x86/Makefile and x86_att/Makefile: add the new .o target.

Proof file is a single self-contained ~4300-line bundle; helper
lemmas (READ_3BYTES_EL, BYTE_BRIDGE_3BYTES, BYTE_BRIDGE_3BYTES_2STATE,
BYTES32_TO_BYTES, ACCEPT_REJ_SAMPLE_SINGLETON, PIVOT_VAL_EQ,
MEMORY_CONJUNCT_CLOSURE, VAL_RCX_ADD3, SCALAR_BODY_LEMMA) are inlined
in dependency order. Top-level theorem proves with 0 CHEATs and only
the 3 standard HOL Light axioms (INFINITY_AX, ETA_AX, SELECT_AX).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Jake Massimo <jakemas@amazon.com>
…ature

The collect-signatures.py check was failing on the new PR because:
1. include/s2n-bignum.h lacked an extern declaration for mldsa_rej_uniform
2. ARM's rej_uniform uses a different symbol name (PQCP_MLDSA_NATIVE_..._asm),
   so mldsa_rej_uniform only exists in x86 — needs to be on the
   onlyInX86 allowlist in tools/collect-signatures.py

Adds the extern to both the C99 and C89 public headers, appends to the
onlyInX86 list, and regenerates x86/proofs/subroutine_signatures.ml.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Jake Massimo <jakemas@amazon.com>
@jakemas jakemas force-pushed the jakemas/mldsa-rej-uniform-x86-proof branch from 38c3035 to 6a81e1d Compare May 6, 2026 01:25
jakemas and others added 10 commits May 6, 2026 03:36
Restructure MLDSA_REJ_UNIFORM_CORRECT to match the standard s2n-bignum
shape (no stackpointer in forall, no `read RSP s = stackpointer` in
state predicate, RSP not in MAYCHANGE). GHOST_INTRO_TAC binds
stackpointer internally so SCALAR_BODY_LEMMA's invariants continue to
work. Drop the unused `nonoverlapping (stackpointer, 8) (res, 1024)`
precondition from SCALAR_BODY_LEMMA.

Add the two user-facing subroutine wrappers:

  - MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT (via
    X86_PROMOTE_RETURN_NOSTACK_TAC)
  - MLDSA_REJ_UNIFORM_SUBROUTINE_CORRECT (via ADD_IBT_RULE)

Remove unused helper lemmas (SHUFFLE_LOW_LANE, SHUFFLE_HIGH_LANE,
LITTLE_ENDIAN_3BYTES, SUB_LIST_0_LENGTH, ODD_ADD_2,
REJ_SAMPLE_ITERATION, REJ_SAMPLE_SUBLIST_256_BOUNDED,
R9_POPCNT_BRIDGE), progress-logging debug helpers (LOG/DBG), all
Printf.printf checkpoint prints, and the disabled post-loop
scratch block. All three theorems verify with no CHEATs and the
three standard HOL Light axioms.

Signed-off-by: Jake Massimo <jakemas@amazon.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add the two new theorems that ship in the x86 proof:

  - MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_CORRECT
  - MLDSA_REJ_UNIFORM_SUBROUTINE_CORRECT

No Windows variants are included (the assembly is linux-only per
mldsa-native).

Signed-off-by: Jake Massimo <jakemas@amazon.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extends the x86 AVX2 mldsa_rej_uniform proof with a Windows x64 ABI
variant, mirroring the pattern used by mldsa_reduce.

In the .S, a Windows-only prolog pushes rdi/rsi, reserves 160 bytes,
spills xmm6..xmm15 (callee-saved under Windows but clobbered by
vzeroupper and the body's ymm use), and marshals rcx/rdx/r8 into
rdi/rsi/rdx so the SysV body can run unchanged. A matching epilog
restores xmm6..xmm15 and unwinds the stack.

The HOL Light proof introduces mldsa_rej_uniform_windows_mc/tmc from
the .obj, defines MLDSA_REJ_UNIFORM_WINDOWS_TMC_EXEC, and proves
NOIBT_WINDOWS_SUBROUTINE_CORRECT by: stepping the 16-insn prolog to
pc+91, invoking the existing MLDSA_REJ_UNIFORM_CORRECT linux spec via
X86_BIGSTEP_TAC, stepping the 14-insn epilog, and using
ENSURES_PRESERVED_TAC over RDI/RSI and ZMM6..ZMM15 :> bottomhalf ::
bottomhalf to discharge the XMM6..XMM15 preservation required by the
Windows MAYCHANGE. Final subsumption closes with MAYCHANGE_ZMM_QUARTER /
MAYCHANGE_YMM_SSE_QUARTER rewrites and WORD_BLAST. The mc-level
SUBROUTINE_CORRECT wrapper follows via ADD_IBT_RULE.

Linux .o is bytewise unchanged.

Signed-off-by: Jake Massimo <jakemas@amazon.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Jake Massimo <jakemas@amazon.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Mirrors the Windows ABI prolog/epilog added in Intel syntax.

Signed-off-by: Jake Massimo <jakemas@amazon.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
None of the other ML-DSA x86_64 assembly routines issue vzeroupper;
drop it from rej_uniform for consistency. Shaves 3 bytes off the
routine.

The HOL Light proof is adjusted accordingly:
  - nonoverlapping (word pc, 246) → 243 in both SCALAR_BODY_LEMMA and
    MLDSA_REJ_UNIFORM_CORRECT (LENGTH mldsa_rej_uniform_tmc dropped
    from 246 to 243).
  - MLDSA_REJ_UNIFORM_CORRECT post-condition RIP = word(pc + 245) →
    word(pc + 242) (ret now sits where vzeroupper used to).
  - The two X86_STEPS_TAC [55] invocations that stepped vzeroupper in
    the post-loop exit are removed; RIP is already at ret.
  - Windows epilog simulation keeps (18--31) — step 31 is now the ret
    itself (was ret-after-vzeroupper previously).
  - Bytecode literal in the `define_assert_from_elf` block is
    regenerated.

Linux and Windows SUBROUTINE_CORRECT theorems both still prove in
~7-8 min native build. AT&T translation regenerated.

Signed-off-by: Jake Massimo <jakemas@amazon.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Per the s2n-bignum PR checklist, adds the previously-deferred
tests/test.c and benchmarks/benchmark.c entries for mldsa_rej_uniform
(x86-only, !arm-gated):

- test_mldsa_rej_uniform + reference_mldsa_rej_uniform: cross-checks
  the asm against a 23-bit FIPS-204 scalar reference. Alternating
  iterations overwrite ~25% of 24-bit groups with values in
  [q, 2^23-1] so both the AVX reject-and-compact path and the
  scalar-tail early-exit branch are exercised (not just the trivial
  ~0.1% natural-rejection case that fills 256/256 every time).
- call_mldsa_rej_uniform + timingtest(!arm, ...) in benchmark.c.
- mldsa_rej_uniform_table (2048 bytes) in both files, derived from
  x86/proofs/mldsa_rej_uniform_table.ml and cast to const uint64_t*
  at the call site to match the public signature.

Verified locally: ./tests/test -400 mldsa_rej_uniform -> All OK, with
200/400 iterations hitting the rejection path across ~30 distinct
accepted counts; ./benchmarks/benchmark -50 mldsa_rej_uniform runs
clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For stylistic consistency with the ARM convention (Lfoo_bar naming)
and to preempt macOS assembler gripes about bare local labels in x86
assembly, rename the three internal branch targets:

  mldsa_rej_uniform_loop   -> Lmldsa_rej_uniform_loop
  mldsa_rej_uniform_scalar -> Lmldsa_rej_uniform_scalar
  mldsa_rej_uniform_done   -> Lmldsa_rej_uniform_done

The AT&T translation is regenerated. Machine code is byte-identical:
only the ELF symbol-table strings change. Verified via objdump -d:
all rel8/rel32 displacements (77 46, eb b3, 73 cc, ...) are unchanged
at the same offsets, and cmp-s succeeds between the Intel and AT&T
assembled objects for both WINDOWS_ABI=0 and WINDOWS_ABI=1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jakemas jakemas marked this pull request as ready for review May 9, 2026 03:45
CI failure root cause: on the ARM test job, the dispatch-table entry
used functionaltest(!arm, ...), so the wrapper logged the test as
"inapplicable" (not "skipped"). That made successes + skipped !=
tested, which triggers the "Testing all passed but is incomplete"
early-exit path and non-zero make status, even though every test
that actually ran passed.

The correct pattern (used by mldsa_reduce, mldsa_nttunpack, etc. for
other x86-only mldsa functions) is functionaltest(all, ...); the
test function itself already has the
get_arch_name() != ARCH_X86_64 early-return guard, so on ARM it now
counts as a normal successful test that returns 0 without calling the
x86 assembly.

The x86 CI job was canceled (not failed on its own) because the arm
job failing first triggered the matrix fail-fast.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
jakemas added 3 commits May 13, 2026 14:24
Append the MEMSAFE proof to the existing functional correctness proof
file. The proof establishes that all memory accesses by
mldsa_rej_uniform fall within the provided buffer, table, result, and
stack regions.

Theorems added:
  - SCALAR_BODY_LEMMA_MEMSAFE: events-tracking variant of the
    SCALAR_BODY_LEMMA helper, used in the post-loop scalar tail
    invariant body.
  - MLDSA_REJ_UNIFORM_MEMSAFE: core memory-safety theorem for the
    BUTLAST tmc machine code (no IBT prefix).
  - MLDSA_REJ_UNIFORM_NOIBT_SUBROUTINE_SAFE,
    MLDSA_REJ_UNIFORM_SUBROUTINE_SAFE: subroutine variants for the
    full machine code (SystemV ABI).

Mirrors the pattern in x86/proofs/mlkem_rej_uniform_VARIABLE_TIME.ml:
a separate (exists e2. read events s = APPEND e2 e /\
memaccess_inbounds e2 R W) post that is independent of private input,
since rejection sampling has variable-time memory access by design
(the sequence of writes depends on which input bytes pass the
< MLDSA_Q filter).

Windows ABI memory safety variants are intentionally omitted, mirroring
the mlkem-native VARIABLE_TIME pattern (see TODO comment there
explaining the WINDOWS_X86_WRAP_STACK_TAC limitation with non-safety
postconditions).

The proof has zero CHEAT_TACs and passes check_axioms() with no
user-introduced axioms.

Signed-off-by: Jake Massimo <jakemas@amazon.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant