From 1a06a7fdf21085bfc8a06cd4956caed138b9dd84 Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 24 Mar 2026 21:27:29 +0000 Subject: [PATCH 01/11] proofs/isabelle: add AutoCorrode submodule Add the DominicPM/AutoCorrode fork as a submodule. --- .gitmodules | 3 +++ AutoCorrode | 1 + 2 files changed, 4 insertions(+) create mode 100644 .gitmodules create mode 160000 AutoCorrode diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000..8d5bf96036 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "AutoCorrode"] + path = AutoCorrode + url = https://github.com/DominicPM/AutoCorrode diff --git a/AutoCorrode b/AutoCorrode new file mode 160000 index 0000000000..416c81072a --- /dev/null +++ b/AutoCorrode @@ -0,0 +1 @@ +Subproject commit 416c81072adb1de96757c1059361b0814d6b5173 From 85ce61d814da6b09884e2d39e79c7bac52964e93 Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 21:59:30 +0000 Subject: [PATCH 02/11] proofs/isabelle: add abstract ML-KEM polynomial specification and refinement Add MLKEM_Spec.thy with core mathematical specifications (barrett_reduce_int, montgomery_reduce_int, fqmul_int, polynomial operations) and MLKEM_Refinement.thy with the refinement relation connecting C word types to abstract integers. --- proofs/isabelle/MLKEM_Refinement.thy | 290 +++++++++++++++++++++++++++ proofs/isabelle/MLKEM_Spec.thy | 281 ++++++++++++++++++++++++++ 2 files changed, 571 insertions(+) create mode 100644 proofs/isabelle/MLKEM_Refinement.thy create mode 100644 proofs/isabelle/MLKEM_Spec.thy diff --git a/proofs/isabelle/MLKEM_Refinement.thy b/proofs/isabelle/MLKEM_Refinement.thy new file mode 100644 index 0000000000..476b6da4a8 --- /dev/null +++ b/proofs/isabelle/MLKEM_Refinement.thy @@ -0,0 +1,290 @@ +(*<*) +theory MLKEM_Refinement + imports MLKEM_Spec +begin +(*>*) + +text \ + Refinement between C word-level types and abstract integers. + Defines the @{text refines_mlk_poly} predicate connecting concrete + @{typ c_mlk_poly} values (16-bit signed word coefficients) to abstract + @{typ int_poly} specifications, together with no-overflow conditions + required for safe arithmetic. +\ + +text_raw \ +\begin{figure}[ht] +\centering +\begin{tikzpicture}[>=Stealth, node distance=1.4cm and 4.5cm, + box/.style={draw=mlkblue, fill=mlklightblue, rounded corners, + minimum width=2.6cm, minimum height=0.8cm, + align=center, font=\small}, + lbl/.style={font=\small\itshape, midway}] + % Left column: C (word-level) + \node[box] (cpre) {C pre-state\\[-1pt]\footnotesize\texttt{c\_mlk\_poly}}; + \node[box, below=of cpre] (cpost) {C post-state\\[-1pt]\footnotesize\texttt{c\_mlk\_poly}}; + % Right column: abstract (int-level) + \node[box, right=of cpre] (apre) {Abstract pre-state\\[-1pt]\footnotesize\texttt{int\_poly}}; + \node[box, right=of cpost] (apost) {Abstract post-state\\[-1pt]\footnotesize\texttt{int\_poly}}; + % Vertical arrows: execution + \draw[->,thick,mlkblue] (cpre) -- node[left,lbl] {C execution} (cpost); + \draw[->,thick,mlkblue] (apre) -- node[right,lbl] {spec} (apost); + % Horizontal arrows: refinement + \draw[<->,densely dashed,thick,mlkblue] (cpre) -- + node[above,font=\small] {\textsf{refines}} (apre); + \draw[<->,densely dashed,thick,mlkblue] (cpost) -- + node[below,font=\small] {\textsf{refines}} (apost); +\end{tikzpicture} +\caption{Refinement-based verification. Each C function is executed in + lockstep with its abstract specification. The dashed arrows denote + the refinement relation @{text refines_mlk_poly}, which requires that + the signed interpretation of concrete 16-bit coefficients equals the + abstract integer list. A function contract establishes that if the + refinement holds before execution (pre-state), it is preserved after + execution (post-state).} +\label{fig:refinement} +\end{figure} +\ + +section \Concrete Types and Refinement\ + +text \ + Refinement relation: a concrete @{typ c_mlk_poly} refines an abstract + @{typ int_poly} when its coefficient list has the correct length and its + signed interpretation matches the abstract polynomial. +\ +definition refines_mlk_poly :: \c_mlk_poly \ int_poly \ bool\ where + \refines_mlk_poly cp ap \ + length (c_mlk_poly_coeffs cp) = MLKEM_N \ + List.map sint (c_mlk_poly_coeffs cp) = ap\ + +text \ + No-overflow condition: the mathematical sum of each coefficient pair + fits in a signed 16-bit integer. This is required both for the C code + to not abort (since @{const c_signed_add} checks for overflow) and for + the refinement to integer arithmetic to hold. +\ +definition no_overflow_add :: \int_poly \ int_poly \ bool\ where + \no_overflow_add ps qs \ + (\i < min (length ps) (length qs). + ps ! i + qs ! i \ {-(2^15) ..< 2^15})\ + +definition no_overflow_sub :: \int_poly \ int_poly \ bool\ where + \no_overflow_sub ps qs \ + (\i < min (length ps) (length qs). + ps ! i - qs ! i \ {-(2^15) ..< 2^15})\ + +text \ + The concrete (word-level) result of polynomial addition — internal helper + for proofs. +\ +definition poly_add_c :: \c_mlk_poly \ c_mlk_poly \ c_mlk_poly\ where + \poly_add_c p q \ update_c_mlk_poly_coeffs + (\_. map2 (+) (c_mlk_poly_coeffs p) (c_mlk_poly_coeffs q)) p\ + +subsection \Refinement Lemmas\ + +lemma sint_plus_no_overflow: + fixes a b :: \'l::{len} sword\ + assumes \sint a + sint b \ {-(2^(LENGTH('l) - 1)) ..< 2^(LENGTH('l) - 1)}\ + shows \sint (a + b) = sint a + sint b\ +using assms by (intro signed_arith_sint) (auto simp: word_size) + +lemma sint_minus_no_overflow: + fixes a b :: \'l::{len} sword\ + assumes \sint a - sint b \ {-(2^(LENGTH('l) - 1)) ..< 2^(LENGTH('l) - 1)}\ + shows \sint (a - b) = sint a - sint b\ +using assms by (intro signed_arith_sint) (auto simp: word_size) + +text \ + The key refinement theorem: under the no-overflow condition, the concrete + word-level addition produces a result that refines the abstract integer sum. +\ +theorem poly_add_c_refines: + assumes \refines_mlk_poly p ap\ + and \refines_mlk_poly q aq\ + and \no_overflow_add ap aq\ + shows \refines_mlk_poly (poly_add_c p q) (poly_add_int ap aq)\ +using assms by (auto simp add: c_mlk_poly.record_simps map2_map_map word_size refines_mlk_poly_def + poly_add_c_def poly_add_int_def no_overflow_add_def intro!: nth_equalityI sint_plus_no_overflow) + +subsection \Auxiliary List Lemmas\ + +lemma nth_map2: + assumes \i < length xs\ + and \i < length ys\ + shows \map2 f xs ys ! i = f (xs ! i) (ys ! i)\ +using assms by (induction xs arbitrary: i ys) (auto simp: less_Suc_eq_0_disj split: list.splits) + +(*<*) +lemma inv_list_step: + assumes \n < length xs\ + and \n < length ys\ + and \length xs = length ys\ + shows \(take n (map2 f xs ys) @ drop n xs)[n := f (xs ! n) (ys ! n)] = + take (Suc n) (map2 f xs ys) @ drop (Suc n) xs\ +proof - + let ?zs = \map2 f xs ys\ + from assms have ln: \n < length ?zs\ + by simp + from assms have zn: \?zs ! n = f (xs ! n) (ys ! n)\ + by (simp add: nth_map2) + from assms have drop_eq: \drop n xs = xs ! n # drop (Suc n) xs\ + by (metis Cons_nth_drop_Suc) + have \(take n ?zs @ drop n xs)[n := ?zs ! n] = take n ?zs @ (drop n xs)[0 := ?zs ! n]\ + using ln by (simp add: list_update_append) + also have \\ = take n ?zs @ ?zs ! n # drop (Suc n) xs\ + using drop_eq by simp + also have \\ = take (Suc n) ?zs @ drop (Suc n) xs\ + using ln by (simp add: take_Suc_conv_app_nth) + finally show ?thesis + using zn by simp +qed + +(*>*) + +(*<*) +lemma no_overflow_add_bounds: + assumes \refines_mlk_poly vr ar\ + and \refines_mlk_poly vb ab\ + and \no_overflow_add ar ab\ + and \i < MLKEM_N\ + shows \sint (c_mlk_poly_coeffs vr ! i) + sint (c_mlk_poly_coeffs vb ! i) < 2 ^ 15\ + and \- (2 ^ 15) \ sint (c_mlk_poly_coeffs vr ! i) + sint (c_mlk_poly_coeffs vb ! i)\ +proof - + from assms(1) have lr: \length (c_mlk_poly_coeffs vr) = MLKEM_N\ + and mr: \List.map sint (c_mlk_poly_coeffs vr) = ar\ + unfolding refines_mlk_poly_def by auto + from assms(2) have lb: \length (c_mlk_poly_coeffs vb) = MLKEM_N\ + and mb: \List.map sint (c_mlk_poly_coeffs vb) = ab\ + unfolding refines_mlk_poly_def by auto + have \ar ! i + ab ! i \ {-(2^15) ..< 2^15}\ + using assms(3,4) lr lb mr mb unfolding no_overflow_add_def by auto + moreover have \ar ! i = sint (c_mlk_poly_coeffs vr ! i)\ + using mr lr assms(4) by (simp add: nth_map[symmetric]) + moreover have \ab ! i = sint (c_mlk_poly_coeffs vb ! i)\ + using mb lb assms(4) by (simp add: nth_map[symmetric]) + ultimately show \sint (c_mlk_poly_coeffs vr ! i) + sint (c_mlk_poly_coeffs vb ! i) < 2 ^ 15\ + and \- (2 ^ 15) \ sint (c_mlk_poly_coeffs vr ! i) + sint (c_mlk_poly_coeffs vb ! i)\ + by auto +qed + +lemma no_overflow_sub_bounds: + assumes \refines_mlk_poly vr ar\ + and \refines_mlk_poly vb ab\ + and \no_overflow_sub ar ab\ + and \i < MLKEM_N\ + shows \sint (c_mlk_poly_coeffs vr ! i) - sint (c_mlk_poly_coeffs vb ! i) < 2 ^ 15\ + and \- (2 ^ 15) \ sint (c_mlk_poly_coeffs vr ! i) - sint (c_mlk_poly_coeffs vb ! i)\ +proof - + from assms(1) have lr: \length (c_mlk_poly_coeffs vr) = MLKEM_N\ + and mr: \List.map sint (c_mlk_poly_coeffs vr) = ar\ + unfolding refines_mlk_poly_def by auto + from assms(2) have lb: \length (c_mlk_poly_coeffs vb) = MLKEM_N\ + and mb: \List.map sint (c_mlk_poly_coeffs vb) = ab\ + unfolding refines_mlk_poly_def by auto + have \ar ! i - ab ! i \ {-(2^15) ..< 2^15}\ + using assms(3,4) lr lb mr mb unfolding no_overflow_sub_def by auto + moreover have \ar ! i = sint (c_mlk_poly_coeffs vr ! i)\ + using mr lr assms(4) by (simp add: nth_map[symmetric]) + moreover have \ab ! i = sint (c_mlk_poly_coeffs vb ! i)\ + using mb lb assms(4) by (simp add: nth_map[symmetric]) + ultimately show \sint (c_mlk_poly_coeffs vr ! i) - sint (c_mlk_poly_coeffs vb ! i) < 2 ^ 15\ + and \- (2 ^ 15) \ sint (c_mlk_poly_coeffs vr ! i) - sint (c_mlk_poly_coeffs vb ! i)\ + by auto +qed + +lemma MLKEM_N_sub_step [simp]: + assumes \k < MLKEM_N\ + shows \MLKEM_N - k = Suc (255 - k)\ +using assms by simp + +lemma mlkem_rev_index_bound [simp]: + shows \255 - k < MLKEM_N\ +by simp +(*>*) + +text \Roundtrip: \sint (word_of_int x)\ equals \x\ when \x\ fits in 16-bit signed range.\ + +lemma sint_word_of_int_16: + assumes \- (2^15) \ x\ + and \x < 2^15\ + shows \sint (word_of_int x :: 16 sword) = x\ +proof - + have \signed_take_bit 15 x = x\ + using assms by (intro signed_take_bit_int_eq_self) auto + moreover have \sint (word_of_int x :: 16 sword) = signed_take_bit 15 x\ + by transfer simp + ultimately show ?thesis + by simp +qed + +text \The sint of \word_of_int (montgomery_reduce_int (sint a * sint b))\ equals + the Montgomery reduction, for any 16-bit signed inputs.\ + +lemma sint_word_of_montgomery_fqmul: + fixes a :: \16 sword\ + and b :: \16 sword\ + shows \sint (word_of_int (montgomery_reduce_int (sint a * sint b)) :: 16 sword) = + montgomery_reduce_int (sint a * sint b)\ +proof - + have ab: \\sint a\ \ 2^15\ \\sint b\ \ 2^15\ + using sint_range_size[where w=a] sint_range_size[where w=b] by (auto simp: word_size) + have \\sint a * sint b\ \ 2^30\ + proof - + have \\sint a * sint b\ = \sint a\ * \sint b\\ + by (rule abs_mult) + also have \\ \ 2^15 * 2^15\ + using ab by (intro mult_mono) auto + finally show ?thesis + by simp + qed + hence \\sint a * sint b\ < 2^31 - 2^15 * MLKEM_Q\ + by simp + hence \\montgomery_reduce_int (sint a * sint b)\ < 2^15\ + by (rule montgomery_reduce_int_bound) + thus ?thesis + by (intro sint_word_of_int_16) auto +qed + +(*<*) +lemma inv_list_step_map: + assumes \n < length xs\ + shows \(take n (List.map f xs) @ drop n xs)[n := f (xs ! n)] = + take (Suc n) (List.map f xs) @ drop (Suc n) xs\ +proof - + let ?zs = \List.map f xs\ + from assms have ln: \n < length ?zs\ + by simp + from assms have zn: \?zs ! n = f (xs ! n)\ + by simp + from assms have drop_eq: \drop n xs = xs ! n # drop (Suc n) xs\ + by (metis Cons_nth_drop_Suc) + have \(take n ?zs @ drop n xs)[n := ?zs ! n] = take n ?zs @ (drop n xs)[0 := ?zs ! n]\ + using ln by (simp add: list_update_append) + also have \\ = take n ?zs @ ?zs ! n # drop (Suc n) xs\ + using drop_eq by simp + also have \\ = take (Suc n) ?zs @ drop (Suc n) xs\ + using ln by (simp add: take_Suc_conv_app_nth) + finally show ?thesis + using zn by simp +qed +(*>*) + +subsection \Additional Refinement Predicates\ + +definition refines_mlk_poly_mulcache :: \c_mlk_poly_mulcache \ int list \ bool\ where + \refines_mlk_poly_mulcache cm am \ + length (c_mlk_poly_mulcache_coeffs cm) = 128 \ + List.map sint (c_mlk_poly_mulcache_coeffs cm) = am\ + +definition refines_coeffs :: \c_short list \ int list \ bool\ where + \refines_coeffs ccs acs \ length ccs = MLKEM_N \ List.map sint ccs = acs\ + +lemma refines_mlk_poly_coeffs: + shows \refines_mlk_poly cp ap \ refines_coeffs (c_mlk_poly_coeffs cp) ap\ +unfolding refines_mlk_poly_def refines_coeffs_def .. + +(*<*) +end +(*>*) diff --git a/proofs/isabelle/MLKEM_Spec.thy b/proofs/isabelle/MLKEM_Spec.thy new file mode 100644 index 0000000000..87bdb4e0d6 --- /dev/null +++ b/proofs/isabelle/MLKEM_Spec.thy @@ -0,0 +1,281 @@ +(*<*) +theory MLKEM_Spec + imports MLKEM_Poly_Definitions "Micro_C_Examples.C_While_Examples" +begin +(*>*) + +text \ + Core mathematical specification of ML-KEM polynomial arithmetic: + coefficient-list polynomials over the integers, Barrett reduction, + Montgomery reduction, and field multiplication (@{text fqmul}). + All definitions operate on unbounded integers; word-level refinement + is handled in @{text MLKEM_Refinement}. +\ + +section \Abstract Polynomial Arithmetic\ + +text \We model mlkem-native polynomials abstractly as fixed-size coefficient + lists over the integers. This gives a clean mathematical specification + independent of machine word sizes.\ + +abbreviation MLKEM_N :: nat where + \MLKEM_N \ 256\ + +type_synonym int_poly = \int list\ + +definition poly_add_int :: \int_poly \ int_poly \ int_poly\ where + \poly_add_int ps qs \ map2 (+) ps qs\ + +definition poly_sub_int :: \int_poly \ int_poly \ int_poly\ where + \poly_sub_int ps qs \ map2 (-) ps qs\ + +subsection \Barrett Reduction\ + +text \Barrett reduction approximates division by @{text Q} using a + pre-computed multiplier, replacing an expensive division with a + multiplication and shift. The result is congruent to the input + modulo @{text Q} but not necessarily fully reduced.\ + +abbreviation MLKEM_Q :: int where + \MLKEM_Q \ 3329\ + +definition barrett_reduce_int :: \int \ int\ where + \barrett_reduce_int a \ a - ((20159 * a + 2^25) div 2^26) * MLKEM_Q\ + +text \Correctness: @{const barrett_reduce_int} preserves the residue class + modulo @{const MLKEM_Q}.\ + +theorem barrett_reduce_mod: + shows \barrett_reduce_int a mod MLKEM_Q = a mod MLKEM_Q\ +unfolding barrett_reduce_int_def by algebra + +subsection \Montgomery Reduction\ + +text \Abstract Montgomery reduction on integers. Given an integer \a\, + returns a value \r\ such that \r * 2^16 \ a (mod Q)\.\ + +definition montgomery_reduce_int :: \int \ int\ where + \montgomery_reduce_int a \ + (let t = signed_take_bit 15 ((a mod 2^16) * 62209 mod 2^16) + in (a - t * 3329) div 2^16)\ + +text \Key refinement properties of @{const montgomery_reduce_int}. + All proofs hide the complex \signed_take_bit 15 (...)\ subterm + behind a local \define\ to keep automated provers fast.\ + +lemma montgomery_reduce_int_result_eq: + shows \montgomery_reduce_int a * (65536::int) = + a - signed_take_bit 15 (a mod 65536 * 62209 mod 65536) * 3329\ +proof - + define u :: int where + \u = a mod 65536 * 62209 mod 65536\ + define t where + \t = signed_take_bit 15 u\ + have mt_def: \montgomery_reduce_int a = (a - t * 3329) div 65536\ + unfolding montgomery_reduce_int_def Let_def t_def u_def by simp + \ \signed\_take\_bit preserves residue mod \2^16\\ + have t_cong_u: \t mod 65536 = u mod 65536\ + proof - + have \t = u mod 65536 - 65536 * of_bool (bit u 15)\ + unfolding t_def by (simp add: signed_take_bit_eq_take_bit_minus take_bit_eq_mod) + thus ?thesis by simp + qed + \ \Modular congruence chain: \t * Q \ a (mod 2^16)\\ + have s1: \t * 3329 mod 65536 = u * 3329 mod 65536\ + using t_cong_u by (metis mod_mult_left_eq) + have s2: \u * 3329 mod 65536 = (a mod 65536 * 62209) * 3329 mod 65536\ + unfolding u_def using mod_mult_left_eq[of \a mod 65536 * 62209\ 65536 3329] by linarith + have s3: \(a mod 65536 * 62209) * 3329 mod 65536 = a * (62209 * 3329) mod 65536\ + using mod_mult_left_eq[of a 65536 \62209 * 3329\] by (simp only: mult.assoc) + have qinv_q_mod: \(62209::int) * 3329 mod 65536 = 1\ + by simp + have s4: \a * ((62209::int) * 3329) mod 65536 = a mod 65536\ + using mod_mult_right_eq[of a \(62209::int) * 3329\ 65536] qinv_q_mod by simp + have tQ_cong: \t * 3329 mod 65536 = a mod 65536\ + using s1 s2 s3 s4 by presburger + \ \Exact divisibility: \2^16\ divides \a - t * Q\\ + have \(a - t * 3329) mod 65536 = (a - a) mod 65536\ + by (rule mod_diff_cong[OF refl tQ_cong]) + hence div_exact: \65536 dvd (a - t * 3329)\ + by (simp add: dvd_eq_mod_eq_0) + \ \Combine: \r * 2^16 = a - t * Q\\ + have \montgomery_reduce_int a * 65536 = a - t * 3329\ + using mt_def dvd_div_mult_self[OF div_exact] by simp + thus ?thesis + unfolding t_def u_def . +qed + +theorem montgomery_reduce_int_correct: + shows \montgomery_reduce_int a * 2^16 mod MLKEM_Q = a mod MLKEM_Q\ +proof - + define t :: int where + \t = signed_take_bit 15 (a mod 65536 * 62209 mod 65536)\ + have result_eq: \montgomery_reduce_int a * 65536 = a - t * 3329\ + unfolding t_def by (rule montgomery_reduce_int_result_eq) + have \montgomery_reduce_int a * 2 ^ 16 mod MLKEM_Q = (a - t * MLKEM_Q) mod MLKEM_Q\ + using result_eq by simp + also have \\ = a mod MLKEM_Q\ + by (metis mod_diff_right_eq mod_mult_self2_is_0 diff_zero) + finally show ?thesis . +qed + +text \Output bound for @{const montgomery_reduce_int}: with the overflow + precondition \|a| < 2^31 - 2^15 * Q\, the result fits in a signed 16-bit + integer: \|r| < 2^15\.\ +theorem montgomery_reduce_int_bound: + assumes \\a\ < 2^31 - 2^15 * MLKEM_Q\ + shows \\montgomery_reduce_int a\ < 2^15\ +proof - + define t :: int where \t = signed_take_bit 15 (a mod 65536 * 62209 mod 65536)\ + have result_eq: \montgomery_reduce_int a * 65536 = a - t * 3329\ + unfolding t_def by (rule montgomery_reduce_int_result_eq) + have t_lb: \t \ -32768\ + proof - + have \t \ -(2^15)\ + unfolding t_def by (rule signed_take_bit_int_greater_eq_minus_exp) + thus ?thesis + by simp + qed + have t_ub: \t < 32768\ + proof - + have \t < 2^15\ + unfolding t_def by (rule signed_take_bit_int_less_exp) + thus ?thesis + by simp + qed + have a_bounds: \a > -2038398976\ \a < 2038398976\ + using assms by auto + have \montgomery_reduce_int a * 65536 > -2147483648\ + using result_eq a_bounds t_ub by linarith + moreover have \montgomery_reduce_int a * 65536 < 2147483648\ + using result_eq a_bounds t_lb by linarith + ultimately show ?thesis + by simp +qed + +subsection \Field Multiplication\ + +text \Field multiplication in Montgomery domain: multiply two integers + and apply @{const montgomery_reduce_int} to the product. The result + satisfies @{text "fqmul a b * R \ a * b (mod Q)"} where @{text "R = 2^16"}.\ + +definition fqmul_int :: \int \ int \ int\ where + \fqmul_int a b \ montgomery_reduce_int (a * b)\ + +subsection \Polynomial Operations\ + +text \Coefficient-wise polynomial operations used by ML-KEM: + Montgomery pre-scaling (by the constant @{text "1353 = R^2 mod Q"}) + and Barrett reduction modulo @{const MLKEM_Q}.\ + +definition poly_tomont_int :: \int_poly \ int_poly\ where + \poly_tomont_int ap \ List.map (\a. montgomery_reduce_int (a * 1353)) ap\ + +definition poly_reduce_int :: \int_poly \ int_poly\ where + \poly_reduce_int ap \ List.map (\a. a mod MLKEM_Q) ap\ + +subsection \Barrett Reduce Bounds\ + +text \When the input fits in a signed 16-bit word, @{const barrett_reduce_int} + produces a result in @{text "[-1664, 1664]"}, and hence within @{text "(-Q, Q)"}. + These bounds also hold for any integer input in that range.\ + +lemma barrett_reduce_int_bounds: + fixes a :: \16 sword\ + shows \-1664 \ barrett_reduce_int (sint a)\ + and \barrett_reduce_int (sint a) \ 1664\ +(*<*) +proof - + define q where + \q = (20159 * sint a + 33554432) div 67108864\ + have sint_range: \-32768 \ sint a\ \sint a \ 32767\ + using sint_range_size[of a] by (auto simp: word_size) + have result: \barrett_reduce_int (sint a) = sint a - q * 3329\ + unfolding barrett_reduce_int_def q_def by simp + have mod_eq: \q * 67108864 + (20159 * sint a + 33554432) mod 67108864 = 20159 * sint a + 33554432\ + unfolding q_def by (rule div_mult_mod_eq) + have mod_lb: \0 \ (20159 * sint a + 33554432) mod (67108864 :: int)\ + by simp + have mod_ub: \(20159 * sint a + 33554432) mod (67108864 :: int) < 67108864\ + by simp + have div_lb: \q * 67108864 \ 20159 * sint a + 33554432\ + using mod_eq mod_lb by linarith + have div_ub: \20159 * sint a + 33554432 < q * 67108864 + 67108864\ + using mod_eq mod_ub by linarith + have q_lb: \q \ -10\ + using div_ub sint_range by linarith + have q_ub: \q \ 10\ + using div_lb sint_range by linarith + have lower: \20159 * (sint a - q * 3329) \ -33558902\ + using div_lb q_ub by (simp add: algebra_simps) + have upper: \20159 * (sint a - q * 3329) < 33558902\ + using div_ub q_lb by (simp add: algebra_simps) + show \-1664 \ barrett_reduce_int (sint a)\ + using result lower by simp + show \barrett_reduce_int (sint a) \ 1664\ + using result upper by simp +qed +(*>*) + +(*<*) +lemma barrett_reduce_int_in_range: + fixes a :: \16 sword\ + shows \- MLKEM_Q < barrett_reduce_int (sint a)\ + and \barrett_reduce_int (sint a) < MLKEM_Q\ +using barrett_reduce_int_bounds[of a] by linarith+ +(*>*) + +lemma barrett_reduce_int_abs_bound: + fixes a :: \16 sword\ + shows \\barrett_reduce_int (sint a)\ < MLKEM_Q\ +(*<*) +using barrett_reduce_int_in_range[of a] by linarith +(*>*) + +(*<*) +lemma barrett_reduce_int_bounds_int: + assumes \-32768 \ x\ + and \x \ 32767\ + shows \-1664 \ barrett_reduce_int x\ + and \barrett_reduce_int x \ 1664\ +proof - + define q where + \q = (20159 * x + 33554432) div 67108864\ + have result: \barrett_reduce_int x = x - q * 3329\ + unfolding barrett_reduce_int_def q_def by simp + have mod_eq: \q * 67108864 + (20159 * x + 33554432) mod 67108864 = 20159 * x + 33554432\ + unfolding q_def by (rule div_mult_mod_eq) + have mod_lb: \0 \ (20159 * x + 33554432) mod (67108864 :: int)\ + by simp + have mod_ub: \(20159 * x + 33554432) mod (67108864 :: int) < 67108864\ + by simp + have div_lb: \q * 67108864 \ 20159 * x + 33554432\ + using mod_eq mod_lb by linarith + have div_ub: \20159 * x + 33554432 < q * 67108864 + 67108864\ + using mod_eq mod_ub by linarith + have q_lb: \q \ -10\ + using div_ub assms by linarith + have q_ub: \q \ 10\ + using div_lb assms by linarith + have lower: \20159 * (x - q * 3329) \ -33558902\ + using div_lb q_ub by (simp add: algebra_simps) + have upper: \20159 * (x - q * 3329) < 33558902\ + using div_ub q_lb by (simp add: algebra_simps) + show \-1664 \ barrett_reduce_int x\ + using result lower by simp + show \barrett_reduce_int x \ 1664\ + using result upper by simp +qed +(*>*) + +lemma barrett_reduce_int_abs_bound_int: + assumes \-32768 \ x\ + and \x \ 32767\ + shows \\barrett_reduce_int x\ < MLKEM_Q\ +(*<*) +using barrett_reduce_int_bounds_int[OF assms] by linarith +(*>*) + +(*<*) +end +(*>*) From 00a4982999e784eefd1173b747db8ce5647435c8 Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 21:59:45 +0000 Subject: [PATCH 03/11] proofs/isabelle: add NTT zeta factor table and properties Add MLKEM_Zetas.thy with precomputed 128 Montgomery-form twiddle factors and their bound and divisibility properties. --- proofs/isabelle/MLKEM_Zetas.thy | 334 ++++++++++++++++++++++++++++++++ 1 file changed, 334 insertions(+) create mode 100644 proofs/isabelle/MLKEM_Zetas.thy diff --git a/proofs/isabelle/MLKEM_Zetas.thy b/proofs/isabelle/MLKEM_Zetas.thy new file mode 100644 index 0000000000..d29f51c24f --- /dev/null +++ b/proofs/isabelle/MLKEM_Zetas.thy @@ -0,0 +1,334 @@ +(*<*) +theory MLKEM_Zetas + imports MLKEM_Spec +begin +(*>*) + +text \ + Precomputed twiddle factors for the NTT: 128 centered Montgomery-form + powers of the primitive 256th root of unity modulo @{const MLKEM_Q}. + Connects the C global @{verbatim "mlk_zetas[128]"} to its mathematical + derivation via @{const montgomery_reduce_int} and bit reversal, and + provides the word-level table used by @{const fqmul_int}. +\ + +section \Zetas Table\ + +text \128 precomputed twiddle factors (signed), matching the C global + @{verbatim "mlk_zetas[128]"} from @{file "../../mlkem/src/poly.c"}.\ + +(*<*) +definition zetas_int :: \int list\ where + \zetas_int \ [ + -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, + 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, + -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, + 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, + -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, + -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, + 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, + -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, + 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, + -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, + -308, 996, 991, 958, -1460, 1522, 1628]\ + +lemma length_zetas_int [simp]: + shows \length zetas_int = 128\ +by eval +(*>*) + +subsection \Mathematical Characterization\ + +text \The zetas table is derived from a primitive 256th root of unity + modulo @{const MLKEM_Q}. The root @{text "\ = 17"} satisfies + @{text "\^256 \ 1 (mod Q)"} and @{text "\^128 \ -1 (mod Q)"}, + confirming that its multiplicative order is exactly 256.\ + +definition mlkem_zeta :: int where + \mlkem_zeta \ 17\ + +lemma mlkem_zeta_is_root_of_unity: + shows \mlkem_zeta ^ 256 mod MLKEM_Q = 1\ +by eval + +lemma mlkem_zeta_half_order: + shows \mlkem_zeta ^ 128 mod MLKEM_Q = MLKEM_Q - 1\ +by eval + +lemma mlkem_zeta_primitive: + shows \list_all (\k. mlkem_zeta ^ k mod MLKEM_Q \ 1) [1..<256]\ +by eval + +text \Bit reversal on @{term n} bits: reverses the @{term n} least significant + bits of a natural number. Maps linear indices to the bit-reversed order + used by the NTT butterfly decomposition.\ + +fun bit_reverse :: \nat \ nat \ nat\ where + \bit_reverse 0 _ = 0\ +| \bit_reverse (Suc n) k = (k mod 2) * 2 ^ n + bit_reverse n (k div 2)\ + +text \Centered modular reduction: maps @{term x} to its unique representative + in @{text "{-\q/2\, \, \q/2\}"}.\ + +definition centered_mod :: \int \ int \ int\ where + \centered_mod x q \ let r = x mod q in if 2 * r > q then r - q else r\ + +text \Each entry of the zetas table is the centered Montgomery-form power of + the primitive root @{const mlkem_zeta} in bit-reversed order: + @{text "zetas_int ! i = centered_mod (\^(bit_reverse 7 i) \ 2^16) Q"}. + The factor @{text "2^16"} is the Montgomery radix used by + @{const montgomery_reduce_int}.\ + +lemma zetas_int_roots_of_unity: + shows \zetas_int = + List.map (\i. centered_mod (mlkem_zeta ^ bit_reverse 7 i * 2 ^ 16) MLKEM_Q) [0..<128]\ +by eval + +subsection \Word-Level Zetas\ + +text \Word-level zetas table, derived from the canonical @{const zetas_int}.\ + +definition zetas_sword :: \16 sword list\ where + \zetas_sword \ List.map word_of_int zetas_int\ + +(*<*) +lemma zetas_sword_unfold: + shows \zetas_sword = [0xFBEC, 0xFD0A, 0xFE99, 0xFA13, 0x5D5, 0x58E, 0x11F, 0xCA, + 0xFF55, 0x26E, 0x629, 0xB6, 0x3C2, 0xFB4E, 0xFA3E, 0x5BC, + 0x23D, 0xFAD3, 0x108, 0x17F, 0xFCC3, 0x5B2, 0xF9BE, 0xFF7E, + 0xFD57, 0x3F9, 0x2DC, 0x260, 0xF9FA, 0x19B, 0xFF33, 0xF9DD, + 0x4C7, 0x28C, 0xFDD8, 0x3F7, 0xFAF3, 0x5D3, 0xFEE6, 0xF9F8, + 0x204, 0xFFF8, 0xFEC0, 0xFD66, 0xF9AE, 0xFB76, 0x7E, 0x5BD, + 0xFCAB, 0xFFA6, 0xFEF1, 0x33E, 0x6B, 0xFA73, 0xFF09, 0xFC49, + 0xFE72, 0x3C1, 0xFA1C, 0xFD2B, 0x1C0, 0xFBD7, 0x2A5, 0xFB05, + 0xFBB1, 0x1AE, 0x22B, 0x34B, 0xFB1D, 0x367, 0x60E, 0x69, + 0x1A6, 0x24B, 0xB1, 0xFF15, 0xFEDD, 0xFE34, 0x626, 0x675, + 0xFF0A, 0x30A, 0x487, 0xFF6D, 0xFCF7, 0x5CB, 0xFDA6, 0x45F, + 0xF9CA, 0x284, 0xFC98, 0x15D, 0x1A2, 0x149, 0xFF64, 0xFFB5, + 0x331, 0x449, 0x25B, 0x262, 0x52A, 0xFAFB, 0xFA47, 0x180, + 0xFB41, 0xFF78, 0x4C2, 0xFAC9, 0xFC96, 0xDC, 0xFB5D, 0xF985, + 0xFB5F, 0xFA06, 0xFB02, 0x31A, 0xFA1A, 0xFCAA, 0xFC9A, 0x1DE, + 0xFF94, 0xFECC, 0x3E4, 0x3DF, 0x3BE, 0xFA4C, 0x5F2, 0x65C]\ +by eval +(*>*) + +lemma length_zetas_sword [simp]: + shows \length zetas_sword = 128\ +by (simp add: zetas_sword_def) + +(*<*) +lemma map_sint_zetas_sword: + shows \List.map sint zetas_sword = zetas_int\ +by eval +(*>*) + +lemma zetas_sword_sint: + assumes \i < 128\ + shows \sint (zetas_sword ! i) = zetas_int ! i\ +using assms nth_map[of i zetas_sword sint] map_sint_zetas_sword by simp + +(*<*) +lemma map_sint_neg_scast_zetas_sword: + shows \List.map (\w :: 16 sword. sint (scast (- (scast w :: 32 sword)) :: 16 sword)) + zetas_sword = List.map uminus zetas_int\ +by eval +(*>*) + +lemma zetas_neg_scast_sint: + assumes \i < 128\ + shows \sint (scast (- (scast (zetas_sword ! i) :: 32 sword)) :: 16 sword) = + - (zetas_int ! i)\ +using assms nth_map[of i zetas_sword \\w. sint (scast (- (scast w :: 32 sword)) :: 16 sword)\] + nth_map[of i zetas_int uminus] map_sint_neg_scast_zetas_sword by simp + +(*<*) +lemma drop_64_zetas_sword: + shows \drop 64 zetas_sword = + [0xFBB1, 0x1AE, 0x22B, 0x34B, 0xFB1D, 0x367, 0x60E, 0x69, + 0x1A6, 0x24B, 0xB1, 0xFF15, 0xFEDD, 0xFE34, 0x626, 0x675, + 0xFF0A, 0x30A, 0x487, 0xFF6D, 0xFCF7, 0x5CB, 0xFDA6, 0x45F, + 0xF9CA, 0x284, 0xFC98, 0x15D, 0x1A2, 0x149, 0xFF64, 0xFFB5, + 0x331, 0x449, 0x25B, 0x262, 0x52A, 0xFAFB, 0xFA47, 0x180, + 0xFB41, 0xFF78, 0x4C2, 0xFAC9, 0xFC96, 0xDC, 0xFB5D, 0xF985, + 0xFB5F, 0xFA06, 0xFB02, 0x31A, 0xFA1A, 0xFCAA, 0xFC9A, 0x1DE, + 0xFF94, 0xFECC, 0x3E4, 0x3DF, 0x3BE, 0xFA4C, 0x5F2, 0x65C]\ +by eval +(*>*) + +subsection \Zetas Bounds\ + +text \Bounds on the abstract zetas coefficients.\ + +lemma zetas_int_abs_bound: + assumes \i < 128\ + shows \\zetas_int ! i\ \ 1659\ +proof - + have \list_all (\z. \z\ \ 1659) zetas_int\ + by eval + thus ?thesis + using assms by (simp add: list_all_length) +qed + +lemma zetas_int_bound: + assumes \i < 128\ + shows \zetas_int ! i \ 1659\ \- (zetas_int ! i) \ 1659\ +using zetas_int_abs_bound[OF assms] by auto + +lemma zetas_int_i32_bound_from_k: + assumes \k < 64\ + shows \zetas_int ! (127 - k) \ 2147483648\ + and \- (zetas_int ! (127 - k)) < 2147483648\ +proof - + have \127 - k < 128\ + by simp + from zetas_int_bound[OF this] show \zetas_int ! (127 - k) \ 2147483648\ + by simp + from zetas_int_bound[OF \127 - k < 128\] show \- (zetas_int ! (127 - k)) < 2147483648\ + by simp +qed + +subsection \C Global Zetas\ + +text \Connecting the C global zetas array to the abstract spec.\ + +lemma c_global_mlk_zetas_eq_zetas_sword: + shows \c_global_mlk_zetas = zetas_sword\ +by (simp add: c_global_mlk_zetas_def zetas_sword_unfold) + +section \Mulcache Computation\ + +text \Abstract mulcache: for each block of 4 coefficients, compute two + fqmul products with the corresponding zeta factor.\ + +definition mulcache_compute_int :: \int_poly \ int list\ where + \mulcache_compute_int ap \ + concat (List.map (\i. [fqmul_int (ap ! (4*i + 1)) (zetas_int ! (64 + i)), + fqmul_int (ap ! (4*i + 3)) (- (zetas_int ! (64 + i)))]) + [0..<64])\ + +lemma length_concat_map_pair: + shows \length (concat (List.map (\j. [f j, g j]) [0.. +by (induct n) simp_all + +lemma length_mulcache_compute_int [simp]: + shows \length (mulcache_compute_int ap) = 128\ +unfolding mulcache_compute_int_def by (simp add: length_concat_map_pair) + +lemma concat_map_pair_nth_aux: + assumes \i < n\ + shows \concat (List.map (\j. [f j, g j]) [0.. concat (List.map (\j. [f j, g j]) [0.. +using assms proof (induct n arbitrary: i) + case (Suc n) + then show ?case + by (cases \i < n\) (auto simp: nth_append less_Suc_eq length_concat_map_pair) +qed auto + +lemma concat_map_pair_nth: + assumes \i < n\ + shows \concat (List.map (\j. [f j, g j]) [0.. + and \concat (List.map (\j. [f j, g j]) [0.. +using concat_map_pair_nth_aux[OF assms] by auto + +lemma mulcache_compute_int_nth_even: + assumes \i < 64\ + shows \mulcache_compute_int ap ! (2*i) = + fqmul_int (ap ! (4*i + 1)) (zetas_int ! (64 + i))\ +unfolding mulcache_compute_int_def using assms by (rule concat_map_pair_nth) + +lemma mulcache_compute_int_nth_odd: + assumes \i < 64\ + shows \mulcache_compute_int ap ! (2*i + 1) = + fqmul_int (ap ! (4*i + 3)) (- (zetas_int ! (64 + i)))\ +unfolding mulcache_compute_int_def using assms by (rule concat_map_pair_nth) + +(*<*) +subsection \Word Arithmetic Helpers\ + +lemma word_of_nat_mult_numeral: + shows \(numeral n :: 'a::len word) * word_of_nat k = word_of_nat (numeral n * k)\ +by (metis of_nat_mult of_nat_numeral) + +lemma unat_word_sub_word_of_nat: + fixes c :: \32 word\ + assumes \unat c = n\ \m \ n\ + shows \unat (c - word_of_nat m) = n - m\ +proof - + have \n < 2 ^ 32\ + using assms(1) unat_lt2p[of c] by simp + hence \m < 2 ^ 32\ + using assms(2) by linarith + hence u: \unat (word_of_nat m :: 32 word) = m\ + by (simp add: unat_of_nat) + have le: \(word_of_nat m :: 32 word) \ c\ + using assms by (simp add: word_le_nat_alt u) + show ?thesis + using unat_sub[OF le] u assms by simp +qed + +lemma unat_0xFD_sub_4k [simp]: + assumes \k < 64\ + shows \unat ((0xFD :: 32 word) - 4 * word_of_nat k) = 253 - 4 * k\ +using assms by (simp del: of_nat_mult of_nat_numeral + add: word_of_nat_mult_numeral unat_of_nat unat_sub word_le_nat_alt) + +lemma unat_0xFF_sub_4k [simp]: + assumes \k < 64\ + shows \unat ((0xFF :: 32 word) - 4 * word_of_nat k) = 255 - 4 * k\ +using assms by (simp del: of_nat_mult of_nat_numeral + add: word_of_nat_mult_numeral unat_of_nat unat_sub word_le_nat_alt) + +lemma unat_0x7E_sub_2k [simp]: + assumes \k < 64\ + shows \unat ((0x7E :: 32 word) - 2 * word_of_nat k) = 126 - 2 * k\ +using assms by (simp del: of_nat_mult of_nat_numeral + add: word_of_nat_mult_numeral unat_of_nat unat_sub word_le_nat_alt) + +lemma unat_0x7F_sub_2k [simp]: + assumes \k < 64\ + shows \unat ((0x7F :: 32 word) - 2 * word_of_nat k) = 127 - 2 * k\ +using assms by (simp del: of_nat_mult of_nat_numeral + add: word_of_nat_mult_numeral unat_of_nat unat_sub word_le_nat_alt) + +lemma unat_0x7F_sub_k [simp]: + assumes \k < 128\ + shows \unat ((0x7F :: 32 word) - word_of_nat k) = 127 - k\ +using assms by (intro unat_word_sub_word_of_nat) simp_all + +subsection \Downward-Counting Indexing\ + +lemma mulcache_compute_int_nth_even': + assumes \k < 64\ + shows \mulcache_compute_int ap ! (126 - 2 * k) = + fqmul_int (ap ! (253 - 4 * k)) (zetas_int ! (127 - k))\ +proof - + define i where + \i = 63 - k\ + with assms have \i < 64\ + by simp + have idx: \126 - 2 * k = 2 * i\ \253 - 4 * k = 4 * i + 1\ \127 - k = 64 + i\ + unfolding i_def using assms by auto + show ?thesis unfolding idx + by (rule mulcache_compute_int_nth_even[OF \i < 64\]) +qed + +lemma mulcache_compute_int_nth_odd': + assumes \k < 64\ + shows \mulcache_compute_int ap ! (127 - 2 * k) = + fqmul_int (ap ! (255 - 4 * k)) (- (zetas_int ! (127 - k)))\ +proof - + define i where + \i = 63 - k\ + with assms have \i < 64\ + by simp + have idx: \127 - 2 * k = 2 * i + 1\ \255 - 4 * k = 4 * i + 3\ \127 - k = 64 + i\ + unfolding i_def using assms by auto + show ?thesis unfolding idx + by (rule mulcache_compute_int_nth_odd[OF \i < 64\]) +qed +(*>*) + +(*<*) +end +(*>*) From fa0daf9fbee4275d6f6b033ef17cf24083ee547d Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 21:59:55 +0000 Subject: [PATCH 04/11] proofs/isabelle: add forward and inverse NTT abstract specifications Add MLKEM_NTT_Spec.thy and MLKEM_InvNTT_Spec.thy with butterfly, inner loop, middle loop, layer, and outer loop structure, plus bound propagation and overflow safety lemmas. --- proofs/isabelle/MLKEM_InvNTT_Spec.thy | 529 ++++++++++++++++ proofs/isabelle/MLKEM_NTT_Spec.thy | 856 ++++++++++++++++++++++++++ 2 files changed, 1385 insertions(+) create mode 100644 proofs/isabelle/MLKEM_InvNTT_Spec.thy create mode 100644 proofs/isabelle/MLKEM_NTT_Spec.thy diff --git a/proofs/isabelle/MLKEM_InvNTT_Spec.thy b/proofs/isabelle/MLKEM_InvNTT_Spec.thy new file mode 100644 index 0000000000..fab2b174fa --- /dev/null +++ b/proofs/isabelle/MLKEM_InvNTT_Spec.thy @@ -0,0 +1,529 @@ +(*<*) +theory MLKEM_InvNTT_Spec + imports MLKEM_NTT_Spec +begin +(*>*) + +text \ + Abstract specification of the inverse NTT (Number Theoretic Transform) + for ML-KEM polynomials. Mirrors the structure of @{text MLKEM_NTT_Spec} + with reversed butterfly direction: the sum is Barrett-reduced while + the difference is Montgomery-multiplied by the twiddle factor. + Coefficient-bound propagation and overflow safety are established + for all loop levels. +\ + +section \Inverse NTT Specification\ + +text_raw \ +\begin{figure}[ht] +\centering +\begin{tikzpicture}[>=Stealth, node distance=2.8cm and 3.5cm, + every node/.style={font=\small}, + io/.style={fill=mlklightblue, draw=mlkblue, rounded corners=3pt, + inner sep=4pt}] + % Input nodes + \node[io] (a) {$a$}; + \node[io, below=1.6cm of a] (b) {$b$}; + % Output nodes + \node[io, right=5cm of a] (out1) {$\mathrm{barrett}(a + b)$}; + \node[io, right=5cm of b] (out2) {$\mathit{fqmul}(b - a,\;\zeta)$}; + % Arrows + \draw[->,thick,mlkblue] (a) -- (out1); + \draw[->,thick,mlkblue] (a) -- (out2); + \draw[->,thick,mlkblue] (b) -- (out1); + \draw[->,thick,mlkblue] (b) -- node[below,pos=0.4] {$\times\,\zeta$} (out2); +\end{tikzpicture} +\caption{Inverse NTT butterfly: the sum is Barrett-reduced to keep + coefficients bounded by~$q$, while the difference is + Montgomery-multiplied by the twiddle factor~$\zeta$.} +\label{fig:invntt-butterfly} +\end{figure} +\ + +subsection \Definitions\ + +text \The inverse NTT butterfly applies Barrett reduction to the sum + and Montgomery multiplication to the difference.\ + +definition invntt_butterfly_int :: \int \ nat \ nat \ int list \ int list\ where + \invntt_butterfly_int zeta j blen cs \ + let t = cs ! j; + cs' = cs[j := barrett_reduce_int (t + cs ! (j + blen))] + in cs'[j + blen := fqmul_int (cs ! (j + blen) - t) zeta]\ + +fun invntt_inner_loop_int :: \int \ nat \ nat \ nat \ int list \ int list\ where + \invntt_inner_loop_int zeta off blen 0 cs = cs\ +| \invntt_inner_loop_int zeta off blen (Suc cnt) cs = + invntt_inner_loop_int zeta (Suc off) blen cnt + (invntt_butterfly_int zeta off blen cs)\ + +fun invntt_middle_loop_int :: \nat \ nat \ nat \ nat \ int list \ nat \ int list\ where + \invntt_middle_loop_int k blen 0 num_blocks cs = (k, cs)\ +| \invntt_middle_loop_int k blen (Suc remaining) num_blocks cs = + (let block = num_blocks - Suc remaining; + off = block * (2 * blen); + zeta = zetas_int ! k; + cs' = invntt_inner_loop_int zeta off blen blen cs + in invntt_middle_loop_int (k - 1) blen remaining num_blocks cs')\ + +fun invntt_outer_loop_int :: \nat \ int list \ int list\ where + \invntt_outer_loop_int 0 cs = cs\ +| \invntt_outer_loop_int (Suc n) cs = + (let layer = Suc n; + blen = 2 ^ (8 - layer); + k = 2 ^ layer - 1; + num_blocks = 2 ^ (layer - 1); + (_, cs') = invntt_middle_loop_int k blen num_blocks num_blocks cs + in invntt_outer_loop_int n cs')\ + +text \Full inverse NTT with Montgomery pre-scaling by 1441.\ + +definition poly_invntt_tomont_int :: \int_poly \ int_poly\ where + \poly_invntt_tomont_int cs \ + invntt_outer_loop_int 7 (List.map (\c. fqmul_int c 1441) cs)\ + +text \Convenience: single inverse NTT layer by layer number.\ + +definition invntt_layer_int :: \nat \ int list \ int list\ where + \invntt_layer_int layer cs \ + snd (invntt_middle_loop_int (2^layer - 1) (2^(8 - layer)) + (2^(layer - 1)) (2^(layer - 1)) cs)\ + +subsection \Inner Loop Properties\ + +text \Structural lemmas for the inverse NTT inner loop: length preservation, + snoc decomposition, and per-position value characterisation.\ + +(*<*) +lemma invntt_butterfly_int_length: + shows \length (invntt_butterfly_int z j blen cs) = length cs\ +unfolding invntt_butterfly_int_def Let_def by simp + +lemma invntt_inner_loop_int_snoc: + shows \invntt_inner_loop_int z off blen (Suc m) cs = + invntt_butterfly_int z (off + m) blen (invntt_inner_loop_int z off blen m cs)\ +proof (induct m arbitrary: off cs) + case 0 + then show ?case by simp +next + case (Suc m) + have \invntt_inner_loop_int z off blen (Suc (Suc m)) cs = + invntt_inner_loop_int z (Suc off) blen (Suc m) (invntt_butterfly_int z off blen cs)\ + by simp + also have \\ = invntt_butterfly_int z (Suc off + m) blen + (invntt_inner_loop_int z (Suc off) blen m (invntt_butterfly_int z off blen cs))\ + by (rule Suc) + also have \Suc off + m = off + Suc m\ + by simp + also have \invntt_inner_loop_int z (Suc off) blen m (invntt_butterfly_int z off blen cs) + = invntt_inner_loop_int z off blen (Suc m) cs\ + by simp + finally show ?case + . +qed + +lemma invntt_inner_loop_int_length: + shows \length (invntt_inner_loop_int z off blen cnt cs) = length cs\ +by (induction cnt arbitrary: off cs) (simp_all add: invntt_butterfly_int_length) + +lemma invntt_butterfly_int_nth_other: + assumes \i \ j\ + and \i \ j + blen\ + shows \invntt_butterfly_int zeta j blen cs ! i = cs ! i\ +unfolding invntt_butterfly_int_def Let_def using assms by simp + +lemma invntt_inner_loop_int_nth_unchanged: + assumes \i \ {off.. + and \i \ {off+blen.. + shows \invntt_inner_loop_int z off blen cnt cs ! i = cs ! i\ +using assms proof (induct cnt arbitrary: off cs) + case 0 then show ?case by simp +next + case (Suc cnt) + from Suc.prems have \i \ off\ \i \ off + blen\ + by auto + from Suc.prems have ih1: \i \ {Suc off.. + by auto + from Suc.prems have ih2: \i \ {Suc off+blen.. + by auto + have \invntt_inner_loop_int z off blen (Suc cnt) cs ! i = + invntt_inner_loop_int z (Suc off) blen cnt (invntt_butterfly_int z off blen cs) ! i\ + by simp + also have \\ = (invntt_butterfly_int z off blen cs) ! i\ + by (rule Suc.hyps[OF ih1 ih2]) + also have \\ = cs ! i\ + by (rule invntt_butterfly_int_nth_other[OF \i \ off\ \i \ off + blen\]) + finally show ?case . +qed + +lemma invntt_inner_loop_int_low_val: + assumes \m < cnt\ + and \cnt \ blen\ + and \off + 2 * blen \ length cs\ + shows \invntt_inner_loop_int z off blen cnt cs ! (off + m) = + barrett_reduce_int (cs ! (off + m) + cs ! (off + m + blen))\ +using assms proof (induct cnt arbitrary: m) + case 0 then show ?case by simp +next + case (Suc cnt) + have snoc: \invntt_inner_loop_int z off blen (Suc cnt) cs = + invntt_butterfly_int z (off + cnt) blen (invntt_inner_loop_int z off blen cnt cs)\ + by (rule invntt_inner_loop_int_snoc) + define prev where \prev = invntt_inner_loop_int z off blen cnt cs\ + have len_prev: \length prev = length cs\ unfolding prev_def + by (rule invntt_inner_loop_int_length) + have p1: \prev ! (off + cnt) = cs ! (off + cnt)\ + unfolding prev_def by (rule invntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + have p2: \prev ! (off + cnt + blen) = cs ! (off + cnt + blen)\ + unfolding prev_def by (rule invntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + show ?case + proof (cases \m = cnt\) + case True + have \off + cnt + blen < length prev\ + using Suc.prems len_prev by simp + have ne: \off + cnt + blen \ off + cnt\ + using Suc.prems by simp + show ?thesis using True snoc p1 p2 ne \off + cnt + blen < length prev\ + by (simp add: invntt_butterfly_int_def Let_def prev_def[symmetric]) + next + case False + with Suc.prems have \m < cnt\ by simp + have \off + m \ off + cnt\ using \m < cnt\ by simp + have \off + m \ off + cnt + blen\ using \m < cnt\ Suc.prems by simp + have \invntt_inner_loop_int z off blen (Suc cnt) cs ! (off + m) = prev ! (off + m)\ + using snoc invntt_butterfly_int_nth_other[OF \off + m \ off + cnt\ \off + m \ off + cnt + blen\] + by (simp add: prev_def) + also have \\ = barrett_reduce_int (cs ! (off + m) + cs ! (off + m + blen))\ + unfolding prev_def by (rule Suc.hyps[OF \m < cnt\]) (use Suc.prems in auto) + finally show ?thesis . + qed +qed + +lemma invntt_inner_loop_int_high_val: + assumes \m < cnt\ \cnt \ blen\ \off + 2 * blen \ length cs\ + shows \invntt_inner_loop_int z off blen cnt cs ! (off + m + blen) = + fqmul_int (cs ! (off + m + blen) - cs ! (off + m)) z\ +using assms proof (induct cnt arbitrary: m) + case 0 then show ?case by simp +next + case (Suc cnt) + have snoc: \invntt_inner_loop_int z off blen (Suc cnt) cs = + invntt_butterfly_int z (off + cnt) blen (invntt_inner_loop_int z off blen cnt cs)\ + by (rule invntt_inner_loop_int_snoc) + define prev where \prev = invntt_inner_loop_int z off blen cnt cs\ + have len_prev: \length prev = length cs\ unfolding prev_def + by (rule invntt_inner_loop_int_length) + have p1: \prev ! (off + cnt) = cs ! (off + cnt)\ + unfolding prev_def by (rule invntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + have p2: \prev ! (off + cnt + blen) = cs ! (off + cnt + blen)\ + unfolding prev_def by (rule invntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + show ?case + proof (cases \m = cnt\) + case True + have len1: \off + cnt < length prev\ using Suc.prems len_prev by simp + have len2: \off + cnt + blen < length prev\ using Suc.prems len_prev by simp + have ne: \off + cnt + blen \ off + cnt\ using Suc.prems by simp + show ?thesis using True snoc p1 p2 ne len1 len2 + by (simp add: invntt_butterfly_int_def Let_def prev_def[symmetric]) + next + case False + with Suc.prems have mc: \m < cnt\ by simp + have ne1: \off + m + blen \ off + cnt\ using Suc.prems by simp + have ne2: \off + m + blen \ off + cnt + blen\ using mc by simp + have \invntt_inner_loop_int z off blen (Suc cnt) cs ! (off + m + blen) = prev ! (off + m + blen)\ + using snoc invntt_butterfly_int_nth_other[OF ne1 ne2] by (simp add: prev_def) + also have \\ = fqmul_int (cs ! (off + m + blen) - cs ! (off + m)) z\ + unfolding prev_def by (rule Suc.hyps[OF mc]) (use Suc.prems in auto) + finally show ?thesis . + qed +qed +(*>*) + +subsection \Middle Loop Properties\ + +text \Structural lemmas for the middle and outer loops: length preservation, + k-index tracking, snoc decomposition, and position-unchanged results.\ + +(*<*) +lemma invntt_middle_loop_int_length: + shows \length (snd (invntt_middle_loop_int k blen rem nb cs)) = length cs\ +by (induction rem arbitrary: k cs) (auto simp: case_prod_beta Let_def invntt_inner_loop_int_length) + +lemma invntt_layer_int_length: + shows \length (invntt_layer_int l cs) = length cs\ +unfolding invntt_layer_int_def by (rule invntt_middle_loop_int_length) + +lemma invntt_outer_loop_int_length: + shows \length (invntt_outer_loop_int n cs) = length cs\ +by (induction n arbitrary: cs) (auto simp: case_prod_beta Let_def invntt_middle_loop_int_length) + +text \Inverse NTT middle loop: k-index tracking (k decrements).\ + +lemma invntt_middle_loop_int_fst: + shows \fst (invntt_middle_loop_int k blen rem nb cs) = k - rem\ +by (induction rem arbitrary: k cs) (auto simp: case_prod_beta Let_def) + +text \Snoc decomposition for the inverse NTT middle loop: processing @{term \Suc j\} blocks + equals processing @{term j} blocks then applying one more inner loop + at the next block offset. Analogous to @{thm ntt_middle_loop_int_snoc_gen} + but with k decrementing instead of incrementing.\ + +lemma invntt_middle_loop_int_snoc_gen: + shows \snd (invntt_middle_loop_int k blen (Suc j) (s + Suc j) cs) = + invntt_inner_loop_int (zetas_int ! (k - j)) ((s + j) * (2 * blen)) blen blen + (snd (invntt_middle_loop_int k blen j (s + j) cs))\ +proof (induct j arbitrary: k s cs) + case 0 + then show ?case + by (simp add: case_prod_beta Let_def) +next + case (Suc j) + \ \Unfold one step: processes block at offset s\ + have lhs: \invntt_middle_loop_int k blen (Suc (Suc j)) (s + Suc (Suc j)) cs = + invntt_middle_loop_int (k - 1) blen (Suc j) (s + Suc (Suc j)) + (invntt_inner_loop_int (zetas_int ! k) (s * (2 * blen)) blen blen cs)\ + by (simp add: case_prod_beta Let_def) + define cs' where + \cs' = invntt_inner_loop_int (zetas_int ! k) (s * (2 * blen)) blen blen cs\ + \ \Rewrite @{term \s + Suc (Suc j)\} = @{term \Suc s + Suc j\}\ + have \snd (invntt_middle_loop_int (k - 1) blen (Suc j) (Suc s + Suc j) cs') = + invntt_inner_loop_int (zetas_int ! ((k - 1) - j)) ((Suc s + j) * (2 * blen)) blen blen + (snd (invntt_middle_loop_int (k - 1) blen j (Suc s + j) cs'))\ + by (rule Suc) + \ \RHS unfolds the same way\ + moreover have \invntt_middle_loop_int k blen (Suc j) (s + Suc j) cs = + invntt_middle_loop_int (k - 1) blen j (s + Suc j) + (invntt_inner_loop_int (zetas_int ! k) (s * (2 * blen)) blen blen cs)\ + by (simp add: case_prod_beta Let_def) + ultimately show ?case + using lhs by (simp add: cs'_def) +qed + +corollary invntt_middle_loop_int_snoc: + shows \snd (invntt_middle_loop_int k blen (Suc j) (Suc j) cs) = + invntt_inner_loop_int (zetas_int ! (k - j)) (j * (2 * blen)) blen blen + (snd (invntt_middle_loop_int k blen j j cs))\ +using invntt_middle_loop_int_snoc_gen[where s=0] by simp + +lemma invntt_middle_loop_int_nth_unchanged: + assumes \j * (2 * blen) \ i\ + shows \snd (invntt_middle_loop_int k blen j j cs) ! i = cs ! i\ +using assms proof (induct j arbitrary: k cs) + case 0 + then show ?case by simp +next + case (Suc j) + have snoc: \snd (invntt_middle_loop_int k blen (Suc j) (Suc j) cs) = + invntt_inner_loop_int (zetas_int ! (k - j)) (j * (2 * blen)) blen blen + (snd (invntt_middle_loop_int k blen j j cs))\ + by (rule invntt_middle_loop_int_snoc) + define prev where \prev = snd (invntt_middle_loop_int k blen j j cs)\ + from Suc.prems have \j * (2 * blen) \ i\ by simp + hence ih: \prev ! i = cs ! i\ + unfolding prev_def by (rule Suc.hyps) + from Suc.prems have \i \ {j * (2 * blen).. by auto + moreover from Suc.prems have \i \ {j * (2 * blen) + blen.. by auto + ultimately have \invntt_inner_loop_int (zetas_int ! (k - j)) (j * (2 * blen)) blen blen prev ! i = prev ! i\ + by (rule invntt_inner_loop_int_nth_unchanged) + with ih snoc show ?case by (simp add: prev_def) +qed +(*>*) + +subsection \Bound Propagation\ + +text \fqmul bound when the second argument is bounded by the max zetas value. + Since @{term \\b\ \ 1659\} and @{term \\a\ < 32768\}, the product + @{term \\a * b\ < 32768 * MLKEM_Q\} and @{thm fqmul_int_bound_Q} applies.\ + +lemma fqmul_prescale_bound: + assumes \\b\ \ 1659\ and \\a\ \ 32768\ + shows \\fqmul_int a b\ < MLKEM_Q\ +proof (rule fqmul_int_bound_Q) + have \\a * b\ = \a\ * \b\\ + by (rule abs_mult) + also have \\ \ 32768 * 1659\ + proof (rule mult_mono) + show \\a\ \ 32768\ + by (rule assms(2)) + show \\b\ \ 1659\ + by (rule assms(1)) + qed auto + also have \\ < 32768 * MLKEM_Q\ + by simp + finally show \\a * b\ < 32768 * MLKEM_Q\ + . +qed + +lemma fqmul_prescale_bound_sint: + fixes a :: \16 sword\ + assumes \\b\ \ 1659\ + shows \\fqmul_int (sint a) b\ < MLKEM_Q\ +proof (rule fqmul_prescale_bound[OF assms]) + have \sint a \ -(2^(size a - 1))\ and \sint a < 2^(size a - 1)\ + using sint_range_size[of a] by auto + thus \\sint a\ \ 32768\ + by (auto simp: word_size) +qed + +text \Coefficient bound preservation through inverse NTT butterfly and loops.\ + +lemma invntt_butterfly_int_coeff_bound: + assumes \coeff_bound MLKEM_Q cs\ + and \j + blen < length cs\ + and \j < length cs\ + and \\zeta\ \ 1659\ + shows \coeff_bound MLKEM_Q (invntt_butterfly_int zeta j blen cs)\ +proof - + from assms(1,2,3) have j_bound: \\cs ! j\ < MLKEM_Q\ + and jb_bound: \\cs ! (j + blen)\ < MLKEM_Q\ + unfolding coeff_bound_def by auto + have sum_range: \-32768 \ cs ! j + cs ! (j + blen)\ + \cs ! j + cs ! (j + blen) \ 32767\ + using j_bound jb_bound by linarith+ + have barrett_bound: \\barrett_reduce_int (cs ! j + cs ! (j + blen))\ < MLKEM_Q\ + by (rule barrett_reduce_int_abs_bound_int[OF sum_range]) + have diff_bound: \\cs ! (j + blen) - cs ! j\ \ 32768\ + using j_bound jb_bound by linarith + have fqmul_bound: \\fqmul_int (cs ! (j + blen) - cs ! j) zeta\ < MLKEM_Q\ + by (rule fqmul_prescale_bound[OF assms(4) diff_bound]) + show ?thesis + unfolding invntt_butterfly_int_def Let_def coeff_bound_def + invntt_butterfly_int_length + using barrett_bound fqmul_bound assms(1,2,3) + by (auto simp: nth_list_update coeff_bound_def) +qed + +lemma invntt_inner_loop_int_coeff_bound: + assumes \coeff_bound MLKEM_Q cs\ + and \off + 2 * blen \ length cs\ + and \cnt \ blen\ + and \\zeta\ \ 1659\ + shows \coeff_bound MLKEM_Q (invntt_inner_loop_int zeta off blen cnt cs)\ +proof - + from assms(2,3) have \off + cnt + blen \ length cs\ + by linarith + thus ?thesis using assms(1) + proof (induct cnt arbitrary: off cs) + case 0 + then show ?case by simp + next + case (Suc cnt) + have off_lt: \off < length cs\ + using Suc.prems(1) by linarith + have off_blen_lt: \off + blen < length cs\ + using Suc.prems(1) by linarith + have cb': \coeff_bound MLKEM_Q (invntt_butterfly_int zeta off blen cs)\ + by (rule invntt_butterfly_int_coeff_bound[OF Suc.prems(2) off_blen_lt off_lt assms(4)]) + have len': \Suc off + cnt + blen \ length (invntt_butterfly_int zeta off blen cs)\ + using Suc.prems(1) by (simp add: invntt_butterfly_int_length) + show ?case + by simp (rule Suc.hyps[OF len' cb']) + qed +qed + +lemma invntt_middle_loop_int_coeff_bound: + assumes \coeff_bound MLKEM_Q cs\ + and \1 \ l\ + and \l \ 7\ + and \length cs = MLKEM_N\ + and \j \ 2 ^ (l - 1)\ + shows \coeff_bound MLKEM_Q (snd (invntt_middle_loop_int (2^l - 1) (2^(8 - l)) j j cs))\ +using assms(5) proof (induct j) + case 0 + show ?case using assms(1) by simp +next + case (Suc j) + define blen where + \blen = (2::nat) ^ (8 - l)\ + define k0 where + \k0 = (2::nat) ^ l - 1\ + define nb where + \nb = (2::nat) ^ (l - 1)\ + define prev where + \prev = snd (invntt_middle_loop_int k0 blen j j cs)\ + define z where + \z = zetas_int ! (k0 - j)\ + have snoc: \snd (invntt_middle_loop_int k0 blen (Suc j) (Suc j) cs) = + invntt_inner_loop_int z (j * (2 * blen)) blen blen prev\ + unfolding z_def prev_def by (rule invntt_middle_loop_int_snoc) + from Suc.prems have j_le: \j \ 2^(l-1)\ + by simp + hence ih: \coeff_bound MLKEM_Q prev\ + unfolding prev_def k0_def blen_def by (rule Suc.hyps) + have len_prev: \length prev = length cs\ + unfolding prev_def k0_def blen_def by (rule invntt_middle_loop_int_length) + have block_fits: \j * (2 * blen) + 2 * blen \ length prev\ + proof - + have \(l - 1) + (8 - l) = 7\ + using assms(2,3) by simp + hence \(2::nat) ^ (l-1) * 2^(8-l) = 2^7\ + by (metis power_add) + hence nb_2blen: \nb * (2 * blen) = 256\ + unfolding nb_def blen_def by simp + from Suc.prems have \Suc j \ nb\ + unfolding nb_def by simp + hence \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1) + hence \j * (2 * blen) + 2 * blen \ 256\ + using nb_2blen by simp + thus ?thesis + using assms(4) len_prev by simp + qed + have z_bound: \\z\ \ 1659\ + proof - + have \k0 < 128\ + proof - + have \(2::nat)^l \ 2^7\ + using assms(3) by (intro power_increasing) simp_all + thus ?thesis + unfolding k0_def by simp + qed + hence \k0 - j < 128\ + using diff_le_self le_less_trans by blast + thus ?thesis + unfolding z_def by (rule zetas_int_abs_bound) + qed + have \coeff_bound MLKEM_Q (invntt_inner_loop_int z (j * (2 * blen)) blen blen prev)\ + by (rule invntt_inner_loop_int_coeff_bound[OF ih block_fits le_refl z_bound]) + thus ?case + unfolding k0_def[symmetric] blen_def[symmetric] using snoc by simp +qed + +subsection \Layer and Outer Loop\ + +theorem invntt_layer_int_coeff_bound: + assumes \coeff_bound MLKEM_Q cs\ + and \1 \ l\ + and \l \ 7\ + and \length cs = MLKEM_N\ + shows \coeff_bound MLKEM_Q (invntt_layer_int l cs)\ +unfolding invntt_layer_int_def by (rule invntt_middle_loop_int_coeff_bound[OF assms]) simp + +subsection \Overflow Safety\ + +text \Overflow safety: if all coefficients are bounded by @{const MLKEM_Q}, sums and + differences of two coefficients fit in 16-bit signed range.\ + +lemma invntt_coeff_bound_sum_bounds: + assumes \coeff_bound MLKEM_Q cs\ + and \j < length cs\ + and \j + blen < length cs\ + shows \-32768 \ cs ! j + cs ! (j + blen)\ + and \cs ! j + cs ! (j + blen) \ 32767\ + and \-32768 \ cs ! (j + blen) - cs ! j\ + and \cs ! (j + blen) - cs ! j \ 32767\ +proof - + from assms have j_bound: \\cs ! j\ < MLKEM_Q\ + and jb_bound: \\cs ! (j + blen)\ < MLKEM_Q\ + unfolding coeff_bound_def by auto + show \-32768 \ cs ! j + cs ! (j + blen)\ + and \cs ! j + cs ! (j + blen) \ 32767\ + and \-32768 \ cs ! (j + blen) - cs ! j\ + and \cs ! (j + blen) - cs ! j \ 32767\ + using j_bound jb_bound by linarith+ +qed + +(*<*) +end +(*>*) diff --git a/proofs/isabelle/MLKEM_NTT_Spec.thy b/proofs/isabelle/MLKEM_NTT_Spec.thy new file mode 100644 index 0000000000..d1da56638a --- /dev/null +++ b/proofs/isabelle/MLKEM_NTT_Spec.thy @@ -0,0 +1,856 @@ +(*<*) +theory MLKEM_NTT_Spec + imports MLKEM_Zetas +begin +(*>*) + +text \ + Abstract specification of the Number Theoretic Transform (NTT) matching + the butterfly\,\\\\,inner-loop\,\\\\,middle-loop\,\\\\,outer-loop + structure of the C implementation. All operations on unbounded integers; + overflow analysis is separate. +\ + +section \NTT Specification\ + +text_raw \ +\begin{figure}[ht] +\centering +\begin{tikzpicture}[>=Stealth, node distance=2.8cm and 3.5cm, + every node/.style={font=\small}, + io/.style={fill=mlklightblue, draw=mlkblue, rounded corners=3pt, + inner sep=4pt}] + % Input nodes + \node[io] (a) {$a$}; + \node[io, below=1.6cm of a] (b) {$b$}; + % Output nodes + \node[io, right=4cm of a] (out1) {$a + \zeta \cdot b$}; + \node[io, right=4cm of b] (out2) {$a - \zeta \cdot b$}; + % Arrows + \draw[->,thick,mlkblue] (a) -- (out1); + \draw[->,thick,mlkblue] (a) -- (out2); + \draw[->,thick,mlkblue] (b) -- node[above,pos=0.4] {$\times\,\zeta$} (out1); + \draw[->,thick,mlkblue] (b) -- node[below,pos=0.4] {$\times\,(-\zeta)$} (out2); +\end{tikzpicture} +\caption{Forward NTT butterfly: given inputs $a$ and $b$ with twiddle + factor $\zeta$, the outputs are $a + \mathit{fqmul}(b, \zeta)$ and + $a - \mathit{fqmul}(b, \zeta)$, where \emph{fqmul} denotes + Montgomery multiplication modulo~$q$.} +\label{fig:ntt-butterfly} +\end{figure} +\ + +subsection \Definitions\ + +text \Structural abstract NTT following the C implementation. + All operations on unbounded integers; overflow analysis is separate.\ + +definition ntt_butterfly_int :: \int \ nat \ nat \ int list \ int list\ where + \ntt_butterfly_int zeta j blen cs \ + let t = fqmul_int (cs ! (j + blen)) zeta in + (cs[j + blen := cs ! j - t])[j := cs ! j + t]\ + +fun ntt_inner_loop_int :: \int \ nat \ nat \ nat \ int list \ int list\ where + \ntt_inner_loop_int zeta off blen 0 cs = cs\ +| \ntt_inner_loop_int zeta off blen (Suc cnt) cs = + ntt_inner_loop_int zeta (Suc off) blen cnt + (ntt_butterfly_int zeta off blen cs)\ + +fun ntt_middle_loop_int :: \nat \ nat \ nat \ nat \ int list \ nat \ int list\ where + \ntt_middle_loop_int k blen 0 num_blocks cs = (k, cs)\ +| \ntt_middle_loop_int k blen (Suc remaining) num_blocks cs = + (let block = num_blocks - Suc remaining; + off = block * (2 * blen); + zeta = zetas_int ! k; + cs' = ntt_inner_loop_int zeta off blen blen cs + in ntt_middle_loop_int (Suc k) blen remaining num_blocks cs')\ + +fun ntt_outer_loop_int :: \nat \ nat \ int list \ int list\ where + \ntt_outer_loop_int k 0 cs = cs\ +| \ntt_outer_loop_int k (Suc layer_rem) cs = + (let blen = 2 ^ (Suc layer_rem); + num_blocks = 2 ^ (6 - layer_rem); + (k', cs') = ntt_middle_loop_int k blen num_blocks num_blocks cs + in ntt_outer_loop_int k' layer_rem cs')\ + +definition poly_ntt_int :: \int_poly \ int_poly\ where + \poly_ntt_int cs \ ntt_outer_loop_int 1 7 cs\ + +text \Convenience: single NTT layer by layer number.\ + +definition ntt_layer_int :: \nat \ int list \ int list\ where + \ntt_layer_int layer cs \ + snd (ntt_middle_loop_int (2^(layer - 1)) (2^(8 - layer)) + (2^(layer - 1)) (2^(layer - 1)) cs)\ + +subsection \Inner Loop Properties\ + +(*<*) +lemma ntt_butterfly_int_length: + shows \length (ntt_butterfly_int z j blen cs) = length cs\ +unfolding ntt_butterfly_int_def Let_def by simp + +lemma ntt_inner_loop_int_length: + shows \length (ntt_inner_loop_int z off blen cnt cs) = length cs\ +by (induct cnt arbitrary: off cs) (simp_all add: ntt_butterfly_int_length) + +text \Snoc decomposition: processing @{term \Suc m\} butterflies equals + processing @{term m} then applying one more butterfly at position + @{term \off + m\}.\ + +lemma ntt_inner_loop_int_snoc: + shows \ntt_inner_loop_int z off blen (Suc m) cs = + ntt_butterfly_int z (off + m) blen (ntt_inner_loop_int z off blen m cs)\ +proof (induct m arbitrary: off cs) + case 0 + then show ?case by simp +next + case (Suc m) + have \ntt_inner_loop_int z off blen (Suc (Suc m)) cs = + ntt_inner_loop_int z (Suc off) blen (Suc m) (ntt_butterfly_int z off blen cs)\ + by simp + also have \\ = ntt_butterfly_int z (Suc off + m) blen + (ntt_inner_loop_int z (Suc off) blen m (ntt_butterfly_int z off blen cs))\ + by (rule Suc) + also have \Suc off + m = off + Suc m\ + by simp + also have \ntt_inner_loop_int z (Suc off) blen m (ntt_butterfly_int z off blen cs) + = ntt_inner_loop_int z off blen (Suc m) cs\ + by simp + finally show ?case + . +qed + +lemma ntt_butterfly_int_nth_other: + assumes \i \ j\ + and \i \ j + blen\ + shows \ntt_butterfly_int zeta j blen cs ! i = cs ! i\ +unfolding ntt_butterfly_int_def Let_def using assms by simp + +lemma ntt_inner_loop_int_nth_unchanged: + assumes \i \ {off.. + and \i \ {off+blen.. + shows \ntt_inner_loop_int z off blen cnt cs ! i = cs ! i\ +using assms proof (induct cnt arbitrary: off cs) + case 0 then show ?case by simp +next + case (Suc cnt) + from Suc.prems have \i \ off\ \i \ off + blen\ + by auto + from Suc.prems have ih1: \i \ {Suc off.. + by auto + from Suc.prems have ih2: \i \ {Suc off+blen.. + by auto + have \ntt_inner_loop_int z off blen (Suc cnt) cs = + ntt_inner_loop_int z (Suc off) blen cnt (ntt_butterfly_int z off blen cs)\ + by simp + hence \ntt_inner_loop_int z off blen (Suc cnt) cs ! i = + ntt_inner_loop_int z (Suc off) blen cnt (ntt_butterfly_int z off blen cs) ! i\ + by simp + also have \\ = (ntt_butterfly_int z off blen cs) ! i\ + by (rule Suc.hyps[OF ih1 ih2]) + also have \\ = cs ! i\ + by (rule ntt_butterfly_int_nth_other[OF \i \ off\ \i \ off + blen\]) + finally show ?case + . +qed + +lemma ntt_inner_loop_int_low_val: + assumes \m < cnt\ + and \cnt \ blen\ + and \off + 2 * blen \ length cs\ + shows \ntt_inner_loop_int z off blen cnt cs ! (off + m) = + cs ! (off + m) + fqmul_int (cs ! (off + m + blen)) z\ +using assms proof (induct cnt arbitrary: m) + case 0 + then show ?case + by simp +next + case (Suc cnt) + have snoc: \ntt_inner_loop_int z off blen (Suc cnt) cs = + ntt_butterfly_int z (off + cnt) blen (ntt_inner_loop_int z off blen cnt cs)\ + by (rule ntt_inner_loop_int_snoc) + define prev where \prev = ntt_inner_loop_int z off blen cnt cs\ + have len_prev: \length prev = length cs\ unfolding prev_def + by (rule ntt_inner_loop_int_length) + have p1: \prev ! (off + cnt) = cs ! (off + cnt)\ + unfolding prev_def by (rule ntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + have p2: \prev ! (off + cnt + blen) = cs ! (off + cnt + blen)\ + unfolding prev_def by (rule ntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + show ?case + proof (cases \m = cnt\) + case True + have \off + cnt < length prev\ + using Suc.prems len_prev by simp + thus ?thesis using True snoc p1 p2 + by (simp add: ntt_butterfly_int_def Let_def prev_def[symmetric]) + next + case False + with Suc.prems have \m < cnt\ + by simp + have \off + m \ off + cnt\ + using \m < cnt\ by simp + have \off + m \ off + cnt + blen\ + using \m < cnt\ Suc.prems by simp + have \ntt_inner_loop_int z off blen (Suc cnt) cs ! (off + m) = + prev ! (off + m)\ + using snoc ntt_butterfly_int_nth_other[OF \off + m \ off + cnt\ \off + m \ off + cnt + blen\] + by (simp add: prev_def) + also have \\ = cs ! (off + m) + fqmul_int (cs ! (off + m + blen)) z\ + unfolding prev_def by (rule Suc.hyps[OF \m < cnt\]) (use Suc.prems in auto) + finally show ?thesis + . + qed +qed + +lemma ntt_inner_loop_int_high_val: + assumes \m < cnt\ \cnt \ blen\ \off + 2 * blen \ length cs\ + shows \ntt_inner_loop_int z off blen cnt cs ! (off + m + blen) = + cs ! (off + m) - fqmul_int (cs ! (off + m + blen)) z\ +using assms proof (induct cnt arbitrary: m) + case 0 + then show ?case + by simp +next + case (Suc cnt) + have snoc: \ntt_inner_loop_int z off blen (Suc cnt) cs = + ntt_butterfly_int z (off + cnt) blen (ntt_inner_loop_int z off blen cnt cs)\ + by (rule ntt_inner_loop_int_snoc) + define prev where \prev = ntt_inner_loop_int z off blen cnt cs\ + have len_prev: \length prev = length cs\ + unfolding prev_def by (rule ntt_inner_loop_int_length) + have p1: \prev ! (off + cnt) = cs ! (off + cnt)\ + unfolding prev_def by (rule ntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + have p2: \prev ! (off + cnt + blen) = cs ! (off + cnt + blen)\ + unfolding prev_def by (rule ntt_inner_loop_int_nth_unchanged) (use Suc.prems in auto) + show ?case + proof (cases \m = cnt\) + case True + have len1: \off + cnt < length prev\ + using Suc.prems len_prev by simp + have len2: \off + cnt + blen < length prev\ + using Suc.prems len_prev by simp + have ne: \off + cnt + blen \ off + cnt\ + using Suc.prems by simp + thus ?thesis + using True snoc p1 p2 ne len1 len2 by (simp add: ntt_butterfly_int_def Let_def prev_def[symmetric]) + next + case False + with Suc.prems have mc: \m < cnt\ + by simp + have ne1: \off + m + blen \ off + cnt\ + using Suc.prems by simp + have ne2: \off + m + blen \ off + cnt + blen\ + using mc by simp + have \ntt_inner_loop_int z off blen (Suc cnt) cs ! (off + m + blen) = + prev ! (off + m + blen)\ + using snoc ntt_butterfly_int_nth_other[OF ne1 ne2] by (simp add: prev_def) + also have \\ = cs ! (off + m) - fqmul_int (cs ! (off + m + blen)) z\ + unfolding prev_def by (rule Suc.hyps[OF \m < cnt\]) (use Suc.prems in auto) + finally show ?thesis + . + qed +qed + +(*>*) + +subsection \Coefficient Bounds and Overflow Predicates\ + +text \The NTT butterfly adds and subtracts values, so coefficient magnitudes + grow with each layer. To guarantee that intermediate results stay within + the 16-bit machine word, we define a coefficient-bound predicate and a + no-overflow predicate for individual butterfly steps. The C verification + uses these to thread overflow safety through all seven layers.\ + +text \Coefficient bound predicate: all coefficients have absolute value less than B.\ + +definition coeff_bound :: \int \ int list \ bool\ where + \coeff_bound B cs \ \i < length cs. \cs ! i\ < B\ + +text \No-overflow predicate for the inner loop. + Ensures each butterfly step produces values in 16-bit range.\ + +definition ntt_inner_no_overflow :: \int \ nat \ nat \ nat \ int list \ bool\ where + \ntt_inner_no_overflow zeta off blen cnt acs \ + (\m < cnt. + let cs' = ntt_inner_loop_int zeta off blen m acs; + j = off + m; t = fqmul_int (cs' ! (j + blen)) zeta + in - 32768 \ cs' ! j + t \ cs' ! j + t \ 32767 \ + - 32768 \ cs' ! j - t \ cs' ! j - t \ 32767)\ + +text \No-overflow predicate for a full NTT layer.\ + +definition ntt_layer_no_overflow :: \nat \ int list \ bool\ where + \ntt_layer_no_overflow l acs \ + (let k0 = 2 ^ (l - 1); blen = 2 ^ (8 - l) + in \j < 2 ^ (l - 1). + ntt_inner_no_overflow (zetas_int ! (k0 + j)) (j * 2 * blen) blen blen + (snd (ntt_middle_loop_int k0 blen j j acs)))\ + +lemma ntt_layer_no_overflow_block: + assumes \ntt_layer_no_overflow l acs\ \j < 2 ^ (l - 1)\ + shows \ntt_inner_no_overflow (zetas_int ! (2 ^ (l - 1) + j)) + (j * 2 * 2 ^ (8 - l)) (2 ^ (8 - l)) (2 ^ (8 - l)) + (snd (ntt_middle_loop_int (2 ^ (l - 1)) (2 ^ (8 - l)) j j acs))\ +using assms unfolding ntt_layer_no_overflow_def Let_def by auto + +text \Key fqmul bound: if the product is small enough, + the Montgomery-reduced result is strictly less than Q.\ + +lemma fqmul_int_bound_Q: + assumes \\a * b\ < 32768 * MLKEM_Q\ + shows \\fqmul_int a b\ < MLKEM_Q\ +proof - + define t where \t = signed_take_bit 15 ((a * b) mod 65536 * 62209 mod 65536)\ + have result_eq: \fqmul_int a b * 65536 = a * b - t * 3329\ + unfolding fqmul_int_def t_def by (rule montgomery_reduce_int_result_eq) + have t_lb: \t \ -32768\ + proof - + have \t \ -(2^15)\ + unfolding t_def by (rule signed_take_bit_int_greater_eq_minus_exp) + thus ?thesis + by simp + qed + have t_ub: \t < 32768\ + proof - + have \t < 2^15\ + unfolding t_def by (rule signed_take_bit_int_less_exp) + thus ?thesis + by simp + qed + from assms have \a * b < 32768 * 3329\ \a * b > -(32768 * 3329)\ + by (auto simp: abs_less_iff) + have \fqmul_int a b * 65536 < 65536 * 3329\ + using result_eq \a * b < _\ t_lb by linarith + moreover have \fqmul_int a b * 65536 > -(65536 * 3329)\ + using result_eq \a * b > _\ t_ub by linarith + ultimately show ?thesis + by simp +qed + +(*<*) +subsection \Middle Loop Properties\ + +lemma ntt_middle_loop_int_fst: + shows \fst (ntt_middle_loop_int k blen rem nb cs) = k + rem\ +by (induct rem arbitrary: k cs) (auto simp: case_prod_beta Let_def) + +lemma ntt_middle_loop_int_length: + shows \length (snd (ntt_middle_loop_int k blen rem nb cs)) = length cs\ +by (induct rem arbitrary: k cs) (auto simp: case_prod_beta Let_def ntt_inner_loop_int_length) + +text \Snoc decomposition for the middle loop: processing @{term \Suc j\} blocks + equals processing @{term j} blocks then applying one more inner loop + at the next block offset.\ + +lemma ntt_middle_loop_int_snoc_gen: + shows \snd (ntt_middle_loop_int k blen (Suc j) (s + Suc j) cs) = + ntt_inner_loop_int (zetas_int ! (k + j)) ((s + j) * (2 * blen)) blen blen + (snd (ntt_middle_loop_int k blen j (s + j) cs))\ +proof (induct j arbitrary: k s cs) + case 0 + then show ?case by (simp add: case_prod_beta Let_def) +next + case (Suc j) + \ \Unfold one step: processes block at offset s\ + have lhs: \ntt_middle_loop_int k blen (Suc (Suc j)) (s + Suc (Suc j)) cs = + ntt_middle_loop_int (Suc k) blen (Suc j) (s + Suc (Suc j)) + (ntt_inner_loop_int (zetas_int ! k) (s * (2 * blen)) blen blen cs)\ + by (simp add: case_prod_beta Let_def) + define cs' where \cs' = ntt_inner_loop_int (zetas_int ! k) (s * (2 * blen)) blen blen cs\ + \ \Rewrite @{term \s + Suc (Suc j)\} = @{term \Suc s + Suc j\}\ + have \snd (ntt_middle_loop_int (Suc k) blen (Suc j) (Suc s + Suc j) cs') = + ntt_inner_loop_int (zetas_int ! (Suc k + j)) ((Suc s + j) * (2 * blen)) blen blen + (snd (ntt_middle_loop_int (Suc k) blen j (Suc s + j) cs'))\ + by (rule Suc) + \ \RHS unfolds the same way\ + moreover have \ntt_middle_loop_int k blen (Suc j) (s + Suc j) cs = + ntt_middle_loop_int (Suc k) blen j (s + Suc j) + (ntt_inner_loop_int (zetas_int ! k) (s * (2 * blen)) blen blen cs)\ + by (simp add: case_prod_beta Let_def) + ultimately show ?case + using lhs by (simp add: cs'_def) +qed + +corollary ntt_middle_loop_int_snoc: + shows \snd (ntt_middle_loop_int k blen (Suc j) (Suc j) cs) = + ntt_inner_loop_int (zetas_int ! (k + j)) (j * (2 * blen)) blen blen + (snd (ntt_middle_loop_int k blen j j cs))\ +using ntt_middle_loop_int_snoc_gen[where s=0] by simp + +(*>*) + +subsection \Bound Propagation Through NTT Layers\ + +lemma coeff_bound_mono: + assumes \coeff_bound B cs\ \B \ B'\ + shows \coeff_bound B' cs\ +using assms unfolding coeff_bound_def by (meson order_less_le_trans) + +lemma ntt_middle_loop_int_nth_unchanged: + assumes \j * (2 * blen) \ i\ + shows \snd (ntt_middle_loop_int k blen j j cs) ! i = cs ! i\ +using assms proof (induct j arbitrary: k cs) + case 0 + then show ?case + by simp +next + case (Suc j) + have snoc: \snd (ntt_middle_loop_int k blen (Suc j) (Suc j) cs) = + ntt_inner_loop_int (zetas_int ! (k + j)) (j * (2 * blen)) blen blen + (snd (ntt_middle_loop_int k blen j j cs))\ + by (rule ntt_middle_loop_int_snoc) + define prev where \prev = snd (ntt_middle_loop_int k blen j j cs)\ + from Suc.prems have \j * (2 * blen) \ i\ + by simp + hence ih: \prev ! i = cs ! i\ + unfolding prev_def by (rule Suc.hyps) + from Suc.prems have \i \ {j * (2 * blen).. + by auto + moreover from Suc.prems have \i \ {j * (2 * blen) + blen.. + by auto + ultimately have \ntt_inner_loop_int (zetas_int ! (k + j)) (j * (2 * blen)) blen blen prev ! i = prev ! i\ + by (rule ntt_inner_loop_int_nth_unchanged) + with ih snoc show ?case + by (simp add: prev_def) +qed + +lemma ntt_middle_loop_int_coeff_bound: + assumes cb: \coeff_bound (int l * MLKEM_Q) cs\ + and l_ge: \1 \ l\ + and l_le: \l \ 7\ + and len: \length cs = MLKEM_N\ + and j_le: \j \ 2 ^ (l - 1)\ + shows \coeff_bound (int (l + 1) * MLKEM_Q) + (snd (ntt_middle_loop_int (2^(l-1)) (2^(8-l)) j j cs))\ +using j_le proof (induct j) + case 0 + show ?case + using coeff_bound_mono[OF cb] by simp +next + case (Suc j) + define blen where \blen = (2::nat) ^ (8 - l)\ + define k0 where \k0 = (2::nat) ^ (l - 1)\ + define prev where \prev = snd (ntt_middle_loop_int k0 blen j j cs)\ + define z where \z = zetas_int ! (k0 + j)\ + have snoc: \snd (ntt_middle_loop_int k0 blen (Suc j) (Suc j) cs) = + ntt_inner_loop_int z (j * (2 * blen)) blen blen prev\ + unfolding z_def prev_def by (rule ntt_middle_loop_int_snoc) + from Suc.prems have \j \ 2^(l-1)\ + by simp + hence ih: \coeff_bound (int (l + 1) * MLKEM_Q) prev\ + unfolding prev_def k0_def blen_def by (rule Suc.hyps) + have len_prev: \length prev = length cs\ + unfolding prev_def k0_def blen_def by (rule ntt_middle_loop_int_length) + have block_fits: \j * (2 * blen) + 2 * blen \ length cs\ + proof - + have k0_2blen: \k0 * (2 * blen) = 256\ + proof - + have \l - 1 + (8 - l) = 7\ + using l_ge l_le by simp + hence \(2::nat) ^ (l - 1) * 2 ^ (8 - l) = 2 ^ 7\ + by (metis power_add) + thus ?thesis + unfolding k0_def blen_def by simp + qed + from Suc.prems have \Suc j \ k0\ + unfolding k0_def by simp + hence \Suc j * (2 * blen) \ k0 * (2 * blen)\ + by (intro mult_le_mono1) + thus ?thesis + using k0_2blen len by simp + qed + have z_bound: \\z\ \ 1659\ + proof - + have \k0 + j < 128\ + proof - + have \2 * k0 \ 128\ + proof - + have \l - 1 \ 6\ + using l_le by simp + hence \(2::nat) ^ (l - 1) \ 2 ^ 6\ + by (intro power_increasing) simp_all + thus ?thesis + unfolding k0_def by simp + qed + moreover from Suc.prems have \j < k0\ + unfolding k0_def by simp + ultimately show ?thesis + by simp + qed + thus ?thesis + unfolding z_def by (rule zetas_int_abs_bound) + qed + show ?case unfolding snoc k0_def[symmetric] blen_def[symmetric] + unfolding coeff_bound_def ntt_inner_loop_int_length len_prev + proof (intro allI impI) + fix i assume i_lt: \i < length cs\ + have fqmul_bound: \\fqmul_int x z\ < MLKEM_Q\ if \\x\ < int l * MLKEM_Q\ for x + proof (rule fqmul_int_bound_Q) + have \\x * z\ = \x\ * \z\\ + by (rule abs_mult) + also have \\ \ (int l * 3329 - 1) * 1659\ + proof (rule mult_mono) + from that show \\x\ \ int l * 3329 - 1\ + by linarith + qed (use z_bound l_ge in auto) + also have \\ < 32768 * 3329\ + using l_ge l_le by simp + finally show \\x * z\ < 32768 * MLKEM_Q\ + by simp + qed + show \\ntt_inner_loop_int z (j * (2 * blen)) blen blen prev ! i\ < int (l + 1) * MLKEM_Q\ + proof (cases \j * (2 * blen) \ i \ i < j * (2 * blen) + blen\) + case True + then obtain m where i_eq: \i = j * (2 * blen) + m\ and m_lt: \m < blen\ + by (metis le_add_diff_inverse nat_add_left_cancel_less) + have \ntt_inner_loop_int z (j * (2 * blen)) blen blen prev ! i = + prev ! i + fqmul_int (prev ! (i + blen)) z\ + unfolding i_eq by (rule ntt_inner_loop_int_low_val[OF m_lt le_refl]) (use block_fits len_prev in simp) + also have \prev ! i = cs ! i\ + unfolding prev_def by (rule ntt_middle_loop_int_nth_unchanged) (use True in simp) + also have \prev ! (i + blen) = cs ! (i + blen)\ + unfolding prev_def by (rule ntt_middle_loop_int_nth_unchanged) (use True in simp) + finally have val: \ntt_inner_loop_int z (j * (2 * blen)) blen blen prev ! i = + cs ! i + fqmul_int (cs ! (i + blen)) z\ . + have ci: \\cs ! i\ < int l * MLKEM_Q\ + using cb i_lt unfolding coeff_bound_def by auto + have cib: \\cs ! (i + blen)\ < int l * MLKEM_Q\ + using cb block_fits True unfolding coeff_bound_def blen_def by auto + have int_expand: \int (l + 1) * MLKEM_Q = int l * MLKEM_Q + MLKEM_Q\ + by (simp add: algebra_simps) + show ?thesis unfolding val int_expand + using ci fqmul_bound[OF cib] abs_triangle_ineq[of \cs!i\ \fqmul_int (cs!(i+blen)) z\] by linarith + next + case False + show ?thesis + proof (cases \j * (2 * blen) + blen \ i \ i < j * (2 * blen) + 2 * blen\) + case True + then obtain m where i_eq: \i = j * (2 * blen) + m + blen\ and m_lt: \m < blen\ + by (metis add.commute group_cancel.add1 le_add_diff_inverse left_add_twice nat_add_left_cancel_less) + have \ntt_inner_loop_int z (j * (2 * blen)) blen blen prev ! i = + prev ! (j * (2 * blen) + m) - fqmul_int (prev ! i) z\ + unfolding i_eq by (rule ntt_inner_loop_int_high_val[OF m_lt le_refl]) (use block_fits len_prev in simp) + also have \prev ! (j * (2 * blen) + m) = cs ! (j * (2 * blen) + m)\ + unfolding prev_def by (rule ntt_middle_loop_int_nth_unchanged) simp + also have \prev ! i = cs ! i\ + unfolding prev_def by (rule ntt_middle_loop_int_nth_unchanged) (use True in simp) + finally have val: \ntt_inner_loop_int z (j * (2 * blen)) blen blen prev ! i = + cs ! (j * (2 * blen) + m) - fqmul_int (cs ! i) z\ + . + have ci: \\cs ! (j * (2 * blen) + m)\ < int l * MLKEM_Q\ + using cb block_fits m_lt unfolding coeff_bound_def by auto + have cib: \\cs ! i\ < int l * MLKEM_Q\ + using cb i_lt unfolding coeff_bound_def by auto + have int_expand: \int (l + 1) * MLKEM_Q = int l * MLKEM_Q + MLKEM_Q\ + by (simp add: algebra_simps) + show ?thesis unfolding val int_expand + using ci fqmul_bound[OF cib] + abs_triangle_ineq[of \cs!(j*(2*blen)+m)\ \- fqmul_int (cs!i) z\] by linarith + next + case False + with \\ (j * (2 * blen) \ i \ i < j * (2 * blen) + blen)\ + have \i \ {j * (2 * blen).. + \i \ {j * (2 * blen) + blen.. + by auto + hence \ntt_inner_loop_int z (j * (2 * blen)) blen blen prev ! i = prev ! i\ + by (rule ntt_inner_loop_int_nth_unchanged) + moreover have \\prev ! i\ < int (l + 1) * MLKEM_Q\ + using ih i_lt len_prev unfolding coeff_bound_def by auto + ultimately show ?thesis + by simp + qed + qed + qed +qed + +theorem ntt_layer_int_bound: + assumes \coeff_bound (int l * MLKEM_Q) cs\ + and \1 \ l\ + and \l \ 7\ + and \length cs = MLKEM_N\ + shows \coeff_bound (int (l + 1) * MLKEM_Q) (ntt_layer_int l cs)\ +unfolding ntt_layer_int_def by (rule ntt_middle_loop_int_coeff_bound[OF assms]) simp + +(*<*) +lemma ntt_layer_int_length: + shows \length (ntt_layer_int l cs) = length cs\ +unfolding ntt_layer_int_def by (rule ntt_middle_loop_int_length) + +lemma ntt_outer_loop_int_length: + shows \length (ntt_outer_loop_int k n cs) = length cs\ +by (induction n arbitrary: k cs) (auto simp: case_prod_beta Let_def ntt_middle_loop_int_length) +(*>*) + +lemma ntt_outer_loop_int_bound: + assumes \coeff_bound (int l * MLKEM_Q) cs\ + and \1 \ l\ and \l + lr = 8\ + and \length cs = MLKEM_N\ + and \k = 2^(l-1)\ + shows \coeff_bound (int (l + lr) * MLKEM_Q) (ntt_outer_loop_int k lr cs)\ +using assms proof (induct lr arbitrary: l k cs) + case 0 + then show ?case + by simp +next + case (Suc lr) + define blen where \blen = (2::nat) ^ Suc lr\ + define nb where \nb = (2::nat) ^ (6 - lr)\ + obtain k' cs' where mid: \ntt_middle_loop_int k blen nb nb cs = (k', cs')\ + by (cases \ntt_middle_loop_int k blen nb nb cs\) + have cs'_eq: \cs' = snd (ntt_middle_loop_int k blen nb nb cs)\ + using mid by simp + have k'_eq: \k' = k + nb\ + using ntt_middle_loop_int_fst[of k blen nb nb cs] mid by simp + have len_cs': \length cs' = MLKEM_N\ + using cs'_eq ntt_middle_loop_int_length[of k blen nb nb cs] Suc.prems(4) by simp + have l_eq: \Suc lr = 8 - l\ + using Suc.prems(2,3) by simp + have exp_eq: \6 - lr = l - 1\ + proof - + from l_eq Suc.prems(2) have \l - 1 + lr = 6\ + by simp + thus ?thesis + by simp + qed + have nb_eq: \nb = 2^(l - 1)\ + unfolding nb_def using exp_eq by simp + have blen_eq: \blen = 2^(8-l)\ + unfolding blen_def using l_eq by simp + have k_eq: \k = 2^(l-1)\ + using Suc.prems(5) . + have cb': \coeff_bound (int (l+1) * MLKEM_Q) cs'\ + unfolding cs'_eq blen_eq nb_eq k_eq by (rule ntt_middle_loop_int_coeff_bound[OF Suc.prems(1,2)]) + (use Suc.prems in simp_all) + have k'_val: \k' = 2 ^ l\ + proof - + have \Suc (l - 1) = l\ + using Suc.prems(2) by simp + thus ?thesis + unfolding k'_eq k_eq nb_eq by (metis power_Suc mult_2) + qed + have unroll: \ntt_outer_loop_int k (Suc lr) cs = + ntt_outer_loop_int k' lr cs'\ + by (simp only: ntt_outer_loop_int.simps blen_def[symmetric] nb_def[symmetric] + Let_def case_prod_beta mid prod.sel) + show ?case unfolding unroll + apply (subst add_Suc_right) + apply (subst add_Suc[symmetric]) + apply (rule Suc.hyps) + apply (use cb' Suc.prems k'_val len_cs' in simp_all) + done +qed + +lemma poly_ntt_int_bound: + assumes \coeff_bound MLKEM_Q cs\ + and \length cs = MLKEM_N\ + shows \coeff_bound (8 * MLKEM_Q) (poly_ntt_int cs)\ +proof - + have \coeff_bound (int 1 * MLKEM_Q) cs\ + using assms(1) by simp + hence \coeff_bound (int (1 + 7) * MLKEM_Q) (ntt_outer_loop_int (2^(1-1)) 7 cs)\ + by (rule ntt_outer_loop_int_bound) (use assms(2) in simp_all) + thus ?thesis + unfolding poly_ntt_int_def by simp +qed + +subsection \No-Overflow From Coefficient Bounds\ + +text \Coefficient bound implies no overflow for one NTT layer. + If all coefficients are bounded by @{term \l * Q\} with @{term \l \ 7\}, + then each butterfly in the layer produces values in 16-bit range.\ + +theorem coeff_bound_implies_ntt_layer_no_overflow: + assumes cb: \coeff_bound (int l * MLKEM_Q) acs\ + and l_ge: \1 \ l\ + and l_le: \l \ 7\ + and len: \length acs = MLKEM_N\ + shows \ntt_layer_no_overflow l acs\ +proof - + define k0 where \k0 = (2::nat) ^ (l - 1)\ + define blen where \blen = (2::nat) ^ (8 - l)\ + have k0_2blen: \k0 * (2 * blen) = 256\ + proof - + have \l - 1 + (8 - l) = 7\ + using l_ge l_le by simp + hence \(2::nat) ^ (l - 1) * 2 ^ (8 - l) = 2 ^ 7\ + by (metis power_add) + thus ?thesis + unfolding k0_def blen_def by simp + qed + have inner: \ntt_inner_no_overflow (zetas_int ! (k0 + j)) (j * 2 * blen) blen blen + (snd (ntt_middle_loop_int k0 blen j j acs))\ + if j_lt: \j < k0\ for j + proof - + define zeta where \zeta = zetas_int ! (k0 + j)\ + define off where \off = j * 2 * blen\ + define prev where \prev = snd (ntt_middle_loop_int k0 blen j j acs)\ + have block_fits: \off + 2 * blen \ 256\ + proof - + from j_lt have \(j + 1) * (2 * blen) \ k0 * (2 * blen)\ + by (intro mult_le_mono1) simp + thus ?thesis + using k0_2blen unfolding off_def by (simp add: algebra_simps) + qed + have z_bound: \\zeta\ \ 1659\ + proof - + have \k0 + j < 128\ + proof - + have \2 * k0 \ 128\ + proof - + have \l - 1 \ 6\ + using l_le by simp + hence \(2::nat) ^ (l - 1) \ 2 ^ 6\ + by (intro power_increasing) simp_all + thus ?thesis + unfolding k0_def by simp + qed + with j_lt show ?thesis + by simp + qed + thus ?thesis + unfolding zeta_def by (rule zetas_int_abs_bound) + qed + have fqmul_bound: \\fqmul_int x zeta\ < MLKEM_Q\ if \\x\ < int l * MLKEM_Q\ for x + proof (rule fqmul_int_bound_Q) + have \\x * zeta\ = \x\ * \zeta\\ + by (rule abs_mult) + also have \\ \ (int l * 3329 - 1) * 1659\ + proof (rule mult_mono) + from that show \\x\ \ int l * 3329 - 1\ + by linarith + qed (use z_bound l_ge in auto) + also have \\ < 32768 * 3329\ + using l_ge l_le by simp + finally show \\x * zeta\ < 32768 * MLKEM_Q\ + by simp + qed + show ?thesis + unfolding ntt_inner_no_overflow_def Let_def zeta_def[symmetric] + off_def[symmetric] prev_def[symmetric] + proof (intro allI impI conjI) + fix m + assume m_lt: \m < blen\ + define cs' where \cs' = ntt_inner_loop_int zeta off blen m prev\ + define idx where \idx = off + m\ + \ \Positions are unchanged by inner loop (not yet processed)\ + have cs'_idx: \cs' ! idx = prev ! idx\ + unfolding cs'_def idx_def + by (rule ntt_inner_loop_int_nth_unchanged) (use m_lt in auto) + have cs'_idx_blen: \cs' ! (idx + blen) = prev ! (idx + blen)\ + unfolding cs'_def idx_def + by (rule ntt_inner_loop_int_nth_unchanged) (use m_lt in auto) + \ \Positions are unchanged by middle loop (in later block)\ + have prev_idx: \prev ! idx = acs ! idx\ + unfolding prev_def k0_def blen_def idx_def off_def + by (rule ntt_middle_loop_int_nth_unchanged) simp + have prev_idx_blen: \prev ! (idx + blen) = acs ! (idx + blen)\ + unfolding prev_def k0_def blen_def idx_def off_def + by (rule ntt_middle_loop_int_nth_unchanged) simp + \ \Index bounds\ + have idx_lt: \idx < 256\ + using m_lt block_fits unfolding idx_def by simp + have idx_blen_lt: \idx + blen < 256\ + using m_lt block_fits unfolding idx_def by simp + \ \Original values are bounded\ + have cb_idx: \\acs ! idx\ < int l * MLKEM_Q\ + using cb idx_lt len unfolding coeff_bound_def by auto + have cb_idx_blen: \\acs ! (idx + blen)\ < int l * MLKEM_Q\ + using cb idx_blen_lt len unfolding coeff_bound_def by auto + \ \The fqmul result\ + define t where \t = fqmul_int (acs ! (idx + blen)) zeta\ + have t_eq: \fqmul_int (cs' ! (idx + blen)) zeta = t\ + using cs'_idx_blen prev_idx_blen unfolding t_def by simp + have t_bound: \\t\ < MLKEM_Q\ + unfolding t_def by (rule fqmul_bound[OF cb_idx_blen]) + \ \Range bounds\ + have range: \(int l + 1) * MLKEM_Q \ 32768\ + using l_le by simp + have val_idx: \cs' ! idx = acs ! idx\ + using cs'_idx prev_idx by simp + have plus_bound: \\acs ! idx + t\ < (int l + 1) * MLKEM_Q\ + proof - + have \\acs ! idx + t\ \ \acs ! idx\ + \t\\ + by (rule abs_triangle_ineq) + also have \\ < int l * MLKEM_Q + MLKEM_Q\ + using cb_idx t_bound by linarith + finally show ?thesis + by (simp add: ring_distribs) + qed + have minus_bound: \\acs ! idx - t\ < (int l + 1) * MLKEM_Q\ + proof - + have \\acs ! idx - t\ \ \acs ! idx\ + \t\\ + using abs_triangle_ineq[of \acs ! idx\ \-t\] by simp + also have \\ < int l * MLKEM_Q + MLKEM_Q\ + using cb_idx t_bound by linarith + finally show ?thesis + by (simp add: ring_distribs) + qed + show \- 32768 \ ntt_inner_loop_int zeta off blen m prev ! (off + m) + + fqmul_int (ntt_inner_loop_int zeta off blen m prev ! (off + m + blen)) zeta\ + using plus_bound range val_idx t_eq unfolding cs'_def idx_def by linarith + show \ntt_inner_loop_int zeta off blen m prev ! (off + m) + + fqmul_int (ntt_inner_loop_int zeta off blen m prev ! (off + m + blen)) zeta \ 32767\ + using plus_bound range val_idx t_eq unfolding cs'_def idx_def by linarith + show \- 32768 \ ntt_inner_loop_int zeta off blen m prev ! (off + m) - + fqmul_int (ntt_inner_loop_int zeta off blen m prev ! (off + m + blen)) zeta\ + using minus_bound range val_idx t_eq unfolding cs'_def idx_def by linarith + show \ntt_inner_loop_int zeta off blen m prev ! (off + m) - + fqmul_int (ntt_inner_loop_int zeta off blen m prev ! (off + m + blen)) zeta \ 32767\ + using minus_bound range val_idx t_eq unfolding cs'_def idx_def by linarith + qed + qed + thus ?thesis + unfolding ntt_layer_no_overflow_def Let_def + k0_def[symmetric] blen_def[symmetric] by auto +qed + +subsection \Outer Loop Composition\ + +text \Outer loop layer composition.\ + +lemma ntt_outer_loop_int_layer: + assumes \layer_rem \ 7\ + shows \ntt_outer_loop_int k (Suc layer_rem) cs = + (let blen = 2 ^ (Suc layer_rem); + nb = 2 ^ (6 - layer_rem); + (k', cs') = ntt_middle_loop_int k blen nb nb cs + in ntt_outer_loop_int k' layer_rem cs')\ +by simp + +text \Stepping the outer loop: one layer via @{const ntt_layer_int}.\ + +lemma ntt_outer_loop_step_layer: + assumes \1 \ l\ \l \ 7\ + shows \ntt_outer_loop_int (2^(l-1)) (8-l) cs = + ntt_outer_loop_int (2^l) (7-l) (ntt_layer_int l cs)\ +proof - + define blen where \blen = (2::nat) ^ Suc (7 - l)\ + define nb where \nb = (2::nat) ^ (6 - (7 - l))\ + from assms have sl: \8 - l = Suc (7 - l)\ + by simp + from assms have exp_eq: \6 - (7 - l) = l - 1\ + by simp + obtain k' cs' where mid: \ntt_middle_loop_int (2^(l-1)) blen nb nb cs = (k', cs')\ + by (cases \ntt_middle_loop_int (2^(l-1)) blen nb nb cs\) + have nb_eq: \nb = 2^(l-1)\ + unfolding nb_def using exp_eq by simp + have blen_eq: \blen = 2^(8-l)\ + unfolding blen_def using sl by simp + have cs'_eq: \cs' = ntt_layer_int l cs\ + using mid unfolding ntt_layer_int_def nb_eq blen_eq by simp + have k'_eq: \k' = 2^l\ + proof - + have \k' = 2^(l-1) + 2^(l-1)\ + using ntt_middle_loop_int_fst[of \2^(l-1)\ blen nb nb cs] mid nb_eq by simp + thus ?thesis using assms(1) + by (simp add: mult_2[symmetric] power_Suc[symmetric]) + qed + show ?thesis + unfolding sl by (simp only: ntt_outer_loop_int.simps blen_def[symmetric] nb_def[symmetric] + Let_def case_prod_beta mid prod.sel k'_eq cs'_eq) +qed + +(*<*) +end +(*>*) From d4d4645bb629a88b5e060e1ba1e133ef7f2dce4c Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:00:07 +0000 Subject: [PATCH 05/11] proofs/isabelle: prove NTT/inverse-NTT are two-sided inverses mod q Add MLKEM_NTT_Correctness.thy with the 7-tier proof establishing that forward and inverse NTT compose to scaling by 2^16 mod q. Final theorems: poly_ntt_invntt_tomont and poly_invntt_tomont_ntt. --- proofs/isabelle/MLKEM_NTT_Correctness.thy | 2426 +++++++++++++++++++++ 1 file changed, 2426 insertions(+) create mode 100644 proofs/isabelle/MLKEM_NTT_Correctness.thy diff --git a/proofs/isabelle/MLKEM_NTT_Correctness.thy b/proofs/isabelle/MLKEM_NTT_Correctness.thy new file mode 100644 index 0000000000..f424f03a7c --- /dev/null +++ b/proofs/isabelle/MLKEM_NTT_Correctness.thy @@ -0,0 +1,2426 @@ +(*<*) +theory MLKEM_NTT_Correctness + imports MLKEM_InvNTT_Spec +begin +(*>*) + +section \NTT / Inverse-NTT Inverse Relationship\ + +text \ + The forward NTT and inverse NTT with Montgomery prescaling are two-sided + inverses modulo @{const MLKEM_Q}, up to a multiplicative factor of + @{text "2^16"} (the Montgomery radix R). The proof is structured in + tiers, from low-level infrastructure through butterfly, loop, layer, + and linearity lemmas, culminating in the full-transform composition + theorems @{text poly_ntt_invntt_tomont} and @{text poly_invntt_tomont_ntt}. +\ + +text_raw \ +\begin{figure}[ht] +\centering +\begin{tikzpicture}[>=Stealth, + tier/.style={draw=mlkblue, fill=mlklightblue, rounded corners, + minimum width=3.8cm, minimum height=0.7cm, + align=center, font=\small}, + node distance=0.6cm] + \node[tier] (t0) {Tier 0: Infrastructure}; + \node[tier, below=of t0] (t1) {Tier 1: Zeta Properties}; + \node[tier, below=of t1] (t2) {Tier 2: Bound Propagation}; + \node[tier, below=of t2] (t3) {Tier 3: Butterfly Composition}; + \node[tier, below=of t3] (t4) {Tier 4: Loop Composition}; + \node[tier, below=of t4] (t5) {Tier 5: Layer Inverse}; + \node[tier, below=of t5] (t6) {Tier 6: Linearity}; + \node[tier, below=of t6] (t7) {Tier 7: Full Composition}; + \draw[->,thick,mlkblue] (t0) -- (t1); + \draw[->,thick,mlkblue] (t1) -- (t2); + \draw[->,thick,mlkblue] (t2) -- (t3); + \draw[->,thick,mlkblue] (t3) -- (t4); + \draw[->,thick,mlkblue] (t4) -- (t5); + \draw[->,thick,mlkblue] (t5) -- (t6); + \draw[->,thick,mlkblue] (t6) -- (t7); +\end{tikzpicture} +\caption{Proof tier hierarchy. Each tier builds on results from the tier + above; the final composition theorems at Tier~7 connect the forward and + inverse NTT as two-sided inverses modulo~$q$.} +\label{fig:proof-tiers} +\end{figure} +\ + +subsection \Tier 0 — Infrastructure\ + +(*<*) +lemma centered_mod_mod: + assumes \q > 0\ + shows \centered_mod x q mod q = x mod q\ +unfolding centered_mod_def Let_def by (simp add: mod_diff_right_eq) + +lemma fqmul_int_cong: + shows \fqmul_int a b * 2^16 mod MLKEM_Q = a * b mod MLKEM_Q\ +unfolding fqmul_int_def by (rule montgomery_reduce_int_correct) +(*>*) + +subsection \Tier 1 — Zeta Properties\ + +(*<*) + +lemma zetas_int_mod: + assumes \i < 128\ + shows \zetas_int ! i mod MLKEM_Q = mlkem_zeta ^ bit_reverse 7 i * 2^16 mod MLKEM_Q\ +using assms by (simp add: zetas_int_roots_of_unity centered_mod_mod) + +text \The product of conjugate zeta powers equals @{text "-1 mod Q"}, i.e. + @{text "\^{br(k)} \ \^{br(k')} \ -1 (mod Q)"} when bit-reversed indices + sum to 128.\ + +lemma zeta_power_sum_128: + assumes \a + b = 128\ + shows \mlkem_zeta ^ a * mlkem_zeta ^ b mod MLKEM_Q = MLKEM_Q - 1\ +proof - + have \mlkem_zeta ^ a * mlkem_zeta ^ b = mlkem_zeta ^ (a + b)\ + by (simp add: power_add) + thus ?thesis + using assms mlkem_zeta_half_order by simp +qed + +lemma bit_reverse_bound: + assumes \k < 2^n\ + shows \bit_reverse n k < 2^n\ +using assms proof (induction n arbitrary: k) + case 0 + then show ?case by simp +next + case (Suc n) + have \k div 2 < 2^n\ + using Suc.prems by simp + hence IH: \bit_reverse n (k div 2) < 2^n\ + by (rule Suc.IH) + have \k mod 2 * 2^n + bit_reverse n (k div 2) < 2 * 2^n\ + proof - + have \k mod 2 \ 1\ + by simp + hence \k mod 2 * 2^n \ 2^n\ + by (simp add: mult_le_cancel2) + thus ?thesis + using IH by linarith + qed + thus ?case + by simp +qed + +text \Bit-reverse complement: for indices in the NTT butterfly pattern, + the bit-reversed forward and inverse zeta indices sum to 128.\ + +lemma bit_reverse_complement: + assumes \j < 2^(l-1)\ + and \1 \ l\ + and \l \ 7\ + shows \bit_reverse 7 (2^(l-1) + j) + bit_reverse 7 (2^l - 1 - j) = 128\ +proof - + have \list_all (\l. list_all + (\j. bit_reverse 7 (2^(l-1) + j) + bit_reverse 7 (2^l - 1 - j) = 128) + [0..<2^(l-1)]) [1..<8::nat]\ + by eval + thus ?thesis + using assms by (auto simp: list_all_iff) +qed +(*>*) + +subsection \Tier 2 — Butterfly Values and Montgomery Cancellation\ + +text \Montgomery-radix inverse: @{text "2^16 \ 169 \ 1 (mod Q)"}.\ + +lemma R_inv_mod_Q: + shows \(2::int)^16 * 169 mod MLKEM_Q = 1\ +by eval + +text \Key cancellation: when a zeta argument carries a Montgomery factor + @{text "2^16"}, it cancels with the @{text "R^{-1}"} inside @{const fqmul_int}.\ + +lemma fqmul_zeta_cancel: + assumes \zeta mod MLKEM_Q = z_math * 2^16 mod MLKEM_Q\ + shows \fqmul_int a zeta mod MLKEM_Q = a * z_math mod MLKEM_Q\ +proof - + have eq: \fqmul_int a zeta * 2^16 mod MLKEM_Q = a * z_math * 2^16 mod MLKEM_Q\ + proof - + have \fqmul_int a zeta * 2^16 mod MLKEM_Q = a * zeta mod MLKEM_Q\ + by (rule fqmul_int_cong) + also have \a * zeta mod MLKEM_Q = a * (zeta mod MLKEM_Q) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = a * (z_math * 2^16 mod MLKEM_Q) mod MLKEM_Q\ + using assms by simp + also have \\ = a * (z_math * 2^16) mod MLKEM_Q\ + by (rule mod_mult_right_eq) + also have \\ = a * z_math * 2^16 mod MLKEM_Q\ + by (simp add: mult.assoc) + finally show ?thesis . + qed + from eq show ?thesis + proof - + have lhs: \fqmul_int a zeta mod MLKEM_Q = (fqmul_int a zeta * 2^16) * 169 mod MLKEM_Q\ + proof - + have \fqmul_int a zeta mod MLKEM_Q = fqmul_int a zeta * 1 mod MLKEM_Q\ + by simp + also have \\ = fqmul_int a zeta * (2^16 * 169 mod MLKEM_Q) mod MLKEM_Q\ + using R_inv_mod_Q by simp + also have \\ = fqmul_int a zeta * (2^16 * 169) mod MLKEM_Q\ + by (rule mod_mult_right_eq) + finally show ?thesis + by (simp add: mult.assoc) + qed + have rhs: \a * z_math mod MLKEM_Q = (a * z_math * 2^16) * 169 mod MLKEM_Q\ + proof - + have \a * z_math mod MLKEM_Q = a * z_math * 1 mod MLKEM_Q\ + by simp + also have \\ = a * z_math * (2^16 * 169 mod MLKEM_Q) mod MLKEM_Q\ + using R_inv_mod_Q by simp + also have \\ = a * z_math * (2^16 * 169) mod MLKEM_Q\ + by (rule mod_mult_right_eq) + finally show ?thesis + by (simp add: mult.assoc) + qed + have mid: \(fqmul_int a zeta * 2^16) * 169 mod MLKEM_Q = (a * z_math * 2^16) * 169 mod MLKEM_Q\ + proof - + have \(fqmul_int a zeta * 2^16) * 169 mod MLKEM_Q = + (fqmul_int a zeta * 2^16 mod MLKEM_Q) * 169 mod MLKEM_Q\ + by (rule mod_mult_left_eq[symmetric]) + also have \\ = (a * z_math * 2^16 mod MLKEM_Q) * 169 mod MLKEM_Q\ + using eq by simp + also have \\ = (a * z_math * 2^16) * 169 mod MLKEM_Q\ + by (rule mod_mult_left_eq) + finally show ?thesis . + qed + show ?thesis + using lhs mid rhs by simp + qed +qed + +text \Butterfly exact values — unfolding the list-update structure.\ + +lemma ntt_butterfly_int_val_low: + assumes \j + blen < length cs\ + and \blen > 0\ + shows \ntt_butterfly_int zeta j blen cs ! j = + cs ! j + fqmul_int (cs ! (j + blen)) zeta\ +unfolding ntt_butterfly_int_def Let_def using assms by simp + +lemma ntt_butterfly_int_val_high: + assumes \j + blen < length cs\ + and \blen > 0\ + shows \ntt_butterfly_int zeta j blen cs ! (j + blen) = + cs ! j - fqmul_int (cs ! (j + blen)) zeta\ +unfolding ntt_butterfly_int_def Let_def using assms by auto + +lemma invntt_butterfly_int_val_low: + assumes \j + blen < length cs\ + and \blen > 0\ + shows \invntt_butterfly_int zeta j blen cs ! j = + barrett_reduce_int (cs ! j + cs ! (j + blen))\ +unfolding invntt_butterfly_int_def Let_def using assms by auto + +lemma invntt_butterfly_int_val_high: + assumes \j + blen < length cs\ + and \blen > 0\ + shows \invntt_butterfly_int zeta j blen cs ! (j + blen) = + fqmul_int (cs ! (j + blen) - cs ! j) zeta\ +unfolding invntt_butterfly_int_def Let_def using assms by simp + +text \Mod-Q butterfly characterization (combining exact values with + @{thm fqmul_zeta_cancel} and @{thm barrett_reduce_mod}).\ + +lemma ntt_butterfly_int_mod_low: + assumes \j + blen < length cs\ + and \blen > 0\ + and \zeta mod MLKEM_Q = z_math * 2^16 mod MLKEM_Q\ + shows \ntt_butterfly_int zeta j blen cs ! j mod MLKEM_Q = + (cs ! j + cs ! (j + blen) * z_math) mod MLKEM_Q\ +using ntt_butterfly_int_val_low[OF assms(1,2)] + fqmul_zeta_cancel[OF assms(3), of \cs ! (j + blen)\] +proof - + assume val: \\zeta. ntt_butterfly_int zeta j blen cs ! j = cs ! j + fqmul_int (cs ! (j + blen)) zeta\ + assume cancel: \fqmul_int (cs ! (j + blen)) zeta mod MLKEM_Q = cs ! (j + blen) * z_math mod MLKEM_Q\ + have \ntt_butterfly_int zeta j blen cs ! j mod MLKEM_Q = + (cs ! j + fqmul_int (cs ! (j + blen)) zeta) mod MLKEM_Q\ + using val by simp + also have \\ = (cs ! j + fqmul_int (cs ! (j + blen)) zeta mod MLKEM_Q) mod MLKEM_Q\ + by (rule mod_add_right_eq[symmetric]) + also have \\ = (cs ! j + cs ! (j + blen) * z_math mod MLKEM_Q) mod MLKEM_Q\ + using cancel by simp + also have \\ = (cs ! j + cs ! (j + blen) * z_math) mod MLKEM_Q\ + by (rule mod_add_right_eq) + finally show ?thesis . +qed + +lemma ntt_butterfly_int_mod_high: + assumes \j + blen < length cs\ + and \blen > 0\ + and \zeta mod MLKEM_Q = z_math * 2^16 mod MLKEM_Q\ + shows \ntt_butterfly_int zeta j blen cs ! (j + blen) mod MLKEM_Q = + (cs ! j - cs ! (j + blen) * z_math) mod MLKEM_Q\ +using ntt_butterfly_int_val_high[OF assms(1,2)] + fqmul_zeta_cancel[OF assms(3), of \cs ! (j + blen)\] +proof - + assume val: \\zeta. ntt_butterfly_int zeta j blen cs ! (j + blen) = cs ! j - fqmul_int (cs ! (j + blen)) zeta\ + assume cancel: \fqmul_int (cs ! (j + blen)) zeta mod MLKEM_Q = cs ! (j + blen) * z_math mod MLKEM_Q\ + have \ntt_butterfly_int zeta j blen cs ! (j + blen) mod MLKEM_Q = + (cs ! j - fqmul_int (cs ! (j + blen)) zeta) mod MLKEM_Q\ + using val by simp + also have \\ = (cs ! j - fqmul_int (cs ! (j + blen)) zeta mod MLKEM_Q) mod MLKEM_Q\ + by (rule mod_diff_right_eq[symmetric]) + also have \\ = (cs ! j - cs ! (j + blen) * z_math mod MLKEM_Q) mod MLKEM_Q\ + using cancel by simp + also have \\ = (cs ! j - cs ! (j + blen) * z_math) mod MLKEM_Q\ + by (rule mod_diff_right_eq) + finally show ?thesis . +qed + +lemma invntt_butterfly_int_mod_low: + assumes \j + blen < length cs\ + and \blen > 0\ + shows \invntt_butterfly_int zeta j blen cs ! j mod MLKEM_Q = + (cs ! j + cs ! (j + blen)) mod MLKEM_Q\ +using invntt_butterfly_int_val_low[OF assms] barrett_reduce_mod by simp + +lemma invntt_butterfly_int_mod_high: + assumes \j + blen < length cs\ + and \blen > 0\ + and \zeta mod MLKEM_Q = z_math * 2^16 mod MLKEM_Q\ + shows \invntt_butterfly_int zeta j blen cs ! (j + blen) mod MLKEM_Q = + (cs ! (j + blen) - cs ! j) * z_math mod MLKEM_Q\ +using invntt_butterfly_int_val_high[OF assms(1,2)] + fqmul_zeta_cancel[OF assms(3), of \cs ! (j + blen) - cs ! j\] by simp + +subsection \Tier 3 — Butterfly Composition\ + +text \Composing inverse-then-forward (or forward-then-inverse) butterflies + with conjugate zetas doubles each coefficient mod Q.\ + +lemma butterfly_inverse_composition: + assumes \j + blen < length cs\ + and \blen > 0\ + and \mlkem_zeta ^ br_fwd * mlkem_zeta ^ br_inv mod MLKEM_Q = MLKEM_Q - 1\ + and \zf mod MLKEM_Q = mlkem_zeta ^ br_fwd * 2^16 mod MLKEM_Q\ + and \zi mod MLKEM_Q = mlkem_zeta ^ br_inv * 2^16 mod MLKEM_Q\ + shows \ntt_butterfly_int zf j blen (invntt_butterfly_int zi j blen cs) ! j mod MLKEM_Q = + 2 * cs ! j mod MLKEM_Q\ + and \ntt_butterfly_int zf j blen (invntt_butterfly_int zi j blen cs) ! (j + blen) mod MLKEM_Q = + 2 * cs ! (j + blen) mod MLKEM_Q\ +proof - + define cs' where + \cs' = invntt_butterfly_int zi j blen cs\ + let ?D = \cs ! (j + blen) - cs ! j\ + let ?Zf = \mlkem_zeta ^ br_fwd\ + let ?Zi = \mlkem_zeta ^ br_inv\ + have jb: \j + blen < length cs'\ + using assms(1) unfolding cs'_def by (simp add: invntt_butterfly_int_length) + have cs'_j_mod: \cs' ! j mod MLKEM_Q = (cs ! j + cs ! (j + blen)) mod MLKEM_Q\ + unfolding cs'_def using invntt_butterfly_int_mod_low[OF assms(1,2)] by simp + have cs'_jb_mod: \cs' ! (j + blen) mod MLKEM_Q = ?D * ?Zi mod MLKEM_Q\ + unfolding cs'_def using invntt_butterfly_int_mod_high[OF assms(1,2,5)] by simp + have product_mod: \(cs' ! (j + blen) * ?Zf) mod MLKEM_Q = (?D * (MLKEM_Q - 1)) mod MLKEM_Q\ + proof - + have \(cs' ! (j + blen) * ?Zf) mod MLKEM_Q = + (cs' ! (j + blen) mod MLKEM_Q * ?Zf) mod MLKEM_Q\ + by (rule mod_mult_left_eq[symmetric]) + also have \\ = (?D * ?Zi mod MLKEM_Q * ?Zf) mod MLKEM_Q\ + using cs'_jb_mod by simp + also have \\ = (?D * ?Zi * ?Zf) mod MLKEM_Q\ + by (rule mod_mult_left_eq) + also have \\ = (?D * (?Zi * ?Zf)) mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = (?D * (?Zi * ?Zf mod MLKEM_Q)) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = (?D * (?Zf * ?Zi mod MLKEM_Q)) mod MLKEM_Q\ + by (simp add: mult.commute[of ?Zi ?Zf]) + also have \\ = (?D * (MLKEM_Q - 1)) mod MLKEM_Q\ + using assms(3) by simp + finally show ?thesis . + qed + show \ntt_butterfly_int zf j blen (invntt_butterfly_int zi j blen cs) ! j mod MLKEM_Q = + 2 * cs ! j mod MLKEM_Q\ + proof - + have \ntt_butterfly_int zf j blen cs' ! j mod MLKEM_Q = + (cs' ! j + cs' ! (j + blen) * ?Zf) mod MLKEM_Q\ + using ntt_butterfly_int_mod_low[OF jb assms(2,4)] by simp + also have \\ = ((cs ! j + cs ! (j + blen)) + ?D * (MLKEM_Q - 1)) mod MLKEM_Q\ + by (rule mod_add_cong[OF cs'_j_mod product_mod]) + also have \\ = 2 * cs ! j mod MLKEM_Q\ + by (simp add: mod_eq_dvd_iff algebra_simps) + finally show ?thesis + unfolding cs'_def . + qed + show \ntt_butterfly_int zf j blen (invntt_butterfly_int zi j blen cs) ! (j + blen) mod MLKEM_Q = + 2 * cs ! (j + blen) mod MLKEM_Q\ + proof - + have \ntt_butterfly_int zf j blen cs' ! (j + blen) mod MLKEM_Q = + (cs' ! j - cs' ! (j + blen) * ?Zf) mod MLKEM_Q\ + using ntt_butterfly_int_mod_high[OF jb assms(2,4)] by simp + also have \\ = ((cs ! j + cs ! (j + blen)) - ?D * (MLKEM_Q - 1)) mod MLKEM_Q\ + by (rule mod_diff_cong[OF cs'_j_mod product_mod]) + also have \\ = 2 * cs ! (j + blen) mod MLKEM_Q\ + by (simp add: mod_eq_dvd_iff algebra_simps) + finally show ?thesis + unfolding cs'_def . + qed +qed + +lemma butterfly_forward_composition: + assumes \j + blen < length cs\ + and \blen > 0\ + and \mlkem_zeta ^ br_fwd * mlkem_zeta ^ br_inv mod MLKEM_Q = MLKEM_Q - 1\ + and \zf mod MLKEM_Q = mlkem_zeta ^ br_fwd * 2^16 mod MLKEM_Q\ + and \zi mod MLKEM_Q = mlkem_zeta ^ br_inv * 2^16 mod MLKEM_Q\ + shows \invntt_butterfly_int zi j blen (ntt_butterfly_int zf j blen cs) ! j mod MLKEM_Q = + 2 * cs ! j mod MLKEM_Q\ + and \invntt_butterfly_int zi j blen (ntt_butterfly_int zf j blen cs) ! (j + blen) mod MLKEM_Q = + 2 * cs ! (j + blen) mod MLKEM_Q\ +proof - + define cs' where + \cs' = ntt_butterfly_int zf j blen cs\ + have jb: \j + blen < length cs'\ + using assms(1) unfolding cs'_def by (simp add: ntt_butterfly_int_length) + have val_low: \cs' ! j = cs ! j + fqmul_int (cs ! (j + blen)) zf\ + unfolding cs'_def using ntt_butterfly_int_val_low[OF assms(1,2)] by simp + have val_high: \cs' ! (j + blen) = cs ! j - fqmul_int (cs ! (j + blen)) zf\ + unfolding cs'_def using ntt_butterfly_int_val_high[OF assms(1,2)] by simp + have sum_eq: \cs' ! j + cs' ! (j + blen) = 2 * cs ! j\ + using val_low val_high by simp + have diff_eq: \cs' ! (j + blen) - cs' ! j = (-2) * fqmul_int (cs ! (j + blen)) zf\ + using val_low val_high by simp + have fqmul_cancel: \fqmul_int (cs ! (j + blen)) zf mod MLKEM_Q = + cs ! (j + blen) * (mlkem_zeta ^ br_fwd) mod MLKEM_Q\ + using fqmul_zeta_cancel[OF assms(4), of \cs ! (j + blen)\] . + have neg2_Q_mod: \(-2) * (x::int) * (MLKEM_Q - 1) mod MLKEM_Q = 2 * x mod MLKEM_Q\ for x + by (simp add: mod_eq_dvd_iff algebra_simps) + show \invntt_butterfly_int zi j blen (ntt_butterfly_int zf j blen cs) ! j mod MLKEM_Q = + 2 * cs ! j mod MLKEM_Q\ + proof - + have \invntt_butterfly_int zi j blen cs' ! j mod MLKEM_Q = + (cs' ! j + cs' ! (j + blen)) mod MLKEM_Q\ + using invntt_butterfly_int_mod_low[OF jb assms(2)] by simp + also have \\ = 2 * cs ! j mod MLKEM_Q\ + using sum_eq by simp + finally show ?thesis + unfolding cs'_def . + qed + show \invntt_butterfly_int zi j blen (ntt_butterfly_int zf j blen cs) ! (j + blen) mod MLKEM_Q = + 2 * cs ! (j + blen) mod MLKEM_Q\ + proof - + let ?X = \cs ! (j + blen)\ + let ?Zf = \mlkem_zeta ^ br_fwd\ + let ?Zi = \mlkem_zeta ^ br_inv\ + have \invntt_butterfly_int zi j blen cs' ! (j + blen) mod MLKEM_Q = + (cs' ! (j + blen) - cs' ! j) * ?Zi mod MLKEM_Q\ + using invntt_butterfly_int_mod_high[OF jb assms(2,5)] by simp + also have \\ = (-2) * fqmul_int ?X zf * ?Zi mod MLKEM_Q\ + using diff_eq by simp + also have \\ = (-2) * (?X * ?Zf) * ?Zi mod MLKEM_Q\ + proof - + have \(-2) * fqmul_int ?X zf * ?Zi mod MLKEM_Q = + (((-2) * fqmul_int ?X zf mod MLKEM_Q) * ?Zi) mod MLKEM_Q\ + by (rule mod_mult_left_eq[symmetric]) + also have \(-2) * fqmul_int ?X zf mod MLKEM_Q = (-2) * (?X * ?Zf) mod MLKEM_Q\ + proof - + have \(-2) * fqmul_int ?X zf mod MLKEM_Q = + ((-2) * (fqmul_int ?X zf mod MLKEM_Q)) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = ((-2) * (?X * ?Zf mod MLKEM_Q)) mod MLKEM_Q\ + using fqmul_cancel by simp + also have \\ = (-2) * (?X * ?Zf) mod MLKEM_Q\ + by (rule mod_mult_right_eq) + finally show ?thesis . + qed + hence \(((-2) * fqmul_int ?X zf mod MLKEM_Q) * ?Zi) mod MLKEM_Q = + (((-2) * (?X * ?Zf) mod MLKEM_Q) * ?Zi) mod MLKEM_Q\ + by simp + also have \\ = ((-2) * (?X * ?Zf) * ?Zi) mod MLKEM_Q\ + by (rule mod_mult_left_eq) + finally show ?thesis + by simp + qed + also have \\ = (-2) * ?X * (?Zf * ?Zi) mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = (-2) * ?X * (MLKEM_Q - 1) mod MLKEM_Q\ + proof - + have \(-2) * ?X * (?Zf * ?Zi) mod MLKEM_Q = + ((-2) * ?X * (?Zf * ?Zi mod MLKEM_Q)) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = ((-2) * ?X * (MLKEM_Q - 1)) mod MLKEM_Q\ + using assms(3) by simp + finally show ?thesis . + qed + also have \\ = 2 * ?X mod MLKEM_Q\ + by (rule neg2_Q_mod) + finally show ?thesis + unfolding cs'_def . + qed +qed + +subsection \Tier 4 — Loop Composition\ + +text \Composing forward and inverse inner/middle loops with matching zeta + pairs yields pointwise scaling by 2 mod Q.\ + +text \Helper: composing two @{const fqmul_int} with conjugate zetas cancels + Montgomery factors and yields multiplication by @{text "Q - 1 mod Q"}.\ + +lemma fqmul_fqmul_cancel: + assumes \zf mod MLKEM_Q = mlkem_zeta ^ br_fwd * 2^16 mod MLKEM_Q\ + and \zi mod MLKEM_Q = mlkem_zeta ^ br_inv * 2^16 mod MLKEM_Q\ + and \mlkem_zeta ^ br_fwd * mlkem_zeta ^ br_inv mod MLKEM_Q = MLKEM_Q - 1\ + shows \fqmul_int (fqmul_int a zi) zf mod MLKEM_Q = + a * (MLKEM_Q - 1) mod MLKEM_Q\ +proof - + let ?Zf = \mlkem_zeta ^ br_fwd\ + let ?Zi = \mlkem_zeta ^ br_inv\ + have \fqmul_int (fqmul_int a zi) zf mod MLKEM_Q = + fqmul_int a zi * ?Zf mod MLKEM_Q\ + by (rule fqmul_zeta_cancel[OF assms(1)]) + also have \\ = (fqmul_int a zi mod MLKEM_Q * ?Zf) mod MLKEM_Q\ + by (rule mod_mult_left_eq[symmetric]) + also have \\ = (a * ?Zi mod MLKEM_Q * ?Zf) mod MLKEM_Q\ + using fqmul_zeta_cancel[OF assms(2)] by simp + also have \\ = (a * ?Zi * ?Zf) mod MLKEM_Q\ + by (rule mod_mult_left_eq) + also have \\ = (a * (?Zf * ?Zi)) mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = (a * (?Zf * ?Zi mod MLKEM_Q)) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = a * (MLKEM_Q - 1) mod MLKEM_Q\ + using assms(3) by simp + finally show ?thesis . +qed + +lemma inner_loop_composition_invntt_ntt: + assumes \length cs = MLKEM_N\ + and \off + 2 * blen \ MLKEM_N\ + and \cnt \ blen\ + and \mlkem_zeta ^ br_fwd * mlkem_zeta ^ br_inv mod MLKEM_Q = MLKEM_Q - 1\ + and \zf mod MLKEM_Q = mlkem_zeta ^ br_fwd * 2^16 mod MLKEM_Q\ + and \zi mod MLKEM_Q = mlkem_zeta ^ br_inv * 2^16 mod MLKEM_Q\ + and \i < length cs\ + shows \ntt_inner_loop_int zf off blen cnt (invntt_inner_loop_int zi off blen cnt cs) ! i mod MLKEM_Q = + (if i \ {off.. i \ {off+blen.. +proof - + define inv_cs where + \inv_cs = invntt_inner_loop_int zi off blen cnt cs\ + have len_inv: \length inv_cs = length cs\ + by (simp add: inv_cs_def invntt_inner_loop_int_length) + have off_bound: \off + 2 * blen \ length cs\ + using assms(1,2) by simp + have case_low: \i \ {off.. ?thesis\ + proof - + assume low: \i \ {off.. + define m where + \m = i - off\ + have im1: \i = off + m\ and im2: \m < cnt\ + using low by (auto simp: m_def) + let ?a = \cs ! (off + m)\ + let ?b = \cs ! (off + m + blen)\ + let ?D = \?b - ?a\ + have ntt_val: \ntt_inner_loop_int zf off blen cnt inv_cs ! (off + m) = + inv_cs ! (off + m) + fqmul_int (inv_cs ! (off + m + blen)) zf\ + by (rule ntt_inner_loop_int_low_val[OF im2 assms(3)]) + (use len_inv off_bound in simp) + have inv_low: \inv_cs ! (off + m) = barrett_reduce_int (?a + ?b)\ + unfolding inv_cs_def by (rule invntt_inner_loop_int_low_val[OF im2 assms(3) off_bound]) + have inv_high: \inv_cs ! (off + m + blen) = fqmul_int ?D zi\ + unfolding inv_cs_def by (rule invntt_inner_loop_int_high_val[OF im2 assms(3) off_bound]) + have \ntt_inner_loop_int zf off blen cnt inv_cs ! (off + m) mod MLKEM_Q = + (barrett_reduce_int (?a + ?b) + fqmul_int (fqmul_int ?D zi) zf) mod MLKEM_Q\ + using ntt_val inv_low inv_high by simp + also have \\ = ((?a + ?b) + ?D * (MLKEM_Q - 1)) mod MLKEM_Q\ + by (rule mod_add_cong[OF barrett_reduce_mod fqmul_fqmul_cancel[OF assms(5,6,4)]]) + also have \\ = 2 * ?a mod MLKEM_Q\ + by (simp add: mod_eq_dvd_iff algebra_simps) + finally show ?thesis + using im1 im2 low by (simp add: inv_cs_def) + qed + have case_high: \i \ {off+blen.. i \ {off.. ?thesis\ + proof - + assume high: \i \ {off+blen.. + and not_low: \i \ {off.. + define m where + \m = i - off - blen\ + have im1: \i = off + m + blen\ and im2: \m < cnt\ + using high assms(3) by (auto simp: m_def) + let ?a = \cs ! (off + m)\ + let ?b = \cs ! (off + m + blen)\ + let ?D = \?b - ?a\ + have ntt_val: \ntt_inner_loop_int zf off blen cnt inv_cs ! (off + m + blen) = + inv_cs ! (off + m) - fqmul_int (inv_cs ! (off + m + blen)) zf\ + by (rule ntt_inner_loop_int_high_val[OF im2 assms(3)]) + (use len_inv off_bound in simp) + have inv_low: \inv_cs ! (off + m) = barrett_reduce_int (?a + ?b)\ + unfolding inv_cs_def by (rule invntt_inner_loop_int_low_val[OF im2 assms(3) off_bound]) + have inv_high: \inv_cs ! (off + m + blen) = fqmul_int ?D zi\ + unfolding inv_cs_def by (rule invntt_inner_loop_int_high_val[OF im2 assms(3) off_bound]) + have \ntt_inner_loop_int zf off blen cnt inv_cs ! (off + m + blen) mod MLKEM_Q = + (barrett_reduce_int (?a + ?b) - fqmul_int (fqmul_int ?D zi) zf) mod MLKEM_Q\ + using ntt_val inv_low inv_high by simp + also have \\ = ((?a + ?b) - ?D * (MLKEM_Q - 1)) mod MLKEM_Q\ + by (rule mod_diff_cong[OF barrett_reduce_mod fqmul_fqmul_cancel[OF assms(5,6,4)]]) + also have \\ = 2 * ?b mod MLKEM_Q\ + by (simp add: mod_eq_dvd_iff algebra_simps) + finally show ?thesis + using im1 im2 not_low by (simp add: inv_cs_def) + qed + have case_out: \i \ {off.. i \ {off+blen.. ?thesis\ + proof - + assume nl: \i \ {off.. and nh: \i \ {off+blen.. + have \ntt_inner_loop_int zf off blen cnt inv_cs ! i = inv_cs ! i\ + by (rule ntt_inner_loop_int_nth_unchanged) (use nl nh in auto) + also have \\ = cs ! i\ + unfolding inv_cs_def by (rule invntt_inner_loop_int_nth_unchanged) (use nl nh in auto) + finally show ?thesis + using nl nh by (auto simp: inv_cs_def) + qed + show ?thesis + using case_low case_high case_out by blast +qed + +lemma inner_loop_composition_ntt_invntt: + assumes \length cs = MLKEM_N\ + and \off + 2 * blen \ MLKEM_N\ + and \cnt \ blen\ + and \mlkem_zeta ^ br_fwd * mlkem_zeta ^ br_inv mod MLKEM_Q = MLKEM_Q - 1\ + and \zf mod MLKEM_Q = mlkem_zeta ^ br_fwd * 2^16 mod MLKEM_Q\ + and \zi mod MLKEM_Q = mlkem_zeta ^ br_inv * 2^16 mod MLKEM_Q\ + and \i < length cs\ + shows \invntt_inner_loop_int zi off blen cnt (ntt_inner_loop_int zf off blen cnt cs) ! i mod MLKEM_Q = + (if i \ {off.. i \ {off+blen.. +proof - + define ntt_cs where + \ntt_cs = ntt_inner_loop_int zf off blen cnt cs\ + have len_ntt: \length ntt_cs = length cs\ + by (simp add: ntt_cs_def ntt_inner_loop_int_length) + have off_bound: \off + 2 * blen \ length cs\ + using assms(1,2) by simp + have case_low: \i \ {off.. ?thesis\ + proof - + assume low: \i \ {off.. + define m where + \m = i - off\ + let ?a = \cs ! (off + m)\ + let ?b = \cs ! (off + m + blen)\ + have im1: \i = off + m\ and im2: \m < cnt\ + using low by (auto simp: m_def) + have ntt_low: \ntt_cs ! (off + m) = ?a + fqmul_int ?b zf\ + unfolding ntt_cs_def by (rule ntt_inner_loop_int_low_val[OF im2 assms(3) off_bound]) + have ntt_high: \ntt_cs ! (off + m + blen) = ?a - fqmul_int ?b zf\ + unfolding ntt_cs_def by (rule ntt_inner_loop_int_high_val[OF im2 assms(3) off_bound]) + have inv_val: \invntt_inner_loop_int zi off blen cnt ntt_cs ! (off + m) = + barrett_reduce_int (ntt_cs ! (off + m) + ntt_cs ! (off + m + blen))\ + by (rule invntt_inner_loop_int_low_val[OF im2 assms(3)]) + (use len_ntt off_bound in simp) + have sum_eq: \ntt_cs ! (off + m) + ntt_cs ! (off + m + blen) = 2 * ?a\ + using ntt_low ntt_high by simp + have \invntt_inner_loop_int zi off blen cnt ntt_cs ! (off + m) mod MLKEM_Q = + barrett_reduce_int (2 * ?a) mod MLKEM_Q\ + using inv_val sum_eq by simp + also have \\ = 2 * ?a mod MLKEM_Q\ + by (rule barrett_reduce_mod) + finally show ?thesis + using im1 im2 low by (simp add: ntt_cs_def) + qed + have case_high: \i \ {off+blen.. i \ {off.. ?thesis\ + proof - + assume high: \i \ {off+blen.. + and not_low: \i \ {off.. + define m where + \m = i - off - blen\ + let ?a = \cs ! (off + m)\ + let ?b = \cs ! (off + m + blen)\ + let ?Zf = \mlkem_zeta ^ br_fwd\ + let ?Zi = \mlkem_zeta ^ br_inv\ + have im1: \i = off + m + blen\ and im2: \m < cnt\ + using high assms(3) by (auto simp: m_def) + have ntt_low: \ntt_cs ! (off + m) = ?a + fqmul_int ?b zf\ + unfolding ntt_cs_def by (rule ntt_inner_loop_int_low_val[OF im2 assms(3) off_bound]) + have ntt_high: \ntt_cs ! (off + m + blen) = ?a - fqmul_int ?b zf\ + unfolding ntt_cs_def by (rule ntt_inner_loop_int_high_val[OF im2 assms(3) off_bound]) + have inv_val: \invntt_inner_loop_int zi off blen cnt ntt_cs ! (off + m + blen) = + fqmul_int (ntt_cs ! (off + m + blen) - ntt_cs ! (off + m)) zi\ + by (rule invntt_inner_loop_int_high_val[OF im2 assms(3)]) + (use len_ntt off_bound in simp) + have diff_eq: \ntt_cs ! (off + m + blen) - ntt_cs ! (off + m) = - 2 * fqmul_int ?b zf\ + using ntt_low ntt_high by simp + have fqmul_cancel: \fqmul_int ?b zf mod MLKEM_Q = ?b * ?Zf mod MLKEM_Q\ + by (rule fqmul_zeta_cancel[OF assms(5)]) + have neg2_Q_mod: \(-2) * (x::int) * (MLKEM_Q - 1) mod MLKEM_Q = 2 * x mod MLKEM_Q\ for x + by (simp add: mod_eq_dvd_iff algebra_simps) + have key: \fqmul_int (- 2 * fqmul_int ?b zf) zi mod MLKEM_Q = 2 * ?b mod MLKEM_Q\ + proof - + have \fqmul_int (- 2 * fqmul_int ?b zf) zi mod MLKEM_Q = + (- 2 * fqmul_int ?b zf) * ?Zi mod MLKEM_Q\ + by (rule fqmul_zeta_cancel[OF assms(6)]) + also have \\ = (-2) * (?b * ?Zf) * ?Zi mod MLKEM_Q\ + proof - + have \(- 2 * fqmul_int ?b zf) * ?Zi mod MLKEM_Q = + ((- 2 * fqmul_int ?b zf mod MLKEM_Q) * ?Zi) mod MLKEM_Q\ + by (rule mod_mult_left_eq[symmetric]) + also have \- 2 * fqmul_int ?b zf mod MLKEM_Q = (- 2) * (?b * ?Zf) mod MLKEM_Q\ + proof - + have \- 2 * fqmul_int ?b zf mod MLKEM_Q = + ((- 2) * (fqmul_int ?b zf mod MLKEM_Q)) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = ((- 2) * (?b * ?Zf mod MLKEM_Q)) mod MLKEM_Q\ + using fqmul_cancel by simp + also have \\ = (- 2) * (?b * ?Zf) mod MLKEM_Q\ + by (rule mod_mult_right_eq) + finally show ?thesis . + qed + hence \((- 2 * fqmul_int ?b zf mod MLKEM_Q) * ?Zi) mod MLKEM_Q = + (((- 2) * (?b * ?Zf) mod MLKEM_Q) * ?Zi) mod MLKEM_Q\ + by simp + also have \\ = ((- 2) * (?b * ?Zf) * ?Zi) mod MLKEM_Q\ + by (rule mod_mult_left_eq) + finally show ?thesis + by simp + qed + also have \\ = (-2) * ?b * (?Zf * ?Zi) mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = (-2) * ?b * (MLKEM_Q - 1) mod MLKEM_Q\ + proof - + have \(-2) * ?b * (?Zf * ?Zi) mod MLKEM_Q = + ((-2) * ?b * (?Zf * ?Zi mod MLKEM_Q)) mod MLKEM_Q\ + by (rule mod_mult_right_eq[symmetric]) + also have \\ = ((-2) * ?b * (MLKEM_Q - 1)) mod MLKEM_Q\ + using assms(4) by simp + finally show ?thesis . + qed + also have \\ = 2 * ?b mod MLKEM_Q\ + by (rule neg2_Q_mod) + finally show ?thesis . + qed + have \invntt_inner_loop_int zi off blen cnt ntt_cs ! (off + m + blen) mod MLKEM_Q = + fqmul_int (- 2 * fqmul_int ?b zf) zi mod MLKEM_Q\ + using inv_val diff_eq by simp + also have \\ = 2 * ?b mod MLKEM_Q\ + by (rule key) + finally show ?thesis + using im1 im2 not_low by (simp add: ntt_cs_def) + qed + have case_out: \i \ {off.. i \ {off+blen.. ?thesis\ + proof - + assume nl: \i \ {off.. + and nh: \i \ {off+blen.. + have \invntt_inner_loop_int zi off blen cnt ntt_cs ! i = ntt_cs ! i\ + by (rule invntt_inner_loop_int_nth_unchanged) (use nl nh in auto) + also have \\ = cs ! i\ + unfolding ntt_cs_def by (rule ntt_inner_loop_int_nth_unchanged) (use nl nh in auto) + finally show ?thesis + using nl nh by (auto simp: ntt_cs_def) + qed + show ?thesis + using case_low case_high case_out by blast +qed + +subsection \Tier 5 — Layer-Level Inverse\ + +text \A full NTT layer composed with the corresponding inverse NTT layer + (or vice versa) scales every coefficient by 2 mod Q.\ + +text \Inner loop congruence: if two lists agree on the block positions, + the inner loop results agree on those positions.\ + +lemma ntt_inner_loop_int_cong: + assumes agree: \\p. off \ p \ p < off + 2 * blen \ xs ! p = ys ! p\ + and len_xs: \off + 2 * blen \ length xs\ + and len_ys: \off + 2 * blen \ length ys\ + and i_lo: \off \ i\ + and i_hi: \i < off + 2 * blen\ + shows \ntt_inner_loop_int z off blen blen xs ! i = + ntt_inner_loop_int z off blen blen ys ! i\ +proof (cases \i < off + blen\) + case True + define m where + \m = i - off\ + have im: \i = off + m\ and mb: \m < blen\ + using i_lo True by (auto simp: m_def) + have eq1: \xs ! (off + m) = ys ! (off + m)\ + using agree mb by auto + have eq2: \xs ! (off + m + blen) = ys ! (off + m + blen)\ + using agree mb by auto + show ?thesis unfolding im + using ntt_inner_loop_int_low_val[OF mb le_refl len_xs] + ntt_inner_loop_int_low_val[OF mb le_refl len_ys] eq1 eq2 by simp +next + case False + define m where + \m = i - off - blen\ + have im: \i = off + m + blen\ and mb: \m < blen\ + using i_lo i_hi False by (auto simp: m_def) + have eq1: \xs ! (off + m) = ys ! (off + m)\ + using agree mb by auto + have eq2: \xs ! (off + m + blen) = ys ! (off + m + blen)\ + using agree mb by auto + show ?thesis unfolding im + using ntt_inner_loop_int_high_val[OF mb le_refl len_xs] + ntt_inner_loop_int_high_val[OF mb le_refl len_ys] eq1 eq2 by simp +qed + +lemma invntt_inner_loop_int_cong: + assumes agree: \\p. off \ p \ p < off + 2 * blen \ xs ! p = ys ! p\ + and len_xs: \off + 2 * blen \ length xs\ + and len_ys: \off + 2 * blen \ length ys\ + and i_lo: \off \ i\ + and i_hi: \i < off + 2 * blen\ + shows \invntt_inner_loop_int z off blen blen xs ! i = + invntt_inner_loop_int z off blen blen ys ! i\ +proof (cases \i < off + blen\) + case True + define m where + \m = i - off\ + have im: \i = off + m\ and mb: \m < blen\ + using i_lo True by (auto simp: m_def) + have eq1: \xs ! (off + m) = ys ! (off + m)\ + using agree mb by auto + have eq2: \xs ! (off + m + blen) = ys ! (off + m + blen)\ + using agree mb by auto + show ?thesis + unfolding im using invntt_inner_loop_int_low_val[OF mb le_refl len_xs] + invntt_inner_loop_int_low_val[OF mb le_refl len_ys] eq1 eq2 by simp +next + case False + define m where + \m = i - off - blen\ + have im: \i = off + m + blen\ and mb: \m < blen\ + using i_lo i_hi False by (auto simp: m_def) + have eq1: \xs ! (off + m) = ys ! (off + m)\ + using agree mb by auto + have eq2: \xs ! (off + m + blen) = ys ! (off + m + blen)\ + using agree mb by auto + show ?thesis + unfolding im using invntt_inner_loop_int_high_val[OF mb le_refl len_xs] + invntt_inner_loop_int_high_val[OF mb le_refl len_ys] eq1 eq2 by simp +qed + +text \Middle loop block decomposition: the full middle loop at a position + in block j equals just the inner loop for block j applied to the original list.\ + +lemma ntt_middle_loop_at_block: + assumes \j < nb\ \nb * (2 * blen) \ length cs\ + and \j * (2 * blen) \ i\ + and \i < Suc j * (2 * blen)\ + shows \snd (ntt_middle_loop_int k blen nb nb cs) ! i = + ntt_inner_loop_int (zetas_int ! (k + j)) (j * (2 * blen)) blen blen cs ! i\ +using assms proof (induction nb arbitrary: k cs) + case 0 then + show ?case by simp +next + case (Suc nb') + have snoc: \snd (ntt_middle_loop_int k blen (Suc nb') (Suc nb') cs) = + ntt_inner_loop_int (zetas_int ! (k + nb')) (nb' * (2 * blen)) blen blen + (snd (ntt_middle_loop_int k blen nb' nb' cs))\ + by (rule ntt_middle_loop_int_snoc) + show ?case + proof (cases \j < nb'\) + case True + hence \Suc j \ nb'\ + by simp + hence \Suc j * (2 * blen) \ nb' * (2 * blen)\ + by (rule mult_le_mono1) + with Suc.prems(4) have i_bound: \i < nb' * (2 * blen)\ by simp + have i_not_lo: \i \ {nb' * (2 * blen).. + using i_bound by auto + have i_not_hi: \i \ {nb' * (2 * blen) + blen.. + using i_bound by auto + have len_nb': \nb' * (2 * blen) \ length cs\ + using Suc.prems(2) by simp + have \snd (ntt_middle_loop_int k blen (Suc nb') (Suc nb') cs) ! i = + snd (ntt_middle_loop_int k blen nb' nb' cs) ! i\ + unfolding snoc by (rule ntt_inner_loop_int_nth_unchanged[OF i_not_lo i_not_hi]) + also have \\ = ntt_inner_loop_int (zetas_int ! (k + j)) (j * (2 * blen)) blen blen cs ! i\ + by (rule Suc.IH[OF True len_nb' Suc.prems(3) Suc.prems(4)]) + finally show ?thesis . + next + case False + hence j_eq: \j = nb'\ using Suc.prems(1) + by simp + define prev where + \prev = snd (ntt_middle_loop_int k blen nb' nb' cs)\ + have agree: \\p. nb' * (2 * blen) \ p \ p < nb' * (2 * blen) + 2 * blen \ prev ! p = cs ! p\ + proof (intro allI impI) + fix p + assume \nb' * (2 * blen) \ p\ \p < nb' * (2 * blen) + 2 * blen\ + show \prev ! p = cs ! p\ + unfolding prev_def by (rule ntt_middle_loop_int_nth_unchanged[OF \nb' * (2 * blen) \ p\]) + qed + have len_prev: \nb' * (2 * blen) + 2 * blen \ length prev\ + unfolding prev_def using Suc.prems(2) by (simp add: ntt_middle_loop_int_length) + have len_cs: \nb' * (2 * blen) + 2 * blen \ length cs\ + using Suc.prems(2) by simp + show ?thesis + unfolding snoc j_eq prev_def[symmetric] + by (rule ntt_inner_loop_int_cong[OF agree len_prev len_cs]) + (use Suc.prems(3,4) j_eq in auto) + qed +qed + +lemma invntt_middle_loop_at_block: + assumes \j < nb\ \nb * (2 * blen) \ length cs\ + and \j * (2 * blen) \ i\ + and \i < Suc j * (2 * blen)\ + shows \snd (invntt_middle_loop_int k blen nb nb cs) ! i = + invntt_inner_loop_int (zetas_int ! (k - j)) (j * (2 * blen)) blen blen cs ! i\ +using assms proof (induction nb arbitrary: k cs) + case 0 + then show ?case by simp +next + case (Suc nb') + have snoc: \snd (invntt_middle_loop_int k blen (Suc nb') (Suc nb') cs) = + invntt_inner_loop_int (zetas_int ! (k - nb')) (nb' * (2 * blen)) blen blen + (snd (invntt_middle_loop_int k blen nb' nb' cs))\ + by (rule invntt_middle_loop_int_snoc) + show ?case + proof (cases \j < nb'\) + case True + hence \Suc j \ nb'\ + by simp + hence \Suc j * (2 * blen) \ nb' * (2 * blen)\ + by (rule mult_le_mono1) + with Suc.prems(4) have i_bound: \i < nb' * (2 * blen)\ + by simp + have i_not_lo: \i \ {nb' * (2 * blen).. + using i_bound by auto + have i_not_hi: \i \ {nb' * (2 * blen) + blen.. + using i_bound by auto + have len_nb': \nb' * (2 * blen) \ length cs\ + using Suc.prems(2) by simp + have \snd (invntt_middle_loop_int k blen (Suc nb') (Suc nb') cs) ! i = + snd (invntt_middle_loop_int k blen nb' nb' cs) ! i\ + unfolding snoc by (rule invntt_inner_loop_int_nth_unchanged[OF i_not_lo i_not_hi]) + also have \\ = invntt_inner_loop_int (zetas_int ! (k - j)) (j * (2 * blen)) blen blen cs ! i\ + by (rule Suc.IH[OF True len_nb' Suc.prems(3) Suc.prems(4)]) + finally show ?thesis . + next + case False + hence j_eq: \j = nb'\ + using Suc.prems(1) by simp + define prev where + \prev = snd (invntt_middle_loop_int k blen nb' nb' cs)\ + have agree: \\p. nb' * (2 * blen) \ p \ p < nb' * (2 * blen) + 2 * blen \ prev ! p = cs ! p\ + proof (intro allI impI) + fix p + assume \nb' * (2 * blen) \ p\ \p < nb' * (2 * blen) + 2 * blen\ + show \prev ! p = cs ! p\ + unfolding prev_def by (rule invntt_middle_loop_int_nth_unchanged[OF \nb' * (2 * blen) \ p\]) + qed + have len_prev: \nb' * (2 * blen) + 2 * blen \ length prev\ + unfolding prev_def using Suc.prems(2) by (simp add: invntt_middle_loop_int_length) + have len_cs: \nb' * (2 * blen) + 2 * blen \ length cs\ + using Suc.prems(2) by simp + show ?thesis + unfolding snoc j_eq prev_def[symmetric] + by (rule invntt_inner_loop_int_cong[OF agree len_prev len_cs]) + (use Suc.prems(3,4) j_eq in auto) + qed +qed + +text \Arithmetic helper: the number of blocks times the block size equals @{text MLKEM_N}.\ + +lemma ntt_nb_blen_eq: + assumes \1 \ l\ + and \l \ 7\ + shows \(2::nat)^(l-1) * (2 * 2^(8-l)) = MLKEM_N\ +proof - + have \(2::nat)^(l-1) * (2 * 2^(8-l)) = 2^(l-1) * 2^(9-l)\ + proof - + have \(9::nat) - l = Suc (8 - l)\ + using assms by auto + thus ?thesis + by simp + qed + also have \\ = (2::nat)^((l-1) + (9-l))\ + by (simp add: power_add) + also have \(l-1) + (9-l) = (8::nat)\ + using assms by auto + finally show ?thesis + by simp +qed + +lemma ntt_invntt_layer_inverse: + assumes \length cs = MLKEM_N\ + and \1 \ l\ + and \l \ 7\ + and \i < MLKEM_N\ + shows \ntt_layer_int l (invntt_layer_int l cs) ! i mod MLKEM_Q = + 2 * cs ! i mod MLKEM_Q\ +proof - + define nb where + \nb = (2::nat)^(l-1)\ + define blen where + \blen = (2::nat)^(8-l)\ + define inv_cs where + \inv_cs = invntt_layer_int l cs\ + define j where + \j = i div (2 * blen)\ + define ki where + \ki = (2::nat)^l - 1 - j\ + define local_inv where + \local_inv = invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen cs\ + have nb_blen: \nb * (2 * blen) = MLKEM_N\ + using ntt_nb_blen_eq[OF assms(2) assms(3)] by (simp add: nb_def blen_def) + have i_lt: \i < nb * (2 * blen)\ + using assms(4) nb_blen by linarith + have j_lt_nb: \j < nb\ + unfolding j_def using i_lt by (simp add: less_mult_imp_div_less) + have sj_le_nb: \Suc j \ nb\ + using j_lt_nb by simp + have sj_le: \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + have j_lo: \j * (2 * blen) \ i\ + unfolding j_def by simp + have j_hi_aux: \i < j * (2 * blen) + 2 * blen\ + proof - + have blen_pos: \(0::nat) < 2 * blen\ + by (simp add: blen_def) + have mod_bound: \i mod (2 * blen) < 2 * blen\ + using blen_pos by simp + have eq: \i div (2 * blen) * (2 * blen) + i mod (2 * blen) = i\ + by (rule div_mult_mod_eq) + show ?thesis unfolding j_def using mod_bound eq by linarith + qed + hence j_hi: \i < Suc j * (2 * blen)\ + by simp + have len_cs: \length cs = MLKEM_N\ + by (rule assms(1)) + have block_len_cs: \j * (2 * blen) + 2 * blen \ length cs\ + using sj_le nb_blen len_cs by simp + have off_bound: \j * (2 * blen) + 2 * blen \ MLKEM_N\ + using block_len_cs assms(1) by simp + have l_suc: \Suc (l - 1) = l\ + using assms(2) by simp + have two_nb: \2 * nb = 2^l\ + proof - + have \(2::nat) * 2^(l-1) = 2^Suc (l-1)\ by simp + also have \Suc (l-1) = l\ using assms(2) by simp + finally show ?thesis by (simp add: nb_def) + qed + have l_le_8: \l \ 8\ + using assms(3) by simp + have two_l_le: \(2::nat)^l \ 128\ + proof - + have \(2::nat)^l \ 2^7\ + by (rule power_increasing) (use assms(3) in auto) + thus ?thesis by simp + qed + have nb_j_lt: \nb + j < 128\ + using j_lt_nb two_nb two_l_le by linarith + have ki_bound: \ki < 2^l\ + unfolding ki_def using j_lt_nb two_nb by linarith + have ki_j_lt: \ki < 128\ + using ki_bound two_l_le by linarith + have ki_eq: \ki = 2^l - 1 - j\ + unfolding ki_def by simp + have j_lt_pow: \j < 2^(l-1)\ + using j_lt_nb nb_def by simp + have ntt_decomp: \ntt_layer_int l inv_cs ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen inv_cs ! i\ + proof - + have \snd (ntt_middle_loop_int nb blen nb nb inv_cs) ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen inv_cs ! i\ + by (rule ntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen inv_cs_def invntt_layer_int_length assms(1)) + thus ?thesis + by (simp add: ntt_layer_int_def nb_def blen_def inv_cs_def) + qed + have inv_agree: \inv_cs ! p = local_inv ! p\ + if \j * (2 * blen) \ p\ \p < j * (2 * blen) + 2 * blen\ for p + proof - + have step1: \snd (invntt_middle_loop_int (2^l - 1) blen nb nb cs) ! p = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen cs ! p\ + by (rule invntt_middle_loop_at_block[OF j_lt_nb _ that(1)]) + (use nb_blen len_cs in simp, use that(2) in simp) + have \invntt_layer_int l cs ! p = + invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen cs ! p\ + using step1 by (simp add: invntt_layer_int_def nb_def blen_def ki_def) + thus ?thesis + by (simp add: inv_cs_def local_inv_def) + qed + have cong_step: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen inv_cs ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen local_inv ! i\ + proof (rule ntt_inner_loop_int_cong[OF _ _ _ j_lo]) + show \\p. j * (2 * blen) \ p \ p < j * (2 * blen) + 2 * blen \ inv_cs ! p = local_inv ! p\ + using inv_agree by blast + show \j * (2 * blen) + 2 * blen \ length inv_cs\ + unfolding inv_cs_def using invntt_layer_int_length assms(1) off_bound by simp + show \j * (2 * blen) + 2 * blen \ length local_inv\ + unfolding local_inv_def using invntt_inner_loop_int_length assms(1) off_bound by simp + show \i < j * (2 * blen) + 2 * blen\ + using j_hi_aux by simp + qed + have br_sum: \bit_reverse 7 (nb + j) + bit_reverse 7 ki = 128\ + unfolding nb_def ki_eq + by (rule bit_reverse_complement[OF j_lt_pow assms(2) assms(3)]) + have zeta_cancel: \mlkem_zeta ^ bit_reverse 7 (nb + j) * mlkem_zeta ^ bit_reverse 7 ki mod MLKEM_Q = MLKEM_Q - 1\ + by (rule zeta_power_sum_128[OF br_sum]) + have zf_mod: \zetas_int ! (nb + j) mod MLKEM_Q = mlkem_zeta ^ bit_reverse 7 (nb + j) * 2^16 mod MLKEM_Q\ + by (rule zetas_int_mod[OF nb_j_lt]) + have zi_mod: \zetas_int ! ki mod MLKEM_Q = mlkem_zeta ^ bit_reverse 7 ki * 2^16 mod MLKEM_Q\ + by (rule zetas_int_mod[OF ki_j_lt]) + have i_lt_cs: \i < length cs\ + using assms(1) assms(4) by simp + have comp: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen local_inv ! i mod MLKEM_Q = + 2 * cs ! i mod MLKEM_Q\ + proof - + have \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen + (invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen cs) ! i mod MLKEM_Q = + (if i \ {j * (2 * blen).. + i \ {j * (2 * blen) + blen.. + by (rule inner_loop_composition_invntt_ntt[OF assms(1) off_bound le_refl + zeta_cancel zf_mod zi_mod i_lt_cs]) + moreover have \i \ {j * (2 * blen).. + i \ {j * (2 * blen) + blen.. + using j_lo j_hi_aux by auto + ultimately show ?thesis + unfolding local_inv_def by simp + qed + have \ntt_layer_int l (invntt_layer_int l cs) ! i mod MLKEM_Q = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen local_inv ! i mod MLKEM_Q\ + unfolding inv_cs_def[symmetric] using ntt_decomp cong_step by simp + also have \\ = 2 * cs ! i mod MLKEM_Q\ + by (rule comp) + finally show ?thesis . +qed + +lemma invntt_ntt_layer_inverse: + assumes \length cs = MLKEM_N\ + and \1 \ l\ + and \l \ 7\ + and \i < MLKEM_N\ + shows \invntt_layer_int l (ntt_layer_int l cs) ! i mod MLKEM_Q = + 2 * cs ! i mod MLKEM_Q\ +proof - + define nb where + \nb = (2::nat)^(l-1)\ + define blen where + \blen = (2::nat)^(8-l)\ + define ntt_cs where + \ntt_cs = ntt_layer_int l cs\ + define j where + \j = i div (2 * blen)\ + define ki where + \ki = (2::nat)^l - 1 - j\ + define local_ntt where + \local_ntt = ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs\ + have nb_blen: \nb * (2 * blen) = MLKEM_N\ + using ntt_nb_blen_eq[OF assms(2) assms(3)] by (simp add: nb_def blen_def) + have i_lt: \i < nb * (2 * blen)\ + using assms(4) nb_blen by linarith + have j_lt_nb: \j < nb\ + unfolding j_def using i_lt by (simp add: less_mult_imp_div_less) + have sj_le_nb: \Suc j \ nb\ + using j_lt_nb by simp + have sj_le: \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + have j_lo: \j * (2 * blen) \ i\ + unfolding j_def by simp + have j_hi_aux: \i < j * (2 * blen) + 2 * blen\ + proof - + have blen_pos: \(0::nat) < 2 * blen\ + by (simp add: blen_def) + have mod_bound: \i mod (2 * blen) < 2 * blen\ + using blen_pos by simp + have eq: \i div (2 * blen) * (2 * blen) + i mod (2 * blen) = i\ + by (rule div_mult_mod_eq) + show ?thesis + unfolding j_def using mod_bound eq by linarith + qed + hence j_hi: \i < Suc j * (2 * blen)\ + by simp + have len_cs: \length cs = MLKEM_N\ + by (rule assms(1)) + have block_len_cs: \j * (2 * blen) + 2 * blen \ length cs\ + using sj_le nb_blen len_cs by simp + have off_bound: \j * (2 * blen) + 2 * blen \ MLKEM_N\ + using block_len_cs assms(1) by simp + have l_suc: \Suc (l - 1) = l\ + using assms(2) by simp + have two_nb: \2 * nb = 2^l\ + proof - + have \(2::nat) * 2^(l-1) = 2^Suc (l-1)\ + by simp + also have \Suc (l-1) = l\ + using assms(2) by simp + finally show ?thesis + by (simp add: nb_def) + qed + have two_l_le: \(2::nat)^l \ 128\ + proof - + have \(2::nat)^l \ 2^7\ + by (rule power_increasing) (use assms(3) in auto) + thus ?thesis + by simp + qed + have nb_j_lt: \nb + j < 128\ + using j_lt_nb two_nb two_l_le by linarith + have ki_bound: \ki < 2^l\ + unfolding ki_def using j_lt_nb two_nb by linarith + have ki_j_lt: \ki < 128\ + using ki_bound two_l_le by linarith + have ki_eq: \ki = 2^l - 1 - j\ + unfolding ki_def by simp + have j_lt_pow: \j < 2^(l-1)\ + using j_lt_nb nb_def by simp + have invntt_decomp: \invntt_layer_int l ntt_cs ! i = + invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen ntt_cs ! i\ + proof - + have step1: \snd (invntt_middle_loop_int (2^l - 1) blen nb nb ntt_cs) ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen ntt_cs ! i\ + by (rule invntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen ntt_cs_def ntt_layer_int_length assms(1)) + thus ?thesis + by (simp add: invntt_layer_int_def nb_def blen_def ntt_cs_def ki_def) + qed + have ntt_agree: \ntt_cs ! p = local_ntt ! p\ + if \j * (2 * blen) \ p\ \p < j * (2 * blen) + 2 * blen\ for p + proof - + have step1: \snd (ntt_middle_loop_int nb blen nb nb cs) ! p = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs ! p\ + by (rule ntt_middle_loop_at_block[OF j_lt_nb _ that(1)]) + (use nb_blen len_cs in simp, use that(2) in simp) + have \ntt_layer_int l cs ! p = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs ! p\ + using step1 by (simp add: ntt_layer_int_def nb_def blen_def) + thus ?thesis + by (simp add: ntt_cs_def local_ntt_def) + qed + have cong_step: \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen ntt_cs ! i = + invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen local_ntt ! i\ + proof (rule invntt_inner_loop_int_cong[OF _ _ _ j_lo]) + show \\p. j * (2 * blen) \ p \ p < j * (2 * blen) + 2 * blen \ ntt_cs ! p = local_ntt ! p\ + using ntt_agree by blast + show \j * (2 * blen) + 2 * blen \ length ntt_cs\ + unfolding ntt_cs_def using ntt_layer_int_length assms(1) off_bound by simp + show \j * (2 * blen) + 2 * blen \ length local_ntt\ + unfolding local_ntt_def using ntt_inner_loop_int_length assms(1) off_bound by simp + show \i < j * (2 * blen) + 2 * blen\ + using j_hi_aux by simp + qed + have br_sum: \bit_reverse 7 (nb + j) + bit_reverse 7 ki = 128\ + unfolding nb_def ki_eq by (rule bit_reverse_complement[OF j_lt_pow assms(2) assms(3)]) + have zeta_cancel: \mlkem_zeta ^ bit_reverse 7 (nb + j) * mlkem_zeta ^ bit_reverse 7 ki mod MLKEM_Q = MLKEM_Q - 1\ + by (rule zeta_power_sum_128[OF br_sum]) + have zf_mod: \zetas_int ! (nb + j) mod MLKEM_Q = mlkem_zeta ^ bit_reverse 7 (nb + j) * 2^16 mod MLKEM_Q\ + by (rule zetas_int_mod[OF nb_j_lt]) + have zi_mod: \zetas_int ! ki mod MLKEM_Q = mlkem_zeta ^ bit_reverse 7 ki * 2^16 mod MLKEM_Q\ + by (rule zetas_int_mod[OF ki_j_lt]) + have i_lt_cs: \i < length cs\ + using assms(1) assms(4) by simp + have comp: \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen local_ntt ! i mod MLKEM_Q = + 2 * cs ! i mod MLKEM_Q\ + proof - + have \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen + (ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs) ! i mod MLKEM_Q = + (if i \ {j * (2 * blen).. + i \ {j * (2 * blen) + blen.. + by (rule inner_loop_composition_ntt_invntt[OF assms(1) off_bound le_refl + zeta_cancel zf_mod zi_mod i_lt_cs]) + moreover have \i \ {j * (2 * blen).. + i \ {j * (2 * blen) + blen.. + using j_lo j_hi_aux by auto + ultimately show ?thesis + unfolding local_ntt_def by simp + qed + have \invntt_layer_int l (ntt_layer_int l cs) ! i mod MLKEM_Q = + invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen local_ntt ! i mod MLKEM_Q\ + unfolding ntt_cs_def[symmetric] using invntt_decomp cong_step by simp + also have \\ = 2 * cs ! i mod MLKEM_Q\ + by (rule comp) + finally show ?thesis . +qed + +subsection \Tier 6 — Linearity\ + +text \NTT layers are linear mod Q: scaling inputs by a constant scales + outputs by the same constant. Needed to commute the Montgomery + prescaling past the NTT layers in one direction of the proof.\ + +text \Helper: fqmul distributes scaling mod Q.\ + +lemma fqmul_int_linear_mod: + shows \fqmul_int (a * k) b mod MLKEM_Q = fqmul_int a b * k mod MLKEM_Q\ +proof - + have lhs: \fqmul_int (a * k) b * 2^16 mod MLKEM_Q = a * k * b mod MLKEM_Q\ + by (rule fqmul_int_cong) + have rhs: \fqmul_int a b * 2^16 mod MLKEM_Q = a * b mod MLKEM_Q\ + by (rule fqmul_int_cong) + have eq16: \fqmul_int (a * k) b * 2^16 mod MLKEM_Q = fqmul_int a b * k * 2^16 mod MLKEM_Q\ + proof - + have \fqmul_int a b * 2^16 * k mod MLKEM_Q = a * b * k mod MLKEM_Q\ + by (rule mod_mult_cong[OF rhs refl]) + thus ?thesis + using lhs by (simp add: algebra_simps) + qed + have cop: \coprime MLKEM_Q ((2::int)^16)\ + by eval + show ?thesis + by (rule mult_mod_cancel_right[OF eq16 cop]) +qed + +lemma ntt_layer_int_linear_mod: + assumes \length cs = MLKEM_N\ + and \1 \ l\ + and \l \ 7\ + and \i < MLKEM_N\ + shows \ntt_layer_int l (List.map (\c. c * k) cs) ! i mod MLKEM_Q = + ntt_layer_int l cs ! i * k mod MLKEM_Q\ +proof - + define nb where + \nb = (2::nat)^(l-1)\ + define blen where + \blen = (2::nat)^(8-l)\ + define scaled where + \scaled = List.map (\c. c * k) cs\ + define j where + \j = i div (2 * blen)\ + have nb_blen: \nb * (2 * blen) = MLKEM_N\ + using ntt_nb_blen_eq[OF assms(2) assms(3)] unfolding nb_def blen_def . + have len_scaled: \length scaled = MLKEM_N\ + unfolding scaled_def using assms(1) by simp + have i_lt: \i < nb * (2 * blen)\ + using assms(4) nb_blen by linarith + have j_lt_nb: \j < nb\ + unfolding j_def using i_lt by (simp add: less_mult_imp_div_less) + have sj_le_nb: \Suc j \ nb\ + using j_lt_nb by simp + have j_lo: \j * (2 * blen) \ i\ + unfolding j_def by simp + have j_hi_aux: \i < j * (2 * blen) + 2 * blen\ + proof - + have blen_pos: \(0::nat) < 2 * blen\ + by (simp add: blen_def) + have \i mod (2 * blen) < 2 * blen\ + using blen_pos by simp + have \i div (2 * blen) * (2 * blen) + i mod (2 * blen) = i\ + by (rule div_mult_mod_eq) + thus ?thesis unfolding j_def using \i mod (2 * blen) < 2 * blen\ + by linarith + qed + hence j_hi: \i < Suc j * (2 * blen)\ + by simp + have block_len_cs: \j * (2 * blen) + 2 * blen \ length cs\ + proof - + have \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + thus ?thesis + using nb_blen assms(1) by simp + qed + have block_len_scaled: \j * (2 * blen) + 2 * blen \ length scaled\ + using block_len_cs len_scaled assms(1) by simp + \ \Decompose layer to inner loop for block j\ + have decomp_scaled: \ntt_layer_int l scaled ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen scaled ! i\ + proof - + have \snd (ntt_middle_loop_int nb blen nb nb scaled) ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen scaled ! i\ + by (rule ntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen len_scaled) + thus ?thesis + unfolding ntt_layer_int_def nb_def blen_def . + qed + have decomp_cs: \ntt_layer_int l cs ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs ! i\ + proof - + have \snd (ntt_middle_loop_int nb blen nb nb cs) ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs ! i\ + by (rule ntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen assms(1)) + thus ?thesis + unfolding ntt_layer_int_def nb_def blen_def . + qed + show ?thesis + proof (cases \i < j * (2 * blen) + blen\) + case True + define m where + \m = i - j * (2 * blen)\ + have im: \i = j * (2 * blen) + m\ and mb: \m < blen\ + using j_lo True by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < length cs\ + using block_len_cs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < length cs\ + using block_len_cs mb by linarith + have scaled_lo: \scaled ! (j * (2 * blen) + m) = cs ! (j * (2 * blen) + m) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_lo]) + have scaled_hi: \scaled ! (j * (2 * blen) + m + blen) = cs ! (j * (2 * blen) + m + blen) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_hi]) + have val_scaled: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen scaled ! i = + cs ! (j * (2 * blen) + m) * k + fqmul_int (cs ! (j * (2 * blen) + m + blen) * k) (zetas_int ! (nb + j))\ + unfolding im using ntt_inner_loop_int_low_val[OF mb le_refl block_len_scaled] + scaled_lo scaled_hi by simp + have val_cs: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs ! i = + cs ! (j * (2 * blen) + m) + fqmul_int (cs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))\ + unfolding im by (rule ntt_inner_loop_int_low_val[OF mb le_refl block_len_cs]) + have \ntt_layer_int l scaled ! i mod MLKEM_Q = + (cs ! (j * (2 * blen) + m) * k + + fqmul_int (cs ! (j * (2 * blen) + m + blen) * k) (zetas_int ! (nb + j))) mod MLKEM_Q\ + using decomp_scaled val_scaled by simp + also have \\ = (cs ! (j * (2 * blen) + m) * k + + fqmul_int (cs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) * k) mod MLKEM_Q\ + by (rule mod_add_cong[OF refl fqmul_int_linear_mod]) + also have \\ = (cs ! (j * (2 * blen) + m) + + fqmul_int (cs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))) * k mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = ntt_layer_int l cs ! i * k mod MLKEM_Q\ + using decomp_cs val_cs by simp + finally show ?thesis + unfolding scaled_def . + next + case False + hence i_ge: \j * (2 * blen) + blen \ i\ + by simp + define m where + \m = i - j * (2 * blen) - blen\ + have im: \i = j * (2 * blen) + m + blen\ and mb: \m < blen\ + using j_lo i_ge j_hi_aux by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < length cs\ + using block_len_cs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < length cs\ + using block_len_cs mb by linarith + have scaled_lo: \scaled ! (j * (2 * blen) + m) = cs ! (j * (2 * blen) + m) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_lo]) + have scaled_hi: \scaled ! (j * (2 * blen) + m + blen) = cs ! (j * (2 * blen) + m + blen) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_hi]) + have val_scaled: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen scaled ! i = + cs ! (j * (2 * blen) + m) * k - fqmul_int (cs ! (j * (2 * blen) + m + blen) * k) (zetas_int ! (nb + j))\ + unfolding im using ntt_inner_loop_int_high_val[OF mb le_refl block_len_scaled] + scaled_lo scaled_hi by simp + have val_cs: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen cs ! i = + cs ! (j * (2 * blen) + m) - fqmul_int (cs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))\ + unfolding im by (rule ntt_inner_loop_int_high_val[OF mb le_refl block_len_cs]) + have \ntt_layer_int l scaled ! i mod MLKEM_Q = + (cs ! (j * (2 * blen) + m) * k - + fqmul_int (cs ! (j * (2 * blen) + m + blen) * k) (zetas_int ! (nb + j))) mod MLKEM_Q\ + using decomp_scaled val_scaled by simp + also have \\ = (cs ! (j * (2 * blen) + m) * k - + fqmul_int (cs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) * k) mod MLKEM_Q\ + by (rule mod_diff_cong[OF refl fqmul_int_linear_mod]) + also have \\ = (cs ! (j * (2 * blen) + m) - + fqmul_int (cs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))) * k mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = ntt_layer_int l cs ! i * k mod MLKEM_Q\ + using decomp_cs val_cs by simp + finally show ?thesis + unfolding scaled_def . + qed +qed + +lemma invntt_layer_int_linear_mod: + assumes \length cs = MLKEM_N\ + and \1 \ l\ + and \l \ 7\ + and \i < MLKEM_N\ + shows \invntt_layer_int l (List.map (\c. c * k) cs) ! i mod MLKEM_Q = + invntt_layer_int l cs ! i * k mod MLKEM_Q\ +proof - + define nb where + \nb = (2::nat)^(l-1)\ + define blen where + \blen = (2::nat)^(8-l)\ + define scaled where + \scaled = List.map (\c. c * k) cs\ + define j where + \j = i div (2 * blen)\ + have nb_blen: \nb * (2 * blen) = MLKEM_N\ + using ntt_nb_blen_eq[OF assms(2) assms(3)] unfolding nb_def blen_def . + have len_scaled: \length scaled = MLKEM_N\ + unfolding scaled_def using assms(1) by simp + have i_lt: \i < nb * (2 * blen)\ + using assms(4) nb_blen by linarith + have j_lt_nb: \j < nb\ + unfolding j_def using i_lt by (simp add: less_mult_imp_div_less) + have sj_le_nb: \Suc j \ nb\ + using j_lt_nb by simp + have j_lo: \j * (2 * blen) \ i\ + unfolding j_def by simp + have j_hi_aux: \i < j * (2 * blen) + 2 * blen\ + proof - + have blen_pos: \(0::nat) < 2 * blen\ + by (simp add: blen_def) + have \i mod (2 * blen) < 2 * blen\ + using blen_pos by simp + have \i div (2 * blen) * (2 * blen) + i mod (2 * blen) = i\ + by (rule div_mult_mod_eq) + thus ?thesis + unfolding j_def using \i mod (2 * blen) < 2 * blen\ by linarith + qed + hence j_hi: \i < Suc j * (2 * blen)\ + by simp + have block_len_cs: \j * (2 * blen) + 2 * blen \ length cs\ + proof - + have \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + thus ?thesis + using nb_blen assms(1) by simp + qed + have block_len_scaled: \j * (2 * blen) + 2 * blen \ length scaled\ + using block_len_cs len_scaled assms(1) by simp + \ \Decompose layer to inner loop for block j\ + have decomp_scaled: \invntt_layer_int l scaled ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen scaled ! i\ + proof - + have \snd (invntt_middle_loop_int (2^l - 1) blen nb nb scaled) ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen scaled ! i\ + by (rule invntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen len_scaled) + thus ?thesis + unfolding invntt_layer_int_def nb_def blen_def . + qed + have decomp_cs: \invntt_layer_int l cs ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen cs ! i\ + proof - + have \snd (invntt_middle_loop_int (2^l - 1) blen nb nb cs) ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen cs ! i\ + by (rule invntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen assms(1)) + thus ?thesis + unfolding invntt_layer_int_def nb_def blen_def . + qed + show ?thesis + proof (cases \i < j * (2 * blen) + blen\) + case True + define m where + \m = i - j * (2 * blen)\ + have im: \i = j * (2 * blen) + m\ and mb: \m < blen\ + using j_lo True by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < length cs\ + using block_len_cs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < length cs\ + using block_len_cs mb by linarith + have scaled_lo: \scaled ! (j * (2 * blen) + m) = cs ! (j * (2 * blen) + m) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_lo]) + have scaled_hi: \scaled ! (j * (2 * blen) + m + blen) = cs ! (j * (2 * blen) + m + blen) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_hi]) + have val_scaled: \invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen scaled ! i = + barrett_reduce_int (cs ! (j * (2 * blen) + m) * k + cs ! (j * (2 * blen) + m + blen) * k)\ + unfolding im using invntt_inner_loop_int_low_val[OF mb le_refl block_len_scaled] + scaled_lo scaled_hi by simp + have val_cs: \invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen cs ! i = + barrett_reduce_int (cs ! (j * (2 * blen) + m) + cs ! (j * (2 * blen) + m + blen))\ + unfolding im by (rule invntt_inner_loop_int_low_val[OF mb le_refl block_len_cs]) + have factor: \cs ! (j * (2 * blen) + m) * k + cs ! (j * (2 * blen) + m + blen) * k = + (cs ! (j * (2 * blen) + m) + cs ! (j * (2 * blen) + m + blen)) * k\ + by (simp add: algebra_simps) + have \invntt_layer_int l scaled ! i mod MLKEM_Q = + barrett_reduce_int ((cs ! (j * (2 * blen) + m) + cs ! (j * (2 * blen) + m + blen)) * k) mod MLKEM_Q\ + using decomp_scaled val_scaled factor by simp + also have \\ = (cs ! (j * (2 * blen) + m) + cs ! (j * (2 * blen) + m + blen)) * k mod MLKEM_Q\ + by (rule barrett_reduce_mod) + also have \\ = barrett_reduce_int (cs ! (j * (2 * blen) + m) + cs ! (j * (2 * blen) + m + blen)) * k mod MLKEM_Q\ + by (rule mod_mult_cong[OF barrett_reduce_mod[symmetric] refl]) + also have \\ = invntt_layer_int l cs ! i * k mod MLKEM_Q\ + using decomp_cs val_cs by simp + finally show ?thesis + unfolding scaled_def . + next + case False + hence i_ge: \j * (2 * blen) + blen \ i\ + by simp + define m where + \m = i - j * (2 * blen) - blen\ + have im: \i = j * (2 * blen) + m + blen\ and mb: \m < blen\ + using j_lo i_ge j_hi_aux by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < length cs\ + using block_len_cs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < length cs\ + using block_len_cs mb by linarith + have scaled_lo: \scaled ! (j * (2 * blen) + m) = cs ! (j * (2 * blen) + m) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_lo]) + have scaled_hi: \scaled ! (j * (2 * blen) + m + blen) = cs ! (j * (2 * blen) + m + blen) * k\ + unfolding scaled_def by (simp add: nth_map[OF idx_hi]) + have val_scaled: \invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen scaled ! i = + fqmul_int (cs ! (j * (2 * blen) + m + blen) * k - cs ! (j * (2 * blen) + m) * k) (zetas_int ! (2^l - 1 - j))\ + unfolding im using invntt_inner_loop_int_high_val[OF mb le_refl block_len_scaled] + scaled_lo scaled_hi by simp + have val_cs: \invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen cs ! i = + fqmul_int (cs ! (j * (2 * blen) + m + blen) - cs ! (j * (2 * blen) + m)) (zetas_int ! (2^l - 1 - j))\ + unfolding im by (rule invntt_inner_loop_int_high_val[OF mb le_refl block_len_cs]) + have factor: \cs ! (j * (2 * blen) + m + blen) * k - cs ! (j * (2 * blen) + m) * k = + (cs ! (j * (2 * blen) + m + blen) - cs ! (j * (2 * blen) + m)) * k\ + by (simp add: algebra_simps) + have \invntt_layer_int l scaled ! i mod MLKEM_Q = + fqmul_int ((cs ! (j * (2 * blen) + m + blen) - cs ! (j * (2 * blen) + m)) * k) + (zetas_int ! (2^l - 1 - j)) mod MLKEM_Q\ + using decomp_scaled val_scaled factor by simp + also have \\ = fqmul_int (cs ! (j * (2 * blen) + m + blen) - cs ! (j * (2 * blen) + m)) + (zetas_int ! (2^l - 1 - j)) * k mod MLKEM_Q\ + by (rule fqmul_int_linear_mod) + also have \\ = invntt_layer_int l cs ! i * k mod MLKEM_Q\ + using decomp_cs val_cs by simp + finally show ?thesis + unfolding scaled_def . + qed +qed + +subsection \Tier 7 — Full Transform Composition\ + +text \Helper: @{const fqmul_int} preserves mod-Q congruence in the first argument.\ + +lemma fqmul_int_mod_cong: + assumes \a mod MLKEM_Q = b mod MLKEM_Q\ + shows \fqmul_int a z mod MLKEM_Q = fqmul_int b z mod MLKEM_Q\ +proof - + have la: \fqmul_int a z * 2^16 mod MLKEM_Q = a * z mod MLKEM_Q\ + by (rule fqmul_int_cong) + have lb: \fqmul_int b z * 2^16 mod MLKEM_Q = b * z mod MLKEM_Q\ + by (rule fqmul_int_cong) + have \a * z mod MLKEM_Q = b * z mod MLKEM_Q\ + using assms by (rule mod_mult_cong[OF _ refl]) + hence \fqmul_int a z * 2^16 mod MLKEM_Q = fqmul_int b z * 2^16 mod MLKEM_Q\ + using la lb by simp + thus ?thesis + by (rule mult_mod_cancel_right) eval +qed + +text \Combined congruence + linearity for a single NTT layer.\ + +lemma ntt_layer_int_linear_cong: + assumes cong: \\j < MLKEM_N. xs ! j mod MLKEM_Q = ys ! j * k mod MLKEM_Q\ + and len_xs: \length xs = MLKEM_N\ + and len_ys: \length ys = MLKEM_N\ + and l_ge: \1 \ l\ + and l_le: \l \ 7\ + and i_lt: \i < MLKEM_N\ + shows \ntt_layer_int l xs ! i mod MLKEM_Q = + ntt_layer_int l ys ! i * k mod MLKEM_Q\ +proof - + define nb where \nb = (2::nat)^(l-1)\ + define blen where \blen = (2::nat)^(8-l)\ + define j where \j = i div (2 * blen)\ + have nb_blen: \nb * (2 * blen) = MLKEM_N\ + using ntt_nb_blen_eq[OF l_ge l_le] unfolding nb_def blen_def . + have i_lt2: \i < nb * (2 * blen)\ + using i_lt nb_blen by linarith + have j_lt_nb: \j < nb\ + unfolding j_def using i_lt2 by (simp add: less_mult_imp_div_less) + have sj_le_nb: \Suc j \ nb\ + using j_lt_nb by simp + have j_lo: \j * (2 * blen) \ i\ + unfolding j_def by simp + have j_hi_aux: \i < j * (2 * blen) + 2 * blen\ + proof - + have \(0::nat) < 2 * blen\ by (simp add: blen_def) + have \i mod (2 * blen) < 2 * blen\ using \0 < 2 * blen\ by simp + have \i div (2 * blen) * (2 * blen) + i mod (2 * blen) = i\ + by (rule div_mult_mod_eq) + thus ?thesis unfolding j_def using \i mod (2 * blen) < 2 * blen\ + by linarith + qed + hence j_hi: \i < Suc j * (2 * blen)\ by simp + have block_len_xs: \j * (2 * blen) + 2 * blen \ length xs\ + proof - + have \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + thus ?thesis using nb_blen len_xs by simp + qed + have block_len_ys: \j * (2 * blen) + 2 * blen \ length ys\ + proof - + have \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + thus ?thesis using nb_blen len_ys by simp + qed + have decomp_xs: \ntt_layer_int l xs ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen xs ! i\ + proof - + have \snd (ntt_middle_loop_int nb blen nb nb xs) ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen xs ! i\ + by (rule ntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen len_xs) + thus ?thesis + unfolding ntt_layer_int_def nb_def blen_def . + qed + have decomp_ys: \ntt_layer_int l ys ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen ys ! i\ + proof - + have \snd (ntt_middle_loop_int nb blen nb nb ys) ! i = + ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen ys ! i\ + by (rule ntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen len_ys) + thus ?thesis + unfolding ntt_layer_int_def nb_def blen_def . + qed + show ?thesis + proof (cases \i < j * (2 * blen) + blen\) + case True + define m where \m = i - j * (2 * blen)\ + have im: \i = j * (2 * blen) + m\ and mb: \m < blen\ + using j_lo True by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have cong_lo: \xs ! (j * (2 * blen) + m) mod MLKEM_Q = + ys ! (j * (2 * blen) + m) * k mod MLKEM_Q\ + using cong idx_lo by auto + have cong_hi: \xs ! (j * (2 * blen) + m + blen) mod MLKEM_Q = + ys ! (j * (2 * blen) + m + blen) * k mod MLKEM_Q\ + using cong idx_hi by auto + have fqmul_cong: \fqmul_int (xs ! (j * (2 * blen) + m + blen)) + (zetas_int ! (nb + j)) mod MLKEM_Q = + fqmul_int (ys ! (j * (2 * blen) + m + blen)) + (zetas_int ! (nb + j)) * k mod MLKEM_Q\ + proof - + have \fqmul_int (xs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) mod MLKEM_Q = + fqmul_int (ys ! (j * (2 * blen) + m + blen) * k) (zetas_int ! (nb + j)) mod MLKEM_Q\ + by (rule fqmul_int_mod_cong[OF cong_hi]) + also have \\ = fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) * k mod MLKEM_Q\ + by (rule fqmul_int_linear_mod) + finally show ?thesis . + qed + have val_xs: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen xs ! i = + xs ! (j * (2 * blen) + m) + + fqmul_int (xs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))\ + unfolding im by (rule ntt_inner_loop_int_low_val[OF mb le_refl block_len_xs]) + have val_ys: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen ys ! i = + ys ! (j * (2 * blen) + m) + + fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))\ + unfolding im by (rule ntt_inner_loop_int_low_val[OF mb le_refl block_len_ys]) + have \ntt_layer_int l xs ! i mod MLKEM_Q = + (xs ! (j * (2 * blen) + m) + + fqmul_int (xs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))) mod MLKEM_Q\ + using decomp_xs val_xs by simp + also have \\ = (ys ! (j * (2 * blen) + m) * k + + fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) * k) mod MLKEM_Q\ + by (rule mod_add_cong[OF cong_lo fqmul_cong]) + also have \\ = (ys ! (j * (2 * blen) + m) + + fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))) * k mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = ntt_layer_int l ys ! i * k mod MLKEM_Q\ + using decomp_ys val_ys by simp + finally show ?thesis . + next + case False + hence i_ge: \j * (2 * blen) + blen \ i\ by simp + define m where \m = i - j * (2 * blen) - blen\ + have im: \i = j * (2 * blen) + m + blen\ and mb: \m < blen\ + using j_lo i_ge j_hi_aux by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have cong_lo: \xs ! (j * (2 * blen) + m) mod MLKEM_Q = + ys ! (j * (2 * blen) + m) * k mod MLKEM_Q\ + using cong idx_lo by auto + have cong_hi: \xs ! (j * (2 * blen) + m + blen) mod MLKEM_Q = + ys ! (j * (2 * blen) + m + blen) * k mod MLKEM_Q\ + using cong idx_hi by auto + have fqmul_cong: \fqmul_int (xs ! (j * (2 * blen) + m + blen)) + (zetas_int ! (nb + j)) mod MLKEM_Q = + fqmul_int (ys ! (j * (2 * blen) + m + blen)) + (zetas_int ! (nb + j)) * k mod MLKEM_Q\ + proof - + have \fqmul_int (xs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) mod MLKEM_Q = + fqmul_int (ys ! (j * (2 * blen) + m + blen) * k) (zetas_int ! (nb + j)) mod MLKEM_Q\ + by (rule fqmul_int_mod_cong[OF cong_hi]) + also have \\ = fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) * k mod MLKEM_Q\ + by (rule fqmul_int_linear_mod) + finally show ?thesis . + qed + have val_xs: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen xs ! i = + xs ! (j * (2 * blen) + m) - + fqmul_int (xs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))\ + unfolding im by (rule ntt_inner_loop_int_high_val[OF mb le_refl block_len_xs]) + have val_ys: \ntt_inner_loop_int (zetas_int ! (nb + j)) (j * (2 * blen)) blen blen ys ! i = + ys ! (j * (2 * blen) + m) - + fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))\ + unfolding im by (rule ntt_inner_loop_int_high_val[OF mb le_refl block_len_ys]) + have \ntt_layer_int l xs ! i mod MLKEM_Q = + (xs ! (j * (2 * blen) + m) - + fqmul_int (xs ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))) mod MLKEM_Q\ + using decomp_xs val_xs by simp + also have \\ = (ys ! (j * (2 * blen) + m) * k - + fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j)) * k) mod MLKEM_Q\ + by (rule mod_diff_cong[OF cong_lo fqmul_cong]) + also have \\ = (ys ! (j * (2 * blen) + m) - + fqmul_int (ys ! (j * (2 * blen) + m + blen)) (zetas_int ! (nb + j))) * k mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = ntt_layer_int l ys ! i * k mod MLKEM_Q\ + using decomp_ys val_ys by simp + finally show ?thesis . + qed +qed + +text \Analogous combined congruence + linearity for a single inverse NTT layer.\ + +lemma invntt_layer_int_linear_cong: + assumes cong: \\j < MLKEM_N. xs ! j mod MLKEM_Q = ys ! j * k mod MLKEM_Q\ + and len_xs: \length xs = MLKEM_N\ + and len_ys: \length ys = MLKEM_N\ + and l_ge: \1 \ l\ + and l_le: \l \ 7\ + and i_lt: \i < MLKEM_N\ + shows \invntt_layer_int l xs ! i mod MLKEM_Q = + invntt_layer_int l ys ! i * k mod MLKEM_Q\ +proof - + define nb where + \nb = (2::nat)^(l-1)\ + define blen where + \blen = (2::nat)^(8-l)\ + define j where + \j = i div (2 * blen)\ + define ki where + \ki = (2::nat)^l - 1 - j\ + have nb_blen: \nb * (2 * blen) = MLKEM_N\ + using ntt_nb_blen_eq[OF l_ge l_le] unfolding nb_def blen_def . + have i_lt2: \i < nb * (2 * blen)\ + using i_lt nb_blen by linarith + have j_lt_nb: \j < nb\ + unfolding j_def using i_lt2 by (simp add: less_mult_imp_div_less) + have sj_le_nb: \Suc j \ nb\ + using j_lt_nb by simp + have j_lo: \j * (2 * blen) \ i\ + unfolding j_def by simp + have j_hi_aux: \i < j * (2 * blen) + 2 * blen\ + proof - + have \(0::nat) < 2 * blen\ + by (simp add: blen_def) + have \i mod (2 * blen) < 2 * blen\ using \0 < 2 * blen\ + by simp + have \i div (2 * blen) * (2 * blen) + i mod (2 * blen) = i\ + by (rule div_mult_mod_eq) + thus ?thesis unfolding j_def using \i mod (2 * blen) < 2 * blen\ + by linarith + qed + hence j_hi: \i < Suc j * (2 * blen)\ + by simp + have block_len_xs: \j * (2 * blen) + 2 * blen \ length xs\ + proof - + have \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + thus ?thesis + using nb_blen len_xs by simp + qed + have block_len_ys: \j * (2 * blen) + 2 * blen \ length ys\ + proof - + have \Suc j * (2 * blen) \ nb * (2 * blen)\ + by (rule mult_le_mono1[OF sj_le_nb]) + thus ?thesis + using nb_blen len_ys by simp + qed + have two_nb: \2 * nb = 2^l\ + proof - + have \(2::nat) * 2^(l-1) = 2^Suc (l-1)\ + by simp + also have \Suc (l-1) = l\ + using l_ge by simp + finally show ?thesis + by (simp add: nb_def) + qed + have ki_eq: \ki = 2^l - 1 - j\ + unfolding ki_def by simp + have decomp_xs: \invntt_layer_int l xs ! i = + invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen xs ! i\ + proof - + have \snd (invntt_middle_loop_int (2^l - 1) blen nb nb xs) ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen xs ! i\ + by (rule invntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen len_xs) + thus ?thesis + unfolding invntt_layer_int_def nb_def blen_def ki_eq using two_nb nb_def by simp + qed + have decomp_ys: \invntt_layer_int l ys ! i = + invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen ys ! i\ + proof - + have \snd (invntt_middle_loop_int (2^l - 1) blen nb nb ys) ! i = + invntt_inner_loop_int (zetas_int ! (2^l - 1 - j)) (j * (2 * blen)) blen blen ys ! i\ + by (rule invntt_middle_loop_at_block[OF j_lt_nb _ j_lo j_hi]) + (simp add: nb_blen len_ys) + thus ?thesis + unfolding invntt_layer_int_def nb_def blen_def ki_eq using two_nb nb_def by simp + qed + show ?thesis + proof (cases \i < j * (2 * blen) + blen\) + case True + define m where + \m = i - j * (2 * blen)\ + have im: \i = j * (2 * blen) + m\ and mb: \m < blen\ + using j_lo True by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have cong_lo: \xs ! (j * (2 * blen) + m) mod MLKEM_Q = + ys ! (j * (2 * blen) + m) * k mod MLKEM_Q\ + using cong idx_lo by auto + have cong_hi: \xs ! (j * (2 * blen) + m + blen) mod MLKEM_Q = + ys ! (j * (2 * blen) + m + blen) * k mod MLKEM_Q\ + using cong idx_hi by auto + \ \Low position: barrett_reduce(xs[lo] + xs[hi])\ + have val_xs: \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen xs ! i = + barrett_reduce_int (xs ! (j * (2 * blen) + m) + xs ! (j * (2 * blen) + m + blen))\ + unfolding im by (rule invntt_inner_loop_int_low_val[OF mb le_refl block_len_xs]) + have val_ys: \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen ys ! i = + barrett_reduce_int (ys ! (j * (2 * blen) + m) + ys ! (j * (2 * blen) + m + blen))\ + unfolding im by (rule invntt_inner_loop_int_low_val[OF mb le_refl block_len_ys]) + have \invntt_layer_int l xs ! i mod MLKEM_Q = + barrett_reduce_int (xs ! (j * (2 * blen) + m) + xs ! (j * (2 * blen) + m + blen)) mod MLKEM_Q\ + using decomp_xs val_xs by simp + also have \\ = (xs ! (j * (2 * blen) + m) + xs ! (j * (2 * blen) + m + blen)) mod MLKEM_Q\ + by (rule barrett_reduce_mod) + also have \\ = (ys ! (j * (2 * blen) + m) * k + ys ! (j * (2 * blen) + m + blen) * k) mod MLKEM_Q\ + by (rule mod_add_cong[OF cong_lo cong_hi]) + also have \\ = (ys ! (j * (2 * blen) + m) + ys ! (j * (2 * blen) + m + blen)) * k mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = barrett_reduce_int (ys ! (j * (2 * blen) + m) + ys ! (j * (2 * blen) + m + blen)) * k mod MLKEM_Q\ + proof - + have \(ys ! (j * (2 * blen) + m) + ys ! (j * (2 * blen) + m + blen)) mod MLKEM_Q = + barrett_reduce_int (ys ! (j * (2 * blen) + m) + ys ! (j * (2 * blen) + m + blen)) mod MLKEM_Q\ + by (rule barrett_reduce_mod[symmetric]) + thus ?thesis + by (rule mod_mult_cong[OF _ refl]) + qed + also have \\ = invntt_layer_int l ys ! i * k mod MLKEM_Q\ + using decomp_ys val_ys by simp + finally show ?thesis . + next + case False + hence i_ge: \j * (2 * blen) + blen \ i\ + by simp + define m where + \m = i - j * (2 * blen) - blen\ + have im: \i = j * (2 * blen) + m + blen\ and mb: \m < blen\ + using j_lo i_ge j_hi_aux by (auto simp: m_def) + have idx_lo: \j * (2 * blen) + m < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have idx_hi: \j * (2 * blen) + m + blen < MLKEM_N\ + using block_len_xs len_xs mb by linarith + have cong_lo: \xs ! (j * (2 * blen) + m) mod MLKEM_Q = + ys ! (j * (2 * blen) + m) * k mod MLKEM_Q\ + using cong idx_lo by auto + have cong_hi: \xs ! (j * (2 * blen) + m + blen) mod MLKEM_Q = + ys ! (j * (2 * blen) + m + blen) * k mod MLKEM_Q\ + using cong idx_hi by auto + \ \High position: fqmul(xs[hi] - xs[lo], zeta)\ + have fqmul_cong: \fqmul_int (xs ! (j * (2 * blen) + m + blen) - xs ! (j * (2 * blen) + m)) + (zetas_int ! ki) mod MLKEM_Q = + fqmul_int (ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) + (zetas_int ! ki) * k mod MLKEM_Q\ + proof - + have diff_cong: \(xs ! (j * (2 * blen) + m + blen) - xs ! (j * (2 * blen) + m)) mod MLKEM_Q = + (ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) * k mod MLKEM_Q\ + proof - + have \(xs ! (j * (2 * blen) + m + blen) - xs ! (j * (2 * blen) + m)) mod MLKEM_Q = + (ys ! (j * (2 * blen) + m + blen) * k - ys ! (j * (2 * blen) + m) * k) mod MLKEM_Q\ + by (rule mod_diff_cong[OF cong_hi cong_lo]) + also have \\ = (ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) * k mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + have \fqmul_int (xs ! (j * (2 * blen) + m + blen) - xs ! (j * (2 * blen) + m)) + (zetas_int ! ki) mod MLKEM_Q = + fqmul_int ((ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) * k) + (zetas_int ! ki) mod MLKEM_Q\ + by (rule fqmul_int_mod_cong[OF diff_cong]) + also have \\ = fqmul_int (ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) + (zetas_int ! ki) * k mod MLKEM_Q\ + by (rule fqmul_int_linear_mod) + finally show ?thesis . + qed + have val_xs: \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen xs ! i = + fqmul_int (xs ! (j * (2 * blen) + m + blen) - xs ! (j * (2 * blen) + m)) (zetas_int ! ki)\ + unfolding im by (rule invntt_inner_loop_int_high_val[OF mb le_refl block_len_xs]) + have val_ys: \invntt_inner_loop_int (zetas_int ! ki) (j * (2 * blen)) blen blen ys ! i = + fqmul_int (ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) (zetas_int ! ki)\ + unfolding im by (rule invntt_inner_loop_int_high_val[OF mb le_refl block_len_ys]) + have \invntt_layer_int l xs ! i mod MLKEM_Q = + fqmul_int (xs ! (j * (2 * blen) + m + blen) - xs ! (j * (2 * blen) + m)) (zetas_int ! ki) mod MLKEM_Q\ + using decomp_xs val_xs by simp + also have \\ = fqmul_int (ys ! (j * (2 * blen) + m + blen) - ys ! (j * (2 * blen) + m)) (zetas_int ! ki) * k mod MLKEM_Q\ + by (rule fqmul_cong) + also have \\ = invntt_layer_int l ys ! i * k mod MLKEM_Q\ + using decomp_ys val_ys by simp + finally show ?thesis . + qed +qed + +text \Prescaling identity: 7 layer pairs contribute @{text "2^7 = 128"}, + and @{text "fqmul(x, 1441)"} contributes @{text "1441 / 2^{16} \ 512 (mod Q)"}, + giving @{text "128 \ 512 = 2^{16}"}.\ + +lemma prescale_factor: + shows \(128::int) * fqmul_int x 1441 mod MLKEM_Q = x * 2^16 mod MLKEM_Q\ +proof - + have s1: \fqmul_int x 1441 * 2^16 mod MLKEM_Q = x * 1441 mod MLKEM_Q\ + by (rule fqmul_int_cong) + have s2: \128 * (fqmul_int x 1441 * 2^16) mod MLKEM_Q = 128 * (x * 1441) mod MLKEM_Q\ + by (rule mod_mult_cong[OF refl s1]) + have s3: \(128 * 1441 :: int) mod MLKEM_Q = 2^32 mod MLKEM_Q\ + by eval + have s4: \x * (128 * 1441) mod MLKEM_Q = x * 2^32 mod MLKEM_Q\ + by (rule mod_mult_cong[OF refl s3]) + have pow: \(2::int)^32 = 2^16 * 2^16\ + by eval + have \128 * fqmul_int x 1441 * 2^16 mod MLKEM_Q = x * 2^16 * 2^16 mod MLKEM_Q\ + using s2 s4 pow by (simp add: algebra_simps) + thus ?thesis + by (rule mult_mod_cancel_right) eval +qed + +lemma invntt_outer_loop_step: + shows \invntt_outer_loop_int (Suc n) cs = invntt_outer_loop_int n (invntt_layer_int (Suc n) cs)\ +by (simp add: case_prod_beta Let_def invntt_layer_int_def) + +text \The NTT outer loop preserves mod-Q linear congruence: if every coefficient + of @{term xs} agrees mod Q with the corresponding coefficient of @{term ys} + scaled by @{term k}, then the same holds after applying any suffix of NTT layers.\ + +lemma ntt_outer_loop_linear_cong: + assumes cong: \\j < MLKEM_N. xs ! j mod MLKEM_Q = ys ! j * k mod MLKEM_Q\ + and len_xs: \length xs = MLKEM_N\ + and len_ys: \length ys = MLKEM_N\ + and i_lt: \i < MLKEM_N\ + and n_ge: \1 \ n\ + and n_le: \n \ 7\ + shows \ntt_outer_loop_int (2^(7-n)) n xs ! i mod MLKEM_Q = + ntt_outer_loop_int (2^(7-n)) n ys ! i * k mod MLKEM_Q\ +using assms proof (induction n arbitrary: xs ys) + case 0 + then show ?case by simp +next + case (Suc m) + define l where + \l = (8 - Suc m :: nat)\ + have l_ge: \1 \ l\ and l_le: \l \ 7\ + using Suc.prems unfolding l_def by auto + have k_eq: \2^(7 - Suc m) = (2::nat)^(l - 1)\ + unfolding l_def using Suc.prems by auto + have k_eq2: \2^(7 - m) = (2::nat)^l\ + unfolding l_def using Suc.prems by auto + have m_eq: \Suc m = 8 - l\ + unfolding l_def using Suc.prems by auto + have m_eq2: \m = 7 - l\ + unfolding l_def using Suc.prems by auto + \ \Peel layer l from NTT outer loop\ + have xs_step: \ntt_outer_loop_int (2^(l-1)) (8-l) xs = + ntt_outer_loop_int (2^l) (7-l) (ntt_layer_int l xs)\ + by (rule ntt_outer_loop_step_layer[OF l_ge l_le]) + have ys_step: \ntt_outer_loop_int (2^(l-1)) (8-l) ys = + ntt_outer_loop_int (2^l) (7-l) (ntt_layer_int l ys)\ + by (rule ntt_outer_loop_step_layer[OF l_ge l_le]) + \ \Layer l preserves the congruence\ + have layer_cong: \\j < MLKEM_N. ntt_layer_int l xs ! j mod MLKEM_Q = + ntt_layer_int l ys ! j * k mod MLKEM_Q\ + using ntt_layer_int_linear_cong[OF Suc.prems(1) Suc.prems(2) Suc.prems(3) l_ge l_le] by auto + have len_lx: \length (ntt_layer_int l xs) = MLKEM_N\ + using Suc.prems(2) ntt_layer_int_length by auto + have len_ly: \length (ntt_layer_int l ys) = MLKEM_N\ + using Suc.prems(3) ntt_layer_int_length by auto + show ?case + proof (cases \m = 0\) + case True + \ \Base: Suc m = 1, layer 7, loop terminates after one step\ + have \l = 7\ + using True l_def Suc.prems by auto + hence \ntt_outer_loop_int (2^(7 - Suc m)) (Suc m) xs = + ntt_layer_int 7 xs\ + using xs_step k_eq m_eq by simp + moreover have \ntt_outer_loop_int (2^(7 - Suc m)) (Suc m) ys = + ntt_layer_int 7 ys\ + using ys_step k_eq m_eq \l = 7\ by simp + moreover have \ntt_layer_int 7 xs ! i mod MLKEM_Q = + ntt_layer_int 7 ys ! i * k mod MLKEM_Q\ + using layer_cong \l = 7\ Suc.prems(4) by auto + ultimately show ?thesis + by simp + next + case False + hence m_ge: \1 \ m\ + by simp + have m_le: \m \ 7\ + using Suc.prems by simp + \ \Apply IH to the remaining layers\ + have IH: \ntt_outer_loop_int (2^(7-m)) m (ntt_layer_int l xs) ! i mod MLKEM_Q = + ntt_outer_loop_int (2^(7-m)) m (ntt_layer_int l ys) ! i * k mod MLKEM_Q\ + by (rule Suc.IH[OF layer_cong len_lx len_ly Suc.prems(4) m_ge m_le]) + show ?thesis + using xs_step ys_step IH k_eq k_eq2 m_eq m_eq2 by simp + qed +qed + +text \Analogous linear congruence for the inverse NTT outer loop.\ + +lemma invntt_outer_loop_linear_cong: + assumes cong: \\j < MLKEM_N. xs ! j mod MLKEM_Q = ys ! j * k mod MLKEM_Q\ + and len_xs: \length xs = MLKEM_N\ + and len_ys: \length ys = MLKEM_N\ + and i_lt: \i < MLKEM_N\ + and n_le: \n \ 7\ + shows \invntt_outer_loop_int n xs ! i mod MLKEM_Q = + invntt_outer_loop_int n ys ! i * k mod MLKEM_Q\ +using assms proof (induction n arbitrary: xs ys) + case 0 + then show ?case by simp +next + case (Suc m) + define l where + \l = Suc m\ + have l_ge: \1 \ l\ and l_le: \l \ 7\ + using Suc.prems unfolding l_def by auto + have layer_cong: \\j < MLKEM_N. invntt_layer_int l xs ! j mod MLKEM_Q = + invntt_layer_int l ys ! j * k mod MLKEM_Q\ + using invntt_layer_int_linear_cong[OF Suc.prems(1) Suc.prems(2) Suc.prems(3) l_ge l_le] by auto + have len_lx: \length (invntt_layer_int l xs) = MLKEM_N\ + using Suc.prems(2) invntt_layer_int_length by auto + have len_ly: \length (invntt_layer_int l ys) = MLKEM_N\ + using Suc.prems(3) invntt_layer_int_length by auto + have m_le: \m \ 7\ + using Suc.prems by simp + have \invntt_outer_loop_int (Suc m) xs = invntt_outer_loop_int m (invntt_layer_int l xs)\ + unfolding l_def by (rule invntt_outer_loop_step) + moreover have \invntt_outer_loop_int (Suc m) ys = invntt_outer_loop_int m (invntt_layer_int l ys)\ + unfolding l_def by (rule invntt_outer_loop_step) + moreover have \invntt_outer_loop_int m (invntt_layer_int l xs) ! i mod MLKEM_Q = + invntt_outer_loop_int m (invntt_layer_int l ys) ! i * k mod MLKEM_Q\ + by (rule Suc.IH[OF layer_cong len_lx len_ly Suc.prems(4) m_le]) + ultimately show ?case + by simp +qed + +text \Full NTT/invNTT composition: 7 layer pairs each contribute factor 2.\ + +lemma ntt_invntt_outer_compose: + assumes len: \length cs = MLKEM_N\ + and i_lt: \i < MLKEM_N\ + shows \ntt_outer_loop_int 1 7 (invntt_outer_loop_int 7 cs) ! i mod MLKEM_Q = + (2::int)^7 * cs ! i mod MLKEM_Q\ +proof - + \ \Unfold invNTT layers 7 down to 1\ + define I7 where \I7 = invntt_layer_int 7 cs\ + define I6 where \I6 = invntt_layer_int 6 I7\ + define I5 where \I5 = invntt_layer_int 5 I6\ + define I4 where \I4 = invntt_layer_int 4 I5\ + define I3 where \I3 = invntt_layer_int 3 I4\ + define I2 where \I2 = invntt_layer_int 2 I3\ + define I1 where \I1 = invntt_layer_int 1 I2\ + have invntt_eq: \invntt_outer_loop_int 7 cs = I1\ + unfolding I1_def I2_def I3_def I4_def I5_def I6_def I7_def + by (simp only: numeral_eq_Suc pred_numeral_simps Num.BitM.simps + One_nat_def invntt_outer_loop_step invntt_outer_loop_int.simps(1)) + \ \Length lemmas\ + have [simp]: \length I7 = MLKEM_N\ \length I6 = MLKEM_N\ \length I5 = MLKEM_N\ + \length I4 = MLKEM_N\ \length I3 = MLKEM_N\ \length I2 = MLKEM_N\ \length I1 = MLKEM_N\ + unfolding I1_def I2_def I3_def I4_def I5_def I6_def I7_def + using len by (simp_all add: invntt_layer_int_length) + \ \Build accumulated result from innermost to outermost layer\ + \ \Step 7 (base): NTT layer 7 cancels directly\ + have h7: \ntt_outer_loop_int 64 1 I7 ! i mod MLKEM_Q = 2 * cs ! i mod MLKEM_Q\ + proof - + have \ntt_outer_loop_int 64 1 I7 = ntt_outer_loop_int 128 0 (ntt_layer_int 7 I7)\ + using ntt_outer_loop_step_layer[where l=7] by simp + moreover have \ntt_outer_loop_int 128 0 (ntt_layer_int 7 I7) = ntt_layer_int 7 I7\ by simp + moreover have \\j + unfolding I7_def using ntt_invntt_layer_inverse[of cs 7] len by auto + ultimately show ?thesis using i_lt by auto + qed + \ \Step 6: accumulate 2^2\ + have h6: \ntt_outer_loop_int 32 2 I6 ! i mod MLKEM_Q = 2^2 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding I6_def using ntt_invntt_layer_inverse[of I7 6] by (auto simp: mult.commute[of 2]) + have \ntt_outer_loop_int 32 2 I6 ! i mod MLKEM_Q = + ntt_outer_loop_int 64 1 (ntt_layer_int 6 I6) ! i mod MLKEM_Q\ + using ntt_outer_loop_step_layer[where l=6] by simp + also have \\ = ntt_outer_loop_int 64 1 I7 ! i * 2 mod MLKEM_Q\ + using ntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=1] + by (simp add: ntt_layer_int_length) + also have \\ = 2 * cs ! i * 2 mod MLKEM_Q\ by (rule mod_mult_cong[OF h7 refl]) + also have \\ = 2^2 * cs ! i mod MLKEM_Q\ by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 5: accumulate 2^3\ + have h5: \ntt_outer_loop_int 16 3 I5 ! i mod MLKEM_Q = 2^3 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding I5_def using ntt_invntt_layer_inverse[of I6 5] by (auto simp: mult.commute[of 2]) + have \ntt_outer_loop_int 16 3 I5 ! i mod MLKEM_Q = + ntt_outer_loop_int 32 2 (ntt_layer_int 5 I5) ! i mod MLKEM_Q\ + using ntt_outer_loop_step_layer[where l=5] by simp + also have \\ = ntt_outer_loop_int 32 2 I6 ! i * 2 mod MLKEM_Q\ + using ntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=2] + by (simp add: ntt_layer_int_length) + also have \\ = 2^2 * cs ! i * 2 mod MLKEM_Q\ by (rule mod_mult_cong[OF h6 refl]) + also have \\ = 2^3 * cs ! i mod MLKEM_Q\ by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 4: accumulate 2^4\ + have h4: \ntt_outer_loop_int 8 4 I4 ! i mod MLKEM_Q = 2^4 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding I4_def using ntt_invntt_layer_inverse[of I5 4] by (auto simp: mult.commute[of 2]) + have \ntt_outer_loop_int 8 4 I4 ! i mod MLKEM_Q = + ntt_outer_loop_int 16 3 (ntt_layer_int 4 I4) ! i mod MLKEM_Q\ + using ntt_outer_loop_step_layer[where l=4] by simp + also have \\ = ntt_outer_loop_int 16 3 I5 ! i * 2 mod MLKEM_Q\ + using ntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=3] + by (simp add: ntt_layer_int_length) + also have \\ = 2^3 * cs ! i * 2 mod MLKEM_Q\ by (rule mod_mult_cong[OF h5 refl]) + also have \\ = 2^4 * cs ! i mod MLKEM_Q\ by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 3: accumulate 2^5\ + have h3: \ntt_outer_loop_int 4 5 I3 ! i mod MLKEM_Q = 2^5 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding I3_def using ntt_invntt_layer_inverse[of I4 3] by (auto simp: mult.commute[of 2]) + have \ntt_outer_loop_int 4 5 I3 ! i mod MLKEM_Q = + ntt_outer_loop_int 8 4 (ntt_layer_int 3 I3) ! i mod MLKEM_Q\ + using ntt_outer_loop_step_layer[where l=3] by simp + also have \\ = ntt_outer_loop_int 8 4 I4 ! i * 2 mod MLKEM_Q\ + using ntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=4] + by (simp add: ntt_layer_int_length) + also have \\ = 2^4 * cs ! i * 2 mod MLKEM_Q\ by (rule mod_mult_cong[OF h4 refl]) + also have \\ = 2^5 * cs ! i mod MLKEM_Q\ by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 2: accumulate 2^6\ + have h2: \ntt_outer_loop_int 2 6 I2 ! i mod MLKEM_Q = 2^6 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding I2_def using ntt_invntt_layer_inverse[of I3 2] by (auto simp: mult.commute[of 2]) + have \ntt_outer_loop_int 2 6 I2 ! i mod MLKEM_Q = + ntt_outer_loop_int 4 5 (ntt_layer_int 2 I2) ! i mod MLKEM_Q\ + using ntt_outer_loop_step_layer[where l=2] by simp + also have \\ = ntt_outer_loop_int 4 5 I3 ! i * 2 mod MLKEM_Q\ + using ntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=5] + by (simp add: ntt_layer_int_length) + also have \\ = 2^5 * cs ! i * 2 mod MLKEM_Q\ by (rule mod_mult_cong[OF h3 refl]) + also have \\ = 2^6 * cs ! i mod MLKEM_Q\ by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 1: accumulate 2^7\ + have h1: \ntt_outer_loop_int 1 7 I1 ! i mod MLKEM_Q = 2^7 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding I1_def using ntt_invntt_layer_inverse[of I2 1] by (auto simp: mult.commute[of 2]) + have \ntt_outer_loop_int 1 7 I1 ! i mod MLKEM_Q = + ntt_outer_loop_int 2 6 (ntt_layer_int 1 I1) ! i mod MLKEM_Q\ + using ntt_outer_loop_step_layer[where l=1] by simp + also have \\ = ntt_outer_loop_int 2 6 I2 ! i * 2 mod MLKEM_Q\ + using ntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=6] + by (simp add: ntt_layer_int_length) + also have \\ = 2^6 * cs ! i * 2 mod MLKEM_Q\ by (rule mod_mult_cong[OF h2 refl]) + also have \\ = 2^7 * cs ! i mod MLKEM_Q\ by (simp add: algebra_simps) + finally show ?thesis . + qed + show ?thesis using h1 invntt_eq by simp +qed + +text \Reverse composition: invNTT(NTT(cs)) gives the same @{text "2^7"} factor.\ + +lemma invntt_ntt_outer_compose: + assumes len: \length cs = MLKEM_N\ + and i_lt: \i < MLKEM_N\ + shows \invntt_outer_loop_int 7 (ntt_outer_loop_int 1 7 cs) ! i mod MLKEM_Q = + (2::int)^7 * cs ! i mod MLKEM_Q\ +proof - + \ \Unfold NTT layers 1 through 7\ + define N1 where \N1 = ntt_layer_int 1 cs\ + define N2 where \N2 = ntt_layer_int 2 N1\ + define N3 where \N3 = ntt_layer_int 3 N2\ + define N4 where \N4 = ntt_layer_int 4 N3\ + define N5 where \N5 = ntt_layer_int 5 N4\ + define N6 where \N6 = ntt_layer_int 6 N5\ + define N7 where \N7 = ntt_layer_int 7 N6\ + have ntt_eq: \ntt_outer_loop_int 1 7 cs = N7\ + proof - + have \ntt_outer_loop_int 1 7 cs = ntt_outer_loop_int 2 6 N1\ + unfolding N1_def using ntt_outer_loop_step_layer[where l=1] by simp + also have \\ = ntt_outer_loop_int 4 5 N2\ + unfolding N2_def using ntt_outer_loop_step_layer[where l=2] by simp + also have \\ = ntt_outer_loop_int 8 4 N3\ + unfolding N3_def using ntt_outer_loop_step_layer[where l=3] by simp + also have \\ = ntt_outer_loop_int 16 3 N4\ + unfolding N4_def using ntt_outer_loop_step_layer[where l=4] by simp + also have \\ = ntt_outer_loop_int 32 2 N5\ + unfolding N5_def using ntt_outer_loop_step_layer[where l=5] by simp + also have \\ = ntt_outer_loop_int 64 1 N6\ + unfolding N6_def using ntt_outer_loop_step_layer[where l=6] by simp + also have \\ = N7\ + unfolding N7_def using ntt_outer_loop_step_layer[where l=7] by simp + finally show ?thesis . + qed + \ \Length lemmas\ + have [simp]: \length N1 = MLKEM_N\ \length N2 = MLKEM_N\ \length N3 = MLKEM_N\ + \length N4 = MLKEM_N\ \length N5 = MLKEM_N\ \length N6 = MLKEM_N\ \length N7 = MLKEM_N\ + unfolding N1_def N2_def N3_def N4_def N5_def N6_def N7_def using len + by (simp_all add: ntt_layer_int_length) + \ \Base case: invNTT layer 1 cancels NTT layer 1\ + have h1: \invntt_outer_loop_int 1 N1 ! i mod MLKEM_Q = 2 * cs ! i mod MLKEM_Q\ + proof - + have \invntt_outer_loop_int 1 N1 = invntt_layer_int 1 N1\ + by (simp only: One_nat_def invntt_outer_loop_step invntt_outer_loop_int.simps(1)) + moreover have \\j + unfolding N1_def using invntt_ntt_layer_inverse[of cs 1] len by auto + ultimately show ?thesis + using i_lt by auto + qed + \ \Step 2: accumulate 2^2\ + have h2: \invntt_outer_loop_int 2 N2 ! i mod MLKEM_Q = 2^2 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding N2_def using invntt_ntt_layer_inverse[of N1 2] by (auto simp: mult.commute[of 2]) + have \invntt_outer_loop_int 2 N2 ! i mod MLKEM_Q = + invntt_outer_loop_int 1 (invntt_layer_int 2 N2) ! i mod MLKEM_Q\ + using invntt_outer_loop_step[where n=1] by (simp add: eval_nat_numeral del: invntt_outer_loop_int.simps(2)) + also have \\ = invntt_outer_loop_int 1 N1 ! i * 2 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=1] by (simp add: invntt_layer_int_length) + also have \\ = 2 * cs ! i * 2 mod MLKEM_Q\ + by (rule mod_mult_cong[OF h1 refl]) + also have \\ = 2^2 * cs ! i mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 3: accumulate 2^3\ + have h3: \invntt_outer_loop_int 3 N3 ! i mod MLKEM_Q = 2^3 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding N3_def using invntt_ntt_layer_inverse[of N2 3] by (auto simp: mult.commute[of 2]) + have \invntt_outer_loop_int 3 N3 ! i mod MLKEM_Q = + invntt_outer_loop_int 2 (invntt_layer_int 3 N3) ! i mod MLKEM_Q\ + using invntt_outer_loop_step[where n=2] by (simp add: eval_nat_numeral del: invntt_outer_loop_int.simps(2)) + also have \\ = invntt_outer_loop_int 2 N2 ! i * 2 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=2] by (simp add: invntt_layer_int_length) + also have \\ = 2^2 * cs ! i * 2 mod MLKEM_Q\ + by (rule mod_mult_cong[OF h2 refl]) + also have \\ = 2^3 * cs ! i mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 4: accumulate 2^4\ + have h4: \invntt_outer_loop_int 4 N4 ! i mod MLKEM_Q = 2^4 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding N4_def using invntt_ntt_layer_inverse[of N3 4] by (auto simp: mult.commute[of 2]) + have \invntt_outer_loop_int 4 N4 ! i mod MLKEM_Q = + invntt_outer_loop_int 3 (invntt_layer_int 4 N4) ! i mod MLKEM_Q\ + using invntt_outer_loop_step[where n=3] by (simp add: eval_nat_numeral del: invntt_outer_loop_int.simps(2)) + also have \\ = invntt_outer_loop_int 3 N3 ! i * 2 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=3] by (simp add: invntt_layer_int_length) + also have \\ = 2^3 * cs ! i * 2 mod MLKEM_Q\ + by (rule mod_mult_cong[OF h3 refl]) + also have \\ = 2^4 * cs ! i mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 5: accumulate 2^5\ + have h5: \invntt_outer_loop_int 5 N5 ! i mod MLKEM_Q = 2^5 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding N5_def using invntt_ntt_layer_inverse[of N4 5] by (auto simp: mult.commute[of 2]) + have \invntt_outer_loop_int 5 N5 ! i mod MLKEM_Q = + invntt_outer_loop_int 4 (invntt_layer_int 5 N5) ! i mod MLKEM_Q\ + using invntt_outer_loop_step[where n=4] by (simp add: eval_nat_numeral del: invntt_outer_loop_int.simps(2)) + also have \\ = invntt_outer_loop_int 4 N4 ! i * 2 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=4] by (simp add: invntt_layer_int_length) + also have \\ = 2^4 * cs ! i * 2 mod MLKEM_Q\ + by (rule mod_mult_cong[OF h4 refl]) + also have \\ = 2^5 * cs ! i mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 6: accumulate 2^6\ + have h6: \invntt_outer_loop_int 6 N6 ! i mod MLKEM_Q = 2^6 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding N6_def using invntt_ntt_layer_inverse[of N5 6] by (auto simp: mult.commute[of 2]) + have \invntt_outer_loop_int 6 N6 ! i mod MLKEM_Q = + invntt_outer_loop_int 5 (invntt_layer_int 6 N6) ! i mod MLKEM_Q\ + using invntt_outer_loop_step[where n=5] by (simp add: eval_nat_numeral del: invntt_outer_loop_int.simps(2)) + also have \\ = invntt_outer_loop_int 5 N5 ! i * 2 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=5] by (simp add: invntt_layer_int_length) + also have \\ = 2^5 * cs ! i * 2 mod MLKEM_Q\ + by (rule mod_mult_cong[OF h5 refl]) + also have \\ = 2^6 * cs ! i mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + \ \Step 7: accumulate 2^7\ + have h7: \invntt_outer_loop_int 7 N7 ! i mod MLKEM_Q = 2^7 * cs ! i mod MLKEM_Q\ + proof - + have cancel: \\j + unfolding N7_def using invntt_ntt_layer_inverse[of N6 7] by (auto simp: mult.commute[of 2]) + have \invntt_outer_loop_int 7 N7 ! i mod MLKEM_Q = + invntt_outer_loop_int 6 (invntt_layer_int 7 N7) ! i mod MLKEM_Q\ + using invntt_outer_loop_step[where n=6] by (simp add: eval_nat_numeral del: invntt_outer_loop_int.simps(2)) + also have \\ = invntt_outer_loop_int 6 N6 ! i * 2 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cancel _ _ i_lt, where n=6] by (simp add: invntt_layer_int_length) + also have \\ = 2^6 * cs ! i * 2 mod MLKEM_Q\ + by (rule mod_mult_cong[OF h6 refl]) + also have \\ = 2^7 * cs ! i mod MLKEM_Q\ + by (simp add: algebra_simps) + finally show ?thesis . + qed + show ?thesis + using h7 ntt_eq by simp +qed + +text \Direction 2 (easier): @{text "NTT(invNTT(cs)) \ cs \ R (mod Q)"}. + The 7 layer pairs each contribute a factor of 2, giving @{text "2^7 = 128"}. + The prescaling by 1441 contributes @{text "1441 \ R^{-1} \ 2^9 (mod Q)"}. + Together: @{text "128 \ 2^9 = 2^16 = R"}.\ + +theorem poly_ntt_invntt_tomont: + assumes \length cs = 256\ + and \i < 256\ + shows \poly_ntt_int (poly_invntt_tomont_int cs) ! i mod MLKEM_Q = + cs ! i * 2^16 mod MLKEM_Q\ +unfolding poly_ntt_int_def poly_invntt_tomont_int_def + using ntt_invntt_outer_compose[of \List.map (\c. fqmul_int c 1441) cs\ i] + assms prescale_factor[of \cs ! i\] by simp + +text \Direction 1: @{text "invNTT(NTT(cs)) \ cs \ R (mod Q)"}. + Uses linearity to move the prescaling past the NTT layers.\ + +theorem poly_invntt_tomont_ntt: + assumes \length cs = 256\ + and \i < 256\ + shows \poly_invntt_tomont_int (poly_ntt_int cs) ! i mod MLKEM_Q = + cs ! i * 2^16 mod MLKEM_Q\ +proof - + define ntt_cs where + \ntt_cs = ntt_outer_loop_int 1 7 cs\ + have len_ntt: \length ntt_cs = MLKEM_N\ + unfolding ntt_cs_def using assms by (simp add: ntt_outer_loop_int_length) + have cong: \\j < MLKEM_N. (List.map (\c. fqmul_int c 1441) ntt_cs) ! j mod MLKEM_Q = + ntt_cs ! j * fqmul_int 1 1441 mod MLKEM_Q\ + proof (intro allI impI) + fix j + assume \j < MLKEM_N\ + then show \(List.map (\c. fqmul_int c 1441) ntt_cs) ! j mod MLKEM_Q = + ntt_cs ! j * fqmul_int 1 1441 mod MLKEM_Q\ + using len_ntt fqmul_int_linear_mod[of 1 \ntt_cs ! j\ 1441] by (simp add: mult.commute) + qed + have step1: \invntt_outer_loop_int 7 (List.map (\c. fqmul_int c 1441) ntt_cs) ! i mod MLKEM_Q = + invntt_outer_loop_int 7 ntt_cs ! i * fqmul_int 1 1441 mod MLKEM_Q\ + using invntt_outer_loop_linear_cong[OF cong _ len_ntt assms(2), where n=7] by (simp add: len_ntt) + also have \\ = 2^7 * cs ! i * fqmul_int 1 1441 mod MLKEM_Q\ + by (rule mod_mult_cong[OF invntt_ntt_outer_compose[OF assms, folded ntt_cs_def] refl]) + also have \\ = cs ! i * (128 * fqmul_int 1 1441) mod MLKEM_Q\ + by (simp add: algebra_simps) + also have \\ = cs ! i * 2^16 mod MLKEM_Q\ + proof - + have eq: \128 * fqmul_int 1 1441 mod MLKEM_Q = 2^16 mod MLKEM_Q\ + using prescale_factor[of 1] by simp + show ?thesis + using mod_mult_right_eq[of \cs ! i\ \128 * fqmul_int 1 1441\ MLKEM_Q] + mod_mult_right_eq[of \cs ! i\ \2^16\ MLKEM_Q] eq by simp + qed + finally show ?thesis + unfolding poly_invntt_tomont_int_def poly_ntt_int_def ntt_cs_def . +qed + +(*<*) +end +(*>*) + From 4236224c61065591834451d63d946bafba5d4269 Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:00:19 +0000 Subject: [PATCH 06/11] proofs/isabelle: add C functional correctness proofs for scalar helpers Add MLKEM_FC_Scalar.thy with contracts for barrett_reduce, poly_add, poly_sub, value_barrier, cast helpers, ct_sel, and ct_cmask. --- proofs/isabelle/MLKEM_FC_Scalar.thy | 676 ++++++++++++++++++++++++++++ 1 file changed, 676 insertions(+) create mode 100644 proofs/isabelle/MLKEM_FC_Scalar.thy diff --git a/proofs/isabelle/MLKEM_FC_Scalar.thy b/proofs/isabelle/MLKEM_FC_Scalar.thy new file mode 100644 index 0000000000..152fdb6bfe --- /dev/null +++ b/proofs/isabelle/MLKEM_FC_Scalar.thy @@ -0,0 +1,676 @@ +(*<*) +theory MLKEM_FC_Scalar + imports MLKEM_Refinement +begin +(*>*) + +text \ + Functional correctness proofs for scalar C helper functions from + @{verbatim \poly.c\}: Barrett reduction, cast helpers, constant-time + primitives, and polynomial addition/subtraction. Each function is + verified against its abstract specification from @{text MLKEM_Spec} + and @{text MLKEM_Refinement}. +\ + +section \Scalar C Verification\ + +(*<*) +context c_mlk_machine_model +begin + +declare c_mlk_cast_uint16_to_int16_def [micro_rust_simps del] +declare c_mlk_cast_int16_to_uint16_def [micro_rust_simps del] +declare c_mlk_ct_cmask_neg_i16_def [micro_rust_simps del] +declare c_mlk_ct_cmask_nonzero_u16_def [micro_rust_simps del] +declare nondet_choice_def [micro_rust_simps del] +declare bind2_unseq_def [micro_rust_simps del] +declare c_mlk_ct_sel_int16_def [micro_rust_simps del] +declare c_mlk_barrett_reduce_def [micro_rust_simps del] +declare c_mlk_value_barrier_i32_def [micro_rust_simps del] +declare c_mlk_cast_int32_to_uint16_def [micro_rust_simps del] +(*>*) + +subsection \@{text mlk_barrett_reduce} contract\ + +text \The C implementation of Barrett reduction is verified against + @{const barrett_reduce_int} from @{text MLKEM_Spec}. The contract + guarantees that the result equals the abstract specification at the + @{const sint} level.\ + +definition c_mlk_barrett_reduce_contract :: \c_short \ ('s::{sepalg}, c_short, 'b) function_contract\ where + \c_mlk_barrett_reduce_contract a \ + let pre = can_alloc_reference; + post = \r. can_alloc_reference \ + \sint r = barrett_reduce_int (sint a)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_barrett_reduce_contract(*>*) + +lemma barrett_sint_bounds: + fixes a :: \16 sword\ + defines \sa \ sint a\ + shows \sa \ -32768\ \sa < 32768\ + and \20159 * sa \ -(2^31)\ + and \20159 * sa < 2^31\ + and \20159 * sa + 2^25 \ -(2^31)\ + and \20159 * sa + 2^25 < 2^31\ +proof - + have \sa \ -32768\ and \sa < 32768\ + using sint_range_size[where w=a] by (auto simp: sa_def word_size) + then show \sa \ -32768\ and \sa < 32768\ and + \20159 * sa \ -(2^31)\ and \20159 * sa < 2^31\ and + \20159 * sa + 2^25 \ -(2^31)\ and \20159 * sa + 2^25 < 2^31\ + by auto +qed + +lemma barrett_stb31_prod: + shows \signed_take_bit 31 (20159 * sint (a :: 16 sword)) = 20159 * sint a\ +using barrett_sint_bounds[of a] by (intro signed_take_bit_int_eq_self) auto + +lemma barrett_stb31_sum: + shows \signed_take_bit 31 (20159 * sint (a :: 16 sword) + 33554432) = + 20159 * sint a + 33554432\ +using barrett_sint_bounds[of a] by (intro signed_take_bit_int_eq_self) auto + +lemma barrett_quotient_bounds: + fixes a :: \16 sword\ + defines \t \ (20159 * sint a + 33554432) div (67108864 :: int)\ + shows \t \ -10\ + and \t \ 10\ +proof - + have sa: \sint a \ -32768\ \sint a < 32768\ + using barrett_sint_bounds[of a] by auto + have m1: \(20159::int) * (-32768) \ 20159 * sint a\ + using sa by (intro mult_left_mono) auto + have m2: \(20159::int) * sint a \ 20159 * 32767\ + using sa by (intro mult_left_mono) auto + have lower: \(-627015680::int) \ 20159 * sint a + 33554432\ + using m1 by simp + have upper: \20159 * sint a + 33554432 \ (694104385::int)\ + using m2 by simp + have \(-627015680::int) div 67108864 \ (20159 * sint a + 33554432) div 67108864\ + using lower by (rule zdiv_mono1) auto + thus \t \ -10\ + unfolding t_def by simp + have \(20159 * sint a + 33554432) div 67108864 \ (694104385::int) div 67108864\ + using upper by (rule zdiv_mono1) auto + thus \t \ 10\ + unfolding t_def by simp +qed + +lemma barrett_sint_woi: + shows \sint ((word_of_int ((20159 * sint (a::16 sword) + 33554432) div 67108864)) :: 32 sword) = + (20159 * sint a + 33554432) div 67108864\ +proof - + have t: \(20159 * sint a + 33554432) div 67108864 \ -10\ + \(20159 * sint a + 33554432) div 67108864 \ 10\ + using barrett_quotient_bounds[of a] by auto + show ?thesis + by (rule sint_of_int_eq) (use t in auto) +qed + +lemma barrett_stb31_tq: + shows \signed_take_bit 31 ((20159 * sint (a::16 sword) + 33554432) div 67108864 * 3329) = + (20159 * sint a + 33554432) div 67108864 * 3329\ +proof - + have \(20159 * sint a + 33554432) div 67108864 \ -10\ and \(20159 * sint a + 33554432) div 67108864 \ 10\ + using barrett_quotient_bounds[of a] by auto + moreover from this have \(3329::int) * (-10) \ 3329 * ((20159 * sint a + 33554432) div 67108864)\ + by (intro mult_left_mono) auto + moreover from calculation have \(3329::int) * ((20159 * sint a + 33554432) div 67108864) \ 3329 * 10\ + by (intro mult_left_mono) auto + ultimately show ?thesis + by (intro signed_take_bit_int_eq_self) (simp_all add: algebra_simps) +qed + +lemma barrett_result_bounds: + fixes a :: \16 sword\ + defines \sa \ sint a\ + defines \t \ (20159 * sa + 33554432) div (67108864 :: int)\ + shows \-1664 \ sa - t * 3329\ + and \sa - t * 3329 \ 1664\ +proof - + define r where + \r = (20159 * sa + 33554432) mod (67108864 :: int)\ + have sa: \sa \ -32768\ \sa < 32768\ + using barrett_sint_bounds[of a] by (auto simp: sa_def) + have r_ge: \r \ 0\ + unfolding r_def by (intro pos_mod_sign) auto + have r_lt: \r < 67108864\ + unfolding r_def by (intro pos_mod_bound) auto + have t_r_eq: \t * 67108864 + r = 20159 * sa + 33554432\ + unfolding t_def r_def using div_mult_mod_eq[of \20159 * sa + 33554432\ 67108864] + by (simp add: algebra_simps) + \ \Key: 67108864 * (sa - t * 3329) = 3329 * r - 447 * sa - 111702704128 + using 20159 * 3329 = 67108864 + 447\ + have key: \67108864 * (sa - t * 3329) = 3329 * r - 447 * sa - 111702704128\ + proof - + from t_r_eq have \3329 * (t * 67108864 + r) = 3329 * (20159 * sa + 33554432)\ + by auto + thus ?thesis + by (simp add: algebra_simps) + qed + show \-1664 \ sa - t * 3329\ + proof (rule ccontr) + assume \\ (-1664 \ sa - t * 3329)\ + hence \sa - t * 3329 \ -1665\ + by auto + hence \67108864 * (sa - t * 3329) \ 67108864 * (-1665)\ + by (intro mult_left_mono) auto + moreover have \3329 * r - 447 * sa - 111702704128 \ 0 - 447 * 32767 - 111702704128\ + using r_ge sa by (intro diff_mono diff_mono) auto + ultimately show False + using key by simp + qed + show \sa - t * 3329 \ 1664\ + proof (rule ccontr) + assume \\ (sa - t * 3329 \ 1664)\ + hence \sa - t * 3329 \ 1665\ + by auto + hence \67108864 * (sa - t * 3329) \ 67108864 * 1665\ + by (intro mult_left_mono) auto + moreover have \3329 * r - 447 * sa - 111702704128 \ 3329 * 67108863 + 447 * 32768 - 111702704128\ + using r_lt sa by linarith + ultimately show False + using key by simp + qed +qed + +lemma barrett_result_stb31: + fixes a :: \16 sword\ + defines \t \ (20159 * sint a + 33554432) div (67108864 :: int)\ + shows \-2147483648 \ sint a - t * 3329\ + and \sint a - t * 3329 < 2147483648\ +using barrett_result_bounds[of a] by (auto simp: t_def) + +lemma barrett_quotientQ_stb31: + fixes a :: \16 sword\ + defines \t \ (20159 * sint a + 33554432) div (67108864 :: int)\ + shows \-2147483648 \ t * 3329\ + and \t * 3329 < 2147483648\ +using barrett_result_bounds[of a] barrett_sint_bounds[of a] by (auto simp: t_def) + +lemma barrett_result_sint: + fixes a :: \16 sword\ + defines \t \ (20159 * sint a + 33554432) div (67108864 :: int)\ + shows \sint (SCAST(32 signed \ 16 signed) + (SCAST(16 signed \ 32 signed) a - word_of_int t * (0xD01 :: 32 sword))) = sint a - t * 3329\ +proof - + have res: \-1664 \ sint a - t * 3329\ \sint a - t * 3329 \ 1664\ + using barrett_result_bounds[of a] by (auto simp: t_def) + \ \sint of the multiplication\ + have mult_sint: \sint (word_of_int t * (0xD01 :: 32 sword)) = t * 3329\ + proof - + have \sint (word_of_int t :: 32 sword) = t\ + using barrett_sint_woi[of a] by (simp add: t_def) + moreover have \sint (0xD01 :: 32 sword) = (3329 :: int)\ + by simp + moreover have \signed_take_bit 31 (t * 3329) = t * 3329\ + using barrett_stb31_tq[of a] by (simp add: t_def) + ultimately show ?thesis + by (simp add: sint_word_mult word_size) + qed + \ \sint of the subtraction\ + have sub_sint: \sint (SCAST(16 signed \ 32 signed) a - word_of_int t * 0xD01) = sint a - t * 3329\ + proof - + have up: \sint (SCAST(16 signed \ 32 signed) a) = sint a\ + by (simp add: sint_up_scast is_up) + have \signed_take_bit 31 (sint a - t * 3329) = sint a - t * 3329\ + using res by (intro signed_take_bit_int_eq_self) auto + thus ?thesis + by (simp add: sint_word_diff word_size up mult_sint) + qed + \ \sint of the downcast\ + have scast_sint: \sint (SCAST(32 signed \ 16 signed) w) = signed_take_bit 15 (sint w)\ for w :: \32 sword\ + by (simp only: of_int_sint_scast[symmetric] Word.sint_sbintrunc') simp + have \signed_take_bit 15 (sint a - t * 3329) = sint a - t * 3329\ + using res by (intro signed_take_bit_int_eq_self) auto + thus ?thesis + by (simp add: scast_sint sub_sint) +qed + +(*<*) +lemma bounded_while_literal_false [micro_rust_simps]: + shows \bounded_while n (\False) body = \()\ +by (induction n) (simp_all add: micro_rust_simps) + +lemma c_signed_truthy_zero [micro_rust_simps]: + shows \c_signed_truthy 0 = \False\ +by (simp add: c_signed_truthy_def) +(*>*) + +lemma c_mlk_barrett_reduce_spec [crush_specs]: + shows \\; c_mlk_barrett_reduce while_fuel a \\<^sub>F + c_mlk_barrett_reduce_contract a\ + apply (crush_boot f: c_mlk_barrett_reduce_def contract: c_mlk_barrett_reduce_contract_def) + apply (crush_base simp add: sint_up_scast is_up sint_word_ariths + barrett_stb31_prod barrett_stb31_sum + signed_take_bit_int_eq_self barrett_sint_bounds + barrett_sint_woi barrett_stb31_tq + c_signed_truthy_zero bounded_while_literal_false) + apply (all \(insert barrett_result_bounds[of a] barrett_quotient_bounds[of a] + barrett_sint_bounds[of a]; linarith)?\) + apply (simp_all add: barrett_result_sint barrett_reduce_int_def) + done + +subsection \@{text mlk_poly_add} contract\ + +text \ + The contract is self-contained: refinement, well-formedness, and + no-overflow are all expressed as pure assertions in the precondition. + The postcondition asserts the result refines the abstract polynomial sum. + No external assumptions needed on the specification theorem. +\ +definition c_mlk_poly_add_contract :: \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ c_mlk_poly \ + int_poly \ ('s, 'a, 'b) function_contract\ where + \c_mlk_poly_add_contract r gr vr ar b gb vb ab \ + let pre = can_alloc_reference \ + r \\\\ gr\vr \ + \refines_mlk_poly vr ar\ \ + b \\\\ gb\vb \ + \refines_mlk_poly vb ab\ \ + \no_overflow_add ar ab\; + post = \_. can_alloc_reference \ + (\gr' vr'. r \\\\ gr'\vr' \ + \refines_mlk_poly vr' (poly_add_int ar ab)\) \ + b \\\\ gb\vb + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_add_contract(*>*) + +lemma c_mlk_poly_add_spec: + shows \\; c_mlk_poly_add MLKEM_N r b \\<^sub>F c_mlk_poly_add_contract r gr vr ar b gb vb ab\ + apply (crush_boot f: c_mlk_poly_add_def contract: c_mlk_poly_add_contract_def) + apply crush_base + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr'. r \\\\ gr'\(update_c_mlk_poly_coeffs + (\_. take (MLKEM_N - k) (map2 (+) (c_mlk_poly_coeffs vr) (c_mlk_poly_coeffs vb)) + @ drop (MLKEM_N - k) (c_mlk_poly_coeffs vr)) vr)) + \ b \\\\ gb\vb + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\ \ \refines_mlk_poly vb ab\ + \ \no_overflow_add ar ab\\ + and INV'=\\k. (\gr'. r \\\\ gr'\(update_c_mlk_poly_coeffs + (\_. take (MLKEM_N - Suc k) (map2 (+) (c_mlk_poly_coeffs vr) (c_mlk_poly_coeffs vb)) + @ drop (MLKEM_N - Suc k) (c_mlk_poly_coeffs vr)) vr)) + \ b \\\\ gb\vb + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - Suc k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\ \ \refines_mlk_poly vb ab\ + \ \no_overflow_add ar ab\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Initialization + frame\ + by (crush_base simp: refines_mlk_poly_def c_mlk_poly.record_simps + poly_add_int_def no_overflow_add_def map2_map_map word_size + intro!: nth_equalityI sint_plus_no_overflow) + subgoal \ \Condition\ + by crush_base (simp_all add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat) + subgoal \ \Loop body\ + by (crush_base simp add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat + c_mlk_poly.record_simps nth_append refines_mlk_poly_def inv_list_step) + subgoal \ \Fuel exhaust\ + by crush_base + done + +subsection \@{text mlk_poly_sub} contract\ + +text \Coefficient-wise subtraction, symmetric to @{text poly_add} above. + Each output coefficient satisfies + @{text "r!i = a!i - b!i"} at the integer level.\ + +definition c_mlk_poly_sub_contract :: \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ c_mlk_poly \ + int_poly \ ('s, 'a, 'b) function_contract\ where + \c_mlk_poly_sub_contract r gr vr ar b gb vb ab \ + let pre = can_alloc_reference \ + r \\\\ gr\vr \ + \refines_mlk_poly vr ar\ \ + b \\\\ gb\vb \ + \refines_mlk_poly vb ab\ \ + \no_overflow_sub ar ab\; + post = \_. can_alloc_reference \ + (\gr' vr'. r \\\\ gr'\vr' \ + \refines_mlk_poly vr' (poly_sub_int ar ab)\) \ + b \\\\ gb\vb + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_sub_contract(*>*) + +lemma c_mlk_poly_sub_spec: + shows \\; c_mlk_poly_sub MLKEM_N r b \\<^sub>F c_mlk_poly_sub_contract r gr vr ar b gb vb ab\ + apply (crush_boot f: c_mlk_poly_sub_def contract: c_mlk_poly_sub_contract_def) + apply crush_base + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr'. r \\\\ gr'\(update_c_mlk_poly_coeffs + (\_. take (MLKEM_N - k) (map2 (-) (c_mlk_poly_coeffs vr) (c_mlk_poly_coeffs vb)) + @ drop (MLKEM_N - k) (c_mlk_poly_coeffs vr)) vr)) + \ b \\\\ gb\vb + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\ \ \refines_mlk_poly vb ab\ + \ \no_overflow_sub ar ab\\ + and INV'=\\k. (\gr'. r \\\\ gr'\(update_c_mlk_poly_coeffs + (\_. take (MLKEM_N - Suc k) (map2 (-) (c_mlk_poly_coeffs vr) (c_mlk_poly_coeffs vb)) + @ drop (MLKEM_N - Suc k) (c_mlk_poly_coeffs vr)) vr)) + \ b \\\\ gb\vb + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - Suc k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\ \ \refines_mlk_poly vb ab\ + \ \no_overflow_sub ar ab\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Initialization + frame\ + by (crush_base simp: refines_mlk_poly_def c_mlk_poly.record_simps + poly_sub_int_def no_overflow_sub_def map2_map_map word_size + intro!: nth_equalityI sint_minus_no_overflow) + subgoal \ \Condition\ + by crush_base (simp_all add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat) + subgoal \ \Loop body\ + by (crush_base simp add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat + c_mlk_poly.record_simps nth_append refines_mlk_poly_def inv_list_step) + subgoal \ \Fuel exhaust\ + by crush_base + done + +subsection \@{text mlk_value_barrier_i32} — identity (opt-blocker simplified to 0)\ + +text \After preprocessing with @{verbatim "mlk_ct_get_optblocker_i32() = 0"}, + the value barrier reduces to @{term \b XOR 0 = b\}.\ + +definition c_mlk_value_barrier_i32_contract :: + \c_int \ ('s::{sepalg}, c_int, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_value_barrier_i32_contract b \ + let pre = \True\; + post = \r. \r = b\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_value_barrier_i32_contract(*>*) + +lemma c_mlk_value_barrier_i32_spec [crush_specs]: + shows \\; c_mlk_value_barrier_i32 b \\<^sub>F + c_mlk_value_barrier_i32_contract b\ + by (crush_boot f: c_mlk_value_barrier_i32_def + contract: c_mlk_value_barrier_i32_contract_def) crush_base + +subsection \@{text mlk_cast_int32_to_uint16}\ + +text \Truncation cast from 32-bit signed to 16-bit unsigned, + masking with @{text "0xFFFF"}. Used throughout the NTT and + reduction routines to narrow intermediate 32-bit results + back to coefficient width.\ + +definition c_mlk_cast_int32_to_uint16_contract :: + \c_int \ ('s::{sepalg}, c_ushort, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_cast_int32_to_uint16_contract x \ + let pre = \True\; + post = \r. \r = ucast (x AND 0xFFFF)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_cast_int32_to_uint16_contract(*>*) + +lemma c_mlk_cast_int32_to_uint16_spec [crush_specs]: + shows \\; c_mlk_cast_int32_to_uint16 x \\<^sub>F + c_mlk_cast_int32_to_uint16_contract x\ + by (crush_boot f: c_mlk_cast_int32_to_uint16_def + contract: c_mlk_cast_int32_to_uint16_contract_def) + (crush_base simp add: scast_ucast_down_same) + +subsection \@{text mlk_cast_uint16_to_int16}\ + +text \Bit-pattern-preserving reinterpretation from unsigned to signed + 16-bit. The postcondition is simply @{text "r = scast x"}.\ + +definition c_mlk_cast_uint16_to_int16_contract :: \c_ushort \ + ('s::{sepalg}, c_short, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_cast_uint16_to_int16_contract x \ + let pre = \True\; + post = \r. \r = scast x\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_cast_uint16_to_int16_contract(*>*) + +lemma c_mlk_cast_uint16_to_int16_spec [crush_specs]: + shows \\; c_mlk_cast_uint16_to_int16 x \\<^sub>F + c_mlk_cast_uint16_to_int16_contract x\ + by (crush_boot f: c_mlk_cast_uint16_to_int16_def + contract: c_mlk_cast_uint16_to_int16_contract_def) crush_base + +subsection \@{text mlk_ct_cmask_neg_i16} — negative mask\ + +text \Returns @{term \0xFFFF :: c_ushort\} when @{term \sint x < 0\}, + @{term \0 :: c_ushort\} otherwise. + Implements: @{term \(int32_t)x >> 16\} after value-barrier identity.\ + +definition c_mlk_ct_cmask_neg_i16_contract :: \c_short \ + ('s::{sepalg}, c_ushort, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_ct_cmask_neg_i16_contract x \ + let pre = can_alloc_reference; + post = \r. can_alloc_reference \ + \r = (if sint x < 0 then 0xFFFF else 0)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_ct_cmask_neg_i16_contract(*>*) + +text \After @{const c_signed_shr}, the WP produces @{term \word_of_int (sint x div 2 ^ n)\}. + For a 16-bit signed value shifted by 16, the quotient is @{term \-1\} (negative) + or @{term \0\} (non-negative).\ + +lemma sint16_div_65536: + fixes x :: \16 sword\ + shows \sint x div 65536 = (if sint x < 0 then -1 else 0)\ +proof - + have \sint x \ -32768\ \sint x < 32768\ + using sint_range_size[where w=x] by (auto simp: word_size) + thus ?thesis by auto +qed + +lemma c_mlk_ct_cmask_neg_i16_spec [crush_specs]: + shows \\; c_mlk_ct_cmask_neg_i16 x \\<^sub>F + c_mlk_ct_cmask_neg_i16_contract x\ + apply (crush_boot f: c_mlk_ct_cmask_neg_i16_def + contract: c_mlk_ct_cmask_neg_i16_contract_def) + apply (crush_base simp add: sint_up_scast is_up + scast_ucast_down_same sint16_div_65536) + done + +subsection \@{text mlk_cast_int16_to_uint16}\ + +text \Bit-pattern-preserving reinterpretation from signed 32-bit to + unsigned 16-bit via masking and @{const ucast}.\ + +definition c_mlk_cast_int16_to_uint16_contract :: \c_int \ + ('s::{sepalg}, c_ushort, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_cast_int16_to_uint16_contract x \ + let pre = \True\; + post = \r. \r = ucast x\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_cast_int16_to_uint16_contract(*>*) + +lemma ucast_and_0xFFFF: + shows \UCAST(32 signed \ 16) (x AND 0xFFFF) = UCAST(32 signed \ 16) x\ +proof (rule word_eqI) + fix n + assume \n < size (UCAST(32 signed \ 16) (x AND 0xFFFF))\ + hence n16: \n < 16\ + by (simp add: word_size) + have mask_eq: \(0xFFFF :: 32 signed word) = mask 16\ + by eval + have \(0xFFFF :: 32 signed word) !! n\ + using n16 by (simp add: mask_eq nth_mask word_size) + thus \UCAST(32 signed \ 16) (x AND 0xFFFF) !! n = UCAST(32 signed \ 16) x !! n\ + by (simp add: nth_ucast word_ops_nth_size word_size is_down) +qed + +lemma c_mlk_cast_int16_to_uint16_spec [crush_specs]: + shows \\; c_mlk_cast_int16_to_uint16 x \\<^sub>F + c_mlk_cast_int16_to_uint16_contract x\ + by (crush_boot f: c_mlk_cast_int16_to_uint16_def + contract: c_mlk_cast_int16_to_uint16_contract_def) + (crush_base simp add: scast_ucast_down_same ucast_and_0xFFFF) + +subsection \@{text mlk_ct_cmask_nonzero_u16} — nonzero mask\ + +text \Returns @{term \0xFFFF :: c_ushort\} when @{term \x \ 0\}, + @{term \0 :: c_ushort\} otherwise. + Implements: @{term \(-(int32_t)x) >> 16\} after value-barrier identity.\ +definition c_mlk_ct_cmask_nonzero_u16_contract :: + \c_ushort \ ('s::{sepalg}, c_ushort, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_ct_cmask_nonzero_u16_contract x \ + let pre = can_alloc_reference; + post = \r. can_alloc_reference \ + \r = (if x \ 0 then 0xFFFF else 0)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_ct_cmask_nonzero_u16_contract(*>*) + +lemma neg_uint16_div_65536: + fixes x :: \16 word\ + shows \(- uint x) div 65536 = (if x \ 0 then -1 else 0)\ +proof (cases \x = 0\) + case True + thus ?thesis + by simp +next + case False + hence \uint x \ 0\ + by (simp add: uint_0_iff) + hence lb: \1 \ uint x\ + using uint_ge_0[where x=x] by linarith + have ub: \uint x \ 65535\ + using uint_lt2p[where x=x] by simp + have \(- uint x) div 65536 = -1\ + by (rule int_div_pos_eq[where r=\65536 - uint x\]) + (use lb ub in linarith, use ub in linarith, use lb in linarith) + with False show ?thesis + by simp +qed + +lemma sint_neg_ucast_16_32: + fixes x :: \16 word\ + shows \sint (- UCAST(16 \ 32 signed) x) = - uint x\ +proof - + have s: \sint (UCAST(16 \ 32 signed) x) = uint x\ + by (simp add: sint_ucast_eq_uint is_down) + have lb: \0 \ uint x\ + by (rule uint_ge_0) + have ub: \uint x \ 2147483647\ + using uint_lt2p[where x=x] by simp + have \- (2147483648 :: int) \ - uint x\ + using ub by linarith + moreover have \- uint x < (2147483648 :: int)\ + using lb by linarith + ultimately show ?thesis + unfolding sint_word_minus s by (simp add: signed_take_bit_int_eq_self_iff) +qed + +lemma c_mlk_ct_cmask_nonzero_u16_spec [crush_specs]: + shows \\; c_mlk_ct_cmask_nonzero_u16 x \\<^sub>F + c_mlk_ct_cmask_nonzero_u16_contract x\ + apply (crush_boot f: c_mlk_ct_cmask_nonzero_u16_def + contract: c_mlk_ct_cmask_nonzero_u16_contract_def) + apply (crush_base simp add: sint_up_scast is_up is_down + scast_ucast_down_same neg_uint16_div_65536 + sint_ucast_eq_uint ucast_and_0xFFFF + sint_neg_ucast_16_32) + subgoal using uint_lt2p[where x=x] uint_ge_0[where x=x] by linarith + subgoal using uint_lt2p[where x=x] by simp + done + +subsection \@{text mlk_ct_sel_int16} — conditional select\ + +text \Branch-free conditional select: returns @{term \a\} when + @{term \cond \ 0\} and @{term \b\} otherwise, using XOR-and-mask + rather than a conditional branch. The proof requires several + word-level cast round-trip lemmas and a bitwise identity showing + that @{text "b XOR (0xFFFF AND (a XOR b)) = a"}.\ +lemma ct_sel_cast_roundtrip': + fixes x :: \16 signed word\ + shows \UCAST(32 signed \ 16) (SCAST(16 signed \ 32 signed) x) = UCAST(16 signed \ 16) x\ +proof (rule word_eqI) + fix n + assume \n < size (UCAST(32 signed \ 16) (SCAST(16 signed \ 32 signed) x))\ + hence n16: \n < 16\ + by (simp add: word_size) + show \UCAST(32 signed \ 16) (SCAST(16 signed \ 32 signed) x) !! n = UCAST(16 signed \ 16) x !! n\ + proof (cases \n < 15\) + case True + thus ?thesis + by (simp add: nth_ucast nth_scast word_size is_up is_down) + next + case False + with n16 have \n = 15\ + by linarith + thus ?thesis + by (simp add: nth_ucast nth_scast word_size is_up is_down) + qed +qed + +lemma ct_sel_cast_roundtrip: + fixes x :: \16 signed word\ + shows \SCAST(16 \ 16 signed) + (SCAST(32 signed \ 16) + (UCAST(16 \ 32 signed) + (UCAST(32 signed \ 16) + (SCAST(16 signed \ 32 signed) x)))) = x\ +by (simp add: ct_sel_cast_roundtrip' ucast_down_ucast_id is_down scast_ucast_down_same) + +lemma ct_sel_xor_identity: + fixes a b :: \16 signed word\ + shows \SCAST(16 \ 16 signed) + (SCAST(32 signed \ 16) + (UCAST(16 \ 32 signed) (UCAST(32 signed \ 16) (SCAST(16 signed \ 32 signed) b)) + xor (0xFFFF :: 32 signed word) + AND (UCAST(16 \ 32 signed) (UCAST(32 signed \ 16) (SCAST(16 signed \ 32 signed) a)) + xor UCAST(16 \ 32 signed) (UCAST(32 signed \ 16) (SCAST(16 signed \ 32 signed) b))))) = a\ +proof - + have mask_eq: \(0xFFFF :: 32 signed word) = mask 16\ + by eval + show ?thesis + proof (simp only: ct_sel_cast_roundtrip' mask_eq, rule word_eqI) + fix n + let ?lhs = \SCAST(16 \ 16 signed) + (SCAST(32 signed \ 16) + (UCAST(16 \ 32 signed) (UCAST(16 signed \ 16) b) xor mask 16 AND + (UCAST(16 \ 32 signed) (UCAST(16 signed \ 16) a) xor + UCAST(16 \ 32 signed) (UCAST(16 signed \ 16) b))))\ + assume \n < size ?lhs\ + hence n16: \n < 16\ + by (simp add: word_size) + show \?lhs !! n = a !! n\ + proof (cases \n < 15\) + case True + thus ?thesis + by (auto simp add: word_ops_nth_size word_size nth_ucast nth_scast is_up is_down nth_mask) + next + case False + with n16 have \n = 15\ + by linarith + thus ?thesis + by (auto simp add: word_ops_nth_size word_size nth_ucast nth_scast is_up is_down nth_mask) + qed + qed +qed + +text \Returns @{term \a\} when @{term \cond \ 0\}, @{term \b\} otherwise.\ +definition c_mlk_ct_sel_int16_contract :: \c_short \ c_short \ c_ushort \ + ('s::{sepalg}, c_short, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_ct_sel_int16_contract a b cond \ + let pre = can_alloc_reference; + post = \r. can_alloc_reference \ + \r = (if cond \ 0 then a else b)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_ct_sel_int16_contract(*>*) + +lemma c_mlk_ct_sel_int16_spec [crush_specs]: + shows \\; c_mlk_ct_sel_int16 a b cond \\<^sub>F + c_mlk_ct_sel_int16_contract a b cond\ + apply (crush_boot f: c_mlk_ct_sel_int16_def + contract: c_mlk_ct_sel_int16_contract_def) + apply crush_base + apply (simp_all add: ct_sel_cast_roundtrip ct_sel_xor_identity) + done + +(*<*) +end + +end +(*>*) From 4ac46e6d0e08fe928064f04fdec52961ad379acd Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:00:45 +0000 Subject: [PATCH 07/11] proofs/isabelle: add C functional correctness proofs for Montgomery reduction Add MLKEM_FC_Montgomery.thy with contracts for montgomery_reduce, fqmul, and scalar_signed_to_unsigned_q. --- proofs/isabelle/MLKEM_FC_Montgomery.thy | 371 ++++++++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 proofs/isabelle/MLKEM_FC_Montgomery.thy diff --git a/proofs/isabelle/MLKEM_FC_Montgomery.thy b/proofs/isabelle/MLKEM_FC_Montgomery.thy new file mode 100644 index 0000000000..0cf603b68c --- /dev/null +++ b/proofs/isabelle/MLKEM_FC_Montgomery.thy @@ -0,0 +1,371 @@ +(*<*) +theory MLKEM_FC_Montgomery + imports MLKEM_FC_Scalar +begin +(*>*) + +text \ + Functional correctness proofs for Montgomery reduction + (@{verbatim \mlk_montgomery_reduce\}), field multiplication + (@{verbatim \mlk_fqmul\}), and the signed-to-unsigned conversion + (@{verbatim \mlk_scalar_signed_to_unsigned_q\}). +\ + +section \Montgomery Reduction and Field Multiplication\ + +(*<*) +context c_mlk_machine_model +begin + +declare c_mlk_montgomery_reduce_def [micro_rust_simps del] +declare c_mlk_fqmul_def [micro_rust_simps del] +declare c_mlk_scalar_signed_to_unsigned_q_def [micro_rust_simps del] +(*>*) + +subsection \@{text mlk_montgomery_reduce} contract\ + +text \The contract states that the result refines the abstract Montgomery + reduction: the signed interpretation of the return value equals + @{const montgomery_reduce_int} applied to the signed interpretation + of the input. The precondition bounds the input to prevent overflow.\ +definition c_mlk_montgomery_reduce_contract :: + \c_int \ ('s::{sepalg}, c_short, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_montgomery_reduce_contract a \ + let pre = can_alloc_reference \ \\sint a\ < 2^31 - 2^15 * MLKEM_Q\; + post = \r. can_alloc_reference \ + \sint r = montgomery_reduce_int (sint a)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_montgomery_reduce_contract(*>*) + +text \Bound on t * Q: since \|sint t| \ 32768\ and + \32768 * 3329 < 2^31\, multiplication does not overflow.\ +lemma montgomery_t_mul_Q_bounds: + fixes t :: \16 signed word\ + shows \sint (SCAST(16 signed \ 32 signed) t) * MLKEM_Q < 2147483648\ + and \-2147483648 \ sint (SCAST(16 signed \ 32 signed) t) * MLKEM_Q\ +proof - + have \sint t \ -32768\ \sint t < 32768\ + using sint_range_size[where w=t] by (auto simp: word_size) + hence st: \sint (SCAST(16 signed \ 32 signed) t) \ -32768\ + \sint (SCAST(16 signed \ 32 signed) t) < 32768\ + by (simp_all add: sint_up_scast is_up) + show \sint (SCAST(16 signed \ 32 signed) t) * MLKEM_Q < 2147483648\ + using st by (auto intro: order_le_less_trans[OF mult_right_mono]) + show \-2147483648 \ sint (SCAST(16 signed \ 32 signed) t) * MLKEM_Q\ + using st by (auto intro: le_trans[OF _ mult_right_mono]) +qed + +text \Bound on \a - t * Q\: with \|sint a| < 2^31 - 2^15 * Q\ and + \|sint t * Q| \ 32768 * 3329 \ 10^8\, the subtraction fits in 32 bits.\ +lemma montgomery_sub_bounds: + fixes a :: \32 signed word\ + and t :: \16 signed word\ + assumes \\sint a\ < 2^31 - 2^15 * MLKEM_Q\ + shows \sint a - sint (SCAST(16 signed \ 32 signed) t * 0xD01) < 2147483648\ + and \-2147483648 \ sint a - sint (SCAST(16 signed \ 32 signed) t * 0xD01)\ +proof - + have ta: \sint a > -2038398976\ \sint a < 2038398976\ + using assms by auto + have st: \sint t \ -32768\ \sint t < 32768\ + using sint_range_size[where w=t] by (auto simp: word_size) + hence st32: \sint (SCAST(16 signed \ 32 signed) t) \ -32768\ + \sint (SCAST(16 signed \ 32 signed) t) < 32768\ + by (simp_all add: sint_up_scast is_up) + have mul_lo: \sint (SCAST(16 signed \ 32 signed) t) * 3329 \ (-32768) * 3329\ + using st32 by (intro mult_right_mono) auto + have mul_hi: \sint (SCAST(16 signed \ 32 signed) t) * 3329 \ 32767 * 3329\ + using st32 by (intro mult_right_mono) auto + have stb: \signed_take_bit 31 (sint (SCAST(16 signed \ 32 signed) t) * 3329) = + sint (SCAST(16 signed \ 32 signed) t) * 3329\ + using mul_lo mul_hi by (intro signed_take_bit_int_eq_self) auto + have sint_mul: \sint (SCAST(16 signed \ 32 signed) t * 0xD01) = + sint (SCAST(16 signed \ 32 signed) t) * 3329\ + by (simp add: sint_word_mult word_size stb) + show \sint a - sint (SCAST(16 signed \ 32 signed) t * 0xD01) < 2147483648\ + using ta mul_lo by (simp add: sint_mul) + show \-2147483648 \ sint a - sint (SCAST(16 signed \ 32 signed) t * 0xD01)\ + using ta mul_hi by (simp add: sint_mul) +qed + +text \The word-level computation of Montgomery's \t\ equals the abstract + integer-level computation.\ +lemma montgomery_t_sint: + fixes a :: \32 signed word\ + defines \t_w \ SCAST(16 \ 16 signed) (UCAST(32 \ 16) + (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF)) + * 0xF301 AND 0xFFFF))\ + shows \sint t_w = signed_take_bit 15 ((sint a mod 2^16) * 62209 mod 2^16)\ +proof - + \ \SCAST(16 \ 16 signed) preserves sint (same-length cast)\ + let ?inner = \UCAST(32 \ 16) + (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF)) + * 0xF301 AND 0xFFFF)\ + have step1: \sint t_w = sint ?inner\ + unfolding t_w_def by (simp add: sint_up_scast is_up) + \ \For a 16-bit word, sint = signed_take_bit 15 . uint\ + have step2: \sint ?inner = signed_take_bit 15 (uint ?inner)\ + by (simp add: sint_uint word_size) + \ \UCAST(32\16) preserves signed_take_bit 15 of uint\ + let ?prod = \UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF)) + * 0xF301 AND (0xFFFF :: 32 word)\ + have step3: \signed_take_bit 15 (uint (UCAST(32 \ 16) ?prod)) = + signed_take_bit 15 (uint ?prod)\ + proof - + have rw: \UCAST(32 \ 16) ?prod = word_of_int (uint ?prod)\ + by (rule ucast_eq) + have len: \LENGTH(16) = Suc 15\ + by (simp add: word_size) + show ?thesis + by (simp only: rw uint_word_of_int_eq len sbintrunc_bintrunc) + qed + \ \uint of AND 0xFFFF = uint mod 2^16\ + let ?mul = \UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF)) * (0xF301 :: 32 word)\ + have step4: \uint (?mul AND 0xFFFF) = uint ?mul mod 65536\ + proof - + have \(0xFFFF :: 32 word) = mask 16\ + by eval + show ?thesis + by (subst \(0xFFFF :: 32 word) = mask 16\) + (simp add: and_mask_mod_2p uint_word_of_int word_size) + qed + \ \uint of multiplication\ + have step5: \uint ?mul = uint (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF))) * 62209 mod 2^32\ + by (simp add: uint_word_ariths word_size) + \ \uint of up-cast UCAST(16\32) = uint of inner\ + have step6: \uint (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF))) = + uint (UCAST(32 signed \ 16) (a AND 0xFFFF))\ + by (simp add: uint_up_ucast is_up) + \ \uint of down-cast UCAST(32 signed\16) = uint mod 2^16\ + have step7: \uint (UCAST(32 signed \ 16) (a AND 0xFFFF)) = uint (a AND 0xFFFF) mod 2^16\ + proof - + have rw: \UCAST(32 signed \ 16) (a AND 0xFFFF) = word_of_int (uint (a AND 0xFFFF))\ + by (rule ucast_eq) + show ?thesis + by (simp only: rw uint_word_of_int_eq take_bit_eq_mod) (simp add: word_size) + qed + \ \uint of AND 0xFFFF = uint a mod 2^16\ + have step8: \uint (a AND (0xFFFF :: 32 signed word)) = uint a mod 65536\ + proof - + have \(0xFFFF :: 32 signed word) = mask 16\ + by eval + show ?thesis + by (subst \(0xFFFF :: 32 signed word) = mask 16\) + (simp add: and_mask_mod_2p uint_word_of_int word_size take_bit_eq_mod) + qed + \ \Chain: uint of the up-cast = uint a mod 65536\ + have uint_chain: \uint (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF))) = uint a mod 65536\ + using step6 step7 step8 by simp + \ \The product does not overflow 32 bits: (uint a mod 65536) * 62209 < 2^32\ + have no_overflow: \(uint a mod 65536) * 62209 < 2^32\ + proof - + have \uint a mod 65536 < 65536\ + by simp + hence \(uint a mod 65536) * 62209 < 65536 * 62209\ + by (rule mult_strict_right_mono) simp + also have \\ < 2^32\ + by simp + finally show ?thesis + . + qed + \ \Combine: signed_take_bit 15 of the uint chain\ + have lhs_eq: \sint t_w = signed_take_bit 15 ((uint a mod 65536) * 62209 mod 65536)\ + proof - + have \uint ?mul = (uint a mod 65536) * 62209\ + using step5 uint_chain no_overflow by simp + hence \uint (?mul AND 0xFFFF) = (uint a mod 65536) * 62209 mod 65536\ + using step4 by simp + hence \signed_take_bit 15 (uint ?prod) = + signed_take_bit 15 ((uint a mod 65536) * 62209 mod 65536)\ + by simp + thus ?thesis + using step1 step2 step3 by simp + qed + \ \Finally: uint a mod 65536 = sint a mod 65536 (since 2^16 | 2^32)\ + have \uint a mod 65536 = sint a mod 65536\ + by (simp add: uint_sint word_size take_bit_eq_mod mod_mod_cancel) + thus ?thesis + using lhs_eq by simp +qed + +text \The full Montgomery result at the word level equals the abstract + \montgomery_reduce_int\ applied to \sint a\.\ +lemma montgomery_result_sint: + fixes a :: \32 signed word\ + assumes \\sint a\ < 2^31 - 2^15 * MLKEM_Q\ + shows \sint (SCAST(32 signed \ 16 signed) + (word_of_int (sint (a - SCAST(16 signed \ 32 signed) + (SCAST(16 \ 16 signed) (UCAST(32 \ 16) + (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF)) + * 0xF301 AND 0xFFFF))) * 0xD01) div 65536))) = + montgomery_reduce_int (sint a)\ +proof - + define t_w where + \t_w \ SCAST(16 \ 16 signed) (UCAST(32 \ 16) + (UCAST(16 \ 32) (UCAST(32 signed \ 16) (a AND 0xFFFF)) + * 0xF301 AND 0xFFFF))\ + \ \sint of t_w equals the abstract montgomery t\ + have t_sint: \sint t_w = signed_take_bit 15 ((sint a mod 2^16) * 62209 mod 2^16)\ + unfolding t_w_def by (rule montgomery_t_sint) + \ \sint of the upcast preserves value\ + have sint_cast: \sint (SCAST(16 signed \ 32 signed) t_w) = sint t_w\ + by (simp add: sint_up_scast is_up) + \ \Subtraction bounds from montgomery_sub_bounds\ + note sub_bounds = montgomery_sub_bounds[OF assms, where t=t_w, folded t_w_def] + \ \sint of the subtraction\ + have sint_sub: \sint (a - SCAST(16 signed \ 32 signed) t_w * 0xD01) = + sint a - sint (SCAST(16 signed \ 32 signed) t_w * 0xD01)\ + using sub_bounds by (simp add: sint_word_diff word_size signed_take_bit_int_eq_self) + \ \sint of multiplication t_w * Q\ + have sint_mul: \sint (SCAST(16 signed \ 32 signed) t_w * 0xD01) = + sint t_w * 3329\ + proof - + note mul_bounds = montgomery_t_mul_Q_bounds[where t=t_w] + have \signed_take_bit 31 (sint (SCAST(16 signed \ 32 signed) t_w) * 3329) = + sint (SCAST(16 signed \ 32 signed) t_w) * 3329\ + using mul_bounds by (intro signed_take_bit_int_eq_self) auto + thus ?thesis + by (simp add: sint_word_mult word_size sint_cast) + qed + \ \The div expression equals montgomery_reduce_int\ + have div_eq: \sint (a - SCAST(16 signed \ 32 signed) t_w * 0xD01) div 65536 = + montgomery_reduce_int (sint a)\ + using sint_sub sint_mul t_sint + unfolding montgomery_reduce_int_def Let_def by simp + \ \Result bound ensures the final cast preserves value\ + have result_bound: \\montgomery_reduce_int (sint a)\ < 2^15\ + using assms by (rule montgomery_reduce_int_bound) + \ \sint of SCAST(32s\16s)(word_of_int k) = k when |k| < 2^15\ + define r where + \r = montgomery_reduce_int (sint a)\ + have \sint (SCAST(32 signed \ 16 signed) (word_of_int r :: 32 signed word)) = r\ + proof - + have \sint (word_of_int r :: 32 signed word) = r\ + using result_bound unfolding r_def by (intro sint_of_int_eq) auto + hence \SCAST(32 signed \ 16 signed) (word_of_int r :: 32 signed word) = + (word_of_int r :: 16 signed word)\ + by (simp add: scast_def) + moreover have \sint (word_of_int r :: 16 signed word) = r\ + using result_bound unfolding r_def by (intro sint_of_int_eq) (auto simp: word_size) + ultimately show ?thesis + by simp + qed + thus ?thesis + using div_eq by (simp add: t_w_def r_def) +qed + +lemma c_mlk_montgomery_reduce_spec [crush_specs]: + shows \\; c_mlk_montgomery_reduce while_fuel a \\<^sub>F + c_mlk_montgomery_reduce_contract a\ + apply (crush_boot f: c_mlk_montgomery_reduce_def + contract: c_mlk_montgomery_reduce_contract_def) + apply crush_base + apply (all \(insert montgomery_t_mul_Q_bounds montgomery_sub_bounds[OF _]; simp; fail)?\) + apply (simp add: montgomery_result_sint) + done + +subsection \@{text mlk_fqmul} contract\ + +text \Field multiplication: @{term \fqmul a b = montgomery_reduce (a * b)\}. + Since both inputs are 16-bit signed, the product fits in 32-bit signed + and satisfies the Montgomery precondition unconditionally.\ +lemma fqmul_product_bound: + fixes a b :: \16 signed word\ + shows \\sint a * sint b\ < 2 ^ 31 - 2 ^ 15 * MLKEM_Q\ +proof - + have \sint a \ -32768\ \sint a \ 32767\ + using sint_range_size[where w=a] by (auto simp: word_size) + hence a_abs: \\sint a\ \ 32768\ + by auto + have \sint b \ -32768\ \sint b \ 32767\ + using sint_range_size[where w=b] by (auto simp: word_size) + hence b_abs: \\sint b\ \ 32768\ + by auto + have \\sint a * sint b\ = \sint a\ * \sint b\\ + by (rule abs_mult) + also have \\ \ 32768 * 32768\ + using a_abs b_abs by (intro mult_mono) auto + also have \(32768 * 32768 :: int) < 2 ^ 31 - 2 ^ 15 * MLKEM_Q\ + by simp + finally show ?thesis . +qed + +lemma fqmul_sint_product: + fixes a b :: \16 signed word\ + shows \sint (SCAST(16 signed \ 32 signed) a * SCAST(16 signed \ 32 signed) b) = + sint a * sint b\ +using fqmul_product_bound[of a b] by (simp add: sint_word_mult sint_up_scast is_up + sbintrunc_eq_in_range range_sbintrunc word_size, linarith) + +definition c_mlk_fqmul_contract :: + \c_short \ c_short \ ('s::{sepalg}, c_short, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_fqmul_contract a b \ + let pre = can_alloc_reference; + post = \r. can_alloc_reference \ + \sint r = montgomery_reduce_int (sint a * sint b)\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_fqmul_contract(*>*) + +lemma c_mlk_fqmul_spec [crush_specs]: + shows \\; c_mlk_fqmul while_fuel while_fuel a b \\<^sub>F + c_mlk_fqmul_contract a b\ + apply (crush_boot f: c_mlk_fqmul_def contract: c_mlk_fqmul_contract_def) + apply crush_base + apply (insert fqmul_product_bound[of a b]) + apply (simp_all add: fqmul_sint_product sint_up_scast is_up) + done + +subsection \@{text mlk_scalar_signed_to_unsigned_q} contract\ + +text \Converts a signed value in \(-MLKEM_Q, MLKEM_Q)\ to its + canonical unsigned representative in \[0, MLKEM_Q)\. + Adds \MLKEM_Q\ if the input is negative.\ +lemma scalar_unsigned_q_cast_sint: + assumes \-MLKEM_Q < sint c\ \sint c < 0\ + shows \sint (SCAST(32 signed \ 16 signed) + (SCAST(16 signed \ 32 signed) c + 0xD01)) = sint c mod MLKEM_Q\ +proof - + have scast_sint: \sint (SCAST(32 signed \ 16 signed) w) = signed_take_bit 15 (sint w)\ for w :: \32 sword\ + by (simp only: of_int_sint_scast[symmetric] Word.sint_sbintrunc') simp + have up: \sint (SCAST(16 signed \ 32 signed) c) = sint c\ + by (simp add: sint_up_scast is_up) + have add_sint: \sint (SCAST(16 signed \ 32 signed) c + 0xD01) = sint c + 3329\ + proof - + have \signed_take_bit 31 (sint c + 3329) = sint c + 3329\ + using sint_range_size[where w=c] by (intro signed_take_bit_int_eq_self) (auto simp: word_size) + thus ?thesis + by (simp add: sint_word_ariths word_size up) + qed + have stb15: \signed_take_bit 15 (sint c + 3329) = sint c + 3329\ + using assms by (intro signed_take_bit_int_eq_self) auto + have cast_eq: \sint (SCAST(32 signed \ 16 signed) (SCAST(16 signed \ 32 signed) c + 0xD01)) = sint c + 3329\ + by (simp add: scast_sint add_sint stb15) + have mod_eq: \sint c mod 3329 = sint c + 3329\ + using assms by (intro int_mod_pos_eq[where q="-1"]) auto + show ?thesis + by (simp add: cast_eq mod_eq) +qed + +definition c_mlk_scalar_signed_to_unsigned_q_contract :: + \c_short \ ('s::{sepalg}, c_short, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_scalar_signed_to_unsigned_q_contract c \ + let pre = can_alloc_reference \ \-MLKEM_Q < sint c \ sint c < MLKEM_Q\; + post = \r. can_alloc_reference \ \sint r = sint c mod MLKEM_Q\ + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_scalar_signed_to_unsigned_q_contract(*>*) + +lemma c_mlk_scalar_signed_to_unsigned_q_spec [crush_specs]: + shows \\; c_mlk_scalar_signed_to_unsigned_q while_fuel c \\<^sub>F + c_mlk_scalar_signed_to_unsigned_q_contract c\ + apply (crush_boot f: c_mlk_scalar_signed_to_unsigned_q_def + contract: c_mlk_scalar_signed_to_unsigned_q_contract_def) + apply crush_base + apply (simp_all add: scalar_unsigned_q_cast_sint sint_up_scast is_up) + done + +(*<*) +end +(*>*) + +(*<*) +end +(*>*) From fdd11d7677fbb7ea7003d712654abf11903ccef0 Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:00:56 +0000 Subject: [PATCH 08/11] proofs/isabelle: add C functional correctness proofs for polynomial loops Add MLKEM_FC_PolyLoop.thy with contracts for poly_tomont, poly_reduce, and mulcache_compute (both _c inner loops and wrappers). --- proofs/isabelle/MLKEM_FC_PolyLoop.thy | 312 ++++++++++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 proofs/isabelle/MLKEM_FC_PolyLoop.thy diff --git a/proofs/isabelle/MLKEM_FC_PolyLoop.thy b/proofs/isabelle/MLKEM_FC_PolyLoop.thy new file mode 100644 index 0000000000..40d56a78ba --- /dev/null +++ b/proofs/isabelle/MLKEM_FC_PolyLoop.thy @@ -0,0 +1,312 @@ +(*<*) +theory MLKEM_FC_PolyLoop + imports MLKEM_FC_Montgomery MLKEM_Zetas +begin +(*>*) + +text \ + Functional correctness proofs for polynomial loop operations: + Montgomery pre-scaling (@{verbatim \mlk_poly_tomont\}), Barrett + reduction (@{verbatim \mlk_poly_reduce\}), and mulcache computation + (@{verbatim \mlk_poly_mulcache_compute\}). +\ + +section \Polynomial Loop Operations\ + +(*<*) +context c_mlk_machine_model +begin + +declare c_mlk_poly_mulcache_compute_c_def [micro_rust_simps del] +declare c_mlk_poly_reduce_c_def [micro_rust_simps del] +declare c_global_mlk_zetas_def [micro_rust_simps del] +(*>*) + +subsection \@{text mlk_poly_tomont_c} contract\ + +text \Montgomery-domain conversion loop. Each coefficient is multiplied + by the constant 1353 via @{verbatim \fqmul\}, mapping the polynomial + into Montgomery representation.\ + +definition c_mlk_poly_tomont_c_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_tomont_c_contract r gr vr ar \ + let pre = can_alloc_reference \ + r \\\\ gr\vr \ + \refines_mlk_poly vr ar\; + post = \_. can_alloc_reference \ + (\gr' vr'. r \\\\ gr'\vr' \ + \refines_mlk_poly vr' (poly_tomont_int ar)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_tomont_c_contract(*>*) + +lemma c_mlk_poly_tomont_c_spec [crush_specs]: + shows \\; c_mlk_poly_tomont_c while_fuel while_fuel MLKEM_N r \\<^sub>F + c_mlk_poly_tomont_c_contract r gr vr ar\ + apply (crush_boot f: c_mlk_poly_tomont_c_def contract: c_mlk_poly_tomont_c_contract_def) + apply crush_base + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr' vr'. r \\\\ gr'\vr' \ + \length (c_mlk_poly_coeffs vr') = MLKEM_N\ \ + \\j < MLKEM_N - k. sint (c_mlk_poly_coeffs vr' ! j) = + montgomery_reduce_int (sint (c_mlk_poly_coeffs vr ! j) * 1353)\ \ + \\j. MLKEM_N - k \ j \ j < MLKEM_N \ + c_mlk_poly_coeffs vr' ! j = c_mlk_poly_coeffs vr ! j\) + \ (\gf. xa \\\\ gf\(0x549 :: c_short)) + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\\ + and INV'=\\k. (\gr' vr'. r \\\\ gr'\vr' \ + \length (c_mlk_poly_coeffs vr') = MLKEM_N\ \ + \\j < MLKEM_N - Suc k. sint (c_mlk_poly_coeffs vr' ! j) = + montgomery_reduce_int (sint (c_mlk_poly_coeffs vr ! j) * 1353)\ \ + \\j. MLKEM_N - Suc k \ j \ j < MLKEM_N \ + c_mlk_poly_coeffs vr' ! j = c_mlk_poly_coeffs vr ! j\) + \ (\gf. xa \\\\ gf\(0x549 :: c_short)) + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - Suc k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + apply crush_base + apply (crush_base simp add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat + c_mlk_poly.record_simps nth_append nth_list_update refines_mlk_poly_def + poly_tomont_int_def sint_word_of_montgomery_fqmul) + apply (auto intro!: nth_equalityI) + done + +subsection \@{text mlk_poly_tomont} contract\ + +text \Top-level wrapper: delegates to the inner loop via the + @{const refines_mlk_poly} abstraction boundary.\ + +definition c_mlk_poly_tomont_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_tomont_contract r gr vr ar \ + let pre = can_alloc_reference \ + r \\\\ gr\vr \ + \refines_mlk_poly vr ar\; + post = \_. can_alloc_reference \ + (\gr' vr'. r \\\\ gr'\vr' \ + \refines_mlk_poly vr' (poly_tomont_int ar)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_tomont_contract(*>*) + +declare c_mlk_poly_tomont_c_def [micro_rust_simps del] + +lemma c_mlk_poly_tomont_spec [crush_specs]: + shows \\; c_mlk_poly_tomont MLKEM_N while_fuel while_fuel r \\<^sub>F + c_mlk_poly_tomont_contract r gr vr ar\ + apply (crush_boot f: c_mlk_poly_tomont_def contract: c_mlk_poly_tomont_contract_def) + apply (rule wp_callI[OF c_mlk_poly_tomont_c_spec[where gr=gr and vr=vr and ar=ar]]) + apply (simp add: c_mlk_poly_tomont_c_contract_def) + apply (crush_base simp add: c_mlk_poly_tomont_c_contract_def) + done + +subsection \@{text mlk_poly_reduce_c} contract\ + +text \Barrett reduction loop. Applies @{const barrett_reduce_int} to every + coefficient, bringing each into a centered representative modulo @{const MLKEM_Q}.\ + +definition c_mlk_poly_reduce_c_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_reduce_c_contract r gr vr ar \ + let pre = can_alloc_reference \ + r \\\\ gr\vr \ + \refines_mlk_poly vr ar\; + post = \_. can_alloc_reference \ + (\gr' vr'. r \\\\ gr'\vr' \ + \refines_mlk_poly vr' (poly_reduce_int ar)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_reduce_c_contract(*>*) + +lemma c_mlk_poly_reduce_c_spec [crush_specs]: + shows \\; c_mlk_poly_reduce_c while_fuel while_fuel MLKEM_N r \\<^sub>F + c_mlk_poly_reduce_c_contract r gr vr ar\ + apply (crush_boot f: c_mlk_poly_reduce_c_def contract: c_mlk_poly_reduce_c_contract_def) + apply crush_base + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr' vr'. r \\\\ gr'\vr' \ + \length (c_mlk_poly_coeffs vr') = MLKEM_N\ \ + \\j < MLKEM_N - k. sint (c_mlk_poly_coeffs vr' ! j) = + sint (c_mlk_poly_coeffs vr ! j) mod MLKEM_Q\ \ + \\j. MLKEM_N - k \ j \ j < MLKEM_N \ + c_mlk_poly_coeffs vr' ! j = c_mlk_poly_coeffs vr ! j\) + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\\ + and INV'=\\k. (\gr' vr'. r \\\\ gr'\vr' \ + \length (c_mlk_poly_coeffs vr') = MLKEM_N\ \ + \\j < MLKEM_N - Suc k. sint (c_mlk_poly_coeffs vr' ! j) = + sint (c_mlk_poly_coeffs vr ! j) mod MLKEM_Q\ \ + \\j. MLKEM_N - Suc k \ j \ j < MLKEM_N \ + c_mlk_poly_coeffs vr' ! j = c_mlk_poly_coeffs vr ! j\) + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - Suc k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vr ar\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + apply (crush_base simp add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat + c_mlk_poly.record_simps nth_append nth_list_update refines_mlk_poly_def + poly_reduce_int_def + c_mlk_barrett_reduce_contract_def c_mlk_scalar_signed_to_unsigned_q_contract_def + c_signed_truthy_zero bounded_while_literal_false) + apply (auto intro!: nth_equalityI intro: barrett_reduce_int_in_range simp add: barrett_reduce_mod) + done + +subsection \@{text mlk_poly_reduce} contract\ + +text \Top-level Barrett reduction wrapper, lifting the coefficient-level + loop through the @{const refines_mlk_poly} abstraction.\ + +definition c_mlk_poly_reduce_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_reduce_contract r gr vr ar \ + let pre = can_alloc_reference \ + r \\\\ gr\vr \ + \refines_mlk_poly vr ar\; + post = \_. can_alloc_reference \ + (\gr' vr'. r \\\\ gr'\vr' \ + \refines_mlk_poly vr' (poly_reduce_int ar)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_reduce_contract(*>*) + +lemma c_mlk_poly_reduce_spec [crush_specs]: + shows \\; c_mlk_poly_reduce MLKEM_N while_fuel while_fuel r \\<^sub>F + c_mlk_poly_reduce_contract r gr vr ar\ + apply (crush_boot f: c_mlk_poly_reduce_def contract: c_mlk_poly_reduce_contract_def) + apply (rule wp_callI[OF c_mlk_poly_reduce_c_spec[where gr=gr and vr=vr and ar=ar]]) + apply (simp add: c_mlk_poly_reduce_c_contract_def) + apply (crush_base simp add: c_mlk_poly_reduce_c_contract_def) + done + +subsection \@{text mlk_poly_mulcache_compute_c} contract\ + +text \Computes the multiplication cache: 128 products of odd-indexed + coefficients with corresponding zeta factors.\ +definition c_mlk_poly_mulcache_compute_c_contract :: + \('addr, 8 word list, c_mlk_poly_mulcache) Global_Store.ref \ 8 word list \ + c_mlk_poly_mulcache \ + ('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_mulcache_compute_c_contract x gx vx a ga va aa \ + let pre = can_alloc_reference \ + x \\\\ gx\vx \ + a \\\\ ga\va \ + \refines_mlk_poly va aa\ \ + \length (c_mlk_poly_mulcache_coeffs vx) = 128\; + post = \_. can_alloc_reference \ + (\gx' vx'. x \\\\ gx'\vx' \ + \refines_mlk_poly_mulcache vx' (mulcache_compute_int aa)\) \ + a \\\\ ga\va + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_mulcache_compute_c_contract(*>*) + +lemma c_mlk_poly_mulcache_compute_c_spec [crush_specs]: + shows \\; c_mlk_poly_mulcache_compute_c while_fuel while_fuel 64 x a \\<^sub>F + c_mlk_poly_mulcache_compute_c_contract x gx vx a ga va aa\ + apply (crush_boot f: c_mlk_poly_mulcache_compute_c_def + contract: c_mlk_poly_mulcache_compute_c_contract_def) + apply crush_base + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gx' vx'. x \\\\ gx'\vx' \ + \length (c_mlk_poly_mulcache_coeffs vx') = 128\ \ + \\j < 2 * (64 - k). sint (c_mlk_poly_mulcache_coeffs vx' ! j) = + mulcache_compute_int aa ! j\) + \ a \\\\ ga\va + \ (\gi. xa \\\\ gi\(of_nat (64 - k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly va aa\\ + and INV'=\\k. (\gx' vx'. x \\\\ gx'\vx' \ + \length (c_mlk_poly_mulcache_coeffs vx') = 128\ \ + \\j < 2 * (64 - Suc k). sint (c_mlk_poly_mulcache_coeffs vx' ! j) = + mulcache_compute_int aa ! j\) + \ a \\\\ ga\va + \ (\gi. xa \\\\ gi\(of_nat (64 - Suc k) :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly va aa\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + by (crush_base simp add: refines_mlk_poly_mulcache_def + c_mlk_poly_mulcache.record_simps length_mulcache_compute_int) + (auto intro!: nth_equalityI) + subgoal \ \Condition check\ + by (crush_base simp add: word_less_nat_alt unat_of_nat + c_trunc_div_int_def + c_signed_truthy_zero bounded_while_literal_false) auto + subgoal \ \Loop body\ + apply (crush_base simp add: word_less_nat_alt unat_of_nat + c_mlk_poly_mulcache.record_simps c_mlk_poly.record_simps + nth_list_update refines_mlk_poly_def + c_mlk_fqmul_contract_def fqmul_int_def + mulcache_compute_int_nth_even mulcache_compute_int_nth_odd + c_trunc_div_int_def + c_signed_truthy_zero bounded_while_literal_false + length_mulcache_compute_int) + using [[linarith_split_limit = 20]] + apply (simp_all add: word_less_nat_alt unat_of_nat word_le_nat_alt + sint_up_scast is_up c_global_mlk_zetas_def + mulcache_compute_int_nth_even' mulcache_compute_int_nth_odd' + fqmul_int_def nth_map + drop_64_zetas_sword[symmetric] nth_drop length_zetas_sword + zetas_sword_sint zetas_neg_scast_sint + zetas_int_i32_bound_from_k) + done + subgoal \ \Fuel exhaust\ + by (crush_base simp add: c_trunc_div_int_def) + done + +subsection \@{text mlk_poly_mulcache_compute} contract\ + +text \Top-level wrapper for the multiplication-cache computation, + lifting the inner loop through the poly abstraction.\ + +definition c_mlk_poly_mulcache_compute_contract :: + \('addr, 8 word list, c_mlk_poly_mulcache) Global_Store.ref \ 8 word list \ + c_mlk_poly_mulcache \ + ('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_mulcache_compute_contract x gx vx a ga va aa \ + let pre = can_alloc_reference \ + x \\\\ gx\vx \ + a \\\\ ga\va \ + \refines_mlk_poly va aa\ \ + \length (c_mlk_poly_mulcache_coeffs vx) = 128\; + post = \_. can_alloc_reference \ + (\gx' vx'. x \\\\ gx'\vx' \ + \refines_mlk_poly_mulcache vx' (mulcache_compute_int aa)\) \ + a \\\\ ga\va + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_mulcache_compute_contract(*>*) + +lemma c_mlk_poly_mulcache_compute_spec [crush_specs]: + shows \\; c_mlk_poly_mulcache_compute 64 while_fuel while_fuel x a \\<^sub>F + c_mlk_poly_mulcache_compute_contract x gx vx a ga va aa\ + apply (crush_boot f: c_mlk_poly_mulcache_compute_def + contract: c_mlk_poly_mulcache_compute_contract_def) + apply (rule wp_callI[OF c_mlk_poly_mulcache_compute_c_spec + [where gx=gx and vx=vx and ga=ga and va=va and aa=aa]]) + apply (simp add: c_mlk_poly_mulcache_compute_c_contract_def) + apply (crush_base simp add: c_mlk_poly_mulcache_compute_c_contract_def) + done + +(*<*) +end +(*>*) + +(*<*) +end +(*>*) From 58ad2280f0618ec8d39cb9510f9a80b0aec86c5e Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:01:03 +0000 Subject: [PATCH 09/11] proofs/isabelle: add C functional correctness proof of forward NTT Add MLKEM_FC_NTT.thy with contracts for ntt_butterfly_block, ntt_layer, poly_ntt_c, and poly_ntt. --- proofs/isabelle/MLKEM_FC_NTT.thy | 752 +++++++++++++++++++++++++++++++ 1 file changed, 752 insertions(+) create mode 100644 proofs/isabelle/MLKEM_FC_NTT.thy diff --git a/proofs/isabelle/MLKEM_FC_NTT.thy b/proofs/isabelle/MLKEM_FC_NTT.thy new file mode 100644 index 0000000000..5f1ce0bfad --- /dev/null +++ b/proofs/isabelle/MLKEM_FC_NTT.thy @@ -0,0 +1,752 @@ +(*<*) +theory MLKEM_FC_NTT + imports MLKEM_FC_Montgomery MLKEM_NTT_Spec +begin +(*>*) + +text \ + Functional correctness proof of the forward NTT implementation. + Verifies the butterfly block, single-layer, and outer-loop functions + against the abstract NTT specification from @{text MLKEM_NTT_Spec}. +\ + +section \Forward NTT Verification\ + +(*<*) +context c_mlk_machine_model +begin + +declare c_mlk_ntt_butterfly_block_def [micro_rust_simps del] +declare c_mlk_ntt_layer_def [micro_rust_simps del] +declare c_mlk_poly_ntt_c_def [micro_rust_simps del] +declare c_global_mlk_zetas_def [micro_rust_simps del] +(*>*) + +subsection \@{text mlk_ntt_butterfly_block} contract\ + +text \Inner loop: applies butterflies to one block of coefficients. + Works at the coefficient-list level via @{const refines_coeffs}.\ +definition c_mlk_ntt_butterfly_block_contract :: + \('addr, 8 word list, c_short list) Global_Store.ref \ 8 word list \ + c_short list \ int list \ + c_short \ c_uint \ c_uint \ c_uint \ nat \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_ntt_butterfly_block_contract r gr cs acs zeta start_v len_v bound_v n \ + let off = unat start_v; + blen = unat len_v; + pre = can_alloc_reference \ + r \\\\ gr\cs \ + \refines_coeffs cs acs\ \ + \off + 2 * blen \ MLKEM_N\ \ + \blen > 0\ \ + \ntt_inner_no_overflow (sint zeta) off blen blen acs\ \ + \blen \ n\; + post = \_. can_alloc_reference \ + (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (ntt_inner_loop_int (sint zeta) off blen blen acs)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_ntt_butterfly_block_contract(*>*) + +lemma unat_add_no_overflow_from_bound: + assumes \unat (a :: 32 word) + 2 * unat (b :: 32 word) \ 256\ + shows \unat (a + b) = unat a + unat b\ +proof - + from assms have \unat a + unat b \ 256\ + by linarith + hence \unat a + unat b < 4294967296\ + by linarith + thus ?thesis + by (simp add: unat_word_ariths) +qed + +lemma unat_butterfly_idx: + fixes start_v len_v :: \32 word\ + assumes \unat start_v + 2 * unat len_v \ MLKEM_N\ + and \n - Suc k < unat len_v\ + and \Suc k \ n\ + and \unat len_v \ n\ + shows \unat (start_v + (word_of_nat n - (1 + word_of_nat k)) + len_v) = + unat start_v + (n - Suc k) + unat len_v\ + and \unat (start_v + (word_of_nat n - (1 + word_of_nat k))) = + unat start_v + (n - Suc k)\ + and \unat start_v + (n - Suc k) + unat len_v < MLKEM_N\ + and \unat start_v + (n - Suc k) < MLKEM_N\ +proof - + from assms have diff_bound: \n - Suc k < 256\ + by linarith + have sum_bound: \unat start_v + (n - Suc k) + unat len_v < 256\ + using assms diff_bound by linarith + have sum_bound2: \unat start_v + (n - Suc k) < 256\ + using sum_bound by linarith + show bound1: \unat start_v + (n - Suc k) + unat len_v < MLKEM_N\ + using sum_bound by simp + show bound2: \unat start_v + (n - Suc k) < MLKEM_N\ + using sum_bound2 by simp + have word_sub: \(word_of_nat n :: 32 word) - (1 + word_of_nat k) = word_of_nat (n - Suc k)\ + using assms(3) by (simp add: word_of_nat_eq_iff of_nat_diff) + hence sub_unat: \unat ((word_of_nat n :: 32 word) - (1 + word_of_nat k)) = (n - Suc k) mod 4294967296\ + by (simp add: unat_of_nat) + have sub_eq: \unat ((word_of_nat n :: 32 word) - (1 + word_of_nat k)) = n - Suc k\ + using sub_unat diff_bound by simp + show mid: \unat (start_v + (word_of_nat n - (1 + word_of_nat k))) = unat start_v + (n - Suc k)\ + using sub_eq sum_bound2 by (simp add: unat_word_ariths) + show \unat (start_v + (word_of_nat n - (1 + word_of_nat k)) + len_v) + = unat start_v + (n - Suc k) + unat len_v\ + using mid sum_bound by (simp add: unat_word_ariths) +qed + +lemma ntt_butterfly_block_fc_step: + fixes ra :: \c_short list\ + and rb :: \c_short\ + and cs :: \c_short list\ + and z :: \int\ + and off blen n k :: \nat\ + assumes ra_eq: \list.map sint ra = ntt_inner_loop_int z off blen (n - Suc k) (list.map sint cs)\ + and rb_eq: \sint rb = montgomery_reduce_int (sint (ra ! (off + n + blen - Suc k)) * z)\ + and overflow: \ntt_inner_no_overflow z off blen blen (list.map sint cs)\ + and m_bound: \n - Suc k < blen\ + and k_bound: \Suc k \ n\ + and sum_bound: \off + 2 * blen \ MLKEM_N\ + and len_ra: \length ra = MLKEM_N\ + shows \list.map sint (ra[off + n + blen - Suc k := ra ! (off + n - Suc k) - rb, + off + n - Suc k := ra ! (off + n - Suc k) + rb]) = + ntt_inner_loop_int z off blen (n - k) (list.map sint cs)\ +proof - + define m where + \m = n - Suc k\ + with k_bound have m_suc: \n - k = Suc m\ + by simp + from m_def k_bound have idx_hi: \off + n + blen - Suc k = off + m + blen\ + by simp + from m_def k_bound have idx_lo: \off + n - Suc k = off + m\ + by simp + from m_bound m_def sum_bound len_ra have hi_bound: \off + m + blen < length ra\ and lo_bound: \off + m < length ra\ + by auto + \ \Abstract state after m iterations\ + define L where \L = ntt_inner_loop_int z off blen m (list.map sint cs)\ + have L_eq: \L = list.map sint ra\ + using ra_eq m_def L_def by simp + have L_lo: \L ! (off + m) = sint (ra ! (off + m))\ + using L_eq lo_bound by (simp add: nth_map) + have L_hi: \L ! (off + m + blen) = sint (ra ! (off + m + blen))\ + using L_eq hi_bound by (simp add: nth_map) + \ \Montgomery result\ + define t where \t = fqmul_int (L ! (off + m + blen)) z\ + have t_eq: \t = sint rb\ + unfolding t_def fqmul_int_def L_hi using rb_eq idx_hi by simp + \ \Extract overflow bounds from predicate\ + have ob: \- 32768 \ L ! (off + m) + t \ L ! (off + m) + t \ 32767 \ + - 32768 \ L ! (off + m) - t \ L ! (off + m) - t \ 32767\ + using overflow m_bound m_def unfolding ntt_inner_no_overflow_def Let_def L_def t_def by auto + \ \sint distributes over + and -\ + have sint_add: \sint (ra ! (off + m) + rb) = sint (ra ! (off + m)) + sint rb\ + by (rule sint_plus_no_overflow) (use ob L_lo t_eq in auto) + have sint_sub: \sint (ra ! (off + m) - rb) = sint (ra ! (off + m)) - sint rb\ + by (rule sint_minus_no_overflow) (use ob L_lo t_eq in auto) + \ \RHS: expand via snoc + butterfly\ + have rhs_eq: \ntt_inner_loop_int z off blen (Suc m) (list.map sint cs) = + (L[off + m + blen := L ! (off + m) - t])[off + m := L ! (off + m) + t]\ + unfolding L_def ntt_inner_loop_int_snoc ntt_butterfly_int_def Let_def t_def by simp + \ \LHS: push map sint through updates, substitute\ + have lhs_eq: \list.map sint (ra[off + m + blen := ra ! (off + m) - rb, + off + m := ra ! (off + m) + rb]) = + (L[off + m + blen := L ! (off + m) - t])[off + m := L ! (off + m) + t]\ + using L_eq sint_add sint_sub L_lo t_eq lo_bound hi_bound by (simp add: List.map_update) + show ?thesis + unfolding idx_hi idx_lo m_suc using lhs_eq rhs_eq by simp +qed + +lemma c_mlk_ntt_butterfly_block_spec [crush_specs]: + shows \\; c_mlk_ntt_butterfly_block while_fuel while_fuel n + r zeta start_v len_v bound_v \\<^sub>F + c_mlk_ntt_butterfly_block_contract r gr cs acs zeta start_v len_v bound_v n\ + apply (crush_boot f: c_mlk_ntt_butterfly_block_def + contract: c_mlk_ntt_butterfly_block_contract_def) + apply crush_base + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (ntt_inner_loop_int (sint zeta) (unat start_v) (unat len_v) + (min (n - k) (unat len_v)) acs)\) + \ (\gj. x \\\\ gj\(of_nat (unat start_v + min (n - k) (unat len_v)) :: c_uint)) + \ can_alloc_reference + \ \ntt_inner_no_overflow (sint zeta) (unat start_v) (unat len_v) (unat len_v) acs\ + \ \unat start_v + 2 * unat len_v \ MLKEM_N\ + \ \k \ n\\ + and INV'=\\k. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (ntt_inner_loop_int (sint zeta) (unat start_v) (unat len_v) + (n - Suc k) acs)\) + \ (\gj. x \\\\ gj\(of_nat (unat start_v + (n - Suc k)) :: c_uint)) + \ can_alloc_reference + \ \ntt_inner_no_overflow (sint zeta) (unat start_v) (unat len_v) (unat len_v) acs\ + \ \unat start_v + 2 * unat len_v \ MLKEM_N\ + \ \Suc k \ n\ + \ \n - Suc k < unat len_v\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + by (crush_base simp add: min_absorb2) + subgoal \ \Condition check\ + by crush_base (auto simp add: word_less_nat_alt unat_of_nat + unat_add_no_overflow_from_bound min_le_iff_disj min_less_iff_conj) + subgoal \ \Loop body\ + apply (crush_base simp add: word_less_nat_alt unat_of_nat + unat_butterfly_idx ntt_inner_no_overflow_def + refines_coeffs_def ntt_inner_loop_int_snoc + ntt_butterfly_int_def fqmul_int_def nth_list_update + ntt_inner_loop_int_length) + apply (auto simp: ntt_inner_no_overflow_def Let_def fqmul_int_def + intro: ntt_butterfly_block_fc_step) + done + subgoal \ \Fuel exhaust\ + by (crush_base simp add: min_absorb2) + done + +subsection \@{text mlk_ntt_layer} contract\ + +text \Middle loop: applies one NTT layer (all butterfly blocks for a given len).\ + +definition c_mlk_ntt_layer_contract :: + \('addr, 8 word list, c_short list) Global_Store.ref \ 8 word list \ + c_short list \ int list \ + c_uint \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_ntt_layer_contract r gr cs acs layer_val \ + let l = unat layer_val; + pre = can_alloc_reference \ + r \\\\ gr\cs \ + \refines_coeffs cs acs\ \ + \1 \ l\ \ + \l \ 7\ \ + \ntt_layer_no_overflow l acs\; + post = \_. can_alloc_reference \ + (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' (ntt_layer_int l acs)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_ntt_layer_contract(*>*) + +lemma c_global_mlk_zetas_sint: + assumes \i < 128\ + shows \sint (c_global_mlk_zetas ! i) = zetas_int ! i\ +using assms by (simp add: c_global_mlk_zetas_def zetas_sword_unfold[symmetric] zetas_sword_sint) + +(*<*) +lemma ntt_layer_total_size: + assumes \1 \ l\ + and \l \ 7\ + shows \2 ^ (l - 1) * (2 * 2 ^ (8 - l)) = MLKEM_N\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis + by auto +qed + +lemma ntt_layer_k_bound: + assumes \1 \ l\ + and \l \ 7\ + and \j < 2 ^ (l - 1)\ + shows \2 ^ (l - 1) + j < (128 :: nat)\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed + +lemma ntt_layer_nb_le_N: + assumes \Suc 0 \ l\ + and \l \ 7\ + shows \2 ^ (l - Suc 0) \ MLKEM_N\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis + by auto +qed + +lemma drop_bit_0x100_eq: + assumes \Suc 0 \ l\ + and \l \ 7\ + shows \drop_bit l (0x100 :: 32 word) = 2 ^ (8 - l)\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis + by auto +qed + +lemma ntt_layer_total_size_word: + assumes \Suc 0 \ l\ \l \ 7\ + shows \(2 :: 32 word) ^ (l - Suc 0) * (2 * 2 ^ (8 - l)) = 0x100\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis + by auto +qed + +lemma word_of_nat_mult_numeral_right: + shows \word_of_nat k * (numeral n :: 'a::len word) = word_of_nat (k * numeral n)\ +by (metis of_nat_mult of_nat_numeral) + +lemma ntt_layer_start_cond: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \k \ 255\ + shows \((word_of_nat (min (255 - k) (2 ^ (l - Suc 0))) :: 32 word) * (2 * 2 ^ (8 - l)) < 0x100) = + (255 - k < 2 ^ (l - Suc 0))\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis using assms + by (auto simp del: of_nat_mult + simp add: word_of_nat_mult_numeral word_of_nat_mult_numeral_right + word_less_nat_alt unat_of_nat unat_sub word_le_nat_alt + min_def ntt_layer_total_size_word) +qed + +lemma ntt_layer_cond_false_min: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \\ (255 - k < 2 ^ (l - Suc 0))\ + shows \min (255 - (k::nat)) (2 ^ (l - Suc 0)) = min MLKEM_N (2 ^ (l - Suc 0))\ +using assms ntt_layer_nb_le_N[of l] by (simp add: min_def) + +lemma word_of_nat_255_sub: + assumes \k \ 255\ + shows \(word_of_nat (255 - k) :: 32 word) = 0xFF - word_of_nat k\ +using assms by (simp add: of_nat_diff) + +lemma unat_minus_1_bound: + assumes \Suc 0 \ unat v\ + and \unat (v :: 32 word) \ 7\ + shows \unat (v - 1) < 32\ +using assms by (simp add: unat_sub word_le_nat_alt) + +lemma min_from_Suc_less: + assumes \n - Suc k < m\ + and \Suc k \ n\ + shows \min (n - k) m = n - (k :: nat)\ +proof - + from assms have \n - k \ m\ + by linarith + thus ?thesis + by (simp add: min_absorb1) +qed + +lemma ntt_layer_blen_le_N: + assumes \Suc 0 \ l\ + and \l \ 7\ + shows \2 ^ (8 - l) \ MLKEM_N\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis + by auto +qed + +lemma ntt_layer_unat_k_index: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \unat ((2::32 word) ^ (l - Suc 0) + (0xFF - word_of_nat k)) = 2 ^ (l - Suc 0) + (255 - k)\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by (auto simp: unat_of_nat of_nat_diff unat_sub word_le_nat_alt) +qed + +lemma ntt_layer_unat_offset: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \k \ 255\ + shows \unat ((0xFF - word_of_nat k) * (2 * (2::32 word) ^ (8 - l))) = (255 - k) * (2 * 2 ^ (8 - l))\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ + by auto + have eq: \(0xFF :: 32 word) - word_of_nat k = word_of_nat (255 - k)\ + using assms by (simp add: of_nat_diff) + have w2: \(2 :: 32 word) * 2 ^ (8 - l) = word_of_nat (2 * 2 ^ (8 - l))\ + using l_cases by (elim insertE emptyE; simp) + have bound: \(255 - k) * (2 * 2 ^ (8 - l)) < 2 ^ 32\ + using l_cases assms by (elim insertE emptyE; simp) + show ?thesis + unfolding eq w2 of_nat_mult[symmetric] unat_of_nat using bound by (simp add: mod_less) +qed + +lemma ntt_middle_loop_int_step: + shows \ntt_inner_loop_int (zetas_int ! (k + j)) (j * (2 * blen)) blen blen + (snd (ntt_middle_loop_int k blen j j cs)) = + snd (ntt_middle_loop_int (Suc k) blen j (Suc j) + (ntt_inner_loop_int (zetas_int ! k) 0 blen blen cs))\ +by (simp add: ntt_middle_loop_int_snoc[symmetric] ntt_middle_loop_int.simps(2) Let_def) + +lemma ntt_layer_fuel_not_exhausted: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \2 ^ (l - Suc 0) + 255 - k < (128 :: nat)\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed + +lemma ntt_layer_offset_bound: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \(255 - k) * 2 * 2 ^ (8 - l) + 2 * 2 ^ (8 - l) \ MLKEM_N\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed +(*>*) + +lemma c_mlk_ntt_layer_spec [crush_specs]: + shows \\; c_mlk_ntt_layer MLKEM_N while_fuel while_fuel MLKEM_N + r layer_val \\<^sub>F + c_mlk_ntt_layer_contract r gr cs acs layer_val\ + apply (crush_boot f: c_mlk_ntt_layer_def + contract: c_mlk_ntt_layer_contract_def) + apply crush_base + subgoal for x xa xb \ \Main while loop\ + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (snd (ntt_middle_loop_int (2^(unat layer_val - 1)) (2^(8 - unat layer_val)) + (min (MLKEM_N - k) (2^(unat layer_val - 1))) + (min (MLKEM_N - k) (2^(unat layer_val - 1))) acs))\) + \ (\gs. x \\\\ gs\(of_nat (min (MLKEM_N - k) (2^(unat layer_val - 1)) * (2 * 2^(8 - unat layer_val))) :: c_uint)) + \ (\gk. xa \\\\ gk\(of_nat (2^(unat layer_val - 1) + min (MLKEM_N - k) (2^(unat layer_val - 1))) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2^(8 - unat layer_val)) :: c_uint)) + \ can_alloc_reference + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \ntt_layer_no_overflow (unat layer_val) acs\ + \ \k \ MLKEM_N\\ + and INV'=\\k. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (snd (ntt_middle_loop_int (2^(unat layer_val - 1)) (2^(8 - unat layer_val)) + (MLKEM_N - Suc k) (MLKEM_N - Suc k) acs))\) + \ (\gs. x \\\\ gs\(of_nat ((MLKEM_N - Suc k) * (2 * 2^(8 - unat layer_val))) :: c_uint)) + \ (\gk. xa \\\\ gk\(of_nat (2^(unat layer_val - 1) + (MLKEM_N - Suc k)) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2^(8 - unat layer_val)) :: c_uint)) + \ can_alloc_reference + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \ntt_layer_no_overflow (unat layer_val) acs\ + \ \Suc k \ MLKEM_N\ + \ \MLKEM_N - Suc k < 2^(unat layer_val - 1)\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + by (crush_base simp add: min_absorb2 ntt_layer_int_def ntt_layer_total_size ntt_layer_nb_le_N + unat_sub word_le_nat_alt drop_bit_0x100_eq) + subgoal \ \Condition check\ + by crush_base (simp_all add: ntt_layer_start_cond ntt_layer_cond_false_min + word_of_nat_255_sub) + subgoal \ \Loop body\ + \ \Phase 1: VCG without zetas literal\ + apply (crush_base simp add: bind2_unseq_def refines_coeffs_def min_absorb2 + word_less_nat_alt unat_of_nat word_le_nat_alt unat_sub + ntt_layer_k_bound ntt_layer_total_size ntt_layer_total_size_word + ntt_layer_nb_le_N + ntt_middle_loop_int_length ntt_inner_loop_int_length) + \ \Phase 2: finish VCG, folding zetas literal to zetas_sword\ + apply (crush_base simp add: + c_global_mlk_zetas_def zetas_sword_unfold[symmetric] + zetas_sword_sint + bind2_unseq_def refines_coeffs_def min_absorb2 + word_less_nat_alt unat_of_nat word_le_nat_alt unat_sub + ntt_layer_k_bound ntt_layer_total_size ntt_layer_total_size_word + ntt_layer_no_overflow_block ntt_layer_nb_le_N + ntt_middle_loop_int_snoc[symmetric] + ntt_middle_loop_int_length ntt_inner_loop_int_length) + \ \Phase 3: close remaining pure goals\ + apply (simp_all only: ntt_layer_blen_le_N ntt_layer_k_bound ntt_layer_nb_le_N + ntt_layer_total_size ntt_layer_unat_k_index ntt_layer_unat_offset + ntt_middle_loop_int_step[symmetric]) + apply (simp_all only: One_nat_def[symmetric] mult.assoc[symmetric] + ntt_layer_no_overflow_block) + \ \6 goals: offset, start_arith, fuel (\2)\ + apply (auto simp add: ntt_layer_offset_bound ntt_layer_fuel_not_exhausted ring_distribs) + done + subgoal \ \Fuel exhaust\ + by (crush_base simp add: min_absorb2 ntt_layer_nb_le_N ntt_layer_total_size_word) + done + subgoal \ \Shift amount bound\ + by (simp add: unat_sub word_le_nat_alt) + done + +subsection \@{text mlk_poly_ntt_c} contract\ + +text \Forward NTT on a polynomial. Applies 7 NTT layers in-place. + Requires input coefficients bounded by @{term MLKEM_Q} for overflow safety.\ +definition c_mlk_poly_ntt_c_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_ntt_c_contract p gp vp ap \ + let pre = can_alloc_reference \ + p \\\\ gp\vp \ + \refines_mlk_poly vp ap\ \ + \coeff_bound MLKEM_Q ap\; + post = \_. can_alloc_reference \ + (\gp' vp'. p \\\\ gp'\vp' \ + \refines_mlk_poly vp' (poly_ntt_int ap)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_ntt_c_contract(*>*) + +text \Lifted ntt\_layer contract: operates on the poly reference directly, + avoiding the focused reference and its schematic resolution issues.\ + +definition c_mlk_ntt_layer_poly_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int list \ c_uint \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_ntt_layer_poly_contract p gp vp acs layer_val \ + let l = unat layer_val; + pre = can_alloc_reference \ + p \\\\ gp\vp \ + \refines_coeffs (c_mlk_poly_coeffs vp) acs\ \ + \1 \ l\ \ + \l \ 7\ \ + \ntt_layer_no_overflow l acs\; + post = \_. can_alloc_reference \ + (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \refines_coeffs cs' (ntt_layer_int l acs)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_ntt_layer_poly_contract(*>*) + +text \Lens validity and focused view/modify lemmas for @{type c_mlk_poly}.\ + +lemma is_valid_lens_view_modify_c_mlk_poly: + shows \is_valid_lens_view_modify c_mlk_poly_coeffs update_c_mlk_poly_coeffs\ + unfolding is_valid_lens_view_modify_def +proof (intro conjI allI impI) + fix f s + show \c_mlk_poly_coeffs (update_c_mlk_poly_coeffs f s) = f (c_mlk_poly_coeffs s)\ + by (cases s) (simp add: Datatype_Records.datatype_record_update(41) c_mlk_poly.simps) +next + fix f s + assume \f (c_mlk_poly_coeffs s) = c_mlk_poly_coeffs s\ + then show \update_c_mlk_poly_coeffs f s = s\ + by (cases s) (simp add: Datatype_Records.datatype_record_update(41) c_mlk_poly.simps) +next + fix f g s + show \update_c_mlk_poly_coeffs f (update_c_mlk_poly_coeffs g s) = + update_c_mlk_poly_coeffs (\x. f (g x)) s\ + by (cases s) (simp add: Datatype_Records.datatype_record_update(41) c_mlk_poly.simps) +qed + +lemma focus_view_make_c_mlk_poly_eq: + shows \\{Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs + update_c_mlk_poly_coeffs))} (make_c_mlk_poly cs) \ cs\ + apply (subst lens_to_focus_raw_components'(1)[OF is_valid_lens_via_modifyI'[OF is_valid_lens_view_modify_c_mlk_poly]]) + apply (simp add: make_lens_via_view_modify_components c_mlk_poly.sel) + done + +lemma focus_view_make_c_mlk_poly [simp]: + shows \(\{Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs + update_c_mlk_poly_coeffs))} v0 \ v1) \ (v0 = make_c_mlk_poly v1)\ +using focus_view_make_c_mlk_poly_eq[of \c_mlk_poly_coeffs v0\] c_mlk_poly.collapse[of v0] + by (auto simp add: c_mlk_poly.sel) + +lemma focus_modify_make_c_mlk_poly [simp]: + shows \\{Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs update_c_mlk_poly_coeffs))} + f (make_c_mlk_poly cs) = make_c_mlk_poly (f cs)\ + apply (subst lens_to_focus_raw_components'(3)[OF is_valid_lens_via_modifyI'[OF is_valid_lens_view_modify_c_mlk_poly]]) + apply (subst make_lens_via_view_modify_components(3)[OF is_valid_lens_view_modify_c_mlk_poly]) + apply (simp add: Datatype_Records.datatype_record_update(41) c_mlk_poly.simps) + done + +lemma focusedL_c_mlk_poly: + shows \aentails_conditional_crule_strong + (focus_reference (Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs + update_c_mlk_poly_coeffs))) r \\sh\ g1\v1) + (\g0 = g1\ \ \points_to_localizes r g0 (make_c_mlk_poly v1)\) + (\points_to_localizes (focus_reference (Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs + update_c_mlk_poly_coeffs))) r) g1 v1\) + (r \\sh\ g0\(make_c_mlk_poly v1))\ +unfolding aentails_conditional_crule_strong_def by (crush_base simp add: points_to_def + focus_view_make_c_mlk_poly_eq) + +lemma focusedL_c_mlk_poly_aentails: + shows \focus_reference (Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs + update_c_mlk_poly_coeffs))) r \\sh\ g\v \ r \\sh\ g\(make_c_mlk_poly v)\ +unfolding points_to_def + apply (simp only: untype_ref_focus) + apply (rule asepconj_mono) + apply (simp add: apure_def aentails_def is_valid_ref_for_def focus_reference_def + focus_focused_get_focus focus_compose_components bind_eq_Some_conv focus_view_make_c_mlk_poly) + apply (clarsimp simp add: focus_dom.rep_eq focus_raw_dom_def focus_compose.rep_eq + focus_raw_compose_def make_focus_raw_via_view_modify_def dom_def bind_eq_Some_conv + focus_view.rep_eq[symmetric] focus_view_make_c_mlk_poly c_mlk_poly.collapse) + apply (fastforce dest: subsetD intro: exI[where x="c_mlk_poly_coeffs _"]) + done + +lemma focusedR_c_mlk_poly_aentails: + shows \r \\sh\ g\v \ focus_reference (Abs_focus (\'\<^sub>l (make_lens_via_view_modify c_mlk_poly_coeffs + update_c_mlk_poly_coeffs))) r \\sh\ g\(c_mlk_poly_coeffs v)\ +unfolding points_to_def + apply (simp only: untype_ref_focus) + apply (rule asepconj_mono) + apply (simp add: apure_def aentails_def is_valid_ref_for_def focus_reference_def + focus_focused_get_focus focus_compose_components bind_eq_Some_conv focus_view_make_c_mlk_poly + c_mlk_poly.sel c_mlk_poly.collapse) + apply (clarsimp simp add: focus_dom.rep_eq focus_raw_dom_def focus_compose.rep_eq + focus_raw_compose_def make_focus_raw_via_view_modify_def dom_def bind_eq_Some_conv + focus_view.rep_eq[symmetric] focus_view_make_c_mlk_poly c_mlk_poly.collapse) + apply fastforce + done + +lemma c_mlk_ntt_layer_poly_spec: + shows \\; c_mlk_ntt_layer MLKEM_N while_fuel while_fuel MLKEM_N + (focus_reference (Abs_focus (\'\<^sub>l (make_lens_via_view_modify + c_mlk_poly_coeffs update_c_mlk_poly_coeffs))) p) layer_val \\<^sub>F + c_mlk_ntt_layer_poly_contract p gp vp acs layer_val\ + apply (rule satisfies_function_contract_weaken[OF c_mlk_ntt_layer_spec[where + gr=gp and cs=\c_mlk_poly_coeffs vp\ and acs=acs]]) + apply (simp_all add: c_mlk_ntt_layer_poly_contract_def c_mlk_ntt_layer_contract_def Let_def) + subgoal + by crush_base + subgoal + by crush_base + subgoal + by (crush_base intro: focusedR_c_mlk_poly_aentails) + subgoal + apply (rule asepconj_mono) + apply (rule aexists_entailsL)+ + apply (rule aentails_intro(7))+ + apply (rule asepconj_mono2[OF focusedL_c_mlk_poly_aentails]) + done + apply (rule aentails_refl) + done + +lemma c_mlk_poly_ntt_c_spec [crush_specs]: + notes focusedL_c_mlk_poly[crush_aentails_cond_crules, crush_points_to_cond_crules] + shows \\; c_mlk_poly_ntt_c MLKEM_N while_fuel while_fuel MLKEM_N 7 p \\<^sub>F + c_mlk_poly_ntt_c_contract p gp vp ap\ + apply (crush_boot f: c_mlk_poly_ntt_c_def contract: c_mlk_poly_ntt_c_contract_def) + apply crush_base + subgoal for x + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \length cs' = MLKEM_N\ \ + \ntt_outer_loop_int (2 ^ (7 - k)) k (list.map sint cs') = poly_ntt_int ap\ \ + \coeff_bound (int (8 - k) * MLKEM_Q) (list.map sint cs')\) + \ (\gx. x \\\\ gx\(of_nat (8 - k) :: c_uint)) + \ can_alloc_reference\ + and INV'=\\k. (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \length cs' = MLKEM_N\ \ + \ntt_outer_loop_int (2 ^ (6 - k)) (Suc k) (list.map sint cs') = poly_ntt_int ap\ \ + \coeff_bound (int (7 - k) * MLKEM_Q) (list.map sint cs')\) + \ (\gx. x \\\\ gx\(of_nat (7 - k) :: c_uint)) + \ can_alloc_reference\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + by (crush_base simp add: poly_ntt_int_def refines_mlk_poly_def + c_mlk_poly.collapse[symmetric] points_to_def) + subgoal \ \Condition check\ + by crush_base (simp_all add: word_le_nat_alt unat_of_nat unat_sub_if_size) + subgoal for k \ \Loop body\ + apply (crush_base no_schematics) + apply (ucincl_discharge \rule wp_callI[OF dereference_spec]\) + apply (force simp add: dereference_contract_def) + apply (crush_base no_schematics simp add: dereference_contract_def) + apply (ucincl_discharge \rule wp_callI[OF c_mlk_ntt_layer_poly_spec]\) + apply (force simp add: c_mlk_ntt_layer_poly_contract_def Let_def c_mlk_poly.sel + refines_coeffs_def) + apply (crush_base no_schematics simp add: c_mlk_ntt_layer_poly_contract_def) + apply (ucincl_discharge \rule wp_callI[OF dereference_spec]\) + apply (force simp add: dereference_contract_def) + apply (crush_base no_schematics simp add: dereference_contract_def word_le_nat_alt + refines_coeffs_def) + apply (ucincl_discharge \rule wp_callI[OF update_spec]\) + apply (force simp add: update_contract_def) + apply (crush_base no_schematics simp add: update_contract_def) + apply (auto simp add: unat_of_nat word_less_nat_alt unat_sub_if_size word_size + intro: coeff_bound_implies_ntt_layer_no_overflow[where l = \7 - k\])[3] + subgoal for cs' cs'' cs''' cs'''' cs''''' + apply (rule_tac x="\{\\<^sub>p c_uint_prism} (\_. 8 - word_of_nat k) cs'" in aentails_intro(7)) + apply (rule_tac x="cs''''" in aentails_intro(7)) + apply (rule_tac x="cs'''''" in aentails_intro(7)) + apply crush_base + subgoal \ \coeff_bound\ + apply (subgoal_tac \unat (7 - word_of_nat k :: 32 word) = 7 - k\) + prefer 2 apply (simp add: unat_of_nat word_less_nat_alt unat_sub_if_size word_size) + apply (simp only:) + apply (subgoal_tac \coeff_bound (int ((7-k)+1) * MLKEM_Q) (ntt_layer_int (7-k) (list.map sint cs'''))\) + apply (simp add: of_nat_diff algebra_simps) + apply (rule ntt_layer_int_bound[where l=\7 - k\]) + apply (auto simp add: of_nat_diff algebra_simps) + done + subgoal \ \ntt_outer_loop\ + apply (subgoal_tac \unat (7 - word_of_nat k :: 32 word) = 7 - k\) + prefer 2 apply (simp add: unat_of_nat word_less_nat_alt unat_sub_if_size word_size) + apply (simp only:) + apply (simp only: ntt_layer_int_def) + apply (simp only: ntt_middle_loop_int_fst case_prod_beta) + apply (subgoal_tac \(2::nat) ^ (7 - k) = 2 ^ (6 - k) + 2 ^ (6 - k)\) + apply simp + apply (subgoal_tac \(7::nat) - k = Suc (6 - k)\) + apply simp + apply simp + done + done + done + subgoal \ \Fuel exhaust\ + by crush_base + done +done + +subsection \@{text mlk_poly_ntt} contract\ + +text \Top-level wrapper: delegates to @{verbatim \mlk_poly_ntt_c\} via + the @{const refines_mlk_poly} abstraction boundary.\ + +definition c_mlk_poly_ntt_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_ntt_contract p gp vp ap \ + let pre = can_alloc_reference \ + p \\\\ gp\vp \ + \refines_mlk_poly vp ap\ \ + \coeff_bound MLKEM_Q ap\; + post = \_. can_alloc_reference \ + (\gp' vp'. p \\\\ gp'\vp' \ + \refines_mlk_poly vp' (poly_ntt_int ap)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_ntt_contract(*>*) + +lemma c_mlk_poly_ntt_spec [crush_specs]: + shows \\; c_mlk_poly_ntt 7 MLKEM_N while_fuel while_fuel MLKEM_N p \\<^sub>F + c_mlk_poly_ntt_contract p gp vp ap\ + apply (crush_boot f: c_mlk_poly_ntt_def contract: c_mlk_poly_ntt_contract_def) + apply (rule wp_callI[OF c_mlk_poly_ntt_c_spec[where gp=gp and vp=vp and ap=ap]]) + apply (simp add: c_mlk_poly_ntt_c_contract_def) + apply (crush_base simp add: c_mlk_poly_ntt_c_contract_def) + done + +(*<*) +end +(*>*) + +(*<*) +end +(*>*) From e907b757c88ba6bdf92e829273dfa386b0825552 Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:01:08 +0000 Subject: [PATCH 10/11] proofs/isabelle: add C functional correctness proof of inverse NTT Add MLKEM_FC_InvNTT.thy with contracts for invntt_layer, poly_invntt_tomont_c, and poly_invntt_tomont. --- proofs/isabelle/MLKEM_FC_InvNTT.thy | 859 ++++++++++++++++++++++++++++ 1 file changed, 859 insertions(+) create mode 100644 proofs/isabelle/MLKEM_FC_InvNTT.thy diff --git a/proofs/isabelle/MLKEM_FC_InvNTT.thy b/proofs/isabelle/MLKEM_FC_InvNTT.thy new file mode 100644 index 0000000000..bdd74c35c0 --- /dev/null +++ b/proofs/isabelle/MLKEM_FC_InvNTT.thy @@ -0,0 +1,859 @@ +(*<*) +theory MLKEM_FC_InvNTT + imports MLKEM_FC_NTT MLKEM_InvNTT_Spec +begin +(*>*) + +text \ + Functional correctness proof of the inverse NTT implementation. + Verifies the single-layer and outer-loop functions (including fqmul + prescaling) against the abstract specification from + @{text MLKEM_InvNTT_Spec}. +\ + +section \Inverse NTT Verification\ + +(*<*) +context c_mlk_machine_model +begin + +declare c_mlk_invntt_layer_def [micro_rust_simps del] +declare c_mlk_poly_invntt_tomont_c_def [micro_rust_simps del] +(*>*) + +subsection \@{text mlk_invntt_layer} contract\ + +text \Single inverse NTT layer. Applies Gentleman--Sande butterflies + (Barrett-reduced sum, Montgomery-multiplied difference) across all + blocks at a given stride.\ + +definition c_mlk_invntt_layer_contract :: + \('addr, 8 word list, c_short list) Global_Store.ref \ 8 word list \ + c_short list \ int list \ + c_uint \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_invntt_layer_contract r gr cs acs layer_val \ + let l = unat layer_val; + pre = can_alloc_reference \ + r \\\\ gr\cs \ + \refines_coeffs cs acs\ \ + \1 \ l\ \ + \l \ 7\ \ + \coeff_bound MLKEM_Q acs\; + post = \_. can_alloc_reference \ + (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' (invntt_layer_int l acs)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_invntt_layer_contract(*>*) + +text \Helper: the invNTT middle loop step lemma, folding the loop body result back + to the abstract spec form. Analogous to @{thm ntt_middle_loop_int_step}.\ + +lemma invntt_middle_loop_int_step: + shows \invntt_inner_loop_int (zetas_int ! (k - j)) (j * (2 * blen)) blen blen + (snd (invntt_middle_loop_int k blen j j cs)) = + snd (invntt_middle_loop_int (k - 1) blen j (Suc j) + (invntt_inner_loop_int (zetas_int ! k) 0 blen blen cs))\ +by (simp add: invntt_middle_loop_int_snoc[symmetric] invntt_middle_loop_int.simps(2) Let_def) + +text \Specialization of @{thm invntt_middle_loop_int_step} for the concrete + zeta index expression arising in the invNTT layer proof.\ + +lemma invntt_layer_step: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - ko < 2 ^ (l - Suc 0)\ + and \ko \ 255\ + shows \invntt_inner_loop_int (zetas_int ! (2 ^ l + ko - MLKEM_N)) + ((255 - ko) * (2 * blen)) blen blen + (snd (invntt_middle_loop_int (2 ^ l - Suc 0) blen (255 - ko) (255 - ko) cs)) = + snd (invntt_middle_loop_int (2 ^ l - Suc (Suc 0)) blen (255 - ko) (Suc (255 - ko)) + (invntt_inner_loop_int (zetas_int ! (2 ^ l - Suc 0)) 0 blen blen cs))\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ + by auto + hence idx: \2 ^ l + ko - MLKEM_N = (2 ^ l - Suc 0) - (255 - ko)\ + using assms by auto + have dec: \2 ^ l - Suc (Suc 0) = (2 ^ l - Suc 0) - 1\ + using l_cases by auto + show ?thesis + by (simp only: idx dec invntt_middle_loop_int_step) +qed + +lemma invntt_layer_k_dec_word: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - ko < 2 ^ (l - Suc 0)\ + and \ko \ 255\ + shows \(word_of_nat (2 ^ l - Suc (Suc (255 - ko))) :: 32 word) = + word_of_nat (2 ^ l + ko - MLKEM_N) - 1\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ + by auto + have eq: \2 ^ l - Suc (Suc (255 - ko)) = (2 ^ l + ko - MLKEM_N) - 1\ + using l_cases assms by auto + show ?thesis + unfolding eq using l_cases assms by auto +qed + +lemma invntt_layer_total_bound: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - ko < 2 ^ (l - Suc 0)\ + and \ko \ 255\ + shows \(255 - ko) * (2 * 2 ^ (8 - l)) + 2 * 2 ^ (8 - l) \ MLKEM_N\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed + +lemma invntt_inner_loop_cond: + assumes \Suc 0 \ l\ \l \ 7\ \255 - ko < 2 ^ (l - Suc 0)\ \ko \ 255\ \k \ 255\ + shows \(unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - l)) + + word_of_nat (min (255 - k) (2 ^ (8 - l)))) < + unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - l)) + 2 ^ (8 - l))) = + (255 - k < 2 ^ (8 - l))\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ by auto + show ?thesis + proof + assume cond: \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - l)) + + word_of_nat (min (255 - k) (2 ^ (8 - l)))) < + unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - l)) + 2 ^ (8 - l))\ + show \255 - k < 2 ^ (8 - l)\ + proof (rule ccontr) + assume \\ 255 - k < 2 ^ (8 - l)\ + hence \min (255 - k) (2 ^ (8 - l)) = 2 ^ (8 - l)\ by (simp add: min_def) + with cond show False by simp + qed + next + assume cond: \255 - k < 2 ^ (8 - l)\ + show \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - l)) + + word_of_nat (min (255 - k) (2 ^ (8 - l)))) < + unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - l)) + 2 ^ (8 - l))\ + using l_cases assms cond + by (auto simp del: of_nat_mult of_nat_add of_nat_diff + simp add: word_of_nat_mult_numeral word_of_nat_mult_numeral_right + unat_word_ariths(1) unat_sub word_le_nat_alt unat_of_nat + min_def) + qed +qed + +text \Zeta index bound for the inverse NTT: the k variable (starting at @{term \2^l - 1\} + and decrementing) always stays below 128.\ + +lemma invntt_layer_k_bound: + assumes \Suc 0 \ l\ + and \l \ 7\ + shows \2 ^ l - Suc 0 - j < (128 :: nat)\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + then show ?thesis + by auto +qed + +lemma invntt_layer_unat_k_index: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \unat ((2::32 word) ^ l - 1 - (0xFF - word_of_nat k)) = 2 ^ l - Suc 0 - (255 - k)\ +proof - + from assms have l_cases: \l \ {1,2,3,4,5,6,7}\ + by auto + have eq0: \(0xFF :: 32 word) - word_of_nat k = word_of_nat (255 - k)\ + using assms(4) by (simp add: of_nat_diff) + have le: \(255 :: nat) - k \ 2 ^ l - 1\ + using l_cases assms(3) by auto + have eq1: \(2 :: 32 word) ^ l - 1 = word_of_nat (2 ^ l - 1)\ + using l_cases by (elim insertE emptyE; simp) + have eq2: \word_of_nat (2 ^ l - 1) - word_of_nat (255 - k) = + (word_of_nat (2 ^ l - 1 - (255 - k)) :: 32 word)\ + by (rule of_nat_diff[OF le, symmetric]) + have bound: \2 ^ l - 1 - (255 - k) < 2 ^ LENGTH(32)\ + using l_cases by auto + show ?thesis + unfolding eq0 eq1 eq2 unat_of_nat using bound by (simp add: mod_less) +qed + +lemma invntt_layer_zeta_index_bound: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \(2 ^ l + k - MLKEM_N) mod 4294967296 < 128\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed + +lemma invntt_layer_zeta_index_eq: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \(2 ^ l + k - MLKEM_N) mod 4294967296 = 2 ^ l - Suc 0 - (255 - k)\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed + +lemma invntt_layer_zeta_index_bound_raw: + assumes \Suc 0 \ l\ + and \l \ 7\ + and \255 - k < 2 ^ (l - Suc 0)\ + and \k \ 255\ + shows \2 ^ l + k - MLKEM_N < 128\ +proof - + from assms have \l \ {1,2,3,4,5,6,7}\ + by auto + with assms show ?thesis + by auto +qed + +lemma coeff_bound_map_sint_update2: + assumes \coeff_bound B (list.map sint cs)\ + and \\sint v1\ < B\ and \\sint v2\ < B\ + and \i < length cs\ and \j < length cs\ + shows \coeff_bound B (list.map sint (cs[i := v1, j := v2]))\ +using assms unfolding coeff_bound_def by (auto simp: nth_list_update) + +lemma invntt_unat_idx: + fixes layer_val :: \32 word\ + assumes \unat layer_val \ 7\ \Suc 0 \ unat layer_val\ + and \ko \ 255\ \ki \ 255\ + and \255 - ki < 2 ^ (8 - unat layer_val)\ + and \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + 2 * 2 ^ (8 - unat layer_val) \ MLKEM_N\ + shows \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - unat layer_val)) + + (0xFF - word_of_nat ki)) = + (255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (255 - ki)\ + and \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - unat layer_val)) + + (0xFF - word_of_nat ki) + 2 ^ (8 - unat layer_val)) = + (255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (255 - ki) + 2 ^ (8 - unat layer_val)\ +proof - + have ki_bound: \ki < 2 ^ LENGTH(32)\ using assms by simp + have ko_bound: \ko < 2 ^ LENGTH(32)\ using assms by simp + have ki_le: \word_of_nat ki \ (0xFF :: 32 word)\ + using assms ki_bound by (simp add: word_le_nat_alt unat_of_nat mod_less) + have ko_le: \word_of_nat ko \ (0xFF :: 32 word)\ + using assms ko_bound by (simp add: word_le_nat_alt unat_of_nat mod_less) + have h_ki: \unat (0xFF - word_of_nat ki :: 32 word) = 255 - ki\ + using unat_sub[OF ki_le] ki_bound by (simp add: unat_of_nat) + have h_mul: \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - unat layer_val))) + = (255 - ko) * (2 * 2 ^ (8 - unat layer_val))\ + by (rule ntt_layer_unat_offset[OF assms(2) assms(1) assms(3)]) + have sum_bound: \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (255 - ki) < 2 ^ LENGTH(32)\ + using assms by simp + have sum_bound2: \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (255 - ki) + + 2 ^ (8 - unat layer_val) < 2 ^ LENGTH(32)\ + using assms by simp + show \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - unat layer_val)) + + (0xFF - word_of_nat ki)) = + (255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (255 - ki)\ + using sum_bound by (simp only: unat_word_ariths(1) h_mul h_ki) (simp add: mod_less) + show \unat ((0xFF - word_of_nat ko :: 32 word) * (2 * 2 ^ (8 - unat layer_val)) + + (0xFF - word_of_nat ki) + 2 ^ (8 - unat layer_val)) = + (255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (255 - ki) + 2 ^ (8 - unat layer_val)\ + using sum_bound2 by (simp only: unat_word_ariths(1) h_mul h_ki) (simp add: mod_less) +qed + +lemma invntt_butterfly_block_fc_step: + fixes ra :: \c_short list\ + and r_br r_fq :: c_short + and z :: \int\ + and off blen m :: \nat\ + assumes ra_eq: \list.map sint ra = invntt_inner_loop_int z off blen m base_cs\ + and br_eq: \sint r_br = barrett_reduce_int (sint (ra!(off+m) + ra!(off+m+blen)))\ + and fq_eq: \sint r_fq = fqmul_int (sint (ra!(off+m+blen) - ra!(off+m))) z\ + and m_bound: \m < blen\ + and sum_bound: \off + 2 * blen \ MLKEM_N\ + and len_ra: \length ra = MLKEM_N\ + and cb: \coeff_bound MLKEM_Q (list.map sint ra)\ + shows \list.map sint (ra[off+m := r_br, off+m+blen := r_fq]) = + invntt_inner_loop_int z off blen (Suc m) base_cs\ +proof - + \ \Abstract state after m iterations\ + define L where + \L = list.map sint ra\ + from m_bound sum_bound len_ra have lo_bound: \off + m < length ra\ and hi_bound: \off + m + blen < length ra\ + by auto + have L_eq: \L = invntt_inner_loop_int z off blen m base_cs\ + using ra_eq L_def by simp + have L_lo: \L ! (off + m) = sint (ra ! (off + m))\ + using lo_bound L_def by (simp add: nth_map) + have L_hi: \L ! (off + m + blen) = sint (ra ! (off + m + blen))\ + using hi_bound L_def by (simp add: nth_map) + \ \Overflow safety from coeff_bound\ + have ob: \- 32768 \ L ! (off + m) + L ! (off + m + blen)\ + \L ! (off + m) + L ! (off + m + blen) \ 32767\ + \- 32768 \ L ! (off + m + blen) - L ! (off + m)\ + \L ! (off + m + blen) - L ! (off + m) \ 32767\ + using invntt_coeff_bound_sum_bounds[of L \off + m\ blen] cb lo_bound hi_bound by (auto simp: L_def) + \ \sint distributes over + and -\ + have sint_add: \sint (ra ! (off + m) + ra ! (off + m + blen)) = + sint (ra ! (off + m)) + sint (ra ! (off + m + blen))\ + by (rule sint_plus_no_overflow) (use ob L_lo L_hi in auto) + have sint_sub: \sint (ra ! (off + m + blen) - ra ! (off + m)) = + sint (ra ! (off + m + blen)) - sint (ra ! (off + m))\ + by (rule sint_minus_no_overflow) (use ob L_lo L_hi in auto) + \ \RHS: expand via snoc + butterfly\ + have rhs_eq: \invntt_inner_loop_int z off blen (Suc m) base_cs = + (L[off + m := barrett_reduce_int (L ! (off + m) + L ! (off + m + blen))]) + [off + m + blen := fqmul_int (L ! (off + m + blen) - L ! (off + m)) z]\ + unfolding L_eq invntt_inner_loop_int_snoc invntt_butterfly_int_def Let_def by simp + \ \LHS: push map sint through updates, substitute\ + have lhs_eq: \list.map sint (ra[off + m := r_br, off + m + blen := r_fq]) = + (L[off + m := sint r_br])[off + m + blen := sint r_fq]\ + using L_def lo_bound hi_bound by (simp add: List.map_update) + show ?thesis + using lhs_eq rhs_eq br_eq fq_eq sint_add sint_sub L_lo L_hi by simp +qed + +lemma invntt_butterfly_block_fc_step': + fixes ra :: \c_short list\ + and r_br r_fq :: c_short + and z :: \int\ + and off blen m :: \nat\ + assumes ra_eq: \list.map sint ra = invntt_inner_loop_int z off blen m base_cs\ + and br_eq: \sint r_br = barrett_reduce_int (sint (ra!(off+m) + ra!(off+m+blen)))\ + and fq_eq: \sint r_fq = fqmul_int (sint (ra!(off+m+blen) - ra!(off+m))) z\ + and m_bound: \m < blen\ + and sum_bound: \off + 2 * blen \ MLKEM_N\ + and len_ra: \length ra = MLKEM_N\ + and cb: \coeff_bound MLKEM_Q (list.map sint ra)\ + shows \list.map sint (ra[off+m := r_br, off+m+blen := r_fq]) = + invntt_inner_loop_int z (Suc off) blen m + (invntt_butterfly_int z off blen base_cs)\ +using invntt_butterfly_block_fc_step[OF assms] by (simp add: invntt_inner_loop_int_snoc) + +lemma c_mlk_invntt_layer_spec [crush_specs]: + shows \\; c_mlk_invntt_layer while_fuel while_fuel while_fuel MLKEM_N MLKEM_N + r layer_val \\<^sub>F + c_mlk_invntt_layer_contract r gr cs acs layer_val\ + apply (crush_boot f: c_mlk_invntt_layer_def + contract: c_mlk_invntt_layer_contract_def) + apply crush_base + subgoal for x xa xb \ \Main while loop\ + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (snd (invntt_middle_loop_int (2^(unat layer_val) - 1) (2^(8 - unat layer_val)) + (min (MLKEM_N - k) (2^(unat layer_val - 1))) + (min (MLKEM_N - k) (2^(unat layer_val - 1))) acs))\) + \ (\gs. x \\\\ gs\(of_nat (min (MLKEM_N - k) (2^(unat layer_val - 1)) * (2 * 2^(8 - unat layer_val))) :: c_uint)) + \ (\gk. xa \\\\ gk\(of_nat (2^(unat layer_val) - 1 - min (MLKEM_N - k) (2^(unat layer_val - 1))) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2^(8 - unat layer_val)) :: c_uint)) + \ can_alloc_reference + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \k \ MLKEM_N\\ + and INV'=\\k. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (snd (invntt_middle_loop_int (2^(unat layer_val) - 1) (2^(8 - unat layer_val)) + (MLKEM_N - Suc k) (MLKEM_N - Suc k) acs))\) + \ (\gs. x \\\\ gs\(of_nat ((MLKEM_N - Suc k) * (2 * 2^(8 - unat layer_val))) :: c_uint)) + \ (\gk. xa \\\\ gk\(of_nat (2^(unat layer_val) - 1 - (MLKEM_N - Suc k)) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2^(8 - unat layer_val)) :: c_uint)) + \ can_alloc_reference + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \Suc k \ MLKEM_N\ + \ \MLKEM_N - Suc k < 2^(unat layer_val - 1)\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + by (crush_base simp add: min_absorb2 invntt_layer_int_def ntt_layer_total_size ntt_layer_nb_le_N + unat_sub word_le_nat_alt drop_bit_0x100_eq ntt_layer_total_size_word) + subgoal \ \Condition check\ + by crush_base (simp_all add: ntt_layer_start_cond ntt_layer_cond_false_min + word_of_nat_255_sub) + subgoal for ko \ \Loop body\ + \ \Phase 1: VCG without zetas literal\ + apply (crush_base simp add: bind2_unseq_def refines_coeffs_def min_absorb2 + word_less_nat_alt unat_of_nat word_le_nat_alt unat_sub + invntt_layer_k_bound ntt_layer_total_size ntt_layer_total_size_word + ntt_layer_nb_le_N + invntt_middle_loop_int_length invntt_inner_loop_int_length) + \ \Phase 2: finish VCG with zeta index bounds and folding\ + apply (crush_base simp add: + c_global_mlk_zetas_def zetas_sword_unfold[symmetric] + zetas_sword_sint + invntt_layer_zeta_index_bound invntt_layer_zeta_index_eq + bind2_unseq_def refines_coeffs_def min_absorb2 + word_less_nat_alt unat_of_nat word_le_nat_alt unat_sub + invntt_layer_k_bound ntt_layer_total_size ntt_layer_total_size_word + ntt_layer_nb_le_N + invntt_middle_loop_int_snoc[symmetric] + invntt_middle_loop_int_length invntt_inner_loop_int_length) + \ \Phase 3: rewrite zeta index bound to True\ + apply (simp_all only: invntt_layer_zeta_index_bound_raw) + \ \4 goals: WP1, Abort1(\True), WP2, Abort2(\True)\ + subgoal for gl0 gr0 gs0 cs0 gk0 xj xz \ \Inner while loop WP1\ + apply (ucincl_discharge\ + rule_tac + INV=\\ki. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (invntt_inner_loop_int (zetas_int ! (2 ^ unat layer_val + ko - MLKEM_N)) + ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) + (2 ^ (8 - unat layer_val)) + (min (MLKEM_N - ki) (2 ^ (8 - unat layer_val))) + (snd (invntt_middle_loop_int (2 ^ unat layer_val - 1) (2 ^ (8 - unat layer_val)) + (255 - ko) (255 - ko) (list.map sint cs))))\ + \ \length cs' = MLKEM_N\ + \ \coeff_bound MLKEM_Q (list.map sint cs')\) + \ (\gj. xj \\\\ gj\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + min (MLKEM_N - ki) (2 ^ (8 - unat layer_val))) :: c_uint)) + \ (\gs. x \\\\ gs\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2 ^ (8 - unat layer_val)) :: c_uint)) + \ (\gz. xz \\\\ gz\(zetas_sword ! (2 ^ unat layer_val + ko - MLKEM_N))) + \ can_alloc_reference + \ \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + 2 * 2 ^ (8 - unat layer_val) \ MLKEM_N\ + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \ko \ 255\ \ \255 - ko < 2 ^ (unat layer_val - Suc 0)\ + \ \length cs = MLKEM_N\ + \ \ki \ MLKEM_N\ + \ \coeff_bound MLKEM_Q acs\\ + and INV'=\\ki. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (invntt_inner_loop_int (zetas_int ! (2 ^ unat layer_val + ko - MLKEM_N)) + ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) + (2 ^ (8 - unat layer_val)) + (MLKEM_N - Suc ki) + (snd (invntt_middle_loop_int (2 ^ unat layer_val - 1) (2 ^ (8 - unat layer_val)) + (255 - ko) (255 - ko) (list.map sint cs))))\ + \ \length cs' = MLKEM_N\ + \ \coeff_bound MLKEM_Q (list.map sint cs')\) + \ (\gj. xj \\\\ gj\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (MLKEM_N - Suc ki)) :: c_uint)) + \ (\gs. x \\\\ gs\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2 ^ (8 - unat layer_val)) :: c_uint)) + \ (\gz. xz \\\\ gz\(zetas_sword ! (2 ^ unat layer_val + ko - MLKEM_N))) + \ can_alloc_reference + \ \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + 2 * 2 ^ (8 - unat layer_val) \ MLKEM_N\ + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \ko \ 255\ \ \255 - ko < 2 ^ (unat layer_val - Suc 0)\ + \ \length cs = MLKEM_N\ + \ \Suc ki \ MLKEM_N\ + \ \MLKEM_N - Suc ki < 2 ^ (8 - unat layer_val)\ + \ \coeff_bound MLKEM_Q acs\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + apply (crush_base simp add: min_absorb2 refines_coeffs_def + invntt_middle_loop_int_snoc[symmetric] + invntt_middle_loop_int_length invntt_inner_loop_int_length + ntt_layer_total_size ntt_layer_total_size_word ntt_layer_nb_le_N + word_less_nat_alt unat_of_nat word_le_nat_alt unat_sub) + apply (simp_all add: invntt_layer_step invntt_layer_k_dec_word + invntt_layer_total_bound ring_distribs + ntt_layer_total_size ntt_layer_nb_le_N min_absorb2) + apply (rule invntt_middle_loop_int_coeff_bound[simplified One_nat_def]) + apply auto + done + subgoal \ \Condition\ + by (crush_base simp add: word_less_nat_alt unat_of_nat + refines_coeffs_def invntt_inner_loop_int_snoc + invntt_butterfly_int_def fqmul_int_def Let_def + nth_list_update invntt_inner_loop_int_length + invntt_inner_loop_cond min_absorb2 + word_le_nat_alt unat_sub) + subgoal \ \Body\ + using [[linarith_split_limit = 40]] + apply (crush_base simp add: word_less_nat_alt unat_of_nat + c_mlk_barrett_reduce_contract_def unat_butterfly_idx + refines_coeffs_def invntt_inner_loop_int_snoc + invntt_butterfly_int_def fqmul_int_def Let_def + nth_list_update invntt_inner_loop_int_length + invntt_coeff_bound_sum_bounds + min_absorb2 word_le_nat_alt unat_sub + zetas_sword_sint) + subgoal \ \OOB r[j]\ + by (simp add: unat_word_ariths unat_of_nat + word_le_nat_alt word_less_nat_alt unat_sub) + subgoal \ \OOB r[j+blen]\ + by (simp add: unat_word_ariths unat_of_nat + word_le_nat_alt word_less_nat_alt unat_sub) + subgoal \ \coeff_bound 1\ + by (rule coeff_bound_map_sint_update2; + simp add: barrett_reduce_int_abs_bound + fqmul_int_def[symmetric] fqmul_prescale_bound_sint + zetas_sword_sint zetas_int_abs_bound + invntt_layer_zeta_index_bound_raw) + subgoal \ \refinement 1\ + apply (simp only: invntt_unat_idx fqmul_int_def[symmetric] + zetas_sword_sint invntt_layer_zeta_index_bound_raw) + apply (rule invntt_butterfly_block_fc_step'[unfolded + invntt_butterfly_int_def Let_def]) + apply (assumption | simp)+ + done + subgoal \ \coeff_bound 2\ + by (rule coeff_bound_map_sint_update2; + simp add: barrett_reduce_int_abs_bound + fqmul_int_def[symmetric] fqmul_prescale_bound_sint + zetas_sword_sint zetas_int_abs_bound + invntt_layer_zeta_index_bound_raw) + subgoal \ \refinement 2\ + apply (simp only: invntt_unat_idx fqmul_int_def[symmetric] + zetas_sword_sint invntt_layer_zeta_index_bound_raw) + apply (rule invntt_butterfly_block_fc_step'[unfolded + invntt_butterfly_int_def Let_def]) + apply (assumption | simp)+ + done + done + subgoal + by (crush_base simp add: min_absorb2) \ \Fuel\ + done + subgoal + by simp \ \Abort1\ + subgoal for gl1 gr1 gs1 cs1 gk1 xj1 xz1 \ \Inner while loop WP2\ + apply (ucincl_discharge\ + rule_tac + INV=\\ki. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (invntt_inner_loop_int (zetas_int ! (2 ^ unat layer_val + ko - MLKEM_N)) + ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) + (2 ^ (8 - unat layer_val)) + (min (MLKEM_N - ki) (2 ^ (8 - unat layer_val))) + (snd (invntt_middle_loop_int (2 ^ unat layer_val - 1) (2 ^ (8 - unat layer_val)) + (255 - ko) (255 - ko) (list.map sint cs))))\ + \ \length cs' = MLKEM_N\ + \ \coeff_bound MLKEM_Q (list.map sint cs')\) + \ (\gj. xj1 \\\\ gj\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + min (MLKEM_N - ki) (2 ^ (8 - unat layer_val))) :: c_uint)) + \ (\gs. x \\\\ gs\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2 ^ (8 - unat layer_val)) :: c_uint)) + \ (\gz. xz1 \\\\ gz\(zetas_sword ! (2 ^ unat layer_val + ko - MLKEM_N))) + \ can_alloc_reference + \ \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + 2 * 2 ^ (8 - unat layer_val) \ MLKEM_N\ + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \ko \ 255\ \ \255 - ko < 2 ^ (unat layer_val - Suc 0)\ + \ \length cs = MLKEM_N\ + \ \ki \ MLKEM_N\ + \ \coeff_bound MLKEM_Q acs\\ + and INV'=\\ki. (\gr' cs'. r \\\\ gr'\cs' \ + \refines_coeffs cs' + (invntt_inner_loop_int (zetas_int ! (2 ^ unat layer_val + ko - MLKEM_N)) + ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) + (2 ^ (8 - unat layer_val)) + (MLKEM_N - Suc ki) + (snd (invntt_middle_loop_int (2 ^ unat layer_val - 1) (2 ^ (8 - unat layer_val)) + (255 - ko) (255 - ko) (list.map sint cs))))\ + \ \length cs' = MLKEM_N\ + \ \coeff_bound MLKEM_Q (list.map sint cs')\) + \ (\gj. xj1 \\\\ gj\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + (MLKEM_N - Suc ki)) :: c_uint)) + \ (\gs. x \\\\ gs\(of_nat ((255 - ko) * (2 * 2 ^ (8 - unat layer_val))) :: c_uint)) + \ (\gl. xb \\\\ gl\(of_nat (2 ^ (8 - unat layer_val)) :: c_uint)) + \ (\gz. xz1 \\\\ gz\(zetas_sword ! (2 ^ unat layer_val + ko - MLKEM_N))) + \ can_alloc_reference + \ \(255 - ko) * (2 * 2 ^ (8 - unat layer_val)) + 2 * 2 ^ (8 - unat layer_val) \ MLKEM_N\ + \ \Suc 0 \ unat layer_val\ \ \unat layer_val \ 7\ + \ \ko \ 255\ \ \255 - ko < 2 ^ (unat layer_val - Suc 0)\ + \ \length cs = MLKEM_N\ + \ \Suc ki \ MLKEM_N\ + \ \MLKEM_N - Suc ki < 2 ^ (8 - unat layer_val)\ + \ \coeff_bound MLKEM_Q acs\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + apply (crush_base simp add: min_absorb2 refines_coeffs_def + invntt_middle_loop_int_snoc[symmetric] + invntt_middle_loop_int_length invntt_inner_loop_int_length + ntt_layer_total_size ntt_layer_total_size_word ntt_layer_nb_le_N + word_less_nat_alt unat_of_nat word_le_nat_alt unat_sub) + apply (simp_all add: invntt_layer_step invntt_layer_k_dec_word + invntt_layer_total_bound ring_distribs + ntt_layer_total_size ntt_layer_nb_le_N min_absorb2) + apply (rule invntt_middle_loop_int_coeff_bound[simplified One_nat_def]) + apply auto + done + subgoal \ \Condition\ + by (crush_base simp add: word_less_nat_alt unat_of_nat + refines_coeffs_def invntt_inner_loop_int_snoc + invntt_butterfly_int_def fqmul_int_def Let_def + nth_list_update invntt_inner_loop_int_length + invntt_inner_loop_cond min_absorb2 + word_le_nat_alt unat_sub) + subgoal \ \Body\ + using [[linarith_split_limit = 40]] + apply (crush_base simp add: word_less_nat_alt unat_of_nat + c_mlk_barrett_reduce_contract_def unat_butterfly_idx + refines_coeffs_def invntt_inner_loop_int_snoc + invntt_butterfly_int_def fqmul_int_def Let_def + nth_list_update invntt_inner_loop_int_length + invntt_coeff_bound_sum_bounds + min_absorb2 word_le_nat_alt unat_sub + zetas_sword_sint) + subgoal \ \OOB r[j]\ + by (simp add: unat_word_ariths unat_of_nat + word_le_nat_alt word_less_nat_alt unat_sub) + subgoal \ \OOB r[j+blen]\ + by (simp add: unat_word_ariths unat_of_nat + word_le_nat_alt word_less_nat_alt unat_sub) + subgoal \ \coeff_bound 1\ + by (rule coeff_bound_map_sint_update2; + simp add: barrett_reduce_int_abs_bound + fqmul_int_def[symmetric] fqmul_prescale_bound_sint + zetas_sword_sint zetas_int_abs_bound + invntt_layer_zeta_index_bound_raw) + subgoal \ \refinement 1\ + apply (simp only: invntt_unat_idx fqmul_int_def[symmetric] + zetas_sword_sint invntt_layer_zeta_index_bound_raw) + apply (rule invntt_butterfly_block_fc_step'[unfolded + invntt_butterfly_int_def Let_def]) + apply (assumption | simp)+ + done + subgoal \ \coeff_bound 2\ + by (rule coeff_bound_map_sint_update2; + simp add: barrett_reduce_int_abs_bound + fqmul_int_def[symmetric] fqmul_prescale_bound_sint + zetas_sword_sint zetas_int_abs_bound + invntt_layer_zeta_index_bound_raw) + subgoal \ \refinement 2\ + apply (simp only: invntt_unat_idx fqmul_int_def[symmetric] + zetas_sword_sint invntt_layer_zeta_index_bound_raw) + apply (rule invntt_butterfly_block_fc_step'[unfolded + invntt_butterfly_int_def Let_def]) + apply (assumption | simp)+ + done + done + subgoal + by (crush_base simp add: min_absorb2) \ \Fuel\ + done + subgoal + by simp \ \Abort2\ + done + subgoal \ \Fuel exhaust\ + by (crush_base simp add: min_absorb2 ntt_layer_nb_le_N ntt_layer_total_size_word) + done + done + +text \Lifted invntt\_layer contract: operates on the poly reference directly.\ + +definition c_mlk_invntt_layer_poly_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int list \ c_uint \ + ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_invntt_layer_poly_contract p gp vp acs layer_val \ + let l = unat layer_val; + pre = can_alloc_reference \ + p \\\\ gp\vp \ + \refines_coeffs (c_mlk_poly_coeffs vp) acs\ \ + \1 \ l\ \ + \l \ 7\ \ + \coeff_bound MLKEM_Q acs\; + post = \_. can_alloc_reference \ + (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \refines_coeffs cs' (invntt_layer_int l acs)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_invntt_layer_poly_contract(*>*) + +lemma c_mlk_invntt_layer_poly_spec: + shows \\; c_mlk_invntt_layer while_fuel while_fuel while_fuel MLKEM_N MLKEM_N + (focus_reference (Abs_focus (\'\<^sub>l (make_lens_via_view_modify + c_mlk_poly_coeffs update_c_mlk_poly_coeffs))) p) layer_val \\<^sub>F + c_mlk_invntt_layer_poly_contract p gp vp acs layer_val\ + apply (rule satisfies_function_contract_weaken[OF c_mlk_invntt_layer_spec[where + gr=gp and cs=\c_mlk_poly_coeffs vp\ and acs=acs]]) + apply (simp_all add: c_mlk_invntt_layer_poly_contract_def c_mlk_invntt_layer_contract_def Let_def) + subgoal + by crush_base + subgoal + by crush_base + subgoal + by (crush_base intro: focusedR_c_mlk_poly_aentails) + subgoal + apply (rule asepconj_mono) + apply (rule aexists_entailsL)+ + apply (rule aentails_intro(7))+ + apply (rule asepconj_mono2[OF focusedL_c_mlk_poly_aentails]) + done + apply (rule aentails_refl) + done + +text \Helper lemmas for inverse NTT layer.\ + +lemma invntt_outer_loop_int_step: + shows \invntt_outer_loop_int (Suc n) cs = + invntt_outer_loop_int n (invntt_layer_int (Suc n) cs)\ +by (simp add: invntt_layer_int_def Let_def case_prod_beta) + +lemma invntt_outer_loop_int_step': + assumes \l \ 1\ + shows \invntt_outer_loop_int l cs = invntt_outer_loop_int (l - 1) (invntt_layer_int l cs)\ +proof - + from assms obtain n where \l = Suc n\ + using Suc_le_D by auto + then show ?thesis + using invntt_outer_loop_int_step by simp +qed + +subsection \@{text mlk_poly_invntt_tomont_c} contract\ + +text \Inverse NTT with Montgomery post-scaling. Applies fqmul by 1441 + to all coefficients, then 7 inverse NTT layers.\ +definition c_mlk_poly_invntt_tomont_c_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_invntt_tomont_c_contract p gp vp ap \ + let pre = can_alloc_reference \ + p \\\\ gp\vp \ + \refines_mlk_poly vp ap\; + post = \_. can_alloc_reference \ + (\gp' vp'. p \\\\ gp'\vp' \ + \refines_mlk_poly vp' (poly_invntt_tomont_int ap)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_invntt_tomont_c_contract(*>*) + +lemma c_mlk_poly_invntt_tomont_c_spec [crush_specs]: + notes focusedL_c_mlk_poly[crush_aentails_cond_crules, crush_points_to_cond_crules] + shows \\; c_mlk_poly_invntt_tomont_c MLKEM_N MLKEM_N while_fuel while_fuel while_fuel + while_fuel while_fuel MLKEM_N p \\<^sub>F + c_mlk_poly_invntt_tomont_c_contract p gp vp ap\ + apply (crush_boot f: c_mlk_poly_invntt_tomont_c_def contract: c_mlk_poly_invntt_tomont_c_contract_def) + apply (crush_base simp add: c_mlk_fqmul_contract_def fqmul_int_def) + subgoal for x xa xb \ \x=j ref, xa=layer ref, xb=f ref\ + \ \Phase 1: fqmul prescaling by 1441\ + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \length cs' = MLKEM_N\ \ + \\j < MLKEM_N - k. sint (cs' ! j) = fqmul_int (ap ! j) 1441\ \ + \\j. MLKEM_N - k \ j \ j < MLKEM_N \ + cs' ! j = c_mlk_poly_coeffs vp ! j\) + \ (\gf. xb \\\\ gf\(0x5A1 :: c_short)) + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - k) :: c_uint)) + \ (\gl. xa \\\\ gl\(c_uninitialized :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vp ap\\ + and INV'=\\k. (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \length cs' = MLKEM_N\ \ + \\j < MLKEM_N - Suc k. sint (cs' ! j) = fqmul_int (ap ! j) 1441\ \ + \\j. MLKEM_N - Suc k \ j \ j < MLKEM_N \ + cs' ! j = c_mlk_poly_coeffs vp ! j\) + \ (\gf. xb \\\\ gf\(0x5A1 :: c_short)) + \ (\gx. x \\\\ gx\(of_nat (MLKEM_N - Suc k) :: c_uint)) + \ (\gl. xa \\\\ gl\(c_uninitialized :: c_uint)) + \ can_alloc_reference + \ \refines_mlk_poly vp ap\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization\ + apply (crush_base simp add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat + c_mlk_poly.record_simps nth_append nth_list_update refines_mlk_poly_def + fqmul_int_def c_mlk_fqmul_contract_def) + \ \After fqmul finalization, process layer = 7, then layer while loop\ + subgoal + apply (ucincl_discharge\ + rule_tac + INV=\\k. (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \length cs' = MLKEM_N\ \ + \invntt_outer_loop_int (k + 7 - MLKEM_N) + (list.map sint cs') = poly_invntt_tomont_int ap\ \ + \coeff_bound MLKEM_Q (list.map sint cs')\) + \ (\gx. xa \\\\ gx\(of_nat (k + 7 - MLKEM_N) :: c_uint)) + \ can_alloc_reference\ + and INV'=\\k. (\gp' cs'. p \\\\ gp'\(make_c_mlk_poly cs') \ + \length cs' = MLKEM_N\ \ + \invntt_outer_loop_int (Suc k + 7 - MLKEM_N) + (list.map sint cs') = poly_invntt_tomont_int ap\ \ + \coeff_bound MLKEM_Q (list.map sint cs')\) + \ (\gx. xa \\\\ gx\(of_nat (Suc k + 7 - MLKEM_N) :: c_uint)) + \ can_alloc_reference + \ \1 \ Suc k + 7 - MLKEM_N \ Suc k + 7 - MLKEM_N \ 7\\ + and \=\\_. \False\\ + and \=\\_. \False\\ + in wp_bounded_while_framedI\) + subgoal \ \Init + Finalization of layer loop\ + apply (crush_base simp add: poly_invntt_tomont_int_def refines_mlk_poly_def + c_mlk_poly.collapse[symmetric] points_to_def + c_signed_truthy_zero bounded_while_literal_false) + subgoal + apply (simp add: coeff_bound_def) + by (auto simp add: fqmul_int_def intro: fqmul_prescale_bound_sint[where b=1441, simplified]) + subgoal + apply (rule arg_cong[where f=\invntt_outer_loop_int 7\]) + by (auto intro!: nth_equalityI simp add: fqmul_int_def) + done + subgoal \ \Condition check\ + by crush_base (auto simp add: word_le_nat_alt unat_of_nat unat_sub_if_size word_less_nat_alt) + subgoal for k \ \Loop body\ + apply (crush_base no_schematics) + apply (ucincl_discharge \rule wp_callI[OF dereference_spec]\) + apply (force simp add: dereference_contract_def) + apply (crush_base no_schematics simp add: dereference_contract_def) + apply (ucincl_discharge \rule wp_callI[OF c_mlk_invntt_layer_poly_spec]\) + apply (force simp add: c_mlk_invntt_layer_poly_contract_def Let_def c_mlk_poly.sel + refines_coeffs_def unat_of_nat coeff_bound_def) + apply (crush_base no_schematics simp add: c_mlk_invntt_layer_poly_contract_def) + \ \Close pure goals from invntt_layer precondition\ + apply (auto simp add: unat_of_nat word_less_nat_alt unat_sub_if_size word_size + invntt_outer_loop_int_step' refines_coeffs_def invntt_layer_int_length + intro: invntt_layer_int_coeff_bound)[4] + \ \Process dereference+update for layer counter decrement\ + apply crush_base + by (auto simp add: unat_of_nat word_less_nat_alt unat_sub_if_size word_size + invntt_outer_loop_int_step' refines_coeffs_def invntt_layer_int_length + intro: invntt_layer_int_coeff_bound) + subgoal \ \Fuel exhaust\ + by crush_base + done + done + \ \Condition\ + subgoal + by (crush_base simp add: word_less_nat_alt unat_of_nat + c_signed_truthy_zero bounded_while_literal_false) + \ \Body\ + subgoal for k + by (crush_base simp add: word_less_nat_alt unat_sub word_le_nat_alt unat_of_nat + c_mlk_poly.record_simps nth_append nth_list_update refines_mlk_poly_def + fqmul_int_def c_mlk_fqmul_contract_def + c_signed_truthy_zero bounded_while_literal_false) + \ \Fuel exhaust\ + subgoal + by (crush_base simp add: word_less_nat_alt unat_of_nat + c_signed_truthy_zero bounded_while_literal_false) + done + done + +subsection \@{text mlk_poly_invntt_tomont} contract\ + +text \Top-level wrapper: delegates to @{verbatim \mlk_poly_invntt_tomont_c\} + via the @{const refines_mlk_poly} abstraction boundary.\ + +definition c_mlk_poly_invntt_tomont_contract :: + \('addr, 8 word list, c_mlk_poly) Global_Store.ref \ 8 word list \ + c_mlk_poly \ int_poly \ ('s, unit, c_abort) function_contract\ where + [crush_contracts]: \c_mlk_poly_invntt_tomont_contract p gp vp ap \ + let pre = can_alloc_reference \ + p \\\\ gp\vp \ + \refines_mlk_poly vp ap\; + post = \_. can_alloc_reference \ + (\gp' vp'. p \\\\ gp'\vp' \ + \refines_mlk_poly vp' (poly_invntt_tomont_int ap)\) + in make_function_contract pre post\ +(*<*)ucincl_auto c_mlk_poly_invntt_tomont_contract(*>*) + +lemma c_mlk_poly_invntt_tomont_spec [crush_specs]: + shows \\; c_mlk_poly_invntt_tomont MLKEM_N while_fuel while_fuel while_fuel while_fuel + while_fuel MLKEM_N MLKEM_N p \\<^sub>F + c_mlk_poly_invntt_tomont_contract p gp vp ap\ + apply (crush_boot f: c_mlk_poly_invntt_tomont_def contract: c_mlk_poly_invntt_tomont_contract_def) + apply (rule wp_callI[OF c_mlk_poly_invntt_tomont_c_spec[where gp=gp and vp=vp and ap=ap]]) + apply (simp add: c_mlk_poly_invntt_tomont_c_contract_def) + apply (crush_base simp add: c_mlk_poly_invntt_tomont_c_contract_def) + done + +(*<*) +end +(*>*) + +(*<*) +end +(*>*) From e9bc583419871a430cf1a0a190a5c1e74eb551bc Mon Sep 17 00:00:00 2001 From: Dominic Mulligan Date: Tue, 17 Mar 2026 22:01:23 +0000 Subject: [PATCH 11/11] proofs/isabelle: wire up split theories, add PDF document generation, remove monolithic files Update ROOT with PDF output, document sections, and theory ordering. Update Makefile with build/pdf/clean targets. Add document/root.tex with LaTeX preamble. Update imports in MLKEM_Functional_Correctness.thy, MLKEM_Machine_Model.thy, and MLKEM_Poly_Definitions.thy. Remove superseded monolithic Common.thy and MLKEM_Poly_Functional_Correctness.thy. --- proofs/isabelle/.gitignore | 4 + .../isabelle/MLKEM_Functional_Correctness.thy | 15 + proofs/isabelle/MLKEM_Machine_Model.thy | 58 +++ .../isabelle/MLKEM_Machine_Model_Instance.thy | 230 ++++++++++ proofs/isabelle/MLKEM_Poly_Definitions.thy | 22 + proofs/isabelle/MLKEM_Verify_Definitions.thy | 16 + proofs/isabelle/Makefile | 136 ++++++ proofs/isabelle/README.md | 47 ++ proofs/isabelle/ROOT | 28 ++ proofs/isabelle/document/root.tex | 70 +++ proofs/isabelle/generated/.gitignore | 3 + proofs/isabelle/generated/.gitkeep | 0 proofs/isabelle/pipeline/README.md | 61 +++ .../isabelle/pipeline/config/proof_config.h | 62 +++ proofs/isabelle/pipeline/config/stdint.h | 38 ++ proofs/isabelle/pipeline/generate.py | 408 ++++++++++++++++++ .../pipeline/templates/definitions.thy.tpl | 15 + .../templates/functional_correctness.thy.tpl | 12 + proofs/isabelle/pipeline/units.json | 154 +++++++ 19 files changed, 1379 insertions(+) create mode 100644 proofs/isabelle/.gitignore create mode 100644 proofs/isabelle/MLKEM_Functional_Correctness.thy create mode 100644 proofs/isabelle/MLKEM_Machine_Model.thy create mode 100644 proofs/isabelle/MLKEM_Machine_Model_Instance.thy create mode 100644 proofs/isabelle/MLKEM_Poly_Definitions.thy create mode 100644 proofs/isabelle/MLKEM_Verify_Definitions.thy create mode 100644 proofs/isabelle/Makefile create mode 100644 proofs/isabelle/README.md create mode 100644 proofs/isabelle/ROOT create mode 100644 proofs/isabelle/document/root.tex create mode 100644 proofs/isabelle/generated/.gitignore create mode 100644 proofs/isabelle/generated/.gitkeep create mode 100644 proofs/isabelle/pipeline/README.md create mode 100644 proofs/isabelle/pipeline/config/proof_config.h create mode 100644 proofs/isabelle/pipeline/config/stdint.h create mode 100755 proofs/isabelle/pipeline/generate.py create mode 100644 proofs/isabelle/pipeline/templates/definitions.thy.tpl create mode 100644 proofs/isabelle/pipeline/templates/functional_correctness.thy.tpl create mode 100644 proofs/isabelle/pipeline/units.json diff --git a/proofs/isabelle/.gitignore b/proofs/isabelle/.gitignore new file mode 100644 index 0000000000..143592f88b --- /dev/null +++ b/proofs/isabelle/.gitignore @@ -0,0 +1,4 @@ +# Isabelle/jEdit backup files +*~ +# Isabelle build output +output/ diff --git a/proofs/isabelle/MLKEM_Functional_Correctness.thy b/proofs/isabelle/MLKEM_Functional_Correctness.thy new file mode 100644 index 0000000000..2f840fc633 --- /dev/null +++ b/proofs/isabelle/MLKEM_Functional_Correctness.thy @@ -0,0 +1,15 @@ +theory MLKEM_Functional_Correctness + imports + MLKEM_Machine_Model_Instance + MLKEM_FC_InvNTT + MLKEM_FC_PolyLoop + MLKEM_NTT_Correctness +begin + +text \ + Top-level entry point for \<^verbatim>\mlkem-native\ functional-correctness proofs. + + Intentionally empty, barring imports. +\ + +end diff --git a/proofs/isabelle/MLKEM_Machine_Model.thy b/proofs/isabelle/MLKEM_Machine_Model.thy new file mode 100644 index 0000000000..92271a237f --- /dev/null +++ b/proofs/isabelle/MLKEM_Machine_Model.thy @@ -0,0 +1,58 @@ +theory MLKEM_Machine_Model + imports "Micro_C_Examples.Simple_C_Functions" +begin + +(*<*) +text \ + Machine-model assumptions used by generated \<^verbatim>\mlkem-native\ C definitions. + Types are extracted first so locale assumptions can depend on the generated + C record/datatype names. +\ + +micro_c_file manifest: "generated/manifests/poly.types.manifest" "generated/c/poly.pre.c" + +locale c_mlk_machine_model = + c_pointer_model c_ptr_add c_ptr_shift_signed c_ptr_diff + c_ptr_less c_ptr_le c_ptr_greater c_ptr_ge + c_ptr_to_uintptr c_uintptr_to_ptr + + reference reference_types + + ref_c_mlk_poly: reference_allocatable reference_types _ _ _ _ _ _ _ c_mlk_poly_prism + + ref_c_mlk_poly_mulcache: reference_allocatable reference_types _ _ _ _ _ _ _ c_mlk_poly_mulcache_prism + + ref_c_uint: reference_allocatable reference_types _ _ _ _ _ _ _ c_uint_prism + + ref_c_int: reference_allocatable reference_types _ _ _ _ _ _ _ c_int_prism + + ref_c_short: reference_allocatable reference_types _ _ _ _ _ _ _ c_short_prism + + ref_c_ushort: reference_allocatable reference_types _ _ _ _ _ _ _ c_ushort_prism + + ref_c_short_list: reference_allocatable reference_types _ _ _ _ _ _ _ c_short_list_prism + for c_ptr_add :: \('addr, 8 word list) gref \ nat \ nat \ ('addr, 8 word list) gref\ + and c_ptr_shift_signed :: \('addr, 8 word list) gref \ int \ nat \ ('addr, 8 word list) gref\ + and c_ptr_diff :: \('addr, 8 word list) gref \ ('addr, 8 word list) gref \ nat \ int\ + and c_ptr_less :: \('addr, 8 word list) gref \ ('addr, 8 word list) gref \ bool\ + and c_ptr_le :: \('addr, 8 word list) gref \ ('addr, 8 word list) gref \ bool\ + and c_ptr_greater :: \('addr, 8 word list) gref \ ('addr, 8 word list) gref \ bool\ + and c_ptr_ge :: \('addr, 8 word list) gref \ ('addr, 8 word list) gref \ bool\ + and c_ptr_to_uintptr :: \('addr, 8 word list) gref \ int\ + and c_uintptr_to_ptr :: \int \ ('addr, 8 word list) gref\ + and reference_types :: \'s::{sepalg} \ 'addr \ 8 word list \ c_abort \ 'i prompt \ + 'o prompt_output \ unit\ + and c_mlk_poly_prism :: \(8 word list, c_mlk_poly) prism\ + and c_mlk_poly_mulcache_prism :: \(8 word list, c_mlk_poly_mulcache) prism\ + and c_uint_prism :: \(8 word list, c_uint) prism\ + and c_int_prism :: \(8 word list, c_int) prism\ + and c_short_prism :: \(8 word list, c_short) prism\ + and c_ushort_prism :: \(8 word list, c_ushort) prism\ + and c_short_list_prism :: \(8 word list, c_short list) prism\ +begin + +adhoc_overloading store_reference_const \ ref_c_mlk_poly.new +adhoc_overloading store_reference_const \ ref_c_mlk_poly_mulcache.new +adhoc_overloading store_reference_const \ ref_c_uint.new +adhoc_overloading store_reference_const \ ref_c_int.new +adhoc_overloading store_reference_const \ ref_c_short.new +adhoc_overloading store_reference_const \ ref_c_ushort.new +adhoc_overloading store_reference_const \ ref_c_short_list.new +adhoc_overloading store_update_const \ update_fun + +end +(*>*) + +end diff --git a/proofs/isabelle/MLKEM_Machine_Model_Instance.thy b/proofs/isabelle/MLKEM_Machine_Model_Instance.thy new file mode 100644 index 0000000000..fee7adf10c --- /dev/null +++ b/proofs/isabelle/MLKEM_Machine_Model_Instance.thy @@ -0,0 +1,230 @@ +theory MLKEM_Machine_Model_Instance + imports + MLKEM_Machine_Model + "Micro_Rust_Runtime.Runtime_Heap" + "Shallow_Micro_C.C_Byte_Encoding" +begin + +section \Consistency model for @{locale c_mlk_machine_model}\ + +subsection \Default instance for byte lists\ + +text \ + The heap model @{type mem} requires @{class default} on the stored value type. + We instantiate it for lists with the empty list as default value. +\ + +instantiation list :: (type) default +begin + +definition default_list :: \'a list\ where + \default_list = []\ + +instance .. + +end + +text \ + We provide a concrete interpretation of the @{locale c_mlk_machine_model} locale + using the AutoCorrode heap model from @{theory Micro_Rust_Runtime.Runtime_Heap}, + together with byte-level prisms for all C types. + The successful processing of this theory proves that the locale assumptions + are consistent (satisfiable). +\ + +subsection \Byte prism for @{type c_short} lists\ + +fun decode_c_short_list :: \8 word list \ c_short list option\ where + \decode_c_short_list [] = Some []\ +| \decode_c_short_list [_] = None\ +| \decode_c_short_list (a # b # rest) = + Option.bind (prism_project c_short_byte_prism [a, b]) + (\c. map_option ((#) c) (decode_c_short_list rest))\ + +definition c_short_list_byte_prism :: \(8 word list, c_short list) prism\ where + \c_short_list_byte_prism \ make_prism + (\cs. concat (List.map (prism_embed c_short_byte_prism) cs)) + decode_c_short_list\ + +text \ + Validity of @{const c_short_list_byte_prism}: Each @{type c_short} is encoded as exactly + 2 bytes via @{const c_short_byte_prism}. Decoding splits the byte list into 2-byte chunks + and decodes each chunk. +\ + +lemma c_short_byte_prism_embed_length: + shows \length (prism_embed c_short_byte_prism c) = 2\ +unfolding c_short_byte_prism_def prism_compose_def word_sword_iso_prism_def iso_prism_def + word16_byte_list_prism_le_def list_fixlen_prism_def by (simp add: list_fixlen_embed_def) + +lemma decode_encode_c_short_list: + shows \decode_c_short_list (concat (List.map (prism_embed c_short_byte_prism) cs)) = Some cs\ +proof (induction cs) + case Nil + then show ?case by simp +next + case (Cons c cs) + obtain a b where ab: \prism_embed c_short_byte_prism c = [a, b]\ + using c_short_byte_prism_embed_length[of c] by (metis (no_types, opaque_lifting) length_0_conv + length_Cons length_tl list.sel(3) numeral_2_eq_2 Suc_length_conv) + have proj: \prism_project c_short_byte_prism [a, b] = Some c\ + using is_valid_prism_def[of c_short_byte_prism] c_byte_prism_valid(4) ab by metis + show ?case + using Cons.IH by (simp add: ab proj) +qed + +lemma encode_decode_c_short_list: + shows \decode_c_short_list bs = Some cs \ + bs = concat (List.map (prism_embed c_short_byte_prism) cs)\ +proof (induction bs arbitrary: cs rule: decode_c_short_list.induct) + case 1 + then show ?case by simp +next + case (2 v) + then show ?case by simp +next + case (3 a b rest) + from 3(2) obtain c where proj: \prism_project c_short_byte_prism [a, b] = Some c\ and + mc: \map_option ((#) c) (decode_c_short_list rest) = Some cs\ + by (cases \prism_project c_short_byte_prism [a, b]\) auto + from mc obtain cs' where rest_decode: \decode_c_short_list rest = Some cs'\ and + cs_eq: \cs = c # cs'\ + by (cases \decode_c_short_list rest\) auto + have embed: \[a, b] = prism_embed c_short_byte_prism c\ + using is_valid_prism_def[of c_short_byte_prism] c_byte_prism_valid(4) proj by metis + have \rest = concat (List.map (prism_embed c_short_byte_prism) cs')\ + using 3(1)[OF proj rest_decode] . + then show ?case + by (simp add: cs_eq embed) +qed + +lemma c_short_list_byte_prism_valid: + shows \is_valid_prism c_short_list_byte_prism\ +unfolding is_valid_prism_def c_short_list_byte_prism_def by (auto simp: decode_encode_c_short_list + dest: encode_decode_c_short_list) + +subsection \Struct iso prisms\ + +definition c_mlk_poly_struct_prism :: \(c_short list, c_mlk_poly) prism\ where + \c_mlk_poly_struct_prism \ iso_prism make_c_mlk_poly c_mlk_poly_coeffs\ + +definition c_mlk_poly_mulcache_struct_prism :: \(c_short list, c_mlk_poly_mulcache) prism\ where + \c_mlk_poly_mulcache_struct_prism \ iso_prism make_c_mlk_poly_mulcache c_mlk_poly_mulcache_coeffs\ + +lemma c_mlk_poly_struct_prism_valid: + shows \is_valid_prism c_mlk_poly_struct_prism\ +unfolding c_mlk_poly_struct_prism_def by (rule iso_prism_valid) auto + +lemma c_mlk_poly_mulcache_struct_prism_valid: + shows \is_valid_prism c_mlk_poly_mulcache_struct_prism\ +unfolding c_mlk_poly_mulcache_struct_prism_def by (rule iso_prism_valid) auto + +subsection \Composed byte prisms for struct types\ + +definition c_mlk_poly_byte_prism :: \(8 word list, c_mlk_poly) prism\ where + \c_mlk_poly_byte_prism \ prism_compose c_short_list_byte_prism c_mlk_poly_struct_prism\ + +definition c_mlk_poly_mulcache_byte_prism :: \(8 word list, c_mlk_poly_mulcache) prism\ where + \c_mlk_poly_mulcache_byte_prism \ + prism_compose c_short_list_byte_prism c_mlk_poly_mulcache_struct_prism\ + +lemma c_mlk_poly_byte_prism_valid: + shows \is_valid_prism c_mlk_poly_byte_prism\ +unfolding c_mlk_poly_byte_prism_def by (intro prism_compose_valid c_short_list_byte_prism_valid + c_mlk_poly_struct_prism_valid) + +lemma c_mlk_poly_mulcache_byte_prism_valid: + shows \is_valid_prism c_mlk_poly_mulcache_byte_prism\ +unfolding c_mlk_poly_mulcache_byte_prism_def by (intro prism_compose_valid + c_short_list_byte_prism_valid c_mlk_poly_mulcache_struct_prism_valid) + +subsection \Global interpretation of @{locale c_mlk_machine_model}\ + +text \ + The parameter order follows Isabelle's locale convention: implicit parameters + from parent locales (here: @{locale reference}) come first, then the explicit + parameters from the \for\ clause. + We provide trivial dummy implementations for @{locale c_pointer_model} parameters. +\ + +global_interpretation mlk_instance: c_mlk_machine_model + \ \Implicit @{locale reference} parameters\ + urust_heap_update_raw_fun + urust_heap_dereference_raw_fun + urust_heap_reference_raw_fun + urust_heap_points_to_raw' + \\_. UNIV\ UNIV + urust_heap_can_alloc_reference + \ \@{locale c_pointer_model} for-clause parameters (dummy implementations)\ + \\p _ _. p\ + \\p _ _. p\ + \\_ _ _. (0::int)\ + \\_ _. False\ + \\_ _. True\ + \\_ _. True\ + \\_ _. True\ + \\_. (0::int)\ + \\_. undefined\ + \ \@{locale reference} type-constraining parameter\ + \\_ _ _ _ _ _. ()\ + \ \Prism parameters for @{locale reference_allocatable} instances\ + c_mlk_poly_byte_prism + c_mlk_poly_mulcache_byte_prism + c_uint_byte_prism + c_int_byte_prism + c_short_byte_prism + c_ushort_byte_prism + c_short_list_byte_prism + rewrites \reference_defs.dereference_fun urust_heap_dereference_raw_fun = + urust_heap_dereference_fun\ + and \reference_defs.update_fun urust_heap_update_raw_fun urust_heap_dereference_raw_fun = + urust_heap_update_fun\ + and \reference_defs.reference_fun urust_heap_reference_raw_fun = + urust_heap_reference_fun\ +proof - + show \c_mlk_machine_model + urust_heap_update_raw_fun urust_heap_dereference_raw_fun + urust_heap_reference_raw_fun urust_heap_points_to_raw' + (\_. UNIV) UNIV urust_heap_can_alloc_reference + (\p _ _. p) (\_ _ _. (0::int)) (\_ _. False) (\_ _. True) (\_ _. True) + c_mlk_poly_byte_prism c_mlk_poly_mulcache_byte_prism + c_uint_byte_prism c_int_byte_prism c_short_byte_prism + c_ushort_byte_prism c_short_list_byte_prism\ + proof + \ \Prism validity for each @{locale reference_allocatable} instance\ + show \is_valid_prism c_mlk_poly_byte_prism\ + by (rule c_mlk_poly_byte_prism_valid) + show \is_valid_prism c_mlk_poly_mulcache_byte_prism\ + by (rule c_mlk_poly_mulcache_byte_prism_valid) + show \is_valid_prism c_uint_byte_prism\ + by (rule c_byte_prism_valid(5)) + show \is_valid_prism c_int_byte_prism\ + by (rule c_byte_prism_valid(6)) + show \is_valid_prism c_short_byte_prism\ + by (rule c_byte_prism_valid(4)) + show \is_valid_prism c_ushort_byte_prism\ + by (rule c_byte_prism_valid(3)) + show \is_valid_prism c_short_list_byte_prism\ + by (rule c_short_list_byte_prism_valid) + \ \Allocatability: @{term \prism_dom p \ UNIV\} is trivially true for each prism\ + show \References.can_create_gref_for_prism c_mlk_poly_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + show \References.can_create_gref_for_prism c_mlk_poly_mulcache_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + show \References.can_create_gref_for_prism c_uint_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + show \References.can_create_gref_for_prism c_int_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + show \References.can_create_gref_for_prism c_short_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + show \References.can_create_gref_for_prism c_ushort_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + show \References.can_create_gref_for_prism c_short_list_byte_prism\ + by (simp add: References.can_create_gref_for_prism_def) + \ \@{locale c_pointer_model} axioms: trivially satisfied by dummy implementations\ + qed auto +qed (auto simp: urust_heap_dereference_fun_def urust_heap_update_fun_def + urust_heap_reference_fun_def) + +end + diff --git a/proofs/isabelle/MLKEM_Poly_Definitions.thy b/proofs/isabelle/MLKEM_Poly_Definitions.thy new file mode 100644 index 0000000000..be244b1c9b --- /dev/null +++ b/proofs/isabelle/MLKEM_Poly_Definitions.thy @@ -0,0 +1,22 @@ +theory MLKEM_Poly_Definitions + imports + MLKEM_Machine_Model +begin + +(*<*) +text \ + Auto-generated by @{file "pipeline/generate.py"} from preprocessed C. + Source unit: @{file "../../mlkem/src/poly.c"} + Preprocessed unit: @{file "generated/c/poly.pre.c"} +\ + +context c_mlk_machine_model +begin + +micro_c_file addr: 'addr gv: "8 word list" manifest: "generated/manifests/poly.functions.manifest" "generated/c/poly.pre.c" + +end +(*>*) + + +end diff --git a/proofs/isabelle/MLKEM_Verify_Definitions.thy b/proofs/isabelle/MLKEM_Verify_Definitions.thy new file mode 100644 index 0000000000..7e38795f17 --- /dev/null +++ b/proofs/isabelle/MLKEM_Verify_Definitions.thy @@ -0,0 +1,16 @@ +theory MLKEM_Verify_Definitions + imports + MLKEM_Machine_Model +begin + +text \ + Auto-generated by @{file "pipeline/generate.py"} from preprocessed C. + Source unit: @{file "../../mlkem/src/verify.c"} + Preprocessed unit: @{file "generated/c/verify.pre.c"} +\ + +micro_c_file "generated/c/verify.pre.c" + + + +end diff --git a/proofs/isabelle/Makefile b/proofs/isabelle/Makefile new file mode 100644 index 0000000000..f7495eb390 --- /dev/null +++ b/proofs/isabelle/Makefile @@ -0,0 +1,136 @@ +# Copyright (c) The mlkem-native project authors +# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT + +.DEFAULT_GOAL := jedit +.PHONY: check-autocorrode check-afp setup-afp register-afp-components build pdf jedit \ + pipeline pipeline-flags pipeline-preprocess pipeline-theory pipeline-correctness + +# Set this to the directory containing the Isabelle binary. +ISABELLE_HOME ?= /Applications/Isabelle2025-2.app/bin +# Root of the mlkem-native repository. +REPO_ROOT ?= $(abspath ../..) +# Set this to an AutoCorrode checkout. +AUTOCORRODE_DIR ?= $(REPO_ROOT)/AutoCorrode +# Set this to an AFP checkout containing Word_Lib and Isabelle_C +# (either flat or a full mirror-afp clone with thys/ subdirectory). +AFP_COMPONENT_BASE ?= $(AUTOCORRODE_DIR)/dependencies/afp +# Main session provided by this directory. +ISABELLE_SESSION ?= MLKEM_Native_AutoCorrode + +HOST := $(shell uname -s) +ifeq ($(HOST),Darwin) + AVAILABLE_CORES ?= $(shell sysctl -n hw.physicalcpu) +else ifeq ($(HOST),Linux) + AVAILABLE_CORES ?= $(shell nproc) +else + $(error Unsupported host platform) +endif + +# Build one session job using all available cores. +ISABELLE_BUILD_FLAGS ?= -b -j 1 -o "threads=$(AVAILABLE_CORES)" -v +ISABELLE_JEDIT_FLAGS ?= + +ISABELLE_BUILD_FLAGS += $(ISABELLE_REMOTE) +ISABELLE_JEDIT_FLAGS += $(ISABELLE_REMOTE) + +# Pipeline controls (proof translation units -> preprocessed C -> generated theories). +PIPELINE_SCRIPT ?= $(CURDIR)/pipeline/generate.py +PIPELINE_UNIT ?= poly +PIPELINE_PARAMETER_SET ?= 512 +PIPELINE_OPT ?= 0 +PIPELINE_AUTO ?= 0 + +# Resolve AFP entry directory at shell level (handles ~ expansion and thys/ layout). +# Usage: $(call afp-entry-dir) expands AFP_COMPONENT_BASE in shell. +AFP_RESOLVE = base="$$(eval echo '$(AFP_COMPONENT_BASE)')"; \ + if [ -d "$$base/Word_Lib" ]; then echo "$$base"; \ + elif [ -d "$$base/thys/Word_Lib" ]; then echo "$$base/thys"; \ + else echo "$$base"; fi + +ISABELLE_DIRS = \ + -d . \ + -d $(AUTOCORRODE_DIR) \ + -d $$($(AFP_RESOLVE))/Word_Lib \ + -d $$($(AFP_RESOLVE))/Isabelle_C + +check-autocorrode: + @if [ ! -f "$(AUTOCORRODE_DIR)/ROOT" ]; then \ + echo "Missing AutoCorrode checkout at $(AUTOCORRODE_DIR)."; \ + echo "Set AUTOCORRODE_DIR to your AutoCorrode checkout."; \ + exit 1; \ + fi + +check-afp: check-autocorrode + @afp_dir=$$($(AFP_RESOLVE)); \ + if [ ! -d "$$afp_dir/Word_Lib" ] || [ ! -d "$$afp_dir/Isabelle_C" ]; then \ + echo "Missing AFP entries at $(AFP_COMPONENT_BASE)."; \ + echo "Run 'make setup-afp' or set AFP_COMPONENT_BASE to an AFP checkout."; \ + exit 1; \ + fi + +setup-afp: + @afp_dir=$$($(AFP_RESOLVE)); \ + if [ -d "$$afp_dir/Word_Lib" ] && [ -d "$$afp_dir/Isabelle_C" ]; then \ + echo "AFP already available at $$afp_dir."; \ + else \ + base="$$(eval echo '$(AFP_COMPONENT_BASE)')"; \ + if [ -d "$$base" ] && [ ! -d "$$base/thys" ] && \ + [ "$$(ls -A "$$base" 2>/dev/null)" != "" ]; then \ + echo "Directory $$base exists but does not contain Word_Lib or Isabelle_C."; \ + echo "Set AFP_COMPONENT_BASE to a valid AFP checkout."; \ + exit 1; \ + else \ + git clone --depth 1 https://github.com/isabelle-prover/mirror-afp-2025-1 "$$base"; \ + fi; \ + fi + +register-afp-components: check-afp + @afp_dir=$$($(AFP_RESOLVE)); \ + $(ISABELLE_HOME)/isabelle components -u "$$afp_dir/Word_Lib"; \ + $(ISABELLE_HOME)/isabelle components -u "$$afp_dir/Isabelle_C" + +build: register-afp-components + @afp_dir=$$($(AFP_RESOLVE)); \ + $(ISABELLE_HOME)/isabelle build $(ISABELLE_BUILD_FLAGS) \ + -d . -d $(AUTOCORRODE_DIR) -d "$$afp_dir/Word_Lib" -d "$$afp_dir/Isabelle_C" \ + $(ISABELLE_SESSION) + +pdf: register-afp-components + @afp_dir=$$($(AFP_RESOLVE)); \ + $(ISABELLE_HOME)/isabelle build $(ISABELLE_BUILD_FLAGS) \ + -o document=pdf -o document_build=pdflatex \ + -o "document_output=$(CURDIR)/output" \ + -d . -d $(AUTOCORRODE_DIR) -d "$$afp_dir/Word_Lib" -d "$$afp_dir/Isabelle_C" \ + $(ISABELLE_SESSION); \ + echo "PDF written to $(CURDIR)/output/" + +jedit: register-afp-components + @afp_dir=$$($(AFP_RESOLVE)); \ + $(ISABELLE_HOME)/isabelle jedit $(ISABELLE_JEDIT_FLAGS) -l HOL \ + -d . -d $(AUTOCORRODE_DIR) -d "$$afp_dir/Word_Lib" -d "$$afp_dir/Isabelle_C" \ + ./MLKEM_Functional_Correctness.thy & + +pipeline-flags: + python3 $(PIPELINE_SCRIPT) --repo-root $(REPO_ROOT) --proof-root $(CURDIR) \ + --unit $(PIPELINE_UNIT) --parameter-set $(PIPELINE_PARAMETER_SET) \ + --opt $(PIPELINE_OPT) --auto $(PIPELINE_AUTO) --action flags + +pipeline-preprocess: + python3 $(PIPELINE_SCRIPT) --repo-root $(REPO_ROOT) --proof-root $(CURDIR) \ + --unit $(PIPELINE_UNIT) --parameter-set $(PIPELINE_PARAMETER_SET) \ + --opt $(PIPELINE_OPT) --auto $(PIPELINE_AUTO) --action preprocess + +pipeline-theory: + python3 $(PIPELINE_SCRIPT) --repo-root $(REPO_ROOT) --proof-root $(CURDIR) \ + --unit $(PIPELINE_UNIT) --parameter-set $(PIPELINE_PARAMETER_SET) \ + --opt $(PIPELINE_OPT) --auto $(PIPELINE_AUTO) --action theory + +pipeline-correctness: + python3 $(PIPELINE_SCRIPT) --repo-root $(REPO_ROOT) --proof-root $(CURDIR) \ + --unit $(PIPELINE_UNIT) --parameter-set $(PIPELINE_PARAMETER_SET) \ + --opt $(PIPELINE_OPT) --auto $(PIPELINE_AUTO) --action correctness + +pipeline: + python3 $(PIPELINE_SCRIPT) --repo-root $(REPO_ROOT) --proof-root $(CURDIR) \ + --unit $(PIPELINE_UNIT) --parameter-set $(PIPELINE_PARAMETER_SET) \ + --opt $(PIPELINE_OPT) --auto $(PIPELINE_AUTO) --action all diff --git a/proofs/isabelle/README.md b/proofs/isabelle/README.md new file mode 100644 index 0000000000..389611b60f --- /dev/null +++ b/proofs/isabelle/README.md @@ -0,0 +1,47 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +# Isabelle proofs + +This directory contains Isabelle/HOL proofs for `mlkem-native`, built on top of AutoCorrode. + +## Prerequisites + +- Isabelle2025-2 (`isabelle` binary available via `ISABELLE_HOME`) +- An AutoCorrode checkout (set `AUTOCORRODE_DIR`) +- AFP mirror checkout containing `Word_Lib` and `Isabelle_C` (set `AFP_COMPONENT_BASE`) + +Default Makefile assumptions: + +- `AUTOCORRODE_DIR=../../AutoCorrode` +- `AFP_COMPONENT_BASE=$(AUTOCORRODE_DIR)/dependencies/afp` + +## Usage + +From the repository root: + +```bash +make -C proofs/isabelle pipeline +make -C proofs/isabelle build +make -C proofs/isabelle jedit +``` + +## Pipeline overview + +Translation units are preprocessed from `mlkem/src/*.c` using compile flags discovered from the root Makefile. +Generated artifacts are written under `generated/`. + +Useful controls: + +- `PIPELINE_UNIT` (default `poly`) +- `PIPELINE_PARAMETER_SET` (default `512`) +- `PIPELINE_OPT` (default `0`) +- `PIPELINE_AUTO` (default `0`) + +## Current theory split + +- `Common.thy`: shared abstractions and refinement lemmas. +- `MLKEM_Machine_Model.thy`: shared machine-model locale; imports type-only extraction (`poly.types.manifest`) so generated C record types are available for reference assumptions. +- `MLKEM_Poly_Definitions.thy`: auto-generated function definitions from `poly.c`, produced inside `context c_mlk_machine_model`. +- `MLKEM_Poly_Functional_Correctness.thy`: contracts and WP proofs for generated `poly.c` definitions. +- `MLKEM_Verify_Definitions.thy`: extracted translation of `verify.c` (proofs pending). +- `MLKEM_Functional_Correctness.thy`: top-level aggregation theory. diff --git a/proofs/isabelle/ROOT b/proofs/isabelle/ROOT new file mode 100644 index 0000000000..fe7d06030d --- /dev/null +++ b/proofs/isabelle/ROOT @@ -0,0 +1,28 @@ +session MLKEM_Native_AutoCorrode = HOL + + options [document = pdf, document_output = "output"] + sessions + "HOL-Library" + Micro_C_Examples + Micro_Rust_Runtime + Byte_Level_Encoding + theories + MLKEM_Spec + MLKEM_Refinement + MLKEM_Zetas + MLKEM_NTT_Spec + MLKEM_InvNTT_Spec + MLKEM_NTT_Correctness + theories [document = false] + MLKEM_Machine_Model + MLKEM_Machine_Model_Instance + MLKEM_Poly_Definitions + theories + MLKEM_FC_Scalar + MLKEM_FC_Montgomery + MLKEM_FC_PolyLoop + MLKEM_FC_NTT + MLKEM_FC_InvNTT + theories [document = false] + MLKEM_Functional_Correctness + document_files + "root.tex" diff --git a/proofs/isabelle/document/root.tex b/proofs/isabelle/document/root.tex new file mode 100644 index 0000000000..0a5a4c972d --- /dev/null +++ b/proofs/isabelle/document/root.tex @@ -0,0 +1,70 @@ +\documentclass[10pt,a4paper]{article} + +\usepackage{microtype} +\usepackage{a4wide} +\usepackage{tgpagella} +\usepackage[euler-digits]{eulervm} +\usepackage{mdwlist} +\usepackage[T1]{fontenc} +\usepackage{isabelle,isabellesym} + +\usepackage{amssymb} +\usepackage{amsmath} +\usepackage{eurosym} +\usepackage[only,bigsqcap,bigparallel,fatsemi,interleave,sslash]{stmaryrd} +\usepackage{eufrak} +\usepackage{textcomp} +\usepackage{wasysym} + +\usepackage{tikz} +\usetikzlibrary{arrows.meta,positioning,calc} + +\usepackage{xcolor} +\definecolor{mlkblue}{HTML}{2563EB} +\definecolor{mlklightblue}{HTML}{DBEAFE} + +% this should be the last package used +\usepackage{pdfsetup} + +% urls in roman style, theory text in math-similar italics +\urlstyle{rm} +\isabellestyle{it} + +% for uniform font size +\renewcommand{\isastyle}{\isastyleminor} + +% suppress proof bodies, keep theorem statements +\isadroptag{proof} + +\renewcommand{\isasymlonglonglongrightarrow}{\longlonglongrightarrow} + +\begin{document} + +\title{ML-KEM Functional Correctness in Isabelle/HOL} +\author{mlkem-native contributors} +\maketitle + +\begin{abstract} +Machine-checked functional correctness proofs for the polynomial +arithmetic routines in \texttt{mlkem-native}, an ML-KEM (FIPS~203) +implementation in portable C\@. The C code is automatically translated +to Isabelle/HOL via \emph{AutoCorrode}, and each function is verified +against an abstract integer-level specification. The Number Theoretic +Transform (NTT) and its inverse are further shown to be two-sided +inverses modulo~$q$. +\end{abstract} + +\tableofcontents + +% sane default for proof documents +\parindent 0pt\parskip 0.5ex + +% generated text of all theories +\input{session} + +\end{document} + +%%% Local Variables: +%%% mode: latex +%%% TeX-master: t +%%% End: diff --git a/proofs/isabelle/generated/.gitignore b/proofs/isabelle/generated/.gitignore new file mode 100644 index 0000000000..c3d07e57fe --- /dev/null +++ b/proofs/isabelle/generated/.gitignore @@ -0,0 +1,3 @@ +* +!.gitkeep +!.gitignore diff --git a/proofs/isabelle/generated/.gitkeep b/proofs/isabelle/generated/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/proofs/isabelle/pipeline/README.md b/proofs/isabelle/pipeline/README.md new file mode 100644 index 0000000000..869069817c --- /dev/null +++ b/proofs/isabelle/pipeline/README.md @@ -0,0 +1,61 @@ +[//]: # (SPDX-License-Identifier: CC-BY-4.0) + +# Isabelle pipeline (C -> preprocessed C -> theory) + +This directory contains the source-driven generation pipeline for `proofs/isabelle`. + +## Configured units + +- `mlkem/src/poly.c` -> `generated/c/poly.pre.c` -> `MLKEM_Poly_Definitions.thy` +- `mlkem/src/verify.c` -> `generated/c/verify.pre.c` -> `MLKEM_Verify_Definitions.thy` + +## Driver + +Use `pipeline/generate.py`. +It queries the root `Makefile` for compile flags, extracts preprocessor-relevant flags, +and runs preprocessing. + +From `proofs/isabelle`: + +```bash +make pipeline +make pipeline-flags +make pipeline-correctness +``` + +To regenerate a single unit: + +```bash +make pipeline PIPELINE_UNIT=poly +make pipeline PIPELINE_UNIT=verify +``` + +## Filtering layers + +Filtering is split intentionally: + +- `extract` (in `units.json`): script-level trimming of preprocessed C before Isabelle parsing. +- `manifest` (in `units.json`): declarative function/type filtering passed to `micro_c_file`. +- `type_manifest` (optional, in `units.json`): emits a type-only manifest used for two-phase translation. + +For `poly`, the flow is: + +1. `MLKEM_Machine_Model.thy` runs `micro_c_file` with `generated/manifests/poly.types.manifest` + to declare required C-derived types (for example `c_mlk_poly`). +2. `MLKEM_Poly_Definitions.thy` runs `micro_c_file` with `generated/manifests/poly.functions.manifest` + inside `context c_mlk_machine_model` to generate function definitions. + +## Manifest format + +Manifest files are plain text with optional `functions:` and `types:` sections. +Entries can be bare names or `-`-prefixed names: + +```text +functions: +- mlk_barrett_reduce +- mlk_poly_add + +types: +- mlk_poly +- int16_t +``` diff --git a/proofs/isabelle/pipeline/config/proof_config.h b/proofs/isabelle/pipeline/config/proof_config.h new file mode 100644 index 0000000000..7310796c62 --- /dev/null +++ b/proofs/isabelle/pipeline/config/proof_config.h @@ -0,0 +1,62 @@ +#ifndef MLK_PROOF_CONFIG_H +#define MLK_PROOF_CONFIG_H + +/* Proof-oriented configuration for Isabelle translation. */ +/* Keep symbol names compact and stable for theorem references. */ +#define MLK_CONFIG_NAMESPACE_PREFIX mlk + +/* Avoid accidental native backend assumptions in translated C semantics. */ +#define MLK_CONFIG_NO_ASM + +/* Prevent libc string header pull-in from common.h. */ +#define MLK_CONFIG_CUSTOM_MEMCPY +#define MLK_CONFIG_CUSTOM_MEMSET + +/* Provide simple stand-ins used only during translation extraction. */ +#if !defined(__ASSEMBLER__) +typedef unsigned long size_t; + +static inline void *mlk_memcpy(void *dst, const void *src, size_t n) +{ + unsigned char *d = (unsigned char *)dst; + const unsigned char *s = (const unsigned char *)src; + size_t i; + for (i = 0; i < n; i++) + { + d[i] = s[i]; + } + return dst; +} + +static inline void *mlk_memset(void *dst, int c, size_t n) +{ + unsigned char *d = (unsigned char *)dst; + size_t i; + for (i = 0; i < n; i++) + { + d[i] = (unsigned char)c; + } + return dst; +} +#endif + +/* Simplify constant-time opt-blocker globals to constant 0 for proof builds. + * This makes value barriers reduce to identity (b ^ 0 = b), which is + * functionally correct: barriers only prevent compiler optimizations. + * Variadic macros accept the (void) parameter in function declarations. */ +#define mlk_ct_get_optblocker_u32(...) ((uint32_t)0) +#define mlk_ct_get_optblocker_i32(...) ((int32_t)0) +#define mlk_ct_get_optblocker_u8(...) ((uint8_t)0) +#define mlk_ct_get_optblocker_u64(...) ((uint64_t)0) + +/* Provide a minimal zeroization implementation for preprocessing-only builds. */ +#define MLK_CONFIG_CUSTOM_ZEROIZE +#if !defined(__ASSEMBLER__) +static inline void mlk_zeroize(void *ptr, size_t len) +{ + (void)ptr; + (void)len; +} +#endif + +#endif diff --git a/proofs/isabelle/pipeline/config/stdint.h b/proofs/isabelle/pipeline/config/stdint.h new file mode 100644 index 0000000000..910ed024cc --- /dev/null +++ b/proofs/isabelle/pipeline/config/stdint.h @@ -0,0 +1,38 @@ +#ifndef MLK_PROOF_STDINT_H +#define MLK_PROOF_STDINT_H + +typedef signed char int8_t; +typedef short int16_t; +typedef int int32_t; +typedef long long int64_t; + +typedef unsigned char uint8_t; +typedef unsigned short uint16_t; +typedef unsigned int uint32_t; +typedef unsigned long long uint64_t; + +typedef int8_t int_least8_t; +typedef int16_t int_least16_t; +typedef int32_t int_least32_t; +typedef int64_t int_least64_t; +typedef uint8_t uint_least8_t; +typedef uint16_t uint_least16_t; +typedef uint32_t uint_least32_t; +typedef uint64_t uint_least64_t; +typedef int8_t int_fast8_t; +typedef int16_t int_fast16_t; +typedef int32_t int_fast32_t; +typedef int64_t int_fast64_t; +typedef uint8_t uint_fast8_t; +typedef uint16_t uint_fast16_t; +typedef uint32_t uint_fast32_t; +typedef uint64_t uint_fast64_t; + +typedef long int intptr_t; +typedef unsigned long uintptr_t; +typedef long int intmax_t; +typedef unsigned long uintmax_t; + +#define UINT16_MAX 65535u + +#endif diff --git a/proofs/isabelle/pipeline/generate.py b/proofs/isabelle/pipeline/generate.py new file mode 100755 index 0000000000..0a1b211c9d --- /dev/null +++ b/proofs/isabelle/pipeline/generate.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +"""Generate preprocessed C and corresponding Isabelle theory files for proofs/isabelle.""" + +from __future__ import annotations + +import argparse +import json +import shlex +import subprocess +from pathlib import Path +from typing import Iterable + + +def run(cmd: list[str], cwd: Path | None = None) -> str: + proc = subprocess.run(cmd, cwd=str(cwd) if cwd else None, text=True, capture_output=True) + if proc.returncode != 0: + raise RuntimeError( + f"Command failed ({proc.returncode}): {' '.join(cmd)}\n" + f"stdout:\n{proc.stdout}\n" + f"stderr:\n{proc.stderr}" + ) + return proc.stdout + + +def find_compile_command( + repo_root: Path, + build_dir: Path, + source: str, + parameter_set: int, + opt: int, + auto: int, +) -> list[str]: + object_target = f"{build_dir}/mlkem{parameter_set}/{source}.o" + out = run( + [ + "make", + "-C", + str(repo_root), + "-n", + f"BUILD_DIR={build_dir}", + f"PARAMETER_SET={parameter_set}", + f"OPT={opt}", + f"AUTO={auto}", + object_target, + ] + ) + + compile_line = None + for line in out.splitlines(): + stripped = line.strip() + if " -c " in stripped and " -o " in stripped and stripped.endswith(source): + compile_line = stripped + if compile_line is None: + raise RuntimeError( + "Could not locate compile command in make -n output for " + f"source {source}." + ) + return shlex.split(compile_line) + + +def extract_cpp_flags(tokens: list[str], source: str) -> list[str]: + keep = [] + i = 1 + while i < len(tokens): + tok = tokens[i] + + if tok in {"-c", "-S", "-o"}: + if tok == "-o": + i += 2 + else: + i += 1 + continue + + if tok == source or tok.endswith(f"/{source}"): + i += 1 + continue + + if tok in {"-D", "-U", "-I", "-include", "-isystem", "-std"}: + if i + 1 >= len(tokens): + raise RuntimeError(f"Missing argument after {tok}") + keep.extend([tok, tokens[i + 1]]) + i += 2 + continue + + if tok.startswith(("-D", "-U", "-I", "-include", "-isystem", "-std=")): + keep.append(tok) + i += 1 + continue + + if tok in {"-nostdinc", "-nostdinc++"}: + keep.append(tok) + i += 1 + continue + + i += 1 + + return keep + + +def render_template(path: Path, replacements: dict[str, str]) -> str: + content = path.read_text() + for key, value in replacements.items(): + content = content.replace("{" + key + "}", value) + return content + + +def quote_isabelle_string(raw: str) -> str: + escaped = raw.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + + +def render_manifest_text(manifest_cfg: dict) -> str: + functions = list(manifest_cfg.get("functions", [])) + types = list(manifest_cfg.get("types", [])) + if not functions and not types: + raise RuntimeError("manifest must contain at least one of: functions, types") + + lines: list[str] = [] + if functions: + lines.append("functions:") + lines.extend(f"- {name}" for name in functions) + lines.append("") + if types: + lines.append("types:") + lines.extend(f"- {name}" for name in types) + lines.append("") + return "\n".join(lines).rstrip() + "\n" + + +def maybe_write_manifest_cfg( + proof_root: Path, unit_name: str, manifest_cfg: dict | None, action: str, label: str +) -> str | None: + if not manifest_cfg: + return None + + output_rel = manifest_cfg.get("output") + if not output_rel: + raise RuntimeError(f"[{unit_name}] {label}.output is required") + + if action in {"preprocess", "theory", "all"}: + manifest_text = render_manifest_text(manifest_cfg) + out_path = proof_root / output_rel + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(manifest_text) + print(f"[{unit_name}] wrote {out_path}") + + return output_rel + + +def mk_micro_c_file_options(unit: dict, manifest_output: str | None) -> str: + parts: list[str] = [] + prefix = unit.get("prefix") + if prefix: + parts.append(f"prefix: {prefix}") + addr_ty = unit.get("addr_ty") + if addr_ty: + parts.append(f"addr: {addr_ty}") + gv_ty = unit.get("gv_ty") + if gv_ty: + parts.append(f"gv: {gv_ty}") + if manifest_output: + parts.append(f"manifest: {quote_isabelle_string(manifest_output)}") + if not parts: + return "" + return " ".join(parts) + " " + + +def mk_micro_c_file_command(unit: dict, c_file: str, manifest_output: str | None) -> str: + """Build a single micro_c_file command string.""" + parts: list[str] = [] + prefix = unit.get("prefix") + if prefix: + parts.append(f"prefix: {prefix}") + addr_ty = unit.get("addr_ty") + if addr_ty: + parts.append(f"addr: {addr_ty}") + gv_ty = unit.get("gv_ty") + if gv_ty: + parts.append(f"gv: {gv_ty}") + if manifest_output: + parts.append(f"manifest: {quote_isabelle_string(manifest_output)}") + opts = " ".join(parts) + if opts: + opts += " " + return f"micro_c_file {opts}{quote_isabelle_string(c_file)}\n" + + +def iter_units(config_units: list[dict], unit_filter: str | None) -> Iterable[dict]: + for unit in config_units: + if unit_filter is None or unit["name"] == unit_filter: + yield unit + + +def _find_function_with_body(source: str, name: str) -> str: + needle = f"{name}(" + pos = 0 + while True: + i = source.find(needle, pos) + if i < 0: + raise RuntimeError(f"Could not find function definition for {name}") + + line_start = source.rfind("\n", 0, i) + line_start = 0 if line_start < 0 else line_start + 1 + + semi = source.find(";", i) + brace = source.find("{", i) + if brace < 0: + raise RuntimeError(f"No function body found for {name}") + if semi >= 0 and semi < brace: + # Prototype/call-site; keep searching. + pos = semi + 1 + continue + + depth = 0 + j = brace + in_str = None + in_line_comment = False + in_block_comment = False + + while j < len(source): + ch = source[j] + nxt = source[j + 1] if j + 1 < len(source) else "" + + if in_line_comment: + if ch == "\n": + in_line_comment = False + j += 1 + continue + + if in_block_comment: + if ch == "*" and nxt == "/": + in_block_comment = False + j += 2 + else: + j += 1 + continue + + if in_str is not None: + if ch == "\\": + j += 2 + continue + if ch == in_str: + in_str = None + j += 1 + continue + + if ch == "/" and nxt == "/": + in_line_comment = True + j += 2 + continue + + if ch == "/" and nxt == "*": + in_block_comment = True + j += 2 + continue + + if ch in {'"', "'"}: + in_str = ch + j += 1 + continue + + if ch == "{": + depth += 1 + elif ch == "}": + depth -= 1 + if depth == 0: + end = j + 1 + return source[line_start:end].strip() + "\n" + j += 1 + + raise RuntimeError(f"Unbalanced braces while parsing {name}") + + +def extract_function_subset(preprocessed: str, extract_cfg: dict) -> str: + functions = extract_cfg.get("functions", []) + if not functions: + raise RuntimeError("extract.mode=function_subset requires a non-empty functions list") + + preamble_lines = extract_cfg.get("preamble", []) + out_parts = [] + if preamble_lines: + out_parts.append("\n".join(preamble_lines).rstrip() + "\n\n") + + for fn in functions: + out_parts.append(_find_function_with_body(preprocessed, fn).rstrip() + "\n\n") + + return "".join(out_parts).rstrip() + "\n" + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--repo-root", required=True, type=Path) + parser.add_argument("--proof-root", required=True, type=Path) + parser.add_argument("--unit", default=None) + parser.add_argument("--parameter-set", type=int, default=None) + parser.add_argument("--opt", type=int, default=0) + parser.add_argument("--auto", type=int, default=0) + parser.add_argument("--action", choices=["flags", "preprocess", "theory", "correctness", "all"], default="all") + args = parser.parse_args() + + repo_root = args.repo_root.resolve() + proof_root = args.proof_root.resolve() + pipeline_dir = proof_root / "pipeline" + units_path = pipeline_dir / "units.json" + cfg = json.loads(units_path.read_text()) + + build_dir = proof_root / "generated" / "make-build" + build_dir.mkdir(parents=True, exist_ok=True) + + template_path = pipeline_dir / "templates" / "definitions.thy.tpl" + correctness_template_path = pipeline_dir / "templates" / "functional_correctness.thy.tpl" + + selected = list(iter_units(cfg["units"], args.unit)) + if not selected: + raise RuntimeError("No units selected. Check --unit or pipeline/units.json") + + for unit in selected: + source = unit["source"] + param_set = args.parameter_set if args.parameter_set is not None else int(unit.get("parameter_set", 512)) + manifest_output = maybe_write_manifest_cfg( + proof_root, unit["name"], unit.get("manifest"), args.action, "manifest" + ) + type_manifest_output = maybe_write_manifest_cfg( + proof_root, unit["name"], unit.get("type_manifest"), args.action, "type_manifest" + ) + + compile_tokens = find_compile_command( + repo_root=repo_root, + build_dir=build_dir, + source=source, + parameter_set=param_set, + opt=args.opt, + auto=args.auto, + ) + compiler = compile_tokens[0] + cpp_flags = extract_cpp_flags(compile_tokens, source) + cpp_overrides = unit.get("cpp_overrides", []) + if cpp_overrides: + cpp_flags = [*cpp_flags, *cpp_overrides] + + if args.action in {"flags", "all"}: + print(f"[{unit['name']}] compiler: {compiler}") + print(f"[{unit['name']}] parameter_set: {param_set}, opt={args.opt}, auto={args.auto}") + print(f"[{unit['name']}] cpp flags: {' '.join(cpp_flags)}") + + pre_out = proof_root / unit["output_c"] + thy_out = proof_root / unit["output_theory"] + + if args.action in {"preprocess", "all"}: + pre_out.parent.mkdir(parents=True, exist_ok=True) + src_abs = repo_root / source + preprocessed = run([compiler, "-E", "-P", *cpp_flags, str(src_abs)], cwd=repo_root) + extract_cfg = unit.get("extract") + banner = ( + "/* Auto-generated by proofs/isabelle/pipeline/generate.py. */\n" + f"/* Source: {source}; PARAMETER_SET={param_set}; OPT={args.opt}; AUTO={args.auto} */\n\n" + ) + if extract_cfg: + mode = extract_cfg.get("mode") + if mode == "function_subset": + preprocessed = extract_function_subset(preprocessed, extract_cfg) + else: + raise RuntimeError(f"Unsupported extract mode: {mode}") + pre_out.write_text(banner + preprocessed) + print(f"[{unit['name']}] wrote {pre_out}") + + if args.action in {"theory", "all"}: + thy_out.parent.mkdir(parents=True, exist_ok=True) + micro_c_file_commands = mk_micro_c_file_command( + unit, unit["output_c"], manifest_output + ) + rendered = render_template( + template_path, + { + "THEORY_NAME": unit["theory_name"], + "SOURCE_FILE": source, + "PREPROCESSED_FILE": unit["output_c"], + "ISABELLE_PRELUDE": ("\n".join(unit.get("isabelle_prelude", [])) + "\n\n") if unit.get("isabelle_prelude") else "", + "ISABELLE_POSTLUDE": ("\n".join(unit.get("isabelle_postlude", [])) + "\n") if unit.get("isabelle_postlude") else "", + "MICRO_C_FILE_COMMANDS": micro_c_file_commands, + "TYPE_MANIFEST_OUTPUT": type_manifest_output or "", + }, + ) + thy_out.write_text(rendered) + print(f"[{unit['name']}] wrote {thy_out}") + + if args.action in {"correctness", "all"}: + out_fc = unit.get("output_correctness_theory") + name_fc = unit.get("correctness_theory_name") + if out_fc and name_fc: + out_fc_path = proof_root / out_fc + out_fc_path.parent.mkdir(parents=True, exist_ok=True) + rendered_fc = render_template( + correctness_template_path, + { + "CORRECTNESS_THEORY_NAME": name_fc, + "DEFINITIONS_THEORY_NAME": unit["theory_name"], + "SOURCE_FILE": source, + }, + ) + out_fc_path.write_text(rendered_fc) + print(f"[{unit['name']}] wrote {out_fc_path}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/proofs/isabelle/pipeline/templates/definitions.thy.tpl b/proofs/isabelle/pipeline/templates/definitions.thy.tpl new file mode 100644 index 0000000000..ff484b7d46 --- /dev/null +++ b/proofs/isabelle/pipeline/templates/definitions.thy.tpl @@ -0,0 +1,15 @@ +theory {THEORY_NAME} + imports + MLKEM_Machine_Model +begin + +text \ + Auto-generated by @{file "pipeline/generate.py"} from preprocessed C. + Source unit: @{file "../../{SOURCE_FILE}"} + Preprocessed unit: @{file "{PREPROCESSED_FILE}"} +\ + +{ISABELLE_PRELUDE}{MICRO_C_FILE_COMMANDS} +{ISABELLE_POSTLUDE} + +end diff --git a/proofs/isabelle/pipeline/templates/functional_correctness.thy.tpl b/proofs/isabelle/pipeline/templates/functional_correctness.thy.tpl new file mode 100644 index 0000000000..9764c53466 --- /dev/null +++ b/proofs/isabelle/pipeline/templates/functional_correctness.thy.tpl @@ -0,0 +1,12 @@ +theory {CORRECTNESS_THEORY_NAME} + imports Common {DEFINITIONS_THEORY_NAME} +begin + +text \ + Auto-generated scaffold for source-driven functional-correctness proofs. + + Source unit: @{file "{SOURCE_FILE}"} + Definitions theory: @{theory {DEFINITIONS_THEORY_NAME}} +\ + +end diff --git a/proofs/isabelle/pipeline/units.json b/proofs/isabelle/pipeline/units.json new file mode 100644 index 0000000000..9b5b217222 --- /dev/null +++ b/proofs/isabelle/pipeline/units.json @@ -0,0 +1,154 @@ +{ + "units": [ + { + "name": "poly", + "source": "mlkem/src/poly.c", + "parameter_set": 512, + "theory_name": "MLKEM_Poly_Definitions", + "output_c": "generated/c/poly.pre.c", + "output_theory": "MLKEM_Poly_Definitions.thy", + "addr_ty": "'addr", + "gv_ty": "\"8 word list\"", + "cpp_overrides": [ + "-Iproofs/isabelle/pipeline/config", + "-DMLK_CONFIG_FILE=\"proof_config.h\"", + "-D__attribute__(x)=", + "-D__extension__=", + "-D__asm__(x)=" + ], + "extract": { + "mode": "function_subset", + "functions": [ + "mlk_cast_int32_to_uint16", + "mlk_cast_uint16_to_int16", + "mlk_cast_int16_to_uint16", + "mlk_value_barrier_i32", + "mlk_montgomery_reduce", + "mlk_ct_cmask_neg_i16", + "mlk_ct_cmask_nonzero_u16", + "mlk_ct_sel_int16", + "mlk_fqmul", + "mlk_barrett_reduce", + "mlk_scalar_signed_to_unsigned_q", + "mlk_poly_add", + "mlk_poly_sub", + "mlk_poly_tomont_c", + "mlk_poly_tomont", + "mlk_poly_reduce_c", + "mlk_poly_reduce", + "mlk_poly_mulcache_compute_c", + "mlk_poly_mulcache_compute", + "mlk_ntt_butterfly_block", + "mlk_ntt_layer", + "mlk_poly_ntt_c", + "mlk_poly_ntt", + "mlk_invntt_layer", + "mlk_poly_invntt_tomont_c", + "mlk_poly_invntt_tomont" + ], + "preamble": [ + "typedef short int16_t;", + "typedef int int32_t;", + "typedef unsigned short uint16_t;", + "typedef unsigned int uint32_t;", + "typedef unsigned char uint8_t;", + "typedef signed char int8_t;", + "typedef unsigned long uint64_t;", + "", + "typedef struct mlk_poly", + "{", + " int16_t coeffs[256];", + "} mlk_poly;", + "", + "typedef struct mlk_poly_mulcache", + "{", + " int16_t coeffs[128];", + "} mlk_poly_mulcache;", + "", + "static const int16_t mlk_zetas[128] = {", + " -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577,", + " 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458,", + " -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223,", + " 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666,", + " -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247,", + " -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430,", + " 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291,", + " -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119,", + " -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603,", + " 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220,", + " -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108,", + " -308, 996, 991, 958, -1460, 1522, 1628,", + "};", + "" + ] + }, + "manifest": { + "output": "generated/manifests/poly.functions.manifest", + "functions": [ + "mlk_cast_int32_to_uint16", + "mlk_cast_uint16_to_int16", + "mlk_cast_int16_to_uint16", + "mlk_value_barrier_i32", + "mlk_montgomery_reduce", + "mlk_ct_cmask_neg_i16", + "mlk_ct_cmask_nonzero_u16", + "mlk_ct_sel_int16", + "mlk_fqmul", + "mlk_barrett_reduce", + "mlk_scalar_signed_to_unsigned_q", + "mlk_poly_add", + "mlk_poly_sub", + "mlk_poly_tomont_c", + "mlk_poly_tomont", + "mlk_poly_reduce_c", + "mlk_poly_reduce", + "mlk_poly_mulcache_compute_c", + "mlk_poly_mulcache_compute", + "mlk_ntt_butterfly_block", + "mlk_ntt_layer", + "mlk_poly_ntt_c", + "mlk_poly_ntt", + "mlk_invntt_layer", + "mlk_poly_invntt_tomont_c", + "mlk_poly_invntt_tomont" + ] + }, + "type_manifest": { + "output": "generated/manifests/poly.types.manifest", + "functions": [ + "__mlkem_pipeline_no_functions__" + ], + "types": [ + "mlk_poly", + "mlk_poly_mulcache", + "int16_t", + "int32_t", + "uint16_t", + "uint32_t" + ] + }, + "isabelle_prelude": [ + "context c_mlk_machine_model", + "begin" + ], + "isabelle_postlude": [ + "end" + ] + }, + { + "name": "verify", + "source": "mlkem/src/verify.c", + "parameter_set": 512, + "theory_name": "MLKEM_Verify_Definitions", + "output_c": "generated/c/verify.pre.c", + "output_theory": "MLKEM_Verify_Definitions.thy", + "cpp_overrides": [ + "-Iproofs/isabelle/pipeline/config", + "-DMLK_CONFIG_FILE=\"proof_config.h\"", + "-D__attribute__(x)=", + "-D__extension__=", + "-D__asm__(x)=" + ] + } + ] +}