Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/hol_light.yml
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ jobs:
needs: ["mldsa_specs.ml", "mldsa_utils.ml", "mldsa_zetas.ml", "subroutine_signatures.ml"]
- name: intt_avx2_asm
needs: ["mldsa_specs.ml", "mldsa_utils.ml", "mldsa_zetas.ml", "subroutine_signatures.ml"]
- name: rej_uniform_avx2_asm
needs: ["mldsa_specs.ml", "mldsa_rej_uniform_table.ml"]
- name: nttunpack_avx2_asm
needs: ["mldsa_specs.ml", "subroutine_signatures.ml"]
- name: pointwise_avx2_asm
Expand Down
5 changes: 3 additions & 2 deletions BIBLIOGRAPHY.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ source code and documentation.
- [dev/x86_64/src/poly_use_hint_88_avx2.c](dev/x86_64/src/poly_use_hint_88_avx2.c)
- [dev/x86_64/src/polyz_unpack_17_avx2.c](dev/x86_64/src/polyz_unpack_17_avx2.c)
- [dev/x86_64/src/polyz_unpack_19_avx2.c](dev/x86_64/src/polyz_unpack_19_avx2.c)
- [dev/x86_64/src/rej_uniform_avx2.c](dev/x86_64/src/rej_uniform_avx2.c)
- [dev/x86_64/src/rej_uniform_avx2_asm.S](dev/x86_64/src/rej_uniform_avx2_asm.S)
- [dev/x86_64/src/rej_uniform_eta2_avx2.c](dev/x86_64/src/rej_uniform_eta2_avx2.c)
- [dev/x86_64/src/rej_uniform_eta4_avx2.c](dev/x86_64/src/rej_uniform_eta4_avx2.c)
- [mldsa/src/native/x86_64/src/intt_avx2_asm.S](mldsa/src/native/x86_64/src/intt_avx2_asm.S)
Expand All @@ -253,7 +253,7 @@ source code and documentation.
- [mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c](mldsa/src/native/x86_64/src/poly_use_hint_88_avx2.c)
- [mldsa/src/native/x86_64/src/polyz_unpack_17_avx2.c](mldsa/src/native/x86_64/src/polyz_unpack_17_avx2.c)
- [mldsa/src/native/x86_64/src/polyz_unpack_19_avx2.c](mldsa/src/native/x86_64/src/polyz_unpack_19_avx2.c)
- [mldsa/src/native/x86_64/src/rej_uniform_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_avx2.c)
- [mldsa/src/native/x86_64/src/rej_uniform_avx2_asm.S](mldsa/src/native/x86_64/src/rej_uniform_avx2_asm.S)
- [mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta2_avx2.c)
- [mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c](mldsa/src/native/x86_64/src/rej_uniform_eta4_avx2.c)
- [proofs/hol_light/x86_64/mldsa/intt_avx2_asm.S](proofs/hol_light/x86_64/mldsa/intt_avx2_asm.S)
Expand All @@ -264,6 +264,7 @@ source code and documentation.
- [proofs/hol_light/x86_64/mldsa/pointwise_acc_l7_avx2_asm.S](proofs/hol_light/x86_64/mldsa/pointwise_acc_l7_avx2_asm.S)
- [proofs/hol_light/x86_64/mldsa/pointwise_avx2_asm.S](proofs/hol_light/x86_64/mldsa/pointwise_avx2_asm.S)
- [proofs/hol_light/x86_64/mldsa/poly_caddq_avx2_asm.S](proofs/hol_light/x86_64/mldsa/poly_caddq_avx2_asm.S)
- [proofs/hol_light/x86_64/mldsa/rej_uniform_avx2_asm.S](proofs/hol_light/x86_64/mldsa/rej_uniform_avx2_asm.S)

### `Round3_Spec`

Expand Down
3 changes: 2 additions & 1 deletion dev/x86_64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ static MLD_INLINE int mld_rej_uniform_native(int32_t *r, unsigned len,
}

/* Safety: outlen is at most MLDSA_N and, hence, this cast is safe. */
return (int)mld_rej_uniform_avx2(r, buf);
return (int)mld_rej_uniform_avx2_asm(r, buf,
(const uint8_t *)mld_rej_uniform_table);
}

#if !defined(MLD_CONFIG_NO_KEYPAIR_API)
Expand Down
17 changes: 14 additions & 3 deletions dev/x86_64/src/arith_native_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,21 @@ __contract__(
r[i] == old(*(int32_t (*)[MLDSA_N])r)[j])))
);

