diff --git a/.github/workflows/hol_light.yml b/.github/workflows/hol_light.yml index f00a1afa1..ec2da7f1b 100644 --- a/.github/workflows/hol_light.yml +++ b/.github/workflows/hol_light.yml @@ -97,6 +97,10 @@ jobs: needs: ["aarch64_utils.ml"] - name: mldsa_poly_chknorm needs: ["aarch64_utils.ml"] + - name: poly_use_hint_32_aarch64_asm + needs: ["mldsa_specs.ml", "aarch64_utils.ml", "subroutine_signatures.ml"] + - name: poly_use_hint_88_aarch64_asm + needs: ["mldsa_specs.ml", "aarch64_utils.ml", "subroutine_signatures.ml"] - name: keccak_f1600_x1_scalar needs: ["keccak_spec.ml"] - name: keccak_f1600_x1_v84a diff --git a/dev/aarch64_clean/src/arith_native_aarch64.h b/dev/aarch64_clean/src/arith_native_aarch64.h index 4b848d974..11cf75223 100644 --- a/dev/aarch64_clean/src/arith_native_aarch64.h +++ b/dev/aarch64_clean/src/arith_native_aarch64.h @@ -117,10 +117,32 @@ __contract__( #if !defined(MLD_CONFIG_NO_VERIFY_API) #define mld_poly_use_hint_32_asm MLD_NAMESPACE(poly_use_hint_32_asm) -void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h); +void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/poly_use_hint_32_aarch64_asm.ml */ +__contract__( + requires(memory_no_alias(b, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(b, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(b, 0, MLDSA_N, 0, 16)) +); #define mld_poly_use_hint_88_asm MLD_NAMESPACE(poly_use_hint_88_asm) -void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h); +void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/poly_use_hint_88_aarch64_asm.ml */ +__contract__( + requires(memory_no_alias(b, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(b, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(b, 0, MLDSA_N, 0, 44)) +); #endif /* !MLD_CONFIG_NO_VERIFY_API */ #define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm) diff --git a/dev/aarch64_opt/src/arith_native_aarch64.h b/dev/aarch64_opt/src/arith_native_aarch64.h index 4b848d974..11cf75223 100644 --- a/dev/aarch64_opt/src/arith_native_aarch64.h +++ b/dev/aarch64_opt/src/arith_native_aarch64.h @@ -117,10 +117,32 @@ __contract__( #if !defined(MLD_CONFIG_NO_VERIFY_API) #define mld_poly_use_hint_32_asm MLD_NAMESPACE(poly_use_hint_32_asm) -void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h); +void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/poly_use_hint_32_aarch64_asm.ml */ +__contract__( + requires(memory_no_alias(b, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(b, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(b, 0, MLDSA_N, 0, 16)) +); #define mld_poly_use_hint_88_asm MLD_NAMESPACE(poly_use_hint_88_asm) -void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h); +void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/poly_use_hint_88_aarch64_asm.ml */ +__contract__( + requires(memory_no_alias(b, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(b, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(b, 0, MLDSA_N, 0, 44)) +); #endif /* !MLD_CONFIG_NO_VERIFY_API */ #define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm) diff --git a/mldsa/src/native/aarch64/src/arith_native_aarch64.h b/mldsa/src/native/aarch64/src/arith_native_aarch64.h index 4b848d974..11cf75223 100644 --- a/mldsa/src/native/aarch64/src/arith_native_aarch64.h +++ b/mldsa/src/native/aarch64/src/arith_native_aarch64.h @@ -117,10 +117,32 @@ __contract__( #if !defined(MLD_CONFIG_NO_VERIFY_API) #define mld_poly_use_hint_32_asm MLD_NAMESPACE(poly_use_hint_32_asm) -void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h); +void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/poly_use_hint_32_aarch64_asm.ml */ +__contract__( + requires(memory_no_alias(b, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(b, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(b, 0, MLDSA_N, 0, 16)) +); #define mld_poly_use_hint_88_asm MLD_NAMESPACE(poly_use_hint_88_asm) -void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h); +void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h) +/* This must be kept in sync with the HOL-Light specification + * in proofs/hol_light/aarch64/proofs/poly_use_hint_88_aarch64_asm.ml */ +__contract__( + requires(memory_no_alias(b, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(a, sizeof(int32_t) * MLDSA_N)) + requires(memory_no_alias(h, sizeof(int32_t) * MLDSA_N)) + requires(array_bound(a, 0, MLDSA_N, 0, MLDSA_Q)) + requires(array_bound(h, 0, MLDSA_N, 0, 2)) + assigns(memory_slice(b, sizeof(int32_t) * MLDSA_N)) + ensures(array_bound(b, 0, MLDSA_N, 0, 44)) +); #endif /* !MLD_CONFIG_NO_VERIFY_API */ #define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm) diff --git a/proofs/cbmc/poly_use_hint_native_aarch64/Makefile b/proofs/cbmc/poly_use_hint_native_aarch64/Makefile new file mode 100644 index 000000000..b5bed0406 --- /dev/null +++ b/proofs/cbmc/poly_use_hint_native_aarch64/Makefile @@ -0,0 +1,53 @@ +# Copyright (c) The mldsa-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +include ../Makefile_params.common + +HARNESS_ENTRY = harness +HARNESS_FILE = poly_use_hint_native_aarch64_harness + +# This should be a unique identifier for this proof, and will appear on the +# Litani dashboard. It can be human-readable and contain spaces if you wish. +PROOF_UID = poly_use_hint_native_aarch64 + +# We need to set MLD_CHECK_APIS as otherwise mldsa/src/native/api.h won't be +# included, which contains the CBMC specifications. +DEFINES += -DMLD_CONFIG_USE_NATIVE_BACKEND_ARITH -DMLD_CONFIG_ARITH_BACKEND_FILE="\"$(SRCDIR)/mldsa/src/native/aarch64/meta.h\"" -DMLD_CHECK_APIS +INCLUDES += + +REMOVE_FUNCTION_BODY += +UNWINDSET += + +PROOF_SOURCES += $(PROOFDIR)/$(HARNESS_FILE).c +PROJECT_SOURCES += $(SRCDIR)/mldsa/src/poly_kl.c + +ifeq ($(MLD_CONFIG_PARAMETER_SET),44) + CHECK_FUNCTION_CONTRACTS=mld_poly_use_hint_88_native + USE_FUNCTION_CONTRACTS=mld_poly_use_hint_88_asm +else ifeq ($(MLD_CONFIG_PARAMETER_SET),65) + CHECK_FUNCTION_CONTRACTS=mld_poly_use_hint_32_native + USE_FUNCTION_CONTRACTS=mld_poly_use_hint_32_asm +else ifeq ($(MLD_CONFIG_PARAMETER_SET),87) + CHECK_FUNCTION_CONTRACTS=mld_poly_use_hint_32_native + USE_FUNCTION_CONTRACTS=mld_poly_use_hint_32_asm +endif +USE_FUNCTION_CONTRACTS+=mld_sys_check_capability +APPLY_LOOP_CONTRACTS=on +USE_DYNAMIC_FRAMES=1 + +# Disable any setting of EXTERNAL_SAT_SOLVER, and choose SMT backend instead +EXTERNAL_SAT_SOLVER= +CBMCFLAGS=--smt2 + +FUNCTION_NAME = poly_use_hint_native_aarch64 + +# If this proof is found to consume huge amounts of RAM, you can set the +# EXPENSIVE variable. With new enough versions of the proof tools, this will +# restrict the number of EXPENSIVE CBMC jobs running at once. See the +# documentation in Makefile.common under the "Job Pools" heading for details. +# EXPENSIVE = true + +# This function is large enough to need... +CBMC_OBJECT_BITS = 8 + +include ../Makefile.common diff --git a/proofs/cbmc/poly_use_hint_native_aarch64/poly_use_hint_native_aarch64_harness.c b/proofs/cbmc/poly_use_hint_native_aarch64/poly_use_hint_native_aarch64_harness.c new file mode 100644 index 000000000..3b0e943dd --- /dev/null +++ b/proofs/cbmc/poly_use_hint_native_aarch64/poly_use_hint_native_aarch64_harness.c @@ -0,0 +1,24 @@ +// Copyright (c) The mldsa-native project authors +// SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +#include +#include "cbmc.h" +#include "params.h" + +#if MLDSA_GAMMA2 == ((MLDSA_Q - 1) / 88) +int mld_poly_use_hint_88_native(int32_t *b, const int32_t *a, const int32_t *h); +#else +int mld_poly_use_hint_32_native(int32_t *b, const int32_t *a, const int32_t *h); +#endif + +void harness(void) +{ + int32_t *b, *a, *h; + int t; + +#if MLDSA_GAMMA2 == ((MLDSA_Q - 1) / 88) + t = mld_poly_use_hint_88_native(b, a, h); +#else + t = mld_poly_use_hint_32_native(b, a, h); +#endif +} diff --git a/proofs/hol_light/README.md b/proofs/hol_light/README.md index c03451796..dfa1bebd6 100644 --- a/proofs/hol_light/README.md +++ b/proofs/hol_light/README.md @@ -54,6 +54,8 @@ echo '1+1;;' | nc -w 5 127.0.0.1 2012 - ML-DSA Arithmetic: * AArch64 poly_caddq: [mldsa_poly_caddq.S](aarch64/mldsa/mldsa_poly_caddq.S) * AArch64 poly_chknorm: [mldsa_poly_chknorm.S](aarch64/mldsa/mldsa_poly_chknorm.S) + * AArch64 poly_use_hint (l=5,7): [poly_use_hint_32_aarch64_asm.S](aarch64/mldsa/poly_use_hint_32_aarch64_asm.S) + * AArch64 poly_use_hint (l=4): [poly_use_hint_88_aarch64_asm.S](aarch64/mldsa/poly_use_hint_88_aarch64_asm.S) * AArch64 pointwise multiplication: [mldsa_pointwise.S](aarch64/mldsa/mldsa_pointwise.S) * AArch64 pointwise multiplication-accumulation (l=4): [mldsa_pointwise_acc_l4.S](aarch64/mldsa/mldsa_pointwise_acc_l4.S) * AArch64 pointwise multiplication-accumulation (l=5): [mldsa_pointwise_acc_l5.S](aarch64/mldsa/mldsa_pointwise_acc_l5.S) diff --git a/proofs/hol_light/aarch64/Makefile b/proofs/hol_light/aarch64/Makefile index 61766d30c..f8a78fbf0 100644 --- a/proofs/hol_light/aarch64/Makefile +++ b/proofs/hol_light/aarch64/Makefile @@ -56,17 +56,19 @@ SPLIT=tr ';' '\n' OBJ = mldsa/mldsa_ntt.o \ mldsa/mldsa_pointwise.o \ mldsa/mldsa_poly_caddq.o \ - mldsa/mldsa_poly_chknorm.o \ - mldsa/mldsa_pointwise_acc_l4.o \ - mldsa/mldsa_pointwise_acc_l5.o \ - mldsa/mldsa_pointwise_acc_l7.o \ - mldsa/keccak_f1600_x1_scalar.o \ - mldsa/keccak_f1600_x1_v84a.o \ - mldsa/keccak_f1600_x2_v84a.o \ - mldsa/keccak_f1600_x4_v8a_scalar.o \ - mldsa/keccak_f1600_x4_v8a_v84a_scalar.o \ - mldsa/mldsa_polyz_unpack_17.o \ - mldsa/mldsa_polyz_unpack_19.o + mldsa/mldsa_poly_chknorm.o \ + mldsa/mldsa_pointwise_acc_l4.o \ + mldsa/mldsa_pointwise_acc_l5.o \ + mldsa/mldsa_pointwise_acc_l7.o \ + mldsa/keccak_f1600_x1_scalar.o \ + mldsa/keccak_f1600_x1_v84a.o \ + mldsa/keccak_f1600_x2_v84a.o \ + mldsa/keccak_f1600_x4_v8a_scalar.o \ + mldsa/keccak_f1600_x4_v8a_v84a_scalar.o \ + mldsa/mldsa_polyz_unpack_17.o \ + mldsa/mldsa_polyz_unpack_19.o \ + mldsa/poly_use_hint_32_aarch64_asm.o \ + mldsa/poly_use_hint_88_aarch64_asm.o # According to # https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms, diff --git a/proofs/hol_light/aarch64/mldsa/poly_use_hint_32_aarch64_asm.S b/proofs/hol_light/aarch64/mldsa/poly_use_hint_32_aarch64_asm.S new file mode 100644 index 000000000..3d788eeef --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/poly_use_hint_32_aarch64_asm.S @@ -0,0 +1,98 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/poly_use_hint_32_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_32_asm +_PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_32_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_32_asm +PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_32_asm: +#endif + + .cfi_startproc + mov w4, #0xe001 // =57345 + movk w4, #0x7f, lsl #16 + dup v20.4s, w4 + mov w5, #0xe100 // =57600 + movk w5, #0x7b, lsl #16 + dup v21.4s, w5 + mov w7, #0xfe00 // =65024 + movk w7, #0x7, lsl #16 + dup v22.4s, w7 + mov w11, #0x401 // =1025 + movk w11, #0x4010, lsl #16 + dup v23.4s, w11 + movi v24.4s, #0xf + mov x3, #0x10 // =16 + +Lpoly_use_hint_32_loop: + ldr q1, [x1, #0x10] + ldr q2, [x1, #0x20] + ldr q3, [x1, #0x30] + ldr q0, [x1], #0x40 + ldr q5, [x2, #0x10] + ldr q6, [x2, #0x20] + ldr q7, [x2, #0x30] + ldr q4, [x2], #0x40 + sqdmulh v17.4s, v1.4s, v23.4s + srshr v17.4s, v17.4s, #0x12 + cmgt v25.4s, v1.4s, v21.4s + mls v1.4s, v17.4s, v22.4s + bic v17.16b, v17.16b, v25.16b + add v1.4s, v1.4s, v25.4s + cmle v1.4s, v1.4s, #0 + orr v1.4s, #0x1 + mla v17.4s, v1.4s, v5.4s + and v17.16b, v17.16b, v24.16b + sqdmulh v18.4s, v2.4s, v23.4s + srshr v18.4s, v18.4s, #0x12 + cmgt v25.4s, v2.4s, v21.4s + mls v2.4s, v18.4s, v22.4s + bic v18.16b, v18.16b, v25.16b + add v2.4s, v2.4s, v25.4s + cmle v2.4s, v2.4s, #0 + orr v2.4s, #0x1 + mla v18.4s, v2.4s, v6.4s + and v18.16b, v18.16b, v24.16b + sqdmulh v19.4s, v3.4s, v23.4s + srshr v19.4s, v19.4s, #0x12 + cmgt v25.4s, v3.4s, v21.4s + mls v3.4s, v19.4s, v22.4s + bic v19.16b, v19.16b, v25.16b + add v3.4s, v3.4s, v25.4s + cmle v3.4s, v3.4s, #0 + orr v3.4s, #0x1 + mla v19.4s, v3.4s, v7.4s + and v19.16b, v19.16b, v24.16b + sqdmulh v16.4s, v0.4s, v23.4s + srshr v16.4s, v16.4s, #0x12 + cmgt v25.4s, v0.4s, v21.4s + mls v0.4s, v16.4s, v22.4s + bic v16.16b, v16.16b, v25.16b + add v0.4s, v0.4s, v25.4s + cmle v0.4s, v0.4s, #0 + orr v0.4s, #0x1 + mla v16.4s, v0.4s, v4.4s + and v16.16b, v16.16b, v24.16b + str q17, [x0, #0x10] + str q18, [x0, #0x20] + str q19, [x0, #0x30] + str q16, [x0], #0x40 + subs x3, x3, #0x1 + b.ne Lpoly_use_hint_32_loop + ret + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/aarch64/mldsa/poly_use_hint_88_aarch64_asm.S b/proofs/hol_light/aarch64/mldsa/poly_use_hint_88_aarch64_asm.S new file mode 100644 index 000000000..d40d410dd --- /dev/null +++ b/proofs/hol_light/aarch64/mldsa/poly_use_hint_88_aarch64_asm.S @@ -0,0 +1,106 @@ +/* + * Copyright (c) The mldsa-native project authors + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + */ + + +/* + * WARNING: This file is auto-derived from the mldsa-native source file + * dev/aarch64_opt/src/poly_use_hint_88_asm.S using scripts/simpasm. Do not modify it directly. + */ + +.text +.balign 4 +#ifdef __APPLE__ +.global _PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_88_asm +_PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_88_asm: +#else +.global PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_88_asm +PQCP_MLDSA_NATIVE_MLDSA44_poly_use_hint_88_asm: +#endif + + .cfi_startproc + mov w4, #0xe001 // =57345 + movk w4, #0x7f, lsl #16 + dup v20.4s, w4 + mov w5, #0x6c00 // =27648 + movk w5, #0x7e, lsl #16 + dup v21.4s, w5 + mov w7, #0xe800 // =59392 + movk w7, #0x2, lsl #16 + dup v22.4s, w7 + mov w11, #0x581 // =1409 + movk w11, #0x5816, lsl #16 + dup v23.4s, w11 + movi v24.4s, #0x2b + mov x3, #0x10 // =16 + +Lpoly_use_hint_88_loop: + ldr q1, [x1, #0x10] + ldr q2, [x1, #0x20] + ldr q3, [x1, #0x30] + ldr q0, [x1], #0x40 + ldr q5, [x2, #0x10] + ldr q6, [x2, #0x20] + ldr q7, [x2, #0x30] + ldr q4, [x2], #0x40 + sqdmulh v17.4s, v1.4s, v23.4s + srshr v17.4s, v17.4s, #0x11 + cmgt v25.4s, v1.4s, v21.4s + mls v1.4s, v17.4s, v22.4s + bic v17.16b, v17.16b, v25.16b + add v1.4s, v1.4s, v25.4s + cmle v1.4s, v1.4s, #0 + orr v1.4s, #0x1 + mla v17.4s, v1.4s, v5.4s + cmgt v25.4s, v17.4s, v24.4s + bic v17.16b, v17.16b, v25.16b + umin v17.4s, v17.4s, v24.4s + sqdmulh v18.4s, v2.4s, v23.4s + srshr v18.4s, v18.4s, #0x11 + cmgt v25.4s, v2.4s, v21.4s + mls v2.4s, v18.4s, v22.4s + bic v18.16b, v18.16b, v25.16b + add v2.4s, v2.4s, v25.4s + cmle v2.4s, v2.4s, #0 + orr v2.4s, #0x1 + mla v18.4s, v2.4s, v6.4s + cmgt v25.4s, v18.4s, v24.4s + bic v18.16b, v18.16b, v25.16b + umin v18.4s, v18.4s, v24.4s + sqdmulh v19.4s, v3.4s, v23.4s + srshr v19.4s, v19.4s, #0x11 + cmgt v25.4s, v3.4s, v21.4s + mls v3.4s, v19.4s, v22.4s + bic v19.16b, v19.16b, v25.16b + add v3.4s, v3.4s, v25.4s + cmle v3.4s, v3.4s, #0 + orr v3.4s, #0x1 + mla v19.4s, v3.4s, v7.4s + cmgt v25.4s, v19.4s, v24.4s + bic v19.16b, v19.16b, v25.16b + umin v19.4s, v19.4s, v24.4s + sqdmulh v16.4s, v0.4s, v23.4s + srshr v16.4s, v16.4s, #0x11 + cmgt v25.4s, v0.4s, v21.4s + mls v0.4s, v16.4s, v22.4s + bic v16.16b, v16.16b, v25.16b + add v0.4s, v0.4s, v25.4s + cmle v0.4s, v0.4s, #0 + orr v0.4s, #0x1 + mla v16.4s, v0.4s, v4.4s + cmgt v25.4s, v16.4s, v24.4s + bic v16.16b, v16.16b, v25.16b + umin v16.4s, v16.4s, v24.4s + str q17, [x0, #0x10] + str q18, [x0, #0x20] + str q19, [x0, #0x30] + str q16, [x0], #0x40 + subs x3, x3, #0x1 + b.ne Lpoly_use_hint_88_loop + ret + .cfi_endproc + +#if defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/proofs/hol_light/aarch64/proofs/aarch64_utils.ml b/proofs/hol_light/aarch64/proofs/aarch64_utils.ml index 34d06c696..34769c333 100644 --- a/proofs/hol_light/aarch64/proofs/aarch64_utils.ml +++ b/proofs/hol_light/aarch64/proofs/aarch64_utils.ml @@ -5,6 +5,18 @@ needs "common/mldsa_specs.ml";; +let ENSURES_STRENGTHEN_POST = prove( + `!P (Q:armstate->bool) Q' R. + ensures arm P Q' R /\ (!s. Q' s ==> Q s) ==> ensures arm P Q R`, + REPEAT GEN_TAC THEN DISCH_THEN(CONJUNCTS_THEN2 MP_TAC ASSUME_TAC) THEN + REWRITE_TAC[ensures] THEN MATCH_MP_TAC MONO_FORALL THEN + X_GEN_TAC `s0:armstate` THEN MATCH_MP_TAC MONO_IMP THEN REWRITE_TAC[] THEN + MP_TAC(BETA_RULE(ISPECL [`arm`; + `\s':armstate. (Q':armstate->bool) s' /\ (R:armstate->armstate->bool) (s0:armstate) s'`; + `\s':armstate. (Q:armstate->bool) s' /\ (R:armstate->armstate->bool) (s0:armstate) s'`] + EVENTUALLY_MONO)) THEN + ANTS_TAC THENL [ASM_MESON_TAC[]; MESON_TAC[]]);; + (* Merge 4 x bytes32 reads into bytes128 reads *) let MEMORY_128_FROM_32_TAC = let a_tm = `a:int64` and n_tm = `n:num` and i64_ty = `:int64` diff --git a/proofs/hol_light/aarch64/proofs/dump_bytecode.ml b/proofs/hol_light/aarch64/proofs/dump_bytecode.ml index 5daabd90d..fb2a45fff 100644 --- a/proofs/hol_light/aarch64/proofs/dump_bytecode.ml +++ b/proofs/hol_light/aarch64/proofs/dump_bytecode.ml @@ -30,6 +30,14 @@ print_string "=== bytecode start: aarch64/mldsa/mldsa_poly_caddq.o ===\n";; print_literal_from_elf "aarch64/mldsa/mldsa_poly_caddq.o";; print_string "==== bytecode end =====================================\n\n";; +print_string "=== bytecode start: aarch64/mldsa/poly_use_hint_32_aarch64_asm.o ===\n";; +print_literal_from_elf "aarch64/mldsa/poly_use_hint_32_aarch64_asm.o";; +print_string "==== bytecode end =====================================\n\n";; + +print_string "=== bytecode start: aarch64/mldsa/poly_use_hint_88_aarch64_asm.o ===\n";; +print_literal_from_elf "aarch64/mldsa/poly_use_hint_88_aarch64_asm.o";; +print_string "==== bytecode end =====================================\n\n";; + print_string "=== bytecode start: aarch64/mldsa/mldsa_poly_chknorm.o ===\n";; print_literal_from_elf "aarch64/mldsa/mldsa_poly_chknorm.o";; print_string "==== bytecode end =====================================\n\n";; diff --git a/proofs/hol_light/aarch64/proofs/poly_use_hint_32_aarch64_asm.ml b/proofs/hol_light/aarch64/proofs/poly_use_hint_32_aarch64_asm.ml new file mode 100644 index 000000000..8273007e3 --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/poly_use_hint_32_aarch64_asm.ml @@ -0,0 +1,951 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Use hint to correct high bits of decomposition (ML-DSA, param 65/87). *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "common/mldsa_specs.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; + + +(**** print_literal_from_elf "aarch64/mldsa/poly_use_hint_32_aarch64_asm.o";; + ****) + +let poly_use_hint_32_aarch64_asm_mc = define_assert_from_elf + "poly_use_hint_32_aarch64_asm_mc" "aarch64/mldsa/poly_use_hint_32_aarch64_asm.o" +(*** BYTECODE START ***) +[ + 0x529c0024; (* arm_MOV W4 (rvalue (word 57345)) *) + 0x72a00fe4; (* arm_MOVK W4 (word 127) 16 *) + 0x4e040c94; (* arm_DUP_GEN Q20 X4 32 128 *) + 0x529c2005; (* arm_MOV W5 (rvalue (word 57600)) *) + 0x72a00f65; (* arm_MOVK W5 (word 123) 16 *) + 0x4e040cb5; (* arm_DUP_GEN Q21 X5 32 128 *) + 0x529fc007; (* arm_MOV W7 (rvalue (word 65024)) *) + 0x72a000e7; (* arm_MOVK W7 (word 7) 16 *) + 0x4e040cf6; (* arm_DUP_GEN Q22 X7 32 128 *) + 0x5280802b; (* arm_MOV W11 (rvalue (word 1025)) *) + 0x72a8020b; (* arm_MOVK W11 (word 16400) 16 *) + 0x4e040d77; (* arm_DUP_GEN Q23 X11 32 128 *) + 0x4f0005f8; (* arm_MOVI Q24 (word 64424509455) *) + 0xd2800203; (* arm_MOV X3 (rvalue (word 16)) *) + 0x3dc00421; (* arm_LDR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3dc00822; (* arm_LDR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3dc00c23; (* arm_LDR Q3 X1 (Immediate_Offset (word 48)) *) + 0x3cc40420; (* arm_LDR Q0 X1 (Postimmediate_Offset (word 64)) *) + 0x3dc00445; (* arm_LDR Q5 X2 (Immediate_Offset (word 16)) *) + 0x3dc00846; (* arm_LDR Q6 X2 (Immediate_Offset (word 32)) *) + 0x3dc00c47; (* arm_LDR Q7 X2 (Immediate_Offset (word 48)) *) + 0x3cc40444; (* arm_LDR Q4 X2 (Postimmediate_Offset (word 64)) *) + 0x4eb7b431; (* arm_SQDMULH_VEC Q17 Q1 Q23 32 128 *) + 0x4f2e2631; (* arm_SRSHR_VEC Q17 Q17 18 32 128 *) + 0x4eb53439; (* arm_CMGT_VEC Q25 Q1 Q21 32 128 *) + 0x6eb69621; (* arm_MLS_VEC Q1 Q17 Q22 32 128 *) + 0x4e791e31; (* arm_BIC_VEC Q17 Q17 Q25 128 *) + 0x4eb98421; (* arm_ADD_VEC Q1 Q1 Q25 32 128 *) + 0x6ea09821; (* arm_CMLE_VEC_ZERO Q1 Q1 32 128 *) + 0x4f001421; (* arm_ORR_VEC Q1 Q1 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea59431; (* arm_MLA_VEC Q17 Q1 Q5 32 128 *) + 0x4e381e31; (* arm_AND_VEC Q17 Q17 Q24 128 *) + 0x4eb7b452; (* arm_SQDMULH_VEC Q18 Q2 Q23 32 128 *) + 0x4f2e2652; (* arm_SRSHR_VEC Q18 Q18 18 32 128 *) + 0x4eb53459; (* arm_CMGT_VEC Q25 Q2 Q21 32 128 *) + 0x6eb69642; (* arm_MLS_VEC Q2 Q18 Q22 32 128 *) + 0x4e791e52; (* arm_BIC_VEC Q18 Q18 Q25 128 *) + 0x4eb98442; (* arm_ADD_VEC Q2 Q2 Q25 32 128 *) + 0x6ea09842; (* arm_CMLE_VEC_ZERO Q2 Q2 32 128 *) + 0x4f001422; (* arm_ORR_VEC Q2 Q2 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea69452; (* arm_MLA_VEC Q18 Q2 Q6 32 128 *) + 0x4e381e52; (* arm_AND_VEC Q18 Q18 Q24 128 *) + 0x4eb7b473; (* arm_SQDMULH_VEC Q19 Q3 Q23 32 128 *) + 0x4f2e2673; (* arm_SRSHR_VEC Q19 Q19 18 32 128 *) + 0x4eb53479; (* arm_CMGT_VEC Q25 Q3 Q21 32 128 *) + 0x6eb69663; (* arm_MLS_VEC Q3 Q19 Q22 32 128 *) + 0x4e791e73; (* arm_BIC_VEC Q19 Q19 Q25 128 *) + 0x4eb98463; (* arm_ADD_VEC Q3 Q3 Q25 32 128 *) + 0x6ea09863; (* arm_CMLE_VEC_ZERO Q3 Q3 32 128 *) + 0x4f001423; (* arm_ORR_VEC Q3 Q3 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea79473; (* arm_MLA_VEC Q19 Q3 Q7 32 128 *) + 0x4e381e73; (* arm_AND_VEC Q19 Q19 Q24 128 *) + 0x4eb7b410; (* arm_SQDMULH_VEC Q16 Q0 Q23 32 128 *) + 0x4f2e2610; (* arm_SRSHR_VEC Q16 Q16 18 32 128 *) + 0x4eb53419; (* arm_CMGT_VEC Q25 Q0 Q21 32 128 *) + 0x6eb69600; (* arm_MLS_VEC Q0 Q16 Q22 32 128 *) + 0x4e791e10; (* arm_BIC_VEC Q16 Q16 Q25 128 *) + 0x4eb98400; (* arm_ADD_VEC Q0 Q0 Q25 32 128 *) + 0x6ea09800; (* arm_CMLE_VEC_ZERO Q0 Q0 32 128 *) + 0x4f001420; (* arm_ORR_VEC Q0 Q0 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea49410; (* arm_MLA_VEC Q16 Q0 Q4 32 128 *) + 0x4e381e10; (* arm_AND_VEC Q16 Q16 Q24 128 *) + 0x3d800411; (* arm_STR Q17 X0 (Immediate_Offset (word 16)) *) + 0x3d800812; (* arm_STR Q18 X0 (Immediate_Offset (word 32)) *) + 0x3d800c13; (* arm_STR Q19 X0 (Immediate_Offset (word 48)) *) + 0x3c840410; (* arm_STR Q16 X0 (Postimmediate_Offset (word 64)) *) + 0xf1000463; (* arm_SUBS X3 X3 (rvalue (word 1)) *) + 0x54fff961; (* arm_BNE (word 2096940) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let POLY_USE_HINT_32_AARCH64_ASM_EXEC = ARM_MK_EXEC_RULE poly_use_hint_32_aarch64_asm_mc;; + +(* Per-element word function matching the assembly computation *) +let mldsa_use_hint_32_asm = new_definition + `mldsa_use_hint_32_asm (a:int32) (h:int32) : int32 = + let a1 = word_ishr_round (word_2smulh a (word 1074791425)) 18 in + let m:int32 = word_neg(word(bitval(word_igt a (word 8118528)))) in + let a0 = word_add (word_sub a (word_mul a1 (word 523776))) m in + let a1' = word_and a1 (word_not m) in + let delta:int32 = word_or (word_neg(word(bitval(word_ile a0 (word 0))))) (word 1) in + word_and (word_add a1' (word_mul delta h)) (word 15)`;; + +(* Numeric description of the assembly's UseHint path, exposing the Barrett + approximation used by the code. Connected to the FIPS 204 definition + mldsa_use_hint_32 via MLDSA_USE_HINT_32_EQUIV below. *) +let mldsa_use_hint_32_code = new_definition + `mldsa_use_hint_32_code (a:num) (h:num) = + let a1 = ((((a + 127) DIV 128) * 1025 + 2097152) DIV 4194304) MOD 16 in + let a0:int = &a - &a1 * &523776 in + let a0' = if a0 > &4190208 then a0 - &8380417 else a0 in + if h = 0 then a1 + else if a0' > &0 then (a1 + 1) MOD 16 + else (a1 + 15) MOD 16`;; + +(* ========================================================================= *) +(* Functional correctness helper lemmas *) +(* ========================================================================= *) + +let WORD_2SMULH_NOSATURATE_32 = prove( + `!a:int32. val a < 8380417 + ==> word_2smulh a (word 1074791425:int32) : int32 = + iword((&2 * &(val a) * &1074791425) div &2 pow 32)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[word_2smulh; DIMINDEX_32] THEN + ASM_SIMP_TAC[MLDSA_IVAL_VAL] THEN + CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[iword_saturate; word_INT_MIN; word_INT_MAX; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN + SIMP_TAC[INT_OF_NUM_DIV] THEN + REWRITE_TAC[INT_POS_NEG_BOUND] THEN + SUBGOAL_THEN `~(&((2 * val(a:int32) * 1074791425) DIV 4294967296):int > &2147483647)` + (fun th -> REWRITE_TAC[th]) THEN + SUBGOAL_THEN `(2 * val(a:int32) * 1074791425) DIV 4294967296 <= 2147483647` + (fun th -> MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LE] th) THEN INT_ARITH_TAC) THEN + TRANS_TAC LE_TRANS `(2 * 8380416 * 1074791425) DIV 4294967296` THEN + CONJ_TAC THENL + [MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV]);; + +let VAL_DECOMPOSE_A1 = prove( + `!a:int32. val a < 8380417 + ==> val(word_ishr_round (word_2smulh a (word 1074791425:int32)) 18 : int32) + = ((2 * val a * 1074791425) DIV 4294967296 + 131072) DIV 262144`, + GEN_TAC THEN DISCH_TAC THEN + ASM_SIMP_TAC[WORD_2SMULH_NOSATURATE_32] THEN + SUBGOAL_THEN `(2 * val(a:int32) * 1074791425) DIV 4294967296 < 2147483648` + ASSUME_TAC THENL + [TRANS_TAC LT_TRANS `(2 * 8380416 * 1074791425) DIV 4294967296 + 1` THEN + CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_SIMP_TAC[VAL_IWORD_NUM_32] THEN + ABBREV_TAC `t:int32 = iword(&((2 * val(a:int32) * 1074791425) DIV 4294967296))` THEN + SUBGOAL_THEN `val(t:int32) = (2 * val(a:int32) * 1074791425) DIV 4294967296` + ASSUME_TAC THENL + [EXPAND_TAC "t" THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `val(t:int32) < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[word_ishr_round] THEN CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV THEN + SUBGOAL_THEN `ival(t:int32) = &(val t)` ASSUME_TAC THENL + [SIMP_TAC[ival; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `(val(t:int32) + 131072) DIV 262144 < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + TRANS_TAC LT_TRANS `(4194303 + 131072) DIV 262144 + 1` THEN CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + ASM_SIMP_TAC[VAL_IWORD_NUM_32] THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN + UNDISCH_THEN `val(t:int32) = (2 * val(a:int32) * 1074791425) DIV 4294967296` + (SUBST1_TAC o SYM) THEN ASM_REWRITE_TAC[]);; + +let WORD_IGT_THRESHOLD_32 = BITBLAST_RULE + `!a:int32. val a < 8380417 + ==> word_igt a (word 8118528:int32) <=> val a > 8118528`;; + +let A1_BOUND = prove( + `!a. a < 8380417 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 <= 16`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8380416 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `69205952` (SPEC `d * 1025 + 2097152` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 1025 <= 65472 * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A1_WRAP = prove( + `!a. 8118528 < a /\ a < 8380417 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 = 16`, + GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `16 <= ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304` + ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `a + 127` (SPEC `8118529 + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `d * 1025 + 2097152` (SPEC `67108977` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `63427 * 1025 <= d * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + MP_TAC(SPEC `a:num` A1_BOUND) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_ARITH_TAC);; + +let A1_BOUND_NOWRAP = prove( + `!a. a <= 8118528 + ==> ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 <= 15`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8118528 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `4194304` (SPEC `67108802` (SPEC `d * 1025 + 2097152` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 1025 <= 63426 * 1025` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A0_UPPER_32 = prove( + `!a. a <= 8118528 + ==> a < (((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 + 1) * 523776`, + GEN_TAC THEN DISCH_TAC THEN + ABBREV_TAC `nv = ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304` THEN + SUBGOAL_THEN `nv * 4194304 <= (a + 127) DIV 128 * 1025 + 2097152` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN + MP_TAC(SPECL [`(a + 127) DIV 128 * 1025 + 2097152`; `4194304`] (CONJUNCT1 DIVISION_SIMP)) THEN + ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(a + 127) DIV 128 <= 63426` ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `8118528 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 4194304 <= 63426 * 1025 + 2097152` ASSUME_TAC THENL + [SUBGOAL_THEN `(a + 127) DIV 128 * 1025 <= 63426 * 1025` MP_TAC THENL + [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ASM_ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 15` ASSUME_TAC THENL + [CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_ARITH_TAC);; + +let WORD_SUB_SIGN_32 = BITBLAST_RULE + `!a:int32 b:int32. val b <= 7856640 /\ val a <= 8118528 ==> + ((bit 31 (word_sub a b) \/ word_sub a b = word 0) <=> val a <= val b)`;; + +let WRAP_A0_NEGATIVE = BITBLAST_RULE + `!a:int32. val a < 8380417 /\ val a > 8118528 + ==> bit 31 (word_add (word_sub a (word 8380416:int32)) (word 4294967295:int32))`;; + +(* Barrett equivalence: assembly and C decomposition formulas agree. + Both compute round_half_down(a / 523776) via different Barrett + approximation paths. Proved by case analysis on 17 output intervals + using DIV_MONO to sandwich both LHS and RHS to the same constant. *) +let BARRETT_INTERVAL_32 = prove( + `!a lo hi k. + lo <= a /\ a <= hi /\ + k * 262144 <= (2 * lo * 1074791425) DIV 4294967296 + 131072 /\ + (2 * hi * 1074791425) DIV 4294967296 + 131072 < (k + 1) * 262144 /\ + k * 4194304 <= (lo + 127) DIV 128 * 1025 + 2097152 /\ + (hi + 127) DIV 128 * 1025 + 2097152 < (k + 1) * 4194304 + ==> ((2 * a * 1074791425) DIV 4294967296 + 131072) DIV 262144 = k /\ + ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304 = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + CONJ_TAC THEN MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THENL + [CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(2 * lo * 1074791425) DIV 4294967296 + 131072` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 131072 <= y + 131072 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(2 * hi * 1074791425) DIV 4294967296 + 131072` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 131072 <= y + 131072 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]; + CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(lo + 127) DIV 128 * 1025 + 2097152` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 2097152 <= y + 2097152 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(hi + 127) DIV 128 * 1025 + 2097152` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 2097152 <= y + 2097152 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]]);; + +let BARRETT_EQUIV = prove( + `!a. a < 8380417 ==> + ((2 * a * 1074791425) DIV 4294967296 + 131072) DIV 262144 = + ((a + 127) DIV 128 * 1025 + 2097152) DIV 4194304`, + GEN_TAC THEN DISCH_TAC THEN + let intervals = [ + (0, 261888); (261889, 785664); (785665, 1309440); (1309441, 1833216); + (1833217, 2356992); (2356993, 2880768); (2880769, 3404544); + (3404545, 3928320); (3928321, 4452096); (4452097, 4975872); + (4975873, 5499648); (5499649, 6023424); (6023425, 6547200); + (6547201, 7070976); (7070977, 7594752); (7594753, 8118528); + (8118529, 8380416)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("a",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`a:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_32 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + + + +(* ========================================================================= *) +(* Element-level functional correctness *) +(* ========================================================================= *) + +let ELEMENT_CORRECT = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> val(mldsa_use_hint_32_asm a h) = mldsa_use_hint_32_code (val a) (val h)`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[mldsa_use_hint_32_asm; mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ABBREV_TAC `nv = ((val(a:int32) + 127) DIV 128 * 1025 + 2097152) DIV 4194304` THEN + SUBGOAL_THEN `val(word_ishr_round (word_2smulh (a:int32) (word 1074791425)) 18 : int32) = nv` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN TRANS_TAC EQ_TRANS `((2 * val(a:int32) * 1074791425) DIV 4294967296 + 131072) DIV 262144` THEN CONJ_TAC THENL [MATCH_MP_TAC VAL_DECOMPOSE_A1 THEN ASM_REWRITE_TAC[]; MATCH_MP_TAC BARRETT_EQUIV THEN ASM_REWRITE_TAC[]]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 16` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND) THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_igt (a:int32) (word 8118528:int32) <=> val a > 8118528` SUBST1_TAC THENL [MP_TAC(SPEC `a:int32` WORD_IGT_THRESHOLD_32) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(a:int32) > 8118528` THEN ASM_REWRITE_TAC[bitval] THENL + [ + SUBGOAL_THEN `nv = 16` SUBST_ALL_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_WRAP) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_ishr_round (word_2smulh (a:int32) (word 1074791425)) 18 = (word 16:int32)` (fun th -> REWRITE_TAC[th]) THENL [ONCE_REWRITE_TAC[GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[WORD_ILE_ZERO_32] THEN + MP_TAC(SPEC `a:int32` WRAP_A0_NEGATIVE) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + SUBGOAL_THEN `~((if int_gt (&(val(a:int32))) (&4190208) then &(val a) - &8380417 else &(val a):int) > &0)` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN COND_CASES_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT] (ASSUME `val(a:int32) < 8380417`)) THEN INT_ARITH_TAC; + POP_ASSUM(MP_TAC o REWRITE_RULE[INT_GT; INT_NOT_LT]) THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_GT; INT_GT] (ASSUME `val(a:int32) > 8118528`)) THEN INT_ARITH_TAC]; + ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN CONV_TAC WORD_REDUCE_CONV; + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN CONV_TAC WORD_REDUCE_CONV] + ; + SUBGOAL_THEN `nv <= 15` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND_NOWRAP) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `nv MOD 16 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `word_ishr_round (word_2smulh (a:int32) (word 1074791425)) 18 = (word nv:int32)` (fun th -> REWRITE_TAC[th]) THENL [GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_REFL] THEN REWRITE_TAC[WORD_ILE_ZERO_32; WORD_ADD_0] THEN + SUBGOAL_THEN `nv * 523776 <= 7856640` ASSUME_TAC THENL [SUBGOAL_THEN `nv * 523776 <= 15 * 523776` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 523776:int32)) = nv * 523776` ASSUME_TAC THENL [REWRITE_TAC[VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 523776:int32)) <= 7856640` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(bit 31 (word_sub (a:int32) (word_mul (word nv:int32) (word 523776:int32))) \/ word_sub a (word_mul (word nv) (word 523776)) = word 0) <=> ~(&(val a) - &nv * &523776 > &0)` SUBST1_TAC THENL + [MP_TAC(ISPECL [`a:int32`; `word_mul (word nv:int32) (word 523776:int32)`] WORD_SUB_SIGN_32) THEN ASM_REWRITE_TAC[] THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 523776` THENL + [ASM_REWRITE_TAC[] THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] (REWRITE_RULE[GSYM INT_OF_NUM_LE] (ASSUME `val(a:int32) <= nv * 523776`))) THEN INT_ARITH_TAC; + ASM_REWRITE_TAC[] THEN SUBGOAL_THEN `nv * 523776 < val(a:int32)` ASSUME_TAC THENL [UNDISCH_TAC `~(val(a:int32) <= nv * 523776)` THEN ARITH_TAC; ALL_TAC] THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_MUL] (REWRITE_RULE[GSYM INT_OF_NUM_LT] (ASSUME `nv * 523776 < val(a:int32)`))) THEN INT_ARITH_TAC]; ALL_TAC] THEN + REWRITE_TAC[bitval] THEN + ASM_CASES_TAC `val(h:int32) = 0` THEN ASM_REWRITE_TAC[] THENL + [SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN REWRITE_TAC[WORD_MUL_0; WORD_ADD_0; WORD_AND_ONES_32] THEN REWRITE_TAC[VAL_WORD_AND_15_32] THEN SUBGOAL_THEN `val(word nv:int32) = nv` SUBST1_TAC THENL [MATCH_MP_TAC VAL_WORD_EQ THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; + ALL_TAC] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_1_32; WORD_AND_ONES_32] THEN + SUBGOAL_THEN `val(word nv:int32) = nv` ASSUME_TAC THENL [MATCH_MP_TAC VAL_WORD_EQ THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word nv:int32) <= 15` ASSUME_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(int_gt (&(val(a:int32)) - &nv * &523776) (&4190208))` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN MP_TAC(SPEC `val(a:int32)` A0_UPPER_32) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD] (ASSUME `val(a:int32) < (nv + 1) * 523776`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 523776` THENL + [MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `523776`] REAL_INT_GT_BRIDGE) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]; + MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `523776`] REAL_INT_GT_BRIDGE_POS) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]] THEN + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[VAL_WORD_AND_15_32; VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `16 = 2 EXP 4`; ARITH_RULE `4294967296 = 2 EXP 32`; MOD_MOD_EXP_MIN] THEN CONV_TAC NUM_REDUCE_CONV THEN + REWRITE_TAC[ARITH_RULE `4294967295 = 15 + 268435455 * 16`; ARITH_RULE `n + (15 + 268435455 * 16) = (n + 15) + 268435455 * 16`; MOD_MULT_ADD] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` (fun th -> REWRITE_TAC[th]) THEN TRY(MATCH_MP_TAC MOD_LT) THEN ASM_ARITH_TAC]);; + + +let ELEMENT_CORRECT_WORD = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> mldsa_use_hint_32_asm a h = + word(mldsa_use_hint_32_code (val a) (val h))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN + AP_TERM_TAC THEN MP_TAC(SPECL [`a:int32`; `h:int32`] ELEMENT_CORRECT) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]));; + +(* ========================================================================= *) +(* Correctness proof, code-aligned spec (intermediate) *) +(* ========================================================================= *) + +let POLY_USE_HINT_32_AARCH64_ASM_CORRECT_CODE = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH poly_use_hint_32_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_32_aarch64_asm_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH poly_use_hint_32_aarch64_asm_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32_code (val(x i)) (val(y i))))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + + (* Setup *) + MAP_EVERY X_GEN_TAC + [`b:int64`; `a:int64`; `h:int64`; + `x:num->int32`; `y:num->int32`; `pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; + NONOVERLAPPING_CLAUSES; ALL; + fst POLY_USE_HINT_32_AARCH64_ASM_EXEC] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + GLOBALIZE_PRECONDITION_TAC THEN + CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV EXPAND_CASES_CONV))) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC[SOME_FLAGS; MODIFIABLE_SIMD_REGS] THEN + + (* Initialize and merge memory *) + ENSURES_INIT_TAC "s0" THEN + MEMORY_128_FROM_32_TAC "a" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + MEMORY_128_FROM_32_TAC "h" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + + (* Simulate 878 instructions (excluding RET) *) + MAP_EVERY (fun n -> ARM_STEPS_TAC POLY_USE_HINT_32_AARCH64_ASM_EXEC [n] THEN + SIMD_SIMPLIFY_TAC[]) + (1--878) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + (* Split bytes128 -> bytes32 for output memory *) + REPEAT(FIRST_X_ASSUM(STRIP_ASSUME_TAC o + CONV_RULE (SIMD_SIMPLIFY_CONV []) o + CONV_RULE(READ_MEMORY_SPLIT_CONV 2) o + check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN + + (* Expand output cases, substitute *) + CONV_TAC(TOP_DEPTH_CONV EXPAND_CASES_CONV) THEN + CONV_TAC(DEPTH_CONV NUM_MULT_CONV THENC DEPTH_CONV NUM_ADD_CONV) THEN + REWRITE_TAC[WORD_ADD_0] THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN ASM_REWRITE_TAC[] THEN + + (* Push word_subword through SIMD ops to per-element form *) + REWRITE_TAC[WORD_SUBWORD_AND; WORD_SUBWORD_OR] THEN + let WSN_TAC = REWRITE_TAC(map (fun n -> prove( + subst [mk_small_numeral n, `n:num`] + `!x:int128. word_subword(word_not x) (n,32):int32 = word_not(word_subword x (n,32))`, + GEN_TAC THEN MATCH_MP_TAC WORD_SUBWORD_NOT THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_128] THEN ARITH_TAC)) [0;32;64;96]) in + WSN_TAC THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + (* Match expanded ival/iword form *) + let EC_DEEP = + CONV_RULE(DEPTH_CONV WORD_NUM_RED_CONV) + (CONV_RULE(DEPTH_CONV(INT_RED_CONV ORELSEC NUM_RED_CONV)) + (CONV_RULE(TOP_DEPTH_CONV let_CONV) + (REWRITE_RULE[mldsa_use_hint_32_asm; word_2smulh; word_ishr_round; + DIMINDEX_32] ELEMENT_CORRECT_WORD))) in + let EC_FINAL = ONCE_REWRITE_RULE[WORD_AND_SYM] + (ONCE_REWRITE_RULE[WORD_OR_SYM] EC_DEEP) in + REPEAT CONJ_TAC THEN + MATCH_MP_TAC EC_FINAL THEN + CONJ_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ARITH_TAC);; + + +(* ========================================================================= *) +(* Subroutine form (intermediate, code-aligned) *) +(* ========================================================================= *) + +let POLY_USE_HINT_32_AARCH64_ASM_CORRECT_BOUND_CODE = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH poly_use_hint_32_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_32_aarch64_asm_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH poly_use_hint_32_aarch64_asm_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32_code (val(x i)) (val(y i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 16)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + MATCH_MP_TAC ENSURES_STRENGTHEN_POST THEN + EXISTS_TAC + `\s. read PC s = word(pc + LENGTH poly_use_hint_32_aarch64_asm_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32_code (val(x i:int32)) (val(y i:int32))))` THEN + CONJ_TAC THENL + [MATCH_MP_TAC POLY_USE_HINT_32_AARCH64_ASM_CORRECT_CODE THEN ASM_REWRITE_TAC[]; + REWRITE_TAC[] THEN REPEAT STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + FIRST_X_ASSUM(MP_TAC o SPEC `i:num`) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC(ARITH_RULE `x < 16 ==> x MOD 4294967296 < 16`) THEN + REWRITE_TAC[mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN ASM_ARITH_TAC]);; + +(* Intermediate subroutine correctness against the code-aligned spec. + Bridged to the public FIPS 204-aligned theorem below via + MLDSA_USE_HINT_32_EQUIV. *) +let POLY_USE_HINT_32_AARCH64_ASM_SUBROUTINE_CORRECT_CODE = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH poly_use_hint_32_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_32_aarch64_asm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32_code (val(x i)) (val(y i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 16)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REWRITE_TAC[fst POLY_USE_HINT_32_AARCH64_ASM_EXEC] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ARM_ADD_RETURN_NOSTACK_TAC POLY_USE_HINT_32_AARCH64_ASM_EXEC + (CONV_RULE(ONCE_DEPTH_CONV NUM_REDUCE_CONV) + (REWRITE_RULE[fst POLY_USE_HINT_32_AARCH64_ASM_EXEC] + POLY_USE_HINT_32_AARCH64_ASM_CORRECT_BOUND_CODE)));; + + +(* ========================================================================= *) +(* FIPS 204 = code-aligned equivalence *) +(* ========================================================================= *) + +let LINEARIZE_DIV_MOD_TAC = + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +(* Prove r DIV 523776 = k via DIV_SANDWICH + LE_MULT_RCANCEL *) +let DIV_523776_TAC k = + let k_num = mk_small_numeral k and km1 = mk_small_numeral (k-1) + and kp1 = mk_small_numeral (k+1) + and q = mk_var("q",`:num`) and le = `(<=):num->num->bool` + and lt = `(<):num->num->bool` + and c = `523776` in + let mk_mul a b = mk_binop (rator(rator `0*0`)) a b in + MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; c] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, q) THEN + REPEAT GEN_TAC THEN STRIP_TAC THEN + ASM_CASES_TAC(mk_comb(mk_comb(le, q), km1)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul q c), + mk_mul km1 c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul k_num c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_CASES_TAC(mk_comb(mk_comb(lt, k_num), q)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul kp1 c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC]];; + +(* Replace (r - r MOD 523776) DIV 523776 with r DIV 523776 *) +let DIV_MOD_TO_DIV_TAC = + SUBGOAL_THEN `(r - r MOD 523776) DIV 523776 = r DIV 523776` SUBST1_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 = 523776 * r DIV 523776` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV; ALL_TAC];; + +(* Lower half nowrap: dismiss wrap cond, reduce, prove r DIV 523776 = k *) +let DECOMPOSE_R1_LOWER_TAC = + SUBGOAL_THEN `~((&r:int) - &(r MOD 523776) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN LINEARIZE_DIV_MOD_TAC; + ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + DIV_MOD_TO_DIV_TAC THEN + CONV_TAC SYM_CONV THEN + LINEARIZE_DIV_MOD_TAC;; + +(* Upper half nowrap: dismiss wrap cond, reduce, prove r DIV 523776 + 1 = k *) +let DECOMPOSE_R1_UPPER_TAC = + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~((&r:int) - (&(r MOD 523776) - &523776) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [REWRITE_TAC[INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`] THEN + ASM_SIMP_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + LINEARIZE_DIV_MOD_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 + 523776 = 523776 * (r DIV 523776 + 1)` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776 + 1`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN SUBST1_TAC THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 523776` (concl th) || + free_in `r DIV 523776` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 523776`, `m:num`) THEN + SPEC_TAC(`r DIV 523776`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +let DECOMPOSE_R1_NOWRAP_TAC = + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THEN + TRY DECOMPOSE_R1_LOWER_TAC THEN TRY DECOMPOSE_R1_UPPER_TAC;; + +let DECOMPOSE_32_R1_EQUIV = time prove( + `!r. r < 8380417 ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = + decompose_32_r1 r`, + GEN_TAC THEN DISCH_TAC THEN + ASM_CASES_TAC `r <= 8118528` THENL + [ALL_TAC; + (* Wrap zone *) + SUBGOAL_THEN `8118528 < r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `decompose_32_r1 r = 0` SUBST1_TAC THENL + [REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THENL + [(* Lower wrap: r DIV 523776 = 16 *) + SUBGOAL_THEN `r DIV 523776 = 16` ASSUME_TAC THENL + [DIV_523776_TAC 16; ALL_TAC] THEN + SUBGOAL_THEN `16 * 523776 + r MOD 523776 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + ASM_ARITH_TAC; + (* Upper wrap: r DIV 523776 = 15 *) + SUBGOAL_THEN `r DIV 523776 = 15` ASSUME_TAC THENL + [DIV_523776_TAC 15; ALL_TAC] THEN + SUBGOAL_THEN `15 * 523776 + r MOD 523776 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + ALL_TAC] THEN + MP_TAC(SPEC `r:num` A1_WRAP) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONV_TAC NUM_REDUCE_CONV] THEN + (* Nowrap zone: unfold and do interval cascade *) + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + let intervals = [ + (0, 261888); (261889, 785664); (785665, 1309440); + (1309441, 1833216); (1833217, 2356992); (2356993, 2880768); + (2880769, 3404544); (3404545, 3928320); (3928321, 4452096); + (4452097, 4975872); (4975873, 5499648); (5499649, 6023424); + (6023425, 6547200); (6547201, 7070976); (7070977, 7594752); + (7594753, 8118528)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("r",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`r:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_32 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + DECOMPOSE_R1_NOWRAP_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + +let R1_IS_DIV_LOWER = prove( + `!r. r < 8380417 /\ r MOD 523776 * 2 <= 523776 /\ + ~((&r:int) - &(r MOD 523776) = &8380416) ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = r DIV 523776`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + MP_TAC(SPEC `r:num` LOWER_NONWRAP_R1) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R1_IS_DIV_PLUS1_UPPER = prove( + `!r. r < 8380417 /\ ~(r MOD 523776 * 2 <= 523776) /\ + ~((&r:int) - (&(r MOD 523776) - &523776) = &8380416) ==> + (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 = + r DIV 523776 + 1`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + MP_TAC(SPEC `r:num` UPPER_NONWRAP_R1) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +(* Upper nowrap: substitute Barrett = r DIV 523776 + 1, use INT_MOD_RESIDUE *) +let R0_SIGN_UPPER_NOWRAP_TAC = + MP_TAC(SPEC `r:num` R1_IS_DIV_PLUS1_UPPER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_MUL] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `(a:int) - (b + &1) * c = a - b * c - c`] THEN + REWRITE_TAC[INT_ARITH `x - &523776 > &0 <=> x > &523776`; + INT_ARITH `x - &523776 - &8380417 > &0 <=> x > &8904193`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Lower nowrap: substitute Barrett = r DIV 523776, use INT_MOD_RESIDUE *) +let R0_SIGN_LOWER_NOWRAP_TAC = + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Wrap: derive 8118528 < r, use DECOMPOSE_32_R1_EQUIV to get Barrett = 0 *) +let R0_SIGN_WRAP_TAC = + SUBGOAL_THEN `8118528 < r` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> t = `&8380416:int`)) (concl th) && + not(is_neg(concl th)))) THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ; + INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`; + GSYM INT_OF_NUM_ADD] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + REWRITE_TAC[INT_ARITH `x - &1 > &0 <=> x > &1`; + INT_ARITH `(x - &523776) - &1 > &0 <=> x > &523777`; + INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let DECOMPOSE_32_R0_SIGN = time prove( + `!r. r < 8380417 ==> + let a1 = (((r + 127) DIV 128 * 1025 + 2097152) DIV 4194304) MOD 16 in + let a0':int = if (&r:int) - &a1 * &523776 > &4190208 + then &r - &a1 * &523776 - &8380417 + else &r - &a1 * &523776 in + (decompose_32_r0 r > &0 <=> a0' > &0) /\ + (decompose_32_r0 r <= &0 <=> ~(a0' > &0))`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[INT_ARITH `(x:int) <= &0 <=> ~(x > &0)`] THEN + MATCH_MP_TAC(TAUT `(p <=> q) ==> (p <=> q) /\ (~p <=> ~q)`) THEN + REWRITE_TAC[decompose_32_r0; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ONCE_REWRITE_TAC[COND_RAND] THEN REWRITE_TAC[SND; FST] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 523776 * 2 <= 523776` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + TRY R0_SIGN_LOWER_NOWRAP_TAC THEN + TRY R0_SIGN_UPPER_NOWRAP_TAC THEN + TRY R0_SIGN_WRAP_TAC THEN + TRY( + (* Contradiction: lower nowrap with > 4190208 *) + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (SPECL [`r:num`; `523776`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - &((((r + 127) DIV 128 * 1025 + 2097152) DIV + 4194304) MOD 16) * &523776 = &(r MOD 523776)` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(&(r MOD 523776) > (&4190208:int))` MP_TAC THENL + [REWRITE_TAC[INT_NOT_LT; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[INT_OF_NUM_GT] THEN ASM_ARITH_TAC + ));; + +let MLDSA_USE_HINT_32_EQUIV = prove( + `!r h. r < 8380417 /\ h <= 1 + ==> mldsa_use_hint_32 h r = mldsa_use_hint_32_code r h`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[MLDSA_USE_HINT_32_UNFOLD] THEN + REWRITE_TAC[mldsa_use_hint_32_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_32_R0_SIGN) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN STRIP_TAC THEN + ASM_CASES_TAC `h = 0` THENL + [ASM_REWRITE_TAC[ARITH_RULE `~(0 = 1)`]; ALL_TAC] THEN + SUBGOAL_THEN `h = 1` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `decompose_32_r0 r > &0` THEN ASM_REWRITE_TAC[] THEN + ASM_MESON_TAC[]);; + +(* ========================================================================= *) +(* Public subroutine correctness (FIPS 204-aligned) *) +(* ========================================================================= *) + +(* Postcondition is stated in terms of mldsa_use_hint_32 from FIPS 204 + (Algorithm 40), with the output bound < 16 as a corollary. + Derived from POLY_USE_HINT_32_AARCH64_ASM_SUBROUTINE_CORRECT_CODE by + rewriting mldsa_use_hint_32_code -> mldsa_use_hint_32 via + MLDSA_USE_HINT_32_EQUIV. *) +let POLY_USE_HINT_32_AARCH64_ASM_SUBROUTINE_CORRECT = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH poly_use_hint_32_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) /\ + (!i. i < 256 ==> val((x:num->int32) i) < 8380417) /\ + (!i. i < 256 ==> val((y:num->int32) i) <= 1) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_32_aarch64_asm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_32 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 16)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REPEAT GEN_TAC THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + SUBGOAL_THEN + `!i. i < 256 ==> + mldsa_use_hint_32 (val((y:num->int32) i)) (val((x:num->int32) i)) = + mldsa_use_hint_32_code (val(x i)) (val(y i))` + (fun th -> SIMP_TAC[th]) THENL + [REPEAT STRIP_TAC THEN MATCH_MP_TAC MLDSA_USE_HINT_32_EQUIV THEN + CONJ_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]; + MATCH_MP_TAC POLY_USE_HINT_32_AARCH64_ASM_SUBROUTINE_CORRECT_CODE THEN + ASM_REWRITE_TAC[]]);; + + +(* ========================================================================= *) +(* Constant-time and memory safety proof. *) +(* ========================================================================= *) + +needs "arm/proofs/consttime.ml";; +needs "aarch64/proofs/subroutine_signatures.ml";; + + +let full_spec,public_vars = mk_safety_spec + ~keep_maychanges:false + (assoc "poly_use_hint_32_aarch64_asm" subroutine_signatures) + POLY_USE_HINT_32_AARCH64_ASM_SUBROUTINE_CORRECT_CODE + POLY_USE_HINT_32_AARCH64_ASM_EXEC;; + +let POLY_USE_HINT_32_AARCH64_ASM_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e b a h pc returnaddress. + nonoverlapping (word pc,LENGTH poly_use_hint_32_aarch64_asm_mc) (b,1024) /\ + nonoverlapping (b,1024) (a,1024) /\ + nonoverlapping (b,1024) (h,1024) + ==> ensures arm + (\s. + aligned_bytes_loaded s (word pc) + poly_use_hint_32_aarch64_asm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + read events s = e) + (\s. + read PC s = returnaddress /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h b pc returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; b,1024] + [b,1024])) + (\s s'. true)`, + ASSERT_CONCL_TAC full_spec THEN + PROVE_SAFETY_SPEC_TAC ~public_vars:public_vars POLY_USE_HINT_32_AARCH64_ASM_EXEC);; + diff --git a/proofs/hol_light/aarch64/proofs/poly_use_hint_88_aarch64_asm.ml b/proofs/hol_light/aarch64/proofs/poly_use_hint_88_aarch64_asm.ml new file mode 100644 index 000000000..f2ed72dca --- /dev/null +++ b/proofs/hol_light/aarch64/proofs/poly_use_hint_88_aarch64_asm.ml @@ -0,0 +1,1020 @@ +(* + * Copyright (c) The mldsa-native project authors + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0 + *) + +(* ========================================================================= *) +(* Use hint to correct high bits of decomposition (ML-DSA, param 44). *) +(* ========================================================================= *) + +needs "arm/proofs/base.ml";; +needs "common/mldsa_specs.ml";; +needs "aarch64/proofs/aarch64_utils.ml";; + + +(**** print_literal_from_elf "aarch64/mldsa/poly_use_hint_88_aarch64_asm.o";; + ****) + +let poly_use_hint_88_aarch64_asm_mc = define_assert_from_elf + "poly_use_hint_88_aarch64_asm_mc" "aarch64/mldsa/poly_use_hint_88_aarch64_asm.o" +(*** BYTECODE START ***) +[ + 0x529c0024; (* arm_MOV W4 (rvalue (word 57345)) *) + 0x72a00fe4; (* arm_MOVK W4 (word 127) 16 *) + 0x4e040c94; (* arm_DUP_GEN Q20 X4 32 128 *) + 0x528d8005; (* arm_MOV W5 (rvalue (word 27648)) *) + 0x72a00fc5; (* arm_MOVK W5 (word 126) 16 *) + 0x4e040cb5; (* arm_DUP_GEN Q21 X5 32 128 *) + 0x529d0007; (* arm_MOV W7 (rvalue (word 59392)) *) + 0x72a00047; (* arm_MOVK W7 (word 2) 16 *) + 0x4e040cf6; (* arm_DUP_GEN Q22 X7 32 128 *) + 0x5280b02b; (* arm_MOV W11 (rvalue (word 1409)) *) + 0x72ab02cb; (* arm_MOVK W11 (word 22550) 16 *) + 0x4e040d77; (* arm_DUP_GEN Q23 X11 32 128 *) + 0x4f010578; (* arm_MOVI Q24 (word 184683593771) *) + 0xd2800203; (* arm_MOV X3 (rvalue (word 16)) *) + 0x3dc00421; (* arm_LDR Q1 X1 (Immediate_Offset (word 16)) *) + 0x3dc00822; (* arm_LDR Q2 X1 (Immediate_Offset (word 32)) *) + 0x3dc00c23; (* arm_LDR Q3 X1 (Immediate_Offset (word 48)) *) + 0x3cc40420; (* arm_LDR Q0 X1 (Postimmediate_Offset (word 64)) *) + 0x3dc00445; (* arm_LDR Q5 X2 (Immediate_Offset (word 16)) *) + 0x3dc00846; (* arm_LDR Q6 X2 (Immediate_Offset (word 32)) *) + 0x3dc00c47; (* arm_LDR Q7 X2 (Immediate_Offset (word 48)) *) + 0x3cc40444; (* arm_LDR Q4 X2 (Postimmediate_Offset (word 64)) *) + 0x4eb7b431; (* arm_SQDMULH_VEC Q17 Q1 Q23 32 128 *) + 0x4f2f2631; (* arm_SRSHR_VEC Q17 Q17 17 32 128 *) + 0x4eb53439; (* arm_CMGT_VEC Q25 Q1 Q21 32 128 *) + 0x6eb69621; (* arm_MLS_VEC Q1 Q17 Q22 32 128 *) + 0x4e791e31; (* arm_BIC_VEC Q17 Q17 Q25 128 *) + 0x4eb98421; (* arm_ADD_VEC Q1 Q1 Q25 32 128 *) + 0x6ea09821; (* arm_CMLE_VEC_ZERO Q1 Q1 32 128 *) + 0x4f001421; (* arm_ORR_VEC Q1 Q1 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea59431; (* arm_MLA_VEC Q17 Q1 Q5 32 128 *) + 0x4eb83639; (* arm_CMGT_VEC Q25 Q17 Q24 32 128 *) + 0x4e791e31; (* arm_BIC_VEC Q17 Q17 Q25 128 *) + 0x6eb86e31; (* arm_UMIN_VEC Q17 Q17 Q24 32 128 *) + 0x4eb7b452; (* arm_SQDMULH_VEC Q18 Q2 Q23 32 128 *) + 0x4f2f2652; (* arm_SRSHR_VEC Q18 Q18 17 32 128 *) + 0x4eb53459; (* arm_CMGT_VEC Q25 Q2 Q21 32 128 *) + 0x6eb69642; (* arm_MLS_VEC Q2 Q18 Q22 32 128 *) + 0x4e791e52; (* arm_BIC_VEC Q18 Q18 Q25 128 *) + 0x4eb98442; (* arm_ADD_VEC Q2 Q2 Q25 32 128 *) + 0x6ea09842; (* arm_CMLE_VEC_ZERO Q2 Q2 32 128 *) + 0x4f001422; (* arm_ORR_VEC Q2 Q2 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea69452; (* arm_MLA_VEC Q18 Q2 Q6 32 128 *) + 0x4eb83659; (* arm_CMGT_VEC Q25 Q18 Q24 32 128 *) + 0x4e791e52; (* arm_BIC_VEC Q18 Q18 Q25 128 *) + 0x6eb86e52; (* arm_UMIN_VEC Q18 Q18 Q24 32 128 *) + 0x4eb7b473; (* arm_SQDMULH_VEC Q19 Q3 Q23 32 128 *) + 0x4f2f2673; (* arm_SRSHR_VEC Q19 Q19 17 32 128 *) + 0x4eb53479; (* arm_CMGT_VEC Q25 Q3 Q21 32 128 *) + 0x6eb69663; (* arm_MLS_VEC Q3 Q19 Q22 32 128 *) + 0x4e791e73; (* arm_BIC_VEC Q19 Q19 Q25 128 *) + 0x4eb98463; (* arm_ADD_VEC Q3 Q3 Q25 32 128 *) + 0x6ea09863; (* arm_CMLE_VEC_ZERO Q3 Q3 32 128 *) + 0x4f001423; (* arm_ORR_VEC Q3 Q3 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea79473; (* arm_MLA_VEC Q19 Q3 Q7 32 128 *) + 0x4eb83679; (* arm_CMGT_VEC Q25 Q19 Q24 32 128 *) + 0x4e791e73; (* arm_BIC_VEC Q19 Q19 Q25 128 *) + 0x6eb86e73; (* arm_UMIN_VEC Q19 Q19 Q24 32 128 *) + 0x4eb7b410; (* arm_SQDMULH_VEC Q16 Q0 Q23 32 128 *) + 0x4f2f2610; (* arm_SRSHR_VEC Q16 Q16 17 32 128 *) + 0x4eb53419; (* arm_CMGT_VEC Q25 Q0 Q21 32 128 *) + 0x6eb69600; (* arm_MLS_VEC Q0 Q16 Q22 32 128 *) + 0x4e791e10; (* arm_BIC_VEC Q16 Q16 Q25 128 *) + 0x4eb98400; (* arm_ADD_VEC Q0 Q0 Q25 32 128 *) + 0x6ea09800; (* arm_CMLE_VEC_ZERO Q0 Q0 32 128 *) + 0x4f001420; (* arm_ORR_VEC Q0 Q0 (rvalue (word 79228162532711081671548469249)) 128 *) + 0x4ea49410; (* arm_MLA_VEC Q16 Q0 Q4 32 128 *) + 0x4eb83619; (* arm_CMGT_VEC Q25 Q16 Q24 32 128 *) + 0x4e791e10; (* arm_BIC_VEC Q16 Q16 Q25 128 *) + 0x6eb86e10; (* arm_UMIN_VEC Q16 Q16 Q24 32 128 *) + 0x3d800411; (* arm_STR Q17 X0 (Immediate_Offset (word 16)) *) + 0x3d800812; (* arm_STR Q18 X0 (Immediate_Offset (word 32)) *) + 0x3d800c13; (* arm_STR Q19 X0 (Immediate_Offset (word 48)) *) + 0x3c840410; (* arm_STR Q16 X0 (Postimmediate_Offset (word 64)) *) + 0xf1000463; (* arm_SUBS X3 X3 (rvalue (word 1)) *) + 0x54fff861; (* arm_BNE (word 2096908) *) + 0xd65f03c0 (* arm_RET X30 *) +];; +(*** BYTECODE END ***) + +let POLY_USE_HINT_88_AARCH64_ASM_EXEC = ARM_MK_EXEC_RULE poly_use_hint_88_aarch64_asm_mc;; + +(* Per-element word function matching the assembly computation *) +let mldsa_use_hint_88_asm = new_definition + `mldsa_use_hint_88_asm (a:int32) (h:int32) : int32 = + let a1 = word_ishr_round (word_2smulh a (word 1477838209)) 17 in + let m:int32 = word_neg(word(bitval(word_igt a (word 8285184)))) in + let a0 = word_add (word_sub a (word_mul a1 (word 190464))) m in + let a1' = word_and a1 (word_not m) in + let delta:int32 = word_or (word_neg(word(bitval(word_ile a0 (word 0))))) (word 1) in + let tmp = word_add a1' (word_mul delta h) in + let neg_mask:int32 = word_neg(word(bitval(word_igt tmp (word 43)))) in + let tmp' = word_and tmp (word_not neg_mask) in + word_umin tmp' (word 43)`;; + +(* Numeric description of the assembly's UseHint path, exposing the Barrett + approximation used by the code. Connected to the FIPS 204 definition + mldsa_use_hint_88 via MLDSA_USE_HINT_88_EQUIV below. *) +let mldsa_use_hint_88_code = new_definition + `mldsa_use_hint_88_code (a:num) (h:num) = + let a1_raw = ((((a + 127) DIV 128) * 11275 + 8388608) DIV 16777216) in + let a1 = if a1_raw > 43 then 0 else a1_raw in + let a0:int = &a - &a1 * &190464 in + let a0' = if a0 > &4190208 then a0 - &8380417 else a0 in + if h = 0 then a1 + else if a0' > &0 then if a1 = 43 then 0 else a1 + 1 + else if a1 = 0 then 43 else a1 - 1`;; + +(* ========================================================================= *) +(* Helper lemmas *) +(* ========================================================================= *) + +let WORD_2SMULH_NOSATURATE_88 = prove( + `!a:int32. val a < 8380417 + ==> word_2smulh a (word 1477838209:int32) : int32 = + iword((&2 * &(val a) * &1477838209) div &2 pow 32)`, + GEN_TAC THEN DISCH_TAC THEN + REWRITE_TAC[word_2smulh; DIMINDEX_32] THEN + ASM_SIMP_TAC[MLDSA_IVAL_VAL] THEN + CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[iword_saturate; word_INT_MIN; word_INT_MAX; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN + SIMP_TAC[INT_OF_NUM_DIV] THEN + REWRITE_TAC[INT_POS_NEG_BOUND] THEN + SUBGOAL_THEN `~(&((2 * val(a:int32) * 1477838209) DIV 4294967296):int > &2147483647)` + (fun th -> REWRITE_TAC[th]) THEN + REWRITE_TAC[INT_ARITH `~(x:int > y) <=> x <= y`; INT_OF_NUM_LE] THEN + TRANS_TAC LE_TRANS `(2 * 8380416 * 1477838209) DIV 4294967296` THEN + CONJ_TAC THENL + [MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV]);; + +let VAL_DECOMPOSE_A1_88 = prove( + `!a:int32. val a < 8380417 + ==> val(word_ishr_round (word_2smulh a (word 1477838209:int32)) 17 : int32) + = ((2 * val a * 1477838209) DIV 4294967296 + 65536) DIV 131072`, + GEN_TAC THEN DISCH_TAC THEN + ASM_SIMP_TAC[WORD_2SMULH_NOSATURATE_88] THEN + SUBGOAL_THEN `(2 * val(a:int32) * 1477838209) DIV 4294967296 < 2147483648` + ASSUME_TAC THENL + [TRANS_TAC LT_TRANS `(2 * 8380416 * 1477838209) DIV 4294967296 + 1` THEN + CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ABBREV_TAC `t:int32 = iword(&((2 * val(a:int32) * 1477838209) DIV 4294967296))` THEN + SUBGOAL_THEN `val(t:int32) = (2 * val(a:int32) * 1477838209) DIV 4294967296` + ASSUME_TAC THENL + [EXPAND_TAC "t" THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `val(t:int32) < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[word_ishr_round] THEN CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC INT_REDUCE_CONV THEN + SUBGOAL_THEN `ival(t:int32) = &(val t)` ASSUME_TAC THENL + [SIMP_TAC[ival; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_OF_NUM_CLAUSES] THEN SIMP_TAC[INT_OF_NUM_DIV] THEN CONV_TAC NUM_REDUCE_CONV THEN + SUBGOAL_THEN `(val(t:int32) + 65536) DIV 131072 < 2147483648` ASSUME_TAC THENL + [ASM_REWRITE_TAC[] THEN + TRANS_TAC LT_TRANS `(5767167 + 65536) DIV 131072 + 1` THEN CONJ_TAC THENL + [MATCH_MP_TAC(ARITH_RULE `x <= y ==> x < y + 1`) THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV]; ALL_TAC] THEN + ASM_SIMP_TAC[VAL_IWORD_NUM_32] THEN MATCH_MP_TAC VAL_IWORD_NUM_32 THEN + UNDISCH_THEN `val(t:int32) = (2 * val(a:int32) * 1477838209) DIV 4294967296` + (SUBST1_TAC o SYM) THEN ASM_REWRITE_TAC[]);; + +let WORD_IGT_THRESHOLD_88 = BITBLAST_RULE + `!a:int32. val a < 8380417 + ==> word_igt a (word 8285184:int32) <=> val a > 8285184`;; + +let A1_BOUND_88 = prove( + `!a. a < 8380417 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 44`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8380416 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `751819508` (SPEC `d * 11275 + 8388608` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 11275 <= 65472 * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +let A1_BOUND_NOWRAP_88 = prove( + `!a. a <= 8285184 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 43`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPEC `128` (SPEC `8285184 + 127` (SPEC `a + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `738196808` (SPEC `d * 11275 + 8388608` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `d * 11275 <= 64728 * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]);; + +(* Barrett equivalence for _88: 45-interval case analysis *) +let BARRETT_INTERVAL_88 = prove( + `!a lo hi k. + lo <= a /\ a <= hi /\ + k * 131072 <= (2 * lo * 1477838209) DIV 4294967296 + 65536 /\ + (2 * hi * 1477838209) DIV 4294967296 + 65536 < (k + 1) * 131072 /\ + k * 16777216 <= (lo + 127) DIV 128 * 11275 + 8388608 /\ + (hi + 127) DIV 128 * 11275 + 8388608 < (k + 1) * 16777216 + ==> ((2 * a * 1477838209) DIV 4294967296 + 65536) DIV 131072 = k /\ + ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + CONJ_TAC THEN MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THENL + [CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(2 * lo * 1477838209) DIV 4294967296 + 65536` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 65536 <= y + 65536 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(2 * hi * 1477838209) DIV 4294967296 + 65536` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 65536 <= y + 65536 <=> x <= y`] THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]; + CONJ_TAC THENL + [TRANS_TAC LE_TRANS `(lo + 127) DIV 128 * 11275 + 8388608` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 8388608 <= y + 8388608 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC; + TRANS_TAC LET_TRANS `(hi + 127) DIV 128 * 11275 + 8388608` THEN + ASM_REWRITE_TAC[] THEN + REWRITE_TAC[ARITH_RULE `x + 8388608 <= y + 8388608 <=> x <= y`] THEN + REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN + MATCH_MP_TAC DIV_MONO THEN ASM_ARITH_TAC]]);; + +let BARRETT_EQUIV_88 = prove( + `!a. a < 8380417 ==> + ((2 * a * 1477838209) DIV 4294967296 + 65536) DIV 131072 = + ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216`, + GEN_TAC THEN DISCH_TAC THEN + let intervals = [ + (0, 95232); (95233, 285696); (285697, 476160); (476161, 666624); + (666625, 857088); (857089, 1047552); (1047553, 1238016); + (1238017, 1428480); (1428481, 1618944); (1618945, 1809408); + (1809409, 1999872); (1999873, 2190336); (2190337, 2380800); + (2380801, 2571264); (2571265, 2761728); (2761729, 2952192); + (2952193, 3142656); (3142657, 3333120); (3333121, 3523584); + (3523585, 3714048); (3714049, 3904512); (3904513, 4094976); + (4094977, 4285440); (4285441, 4475904); (4475905, 4666368); + (4666369, 4856832); (4856833, 5047296); (5047297, 5237760); + (5237761, 5428224); (5428225, 5618688); (5618689, 5809152); + (5809153, 5999616); (5999617, 6190080); (6190081, 6380544); + (6380545, 6571008); (6571009, 6761472); (6761473, 6951936); + (6951937, 7142400); (7142401, 7332864); (7332865, 7523328); + (7523329, 7713792); (7713793, 7904256); (7904257, 8094720); + (8094721, 8285184); (8285185, 8380416)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("a",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`a:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_88 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + + + +let WORD_SUB_SIGN_88 = BITBLAST_RULE + `!a:int32 b:int32. val b <= 8189952 /\ val a <= 8285184 ==> + ((bit 31 (word_sub a b) \/ word_sub a b = word 0) <=> val a <= val b)`;; + +let WRAP_A0_NEGATIVE_88 = BITBLAST_RULE + `!a:int32. val a < 8380417 /\ val a > 8285184 + ==> bit 31 (word_add (word_sub a (word 8380416:int32)) (word 4294967295:int32))`;; + +let A1_WRAP_88 = prove( + `!a. 8285184 < a /\ a < 8380417 + ==> ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 = 44`, + GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `44 <= ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216` + ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `a + 127` (SPEC `8285185 + 127` DIV_MONO))) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_TAC THEN + ABBREV_TAC `d = (a + 127) DIV 128` THEN + MP_TAC(SPEC `16777216` (SPEC `d * 11275 + 8388608` (SPEC `738208083` DIV_MONO))) THEN + ANTS_TAC THENL + [SUBGOAL_THEN `64729 * 11275 <= d * 11275` MP_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; + CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + MP_TAC(SPEC `a:num` A1_BOUND_88) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_ARITH_TAC);; + +let A0_UPPER_88 = prove( + `!a. a <= 8285184 + ==> a < (((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216 + 1) * 190464`, + GEN_TAC THEN DISCH_TAC THEN + ABBREV_TAC `nv = ((a + 127) DIV 128 * 11275 + 8388608) DIV 16777216` THEN + SUBGOAL_THEN `nv * 16777216 <= (a + 127) DIV 128 * 11275 + 8388608` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN MP_TAC(SPECL [`(a + 127) DIV 128 * 11275 + 8388608`; `16777216`] (CONJUNCT1 DIVISION_SIMP)) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(a + 127) DIV 128 <= 64728` ASSUME_TAC THENL + [MP_TAC(SPEC `128` (SPEC `8285184 + 127` (SPEC `a + 127` DIV_MONO))) THEN ANTS_TAC THENL [ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv * 16777216 <= 64728 * 11275 + 8388608` ASSUME_TAC THENL + [SUBGOAL_THEN `(a + 127) DIV 128 * 11275 <= 64728 * 11275` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; ASM_ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 43` ASSUME_TAC THENL [CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_ARITH_TAC);; + +let WORD_IGT_43_BOUND = BITBLAST_RULE + `!a:int32. val a <= 43 ==> ~(word_igt a (word 43:int32))`;; +let WORD_IGT_43_ADD1 = BITBLAST_RULE + `!a:int32. val a <= 42 ==> ~(word_igt (word_add a (word 1:int32)) (word 43:int32))`;; +let WORD_IGT_43_SUB1 = BITBLAST_RULE + `!a:int32. val a <= 43 /\ ~(val a = 0) ==> + ~(word_igt (word_add a (word 4294967295:int32)) (word 43:int32))`;; +let WORD_IGT_43_TRUE = BITBLAST_RULE + `word_igt (word 44:int32) (word 43:int32)`;; + +let ELEMENT_CORRECT_88 = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> val(mldsa_use_hint_88_asm a h) = mldsa_use_hint_88_code (val a) (val h)`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[mldsa_use_hint_88_asm; mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ABBREV_TAC `nv = ((val(a:int32) + 127) DIV 128 * 11275 + 8388608) DIV 16777216` THEN + SUBGOAL_THEN `val(word_ishr_round (word_2smulh (a:int32) (word 1477838209)) 17 : int32) = nv` ASSUME_TAC THENL + [EXPAND_TAC "nv" THEN TRANS_TAC EQ_TRANS `((2 * val(a:int32) * 1477838209) DIV 4294967296 + 65536) DIV 131072` THEN CONJ_TAC THENL [MATCH_MP_TAC VAL_DECOMPOSE_A1_88 THEN ASM_REWRITE_TAC[]; MATCH_MP_TAC BARRETT_EQUIV_88 THEN ASM_REWRITE_TAC[]]; ALL_TAC] THEN + SUBGOAL_THEN `nv <= 44` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND_88) THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `word_igt (a:int32) (word 8285184:int32) <=> val a > 8285184` SUBST1_TAC THENL [MP_TAC(SPEC `a:int32` WORD_IGT_THRESHOLD_88) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(a:int32) > 8285184` THENL + [(* Wrap case: val a > 8285184, nv = 44 *) + REWRITE_TAC[ASSUME `val(a:int32) > 8285184`; bitval] THEN + SUBGOAL_THEN `nv = 44` SUBST_ALL_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_WRAP_88) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + ABBREV_TAC `a1w = word_ishr_round (word_2smulh (a:int32) (word 1477838209)) 17 : int32` THEN + SUBGOAL_THEN `a1w = (word 44:int32)` SUBST_ALL_TAC THENL [EXPAND_TAC "a1w" THEN ONCE_REWRITE_TAC[GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN + CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_ILE_ZERO_32] THEN + MP_TAC(SPEC `a:int32` WRAP_A0_NEGATIVE_88) THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN + REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + SUBGOAL_THEN `~((if int_gt (&(val(a:int32))) (&4190208) then &(val a) - &8380417 else &(val a):int) > &0)` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN COND_CASES_TAC THENL + [MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT] (ASSUME `val(a:int32) < 8380417`)) THEN INT_ARITH_TAC; + POP_ASSUM(MP_TAC o REWRITE_RULE[INT_GT; INT_NOT_LT]) THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_GT; INT_GT] (ASSUME `val(a:int32) > 8285184`)) THEN INT_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(h:int32) = 0` THENL + [REWRITE_TAC[ASSUME `val(h:int32) = 0`] THEN + SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV; + REWRITE_TAC[ASSUME `~(val(h:int32) = 0)`] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV] + ;(* Nowrap case: val a <= 8285184 *) + REWRITE_TAC[ASSUME `~(val(a:int32) > 8285184)`; bitval] THEN + SUBGOAL_THEN `nv <= 43` ASSUME_TAC THENL [MP_TAC(SPEC `val(a:int32)` A1_BOUND_NOWRAP_88) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_MESON_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `(if nv > 43 then 0 else nv) = nv` SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ABBREV_TAC `a1w = word_ishr_round (word_2smulh (a:int32) (word 1477838209)) 17 : int32` THEN + SUBGOAL_THEN `a1w = (word nv:int32)` SUBST_ALL_TAC THENL [EXPAND_TAC "a1w" THEN GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN AP_TERM_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_REFL; WORD_ILE_ZERO_32; WORD_ADD_0; WORD_AND_ONES_32] THEN + SUBGOAL_THEN `nv * 190464 <= 8189952` ASSUME_TAC THENL [SUBGOAL_THEN `nv * 190464 <= 43 * 190464` MP_TAC THENL [REWRITE_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; CONV_TAC NUM_REDUCE_CONV THEN ARITH_TAC]; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 190464:int32)) = nv * 190464` ASSUME_TAC THENL [REWRITE_TAC[VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN MATCH_MP_TAC MOD_LT THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `val(word_mul (word nv:int32) (word 190464:int32)) <= 8189952` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(bit 31 (word_sub (a:int32) (word_mul (word nv:int32) (word 190464:int32))) \/ word_sub a (word_mul (word nv) (word 190464)) = word 0) <=> val a <= nv * 190464` SUBST1_TAC THENL + [MP_TAC(ISPECL [`a:int32`; `word_mul (word nv:int32) (word 190464:int32)`] WORD_SUB_SIGN_88) THEN ASM_REWRITE_TAC[] THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN DISCH_THEN SUBST1_TAC THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[bitval] THEN + SUBGOAL_THEN `val(word nv:int32) = nv` ASSUME_TAC THENL [MATCH_MP_TAC VAL_WORD_EQ THEN REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~(word_igt (word nv:int32) (word 43:int32))` ASSUME_TAC THENL [MP_TAC(SPEC `word nv:int32` WORD_IGT_43_BOUND) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + ASM_CASES_TAC `val(h:int32) = 0` THENL + [(* h = 0 nowrap *) + REWRITE_TAC[ASSUME `val(h:int32) = 0`] THEN + SUBGOAL_THEN `h:int32 = word 0` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_0; WORD_ADD_0] THEN + ASM_REWRITE_TAC[] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32; VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[MIN] THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC + ;(* h = 1 nowrap *) + REWRITE_TAC[ASSUME `~(val(h:int32) = 0)`] THEN + SUBGOAL_THEN `h:int32 = word 1` SUBST1_TAC THENL [REWRITE_TAC[GSYM VAL_EQ_1] THEN ASM_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[WORD_MUL_1_32; WORD_AND_ONES_32] THEN + SUBGOAL_THEN `val(word nv:int32) <= 43` ASSUME_TAC THENL [ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(int_gt (&(val(a:int32)) - &nv * &190464) (&4190208))` ASSUME_TAC THENL + [REWRITE_TAC[INT_GT; INT_NOT_LT] THEN MP_TAC(SPEC `val(a:int32)` A0_UPPER_88) THEN ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN ASM_REWRITE_TAC[] THEN DISCH_TAC THEN MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD] (ASSUME `val(a:int32) < (nv + 1) * 190464`)) THEN INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `val(a:int32) <= nv * 190464` THENL + [MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `190464`] REAL_INT_GT_BRIDGE) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]; + MP_TAC(SPECL [`val(a:int32)`; `nv:num`; `190464`] REAL_INT_GT_BRIDGE_POS) THEN ASM_REWRITE_TAC[] THEN STRIP_TAC THEN ASM_REWRITE_TAC[]] THENL + [(* delta = -1: a0' <= 0 *) + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32] THEN + ASM_CASES_TAC `nv = 0` THENL + [ASM_REWRITE_TAC[] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV + ;REWRITE_TAC[ASSUME `~(nv = 0)`] THEN + SUBGOAL_THEN `~word_igt (word_add (word nv:int32) (word 4294967295)) (word 43:int32)` ASSUME_TAC THENL + [MP_TAC(SPEC `word nv:int32` WORD_IGT_43_SUB1) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN MATCH_MP_TAC THEN CONJ_TAC THENL [ASM_REWRITE_TAC[]; REWRITE_TAC[GSYM VAL_EQ_0] THEN ASM_REWRITE_TAC[] THEN ASM_ARITH_TAC]; ALL_TAC] THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32; VAL_WORD_UMIN; VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(nv + 4294967295) MOD 4294967296 = nv - 1` SUBST1_TAC THENL [SUBGOAL_THEN `nv + 4294967295 = (nv - 1) + 1 * 4294967296` SUBST1_TAC THENL [UNDISCH_TAC `~(nv = 0)` THEN ARITH_TAC; REWRITE_TAC[MOD_MULT_ADD] THEN MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC]; ALL_TAC] THEN + REWRITE_TAC[MIN] THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC] + ;(* delta = +1: a0' > 0 *) + REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32] THEN + ASM_CASES_TAC `nv = 43` THENL + [ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN CONV_TAC WORD_REDUCE_CONV THEN MP_TAC WORD_IGT_43_TRUE THEN DISCH_TAC THEN ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[VAL_WORD_UMIN; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV + ;REWRITE_TAC[ASSUME `~(nv = 43)`] THEN + SUBGOAL_THEN `~word_igt (word_add (word nv:int32) (word 1)) (word 43:int32)` ASSUME_TAC THENL + [MP_TAC(SPEC `word nv:int32` WORD_IGT_43_ADD1) THEN ASM_REWRITE_TAC[] THEN DISCH_THEN MATCH_MP_TAC THEN UNDISCH_TAC `nv <= 43` THEN UNDISCH_TAC `~(nv = 43)` THEN UNDISCH_TAC `val(word nv:int32) = nv` THEN ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[bitval] THEN CONV_TAC WORD_REDUCE_CONV THEN REWRITE_TAC[WORD_AND_ONES_32; VAL_WORD_UMIN; VAL_WORD_ADD; VAL_WORD; DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `nv MOD 4294967296 = nv` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(nv + 1) MOD 4294967296 = nv + 1` SUBST1_TAC THENL [MATCH_MP_TAC MOD_LT THEN UNDISCH_TAC `nv <= 43` THEN ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[MIN] THEN UNDISCH_TAC `nv <= 43` THEN UNDISCH_TAC `~(nv = 43)` THEN ARITH_TAC]]]]);; + +let ELEMENT_CORRECT_WORD_88 = prove( + `!a:int32 h:int32. + val a < 8380417 /\ val h <= 1 + ==> mldsa_use_hint_88_asm a h = + word(mldsa_use_hint_88_code (val a) (val h))`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + GEN_REWRITE_TAC LAND_CONV [GSYM WORD_VAL] THEN + AP_TERM_TAC THEN MP_TAC(SPECL [`a:int32`; `h:int32`] ELEMENT_CORRECT_88) THEN + ASM_REWRITE_TAC[] THEN DISCH_THEN(fun th -> REWRITE_TAC[th]));; + +(* ========================================================================= *) +(* Correctness proof, code-aligned spec (intermediate) *) +(* ========================================================================= *) + +let POLY_USE_HINT_88_AARCH64_ASM_CORRECT_CODE = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH poly_use_hint_88_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_88_aarch64_asm_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH poly_use_hint_88_aarch64_asm_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i)) (val(y i))))) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + + MAP_EVERY X_GEN_TAC + [`b:int64`; `a:int64`; `h:int64`; + `x:num->int32`; `y:num->int32`; `pc:num`] THEN + REWRITE_TAC[MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI; C_ARGUMENTS; + NONOVERLAPPING_CLAUSES; ALL; + fst POLY_USE_HINT_88_AARCH64_ASM_EXEC] THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + GLOBALIZE_PRECONDITION_TAC THEN + CONV_TAC(RATOR_CONV(LAND_CONV(ONCE_DEPTH_CONV EXPAND_CASES_CONV))) THEN + CONV_TAC NUM_REDUCE_CONV THEN + REPEAT STRIP_TAC THEN + REWRITE_TAC[SOME_FLAGS; MODIFIABLE_SIMD_REGS] THEN + + ENSURES_INIT_TAC "s0" THEN + MEMORY_128_FROM_32_TAC "a" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + MEMORY_128_FROM_32_TAC "h" 0 64 THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN CONV_TAC WORD_REDUCE_CONV THEN + STRIP_TAC THEN + DISCARD_MATCHING_ASSUMPTIONS [`read (memory :> bytes32 a) s = x`] THEN + + MAP_EVERY (fun n -> ARM_STEPS_TAC POLY_USE_HINT_88_AARCH64_ASM_EXEC [n] THEN + SIMD_SIMPLIFY_TAC[]) + (1--1006) THEN + ENSURES_FINAL_STATE_TAC THEN ASM_REWRITE_TAC[] THEN + + REPEAT(FIRST_X_ASSUM(STRIP_ASSUME_TAC o + CONV_RULE (SIMD_SIMPLIFY_CONV []) o + CONV_RULE(READ_MEMORY_SPLIT_CONV 2) o + check (can (term_match [] `read qqq s:int128 = xxx`) o concl))) THEN + + CONV_TAC(TOP_DEPTH_CONV EXPAND_CASES_CONV) THEN + CONV_TAC(DEPTH_CONV NUM_MULT_CONV THENC DEPTH_CONV NUM_ADD_CONV) THEN + REWRITE_TAC[WORD_ADD_0] THEN + ASM_REWRITE_TAC[WORD_ADD_0] THEN ASM_REWRITE_TAC[] THEN + + (* Push word_subword through SIMD ops to per-element form *) + REWRITE_TAC[WORD_SUBWORD_AND; WORD_SUBWORD_OR] THEN + let WSN_TAC = REWRITE_TAC(map (fun n -> prove( + subst [mk_small_numeral n, `n:num`] + `!x:int128. word_subword(word_not x) (n,32):int32 = word_not(word_subword x (n,32))`, + GEN_TAC THEN MATCH_MP_TAC WORD_SUBWORD_NOT THEN + REWRITE_TAC[DIMINDEX_32; DIMINDEX_128] THEN ARITH_TAC)) [0;32;64;96]) in + WSN_TAC THEN + CONV_TAC(DEPTH_CONV WORD_SIMPLE_SUBWORD_CONV) THEN + CONV_TAC(DEPTH_CONV WORD_NUM_RED_CONV) THEN + let EC_DEEP = CONV_RULE(DEPTH_CONV WORD_NUM_RED_CONV) + (CONV_RULE(DEPTH_CONV(INT_RED_CONV ORELSEC NUM_RED_CONV)) + (CONV_RULE(TOP_DEPTH_CONV let_CONV) + (REWRITE_RULE[mldsa_use_hint_88_asm; word_2smulh; word_ishr_round; + DIMINDEX_32] ELEMENT_CORRECT_WORD_88))) in + let EC_OR = ONCE_REWRITE_RULE[WORD_OR_SYM] EC_DEEP in + REPEAT CONJ_TAC THEN + (MATCH_MP_TAC EC_OR ORELSE MATCH_MP_TAC EC_DEEP) THEN + CONJ_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ARITH_TAC);; + +(* ========================================================================= *) +(* Subroutine form (intermediate, code-aligned) *) +(* ========================================================================= *) + +let POLY_USE_HINT_88_AARCH64_ASM_CORRECT_BOUND_CODE = prove + (`!b a h x y pc. + nonoverlapping (word pc, LENGTH poly_use_hint_88_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_88_aarch64_asm_mc /\ + read PC s = word pc /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = word(pc + LENGTH poly_use_hint_88_aarch64_asm_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i)) (val(y i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + MATCH_MP_TAC ENSURES_STRENGTHEN_POST THEN + EXISTS_TAC + `\s. read PC s = word(pc + LENGTH poly_use_hint_88_aarch64_asm_mc - 4) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i:int32)) (val(y i:int32))))` THEN + CONJ_TAC THENL + [MATCH_MP_TAC POLY_USE_HINT_88_AARCH64_ASM_CORRECT_CODE THEN ASM_REWRITE_TAC[]; + REWRITE_TAC[] THEN REPEAT STRIP_TAC THEN ASM_REWRITE_TAC[] THEN + FIRST_X_ASSUM(MP_TAC o SPEC `i:num`) THEN ASM_REWRITE_TAC[] THEN + DISCH_THEN SUBST1_TAC THEN REWRITE_TAC[VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + MATCH_MP_TAC(ARITH_RULE `x < 44 ==> x MOD 4294967296 < 44`) THEN + REWRITE_TAC[mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REPEAT(COND_CASES_TAC THEN ASM_REWRITE_TAC[]) THEN ASM_ARITH_TAC]);; + +(* Intermediate subroutine correctness against the code-aligned spec. + Bridged to the public FIPS 204-aligned theorem below via + MLDSA_USE_HINT_88_EQUIV. *) +let POLY_USE_HINT_88_AARCH64_ASM_SUBROUTINE_CORRECT_CODE = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH poly_use_hint_88_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_88_aarch64_asm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88_code (val(x i)) (val(y i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REWRITE_TAC[fst POLY_USE_HINT_88_AARCH64_ASM_EXEC] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ARM_ADD_RETURN_NOSTACK_TAC POLY_USE_HINT_88_AARCH64_ASM_EXEC + (CONV_RULE(ONCE_DEPTH_CONV NUM_REDUCE_CONV) + (REWRITE_RULE[fst POLY_USE_HINT_88_AARCH64_ASM_EXEC] + POLY_USE_HINT_88_AARCH64_ASM_CORRECT_BOUND_CODE)));; + + +(* ========================================================================= *) +(* FIPS 204 = code-aligned equivalence *) +(* ========================================================================= *) + +let LINEARIZE_DIV_MOD_TAC_88 = + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +(* Prove r DIV 190464 = k via DIV_SANDWICH + LE_MULT_RCANCEL *) +let DIV_190464_TAC k = + let k_num = mk_small_numeral k and km1 = mk_small_numeral (k-1) + and kp1 = mk_small_numeral (k+1) + and q = mk_var("q",`:num`) and le = `(<=):num->num->bool` + and lt = `(<):num->num->bool` + and c = `190464` in + let mk_mul a b = mk_binop (rator(rator `0*0`)) a b in + MATCH_MP_TAC DIV_SANDWICH THEN CONV_TAC NUM_REDUCE_CONV THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; c] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, q) THEN + REPEAT GEN_TAC THEN STRIP_TAC THEN + ASM_CASES_TAC(mk_comb(mk_comb(le, q), km1)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul q c), + mk_mul km1 c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL]; ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul k_num c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + ASM_CASES_TAC(mk_comb(mk_comb(lt, k_num), q)) THENL + [SUBGOAL_THEN(mk_comb(mk_comb(le, mk_mul kp1 c), + mk_mul q c)) ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_MULT_RCANCEL] THEN DISJ1_TAC THEN ASM_ARITH_TAC; + ALL_TAC] THEN + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC; + CONV_TAC NUM_REDUCE_CONV THEN ASM_ARITH_TAC]];; + +(* Replace (r - r MOD 190464) DIV 190464 with r DIV 190464 *) +let DIV_MOD_TO_DIV_TAC_88 = + SUBGOAL_THEN `(r - r MOD 190464) DIV 190464 = r DIV 190464` SUBST1_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 = 190464 * r DIV 190464` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV; ALL_TAC];; + +(* Lower half nowrap: dismiss wrap cond, reduce, prove r DIV 190464 = k *) +let DECOMPOSE_R1_LOWER_TAC_88 = + SUBGOAL_THEN `~((&r:int) - &(r MOD 190464) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN LINEARIZE_DIV_MOD_TAC_88; + ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + DIV_MOD_TO_DIV_TAC_88 THEN + CONV_TAC SYM_CONV THEN + LINEARIZE_DIV_MOD_TAC_88;; + +(* Upper half nowrap: dismiss wrap cond, reduce, prove r DIV 190464 + 1 = k *) +let DECOMPOSE_R1_UPPER_TAC_88 = + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `~((&r:int) - (&(r MOD 190464) - &190464) = &8380416)` (fun th -> REWRITE_TAC[th]) THENL + [REWRITE_TAC[INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`] THEN + ASM_SIMP_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + LINEARIZE_DIV_MOD_TAC_88; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 + 190464 = 190464 * (r DIV 190464 + 1)` + SUBST1_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464 + 1`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV THEN DISCH_THEN SUBST1_TAC THEN + REPEAT(FIRST_X_ASSUM(MP_TAC o check (fun th -> + free_in `r MOD 190464` (concl th) || + free_in `r DIV 190464` (concl th)))) THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + SPEC_TAC(`r MOD 190464`, `m:num`) THEN + SPEC_TAC(`r DIV 190464`, `q:num`) THEN + REPEAT GEN_TAC THEN ASM_ARITH_TAC;; + +let DECOMPOSE_R1_NOWRAP_TAC_88 = + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THEN + TRY DECOMPOSE_R1_LOWER_TAC_88 THEN TRY DECOMPOSE_R1_UPPER_TAC_88;; + +let DECOMPOSE_88_R1_EQUIV = time prove( + `!r. r < 8380417 ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = + decompose_88_r1 r`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ASM_CASES_TAC `r <= 8285184` THENL + [ALL_TAC; + (* Wrap zone *) + SUBGOAL_THEN `8285184 < r` ASSUME_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `decompose_88_r1 r = 0` SUBST1_TAC THENL + [REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; + ALL_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THENL + [(* Lower wrap: r DIV 190464 = 44 *) + SUBGOAL_THEN `r DIV 190464 = 44` ASSUME_TAC THENL + [DIV_190464_TAC 44; ALL_TAC] THEN + SUBGOAL_THEN `44 * 190464 + r MOD 190464 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ] THEN + ASM_ARITH_TAC; + (* Upper wrap: r DIV 190464 = 43 *) + SUBGOAL_THEN `r DIV 190464 = 43` ASSUME_TAC THENL + [DIV_190464_TAC 43; ALL_TAC] THEN + SUBGOAL_THEN `43 * 190464 + r MOD 190464 = r` MP_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` SUBST1_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + REWRITE_TAC[INT_OF_NUM_EQ] THEN ASM_ARITH_TAC]; + ALL_TAC] THEN + MP_TAC(SPEC `r:num` A1_WRAP_88) THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + DISCH_THEN SUBST1_TAC THEN CONV_TAC NUM_REDUCE_CONV] THEN + (* Nowrap zone: Barrett <= 43, so if > 43 then 0 else Barrett = Barrett *) + SUBGOAL_THEN `((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 <= 43` + ASSUME_TAC THENL + [MATCH_MP_TAC A1_BOUND_NOWRAP_88 THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 > 43)` + (fun th -> REWRITE_TAC[th]) THENL [ASM_ARITH_TAC; ALL_TAC] THEN + (* Nowrap zone: unfold and do interval cascade *) + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + let intervals = [ + (0, 95232); (95233, 285696); (285697, 476160); (476161, 666624); + (666625, 857088); (857089, 1047552); (1047553, 1238016); + (1238017, 1428480); (1428481, 1618944); (1618945, 1809408); + (1809409, 1999872); (1999873, 2190336); (2190337, 2380800); + (2380801, 2571264); (2571265, 2761728); (2761729, 2952192); + (2952193, 3142656); (3142657, 3333120); (3333121, 3523584); + (3523585, 3714048); (3714049, 3904512); (3904513, 4094976); + (4094977, 4285440); (4285441, 4475904); (4475905, 4666368); + (4666369, 4856832); (4856833, 5047296); (5047297, 5237760); + (5237761, 5428224); (5428225, 5618688); (5618689, 5809152); + (5809153, 5999616); (5999617, 6190080); (6190081, 6380544); + (6380545, 6571008); (6571009, 6761472); (6761473, 6951936); + (6951937, 7142400); (7142401, 7332864); (7332865, 7523328); + (7523329, 7713792); (7713793, 7904256); (7904257, 8094720); + (8094721, 8285184)] in + let mk_le hi = + mk_comb(mk_comb(`(<=):num->num->bool`, mk_var("r",`:num`)), + mk_small_numeral hi) in + let apply_interval k (lo, hi) = + let th = SPECL [`r:num`; mk_small_numeral lo; + mk_small_numeral hi; mk_small_numeral k] + BARRETT_INTERVAL_88 in + MP_TAC th THEN CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + STRIP_TAC THEN ASM_REWRITE_TAC[] THEN CONV_TAC NUM_REDUCE_CONV THEN + DECOMPOSE_R1_NOWRAP_TAC_88 in + let rec cascade k = function + | [(lo,hi)] -> apply_interval k (lo,hi) + | (lo,hi)::rest -> + ASM_CASES_TAC (mk_le hi) THENL + [apply_interval k (lo,hi); cascade (k+1) rest] + | [] -> failwith "empty" in + cascade 0 intervals);; + +let R1_IS_DIV_LOWER_88 = prove( + `!r. r < 8380417 /\ r MOD 190464 * 2 <= 190464 /\ + ~((&r:int) - &(r MOD 190464) = &8380416) ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = r DIV 190464`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` LOWER_NONWRAP_R1_88) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +let R1_IS_DIV_PLUS1_UPPER_88 = prove( + `!r. r < 8380417 /\ ~(r MOD 190464 * 2 <= 190464) /\ + ~((&r:int) - (&(r MOD 190464) - &190464) = &8380416) ==> + (let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw) = r DIV 190464 + 1`, + GEN_TAC THEN STRIP_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` UPPER_NONWRAP_R1_88) THEN ASM_REWRITE_TAC[] THEN + REPEAT DISCH_TAC THEN ASM_REWRITE_TAC[]);; + +(* Upper nowrap: substitute Barrett = r DIV 190464 + 1, use INT_MOD_RESIDUE *) +let R0_SIGN_UPPER_NOWRAP_TAC_88 = + MP_TAC(SPEC `r:num` R1_IS_DIV_PLUS1_UPPER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_ADD; GSYM INT_OF_NUM_MUL] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `(a:int) - (b + &1) * c = a - b * c - c`] THEN + REWRITE_TAC[INT_ARITH `x - &190464 > &0 <=> x > &190464`; + INT_ARITH `x - &190464 - &8380417 > &0 <=> x > &8570881`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Lower nowrap: substitute Barrett = r DIV 190464, use INT_MOD_RESIDUE *) +let R0_SIGN_LOWER_NOWRAP_TAC_88 = + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_THEN SUBST1_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN ASM_REWRITE_TAC[] THEN + REWRITE_TAC[INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +(* Wrap: derive 8285184 < r, use DECOMPOSE_88_R1_EQUIV to get Barrett = 0 *) +let R0_SIGN_WRAP_TAC_88 = + SUBGOAL_THEN `8285184 < r` ASSUME_TAC THENL + [FIRST_X_ASSUM(MP_TAC o check (fun th -> + can (find_term (fun t -> t = `&8380416:int`)) (concl th) && + not(is_neg(concl th)))) THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_EQ; + INT_ARITH `(a:int) - (b - c) = d <=> a + c - b = d`; + GSYM INT_OF_NUM_ADD] THEN ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + DISCH_TAC THEN ASM_REWRITE_TAC[INT_MUL_LZERO; INT_SUB_RZERO] THEN + REWRITE_TAC[INT_ARITH `x - &1 > &0 <=> x > &1`; + INT_ARITH `(x - &190464) - &1 > &0 <=> x > &190465`; + INT_ARITH `x - &8380417 > &0 <=> x > &8380417`; + INT_OF_NUM_GT] THEN + ASM_ARITH_TAC;; + +let DECOMPOSE_88_R0_SIGN = time prove( + `!r. r < 8380417 ==> + let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + let a1 = if a1_raw > 43 then 0 else a1_raw in + let a0':int = if (&r:int) - &a1 * &190464 > &4190208 + then &r - &a1 * &190464 - &8380417 + else &r - &a1 * &190464 in + (decompose_88_r0 r > &0 <=> a0' > &0) /\ + (decompose_88_r0 r <= &0 <=> ~(a0' > &0))`, + GEN_TAC THEN DISCH_TAC THEN CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + REWRITE_TAC[INT_ARITH `(x:int) <= &0 <=> ~(x > &0)`] THEN + MATCH_MP_TAC(TAUT `(p <=> q) ==> (p <=> q) /\ (~p <=> ~q)`) THEN + REWRITE_TAC[decompose_88_r0; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + ONCE_REWRITE_TAC[COND_RAND] THEN REWRITE_TAC[SND; FST] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + ASM_CASES_TAC `r MOD 190464 * 2 <= 190464` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[] THEN + TRY R0_SIGN_LOWER_NOWRAP_TAC_88 THEN + TRY R0_SIGN_UPPER_NOWRAP_TAC_88 THEN + TRY R0_SIGN_WRAP_TAC_88 THEN + TRY( + (* Contradiction: lower nowrap with > 4190208 *) + MP_TAC(SPEC `r:num` R1_IS_DIV_LOWER_88) THEN + ANTS_TAC THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN DISCH_TAC THEN + MP_TAC(CONV_RULE NUM_REDUCE_CONV + (SPECL [`r:num`; `190464`] INT_MOD_RESIDUE)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `(&r:int) - &((let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw)) * &190464 = &(r MOD 190464)` ASSUME_TAC THENL + [CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[]; ALL_TAC] THEN + SUBGOAL_THEN `~(&(r MOD 190464) > (&4190208:int))` MP_TAC THENL + [REWRITE_TAC[INT_NOT_LT; INT_OF_NUM_LE] THEN ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN REWRITE_TAC[INT_OF_NUM_GT] THEN ASM_ARITH_TAC + ));; + +(* R1 equivalence for _88 (needed for MLDSA_USE_HINT_88_EQUIV below). + + Show decompose_88_r1 r equals the a1 value from the code-aligned spec: + let a1_raw = ((r + 127) DIV 128 * 11275 + 8388608) DIV 16777216 in + if a1_raw > 43 then 0 else a1_raw *) + +let MLDSA_USE_HINT_88_EQUIV = prove( + `!r h. r < 8380417 /\ h <= 1 + ==> mldsa_use_hint_88 h r = mldsa_use_hint_88_code r h`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[MLDSA_USE_HINT_88_UNFOLD] THEN + REWRITE_TAC[mldsa_use_hint_88_code] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R1_EQUIV) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN + DISCH_TAC THEN + MP_TAC(SPEC `r:num` DECOMPOSE_88_R0_SIGN) THEN ASM_REWRITE_TAC[] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN STRIP_TAC THEN + ASM_CASES_TAC `h = 0` THENL + [ASM_REWRITE_TAC[ARITH_RULE `~(0 = 1)`]; ALL_TAC] THEN + SUBGOAL_THEN `h = 1` SUBST_ALL_TAC THENL [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + ASM_CASES_TAC `decompose_88_r0 r > &0` THEN ASM_REWRITE_TAC[] THEN + COND_CASES_TAC THEN ASM_REWRITE_TAC[]);; + +(* ========================================================================= *) +(* Public subroutine correctness (FIPS 204-aligned) *) +(* ========================================================================= *) + +(* Postcondition is stated in terms of mldsa_use_hint_88 from FIPS 204 + (Algorithm 40), with the output bound < 44 as a corollary. + Derived from POLY_USE_HINT_88_AARCH64_ASM_SUBROUTINE_CORRECT_CODE by + rewriting mldsa_use_hint_88_code -> mldsa_use_hint_88 via + MLDSA_USE_HINT_88_EQUIV. *) +let POLY_USE_HINT_88_AARCH64_ASM_SUBROUTINE_CORRECT = prove + (`!b a h x y pc returnaddress. + nonoverlapping (word pc, LENGTH poly_use_hint_88_aarch64_asm_mc) (b, 1024) /\ + nonoverlapping (b, 1024) (a, 1024) /\ + nonoverlapping (b, 1024) (h, 1024) /\ + (!i. i < 256 ==> val((x:num->int32) i) < 8380417) /\ + (!i. i < 256 ==> val((y:num->int32) i) <= 1) + ==> ensures arm + (\s. aligned_bytes_loaded s (word pc) poly_use_hint_88_aarch64_asm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + (!i. i < 256 ==> val(x i) < 8380417) /\ + (!i. i < 256 ==> val(y i) <= 1) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add a (word(4 * i)))) s = x i) /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add h (word(4 * i)))) s = y i)) + (\s. read PC s = returnaddress /\ + (!i. i < 256 ==> + read(memory :> bytes32(word_add b (word(4 * i)))) s = + word(mldsa_use_hint_88 (val(y i)) (val(x i)))) /\ + (!i. i < 256 ==> + val(read(memory :> bytes32(word_add b (word(4 * i)))) s) < 44)) + (MAYCHANGE_REGS_AND_FLAGS_PERMITTED_BY_ABI ,, + MAYCHANGE [memory :> bytes(b, 1024)])`, + REPEAT GEN_TAC THEN + DISCH_THEN(REPEAT_TCL CONJUNCTS_THEN ASSUME_TAC) THEN + SUBGOAL_THEN + `!i. i < 256 ==> + mldsa_use_hint_88 (val((y:num->int32) i)) (val((x:num->int32) i)) = + mldsa_use_hint_88_code (val(x i)) (val(y i))` + (fun th -> SIMP_TAC[th]) THENL + [REPEAT STRIP_TAC THEN MATCH_MP_TAC MLDSA_USE_HINT_88_EQUIV THEN + CONJ_TAC THEN FIRST_X_ASSUM MATCH_MP_TAC THEN ASM_REWRITE_TAC[]; + MATCH_MP_TAC POLY_USE_HINT_88_AARCH64_ASM_SUBROUTINE_CORRECT_CODE THEN + ASM_REWRITE_TAC[]]);; + + +(* ========================================================================= *) +(* Constant-time and memory safety proof. *) +(* ========================================================================= *) + +needs "arm/proofs/consttime.ml";; +needs "aarch64/proofs/subroutine_signatures.ml";; + + +let full_spec,public_vars = mk_safety_spec + ~keep_maychanges:false + (assoc "poly_use_hint_88_aarch64_asm" subroutine_signatures) + POLY_USE_HINT_88_AARCH64_ASM_SUBROUTINE_CORRECT_CODE + POLY_USE_HINT_88_AARCH64_ASM_EXEC;; + +let POLY_USE_HINT_88_AARCH64_ASM_SUBROUTINE_SAFE = time prove + (`exists f_events. + forall e b a h pc returnaddress. + nonoverlapping (word pc,LENGTH poly_use_hint_88_aarch64_asm_mc) (b,1024) /\ + nonoverlapping (b,1024) (a,1024) /\ + nonoverlapping (b,1024) (h,1024) + ==> ensures arm + (\s. + aligned_bytes_loaded s (word pc) + poly_use_hint_88_aarch64_asm_mc /\ + read PC s = word pc /\ + read X30 s = returnaddress /\ + C_ARGUMENTS [b; a; h] s /\ + read events s = e) + (\s. + read PC s = returnaddress /\ + (exists e2. + read events s = APPEND e2 e /\ + e2 = f_events a h b pc returnaddress /\ + memaccess_inbounds e2 [a,1024; h,1024; b,1024] + [b,1024])) + (\s s'. true)`, + ASSERT_CONCL_TAC full_spec THEN + PROVE_SAFETY_SPEC_TAC ~public_vars:public_vars POLY_USE_HINT_88_AARCH64_ASM_EXEC);; diff --git a/proofs/hol_light/aarch64/proofs/subroutine_signatures.ml b/proofs/hol_light/aarch64/proofs/subroutine_signatures.ml index d1c0f9efe..41e54856d 100644 --- a/proofs/hol_light/aarch64/proofs/subroutine_signatures.ml +++ b/proofs/hol_light/aarch64/proofs/subroutine_signatures.ml @@ -81,6 +81,42 @@ let subroutine_signatures = [ ]) ); +("poly_use_hint_32_aarch64_asm", + ([(*args*) + ("b", "int32_t[static 256]", (*is const?*)"false"); + ("a", "int32_t[static 256]", (*is const?*)"true"); + ("h", "int32_t[static 256]", (*is const?*)"true"); + ], + "void", + [(* input buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ("h", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* output buffers *) + ("b", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + +("poly_use_hint_88_aarch64_asm", + ([(*args*) + ("b", "int32_t[static 256]", (*is const?*)"false"); + ("a", "int32_t[static 256]", (*is const?*)"true"); + ("h", "int32_t[static 256]", (*is const?*)"true"); + ], + "void", + [(* input buffers *) + ("a", "256"(* num elems *), 4(* elem bytesize *)); + ("h", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* output buffers *) + ("b", "256"(* num elems *), 4(* elem bytesize *)); + ], + [(* temporary buffers *) + ]) +); + ("mldsa_pointwise_acc_l7", ([(*args*) ("r", "int32_t[static 256]", (*is const?*)"false"); diff --git a/proofs/hol_light/common/mldsa_specs.ml b/proofs/hol_light/common/mldsa_specs.ml index 1758774db..6dba9bde4 100644 --- a/proofs/hol_light/common/mldsa_specs.ml +++ b/proofs/hol_light/common/mldsa_specs.ml @@ -1458,3 +1458,243 @@ let READ_BYTES_SPLIT_ANY = prove( REWRITE_TAC[READ_BYTES_COMBINE] THEN REWRITE_TAC[MATCH_MP NUM_BIT_DECOMPOSE_UNIQ bound]);; +(* ========================================================================= *) +(* ML-DSA use_hint shared infrastructure lemmas *) +(* Used by both poly_use_hint_32 and poly_use_hint_88 proofs *) +(* ========================================================================= *) + + + +let MLDSA_IVAL_VAL = prove( + `!a:int32. val a < 8380417 ==> ival a = &(val a)`, + GEN_TAC THEN DISCH_TAC THEN + SIMP_TAC[ival; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN + COND_CASES_TAC THEN ASM_ARITH_TAC);; + +let INT_POS_NEG_BOUND = prove(`!n. ~((&n:int) < --(&2147483648))`, + GEN_TAC THEN REWRITE_TAC[INT_NOT_LT] THEN + MP_TAC(SPEC `n:num` INT_POS) THEN INT_ARITH_TAC);; + +let VAL_IWORD_NUM_32 = prove( + `!n. n < 2147483648 ==> val(iword(&n):int32) = n`, + GEN_TAC THEN DISCH_TAC THEN + MP_TAC(ISPECL [`&n:int`] (INST_TYPE [`:32`,`:N`] INT_VAL_IWORD)) THEN + REWRITE_TAC[DIMINDEX_32; INT_POS] THEN + CONV_TAC NUM_REDUCE_CONV THEN + ANTS_TAC THENL + [REWRITE_TAC[INT_OF_NUM_LT] THEN ASM_ARITH_TAC; + REWRITE_TAC[INT_OF_NUM_EQ] THEN SIMP_TAC[]]);; + +let WORD_ILE_ZERO_32 = BITBLAST_RULE + `!x:int32. word_ile x (word 0) <=> bit 31 x \/ x = word 0`;; + +let VAL_WORD_AND_15_32 = BITBLAST_RULE + `!x:int32. val(word_and x (word 15:int32)) = val x MOD 16`;; + +let WORD_AND_ONES_32 = prove( + `!x:int32. word_and x (word 4294967295) = x`, + GEN_TAC THEN SUBGOAL_THEN `(word 4294967295 : int32) = word_not(word 0)` SUBST1_TAC THENL + [CONV_TAC WORD_REDUCE_CONV; REWRITE_TAC[WORD_AND_NOT0]]);; + +let WORD_MUL_1_32 = prove( + `!x:int32. word_mul x (word 1) = x`, + GEN_TAC THEN REWRITE_TAC[GSYM VAL_EQ; VAL_WORD_MUL; VAL_WORD; DIMINDEX_32] THEN + CONV_TAC NUM_REDUCE_CONV THEN REWRITE_TAC[MULT_CLAUSES] THEN + MATCH_MP_TAC MOD_LT THEN MP_TAC(ISPEC `x:int32` VAL_BOUND) THEN + REWRITE_TAC[DIMINDEX_32] THEN CONV_TAC NUM_REDUCE_CONV);; + +let REAL_INT_GT_BRIDGE = prove( + `!a:num b c. a <= b * c ==> + ~(real_gt (&a - &b * &c) (&0)) /\ ~(int_gt (&a - &b * &c) (&0))`, + REPEAT GEN_TAC THEN DISCH_TAC THEN CONJ_TAC THENL + [REWRITE_TAC[real_gt; REAL_NOT_LT] THEN + MP_TAC(REWRITE_RULE[GSYM REAL_OF_NUM_LE; GSYM REAL_OF_NUM_MUL] (ASSUME `a <= b * c`)) THEN REAL_ARITH_TAC; + REWRITE_TAC[INT_GT; INT_NOT_LT] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LE; GSYM INT_OF_NUM_MUL] (ASSUME `a <= b * c`)) THEN INT_ARITH_TAC]);; + +let REAL_INT_GT_BRIDGE_POS = prove( + `!a:num b c. ~(a <= b * c) ==> + real_gt (&a - &b * &c) (&0) /\ int_gt (&a - &b * &c) (&0)`, + REPEAT GEN_TAC THEN REWRITE_TAC[NOT_LE] THEN DISCH_TAC THEN CONJ_TAC THENL + [REWRITE_TAC[real_gt] THEN + MP_TAC(REWRITE_RULE[GSYM REAL_OF_NUM_LT; GSYM REAL_OF_NUM_MUL] (ASSUME `b * c < a`)) THEN REAL_ARITH_TAC; + REWRITE_TAC[INT_GT] THEN + MP_TAC(REWRITE_RULE[GSYM INT_OF_NUM_LT; GSYM INT_OF_NUM_MUL] (ASSUME `b * c < a`)) THEN INT_ARITH_TAC]);; + +(* ========================================================================= *) +(* Shared helper lemmas for UseHint proofs *) +(* ========================================================================= *) + +let DIV_SANDWICH = prove( + `!x d k. ~(d = 0) /\ k * d <= x /\ x < (k + 1) * d ==> x DIV d = k`, + REPEAT GEN_TAC THEN STRIP_TAC THEN + SUBGOAL_THEN `k <= x DIV d` ASSUME_TAC THENL + [ASM_SIMP_TAC[LE_RDIV_EQ] THEN ASM_ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `x DIV d < k + 1` ASSUME_TAC THENL + [ASM_SIMP_TAC[RDIV_LT_EQ] THEN ASM_ARITH_TAC; ASM_ARITH_TAC]);; + +let INT_MOD_RESIDUE = prove( + `!r m. ~(m = 0) ==> (&r:int) - &(r DIV m) * &m = &(r MOD m)`, + REPEAT GEN_TAC THEN DISCH_TAC THEN + MP_TAC(SPECL [`r:num`; `m:num`] (CONJUNCT1 DIVISION_SIMP)) THEN + REWRITE_TAC[GSYM INT_OF_NUM_MUL; GSYM INT_OF_NUM_ADD; + GSYM INT_OF_NUM_EQ] THEN + INT_ARITH_TAC);; + +(* ========================================================================= *) +(* FIPS 204 UseHint definitions (Algorithms 36 and 40) *) +(* ========================================================================= *) + +let mldsa_cmod = new_definition + `mldsa_cmod (r:num) (m:num) : int = + if (r MOD m) * 2 <= m then &(r MOD m) else &(r MOD m) - &m`;; + +let mldsa_decompose_32 = new_definition + `mldsa_decompose_32 (r:num) : num # int = + let r0 = mldsa_cmod r 523776 in + if &r - r0 = &8380416 then (0, r0 - &1) + else (num_of_int((&r - r0) div &523776), r0)`;; + +let decompose_32_r1 = new_definition + `decompose_32_r1 (r:num) : num = FST(mldsa_decompose_32 r)`;; + +let decompose_32_r0 = new_definition + `decompose_32_r0 (r:num) : int = SND(mldsa_decompose_32 r)`;; + +let mldsa_decompose_88 = new_definition + `mldsa_decompose_88 (r:num) : num # int = + let r0 = mldsa_cmod r 190464 in + if &r - r0 = &8380416 then (0, r0 - &1) + else (num_of_int((&r - r0) div &190464), r0)`;; + +let decompose_88_r1 = new_definition + `decompose_88_r1 (r:num) : num = FST(mldsa_decompose_88 r)`;; + +let decompose_88_r0 = new_definition + `decompose_88_r0 (r:num) : int = SND(mldsa_decompose_88 r)`;; + +let mldsa_use_hint_32 = new_definition + `mldsa_use_hint_32 (h:num) (r:num) : num = + let (r1, r0) = mldsa_decompose_32 r in + if h = 1 /\ r0 > &0 then (r1 + 1) MOD 16 + else if h = 1 /\ r0 <= &0 then (r1 + 15) MOD 16 + else r1`;; + +let mldsa_use_hint_88 = new_definition + `mldsa_use_hint_88 (h:num) (r:num) : num = + let (r1, r0) = mldsa_decompose_88 r in + if h = 1 /\ r0 > &0 then if r1 = 43 then 0 else r1 + 1 + else if h = 1 /\ r0 <= &0 then if r1 = 0 then 43 else r1 - 1 + else r1`;; + +let LOWER_NONWRAP_R1 = prove( + `!r. r MOD 523776 * 2 <= 523776 /\ + ~((&r:int) - &(r MOD 523776) = &8380416) ==> + decompose_32_r1 r = r DIV 523776`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; + NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 = 523776 * r DIV 523776` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`523776`; `r DIV 523776`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV);; + +let UPPER_NONWRAP_R1 = prove( + `!r. ~(r MOD 523776 * 2 <= 523776) /\ + ~((&r:int) - (&(r MOD 523776) - &523776) = &8380416) ==> + decompose_32_r1 r = r DIV 523776 + 1`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_32_r1; mldsa_decompose_32; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 523776 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 523776 < 523776` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `523776`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 523776) - &523776) = + &(r - r MOD 523776 + 523776)` ASSUME_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `523776`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 523776 + 523776 = (r DIV 523776 + 1) * 523776` + ASSUME_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + MP_TAC(SPECL [`(r DIV 523776 + 1) * 523776`; `523776`] DIV_MULT) THEN + ARITH_TAC);; + +let LOWER_NONWRAP_R1_88 = prove( + `!r. r MOD 190464 * 2 <= 190464 /\ + ~((&r:int) - &(r MOD 190464) = &8380416) ==> + decompose_88_r1 r = r DIV 190464`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + ASM_SIMP_TAC[INT_OF_NUM_SUB; INT_OF_NUM_DIV; + NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 = 190464 * r DIV 190464` SUBST1_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + MP_TAC(SPECL [`190464`; `r DIV 190464`] DIV_MULT) THEN + CONV_TAC NUM_REDUCE_CONV);; + +let UPPER_NONWRAP_R1_88 = prove( + `!r. ~(r MOD 190464 * 2 <= 190464) /\ + ~((&r:int) - (&(r MOD 190464) - &190464) = &8380416) ==> + decompose_88_r1 r = r DIV 190464 + 1`, + GEN_TAC THEN STRIP_TAC THEN + REWRITE_TAC[decompose_88_r1; mldsa_decompose_88; mldsa_cmod] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN ASM_REWRITE_TAC[] THEN + SUBGOAL_THEN `r MOD 190464 <= r` ASSUME_TAC THENL + [MESON_TAC[MOD_LE]; ALL_TAC] THEN + SUBGOAL_THEN `r MOD 190464 < 190464` ASSUME_TAC THENL + [MP_TAC(SPECL [`r:num`; `190464`] MOD_LT_EQ) THEN ARITH_TAC; ALL_TAC] THEN + SUBGOAL_THEN `(&r:int) - (&(r MOD 190464) - &190464) = + &(r - r MOD 190464 + 190464)` ASSUME_TAC THENL + [ASM_SIMP_TAC[GSYM INT_OF_NUM_SUB; GSYM INT_OF_NUM_ADD] THEN + INT_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[INT_OF_NUM_DIV; NUM_OF_INT_OF_NUM; INT_OF_NUM_EQ] THEN + MP_TAC(SPECL [`r:num`; `190464`] (CONJUNCT1 DIVISION_SIMP)) THEN + DISCH_TAC THEN + SUBGOAL_THEN `r - r MOD 190464 + 190464 = (r DIV 190464 + 1) * 190464` + ASSUME_TAC THENL + [ASM_ARITH_TAC; ALL_TAC] THEN + ASM_REWRITE_TAC[] THEN + MP_TAC(SPECL [`(r DIV 190464 + 1) * 190464`; `190464`] DIV_MULT) THEN + ARITH_TAC);; + +let MLDSA_USE_HINT_32_UNFOLD = prove( + `!h r. mldsa_use_hint_32 h r = + (if h = 1 /\ decompose_32_r0 r > &0 then (decompose_32_r1 r + 1) MOD 16 + else if h = 1 /\ decompose_32_r0 r <= &0 + then (decompose_32_r1 r + 15) MOD 16 + else decompose_32_r1 r)`, + REPEAT GEN_TAC THEN + REWRITE_TAC[mldsa_use_hint_32; decompose_32_r1; decompose_32_r0] THEN + SPEC_TAC(`mldsa_decompose_32 r`, `p:num#int`) THEN + REWRITE_TAC[FORALL_PAIR_THM] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN REWRITE_TAC[]);; + +let MLDSA_USE_HINT_88_UNFOLD = prove( + `!h r. mldsa_use_hint_88 h r = + (if h = 1 /\ decompose_88_r0 r > &0 + then if decompose_88_r1 r = 43 then 0 else decompose_88_r1 r + 1 + else if h = 1 /\ decompose_88_r0 r <= &0 + then if decompose_88_r1 r = 0 then 43 else decompose_88_r1 r - 1 + else decompose_88_r1 r)`, + REPEAT GEN_TAC THEN + REWRITE_TAC[mldsa_use_hint_88; decompose_88_r1; decompose_88_r0] THEN + SPEC_TAC(`mldsa_decompose_88 r`, `p:num#int`) THEN + REWRITE_TAC[FORALL_PAIR_THM] THEN + CONV_TAC(TOP_DEPTH_CONV let_CONV) THEN REWRITE_TAC[]);; diff --git a/scripts/autogen b/scripts/autogen index 807918a1a..d85ca2fb9 100755 --- a/scripts/autogen +++ b/scripts/autogen @@ -2402,6 +2402,20 @@ def gen_hol_light_asm(): f"-Imldsa/src/native/aarch64/src {aarch64_flags}", "aarch64", ), + ( + "poly_use_hint_32_asm.S", + "poly_use_hint_32_aarch64_asm.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), + ( + "poly_use_hint_88_asm.S", + "poly_use_hint_88_aarch64_asm.S", + "dev/aarch64_opt/src", + f"-Imldsa/src/native/aarch64/src {aarch64_flags}", + "aarch64", + ), ( "mld_polyvecl_pointwise_acc_montgomery_l4.S", "mldsa_pointwise_acc_l4.S",