#define mld_rej_uniform_avx2 MLD_NAMESPACE(mld_rej_uniform_avx2)
#define mld_rej_uniform_avx2_asm MLD_NAMESPACE(rej_uniform_avx2_asm)
/* This contract must be kept in sync with the HOL-Light specification
* in proofs/hol_light/x86_64/proofs/rej_uniform_avx2_asm.ml */
MLD_MUST_CHECK_RETURN_VALUE
unsigned mld_rej_uniform_avx2(int32_t *r,
const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN]);
unsigned mld_rej_uniform_avx2_asm(
int32_t *r, const uint8_t buf[MLD_AVX2_REJ_UNIFORM_BUFLEN],
const uint8_t *table)
__contract__(
requires(memory_no_alias(r, sizeof(int32_t) * MLDSA_N))
requires(memory_no_alias(buf, 840))
requires(table == (const uint8_t *)mld_rej_uniform_table)
assigns(memory_slice(r, sizeof(int32_t) * MLDSA_N))
ensures(return_value <= MLDSA_N)
ensures(array_bound(r, 0, return_value, 0, MLDSA_Q))
);

#if !defined(MLD_CONFIG_NO_KEYPAIR_API)
#define mld_rej_uniform_eta2_avx2 MLD_NAMESPACE(mld_rej_uniform_eta2_avx2)
Expand Down
126 changes: 0 additions & 126 deletions dev/x86_64/src/rej_uniform_avx2.c

This file was deleted.

172 changes: 172 additions & 0 deletions dev/x86_64/src/rej_uniform_avx2_asm.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

/* References
* ==========
*
* - [REF_AVX2]
* CRYSTALS-Dilithium optimized AVX2 implementation
* Bai, Ducas, Kiltz, Lepoint, Lyubashevsky, Schwabe, Seiler, Stehlé
* https://github.com/pq-crystals/dilithium/tree/master/avx2
*/

/*
* This file is derived from the public domain
* AVX2 Dilithium implementation @[REF_AVX2].
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_X86_64_DEFAULT) && \
!defined(MLD_CONFIG_MULTILEVEL_NO_SHARED)
/* simpasm: header-end */

#define out %rdi
#define in %rsi
#define tab %rdx

#define ctr %eax
#define pos %ecx

#define good %r8d
#define cnt %r9d
#define tmp %r10

#define idx8 %ymm0
#define mask %ymm1
#define bound %ymm2
#define data %ymm3
#define cmp_result %ymm4

.text

/*
* unsigned mld_rej_uniform_avx2_asm(int32_t *r, const uint8_t *buf,
* const uint8_t *table)
*
* Rejection sampling of uniform polynomial coefficients.
* Extracts 23-bit values from a byte buffer and accepts those < MLDSA_Q.
*
* Arguments: out (rdi): pointer to output coefficient array r (256 x int32_t)
* in (rsi): pointer to input byte buffer buf (840 bytes)
* tab (rdx): pointer to rejection sampling lookup table (256x8 bytes)
*
* Returns: ctr (eax): number of valid coefficients written to r
*/
.balign 4
.global MLD_ASM_NAMESPACE(rej_uniform_avx2_asm)
MLD_ASM_FN_SYMBOL(rej_uniform_avx2_asm)

/*
* Construct the shuffle mask for extracting 8 x 23-bit values from 24 bytes.
*
* After vpermq with 0x94, the 32 loaded bytes are rearranged as:
* Low 128 bits: bytes [0..15] (original 64-bit lanes 0, 1)
* High 128 bits: bytes [8..23] (original 64-bit lanes 1, 2)
*
* vpshufb then picks 3-byte groups and zero-pads each to a 32-bit lane:
* Low half: {0,1,2,FF, 3,4,5,FF, 6,7,8,FF, 9,10,11,FF}
* High half: {4,5,6,FF, 7,8,9,FF, 10,11,12,FF, 13,14,15,FF}
*
* This extracts 8 non-overlapping 3-byte windows from the first 24 input bytes.
*/
movq $0xFF050403FF020100, tmp
vmovq tmp, %xmm0
movq $0xFF0B0A09FF080706, tmp
vpinsrq $1, tmp, %xmm0, %xmm0
movq $0xFF090807FF060504, tmp
vmovq tmp, %xmm3
movq $0xFF0F0E0DFF0C0B0A, tmp
vpinsrq $1, tmp, %xmm3, %xmm3
vinserti128 $1, %xmm3, idx8, idx8

// Construct broadcast constants
movl $0x7FFFFF, good
vmovd good, %xmm1
vpbroadcastd %xmm1, mask // mask: 23-bit extraction

movl $8380417, good // MLDSA_Q
vmovd good, %xmm2
vpbroadcastd %xmm2, bound // bound: rejection threshold

// Initialize counters
xorl ctr, ctr // ctr = 0
xorl pos, pos // pos = 0

/*
* Main SIMD loop: process 24 input bytes into up to 8 coefficients
* per iteration. Loops while ctr <= MLDSA_N - 8 and pos <= BUFLEN - 32.
*/
rej_uniform_avx2_asm_loop:
cmpl $248, ctr // MLDSA_N - 8
ja rej_uniform_avx2_asm_scalar
cmpl $808, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 32
ja rej_uniform_avx2_asm_scalar

vmovdqu (in, %rcx), data // load 32 bytes from buf[pos]
addl $24, pos // advance pos
vpermq $0x94, data, data // rearrange 64-bit lanes: [2,1,1,0]
vpshufb idx8, data, data // extract 8 x 3-byte groups
vpand mask, data, data // mask to 23 bits

vpsubd bound, data, cmp_result // d - Q: negative if d < Q (valid)
vmovmskps cmp_result, good // extract sign bits as 8-bit mask

popcntl good, cnt // count valid coefficients

vmovq (tab, %r8, 8), %xmm4 // load permutation from table[good]
vpmovzxbd %xmm4, cmp_result // zero-extend to 8 dword indices
vpermd data, cmp_result, data // compact valid coefficients to front

vmovdqu data, (out, %rax, 4) // store at r[ctr]
addl cnt, ctr // ctr += popcount(good)

jmp rej_uniform_avx2_asm_loop

/*
* Scalar fallback loop: process remaining bytes one coefficient at a time.
* Loops while ctr < MLDSA_N and pos <= BUFLEN - 3.
*/
rej_uniform_avx2_asm_scalar:
cmpl $256, ctr // MLDSA_N
jae rej_uniform_avx2_asm_done
cmpl $837, pos // MLD_AVX2_REJ_UNIFORM_BUFLEN - 3
ja rej_uniform_avx2_asm_done

movzwl (in, %rcx), good // load 2 bytes at buf[pos]
movzbl 2(in, %rcx), cnt // load third byte
shll $16, cnt
orl cnt, good
andl $0x7FFFFF, good // mask to 23 bits
addl $3, pos // advance pos

cmpl $8380417, good // MLDSA_Q
jae rej_uniform_avx2_asm_scalar // reject if >= Q

movl good, (out, %rax, 4) // store valid coefficient
addl $1, ctr // ctr++
jmp rej_uniform_avx2_asm_scalar

rej_uniform_avx2_asm_done:
ret

/* To facilitate single-compilation-unit (SCU) builds, undefine all macros.
* Don't modify by hand -- this is auto-generated by scripts/autogen. */
#undef out
#undef in
#undef tab
#undef ctr
#undef pos
#undef good
#undef cnt
#undef tmp
#undef idx8
#undef mask
#undef bound
#undef data
#undef cmp_result

/* simpasm: footer-start */
#endif /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \
*/
3 changes: 1 addition & 2 deletions mldsa/mldsa_native.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@
#include "src/native/x86_64/src/poly_use_hint_88_avx2.c"
#include "src/native/x86_64/src/polyz_unpack_17_avx2.c"
#include "src/native/x86_64/src/polyz_unpack_19_avx2.c"
#include "src/native/x86_64/src/rej_uniform_avx2.c"
#include "src/native/x86_64/src/rej_uniform_eta2_avx2.c"
#include "src/native/x86_64/src/rej_uniform_eta4_avx2.c"
#include "src/native/x86_64/src/rej_uniform_table.c"
Expand Down Expand Up @@ -759,7 +758,7 @@
#undef mld_poly_use_hint_88_avx2
#undef mld_polyz_unpack_17_avx2
#undef mld_polyz_unpack_19_avx2
#undef mld_rej_uniform_avx2
#undef mld_rej_uniform_avx2_asm
#undef mld_rej_uniform_eta2_avx2
#undef mld_rej_uniform_eta4_avx2
#undef mld_rej_uniform_table
Expand Down
Loading
Loading